FastBQSR/src/bqsr/qual_quantizer.h

315 lines
12 KiB
C
Raw Normal View History

/*
Description:
Copyright : All right reserved by ICT
Author : Zhang Zhonghai
Date : 2025/12/26
*/
#pragma once
#include <set>
#include <vector>
#include <cstdint>
#include "qual_utils.h"
using std::set;
using std::vector;
/**
* A general algorithm for quantizing quality score distributions to use a specific number of levels
*
* Takes a histogram of quality scores and a desired number of levels and produces a
* map from original quality scores -> quantized quality scores.
*
* Note that this data structure is fairly heavy-weight, holding lots of debugging and
* calculation information. If you want to use it efficiently at scale with lots of
* read groups the right way to do this:
*
* Map<ReadGroup, List<Byte>> map
* for each read group rg:
* hist = getQualHist(rg)
* QualQuantizer qq = new QualQuantizer(hist, nLevels, minInterestingQual)
* map.set(rg, qq.getOriginalToQuantizedMap())
*
* This map would then be used to look up the appropriate original -> quantized
* quals for each read as it comes in.
*/
struct QualQuantizer {
/**
* Represents a contiguous interval of quality scores.
*
* qStart and qEnd are inclusive, so qStart = qEnd = 2 is the quality score bin of 2
*/
struct QualInterval {
int qStart, qEnd, fixedQual, level;
int64_t nObservations, nErrors;
set<QualInterval> subIntervals;
/** for debugging / visualization. When was this interval created? */
int mergeOrder;
void init(const int _qStart, const int _qEnd, const int64_t _nObservations, const int64_t _nErrors, const int _level, const int _fixedQual) {
qStart = _qStart;
qEnd = _qEnd;
nObservations = _nObservations;
nErrors = _nErrors;
level = _level;
fixedQual = _fixedQual;
}
QualInterval() {
qStart = -1;
qEnd = -1;
nObservations = -1;
nErrors = -1;
fixedQual = -1;
level = -1;
mergeOrder = 0;
}
QualInterval(const int qStart, const int qEnd, const int64_t nObservations, const int64_t nErrors, const int level) {
init(qStart, qEnd, nObservations, nErrors, level, -1);
}
QualInterval(const int qStart, const int qEnd, const int64_t nObservations, const int64_t nErrors, const int level,
const set<QualInterval>& _subIntervals) {
init(qStart, qEnd, nObservations, nErrors, level, -1);
subIntervals = _subIntervals;
}
QualInterval(const int qStart, const int qEnd, const int64_t nObservations, const int64_t nErrors, const int level, const int fixedQual) {
init(qStart, qEnd, nObservations, nErrors, level, fixedQual);
}
QualInterval(const int qStart, const int qEnd, const int64_t nObservations, const int64_t nErrors, const int level, const int fixedQual,
const set<QualInterval>& _subIntervals) {
init(qStart, qEnd, nObservations, nErrors, level, fixedQual);
subIntervals = _subIntervals;
}
/**
* @return Human readable name of this interval: e.g., 10-12
*/
string getName() const { return std::to_string(qStart) + "-" + std::to_string(qEnd); }
string toString() const {
return "QQ:" + getName();
}
/**
* @return true if this bin is using a fixed qual
*/
bool hasFixedQual() const { return fixedQual != -1; }
/**
* @return the error rate (in real space) of this interval, or 0 if there are no observations
*/
double getErrorRate() const {
if (hasFixedQual())
return QualityUtils::qualToErrorProb((uint8_t)fixedQual);
else if (nObservations == 0)
return 0.0;
else
return (nErrors + 1) / (1.0 * (nObservations + 1));
}
/**
* @return the QUAL of the error rate of this interval, or the fixed qual if this interval was created with a fixed qual.
*/
uint8_t getQual() const {
if (!hasFixedQual())
return QualityUtils::errorProbToQual(getErrorRate());
else
return (uint8_t)fixedQual;
}
int compareTo(const QualInterval& qi) const { return qStart < qi.qStart ? -1 : (qStart == qi.qStart ? 0 : 1); }
/**
* Create a interval representing the merge of this interval and toMerge
*
* Errors and observations are combined
* Subintervals updated in order of left to right (determined by qStart)
* Level is 1 + highest level of this and toMerge
* Order must be updated elsewhere
*
* @param toMerge
* @return newly created merged QualInterval
*/
QualInterval merge(const QualInterval& toMerge) const {
const QualInterval &left = this->compareTo(toMerge) < 0 ? *this : toMerge;
const QualInterval &right = this->compareTo(toMerge) < 0 ? toMerge : *this;
if (left.qEnd + 1 != right.qStart) {
// throw new GATKException("Attempting to merge non-contiguous intervals: left = " + left + " right = " + right);
std::cerr << "Attempting to merge non-contiguous intervals: left = " + left.toString() + " right = " + right.toString() << std::endl;
exit(1);
}
const int64_t nCombinedObs = left.nObservations + right.nObservations;
const int64_t nCombinedErr = left.nErrors + right.nErrors;
const int level = std::max(left.level, right.level) + 1;
set<QualInterval> subIntervals;
subIntervals.insert(left);
subIntervals.insert(right);
QualInterval merged(left.qStart, right.qEnd, nCombinedObs, nCombinedErr, level, subIntervals);
return merged;
}
double getPenalty(const int minInterestingQual) const { return calcPenalty(getErrorRate(), minInterestingQual); }
/**
* Calculate the penalty of this interval, given the overall error rate for the interval
*
* If the globalErrorRate is e, this value is:
*
* sum_i |log10(e_i) - log10(e)| * nObservations_i
*
* each the index i applies to all leaves of the tree accessible from this interval
* (found recursively from subIntervals as necessary)
*
* @param globalErrorRate overall error rate in real space against which we calculate the penalty
* @return the cost of approximating the bins in this interval with the globalErrorRate
*/
double calcPenalty(const double globalErrorRate, const int minInterestingQual) const {
if (globalErrorRate == 0.0) // there were no observations, so there's no penalty
return 0.0;
if (subIntervals.empty()) {
// this is leave node
if (this->qEnd <= minInterestingQual)
// It's free to merge up quality scores below the smallest interesting one
return 0;
else {
return (std::abs(std::log10(getErrorRate()) - std::log10(globalErrorRate))) * nObservations;
}
} else {
double sum = 0;
for (const QualInterval interval : subIntervals) sum += interval.calcPenalty(globalErrorRate, minInterestingQual);
return sum;
}
}
bool operator<(const QualInterval& o) const {
return qStart < o.qStart;
}
QualInterval& operator=(const QualInterval& o) {
if (this == &o) return *this;
init(o.qStart, o.qEnd, o.nObservations, o.nErrors, o.level, o.fixedQual);
mergeOrder = o.mergeOrder;
subIntervals.clear();
for (auto& val : o.subIntervals) subIntervals.insert(val);
return *this;
}
};
/**
* Inputs to the QualQuantizer
*/
const int nLevels, minInterestingQual;
vector<int64_t>& nObservationsPerQual;
QualQuantizer(vector<int64_t>& _nObservationsPerQual, const int _nLevels, const int _minInterestingQual)
: nObservationsPerQual(_nObservationsPerQual), nLevels(_nLevels), minInterestingQual(_minInterestingQual) {
quantize();
}
/** Sorted set of qual intervals.
*
* After quantize() this data structure contains only the top-level qual intervals
*/
set<QualInterval> quantizedIntervals;
/**
* Represents a contiguous interval of quality scores.
*
* qStart and qEnd are inclusive, so qStart = qEnd = 2 is the quality score bin of 2
*/
void getOriginalToQuantizedMap(vector<uint8_t>& quantMap) {
quantMap.resize(getNQualsInHistogram(), UINT8_MAX);
for (auto& interval : quantizedIntervals) {
for (int q = interval.qStart; q <= interval.qEnd; q++) {
quantMap[q] = interval.getQual();
}
}
// if (Collections.min(map) == Byte.MIN_VALUE) throw new GATKException("quantized quality score map contains an un-initialized value");
}
int getNQualsInHistogram() { return nObservationsPerQual.size(); }
/**
* Main method for computing the quantization intervals.
*
* Invoked in the constructor after all input variables are initialized. Walks
* over the inputs and builds the min. penalty forest of intervals with exactly nLevel
* root nodes. Finds this min. penalty forest via greedy search, so is not guarenteed
* to find the optimal combination.
*
* TODO: develop a smarter algorithm
*
* @return the forest of intervals with size == nLevels
*/
void quantize() {
// create intervals for each qual individually
auto& intervals = quantizedIntervals;
for (int qStart = 0; qStart < getNQualsInHistogram(); qStart++) {
const int64_t nObs = nObservationsPerQual.at(qStart);
const double errorRate = QualityUtils::qualToErrorProb((uint8_t)qStart);
const double nErrors = nObs * errorRate;
const QualInterval qi(qStart, qStart, nObs, (int)std::floor(nErrors), 0, (uint8_t)qStart);
intervals.insert(qi);
}
// greedy algorithm:
// while ( n intervals >= nLevels ):
// find intervals to merge with least penalty
// merge it
while (intervals.size() > nLevels) {
mergeLowestPenaltyIntervals(intervals);
}
}
/**
* Helper function that finds and merges together the lowest penalty pair of intervals
* @param intervals
*/
void mergeLowestPenaltyIntervals(set<QualInterval>& intervals) {
// setup the iterators
auto it1 = intervals.begin();
auto it1p = intervals.begin();
++it1p; // skip one
// walk over the pairs of left and right, keeping track of the pair with the lowest merge penalty
QualInterval minMerge;
// if (logger.isDebugEnabled()) logger.debug("mergeLowestPenaltyIntervals: " + intervals.size());
int lastMergeOrder = 0;
while (it1p != intervals.end()) {
const QualInterval& left = *it1;
const QualInterval& right = *it1p;
const QualInterval merged = left.merge(right);
lastMergeOrder = std::max(std::max(lastMergeOrder, left.mergeOrder), right.mergeOrder);
if (minMerge.qStart == -1 || (merged.getPenalty(minInterestingQual) < minMerge.getPenalty(minInterestingQual))) {
// if (logger.isDebugEnabled()) logger.debug(" Updating merge " + minMerge);
minMerge = merged; // merge two bins that when merged incur the lowest "penalty"
}
++it1;
++it1p;
}
// now actually go ahead and merge the minMerge pair
// if (logger.isDebugEnabled()) logger.debug(" => const min merge " + minMerge);
minMerge.mergeOrder = lastMergeOrder + 1;
// intervals.removeAll(minMerge.subIntervals);
for (auto &itr : minMerge.subIntervals) {
intervals.erase(itr);
}
intervals.insert(minMerge);
// if (logger.isDebugEnabled()) logger.debug("updated intervals: " + intervals);
}
};