Optimization of recalibrateRead

-- Refactor calculation so that upfront constant values are pre-computed, and cached, and their values just looked up during application
-- Trivial comment on how we might use BAQ better in BaseRecalibrator
This commit is contained in:
Mark DePristo 2012-12-13 17:21:36 -05:00
parent bd6cda7542
commit 1de2f527b9
4 changed files with 158 additions and 47 deletions

View File

@ -416,6 +416,7 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
}
private byte[] calculateBAQArray( final GATKSAMRecord read ) {
// todo -- it would be good to directly use the BAQ qualities rather than encoding and decoding the result and using the special @ value
baq.baqRead(read, referenceReader, BAQ.CalculationMode.RECALCULATE, BAQ.QualityMode.ADD_TAG);
return BAQ.getBAQTag(read);
}

View File

@ -58,13 +58,20 @@ public class NestedIntegerArray<T> {
int dimensionsToPreallocate = Math.min(dimensions.length, NUM_DIMENSIONS_TO_PREALLOCATE);
logger.info(String.format("Creating NestedIntegerArray with dimensions %s", Arrays.toString(dimensions)));
logger.info(String.format("Pre-allocating first %d dimensions", dimensionsToPreallocate));
if ( logger.isDebugEnabled() ) logger.debug(String.format("Creating NestedIntegerArray with dimensions %s", Arrays.toString(dimensions)));
if ( logger.isDebugEnabled() ) logger.debug(String.format("Pre-allocating first %d dimensions", dimensionsToPreallocate));
data = new Object[dimensions[0]];
preallocateArray(data, 0, dimensionsToPreallocate);
logger.info(String.format("Done pre-allocating first %d dimensions", dimensionsToPreallocate));
if ( logger.isDebugEnabled() ) logger.debug(String.format("Done pre-allocating first %d dimensions", dimensionsToPreallocate));
}
/**
* @return the dimensions of this nested integer array. DO NOT MODIFY
*/
public int[] getDimensions() {
return dimensions;
}
/**

View File

@ -27,6 +27,7 @@ package org.broadinstitute.sting.utils.recalibration;
import net.sf.samtools.SAMTag;
import net.sf.samtools.SAMUtils;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.QualityUtils;
import org.broadinstitute.sting.utils.collections.NestedIntegerArray;
@ -44,7 +45,8 @@ import java.io.File;
*/
public class BaseRecalibration {
private final static int MAXIMUM_RECALIBRATED_READ_LENGTH = 5000;
private static Logger logger = Logger.getLogger(BaseRecalibration.class);
private final static boolean TEST_CACHING = false;
private final QuantizationInfo quantizationInfo; // histogram containing the map for qual quantization (calculated after recalibration is done)
private final RecalibrationTables recalibrationTables;
@ -54,6 +56,10 @@ public class BaseRecalibration {
private final int preserveQLessThan;
private final boolean emitOriginalQuals;
private final NestedIntegerArray<Double> globalDeltaQs;
private final NestedIntegerArray<Double> deltaQReporteds;
/**
* Constructor using a GATK Report file
*
@ -76,6 +82,44 @@ public class BaseRecalibration {
this.disableIndelQuals = disableIndelQuals;
this.preserveQLessThan = preserveQLessThan;
this.emitOriginalQuals = emitOriginalQuals;
logger.info("Calculating cached tables...");
//
// Create a NestedIntegerArray<Double> that maps from rgKey x errorModel -> double,
// where the double is the result of this calculation. The entire calculation can
// be done upfront, on initialization of this BaseRecalibration structure
//
final NestedIntegerArray<RecalDatum> byReadGroupTable = recalibrationTables.getReadGroupTable();
globalDeltaQs = new NestedIntegerArray<Double>( byReadGroupTable.getDimensions() );
logger.info("Calculating global delta Q table...");
for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byReadGroupTable.getAllLeaves() ) {
final int rgKey = leaf.keys[0];
final int eventIndex = leaf.keys[1];
final double globalDeltaQ = calculateGlobalDeltaQ(rgKey, EventType.eventFrom(eventIndex));
globalDeltaQs.put(globalDeltaQ, rgKey, eventIndex);
}
// The calculation of the deltaQ report is constant. key[0] and key[1] are the read group and qual, respectively
// and globalDeltaQ is a constant for the read group. So technically the delta Q reported is simply a lookup
// into a matrix indexed by rgGroup, qual, and event type.
// the code below actually creates this cache with a NestedIntegerArray calling into the actual
// calculateDeltaQReported code.
final NestedIntegerArray<RecalDatum> byQualTable = recalibrationTables.getQualityScoreTable();
deltaQReporteds = new NestedIntegerArray<Double>( byQualTable.getDimensions() );
logger.info("Calculating delta Q reported table...");
for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byQualTable.getAllLeaves() ) {
final int rgKey = leaf.keys[0];
final int qual = leaf.keys[1];
final int eventIndex = leaf.keys[2];
final EventType event = EventType.eventFrom(eventIndex);
final double globalDeltaQ = getGlobalDeltaQ(rgKey, event);
final double deltaQReported = calculateDeltaQReported(rgKey, qual, event, globalDeltaQ, (byte)qual);
deltaQReporteds.put(deltaQReported, rgKey, qual, eventIndex);
}
logger.info("done calculating cache");
}
/**
@ -83,6 +127,18 @@ public class BaseRecalibration {
*
* It updates the base qualities of the read with the new recalibrated qualities (for all event types)
*
* Implements a serial recalibration of the reads using the combinational table.
* First, we perform a positional recalibration, and then a subsequent dinuc correction.
*
* Given the full recalibration table, we perform the following preprocessing steps:
*
* - calculate the global quality score shift across all data [DeltaQ]
* - calculate for each of cycle and dinuc the shift of the quality scores relative to the global shift
* -- i.e., DeltaQ(dinuc) = Sum(pos) Sum(Qual) Qempirical(pos, qual, dinuc) - Qreported(pos, qual, dinuc) / Npos * Nqual
* - The final shift equation is:
*
* Qrecal = Qreported + DeltaQ + DeltaQ(pos) + DeltaQ(dinuc) + DeltaQ( ... any other covariate ... )
*
* @param read the read to recalibrate
*/
public void recalibrateRead(final GATKSAMRecord read) {
@ -95,6 +151,7 @@ public class BaseRecalibration {
}
final ReadCovariates readCovariates = RecalUtils.computeCovariates(read, requestedCovariates);
final int readLength = read.getReadLength();
for (final EventType errorModel : EventType.values()) { // recalibrate all three quality strings
if (disableIndelQuals && errorModel != EventType.BASE_SUBSTITUTION) {
@ -103,58 +160,88 @@ public class BaseRecalibration {
}
final byte[] quals = read.getBaseQualities(errorModel);
final int[][] fullReadKeySet = readCovariates.getKeySet(errorModel); // get the keyset for this base using the error model
final int readLength = read.getReadLength();
// get the keyset for this base using the error model
final int[][] fullReadKeySet = readCovariates.getKeySet(errorModel);
// the rg key is constant over the whole read, the global deltaQ is too
final int rgKey = fullReadKeySet[0][0];
final double globalDeltaQ = getGlobalDeltaQ(rgKey, errorModel);
for (int offset = 0; offset < readLength; offset++) { // recalibrate all bases in the read
final byte origQual = quals[offset];
final byte originalQualityScore = quals[offset];
// only recalibrate usable qualities (the original quality will come from the instrument -- reported quality)
if ( origQual >= preserveQLessThan ) {
// get the keyset for this base using the error model
final int[] keySet = fullReadKeySet[offset];
final double deltaQReported = getDeltaQReported(keySet[0], keySet[1], errorModel, globalDeltaQ);
final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, globalDeltaQ, deltaQReported, origQual);
// calculate the recalibrated qual using the BQSR formula
double recalibratedQualDouble = origQual + globalDeltaQ + deltaQReported + deltaQCovariates;
// recalibrated quality is bound between 1 and MAX_QUAL
final byte recalibratedQual = QualityUtils.boundQual(MathUtils.fastRound(recalibratedQualDouble), QualityUtils.MAX_RECALIBRATED_Q_SCORE);
// return the quantized version of the recalibrated quality
final byte recalibratedQualityScore = quantizationInfo.getQuantizedQuals().get(recalibratedQual);
if (originalQualityScore >= preserveQLessThan) { // only recalibrate usable qualities (the original quality will come from the instrument -- reported quality)
final int[] keySet = fullReadKeySet[offset]; // get the keyset for this base using the error model
final byte recalibratedQualityScore = performSequentialQualityCalculation(keySet, errorModel); // recalibrate the base
quals[offset] = recalibratedQualityScore;
}
}
// finally update the base qualities in the read
read.setBaseQualities(quals, errorModel);
}
}
private double getGlobalDeltaQ(final int rgKey, final EventType errorModel) {
final Double cached = globalDeltaQs.get(rgKey, errorModel.index);
/**
* Implements a serial recalibration of the reads using the combinational table.
* First, we perform a positional recalibration, and then a subsequent dinuc correction.
*
* Given the full recalibration table, we perform the following preprocessing steps:
*
* - calculate the global quality score shift across all data [DeltaQ]
* - calculate for each of cycle and dinuc the shift of the quality scores relative to the global shift
* -- i.e., DeltaQ(dinuc) = Sum(pos) Sum(Qual) Qempirical(pos, qual, dinuc) - Qreported(pos, qual, dinuc) / Npos * Nqual
* - The final shift equation is:
*
* Qrecal = Qreported + DeltaQ + DeltaQ(pos) + DeltaQ(dinuc) + DeltaQ( ... any other covariate ... )
*
* @param key The list of Comparables that were calculated from the covariates
* @param errorModel the event type
* @return A recalibrated quality score as a byte
*/
private byte performSequentialQualityCalculation(final int[] key, final EventType errorModel) {
if ( TEST_CACHING ) {
final double calcd = calculateGlobalDeltaQ(rgKey, errorModel);
if ( calcd != cached )
throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + errorModel);
}
final byte qualFromRead = (byte)(long)key[1];
final double globalDeltaQ = calculateGlobalDeltaQ(recalibrationTables.getReadGroupTable(), key, errorModel);
final double deltaQReported = calculateDeltaQReported(recalibrationTables.getQualityScoreTable(), key, errorModel, globalDeltaQ, qualFromRead);
final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, key, errorModel, globalDeltaQ, deltaQReported, qualFromRead);
double recalibratedQual = qualFromRead + globalDeltaQ + deltaQReported + deltaQCovariates; // calculate the recalibrated qual using the BQSR formula
recalibratedQual = QualityUtils.boundQual(MathUtils.fastRound(recalibratedQual), QualityUtils.MAX_RECALIBRATED_Q_SCORE); // recalibrated quality is bound between 1 and MAX_QUAL
return quantizationInfo.getQuantizedQuals().get((int) recalibratedQual); // return the quantized version of the recalibrated quality
return cachedWithDefault(cached);
}
private double calculateGlobalDeltaQ(final NestedIntegerArray<RecalDatum> table, final int[] key, final EventType errorModel) {
private double getDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ) {
final Double cached = deltaQReporteds.get(rgKey, qualKey, errorModel.index);
if ( TEST_CACHING ) {
final double calcd = calculateDeltaQReported(rgKey, qualKey, errorModel, globalDeltaQ, (byte)qualKey);
if ( calcd != cached )
throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + qualKey + " / " + errorModel);
}
return cachedWithDefault(cached);
}
/**
* @param d a Double (that may be null) that is the result of a delta Q calculation
* @return a double == d if d != null, or 0.0 if it is
*/
private double cachedWithDefault(final Double d) {
return d == null ? 0.0 : d;
}
/**
* Note that this calculation is a constant for each rgKey and errorModel. We need only
* compute this value once for all data.
*
* @param rgKey
* @param errorModel
* @return
*/
private double calculateGlobalDeltaQ(final int rgKey, final EventType errorModel) {
double result = 0.0;
final RecalDatum empiricalQualRG = table.get(key[0], errorModel.index);
final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.index);
if (empiricalQualRG != null) {
final double globalDeltaQEmpirical = empiricalQualRG.getEmpiricalQuality();
final double aggregrateQReported = empiricalQualRG.getEstimatedQReported();
@ -164,10 +251,10 @@ public class BaseRecalibration {
return result;
}
private double calculateDeltaQReported(final NestedIntegerArray<RecalDatum> table, final int[] key, final EventType errorModel, final double globalDeltaQ, final byte qualFromRead) {
private double calculateDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ, final byte qualFromRead) {
double result = 0.0;
final RecalDatum empiricalQualQS = table.get(key[0], key[1], errorModel.index);
final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(rgKey, qualKey, errorModel.index);
if (empiricalQualQS != null) {
final double deltaQReportedEmpirical = empiricalQualQS.getEmpiricalQuality();
result = deltaQReportedEmpirical - qualFromRead - globalDeltaQ;
@ -184,12 +271,28 @@ public class BaseRecalibration {
if (key[i] < 0)
continue;
final RecalDatum empiricalQualCO = recalibrationTables.getTable(i).get(key[0], key[1], key[i], errorModel.index);
if (empiricalQualCO != null) {
final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality();
result += (deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported));
}
result += calculateDeltaQCovariate(recalibrationTables.getTable(i),
key[0], key[1], key[i], errorModel,
globalDeltaQ, deltaQReported, qualFromRead);
}
return result;
}
private double calculateDeltaQCovariate(final NestedIntegerArray<RecalDatum> table,
final int rgKey,
final int qualKey,
final int tableKey,
final EventType errorModel,
final double globalDeltaQ,
final double deltaQReported,
final byte qualFromRead) {
final RecalDatum empiricalQualCO = table.get(rgKey, qualKey, tableKey, errorModel.index);
if (empiricalQualCO != null) {
final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality();
return deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported);
} else {
return 0.0;
}
}
}

View File

@ -47,7 +47,7 @@ public class ReadCovariates {
final int[][][] cachedKeys = cache.get(readLength);
if ( cachedKeys == null ) {
// There's no cached value for read length so we need to create a new int[][][] array
logger.info("Keys cache miss for length " + readLength + " cache size " + cache.size());
if ( logger.isDebugEnabled() ) logger.debug("Keys cache miss for length " + readLength + " cache size " + cache.size());
keys = new int[EventType.values().length][readLength][numberOfCovariates];
cache.put(readLength, keys);
} else {