Refactoring the Bayesian empirical quality estimates to be in a single unit-testable function.
This commit is contained in:
parent
1d5b29e764
commit
cba89e98ad
|
|
@ -57,6 +57,8 @@ import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
|
|||
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
|
||||
|
||||
import java.io.File;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
* Utility methods to facilitate on-the-fly base quality score recalibration.
|
||||
|
|
@ -78,10 +80,6 @@ public class BaseRecalibration {
|
|||
private final double globalQScorePrior;
|
||||
private final boolean emitOriginalQuals;
|
||||
|
||||
private final NestedIntegerArray<Double> globalDeltaQs;
|
||||
private final NestedIntegerArray<Double> deltaQReporteds;
|
||||
|
||||
|
||||
/**
|
||||
* Constructor using a GATK Report file
|
||||
*
|
||||
|
|
@ -105,44 +103,6 @@ public class BaseRecalibration {
|
|||
this.preserveQLessThan = preserveQLessThan;
|
||||
this.globalQScorePrior = globalQScorePrior;
|
||||
this.emitOriginalQuals = emitOriginalQuals;
|
||||
|
||||
logger.info("Calculating cached tables...");
|
||||
|
||||
//
|
||||
// Create a NestedIntegerArray<Double> that maps from rgKey x errorModel -> double,
|
||||
// where the double is the result of this calculation. The entire calculation can
|
||||
// be done upfront, on initialization of this BaseRecalibration structure
|
||||
//
|
||||
final NestedIntegerArray<RecalDatum> byReadGroupTable = recalibrationTables.getReadGroupTable();
|
||||
globalDeltaQs = new NestedIntegerArray<Double>( byReadGroupTable.getDimensions() );
|
||||
logger.info("Calculating global delta Q table...");
|
||||
for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byReadGroupTable.getAllLeaves() ) {
|
||||
final int rgKey = leaf.keys[0];
|
||||
final int eventIndex = leaf.keys[1];
|
||||
final double globalDeltaQ = calculateGlobalDeltaQ(rgKey, EventType.eventFrom(eventIndex));
|
||||
globalDeltaQs.put(globalDeltaQ, rgKey, eventIndex);
|
||||
}
|
||||
|
||||
|
||||
// The calculation of the deltaQ report is constant. key[0] and key[1] are the read group and qual, respectively
|
||||
// and globalDeltaQ is a constant for the read group. So technically the delta Q reported is simply a lookup
|
||||
// into a matrix indexed by rgGroup, qual, and event type.
|
||||
// the code below actually creates this cache with a NestedIntegerArray calling into the actual
|
||||
// calculateDeltaQReported code.
|
||||
final NestedIntegerArray<RecalDatum> byQualTable = recalibrationTables.getQualityScoreTable();
|
||||
deltaQReporteds = new NestedIntegerArray<Double>( byQualTable.getDimensions() );
|
||||
logger.info("Calculating delta Q reported table...");
|
||||
for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byQualTable.getAllLeaves() ) {
|
||||
final int rgKey = leaf.keys[0];
|
||||
final int qual = leaf.keys[1];
|
||||
final int eventIndex = leaf.keys[2];
|
||||
final EventType event = EventType.eventFrom(eventIndex);
|
||||
final double globalDeltaQ = getGlobalDeltaQ(rgKey, event);
|
||||
final double deltaQReported = calculateDeltaQReported(rgKey, qual, event, globalDeltaQ, (byte)qual);
|
||||
deltaQReporteds.put(deltaQReported, rgKey, qual, eventIndex);
|
||||
}
|
||||
|
||||
logger.info("done calculating cache");
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -189,8 +149,8 @@ public class BaseRecalibration {
|
|||
|
||||
// the rg key is constant over the whole read, the global deltaQ is too
|
||||
final int rgKey = fullReadKeySet[0][0];
|
||||
|
||||
final double globalDeltaQ = getGlobalDeltaQ(rgKey, errorModel);
|
||||
final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.ordinal());
|
||||
final double epsilon = ( globalQScorePrior > 0.0 && errorModel.equals(EventType.BASE_SUBSTITUTION) ? globalQScorePrior : empiricalQualRG.getEstimatedQReported() );
|
||||
|
||||
for (int offset = 0; offset < readLength; offset++) { // recalibrate all bases in the read
|
||||
final byte origQual = quals[offset];
|
||||
|
|
@ -199,11 +159,16 @@ public class BaseRecalibration {
|
|||
if ( origQual >= preserveQLessThan ) {
|
||||
// get the keyset for this base using the error model
|
||||
final int[] keySet = fullReadKeySet[offset];
|
||||
final double deltaQReported = getDeltaQReported(keySet[0], keySet[1], errorModel, globalDeltaQ);
|
||||
final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, globalDeltaQ, deltaQReported, origQual);
|
||||
final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(keySet[0], keySet[1], errorModel.ordinal());
|
||||
final List<RecalDatum> empiricalQualCovs = new ArrayList<RecalDatum>();
|
||||
for (int i = 2; i < requestedCovariates.length; i++) {
|
||||
if (keySet[i] < 0) {
|
||||
continue;
|
||||
}
|
||||
empiricalQualCovs.add(recalibrationTables.getTable(i).get(keySet[0], keySet[1], keySet[i], errorModel.ordinal()));
|
||||
}
|
||||
|
||||
// calculate the recalibrated qual using the BQSR formula
|
||||
double recalibratedQualDouble = origQual + globalDeltaQ + deltaQReported + deltaQCovariates;
|
||||
double recalibratedQualDouble = hierarchicalBayesianQualityEstimate( epsilon, empiricalQualRG, empiricalQualQS, empiricalQualCovs );
|
||||
|
||||
// recalibrated quality is bound between 1 and MAX_QUAL
|
||||
final byte recalibratedQual = QualityUtils.boundQual(MathUtils.fastRound(recalibratedQualDouble), QualityUtils.MAX_RECALIBRATED_Q_SCORE);
|
||||
|
|
@ -220,102 +185,14 @@ public class BaseRecalibration {
|
|||
}
|
||||
}
|
||||
|
||||
private double getGlobalDeltaQ(final int rgKey, final EventType errorModel) {
|
||||
final Double cached = globalDeltaQs.get(rgKey, errorModel.ordinal());
|
||||
|
||||
if ( TEST_CACHING ) {
|
||||
final double calcd = calculateGlobalDeltaQ(rgKey, errorModel);
|
||||
if ( calcd != cached )
|
||||
throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + errorModel);
|
||||
protected double hierarchicalBayesianQualityEstimate( final double epsilon, final RecalDatum empiricalQualRG, final RecalDatum empiricalQualQS, final List<RecalDatum> empiricalQualCovs ) {
|
||||
final double globalDeltaQ = ( empiricalQualRG == null ? 0.0 : empiricalQualRG.getEmpiricalQuality(epsilon) - epsilon );
|
||||
final double deltaQReported = ( empiricalQualQS == null ? 0.0 : empiricalQualQS.getEmpiricalQuality(globalDeltaQ + epsilon) - (globalDeltaQ + epsilon) );
|
||||
double deltaQCovariates = 0.0;
|
||||
for( final RecalDatum empiricalQualCov : empiricalQualCovs ) {
|
||||
deltaQCovariates += ( empiricalQualCov == null ? 0.0 : empiricalQualCov.getEmpiricalQuality(deltaQReported + globalDeltaQ + epsilon) - (deltaQReported + globalDeltaQ + epsilon) );
|
||||
}
|
||||
|
||||
return cachedWithDefault(cached);
|
||||
}
|
||||
|
||||
private double getDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ) {
|
||||
final Double cached = deltaQReporteds.get(rgKey, qualKey, errorModel.ordinal());
|
||||
|
||||
if ( TEST_CACHING ) {
|
||||
final double calcd = calculateDeltaQReported(rgKey, qualKey, errorModel, globalDeltaQ, (byte)qualKey);
|
||||
if ( calcd != cached )
|
||||
throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + qualKey + " / " + errorModel);
|
||||
}
|
||||
|
||||
return cachedWithDefault(cached);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param d a Double (that may be null) that is the result of a delta Q calculation
|
||||
* @return a double == d if d != null, or 0.0 if it is
|
||||
*/
|
||||
private double cachedWithDefault(final Double d) {
|
||||
return d == null ? 0.0 : d;
|
||||
}
|
||||
|
||||
/**
|
||||
* Note that this calculation is a constant for each rgKey and errorModel. We need only
|
||||
* compute this value once for all data.
|
||||
*
|
||||
* @param rgKey read group key
|
||||
* @param errorModel the event type
|
||||
* @return global delta Q
|
||||
*/
|
||||
private double calculateGlobalDeltaQ(final int rgKey, final EventType errorModel) {
|
||||
double result = 0.0;
|
||||
|
||||
final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.ordinal());
|
||||
|
||||
if (empiricalQualRG != null) {
|
||||
final double aggregrateQReported = ( globalQScorePrior > 0.0 && errorModel.equals(EventType.BASE_SUBSTITUTION) ? globalQScorePrior : empiricalQualRG.getEstimatedQReported() );
|
||||
final double globalDeltaQEmpirical = empiricalQualRG.getEmpiricalQuality(aggregrateQReported);
|
||||
result = globalDeltaQEmpirical - aggregrateQReported;
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private double calculateDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ, final byte qualFromRead) {
|
||||
double result = 0.0;
|
||||
|
||||
final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(rgKey, qualKey, errorModel.ordinal());
|
||||
if (empiricalQualQS != null) {
|
||||
final double deltaQReportedEmpirical = empiricalQualQS.getEmpiricalQuality(globalDeltaQ + qualFromRead);
|
||||
result = deltaQReportedEmpirical - (globalDeltaQ + qualFromRead);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private double calculateDeltaQCovariates(final RecalibrationTables recalibrationTables, final int[] key, final EventType errorModel, final double globalDeltaQ, final double deltaQReported, final byte qualFromRead) {
|
||||
double result = 0.0;
|
||||
|
||||
// for all optional covariates
|
||||
for (int i = 2; i < requestedCovariates.length; i++) {
|
||||
if (key[i] < 0)
|
||||
continue;
|
||||
|
||||
result += calculateDeltaQCovariate(recalibrationTables.getTable(i),
|
||||
key[0], key[1], key[i], errorModel,
|
||||
globalDeltaQ, deltaQReported, qualFromRead);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private double calculateDeltaQCovariate(final NestedIntegerArray<RecalDatum> table,
|
||||
final int rgKey,
|
||||
final int qualKey,
|
||||
final int tableKey,
|
||||
final EventType errorModel,
|
||||
final double globalDeltaQ,
|
||||
final double deltaQReported,
|
||||
final byte qualFromRead) {
|
||||
final RecalDatum empiricalQualCO = table.get(rgKey, qualKey, tableKey, errorModel.ordinal());
|
||||
if (empiricalQualCO != null) {
|
||||
final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality(deltaQReported + globalDeltaQ + qualFromRead);
|
||||
return deltaQCovariateEmpirical - (deltaQReported + globalDeltaQ + qualFromRead);
|
||||
} else {
|
||||
return 0.0;
|
||||
}
|
||||
return epsilon + globalDeltaQ + deltaQReported + deltaQCovariates;
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue