Optimization of recalibrateRead
-- Refactor calculation so that upfront constant values are pre-computed, and cached, and their values just looked up during application -- Trivial comment on how we might use BAQ better in BaseRecalibrator
This commit is contained in:
parent
bd6cda7542
commit
1de2f527b9
|
|
@ -416,6 +416,7 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
|
|||
}
|
||||
|
||||
private byte[] calculateBAQArray( final GATKSAMRecord read ) {
|
||||
// todo -- it would be good to directly use the BAQ qualities rather than encoding and decoding the result and using the special @ value
|
||||
baq.baqRead(read, referenceReader, BAQ.CalculationMode.RECALCULATE, BAQ.QualityMode.ADD_TAG);
|
||||
return BAQ.getBAQTag(read);
|
||||
}
|
||||
|
|
|
|||
|
|
@ -58,13 +58,20 @@ public class NestedIntegerArray<T> {
|
|||
|
||||
int dimensionsToPreallocate = Math.min(dimensions.length, NUM_DIMENSIONS_TO_PREALLOCATE);
|
||||
|
||||
logger.info(String.format("Creating NestedIntegerArray with dimensions %s", Arrays.toString(dimensions)));
|
||||
logger.info(String.format("Pre-allocating first %d dimensions", dimensionsToPreallocate));
|
||||
if ( logger.isDebugEnabled() ) logger.debug(String.format("Creating NestedIntegerArray with dimensions %s", Arrays.toString(dimensions)));
|
||||
if ( logger.isDebugEnabled() ) logger.debug(String.format("Pre-allocating first %d dimensions", dimensionsToPreallocate));
|
||||
|
||||
data = new Object[dimensions[0]];
|
||||
preallocateArray(data, 0, dimensionsToPreallocate);
|
||||
|
||||
logger.info(String.format("Done pre-allocating first %d dimensions", dimensionsToPreallocate));
|
||||
if ( logger.isDebugEnabled() ) logger.debug(String.format("Done pre-allocating first %d dimensions", dimensionsToPreallocate));
|
||||
}
|
||||
|
||||
/**
|
||||
* @return the dimensions of this nested integer array. DO NOT MODIFY
|
||||
*/
|
||||
public int[] getDimensions() {
|
||||
return dimensions;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -27,6 +27,7 @@ package org.broadinstitute.sting.utils.recalibration;
|
|||
|
||||
import net.sf.samtools.SAMTag;
|
||||
import net.sf.samtools.SAMUtils;
|
||||
import org.apache.log4j.Logger;
|
||||
import org.broadinstitute.sting.utils.MathUtils;
|
||||
import org.broadinstitute.sting.utils.QualityUtils;
|
||||
import org.broadinstitute.sting.utils.collections.NestedIntegerArray;
|
||||
|
|
@ -44,7 +45,8 @@ import java.io.File;
|
|||
*/
|
||||
|
||||
public class BaseRecalibration {
|
||||
private final static int MAXIMUM_RECALIBRATED_READ_LENGTH = 5000;
|
||||
private static Logger logger = Logger.getLogger(BaseRecalibration.class);
|
||||
private final static boolean TEST_CACHING = false;
|
||||
|
||||
private final QuantizationInfo quantizationInfo; // histogram containing the map for qual quantization (calculated after recalibration is done)
|
||||
private final RecalibrationTables recalibrationTables;
|
||||
|
|
@ -54,6 +56,10 @@ public class BaseRecalibration {
|
|||
private final int preserveQLessThan;
|
||||
private final boolean emitOriginalQuals;
|
||||
|
||||
private final NestedIntegerArray<Double> globalDeltaQs;
|
||||
private final NestedIntegerArray<Double> deltaQReporteds;
|
||||
|
||||
|
||||
/**
|
||||
* Constructor using a GATK Report file
|
||||
*
|
||||
|
|
@ -76,6 +82,44 @@ public class BaseRecalibration {
|
|||
this.disableIndelQuals = disableIndelQuals;
|
||||
this.preserveQLessThan = preserveQLessThan;
|
||||
this.emitOriginalQuals = emitOriginalQuals;
|
||||
|
||||
logger.info("Calculating cached tables...");
|
||||
|
||||
//
|
||||
// Create a NestedIntegerArray<Double> that maps from rgKey x errorModel -> double,
|
||||
// where the double is the result of this calculation. The entire calculation can
|
||||
// be done upfront, on initialization of this BaseRecalibration structure
|
||||
//
|
||||
final NestedIntegerArray<RecalDatum> byReadGroupTable = recalibrationTables.getReadGroupTable();
|
||||
globalDeltaQs = new NestedIntegerArray<Double>( byReadGroupTable.getDimensions() );
|
||||
logger.info("Calculating global delta Q table...");
|
||||
for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byReadGroupTable.getAllLeaves() ) {
|
||||
final int rgKey = leaf.keys[0];
|
||||
final int eventIndex = leaf.keys[1];
|
||||
final double globalDeltaQ = calculateGlobalDeltaQ(rgKey, EventType.eventFrom(eventIndex));
|
||||
globalDeltaQs.put(globalDeltaQ, rgKey, eventIndex);
|
||||
}
|
||||
|
||||
|
||||
// The calculation of the deltaQ report is constant. key[0] and key[1] are the read group and qual, respectively
|
||||
// and globalDeltaQ is a constant for the read group. So technically the delta Q reported is simply a lookup
|
||||
// into a matrix indexed by rgGroup, qual, and event type.
|
||||
// the code below actually creates this cache with a NestedIntegerArray calling into the actual
|
||||
// calculateDeltaQReported code.
|
||||
final NestedIntegerArray<RecalDatum> byQualTable = recalibrationTables.getQualityScoreTable();
|
||||
deltaQReporteds = new NestedIntegerArray<Double>( byQualTable.getDimensions() );
|
||||
logger.info("Calculating delta Q reported table...");
|
||||
for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byQualTable.getAllLeaves() ) {
|
||||
final int rgKey = leaf.keys[0];
|
||||
final int qual = leaf.keys[1];
|
||||
final int eventIndex = leaf.keys[2];
|
||||
final EventType event = EventType.eventFrom(eventIndex);
|
||||
final double globalDeltaQ = getGlobalDeltaQ(rgKey, event);
|
||||
final double deltaQReported = calculateDeltaQReported(rgKey, qual, event, globalDeltaQ, (byte)qual);
|
||||
deltaQReporteds.put(deltaQReported, rgKey, qual, eventIndex);
|
||||
}
|
||||
|
||||
logger.info("done calculating cache");
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -83,6 +127,18 @@ public class BaseRecalibration {
|
|||
*
|
||||
* It updates the base qualities of the read with the new recalibrated qualities (for all event types)
|
||||
*
|
||||
* Implements a serial recalibration of the reads using the combinational table.
|
||||
* First, we perform a positional recalibration, and then a subsequent dinuc correction.
|
||||
*
|
||||
* Given the full recalibration table, we perform the following preprocessing steps:
|
||||
*
|
||||
* - calculate the global quality score shift across all data [DeltaQ]
|
||||
* - calculate for each of cycle and dinuc the shift of the quality scores relative to the global shift
|
||||
* -- i.e., DeltaQ(dinuc) = Sum(pos) Sum(Qual) Qempirical(pos, qual, dinuc) - Qreported(pos, qual, dinuc) / Npos * Nqual
|
||||
* - The final shift equation is:
|
||||
*
|
||||
* Qrecal = Qreported + DeltaQ + DeltaQ(pos) + DeltaQ(dinuc) + DeltaQ( ... any other covariate ... )
|
||||
*
|
||||
* @param read the read to recalibrate
|
||||
*/
|
||||
public void recalibrateRead(final GATKSAMRecord read) {
|
||||
|
|
@ -95,6 +151,7 @@ public class BaseRecalibration {
|
|||
}
|
||||
|
||||
final ReadCovariates readCovariates = RecalUtils.computeCovariates(read, requestedCovariates);
|
||||
final int readLength = read.getReadLength();
|
||||
|
||||
for (final EventType errorModel : EventType.values()) { // recalibrate all three quality strings
|
||||
if (disableIndelQuals && errorModel != EventType.BASE_SUBSTITUTION) {
|
||||
|
|
@ -103,58 +160,88 @@ public class BaseRecalibration {
|
|||
}
|
||||
|
||||
final byte[] quals = read.getBaseQualities(errorModel);
|
||||
final int[][] fullReadKeySet = readCovariates.getKeySet(errorModel); // get the keyset for this base using the error model
|
||||
|
||||
final int readLength = read.getReadLength();
|
||||
// get the keyset for this base using the error model
|
||||
final int[][] fullReadKeySet = readCovariates.getKeySet(errorModel);
|
||||
|
||||
// the rg key is constant over the whole read, the global deltaQ is too
|
||||
final int rgKey = fullReadKeySet[0][0];
|
||||
|
||||
final double globalDeltaQ = getGlobalDeltaQ(rgKey, errorModel);
|
||||
|
||||
for (int offset = 0; offset < readLength; offset++) { // recalibrate all bases in the read
|
||||
final byte origQual = quals[offset];
|
||||
|
||||
final byte originalQualityScore = quals[offset];
|
||||
// only recalibrate usable qualities (the original quality will come from the instrument -- reported quality)
|
||||
if ( origQual >= preserveQLessThan ) {
|
||||
// get the keyset for this base using the error model
|
||||
final int[] keySet = fullReadKeySet[offset];
|
||||
final double deltaQReported = getDeltaQReported(keySet[0], keySet[1], errorModel, globalDeltaQ);
|
||||
final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, keySet, errorModel, globalDeltaQ, deltaQReported, origQual);
|
||||
|
||||
// calculate the recalibrated qual using the BQSR formula
|
||||
double recalibratedQualDouble = origQual + globalDeltaQ + deltaQReported + deltaQCovariates;
|
||||
|
||||
// recalibrated quality is bound between 1 and MAX_QUAL
|
||||
final byte recalibratedQual = QualityUtils.boundQual(MathUtils.fastRound(recalibratedQualDouble), QualityUtils.MAX_RECALIBRATED_Q_SCORE);
|
||||
|
||||
// return the quantized version of the recalibrated quality
|
||||
final byte recalibratedQualityScore = quantizationInfo.getQuantizedQuals().get(recalibratedQual);
|
||||
|
||||
if (originalQualityScore >= preserveQLessThan) { // only recalibrate usable qualities (the original quality will come from the instrument -- reported quality)
|
||||
final int[] keySet = fullReadKeySet[offset]; // get the keyset for this base using the error model
|
||||
final byte recalibratedQualityScore = performSequentialQualityCalculation(keySet, errorModel); // recalibrate the base
|
||||
quals[offset] = recalibratedQualityScore;
|
||||
}
|
||||
}
|
||||
|
||||
// finally update the base qualities in the read
|
||||
read.setBaseQualities(quals, errorModel);
|
||||
}
|
||||
}
|
||||
|
||||
private double getGlobalDeltaQ(final int rgKey, final EventType errorModel) {
|
||||
final Double cached = globalDeltaQs.get(rgKey, errorModel.index);
|
||||
|
||||
/**
|
||||
* Implements a serial recalibration of the reads using the combinational table.
|
||||
* First, we perform a positional recalibration, and then a subsequent dinuc correction.
|
||||
*
|
||||
* Given the full recalibration table, we perform the following preprocessing steps:
|
||||
*
|
||||
* - calculate the global quality score shift across all data [DeltaQ]
|
||||
* - calculate for each of cycle and dinuc the shift of the quality scores relative to the global shift
|
||||
* -- i.e., DeltaQ(dinuc) = Sum(pos) Sum(Qual) Qempirical(pos, qual, dinuc) - Qreported(pos, qual, dinuc) / Npos * Nqual
|
||||
* - The final shift equation is:
|
||||
*
|
||||
* Qrecal = Qreported + DeltaQ + DeltaQ(pos) + DeltaQ(dinuc) + DeltaQ( ... any other covariate ... )
|
||||
*
|
||||
* @param key The list of Comparables that were calculated from the covariates
|
||||
* @param errorModel the event type
|
||||
* @return A recalibrated quality score as a byte
|
||||
*/
|
||||
private byte performSequentialQualityCalculation(final int[] key, final EventType errorModel) {
|
||||
if ( TEST_CACHING ) {
|
||||
final double calcd = calculateGlobalDeltaQ(rgKey, errorModel);
|
||||
if ( calcd != cached )
|
||||
throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + errorModel);
|
||||
}
|
||||
|
||||
final byte qualFromRead = (byte)(long)key[1];
|
||||
final double globalDeltaQ = calculateGlobalDeltaQ(recalibrationTables.getReadGroupTable(), key, errorModel);
|
||||
final double deltaQReported = calculateDeltaQReported(recalibrationTables.getQualityScoreTable(), key, errorModel, globalDeltaQ, qualFromRead);
|
||||
final double deltaQCovariates = calculateDeltaQCovariates(recalibrationTables, key, errorModel, globalDeltaQ, deltaQReported, qualFromRead);
|
||||
|
||||
double recalibratedQual = qualFromRead + globalDeltaQ + deltaQReported + deltaQCovariates; // calculate the recalibrated qual using the BQSR formula
|
||||
recalibratedQual = QualityUtils.boundQual(MathUtils.fastRound(recalibratedQual), QualityUtils.MAX_RECALIBRATED_Q_SCORE); // recalibrated quality is bound between 1 and MAX_QUAL
|
||||
|
||||
return quantizationInfo.getQuantizedQuals().get((int) recalibratedQual); // return the quantized version of the recalibrated quality
|
||||
return cachedWithDefault(cached);
|
||||
}
|
||||
|
||||
private double calculateGlobalDeltaQ(final NestedIntegerArray<RecalDatum> table, final int[] key, final EventType errorModel) {
|
||||
private double getDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ) {
|
||||
final Double cached = deltaQReporteds.get(rgKey, qualKey, errorModel.index);
|
||||
|
||||
if ( TEST_CACHING ) {
|
||||
final double calcd = calculateDeltaQReported(rgKey, qualKey, errorModel, globalDeltaQ, (byte)qualKey);
|
||||
if ( calcd != cached )
|
||||
throw new IllegalStateException("calculated " + calcd + " and cached " + cached + " global delta q not equal at " + rgKey + " / " + qualKey + " / " + errorModel);
|
||||
}
|
||||
|
||||
return cachedWithDefault(cached);
|
||||
}
|
||||
|
||||
/**
|
||||
* @param d a Double (that may be null) that is the result of a delta Q calculation
|
||||
* @return a double == d if d != null, or 0.0 if it is
|
||||
*/
|
||||
private double cachedWithDefault(final Double d) {
|
||||
return d == null ? 0.0 : d;
|
||||
}
|
||||
|
||||
/**
|
||||
* Note that this calculation is a constant for each rgKey and errorModel. We need only
|
||||
* compute this value once for all data.
|
||||
*
|
||||
* @param rgKey
|
||||
* @param errorModel
|
||||
* @return
|
||||
*/
|
||||
private double calculateGlobalDeltaQ(final int rgKey, final EventType errorModel) {
|
||||
double result = 0.0;
|
||||
|
||||
final RecalDatum empiricalQualRG = table.get(key[0], errorModel.index);
|
||||
final RecalDatum empiricalQualRG = recalibrationTables.getReadGroupTable().get(rgKey, errorModel.index);
|
||||
|
||||
if (empiricalQualRG != null) {
|
||||
final double globalDeltaQEmpirical = empiricalQualRG.getEmpiricalQuality();
|
||||
final double aggregrateQReported = empiricalQualRG.getEstimatedQReported();
|
||||
|
|
@ -164,10 +251,10 @@ public class BaseRecalibration {
|
|||
return result;
|
||||
}
|
||||
|
||||
private double calculateDeltaQReported(final NestedIntegerArray<RecalDatum> table, final int[] key, final EventType errorModel, final double globalDeltaQ, final byte qualFromRead) {
|
||||
private double calculateDeltaQReported(final int rgKey, final int qualKey, final EventType errorModel, final double globalDeltaQ, final byte qualFromRead) {
|
||||
double result = 0.0;
|
||||
|
||||
final RecalDatum empiricalQualQS = table.get(key[0], key[1], errorModel.index);
|
||||
final RecalDatum empiricalQualQS = recalibrationTables.getQualityScoreTable().get(rgKey, qualKey, errorModel.index);
|
||||
if (empiricalQualQS != null) {
|
||||
final double deltaQReportedEmpirical = empiricalQualQS.getEmpiricalQuality();
|
||||
result = deltaQReportedEmpirical - qualFromRead - globalDeltaQ;
|
||||
|
|
@ -184,12 +271,28 @@ public class BaseRecalibration {
|
|||
if (key[i] < 0)
|
||||
continue;
|
||||
|
||||
final RecalDatum empiricalQualCO = recalibrationTables.getTable(i).get(key[0], key[1], key[i], errorModel.index);
|
||||
if (empiricalQualCO != null) {
|
||||
final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality();
|
||||
result += (deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported));
|
||||
}
|
||||
result += calculateDeltaQCovariate(recalibrationTables.getTable(i),
|
||||
key[0], key[1], key[i], errorModel,
|
||||
globalDeltaQ, deltaQReported, qualFromRead);
|
||||
}
|
||||
|
||||
return result;
|
||||
}
|
||||
|
||||
private double calculateDeltaQCovariate(final NestedIntegerArray<RecalDatum> table,
|
||||
final int rgKey,
|
||||
final int qualKey,
|
||||
final int tableKey,
|
||||
final EventType errorModel,
|
||||
final double globalDeltaQ,
|
||||
final double deltaQReported,
|
||||
final byte qualFromRead) {
|
||||
final RecalDatum empiricalQualCO = table.get(rgKey, qualKey, tableKey, errorModel.index);
|
||||
if (empiricalQualCO != null) {
|
||||
final double deltaQCovariateEmpirical = empiricalQualCO.getEmpiricalQuality();
|
||||
return deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported);
|
||||
} else {
|
||||
return 0.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -47,7 +47,7 @@ public class ReadCovariates {
|
|||
final int[][][] cachedKeys = cache.get(readLength);
|
||||
if ( cachedKeys == null ) {
|
||||
// There's no cached value for read length so we need to create a new int[][][] array
|
||||
logger.info("Keys cache miss for length " + readLength + " cache size " + cache.size());
|
||||
if ( logger.isDebugEnabled() ) logger.debug("Keys cache miss for length " + readLength + " cache size " + cache.size());
|
||||
keys = new int[EventType.values().length][readLength][numberOfCovariates];
|
||||
cache.put(readLength, keys);
|
||||
} else {
|
||||
|
|
|
|||
Loading…
Reference in New Issue