The Bayesian calculation of Qemp in the BQSR is now hierarchical. This fixes issues in which the covariate bins were very sparse and the prior estimate being used was the original quality score. This resulted in large correction factors for each covariate which breaks the equation. There is also now a new option, qlobalQScorePrior, which can be used to ignore the given (very high) quality scores and instead use this value as the prior.

This commit is contained in:
Ryan Poplin 2013-01-28 15:56:33 -05:00
parent aab160372a
commit d665a8ba0c
5 changed files with 80 additions and 37 deletions

View File

@ -67,7 +67,7 @@ public class BQSRReadTransformer extends ReadTransformer {
this.enabled = engine.hasBQSRArgumentSet(); this.enabled = engine.hasBQSRArgumentSet();
if ( enabled ) { if ( enabled ) {
final BQSRArgumentSet args = engine.getBQSRArgumentSet(); final BQSRArgumentSet args = engine.getBQSRArgumentSet();
this.bqsr = new BaseRecalibration(args.getRecalFile(), args.getQuantizationLevels(), args.shouldDisableIndelQuals(), args.getPreserveQscoresLessThan(), args.shouldEmitOriginalQuals()); this.bqsr = new BaseRecalibration(args.getRecalFile(), args.getQuantizationLevels(), args.shouldDisableIndelQuals(), args.getPreserveQscoresLessThan(), args.shouldEmitOriginalQuals(), args.getGlobalQScorePrior());
} }
final BQSRMode mode = WalkerManager.getWalkerAnnotation(walker, BQSRMode.class); final BQSRMode mode = WalkerManager.getWalkerAnnotation(walker, BQSRMode.class);
return mode.ApplicationTime(); return mode.ApplicationTime();

View File

@ -75,12 +75,12 @@ public class BaseRecalibration {
private final boolean disableIndelQuals; private final boolean disableIndelQuals;
private final int preserveQLessThan; private final int preserveQLessThan;
private final double globalQScorePrior;
private final boolean emitOriginalQuals; private final boolean emitOriginalQuals;
private final NestedIntegerArray<Double> globalDeltaQs; private final NestedIntegerArray<Double> globalDeltaQs;
private final NestedIntegerArray<Double> deltaQReporteds; private final NestedIntegerArray<Double> deltaQReporteds;
/** /**
* Constructor using a GATK Report file * Constructor using a GATK Report file
* *
@ -89,7 +89,7 @@ public class BaseRecalibration {
* @param disableIndelQuals if true, do not emit base indel qualities * @param disableIndelQuals if true, do not emit base indel qualities
* @param preserveQLessThan preserve quality scores less than this value * @param preserveQLessThan preserve quality scores less than this value
*/ */
public BaseRecalibration(final File RECAL_FILE, final int quantizationLevels, final boolean disableIndelQuals, final int preserveQLessThan, final boolean emitOriginalQuals) { public BaseRecalibration(final File RECAL_FILE, final int quantizationLevels, final boolean disableIndelQuals, final int preserveQLessThan, final boolean emitOriginalQuals, final double globalQScorePrior) {
RecalibrationReport recalibrationReport = new RecalibrationReport(RECAL_FILE); RecalibrationReport recalibrationReport = new RecalibrationReport(RECAL_FILE);
recalibrationTables = recalibrationReport.getRecalibrationTables(); recalibrationTables = recalibrationReport.getRecalibrationTables();
@ -102,6 +102,7 @@ public class BaseRecalibration {
this.disableIndelQuals = disableIndelQuals; this.disableIndelQuals = disableIndelQuals;
this.preserveQLessThan = preserveQLessThan; this.preserveQLessThan = preserveQLessThan;
this.globalQScorePrior = globalQScorePrior;
this.emitOriginalQuals = emitOriginalQuals; this.emitOriginalQuals = emitOriginalQuals;
logger.info("Calculating cached tables..."); logger.info("Calculating cached tables...");
@ -112,13 +113,16 @@ public class BaseRecalibration {
// be done upfront, on initialization of this BaseRecalibration structure // be done upfront, on initialization of this BaseRecalibration structure
// //
final NestedIntegerArray<RecalDatum> byReadGroupTable = recalibrationTables.getReadGroupTable(); final NestedIntegerArray<RecalDatum> byReadGroupTable = recalibrationTables.getReadGroupTable();
globalDeltaQs = new NestedIntegerArray<Double>( byReadGroupTable.getDimensions() ); final NestedIntegerArray<RecalDatum> byQualTable = recalibrationTables.getQualityScoreTable();
globalDeltaQs = new NestedIntegerArray<Double>( byQualTable.getDimensions() );
logger.info("Calculating global delta Q table..."); logger.info("Calculating global delta Q table...");
for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byReadGroupTable.getAllLeaves() ) { for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byQualTable.getAllLeaves() ) {
final int rgKey = leaf.keys[0]; final int rgKey = leaf.keys[0];
final int eventIndex = leaf.keys[1]; final int qual = leaf.keys[1];
final double globalDeltaQ = calculateGlobalDeltaQ(rgKey, EventType.eventFrom(eventIndex)); final int eventIndex = leaf.keys[2];
globalDeltaQs.put(globalDeltaQ, rgKey, eventIndex); final double globalDeltaQ = calculateGlobalDeltaQ(rgKey, EventType.eventFrom(eventIndex), (byte) qual);
globalDeltaQs.put(globalDeltaQ, rgKey, qual, eventIndex);
} }
@ -127,7 +131,6 @@ public class BaseRecalibration {
// into a matrix indexed by rgGroup, qual, and event type. // into a matrix indexed by rgGroup, qual, and event type.
// the code below actually creates this cache with a NestedIntegerArray calling into the actual // the code below actually creates this cache with a NestedIntegerArray calling into the actual
// calculateDeltaQReported code. // calculateDeltaQReported code.
final NestedIntegerArray<RecalDatum> byQualTable = recalibrationTables.getQualityScoreTable();
deltaQReporteds = new NestedIntegerArray<Double>( byQualTable.getDimensions() ); deltaQReporteds = new NestedIntegerArray<Double>( byQualTable.getDimensions() );
logger.info("Calculating delta Q reported table..."); logger.info("Calculating delta Q reported table...");
for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byQualTable.getAllLeaves() ) { for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byQualTable.getAllLeaves() ) {
@ -135,7 +138,7 @@ public class BaseRecalibration {
final int qual = leaf.keys[1]; final int qual = leaf.keys[1];
final int eventIndex = leaf.keys[2]; final int eventIndex = leaf.keys[2];
final EventType event = EventType.eventFrom(eventIndex); final EventType event = EventType.eventFrom(eventIndex);
final double globalDeltaQ = getGlobalDeltaQ(rgKey, event); final double globalDeltaQ = getGlobalDeltaQ(rgKey, event, (byte)qual);
final double deltaQReported = calculateDeltaQReported(rgKey, qual, event, globalDeltaQ, (byte)qual); final double deltaQReported = calculateDeltaQReported(rgKey, qual, event, globalDeltaQ, (byte)qual);
deltaQReporteds.put(deltaQReported, rgKey, qual, eventIndex); deltaQReporteds.put(deltaQReported, rgKey, qual, eventIndex);
} }
@ -188,17 +191,17 @@ public class BaseRecalibration {
// the rg key is constant over the whole read, the global deltaQ is too // the rg key is constant over the whole read, the global deltaQ is too
final int rgKey = fullReadKeySet[0][0]; final int rgKey = fullReadKeySet[0][0];
final double globalDeltaQ = getGlobalDeltaQ(rgKey, errorModel);
for (int offset = 0; offset < readLength; offset++) { // recalibrate all bases in the read for (int offset = 0; offset < readLength; offset++) { // recalibrate all bases in the read
final byte origQual = quals[offset]; final byte origQual = quals[offset];
final double globalDeltaQ = getGlobalDeltaQ(rgKey, errorModel, origQual);
// only recalibrate usable qualities (the original quality will come from the instrument -- reported quality) // only recalibrate usable qualities (the original quality will come from the instrument -- reported quality)
if ( origQual >= preserveQLessThan ) { if ( origQual >= preserveQLessThan ) {
// get the keyset for this base using the error model // get the keyset for this base using the error model
final int[] keySet = fullReadKeySet[offset]; final int[] keySet = fullReadKeySet[offset];
final double deltaQReported = getDeltaQReported(keySet[0], keySet[1], errorModel, globalDeltaQ); final double deltaQReported = getDeltaQReported(keySet[0], keySet[1], errorModel, globalDeltaQ, origQual);
final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, globalDeltaQ, deltaQReported, origQual); final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, deltaQReported, globalDeltaQ, origQual);
// calculate the recalibrated qual using the BQSR formula // calculate the recalibrated qual using the BQSR formula
double recalibratedQualDouble = origQual + globalDeltaQ + deltaQReported + deltaQCovariates; double recalibratedQualDouble = origQual + globalDeltaQ + deltaQReported + deltaQCovariates;
@ -218,11 +221,12 @@ public class BaseRecalibration {
} }
} }
private double getGlobalDeltaQ(final int rgKey, final EventType errorModel) { private double getGlobalDeltaQ(final int rgKey, final EventType errorModel, final byte qualFromRead) {
final Double cached = globalDeltaQs.get(rgKey, errorModel.ordinal());
final Double cached = globalDeltaQs.get(rgKey, (int) qualFromRead, errorModel.ordinal());
if ( TEST_CACHING ) { if ( TEST_CACHING ) {
final double calcd = calculateGlobalDeltaQ(rgKey, errorModel); final double calcd = calculateGlobalDeltaQ(rgKey, errorModel, qualFromRead);
if ( calcd != cached ) if ( calcd != cached )
throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + errorModel); throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + errorModel);
} }
@ -230,7 +234,8 @@ public class BaseRecalibration {
return cachedWithDefault(cached); return cachedWithDefault(cached);
} }
private double getDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ) { private double getDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ, final byte origQual) {
final Double cached = deltaQReporteds.get(rgKey, qualKey, errorModel.ordinal()); final Double cached = deltaQReporteds.get(rgKey, qualKey, errorModel.ordinal());
if ( TEST_CACHING ) { if ( TEST_CACHING ) {
@ -240,6 +245,7 @@ public class BaseRecalibration {
} }
return cachedWithDefault(cached); return cachedWithDefault(cached);
} }
/** /**
@ -258,14 +264,14 @@ public class BaseRecalibration {
* @param errorModel the event type * @param errorModel the event type
* @return global delta Q * @return global delta Q
*/ */
private double calculateGlobalDeltaQ(final int rgKey, final EventType errorModel) { private double calculateGlobalDeltaQ(final int rgKey, final EventType errorModel, final byte qualFromRead) {
double result = 0.0; double result = 0.0;
final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.ordinal()); final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.ordinal());
if (empiricalQualRG != null) { if (empiricalQualRG != null) {
final double globalDeltaQEmpirical = empiricalQualRG.getEmpiricalQuality(); final double aggregrateQReported = ( globalQScorePrior > 0.0 && errorModel.equals(EventType.BASE_SUBSTITUTION) ? globalQScorePrior : qualFromRead );
final double aggregrateQReported = empiricalQualRG.getEstimatedQReported(); final double globalDeltaQEmpirical = empiricalQualRG.getEmpiricalQuality( aggregrateQReported );
result = globalDeltaQEmpirical - aggregrateQReported; result = globalDeltaQEmpirical - aggregrateQReported;
} }
@ -277,14 +283,14 @@ public class BaseRecalibration {
final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(rgKey, qualKey, errorModel.ordinal()); final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(rgKey, qualKey, errorModel.ordinal());
if (empiricalQualQS != null) { if (empiricalQualQS != null) {
final double deltaQReportedEmpirical = empiricalQualQS.getEmpiricalQuality(); final double deltaQReportedEmpirical = empiricalQualQS.getEmpiricalQuality( qualFromRead + globalDeltaQ );
result = deltaQReportedEmpirical - qualFromRead - globalDeltaQ; result = deltaQReportedEmpirical - ( qualFromRead + globalDeltaQ );
} }
return result; return result;
} }
private double calculateDeltaQCovariates(final RecalibrationTables recalibrationTables, final int[] key, final EventType errorModel, final double globalDeltaQ, final double deltaQReported, final byte qualFromRead) { private double calculateDeltaQCovariates(final RecalibrationTables recalibrationTables, final int[] key, final EventType errorModel, final double deltaQReported, final double globalDeltaQ, final byte qualFromRead) {
double result = 0.0; double result = 0.0;
// for all optional covariates // for all optional covariates
@ -294,7 +300,7 @@ public class BaseRecalibration {
result += calculateDeltaQCovariate(recalibrationTables.getTable(i), result += calculateDeltaQCovariate(recalibrationTables.getTable(i),
key[0], key[1], key[i], errorModel, key[0], key[1], key[i], errorModel,
globalDeltaQ, deltaQReported, qualFromRead); deltaQReported, globalDeltaQ, qualFromRead);
} }
return result; return result;
@ -305,13 +311,13 @@ public class BaseRecalibration {
final int qualKey, final int qualKey,
final int tableKey, final int tableKey,
final EventType errorModel, final EventType errorModel,
final double globalDeltaQ,
final double deltaQReported, final double deltaQReported,
final double globalDeltaQ,
final byte qualFromRead) { final byte qualFromRead) {
final RecalDatum empiricalQualCO = table.get(rgKey, qualKey, tableKey, errorModel.ordinal()); final RecalDatum empiricalQualCO = table.get(rgKey, qualKey, tableKey, errorModel.ordinal());
if (empiricalQualCO != null) { if (empiricalQualCO != null) {
final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality(); final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality( deltaQReported + globalDeltaQ + qualFromRead );
return deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported); return deltaQCovariateEmpirical - ( deltaQReported + globalDeltaQ + qualFromRead );
} else { } else {
return 0.0; return 0.0;
} }

View File

@ -78,6 +78,9 @@ import org.apache.commons.math.optimization.fitting.GaussianFunction;
import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.QualityUtils; import org.broadinstitute.sting.utils.QualityUtils;
import java.util.HashMap;
import java.util.Map;
/** /**
* An individual piece of recalibration data. Each bin counts up the number of observations and the number * An individual piece of recalibration data. Each bin counts up the number of observations and the number
@ -111,6 +114,11 @@ public class RecalDatum {
*/ */
private double empiricalQuality; private double empiricalQuality;
/**
* the empirical quality for datums that have been collapsed together (by read group and reported quality, for example)
*/
private Map<Integer,Double> empiricalQualityMap;
/** /**
* number of bases seen in total * number of bases seen in total
*/ */
@ -148,6 +156,7 @@ public class RecalDatum {
numMismatches = _numMismatches; numMismatches = _numMismatches;
estimatedQReported = reportedQuality; estimatedQReported = reportedQuality;
empiricalQuality = UNINITIALIZED; empiricalQuality = UNINITIALIZED;
empiricalQualityMap = new HashMap<Integer, Double>();
} }
/** /**
@ -159,6 +168,7 @@ public class RecalDatum {
this.numMismatches = copy.getNumMismatches(); this.numMismatches = copy.getNumMismatches();
this.estimatedQReported = copy.estimatedQReported; this.estimatedQReported = copy.estimatedQReported;
this.empiricalQuality = copy.empiricalQuality; this.empiricalQuality = copy.empiricalQuality;
empiricalQualityMap = copy.empiricalQualityMap;
} }
/** /**
@ -172,6 +182,7 @@ public class RecalDatum {
increment(other.getNumObservations(), other.getNumMismatches()); increment(other.getNumObservations(), other.getNumMismatches());
estimatedQReported = -10 * Math.log10(sumErrors / getNumObservations()); estimatedQReported = -10 * Math.log10(sumErrors / getNumObservations());
empiricalQuality = UNINITIALIZED; empiricalQuality = UNINITIALIZED;
empiricalQualityMap = new HashMap<Integer, Double>();
} }
public synchronized void setEstimatedQReported(final double estimatedQReported) { public synchronized void setEstimatedQReported(final double estimatedQReported) {
@ -314,22 +325,37 @@ public class RecalDatum {
return getNumObservations() * QualityUtils.qualToErrorProb(estimatedQReported); return getNumObservations() * QualityUtils.qualToErrorProb(estimatedQReported);
} }
/**
* Calculate empirical quality score from mismatches, observations, and the given conditional prior (expensive operation)
*/
public double getEmpiricalQuality(final double conditionalPriorQ) {
final int priorQKey = MathUtils.fastRound(conditionalPriorQ);
Double returnQ = empiricalQualityMap.get(priorQKey);
if( returnQ == null ) {
// smoothing is one error and one non-error observation
final long mismatches = (long)(getNumMismatches() + 0.5) + SMOOTHING_CONSTANT;
final long observations = getNumObservations() + SMOOTHING_CONSTANT + SMOOTHING_CONSTANT;
final double empiricalQual = RecalDatum.bayesianEstimateOfEmpiricalQuality(observations, mismatches, conditionalPriorQ);
// This is the old and busted point estimate approach:
//final double empiricalQual = -10 * Math.log10(getEmpiricalErrorRate());
returnQ = Math.min(empiricalQual, (double) QualityUtils.MAX_RECALIBRATED_Q_SCORE);
empiricalQualityMap.put( priorQKey, returnQ );
}
return returnQ;
}
/** /**
* Calculate and cache the empirical quality score from mismatches and observations (expensive operation) * Calculate and cache the empirical quality score from mismatches and observations (expensive operation)
*/ */
@Requires("empiricalQuality == UNINITIALIZED") @Requires("empiricalQuality == UNINITIALIZED")
@Ensures("empiricalQuality != UNINITIALIZED") @Ensures("empiricalQuality != UNINITIALIZED")
private synchronized void calcEmpiricalQuality() { private synchronized void calcEmpiricalQuality() {
final double empiricalQual = getEmpiricalQuality(getEstimatedQReported());
// smoothing is one error and one non-error observation
final long mismatches = (long)(getNumMismatches() + 0.5) + SMOOTHING_CONSTANT;
final long observations = getNumObservations() + SMOOTHING_CONSTANT + SMOOTHING_CONSTANT;
final double empiricalQual = RecalDatum.bayesianEstimateOfEmpiricalQuality(observations, mismatches, getEstimatedQReported());
// This is the old and busted point estimate approach:
//final double empiricalQual = -10 * Math.log10(getEmpiricalErrorRate());
empiricalQuality = Math.min(empiricalQual, (double) QualityUtils.MAX_RECALIBRATED_Q_SCORE); empiricalQuality = Math.min(empiricalQual, (double) QualityUtils.MAX_RECALIBRATED_Q_SCORE);
} }

