From cba89e98ad4879e66fc9bac9d5a102276b0517cb Mon Sep 17 00:00:00 2001 From: Ryan Poplin Date: Tue, 29 Jan 2013 15:50:46 -0500 Subject: [PATCH] Refactoring the Bayesian empirical quality estimates to be in a single unit-testable function. --- .../recalibration/BaseRecalibration.java | 163 +++--------------- 1 file changed, 20 insertions(+), 143 deletions(-) diff --git a/protected/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java b/protected/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java index 1f4e92ad7..c8d460308 100644 --- a/protected/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java +++ b/protected/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java @@ -57,6 +57,8 @@ import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; import java.io.File; +import java.util.ArrayList; +import java.util.List; /** * Utility methods to facilitate on-the-fly base quality score recalibration. @@ -78,10 +80,6 @@ public class BaseRecalibration { private final double globalQScorePrior; private final boolean emitOriginalQuals; - private final NestedIntegerArray globalDeltaQs; - private final NestedIntegerArray deltaQReporteds; - - /** * Constructor using a GATK Report file * @@ -105,44 +103,6 @@ public class BaseRecalibration { this.preserveQLessThan = preserveQLessThan; this.globalQScorePrior = globalQScorePrior; this.emitOriginalQuals = emitOriginalQuals; - - logger.info("Calculating cached tables..."); - - // - // Create a NestedIntegerArray that maps from rgKey x errorModel -> double, - // where the double is the result of this calculation. The entire calculation can - // be done upfront, on initialization of this BaseRecalibration structure - // - final NestedIntegerArray byReadGroupTable = recalibrationTables.getReadGroupTable(); - globalDeltaQs = new NestedIntegerArray( byReadGroupTable.getDimensions() ); - logger.info("Calculating global delta Q table..."); - for ( NestedIntegerArray.Leaf leaf : byReadGroupTable.getAllLeaves() ) { - final int rgKey = leaf.keys[0]; - final int eventIndex = leaf.keys[1]; - final double globalDeltaQ = calculateGlobalDeltaQ(rgKey, EventType.eventFrom(eventIndex)); - globalDeltaQs.put(globalDeltaQ, rgKey, eventIndex); - } - - - // The calculation of the deltaQ report is constant. key[0] and key[1] are the read group and qual, respectively - // and globalDeltaQ is a constant for the read group. So technically the delta Q reported is simply a lookup - // into a matrix indexed by rgGroup, qual, and event type. - // the code below actually creates this cache with a NestedIntegerArray calling into the actual - // calculateDeltaQReported code. - final NestedIntegerArray byQualTable = recalibrationTables.getQualityScoreTable(); - deltaQReporteds = new NestedIntegerArray( byQualTable.getDimensions() ); - logger.info("Calculating delta Q reported table..."); - for ( NestedIntegerArray.Leaf leaf : byQualTable.getAllLeaves() ) { - final int rgKey = leaf.keys[0]; - final int qual = leaf.keys[1]; - final int eventIndex = leaf.keys[2]; - final EventType event = EventType.eventFrom(eventIndex); - final double globalDeltaQ = getGlobalDeltaQ(rgKey, event); - final double deltaQReported = calculateDeltaQReported(rgKey, qual, event, globalDeltaQ, (byte)qual); - deltaQReporteds.put(deltaQReported, rgKey, qual, eventIndex); - } - - logger.info("done calculating cache"); } /** @@ -189,8 +149,8 @@ public class BaseRecalibration { // the rg key is constant over the whole read, the global deltaQ is too final int rgKey = fullReadKeySet[0][0]; - - final double globalDeltaQ = getGlobalDeltaQ(rgKey, errorModel); + final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.ordinal()); + final double epsilon = ( globalQScorePrior > 0.0 && errorModel.equals(EventType.BASE_SUBSTITUTION) ? globalQScorePrior : empiricalQualRG.getEstimatedQReported() ); for (int offset = 0; offset < readLength; offset++) { // recalibrate all bases in the read final byte origQual = quals[offset]; @@ -199,11 +159,16 @@ public class BaseRecalibration { if ( origQual >= preserveQLessThan ) { // get the keyset for this base using the error model final int[] keySet = fullReadKeySet[offset]; - final double deltaQReported = getDeltaQReported(keySet[0], keySet[1], errorModel, globalDeltaQ); - final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, globalDeltaQ, deltaQReported, origQual); + final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(keySet[0], keySet[1], errorModel.ordinal()); + final List empiricalQualCovs = new ArrayList(); + for (int i = 2; i < requestedCovariates.length; i++) { + if (keySet[i] < 0) { + continue; + } + empiricalQualCovs.add(recalibrationTables.getTable(i).get(keySet[0], keySet[1], keySet[i], errorModel.ordinal())); + } - // calculate the recalibrated qual using the BQSR formula - double recalibratedQualDouble = origQual + globalDeltaQ + deltaQReported + deltaQCovariates; + double recalibratedQualDouble = hierarchicalBayesianQualityEstimate( epsilon, empiricalQualRG, empiricalQualQS, empiricalQualCovs ); // recalibrated quality is bound between 1 and MAX_QUAL final byte recalibratedQual = QualityUtils.boundQual(MathUtils.fastRound(recalibratedQualDouble), QualityUtils.MAX_RECALIBRATED_Q_SCORE); @@ -220,102 +185,14 @@ public class BaseRecalibration { } } - private double getGlobalDeltaQ(final int rgKey, final EventType errorModel) { - final Double cached = globalDeltaQs.get(rgKey, errorModel.ordinal()); - - if ( TEST_CACHING ) { - final double calcd = calculateGlobalDeltaQ(rgKey, errorModel); - if ( calcd != cached ) - throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + errorModel); + protected double hierarchicalBayesianQualityEstimate( final double epsilon, final RecalDatum empiricalQualRG, final RecalDatum empiricalQualQS, final List empiricalQualCovs ) { + final double globalDeltaQ = ( empiricalQualRG == null ? 0.0 : empiricalQualRG.getEmpiricalQuality(epsilon) - epsilon ); + final double deltaQReported = ( empiricalQualQS == null ? 0.0 : empiricalQualQS.getEmpiricalQuality(globalDeltaQ + epsilon) - (globalDeltaQ + epsilon) ); + double deltaQCovariates = 0.0; + for( final RecalDatum empiricalQualCov : empiricalQualCovs ) { + deltaQCovariates += ( empiricalQualCov == null ? 0.0 : empiricalQualCov.getEmpiricalQuality(deltaQReported + globalDeltaQ + epsilon) - (deltaQReported + globalDeltaQ + epsilon) ); } - return cachedWithDefault(cached); - } - - private double getDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ) { - final Double cached = deltaQReporteds.get(rgKey, qualKey, errorModel.ordinal()); - - if ( TEST_CACHING ) { - final double calcd = calculateDeltaQReported(rgKey, qualKey, errorModel, globalDeltaQ, (byte)qualKey); - if ( calcd != cached ) - throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + qualKey + " / " + errorModel); - } - - return cachedWithDefault(cached); - } - - /** - * @param d a Double (that may be null) that is the result of a delta Q calculation - * @return a double == d if d != null, or 0.0 if it is - */ - private double cachedWithDefault(final Double d) { - return d == null ? 0.0 : d; - } - - /** - * Note that this calculation is a constant for each rgKey and errorModel. We need only - * compute this value once for all data. - * - * @param rgKey read group key - * @param errorModel the event type - * @return global delta Q - */ - private double calculateGlobalDeltaQ(final int rgKey, final EventType errorModel) { - double result = 0.0; - - final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.ordinal()); - - if (empiricalQualRG != null) { - final double aggregrateQReported = ( globalQScorePrior > 0.0 && errorModel.equals(EventType.BASE_SUBSTITUTION) ? globalQScorePrior : empiricalQualRG.getEstimatedQReported() ); - final double globalDeltaQEmpirical = empiricalQualRG.getEmpiricalQuality(aggregrateQReported); - result = globalDeltaQEmpirical - aggregrateQReported; - } - - return result; - } - - private double calculateDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ, final byte qualFromRead) { - double result = 0.0; - - final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(rgKey, qualKey, errorModel.ordinal()); - if (empiricalQualQS != null) { - final double deltaQReportedEmpirical = empiricalQualQS.getEmpiricalQuality(globalDeltaQ + qualFromRead); - result = deltaQReportedEmpirical - (globalDeltaQ + qualFromRead); - } - - return result; - } - - private double calculateDeltaQCovariates(final RecalibrationTables recalibrationTables, final int[] key, final EventType errorModel, final double globalDeltaQ, final double deltaQReported, final byte qualFromRead) { - double result = 0.0; - - // for all optional covariates - for (int i = 2; i < requestedCovariates.length; i++) { - if (key[i] < 0) - continue; - - result += calculateDeltaQCovariate(recalibrationTables.getTable(i), - key[0], key[1], key[i], errorModel, - globalDeltaQ, deltaQReported, qualFromRead); - } - - return result; - } - - private double calculateDeltaQCovariate(final NestedIntegerArray table, - final int rgKey, - final int qualKey, - final int tableKey, - final EventType errorModel, - final double globalDeltaQ, - final double deltaQReported, - final byte qualFromRead) { - final RecalDatum empiricalQualCO = table.get(rgKey, qualKey, tableKey, errorModel.ordinal()); - if (empiricalQualCO != null) { - final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality(deltaQReported + globalDeltaQ + qualFromRead); - return deltaQCovariateEmpirical - (deltaQReported + globalDeltaQ + qualFromRead); - } else { - return 0.0; - } + return epsilon + globalDeltaQ + deltaQReported + deltaQCovariates; } }