From 1f254d29df1c249de917a85ae5c8a756742f9697 Mon Sep 17 00:00:00 2001 From: Ryan Poplin Date: Mon, 28 Jan 2013 22:16:43 -0500 Subject: [PATCH] Don't set the empirical quality when reading in the recal table because then we won't be using the new quality estimates for the prior since the value is cached. --- .../recalibration/BaseRecalibration.java | 60 +++++++++---------- .../sting/utils/recalibration/RecalDatum.java | 58 ++++++------------ .../recalibration/RecalibrationReport.java | 4 +- 3 files changed, 51 insertions(+), 71 deletions(-) diff --git a/protected/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java b/protected/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java index 6852cc40f..ba18f5c96 100644 --- a/protected/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java +++ b/protected/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java @@ -81,9 +81,10 @@ public class BaseRecalibration { private final NestedIntegerArray globalDeltaQs; private final NestedIntegerArray deltaQReporteds; + /** * Constructor using a GATK Report file - * + * * @param RECAL_FILE a GATK Report file containing the recalibration information * @param quantizationLevels number of bins to quantize the quality scores * @param disableIndelQuals if true, do not emit base indel qualities @@ -113,16 +114,13 @@ public class BaseRecalibration { // be done upfront, on initialization of this BaseRecalibration structure // final NestedIntegerArray byReadGroupTable = recalibrationTables.getReadGroupTable(); - final NestedIntegerArray byQualTable = recalibrationTables.getQualityScoreTable(); - - globalDeltaQs = new NestedIntegerArray( byQualTable.getDimensions() ); + globalDeltaQs = new NestedIntegerArray( byReadGroupTable.getDimensions() ); logger.info("Calculating global delta Q table..."); - for ( NestedIntegerArray.Leaf leaf : byQualTable.getAllLeaves() ) { + for ( NestedIntegerArray.Leaf leaf : byReadGroupTable.getAllLeaves() ) { final int rgKey = leaf.keys[0]; - final int qual = leaf.keys[1]; - final int eventIndex = leaf.keys[2]; - final double globalDeltaQ = calculateGlobalDeltaQ(rgKey, EventType.eventFrom(eventIndex), (byte) qual); - globalDeltaQs.put(globalDeltaQ, rgKey, qual, eventIndex); + final int eventIndex = leaf.keys[1]; + final double globalDeltaQ = calculateGlobalDeltaQ(rgKey, EventType.eventFrom(eventIndex)); + globalDeltaQs.put(globalDeltaQ, rgKey, eventIndex); } @@ -131,6 +129,7 @@ public class BaseRecalibration { // into a matrix indexed by rgGroup, qual, and event type. // the code below actually creates this cache with a NestedIntegerArray calling into the actual // calculateDeltaQReported code. + final NestedIntegerArray byQualTable = recalibrationTables.getQualityScoreTable(); deltaQReporteds = new NestedIntegerArray( byQualTable.getDimensions() ); logger.info("Calculating delta Q reported table..."); for ( NestedIntegerArray.Leaf leaf : byQualTable.getAllLeaves() ) { @@ -138,7 +137,7 @@ public class BaseRecalibration { final int qual = leaf.keys[1]; final int eventIndex = leaf.keys[2]; final EventType event = EventType.eventFrom(eventIndex); - final double globalDeltaQ = getGlobalDeltaQ(rgKey, event, (byte)qual); + final double globalDeltaQ = getGlobalDeltaQ(rgKey, event); final double deltaQReported = calculateDeltaQReported(rgKey, qual, event, globalDeltaQ, (byte)qual); deltaQReporteds.put(deltaQReported, rgKey, qual, eventIndex); } @@ -191,17 +190,17 @@ public class BaseRecalibration { // the rg key is constant over the whole read, the global deltaQ is too final int rgKey = fullReadKeySet[0][0]; + final double globalDeltaQ = getGlobalDeltaQ(rgKey, errorModel); for (int offset = 0; offset < readLength; offset++) { // recalibrate all bases in the read final byte origQual = quals[offset]; - final double globalDeltaQ = getGlobalDeltaQ(rgKey, errorModel, origQual); // only recalibrate usable qualities (the original quality will come from the instrument -- reported quality) if ( origQual >= preserveQLessThan ) { // get the keyset for this base using the error model final int[] keySet = fullReadKeySet[offset]; - final double deltaQReported = getDeltaQReported(keySet[0], keySet[1], errorModel, globalDeltaQ, origQual); - final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, deltaQReported, globalDeltaQ, origQual); + final double deltaQReported = getDeltaQReported(keySet[0], keySet[1], errorModel, globalDeltaQ); + final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, globalDeltaQ, deltaQReported, origQual); // calculate the recalibrated qual using the BQSR formula double recalibratedQualDouble = origQual + globalDeltaQ + deltaQReported + deltaQCovariates; @@ -213,6 +212,10 @@ public class BaseRecalibration { final byte recalibratedQualityScore = quantizationInfo.getQuantizedQuals().get(recalibratedQual); quals[offset] = recalibratedQualityScore; + if( quals[offset] > QualityUtils.MAX_REASONABLE_Q_SCORE ) { + System.out.println("A"); + //calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, globalDeltaQ, deltaQReported, origQual); + } } } @@ -221,12 +224,11 @@ public class BaseRecalibration { } } - private double getGlobalDeltaQ(final int rgKey, final EventType errorModel, final byte qualFromRead) { - - final Double cached = globalDeltaQs.get(rgKey, (int) qualFromRead, errorModel.ordinal()); + private double getGlobalDeltaQ(final int rgKey, final EventType errorModel) { + final Double cached = globalDeltaQs.get(rgKey, errorModel.ordinal()); if ( TEST_CACHING ) { - final double calcd = calculateGlobalDeltaQ(rgKey, errorModel, qualFromRead); + final double calcd = calculateGlobalDeltaQ(rgKey, errorModel); if ( calcd != cached ) throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + errorModel); } @@ -234,8 +236,7 @@ public class BaseRecalibration { return cachedWithDefault(cached); } - private double getDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ, final byte origQual) { - + private double getDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ) { final Double cached = deltaQReporteds.get(rgKey, qualKey, errorModel.ordinal()); if ( TEST_CACHING ) { @@ -245,7 +246,6 @@ public class BaseRecalibration { } return cachedWithDefault(cached); - } /** @@ -264,14 +264,14 @@ public class BaseRecalibration { * @param errorModel the event type * @return global delta Q */ - private double calculateGlobalDeltaQ(final int rgKey, final EventType errorModel, final byte qualFromRead) { + private double calculateGlobalDeltaQ(final int rgKey, final EventType errorModel) { double result = 0.0; final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.ordinal()); if (empiricalQualRG != null) { - final double aggregrateQReported = ( globalQScorePrior > 0.0 && errorModel.equals(EventType.BASE_SUBSTITUTION) ? globalQScorePrior : qualFromRead ); - final double globalDeltaQEmpirical = empiricalQualRG.getEmpiricalQuality( aggregrateQReported ); + final double aggregrateQReported = ( globalQScorePrior > 0.0 && errorModel.equals(EventType.BASE_SUBSTITUTION) ? globalQScorePrior : empiricalQualRG.getEstimatedQReported() ); + final double globalDeltaQEmpirical = empiricalQualRG.getEmpiricalQuality(aggregrateQReported); result = globalDeltaQEmpirical - aggregrateQReported; } @@ -283,14 +283,14 @@ public class BaseRecalibration { final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(rgKey, qualKey, errorModel.ordinal()); if (empiricalQualQS != null) { - final double deltaQReportedEmpirical = empiricalQualQS.getEmpiricalQuality( qualFromRead + globalDeltaQ ); - result = deltaQReportedEmpirical - ( qualFromRead + globalDeltaQ ); + final double deltaQReportedEmpirical = empiricalQualQS.getEmpiricalQuality(globalDeltaQ + qualFromRead); + result = deltaQReportedEmpirical - (globalDeltaQ + qualFromRead); } return result; } - private double calculateDeltaQCovariates(final RecalibrationTables recalibrationTables, final int[] key, final EventType errorModel, final double deltaQReported, final double globalDeltaQ, final byte qualFromRead) { + private double calculateDeltaQCovariates(final RecalibrationTables recalibrationTables, final int[] key, final EventType errorModel, final double globalDeltaQ, final double deltaQReported, final byte qualFromRead) { double result = 0.0; // for all optional covariates @@ -300,7 +300,7 @@ public class BaseRecalibration { result += calculateDeltaQCovariate(recalibrationTables.getTable(i), key[0], key[1], key[i], errorModel, - deltaQReported, globalDeltaQ, qualFromRead); + globalDeltaQ, deltaQReported, qualFromRead); } return result; @@ -311,13 +311,13 @@ public class BaseRecalibration { final int qualKey, final int tableKey, final EventType errorModel, - final double deltaQReported, final double globalDeltaQ, + final double deltaQReported, final byte qualFromRead) { final RecalDatum empiricalQualCO = table.get(rgKey, qualKey, tableKey, errorModel.ordinal()); if (empiricalQualCO != null) { - final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality( deltaQReported + globalDeltaQ + qualFromRead ); - return deltaQCovariateEmpirical - ( deltaQReported + globalDeltaQ + qualFromRead ); + final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality(deltaQReported + globalDeltaQ + qualFromRead); + return deltaQCovariateEmpirical - (deltaQReported + globalDeltaQ + qualFromRead); } else { return 0.0; } diff --git a/protected/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java b/protected/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java index 743c56ee5..67794c248 100644 --- a/protected/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java +++ b/protected/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java @@ -77,9 +77,7 @@ import com.google.java.contract.Requires; import org.apache.commons.math.optimization.fitting.GaussianFunction; import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.QualityUtils; - -import java.util.HashMap; -import java.util.Map; +import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; /** @@ -114,11 +112,6 @@ public class RecalDatum { */ private double empiricalQuality; - /** - * the empirical quality for datums that have been collapsed together (by read group and reported quality, for example) - */ - private Map empiricalQualityMap; - /** * number of bases seen in total */ @@ -156,7 +149,6 @@ public class RecalDatum { numMismatches = _numMismatches; estimatedQReported = reportedQuality; empiricalQuality = UNINITIALIZED; - empiricalQualityMap = new HashMap(); } /** @@ -168,7 +160,6 @@ public class RecalDatum { this.numMismatches = copy.getNumMismatches(); this.estimatedQReported = copy.estimatedQReported; this.empiricalQuality = copy.empiricalQuality; - empiricalQualityMap = copy.empiricalQualityMap; } /** @@ -182,7 +173,6 @@ public class RecalDatum { increment(other.getNumObservations(), other.getNumMismatches()); estimatedQReported = -10 * Math.log10(sumErrors / getNumObservations()); empiricalQuality = UNINITIALIZED; - empiricalQualityMap = new HashMap(); } public synchronized void setEstimatedQReported(final double estimatedQReported) { @@ -232,8 +222,13 @@ public class RecalDatum { } public final double getEmpiricalQuality() { - if (empiricalQuality == UNINITIALIZED) - calcEmpiricalQuality(); + return getEmpiricalQuality(getEstimatedQReported()); + } + + public final double getEmpiricalQuality(final double conditionalPrior) { + if (empiricalQuality == UNINITIALIZED) { + calcEmpiricalQuality(conditionalPrior); + } return empiricalQuality; } @@ -325,37 +320,22 @@ public class RecalDatum { return getNumObservations() * QualityUtils.qualToErrorProb(estimatedQReported); } - /** - * Calculate empirical quality score from mismatches, observations, and the given conditional prior (expensive operation) - */ - public double getEmpiricalQuality(final double conditionalPriorQ) { - - final int priorQKey = MathUtils.fastRound(conditionalPriorQ); - Double returnQ = empiricalQualityMap.get(priorQKey); - - if( returnQ == null ) { - // smoothing is one error and one non-error observation - final long mismatches = (long)(getNumMismatches() + 0.5) + SMOOTHING_CONSTANT; - final long observations = getNumObservations() + SMOOTHING_CONSTANT + SMOOTHING_CONSTANT; - - final double empiricalQual = RecalDatum.bayesianEstimateOfEmpiricalQuality(observations, mismatches, conditionalPriorQ); - - // This is the old and busted point estimate approach: - //final double empiricalQual = -10 * Math.log10(getEmpiricalErrorRate()); - - returnQ = Math.min(empiricalQual, (double) QualityUtils.MAX_RECALIBRATED_Q_SCORE); - empiricalQualityMap.put( priorQKey, returnQ ); - } - return returnQ; - } - /** * Calculate and cache the empirical quality score from mismatches and observations (expensive operation) */ @Requires("empiricalQuality == UNINITIALIZED") @Ensures("empiricalQuality != UNINITIALIZED") - private synchronized void calcEmpiricalQuality() { - final double empiricalQual = getEmpiricalQuality(getEstimatedQReported()); + private synchronized void calcEmpiricalQuality(final double conditionalPrior) { + + // smoothing is one error and one non-error observation + final long mismatches = (long)(getNumMismatches() + 0.5) + SMOOTHING_CONSTANT; + final long observations = getNumObservations() + SMOOTHING_CONSTANT + SMOOTHING_CONSTANT; + + final double empiricalQual = RecalDatum.bayesianEstimateOfEmpiricalQuality(observations, mismatches, conditionalPrior); + + // This is the old and busted point estimate approach: + //final double empiricalQual = -10 * Math.log10(getEmpiricalErrorRate()); + empiricalQuality = Math.min(empiricalQual, (double) QualityUtils.MAX_RECALIBRATED_Q_SCORE); } diff --git a/protected/java/src/org/broadinstitute/sting/utils/recalibration/RecalibrationReport.java b/protected/java/src/org/broadinstitute/sting/utils/recalibration/RecalibrationReport.java index 3fdbd63bd..bf8f00ca4 100644 --- a/protected/java/src/org/broadinstitute/sting/utils/recalibration/RecalibrationReport.java +++ b/protected/java/src/org/broadinstitute/sting/utils/recalibration/RecalibrationReport.java @@ -246,7 +246,7 @@ public class RecalibrationReport { private RecalDatum getRecalDatum(final GATKReportTable reportTable, final int row, final boolean hasEstimatedQReportedColumn) { final long nObservations = asLong(reportTable.get(row, RecalUtils.NUMBER_OBSERVATIONS_COLUMN_NAME)); final double nErrors = asDouble(reportTable.get(row, RecalUtils.NUMBER_ERRORS_COLUMN_NAME)); - final double empiricalQuality = asDouble(reportTable.get(row, RecalUtils.EMPIRICAL_QUALITY_COLUMN_NAME)); + //final double empiricalQuality = asDouble(reportTable.get(row, RecalUtils.EMPIRICAL_QUALITY_COLUMN_NAME)); // the estimatedQreported column only exists in the ReadGroup table final double estimatedQReported = hasEstimatedQReportedColumn ? @@ -255,7 +255,7 @@ public class RecalibrationReport { final RecalDatum datum = new RecalDatum(nObservations, nErrors, (byte)1); datum.setEstimatedQReported(estimatedQReported); - datum.setEmpiricalQuality(empiricalQuality); + //datum.setEmpiricalQuality(empiricalQuality); return datum; }