Optimizations for BaseRecalibrator

-- No longer computes at each update the overall read group table.  Now computes this derived table only at the end of the computation, using the ByQual table as input.  Reduces BQSR runtime by 1/3 in my test
This commit is contained in:
Mark DePristo 2012-12-12 13:54:14 -05:00
parent e6f468b647
commit 5ec25797b3
7 changed files with 57 additions and 72 deletions

View File

@ -25,13 +25,11 @@ package org.broadinstitute.sting.gatk.walkers.bqsr;
* OTHER DEALINGS IN THE SOFTWARE.
*/
import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
import org.broadinstitute.sting.utils.BaseUtils;
import org.broadinstitute.sting.utils.classloader.ProtectedPackageSource;
import org.broadinstitute.sting.utils.pileup.PileupElement;
import org.broadinstitute.sting.utils.recalibration.EventType;
import org.broadinstitute.sting.utils.recalibration.ReadCovariates;
import org.broadinstitute.sting.utils.recalibration.RecalibrationTables;
import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import org.broadinstitute.sting.utils.threading.ThreadLocalArray;
@ -67,8 +65,6 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp
final byte qual = tempQualArray[eventIndex];
final double isError = tempFractionalErrorArray[eventIndex];
combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex);
incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex);
for (int i = 2; i < covariates.length; i++) {

View File

@ -452,6 +452,8 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
@Override
public void onTraversalDone(Long result) {
recalibrationEngine.finalizeData();
logger.info("Calculating quantized quality scores...");
quantizeQualityScores();

View File

@ -1,8 +1,7 @@
package org.broadinstitute.sting.gatk.walkers.bqsr;
import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
import org.broadinstitute.sting.utils.pileup.PileupElement;
import org.broadinstitute.sting.utils.recalibration.RecalibrationTables;
import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
/*
@ -34,4 +33,6 @@ public interface RecalibrationEngine {
public void initialize(final Covariate[] covariates, final RecalibrationTables recalibrationTables);
public void updateDataForRead(final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors);
public void finalizeData();
}

View File

@ -25,15 +25,13 @@ package org.broadinstitute.sting.gatk.walkers.bqsr;
* OTHER DEALINGS IN THE SOFTWARE.
*/
import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
import org.broadinstitute.sting.utils.BaseUtils;
import org.broadinstitute.sting.utils.classloader.PublicPackageSource;
import org.broadinstitute.sting.utils.collections.NestedIntegerArray;
import org.broadinstitute.sting.utils.pileup.PileupElement;
import org.broadinstitute.sting.utils.recalibration.EventType;
import org.broadinstitute.sting.utils.recalibration.ReadCovariates;
import org.broadinstitute.sting.utils.recalibration.RecalDatum;
import org.broadinstitute.sting.utils.recalibration.RecalibrationTables;
import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
public class StandardRecalibrationEngine implements RecalibrationEngine, PublicPackageSource {
@ -58,8 +56,6 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
final int[] keys = readCovariates.getKeySet(offset, EventType.BASE_SUBSTITUTION);
final int eventIndex = EventType.BASE_SUBSTITUTION.index;
combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex);
incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex);
for (int i = 2; i < covariates.length; i++) {
@ -93,6 +89,34 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
return (ReadCovariates) read.getTemporaryAttribute(BaseRecalibrator.COVARS_ATTRIBUTE);
}
/**
* Create derived recalibration data tables
*
* Assumes that all of the principal tables (by quality score) have been completely updated,
* and walks over this data to create summary data tables like by read group table.
*/
@Override
public void finalizeData() {
final NestedIntegerArray<RecalDatum> byReadGroupTable = recalibrationTables.getReadGroupTable();
final NestedIntegerArray<RecalDatum> byQualTable = recalibrationTables.getQualityScoreTable();
// iterate over all values in the qual table
for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byQualTable.getAllLeaves() ) {
final int rgKey = leaf.keys[0];
final int eventIndex = leaf.keys[2];
final RecalDatum rgDatum = byReadGroupTable.get(rgKey, eventIndex);
final RecalDatum qualDatum = leaf.value;
if ( rgDatum == null ) {
// create a copy of qualDatum, and initialize byReadGroup table with it
byReadGroupTable.put(new RecalDatum(qualDatum), rgKey, eventIndex);
} else {
// combine the qual datum with the existing datum in the byReadGroup table
rgDatum.combine(qualDatum);
}
}
}
/**
* Increments the RecalDatum at the specified position in the specified table, or put a new item there
* if there isn't already one.
@ -121,34 +145,4 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
existingDatum.increment(1.0, isError);
}
}
/**
* Combines the RecalDatum at the specified position in the specified table with a new RecalDatum, or put a
* new item there if there isn't already one.
*
* Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put()
* to return false if another thread inserts a new item at our position in the middle of our put operation.
*
* @param table the table that holds/will hold our item
* @param qual qual for this event
* @param isError error value for this event
* @param keys location in table of our item
*/
protected void combineDatumOrPutIfNecessary( final NestedIntegerArray<RecalDatum> table, final byte qual, final double isError, final int... keys ) {
final RecalDatum existingDatum = table.get(keys);
final RecalDatum newDatum = createDatumObject(qual, isError);
if ( existingDatum == null ) {
// No existing item, try to put a new one
if ( ! table.put(newDatum, keys) ) {
// Failed to put a new item because another thread came along and put an item here first.
// Get the newly-put item and combine it with our item (item is guaranteed to exist at this point)
table.get(keys).combine(newDatum);
}
}
else {
// Easy case: already an item here, so combine it with our item
existingDatum.combine(newDatum);
}
}
}

