Refactoring the Bayesian empirical quality estimates to be in a single unit-testable function.

This commit is contained in:
Ryan Poplin 2013-01-29 15:50:46 -05:00
parent 1d5b29e764
commit cba89e98ad
1 changed files with 20 additions and 143 deletions

View File

@ -57,6 +57,8 @@ import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import java.io.File;
import java.util.ArrayList;
import java.util.List;
/**
* Utility methods to facilitate on-the-fly base quality score recalibration.
@ -78,10 +80,6 @@ public class BaseRecalibration {
private final double globalQScorePrior;
private final boolean emitOriginalQuals;
private final NestedIntegerArray<Double> globalDeltaQs;
private final NestedIntegerArray<Double> deltaQReporteds;
/**
* Constructor using a GATK Report file
*
@ -105,44 +103,6 @@ public class BaseRecalibration {
this.preserveQLessThan = preserveQLessThan;
this.globalQScorePrior = globalQScorePrior;
this.emitOriginalQuals = emitOriginalQuals;
logger.info("Calculating cached tables...");
//
// Create a NestedIntegerArray<Double> that maps from rgKey x errorModel -> double,
// where the double is the result of this calculation. The entire calculation can
// be done upfront, on initialization of this BaseRecalibration structure
//
final NestedIntegerArray<RecalDatum> byReadGroupTable = recalibrationTables.getReadGroupTable();
globalDeltaQs = new NestedIntegerArray<Double>( byReadGroupTable.getDimensions() );
logger.info("Calculating global delta Q table...");
for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byReadGroupTable.getAllLeaves() ) {
final int rgKey = leaf.keys[0];
final int eventIndex = leaf.keys[1];
final double globalDeltaQ = calculateGlobalDeltaQ(rgKey, EventType.eventFrom(eventIndex));
globalDeltaQs.put(globalDeltaQ, rgKey, eventIndex);
}
// The calculation of the deltaQ report is constant. key[0] and key[1] are the read group and qual, respectively
// and globalDeltaQ is a constant for the read group. So technically the delta Q reported is simply a lookup
// into a matrix indexed by rgGroup, qual, and event type.
// the code below actually creates this cache with a NestedIntegerArray calling into the actual
// calculateDeltaQReported code.
final NestedIntegerArray<RecalDatum> byQualTable = recalibrationTables.getQualityScoreTable();
deltaQReporteds = new NestedIntegerArray<Double>( byQualTable.getDimensions() );
logger.info("Calculating delta Q reported table...");
for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byQualTable.getAllLeaves() ) {
final int rgKey = leaf.keys[0];
final int qual = leaf.keys[1];
final int eventIndex = leaf.keys[2];
final EventType event = EventType.eventFrom(eventIndex);
final double globalDeltaQ = getGlobalDeltaQ(rgKey, event);
final double deltaQReported = calculateDeltaQReported(rgKey, qual, event, globalDeltaQ, (byte)qual);
deltaQReporteds.put(deltaQReported, rgKey, qual, eventIndex);
}
logger.info("done calculating cache");
}
/**
@ -189,8 +149,8 @@ public class BaseRecalibration {
// the rg key is constant over the whole read, the global deltaQ is too
final int rgKey = fullReadKeySet[0][0];
final double globalDeltaQ = getGlobalDeltaQ(rgKey, errorModel);
final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.ordinal());
final double epsilon = ( globalQScorePrior > 0.0 && errorModel.equals(EventType.BASE_SUBSTITUTION) ? globalQScorePrior : empiricalQualRG.getEstimatedQReported() );
for (int offset = 0; offset < readLength; offset++) { // recalibrate all bases in the read
final byte origQual = quals[offset];
@ -199,11 +159,16 @@ public class BaseRecalibration {
if ( origQual >= preserveQLessThan ) {
// get the keyset for this base using the error model
final int[] keySet = fullReadKeySet[offset];
final double deltaQReported = getDeltaQReported(keySet[0], keySet[1], errorModel, globalDeltaQ);
final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, globalDeltaQ, deltaQReported, origQual);
final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(keySet[0], keySet[1], errorModel.ordinal());
final List<RecalDatum> empiricalQualCovs = new ArrayList<RecalDatum>();
for (int i = 2; i < requestedCovariates.length; i++) {
if (keySet[i] < 0) {
continue;
}
empiricalQualCovs.add(recalibrationTables.getTable(i).get(keySet[0], keySet[1], keySet[i], errorModel.ordinal()));
}
// calculate the recalibrated qual using the BQSR formula
double recalibratedQualDouble = origQual + globalDeltaQ + deltaQReported + deltaQCovariates;
double recalibratedQualDouble = hierarchicalBayesianQualityEstimate( epsilon, empiricalQualRG, empiricalQualQS, empiricalQualCovs );
// recalibrated quality is bound between 1 and MAX_QUAL
final byte recalibratedQual = QualityUtils.boundQual(MathUtils.fastRound(recalibratedQualDouble), QualityUtils.MAX_RECALIBRATED_Q_SCORE);
@ -220,102 +185,14 @@ public class BaseRecalibration {
}
}
private double getGlobalDeltaQ(final int rgKey, final EventType errorModel) {
final Double cached = globalDeltaQs.get(rgKey, errorModel.ordinal());
if ( TEST_CACHING ) {
final double calcd = calculateGlobalDeltaQ(rgKey, errorModel);
if ( calcd != cached )
throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + errorModel);
protected double hierarchicalBayesianQualityEstimate( final double epsilon, final RecalDatum empiricalQualRG, final RecalDatum empiricalQualQS, final List<RecalDatum> empiricalQualCovs ) {
final double globalDeltaQ = ( empiricalQualRG == null ? 0.0 : empiricalQualRG.getEmpiricalQuality(epsilon) - epsilon );
final double deltaQReported = ( empiricalQualQS == null ? 0.0 : empiricalQualQS.getEmpiricalQuality(globalDeltaQ + epsilon) - (globalDeltaQ + epsilon) );
double deltaQCovariates = 0.0;
for( final RecalDatum empiricalQualCov : empiricalQualCovs ) {
deltaQCovariates += ( empiricalQualCov == null ? 0.0 : empiricalQualCov.getEmpiricalQuality(deltaQReported + globalDeltaQ + epsilon) - (deltaQReported + globalDeltaQ + epsilon) );
}
return cachedWithDefault(cached);
}
private double getDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ) {
final Double cached = deltaQReporteds.get(rgKey, qualKey, errorModel.ordinal());
if ( TEST_CACHING ) {
final double calcd = calculateDeltaQReported(rgKey, qualKey, errorModel, globalDeltaQ, (byte)qualKey);
if ( calcd != cached )
throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + qualKey + " / " + errorModel);
}
return cachedWithDefault(cached);
}
/**
* @param d a Double (that may be null) that is the result of a delta Q calculation
* @return a double == d if d != null, or 0.0 if it is
*/
private double cachedWithDefault(final Double d) {
return d == null ? 0.0 : d;
}
/**
* Note that this calculation is a constant for each rgKey and errorModel. We need only
* compute this value once for all data.
*
* @param rgKey read group key
* @param errorModel the event type
* @return global delta Q
*/
private double calculateGlobalDeltaQ(final int rgKey, final EventType errorModel) {
double result = 0.0;
final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.ordinal());
if (empiricalQualRG != null) {
final double aggregrateQReported = ( globalQScorePrior > 0.0 && errorModel.equals(EventType.BASE_SUBSTITUTION) ? globalQScorePrior : empiricalQualRG.getEstimatedQReported() );
final double globalDeltaQEmpirical = empiricalQualRG.getEmpiricalQuality(aggregrateQReported);
result = globalDeltaQEmpirical - aggregrateQReported;
}
return result;
}
private double calculateDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ, final byte qualFromRead) {
double result = 0.0;
final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(rgKey, qualKey, errorModel.ordinal());
if (empiricalQualQS != null) {
final double deltaQReportedEmpirical = empiricalQualQS.getEmpiricalQuality(globalDeltaQ + qualFromRead);
result = deltaQReportedEmpirical - (globalDeltaQ + qualFromRead);
}
return result;
}
private double calculateDeltaQCovariates(final RecalibrationTables recalibrationTables, final int[] key, final EventType errorModel, final double globalDeltaQ, final double deltaQReported, final byte qualFromRead) {
double result = 0.0;
// for all optional covariates
for (int i = 2; i < requestedCovariates.length; i++) {
if (key[i] < 0)
continue;
result += calculateDeltaQCovariate(recalibrationTables.getTable(i),
key[0], key[1], key[i], errorModel,
globalDeltaQ, deltaQReported, qualFromRead);
}
return result;
}
private double calculateDeltaQCovariate(final NestedIntegerArray<RecalDatum> table,
final int rgKey,
final int qualKey,
final int tableKey,
final EventType errorModel,
final double globalDeltaQ,
final double deltaQReported,
final byte qualFromRead) {
final RecalDatum empiricalQualCO = table.get(rgKey, qualKey, tableKey, errorModel.ordinal());
if (empiricalQualCO != null) {
final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality(deltaQReported + globalDeltaQ + qualFromRead);
return deltaQCovariateEmpirical - (deltaQReported + globalDeltaQ + qualFromRead);
} else {
return 0.0;
}
return epsilon + globalDeltaQ + deltaQReported + deltaQCovariates;
}
}