View File

@ -241,6 +241,9 @@ public class GATKArgumentCollection {
@Argument(fullName = "preserve_qscores_less_than", shortName = "preserveQ", doc = "Bases with quality scores less than this threshold won't be recalibrated (with -BQSR)", required = false) @Argument(fullName = "preserve_qscores_less_than", shortName = "preserveQ", doc = "Bases with quality scores less than this threshold won't be recalibrated (with -BQSR)", required = false)
public int PRESERVE_QSCORES_LESS_THAN = QualityUtils.MIN_USABLE_Q_SCORE; public int PRESERVE_QSCORES_LESS_THAN = QualityUtils.MIN_USABLE_Q_SCORE;
@Argument(fullName = "qlobalQScorePrior", shortName = "qlobalQScorePrior", doc = "The global Qscore Bayesian prior to use in the BQSR. If specified, this value will be used as the prior for all mismatch quality scores instead of the actual reported quality score", required = false)
public double globalQScorePrior = -1.0;
// -------------------------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------------------------
// //
// Other utility arguments // Other utility arguments

View File

@ -36,6 +36,7 @@ public class BQSRArgumentSet {
private boolean disableIndelQuals; private boolean disableIndelQuals;
private boolean emitOriginalQuals; private boolean emitOriginalQuals;
private int PRESERVE_QSCORES_LESS_THAN; private int PRESERVE_QSCORES_LESS_THAN;
private double globalQScorePrior;
public BQSRArgumentSet(final GATKArgumentCollection args) { public BQSRArgumentSet(final GATKArgumentCollection args) {
this.BQSR_RECAL_FILE = args.BQSR_RECAL_FILE; this.BQSR_RECAL_FILE = args.BQSR_RECAL_FILE;
@ -43,6 +44,7 @@ public class BQSRArgumentSet {
this.disableIndelQuals = args.disableIndelQuals; this.disableIndelQuals = args.disableIndelQuals;
this.emitOriginalQuals = args.emitOriginalQuals; this.emitOriginalQuals = args.emitOriginalQuals;
this.PRESERVE_QSCORES_LESS_THAN = args.PRESERVE_QSCORES_LESS_THAN; this.PRESERVE_QSCORES_LESS_THAN = args.PRESERVE_QSCORES_LESS_THAN;
this.globalQScorePrior = args.globalQScorePrior;
} }
public File getRecalFile() { return BQSR_RECAL_FILE; } public File getRecalFile() { return BQSR_RECAL_FILE; }
@ -55,6 +57,8 @@ public class BQSRArgumentSet {
public int getPreserveQscoresLessThan() { return PRESERVE_QSCORES_LESS_THAN; } public int getPreserveQscoresLessThan() { return PRESERVE_QSCORES_LESS_THAN; }
public double getGlobalQScorePrior() { return globalQScorePrior; }
public void setRecalFile(final File BQSR_RECAL_FILE) { public void setRecalFile(final File BQSR_RECAL_FILE) {
this.BQSR_RECAL_FILE = BQSR_RECAL_FILE; this.BQSR_RECAL_FILE = BQSR_RECAL_FILE;
} }
@ -74,4 +78,8 @@ public class BQSRArgumentSet {
public void setPreserveQscoresLessThan(final int PRESERVE_QSCORES_LESS_THAN) { public void setPreserveQscoresLessThan(final int PRESERVE_QSCORES_LESS_THAN) {
this.PRESERVE_QSCORES_LESS_THAN = PRESERVE_QSCORES_LESS_THAN; this.PRESERVE_QSCORES_LESS_THAN = PRESERVE_QSCORES_LESS_THAN;
} }
public void setGlobalQScorePrior(final double globalQScorePrior) {
this.globalQScorePrior = globalQScorePrior;
}
} }