From 1de2f527b95872d06320107fe47489e1756937f6 Mon Sep 17 00:00:00 2001 From: Mark DePristo Date: Thu, 13 Dec 2012 17:21:36 -0500 Subject: [PATCH] Optimization of recalibrateRead -- Refactor calculation so that upfront constant values are pre-computed, and cached, and their values just looked up during application -- Trivial comment on how we might use BAQ better in BaseRecalibrator --- .../gatk/walkers/bqsr/BaseRecalibrator.java | 1 + .../utils/collections/NestedIntegerArray.java | 13 +- .../recalibration/BaseRecalibration.java | 189 ++++++++++++++---- .../utils/recalibration/ReadCovariates.java | 2 +- 4 files changed, 158 insertions(+), 47 deletions(-) diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java index f037f861f..4d7dbc912 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java @@ -416,6 +416,7 @@ public class BaseRecalibrator extends ReadWalker implements NanoSche } private byte[] calculateBAQArray( final GATKSAMRecord read ) { + // todo -- it would be good to directly use the BAQ qualities rather than encoding and decoding the result and using the special @ value baq.baqRead(read, referenceReader, BAQ.CalculationMode.RECALCULATE, BAQ.QualityMode.ADD_TAG); return BAQ.getBAQTag(read); } diff --git a/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java b/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java index 2e45eabe1..890a9b488 100755 --- a/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java +++ b/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java @@ -58,13 +58,20 @@ public class NestedIntegerArray { int dimensionsToPreallocate = Math.min(dimensions.length, NUM_DIMENSIONS_TO_PREALLOCATE); - logger.info(String.format("Creating NestedIntegerArray with dimensions %s", Arrays.toString(dimensions))); - logger.info(String.format("Pre-allocating first %d dimensions", dimensionsToPreallocate)); + if ( logger.isDebugEnabled() ) logger.debug(String.format("Creating NestedIntegerArray with dimensions %s", Arrays.toString(dimensions))); + if ( logger.isDebugEnabled() ) logger.debug(String.format("Pre-allocating first %d dimensions", dimensionsToPreallocate)); data = new Object[dimensions[0]]; preallocateArray(data, 0, dimensionsToPreallocate); - logger.info(String.format("Done pre-allocating first %d dimensions", dimensionsToPreallocate)); + if ( logger.isDebugEnabled() ) logger.debug(String.format("Done pre-allocating first %d dimensions", dimensionsToPreallocate)); + } + + /** + * @return the dimensions of this nested integer array. DO NOT MODIFY + */ + public int[] getDimensions() { + return dimensions; } /** diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java index 7d63996c3..567514f8c 100644 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java @@ -27,6 +27,7 @@ package org.broadinstitute.sting.utils.recalibration; import net.sf.samtools.SAMTag; import net.sf.samtools.SAMUtils; +import org.apache.log4j.Logger; import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.QualityUtils; import org.broadinstitute.sting.utils.collections.NestedIntegerArray; @@ -44,7 +45,8 @@ import java.io.File; */ public class BaseRecalibration { - private final static int MAXIMUM_RECALIBRATED_READ_LENGTH = 5000; + private static Logger logger = Logger.getLogger(BaseRecalibration.class); + private final static boolean TEST_CACHING = false; private final QuantizationInfo quantizationInfo; // histogram containing the map for qual quantization (calculated after recalibration is done) private final RecalibrationTables recalibrationTables; @@ -54,6 +56,10 @@ public class BaseRecalibration { private final int preserveQLessThan; private final boolean emitOriginalQuals; + private final NestedIntegerArray globalDeltaQs; + private final NestedIntegerArray deltaQReporteds; + + /** * Constructor using a GATK Report file * @@ -76,6 +82,44 @@ public class BaseRecalibration { this.disableIndelQuals = disableIndelQuals; this.preserveQLessThan = preserveQLessThan; this.emitOriginalQuals = emitOriginalQuals; + + logger.info("Calculating cached tables..."); + + // + // Create a NestedIntegerArray that maps from rgKey x errorModel -> double, + // where the double is the result of this calculation. The entire calculation can + // be done upfront, on initialization of this BaseRecalibration structure + // + final NestedIntegerArray byReadGroupTable = recalibrationTables.getReadGroupTable(); + globalDeltaQs = new NestedIntegerArray( byReadGroupTable.getDimensions() ); + logger.info("Calculating global delta Q table..."); + for ( NestedIntegerArray.Leaf leaf : byReadGroupTable.getAllLeaves() ) { + final int rgKey = leaf.keys[0]; + final int eventIndex = leaf.keys[1]; + final double globalDeltaQ = calculateGlobalDeltaQ(rgKey, EventType.eventFrom(eventIndex)); + globalDeltaQs.put(globalDeltaQ, rgKey, eventIndex); + } + + + // The calculation of the deltaQ report is constant. key[0] and key[1] are the read group and qual, respectively + // and globalDeltaQ is a constant for the read group. So technically the delta Q reported is simply a lookup + // into a matrix indexed by rgGroup, qual, and event type. + // the code below actually creates this cache with a NestedIntegerArray calling into the actual + // calculateDeltaQReported code. + final NestedIntegerArray byQualTable = recalibrationTables.getQualityScoreTable(); + deltaQReporteds = new NestedIntegerArray( byQualTable.getDimensions() ); + logger.info("Calculating delta Q reported table..."); + for ( NestedIntegerArray.Leaf leaf : byQualTable.getAllLeaves() ) { + final int rgKey = leaf.keys[0]; + final int qual = leaf.keys[1]; + final int eventIndex = leaf.keys[2]; + final EventType event = EventType.eventFrom(eventIndex); + final double globalDeltaQ = getGlobalDeltaQ(rgKey, event); + final double deltaQReported = calculateDeltaQReported(rgKey, qual, event, globalDeltaQ, (byte)qual); + deltaQReporteds.put(deltaQReported, rgKey, qual, eventIndex); + } + + logger.info("done calculating cache"); } /** @@ -83,6 +127,18 @@ public class BaseRecalibration { * * It updates the base qualities of the read with the new recalibrated qualities (for all event types) * + * Implements a serial recalibration of the reads using the combinational table. + * First, we perform a positional recalibration, and then a subsequent dinuc correction. + * + * Given the full recalibration table, we perform the following preprocessing steps: + * + * - calculate the global quality score shift across all data [DeltaQ] + * - calculate for each of cycle and dinuc the shift of the quality scores relative to the global shift + * -- i.e., DeltaQ(dinuc) = Sum(pos) Sum(Qual) Qempirical(pos, qual, dinuc) - Qreported(pos, qual, dinuc) / Npos * Nqual + * - The final shift equation is: + * + * Qrecal = Qreported + DeltaQ + DeltaQ(pos) + DeltaQ(dinuc) + DeltaQ( ... any other covariate ... ) + * * @param read the read to recalibrate */ public void recalibrateRead(final GATKSAMRecord read) { @@ -95,6 +151,7 @@ public class BaseRecalibration { } final ReadCovariates readCovariates = RecalUtils.computeCovariates(read, requestedCovariates); + final int readLength = read.getReadLength(); for (final EventType errorModel : EventType.values()) { // recalibrate all three quality strings if (disableIndelQuals && errorModel != EventType.BASE_SUBSTITUTION) { @@ -103,58 +160,88 @@ public class BaseRecalibration { } final byte[] quals = read.getBaseQualities(errorModel); - final int[][] fullReadKeySet = readCovariates.getKeySet(errorModel); // get the keyset for this base using the error model - final int readLength = read.getReadLength(); + // get the keyset for this base using the error model + final int[][] fullReadKeySet = readCovariates.getKeySet(errorModel); + + // the rg key is constant over the whole read, the global deltaQ is too + final int rgKey = fullReadKeySet[0][0]; + + final double globalDeltaQ = getGlobalDeltaQ(rgKey, errorModel); + for (int offset = 0; offset < readLength; offset++) { // recalibrate all bases in the read + final byte origQual = quals[offset]; - final byte originalQualityScore = quals[offset]; + // only recalibrate usable qualities (the original quality will come from the instrument -- reported quality) + if ( origQual >= preserveQLessThan ) { + // get the keyset for this base using the error model + final int[] keySet = fullReadKeySet[offset]; + final double deltaQReported = getDeltaQReported(keySet[0], keySet[1], errorModel, globalDeltaQ); + final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, globalDeltaQ, deltaQReported, origQual); + + // calculate the recalibrated qual using the BQSR formula + double recalibratedQualDouble = origQual + globalDeltaQ + deltaQReported + deltaQCovariates; + + // recalibrated quality is bound between 1 and MAX_QUAL + final byte recalibratedQual = QualityUtils.boundQual(MathUtils.fastRound(recalibratedQualDouble), QualityUtils.MAX_RECALIBRATED_Q_SCORE); + + // return the quantized version of the recalibrated quality + final byte recalibratedQualityScore = quantizationInfo.getQuantizedQuals().get(recalibratedQual); - if (originalQualityScore >= preserveQLessThan) { // only recalibrate usable qualities (the original quality will come from the instrument -- reported quality) - final int[] keySet = fullReadKeySet[offset]; // get the keyset for this base using the error model - final byte recalibratedQualityScore = performSequentialQualityCalculation(keySet, errorModel); // recalibrate the base quals[offset] = recalibratedQualityScore; } } + + // finally update the base qualities in the read read.setBaseQualities(quals, errorModel); } } + private double getGlobalDeltaQ(final int rgKey, final EventType errorModel) { + final Double cached = globalDeltaQs.get(rgKey, errorModel.index); - /** - * Implements a serial recalibration of the reads using the combinational table. - * First, we perform a positional recalibration, and then a subsequent dinuc correction. - * - * Given the full recalibration table, we perform the following preprocessing steps: - * - * - calculate the global quality score shift across all data [DeltaQ] - * - calculate for each of cycle and dinuc the shift of the quality scores relative to the global shift - * -- i.e., DeltaQ(dinuc) = Sum(pos) Sum(Qual) Qempirical(pos, qual, dinuc) - Qreported(pos, qual, dinuc) / Npos * Nqual - * - The final shift equation is: - * - * Qrecal = Qreported + DeltaQ + DeltaQ(pos) + DeltaQ(dinuc) + DeltaQ( ... any other covariate ... ) - * - * @param key The list of Comparables that were calculated from the covariates - * @param errorModel the event type - * @return A recalibrated quality score as a byte - */ - private byte performSequentialQualityCalculation(final int[] key, final EventType errorModel) { + if ( TEST_CACHING ) { + final double calcd = calculateGlobalDeltaQ(rgKey, errorModel); + if ( calcd != cached ) + throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + errorModel); + } - final byte qualFromRead = (byte)(long)key[1]; - final double globalDeltaQ = calculateGlobalDeltaQ(recalibrationTables.getReadGroupTable(), key, errorModel); - final double deltaQReported = calculateDeltaQReported(recalibrationTables.getQualityScoreTable(), key, errorModel, globalDeltaQ, qualFromRead); - final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, key, errorModel, globalDeltaQ, deltaQReported, qualFromRead); - - double recalibratedQual = qualFromRead + globalDeltaQ + deltaQReported + deltaQCovariates; // calculate the recalibrated qual using the BQSR formula - recalibratedQual = QualityUtils.boundQual(MathUtils.fastRound(recalibratedQual), QualityUtils.MAX_RECALIBRATED_Q_SCORE); // recalibrated quality is bound between 1 and MAX_QUAL - - return quantizationInfo.getQuantizedQuals().get((int) recalibratedQual); // return the quantized version of the recalibrated quality + return cachedWithDefault(cached); } - private double calculateGlobalDeltaQ(final NestedIntegerArray table, final int[] key, final EventType errorModel) { + private double getDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ) { + final Double cached = deltaQReporteds.get(rgKey, qualKey, errorModel.index); + + if ( TEST_CACHING ) { + final double calcd = calculateDeltaQReported(rgKey, qualKey, errorModel, globalDeltaQ, (byte)qualKey); + if ( calcd != cached ) + throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + qualKey + " / " + errorModel); + } + + return cachedWithDefault(cached); + } + + /** + * @param d a Double (that may be null) that is the result of a delta Q calculation + * @return a double == d if d != null, or 0.0 if it is + */ + private double cachedWithDefault(final Double d) { + return d == null ? 0.0 : d; + } + + /** + * Note that this calculation is a constant for each rgKey and errorModel. We need only + * compute this value once for all data. + * + * @param rgKey + * @param errorModel + * @return + */ + private double calculateGlobalDeltaQ(final int rgKey, final EventType errorModel) { double result = 0.0; - final RecalDatum empiricalQualRG = table.get(key[0], errorModel.index); + final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.index); + if (empiricalQualRG != null) { final double globalDeltaQEmpirical = empiricalQualRG.getEmpiricalQuality(); final double aggregrateQReported = empiricalQualRG.getEstimatedQReported(); @@ -164,10 +251,10 @@ public class BaseRecalibration { return result; } - private double calculateDeltaQReported(final NestedIntegerArray table, final int[] key, final EventType errorModel, final double globalDeltaQ, final byte qualFromRead) { + private double calculateDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ, final byte qualFromRead) { double result = 0.0; - final RecalDatum empiricalQualQS = table.get(key[0], key[1], errorModel.index); + final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(rgKey, qualKey, errorModel.index); if (empiricalQualQS != null) { final double deltaQReportedEmpirical = empiricalQualQS.getEmpiricalQuality(); result = deltaQReportedEmpirical - qualFromRead - globalDeltaQ; @@ -184,12 +271,28 @@ public class BaseRecalibration { if (key[i] < 0) continue; - final RecalDatum empiricalQualCO = recalibrationTables.getTable(i).get(key[0], key[1], key[i], errorModel.index); - if (empiricalQualCO != null) { - final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality(); - result += (deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported)); - } + result += calculateDeltaQCovariate(recalibrationTables.getTable(i), + key[0], key[1], key[i], errorModel, + globalDeltaQ, deltaQReported, qualFromRead); } + return result; } + + private double calculateDeltaQCovariate(final NestedIntegerArray table, + final int rgKey, + final int qualKey, + final int tableKey, + final EventType errorModel, + final double globalDeltaQ, + final double deltaQReported, + final byte qualFromRead) { + final RecalDatum empiricalQualCO = table.get(rgKey, qualKey, tableKey, errorModel.index); + if (empiricalQualCO != null) { + final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality(); + return deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported); + } else { + return 0.0; + } + } } diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/ReadCovariates.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/ReadCovariates.java index ae2b0ad28..4ddcb2b92 100644 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/ReadCovariates.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/ReadCovariates.java @@ -47,7 +47,7 @@ public class ReadCovariates { final int[][][] cachedKeys = cache.get(readLength); if ( cachedKeys == null ) { // There's no cached value for read length so we need to create a new int[][][] array - logger.info("Keys cache miss for length " + readLength + " cache size " + cache.size()); + if ( logger.isDebugEnabled() ) logger.debug("Keys cache miss for length " + readLength + " cache size " + cache.size()); keys = new int[EventType.values().length][readLength][numberOfCovariates]; cache.put(readLength, keys); } else {