From d665a8ba0caf7a5a5da592ca6a1e52571c8a5d59 Mon Sep 17 00:00:00 2001 From: Ryan Poplin Date: Mon, 28 Jan 2013 15:56:33 -0500 Subject: [PATCH] The Bayesian calculation of Qemp in the BQSR is now hierarchical. This fixes issues in which the covariate bins were very sparse and the prior estimate being used was the original quality score. This resulted in large correction factors for each covariate which breaks the equation. There is also now a new option, qlobalQScorePrior, which can be used to ignore the given (very high) quality scores and instead use this value as the prior. --- .../recalibration/BQSRReadTransformer.java | 2 +- .../recalibration/BaseRecalibration.java | 58 ++++++++++--------- .../sting/utils/recalibration/RecalDatum.java | 46 +++++++++++---- .../arguments/GATKArgumentCollection.java | 3 + .../utils/recalibration/BQSRArgumentSet.java | 8 +++ 5 files changed, 80 insertions(+), 37 deletions(-) diff --git a/protected/java/src/org/broadinstitute/sting/utils/recalibration/BQSRReadTransformer.java b/protected/java/src/org/broadinstitute/sting/utils/recalibration/BQSRReadTransformer.java index f6e63deec..c85072fa2 100644 --- a/protected/java/src/org/broadinstitute/sting/utils/recalibration/BQSRReadTransformer.java +++ b/protected/java/src/org/broadinstitute/sting/utils/recalibration/BQSRReadTransformer.java @@ -67,7 +67,7 @@ public class BQSRReadTransformer extends ReadTransformer { this.enabled = engine.hasBQSRArgumentSet(); if ( enabled ) { final BQSRArgumentSet args = engine.getBQSRArgumentSet(); - this.bqsr = new BaseRecalibration(args.getRecalFile(), args.getQuantizationLevels(), args.shouldDisableIndelQuals(), args.getPreserveQscoresLessThan(), args.shouldEmitOriginalQuals()); + this.bqsr = new BaseRecalibration(args.getRecalFile(), args.getQuantizationLevels(), args.shouldDisableIndelQuals(), args.getPreserveQscoresLessThan(), args.shouldEmitOriginalQuals(), args.getGlobalQScorePrior()); } final BQSRMode mode = WalkerManager.getWalkerAnnotation(walker, BQSRMode.class); return mode.ApplicationTime(); diff --git a/protected/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java b/protected/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java index 828f91c6f..6852cc40f 100644 --- a/protected/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java +++ b/protected/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java @@ -75,12 +75,12 @@ public class BaseRecalibration { private final boolean disableIndelQuals; private final int preserveQLessThan; + private final double globalQScorePrior; private final boolean emitOriginalQuals; private final NestedIntegerArray globalDeltaQs; private final NestedIntegerArray deltaQReporteds; - /** * Constructor using a GATK Report file * @@ -89,7 +89,7 @@ public class BaseRecalibration { * @param disableIndelQuals if true, do not emit base indel qualities * @param preserveQLessThan preserve quality scores less than this value */ - public BaseRecalibration(final File RECAL_FILE, final int quantizationLevels, final boolean disableIndelQuals, final int preserveQLessThan, final boolean emitOriginalQuals) { + public BaseRecalibration(final File RECAL_FILE, final int quantizationLevels, final boolean disableIndelQuals, final int preserveQLessThan, final boolean emitOriginalQuals, final double globalQScorePrior) { RecalibrationReport recalibrationReport = new RecalibrationReport(RECAL_FILE); recalibrationTables = recalibrationReport.getRecalibrationTables(); @@ -102,6 +102,7 @@ public class BaseRecalibration { this.disableIndelQuals = disableIndelQuals; this.preserveQLessThan = preserveQLessThan; + this.globalQScorePrior = globalQScorePrior; this.emitOriginalQuals = emitOriginalQuals; logger.info("Calculating cached tables..."); @@ -112,13 +113,16 @@ public class BaseRecalibration { // be done upfront, on initialization of this BaseRecalibration structure // final NestedIntegerArray byReadGroupTable = recalibrationTables.getReadGroupTable(); - globalDeltaQs = new NestedIntegerArray( byReadGroupTable.getDimensions() ); + final NestedIntegerArray byQualTable = recalibrationTables.getQualityScoreTable(); + + globalDeltaQs = new NestedIntegerArray( byQualTable.getDimensions() ); logger.info("Calculating global delta Q table..."); - for ( NestedIntegerArray.Leaf leaf : byReadGroupTable.getAllLeaves() ) { + for ( NestedIntegerArray.Leaf leaf : byQualTable.getAllLeaves() ) { final int rgKey = leaf.keys[0]; - final int eventIndex = leaf.keys[1]; - final double globalDeltaQ = calculateGlobalDeltaQ(rgKey, EventType.eventFrom(eventIndex)); - globalDeltaQs.put(globalDeltaQ, rgKey, eventIndex); + final int qual = leaf.keys[1]; + final int eventIndex = leaf.keys[2]; + final double globalDeltaQ = calculateGlobalDeltaQ(rgKey, EventType.eventFrom(eventIndex), (byte) qual); + globalDeltaQs.put(globalDeltaQ, rgKey, qual, eventIndex); } @@ -127,7 +131,6 @@ public class BaseRecalibration { // into a matrix indexed by rgGroup, qual, and event type. // the code below actually creates this cache with a NestedIntegerArray calling into the actual // calculateDeltaQReported code. - final NestedIntegerArray byQualTable = recalibrationTables.getQualityScoreTable(); deltaQReporteds = new NestedIntegerArray( byQualTable.getDimensions() ); logger.info("Calculating delta Q reported table..."); for ( NestedIntegerArray.Leaf leaf : byQualTable.getAllLeaves() ) { @@ -135,7 +138,7 @@ public class BaseRecalibration { final int qual = leaf.keys[1]; final int eventIndex = leaf.keys[2]; final EventType event = EventType.eventFrom(eventIndex); - final double globalDeltaQ = getGlobalDeltaQ(rgKey, event); + final double globalDeltaQ = getGlobalDeltaQ(rgKey, event, (byte)qual); final double deltaQReported = calculateDeltaQReported(rgKey, qual, event, globalDeltaQ, (byte)qual); deltaQReporteds.put(deltaQReported, rgKey, qual, eventIndex); } @@ -188,17 +191,17 @@ public class BaseRecalibration { // the rg key is constant over the whole read, the global deltaQ is too final int rgKey = fullReadKeySet[0][0]; - final double globalDeltaQ = getGlobalDeltaQ(rgKey, errorModel); for (int offset = 0; offset < readLength; offset++) { // recalibrate all bases in the read final byte origQual = quals[offset]; + final double globalDeltaQ = getGlobalDeltaQ(rgKey, errorModel, origQual); // only recalibrate usable qualities (the original quality will come from the instrument -- reported quality) if ( origQual >= preserveQLessThan ) { // get the keyset for this base using the error model final int[] keySet = fullReadKeySet[offset]; - final double deltaQReported = getDeltaQReported(keySet[0], keySet[1], errorModel, globalDeltaQ); - final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, globalDeltaQ, deltaQReported, origQual); + final double deltaQReported = getDeltaQReported(keySet[0], keySet[1], errorModel, globalDeltaQ, origQual); + final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, deltaQReported, globalDeltaQ, origQual); // calculate the recalibrated qual using the BQSR formula double recalibratedQualDouble = origQual + globalDeltaQ + deltaQReported + deltaQCovariates; @@ -218,11 +221,12 @@ public class BaseRecalibration { } } - private double getGlobalDeltaQ(final int rgKey, final EventType errorModel) { - final Double cached = globalDeltaQs.get(rgKey, errorModel.ordinal()); + private double getGlobalDeltaQ(final int rgKey, final EventType errorModel, final byte qualFromRead) { + + final Double cached = globalDeltaQs.get(rgKey, (int) qualFromRead, errorModel.ordinal()); if ( TEST_CACHING ) { - final double calcd = calculateGlobalDeltaQ(rgKey, errorModel); + final double calcd = calculateGlobalDeltaQ(rgKey, errorModel, qualFromRead); if ( calcd != cached ) throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + errorModel); } @@ -230,7 +234,8 @@ public class BaseRecalibration { return cachedWithDefault(cached); } - private double getDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ) { + private double getDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ, final byte origQual) { + final Double cached = deltaQReporteds.get(rgKey, qualKey, errorModel.ordinal()); if ( TEST_CACHING ) { @@ -240,6 +245,7 @@ public class BaseRecalibration { } return cachedWithDefault(cached); + } /** @@ -258,14 +264,14 @@ public class BaseRecalibration { * @param errorModel the event type * @return global delta Q */ - private double calculateGlobalDeltaQ(final int rgKey, final EventType errorModel) { + private double calculateGlobalDeltaQ(final int rgKey, final EventType errorModel, final byte qualFromRead) { double result = 0.0; final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.ordinal()); if (empiricalQualRG != null) { - final double globalDeltaQEmpirical = empiricalQualRG.getEmpiricalQuality(); - final double aggregrateQReported = empiricalQualRG.getEstimatedQReported(); + final double aggregrateQReported = ( globalQScorePrior > 0.0 && errorModel.equals(EventType.BASE_SUBSTITUTION) ? globalQScorePrior : qualFromRead ); + final double globalDeltaQEmpirical = empiricalQualRG.getEmpiricalQuality( aggregrateQReported ); result = globalDeltaQEmpirical - aggregrateQReported; } @@ -277,14 +283,14 @@ public class BaseRecalibration { final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(rgKey, qualKey, errorModel.ordinal()); if (empiricalQualQS != null) { - final double deltaQReportedEmpirical = empiricalQualQS.getEmpiricalQuality(); - result = deltaQReportedEmpirical - qualFromRead - globalDeltaQ; + final double deltaQReportedEmpirical = empiricalQualQS.getEmpiricalQuality( qualFromRead + globalDeltaQ ); + result = deltaQReportedEmpirical - ( qualFromRead + globalDeltaQ ); } return result; } - private double calculateDeltaQCovariates(final RecalibrationTables recalibrationTables, final int[] key, final EventType errorModel, final double globalDeltaQ, final double deltaQReported, final byte qualFromRead) { + private double calculateDeltaQCovariates(final RecalibrationTables recalibrationTables, final int[] key, final EventType errorModel, final double deltaQReported, final double globalDeltaQ, final byte qualFromRead) { double result = 0.0; // for all optional covariates @@ -294,7 +300,7 @@ public class BaseRecalibration { result += calculateDeltaQCovariate(recalibrationTables.getTable(i), key[0], key[1], key[i], errorModel, - globalDeltaQ, deltaQReported, qualFromRead); + deltaQReported, globalDeltaQ, qualFromRead); } return result; @@ -305,13 +311,13 @@ public class BaseRecalibration { final int qualKey, final int tableKey, final EventType errorModel, - final double globalDeltaQ, final double deltaQReported, + final double globalDeltaQ, final byte qualFromRead) { final RecalDatum empiricalQualCO = table.get(rgKey, qualKey, tableKey, errorModel.ordinal()); if (empiricalQualCO != null) { - final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality(); - return deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported); + final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality( deltaQReported + globalDeltaQ + qualFromRead ); + return deltaQCovariateEmpirical - ( deltaQReported + globalDeltaQ + qualFromRead ); } else { return 0.0; } diff --git a/protected/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java b/protected/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java index 9430caaac..743c56ee5 100644 --- a/protected/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java +++ b/protected/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java @@ -78,6 +78,9 @@ import org.apache.commons.math.optimization.fitting.GaussianFunction; import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.QualityUtils; +import java.util.HashMap; +import java.util.Map; + /** * An individual piece of recalibration data. Each bin counts up the number of observations and the number @@ -111,6 +114,11 @@ public class RecalDatum { */ private double empiricalQuality; + /** + * the empirical quality for datums that have been collapsed together (by read group and reported quality, for example) + */ + private Map empiricalQualityMap; + /** * number of bases seen in total */ @@ -148,6 +156,7 @@ public class RecalDatum { numMismatches = _numMismatches; estimatedQReported = reportedQuality; empiricalQuality = UNINITIALIZED; + empiricalQualityMap = new HashMap(); } /** @@ -159,6 +168,7 @@ public class RecalDatum { this.numMismatches = copy.getNumMismatches(); this.estimatedQReported = copy.estimatedQReported; this.empiricalQuality = copy.empiricalQuality; + empiricalQualityMap = copy.empiricalQualityMap; } /** @@ -172,6 +182,7 @@ public class RecalDatum { increment(other.getNumObservations(), other.getNumMismatches()); estimatedQReported = -10 * Math.log10(sumErrors / getNumObservations()); empiricalQuality = UNINITIALIZED; + empiricalQualityMap = new HashMap(); } public synchronized void setEstimatedQReported(final double estimatedQReported) { @@ -314,22 +325,37 @@ public class RecalDatum { return getNumObservations() * QualityUtils.qualToErrorProb(estimatedQReported); } + /** + * Calculate empirical quality score from mismatches, observations, and the given conditional prior (expensive operation) + */ + public double getEmpiricalQuality(final double conditionalPriorQ) { + + final int priorQKey = MathUtils.fastRound(conditionalPriorQ); + Double returnQ = empiricalQualityMap.get(priorQKey); + + if( returnQ == null ) { + // smoothing is one error and one non-error observation + final long mismatches = (long)(getNumMismatches() + 0.5) + SMOOTHING_CONSTANT; + final long observations = getNumObservations() + SMOOTHING_CONSTANT + SMOOTHING_CONSTANT; + + final double empiricalQual = RecalDatum.bayesianEstimateOfEmpiricalQuality(observations, mismatches, conditionalPriorQ); + + // This is the old and busted point estimate approach: + //final double empiricalQual = -10 * Math.log10(getEmpiricalErrorRate()); + + returnQ = Math.min(empiricalQual, (double) QualityUtils.MAX_RECALIBRATED_Q_SCORE); + empiricalQualityMap.put( priorQKey, returnQ ); + } + return returnQ; + } + /** * Calculate and cache the empirical quality score from mismatches and observations (expensive operation) */ @Requires("empiricalQuality == UNINITIALIZED") @Ensures("empiricalQuality != UNINITIALIZED") private synchronized void calcEmpiricalQuality() { - - // smoothing is one error and one non-error observation - final long mismatches = (long)(getNumMismatches() + 0.5) + SMOOTHING_CONSTANT; - final long observations = getNumObservations() + SMOOTHING_CONSTANT + SMOOTHING_CONSTANT; - - final double empiricalQual = RecalDatum.bayesianEstimateOfEmpiricalQuality(observations, mismatches, getEstimatedQReported()); - - // This is the old and busted point estimate approach: - //final double empiricalQual = -10 * Math.log10(getEmpiricalErrorRate()); - + final double empiricalQual = getEmpiricalQuality(getEstimatedQReported()); empiricalQuality = Math.min(empiricalQual, (double) QualityUtils.MAX_RECALIBRATED_Q_SCORE); } diff --git a/public/java/src/org/broadinstitute/sting/gatk/arguments/GATKArgumentCollection.java b/public/java/src/org/broadinstitute/sting/gatk/arguments/GATKArgumentCollection.java index 9cd88001c..6520947d4 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/arguments/GATKArgumentCollection.java +++ b/public/java/src/org/broadinstitute/sting/gatk/arguments/GATKArgumentCollection.java @@ -241,6 +241,9 @@ public class GATKArgumentCollection { @Argument(fullName = "preserve_qscores_less_than", shortName = "preserveQ", doc = "Bases with quality scores less than this threshold won't be recalibrated (with -BQSR)", required = false) public int PRESERVE_QSCORES_LESS_THAN = QualityUtils.MIN_USABLE_Q_SCORE; + @Argument(fullName = "qlobalQScorePrior", shortName = "qlobalQScorePrior", doc = "The global Qscore Bayesian prior to use in the BQSR. If specified, this value will be used as the prior for all mismatch quality scores instead of the actual reported quality score", required = false) + public double globalQScorePrior = -1.0; + // -------------------------------------------------------------------------------------------------------------- // // Other utility arguments diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/BQSRArgumentSet.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/BQSRArgumentSet.java index dbf70f4ce..600700484 100644 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/BQSRArgumentSet.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/BQSRArgumentSet.java @@ -36,6 +36,7 @@ public class BQSRArgumentSet { private boolean disableIndelQuals; private boolean emitOriginalQuals; private int PRESERVE_QSCORES_LESS_THAN; + private double globalQScorePrior; public BQSRArgumentSet(final GATKArgumentCollection args) { this.BQSR_RECAL_FILE = args.BQSR_RECAL_FILE; @@ -43,6 +44,7 @@ public class BQSRArgumentSet { this.disableIndelQuals = args.disableIndelQuals; this.emitOriginalQuals = args.emitOriginalQuals; this.PRESERVE_QSCORES_LESS_THAN = args.PRESERVE_QSCORES_LESS_THAN; + this.globalQScorePrior = args.globalQScorePrior; } public File getRecalFile() { return BQSR_RECAL_FILE; } @@ -55,6 +57,8 @@ public class BQSRArgumentSet { public int getPreserveQscoresLessThan() { return PRESERVE_QSCORES_LESS_THAN; } + public double getGlobalQScorePrior() { return globalQScorePrior; } + public void setRecalFile(final File BQSR_RECAL_FILE) { this.BQSR_RECAL_FILE = BQSR_RECAL_FILE; } @@ -74,4 +78,8 @@ public class BQSRArgumentSet { public void setPreserveQscoresLessThan(final int PRESERVE_QSCORES_LESS_THAN) { this.PRESERVE_QSCORES_LESS_THAN = PRESERVE_QSCORES_LESS_THAN; } + + public void setGlobalQScorePrior(final double globalQScorePrior) { + this.globalQScorePrior = globalQScorePrior; + } }