View File

@ -174,23 +174,23 @@ public class NestedIntegerArray<T> {
}
}
public static class Leaf {
public static class Leaf<T> {
public final int[] keys;
public final Object value;
public final T value;
public Leaf(final int[] keys, final Object value) {
public Leaf(final int[] keys, final T value) {
this.keys = keys;
this.value = value;
}
}
public List<Leaf> getAllLeaves() {
final List<Leaf> result = new ArrayList<Leaf>();
public List<Leaf<T>> getAllLeaves() {
final List<Leaf<T>> result = new ArrayList<Leaf<T>>();
fillAllLeaves(data, new int[0], result);
return result;
}
private void fillAllLeaves(final Object[] array, final int[] path, final List<Leaf> result) {
private void fillAllLeaves(final Object[] array, final int[] path, final List<Leaf<T>> result) {
for ( int key = 0; key < array.length; key++ ) {
final Object value = array[key];
if ( value == null )
@ -199,7 +199,7 @@ public class NestedIntegerArray<T> {
if ( value instanceof Object[] ) {
fillAllLeaves((Object[]) value, newPath, result);
} else {
result.add(new Leaf(newPath, value));
result.add(new Leaf<T>(newPath, (T)value));
}
}
}

View File

@ -191,9 +191,9 @@ public class RecalDatum {
return (byte)(Math.round(getEmpiricalQuality()));
}
//---------------------------------------------------------------------------------------------------------------
//---------------------------------------------------------------------------------------------------------------
//
// increment methods
// toString methods
//
//---------------------------------------------------------------------------------------------------------------
@ -206,22 +206,6 @@ public class RecalDatum {
return String.format("%s,%.2f,%.2f", toString(), getEstimatedQReported(), getEmpiricalQuality() - getEstimatedQReported());
}
// /**
// * We don't compare the estimated quality reported because it may be different when read from
// * report tables.
// *
// * @param o the other recal datum
// * @return true if the two recal datums have the same number of observations, errors and empirical quality.
// */
// @Override
// public boolean equals(Object o) {
// if (!(o instanceof RecalDatum))
// return false;
// RecalDatum other = (RecalDatum) o;
// return super.equals(o) &&
// MathUtils.compareDoubles(this.empiricalQuality, other.empiricalQuality, 0.001) == 0;
// }
//---------------------------------------------------------------------------------------------------------------
//
// increment methods
@ -270,9 +254,7 @@ public class RecalDatum {
@Ensures({"numObservations == old(numObservations) + 1", "numMismatches >= old(numMismatches)"})
public synchronized void increment(final boolean isError) {
incrementNumObservations(1);
if ( isError )
incrementNumMismatches(1);
increment(1, isError ? 1 : 0.0);
}
// -------------------------------------------------------------------------------------
@ -286,7 +268,7 @@ public class RecalDatum {
*/
@Requires("empiricalQuality == UNINITIALIZED")
@Ensures("empiricalQuality != UNINITIALIZED")
private synchronized final void calcEmpiricalQuality() {
private synchronized void calcEmpiricalQuality() {
final double empiricalQual = -10 * Math.log10(getEmpiricalErrorRate());
empiricalQuality = Math.min(empiricalQual, (double) QualityUtils.MAX_RECALIBRATED_Q_SCORE);
}

View File

@ -1,11 +1,13 @@
package org.broadinstitute.sting.utils.recalibration.covariates;
import org.broadinstitute.sting.utils.recalibration.ReadCovariates;
import org.broadinstitute.sting.gatk.walkers.bqsr.RecalibrationArgumentCollection;
import org.broadinstitute.sting.utils.recalibration.ReadCovariates;
import org.broadinstitute.sting.utils.sam.GATKSAMReadGroupRecord;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
/*
* Copyright (c) 2009 The Broad Institute
@ -77,6 +79,14 @@ public class ReadGroupCovariate implements RequiredCovariate {
return keyForReadGroup((String) value);
}
/**
* Get the mapping from read group names to integer key values for all read groups in this covariate
* @return a set of mappings from read group names -> integer key values
*/
public Set<Map.Entry<String, Integer>> getKeyMap() {
return readGroupLookupTable.entrySet();
}
private int keyForReadGroup(final String readGroupId) {
// Rather than synchronize this entire method (which would be VERY expensive for walkers like the BQSR),
// synchronize only the table updates.