BQSR optimization: make RecalibrationTables thread-local, and merge results in onTraversalDone

-- With the newer, faster BQSR, scaling was limited by the NestedIntegerArray.  The solution to this is to make the entire table thread-local, so that each nct thread has its own data and doesn't have any collisions.
-- Removed the previous partial solution of having a thread-local quality score table
-- Added a new argument -lowMemory
This commit is contained in:
Mark DePristo 2013-01-03 14:43:36 -05:00
parent 1ba8d47a81
commit 7df47418d8
5 changed files with 130 additions and 69 deletions

View File

@ -25,35 +25,21 @@ package org.broadinstitute.sting.gatk.walkers.bqsr;
* OTHER DEALINGS IN THE SOFTWARE.
*/
import org.apache.log4j.Logger;
import org.broadinstitute.sting.utils.classloader.ProtectedPackageSource;
import org.broadinstitute.sting.utils.collections.NestedIntegerArray;
import org.broadinstitute.sting.utils.recalibration.EventType;
import org.broadinstitute.sting.utils.recalibration.ReadCovariates;
import org.broadinstitute.sting.utils.recalibration.RecalDatum;
import org.broadinstitute.sting.utils.recalibration.RecalibrationTables;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import java.util.LinkedList;
import java.util.List;
public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine implements ProtectedPackageSource {
private final static Logger logger = Logger.getLogger(AdvancedRecalibrationEngine.class);
final List<NestedIntegerArray<RecalDatum>> allThreadLocalQualityScoreTables = new LinkedList<NestedIntegerArray<RecalDatum>>();
private ThreadLocal<NestedIntegerArray<RecalDatum>> threadLocalQualityScoreTables = new ThreadLocal<NestedIntegerArray<RecalDatum>>() {
@Override
protected synchronized NestedIntegerArray<RecalDatum> initialValue() {
final NestedIntegerArray<RecalDatum> table = recalibrationTables.makeQualityScoreTable();
allThreadLocalQualityScoreTables.add(table);
return table;
}
};
@Override
public void updateDataForRead( final ReadRecalibrationInfo recalInfo ) {
final GATKSAMRecord read = recalInfo.getRead();
final ReadCovariates readCovariates = recalInfo.getCovariatesValues();
final NestedIntegerArray<RecalDatum> qualityScoreTable = getThreadLocalQualityScoreTable();
final RecalibrationTables tables = getRecalibrationTables();
final NestedIntegerArray<RecalDatum> qualityScoreTable = tables.getQualityScoreTable();
for( int offset = 0; offset < read.getReadBases().length; offset++ ) {
if( ! recalInfo.skip(offset) ) {
@ -70,30 +56,10 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp
if (keys[i] < 0)
continue;
incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex);
incrementDatumOrPutIfNecessary(tables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex);
}
}
}
}
}
/**
* Get a NestedIntegerArray for a QualityScore table specific to this thread
* @return a non-null NestedIntegerArray ready to be used to collect calibration info for the quality score covariate
*/
private NestedIntegerArray<RecalDatum> getThreadLocalQualityScoreTable() {
return threadLocalQualityScoreTables.get();
}
@Override
public void finalizeData() {
// merge in all of the thread local tables
logger.info("Merging " + allThreadLocalQualityScoreTables.size() + " thread-local quality score tables");
for ( final NestedIntegerArray<RecalDatum> localTable : allThreadLocalQualityScoreTables ) {
recalibrationTables.combineQualityScoreTable(localTable);
}
allThreadLocalQualityScoreTables.clear(); // cleanup after ourselves
super.finalizeData();
}
}

View File

@ -120,9 +120,16 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
@Argument(fullName = "bqsrBAQGapOpenPenalty", shortName="bqsrBAQGOP", doc="BQSR BAQ gap open penalty (Phred Scaled). Default value is 40. 30 is perhaps better for whole genome call sets", required = false)
public double BAQGOP = BAQ.DEFAULT_GOP;
private QuantizationInfo quantizationInfo; // an object that keeps track of the information necessary for quality score quantization
/**
* When you have nct > 1, BQSR uses nct times more memory to compute its recalibration tables, for efficiency
* purposes. If you have many covariates, and therefore are using a lot of memory, you can use this flag
* to safely access only one table. There may be some CPU cost, but as long as the table is really big
* there should be relatively little CPU costs.
*/
@Argument(fullName = "lowMemoryMode", shortName="lowMemoryMode", doc="Reduce memory usage in multi-threaded code at the expense of threading efficiency", required = false)
public boolean lowMemoryMode = false;
private RecalibrationTables recalibrationTables;
private QuantizationInfo quantizationInfo; // an object that keeps track of the information necessary for quality score quantization
private Covariate[] requestedCovariates; // list to hold the all the covariate objects that were requested (required + standard + experimental)
@ -130,8 +137,6 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
private int minimumQToUse;
protected static final String COVARS_ATTRIBUTE = "COVARS"; // used to store covariates array as a temporary attribute inside GATKSAMRecord.\
private static final String NO_DBSNP_EXCEPTION = "This calculation is critically dependent on being able to skip over known variant sites. Please provide a VCF file containing known sites of genetic variation.";
private BAQ baq; // BAQ the reads on the fly to generate the alignment uncertainty vector
@ -143,7 +148,6 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
* Based on the covariates' estimates for initial capacity allocate the data hashmap
*/
public void initialize() {
baq = new BAQ(BAQGOP); // setup the BAQ object with the provided gap open penalty
// check for unsupported access
@ -188,10 +192,11 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
int numReadGroups = 0;
for ( final SAMFileHeader header : getToolkit().getSAMFileHeaders() )
numReadGroups += header.getReadGroups().size();
recalibrationTables = new RecalibrationTables(requestedCovariates, numReadGroups, RAC.RECAL_TABLE_UPDATE_LOG);
recalibrationEngine = initializeRecalibrationEngine();
recalibrationEngine.initialize(requestedCovariates, recalibrationTables);
recalibrationEngine.initialize(requestedCovariates, numReadGroups, RAC.RECAL_TABLE_UPDATE_LOG);
if ( lowMemoryMode )
recalibrationEngine.enableLowMemoryMode();
minimumQToUse = getToolkit().getArguments().PRESERVE_QSCORES_LESS_THAN;
referenceReader = getToolkit().getReferenceDataSource().getReference();
@ -501,14 +506,18 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
logger.info("Processed: " + result + " reads");
}
private RecalibrationTables getRecalibrationTable() {
return recalibrationEngine.getFinalRecalibrationTables();
}
private void generatePlots() {
File recalFile = getToolkit().getArguments().BQSR_RECAL_FILE;
if (recalFile != null) {
RecalibrationReport report = new RecalibrationReport(recalFile);
RecalUtils.generateRecalibrationPlot(RAC, report.getRecalibrationTables(), recalibrationTables, requestedCovariates);
RecalUtils.generateRecalibrationPlot(RAC, report.getRecalibrationTables(), getRecalibrationTable(), requestedCovariates);
}
else
RecalUtils.generateRecalibrationPlot(RAC, recalibrationTables, requestedCovariates);
RecalUtils.generateRecalibrationPlot(RAC, getRecalibrationTable(), requestedCovariates);
}
/**
@ -517,10 +526,10 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
* generate a quantization map (recalibrated_qual -> quantized_qual)
*/
private void quantizeQualityScores() {
quantizationInfo = new QuantizationInfo(recalibrationTables, RAC.QUANTIZING_LEVELS);
quantizationInfo = new QuantizationInfo(getRecalibrationTable(), RAC.QUANTIZING_LEVELS);
}
private void generateReport() {
RecalUtils.outputRecalibrationReport(RAC, quantizationInfo, recalibrationTables, requestedCovariates, RAC.SORT_BY_ALL_COLUMNS);
RecalUtils.outputRecalibrationReport(RAC, quantizationInfo, getRecalibrationTable(), requestedCovariates, RAC.SORT_BY_ALL_COLUMNS);
}
}

View File

@ -5,6 +5,8 @@ import org.broadinstitute.sting.utils.recalibration.RecalibrationTables;
import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import java.io.PrintStream;
/*
* Copyright (c) 2009 The Broad Institute
*
@ -40,9 +42,10 @@ public interface RecalibrationEngine {
* The engine should collect match and mismatch data into the recalibrationTables data.
*
* @param covariates an array of the covariates we'll be using in this engine, order matters
* @param recalibrationTables the destination recalibrationTables where stats should be collected
* @param numReadGroups the number of read groups we should use for the recalibration tables
* @param maybeLogStream an optional print stream for logging calls to the nestedhashmap in the recalibration tables
*/
public void initialize(final Covariate[] covariates, final RecalibrationTables recalibrationTables);
public void initialize(final Covariate[] covariates, final int numReadGroups, final PrintStream maybeLogStream);
/**
* Update the recalibration statistics using the information in recalInfo
@ -57,4 +60,8 @@ public interface RecalibrationEngine {
* Called once after all calls to updateDataForRead have been issued.
*/
public void finalizeData();
public void enableLowMemoryMode();
public RecalibrationTables getFinalRecalibrationTables();
}

View File

@ -25,26 +25,64 @@ package org.broadinstitute.sting.gatk.walkers.bqsr;
* OTHER DEALINGS IN THE SOFTWARE.
*/
import com.google.java.contract.Requires;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.utils.classloader.PublicPackageSource;
import org.broadinstitute.sting.utils.collections.NestedIntegerArray;
import org.broadinstitute.sting.utils.recalibration.EventType;
import org.broadinstitute.sting.utils.recalibration.ReadCovariates;
import org.broadinstitute.sting.utils.recalibration.RecalDatum;
import org.broadinstitute.sting.utils.recalibration.RecalibrationTables;
import org.broadinstitute.sting.utils.recalibration.*;
import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import java.io.PrintStream;
import java.util.LinkedList;
import java.util.List;
public class StandardRecalibrationEngine implements RecalibrationEngine, PublicPackageSource {
private static final Logger logger = Logger.getLogger(StandardRecalibrationEngine.class);
protected Covariate[] covariates;
protected RecalibrationTables recalibrationTables;
private int numReadGroups;
private PrintStream maybeLogStream;
private boolean lowMemoryMode = false;
private boolean finalized = false;
private RecalibrationTables mergedRecalibrationTables = null;
private final List<RecalibrationTables> recalibrationTablesList = new LinkedList<RecalibrationTables>();
private final ThreadLocal<RecalibrationTables> threadLocalTables = new ThreadLocal<RecalibrationTables>() {
private synchronized RecalibrationTables makeAndCaptureTable() {
logger.info("Creating RecalibrationTable for " + Thread.currentThread());
final RecalibrationTables newTable = new RecalibrationTables(covariates, numReadGroups, maybeLogStream);
recalibrationTablesList.add(newTable);
return newTable;
}
@Override
protected synchronized RecalibrationTables initialValue() {
if ( lowMemoryMode ) {
return recalibrationTablesList.isEmpty() ? makeAndCaptureTable() : recalibrationTablesList.get(0);
} else {
return makeAndCaptureTable();
}
}
};
protected RecalibrationTables getRecalibrationTables() {
return threadLocalTables.get();
}
public void enableLowMemoryMode() {
this.lowMemoryMode = true;
}
@Override
public void initialize(final Covariate[] covariates, final RecalibrationTables recalibrationTables) {
public void initialize(final Covariate[] covariates, final int numReadGroups, final PrintStream maybeLogStream) {
if ( covariates == null ) throw new IllegalArgumentException("Covariates cannot be null");
if ( recalibrationTables == null ) throw new IllegalArgumentException("recalibrationTables cannot be null");
if ( numReadGroups < 1 ) throw new IllegalArgumentException("numReadGroups must be >= 1 but got " + numReadGroups);
this.covariates = covariates.clone();
this.recalibrationTables = recalibrationTables;
this.numReadGroups = numReadGroups;
this.maybeLogStream = maybeLogStream;
}
@Override
@ -59,13 +97,13 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
final double isError = recalInfo.getErrorFraction(eventType, offset);
final int[] keys = readCovariates.getKeySet(offset, eventType);
incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventType.index);
incrementDatumOrPutIfNecessary(getRecalibrationTables().getQualityScoreTable(), qual, isError, keys[0], keys[1], eventType.index);
for (int i = 2; i < covariates.length; i++) {
if (keys[i] < 0)
continue;
incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventType.index);
incrementDatumOrPutIfNecessary(getRecalibrationTables().getTable(i), qual, isError, keys[0], keys[1], keys[i], eventType.index);
}
}
}
@ -90,8 +128,13 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
*/
@Override
public void finalizeData() {
final NestedIntegerArray<RecalDatum> byReadGroupTable = recalibrationTables.getReadGroupTable();
final NestedIntegerArray<RecalDatum> byQualTable = recalibrationTables.getQualityScoreTable();
if ( finalized ) throw new IllegalStateException("FinalizeData() has already been called");
// merge all of the thread-local tables
mergedRecalibrationTables = mergeThreadLocalRecalibrationTables();
final NestedIntegerArray<RecalDatum> byReadGroupTable = mergedRecalibrationTables.getReadGroupTable();
final NestedIntegerArray<RecalDatum> byQualTable = mergedRecalibrationTables.getQualityScoreTable();
// iterate over all values in the qual table
for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byQualTable.getAllLeaves() ) {
@ -108,6 +151,38 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
rgDatum.combine(qualDatum);
}
}
finalized = true;
}
/**
* Merge all of the thread local recalibration tables into a single one.
*
* Reuses one of the recalibration tables to hold the merged table, so this function can only be
* called once in the engine.
*
* @return the merged recalibration table
*/
@Requires("! finalized")
private RecalibrationTables mergeThreadLocalRecalibrationTables() {
if ( recalibrationTablesList.isEmpty() ) throw new IllegalStateException("recalibration tables list is empty");
RecalibrationTables merged = null;
for ( final RecalibrationTables table : recalibrationTablesList ) {
if ( merged == null )
// fast path -- if there's only only one table, so just make it the merged one
merged = table;
else {
merged.combine(table);
}
}
return merged;
}
public RecalibrationTables getFinalRecalibrationTables() {
if ( ! finalized ) throw new IllegalStateException("Cannot get final recalibration tables until finalizeData() has been called");
return mergedRecalibrationTables;
}
/**

View File

@ -123,12 +123,16 @@ public final class RecalibrationTables {
}
/**
* Merge in the quality score table information from qualityScoreTable into this
* recalibration table's quality score table.
*
* @param qualityScoreTable the quality score table we want to merge in
* Merge all of the tables from toMerge into into this set of tables
*/
public void combineQualityScoreTable(final NestedIntegerArray<RecalDatum> qualityScoreTable) {
RecalUtils.combineTables(getQualityScoreTable(), qualityScoreTable);
public void combine(final RecalibrationTables toMerge) {
if ( numTables() != toMerge.numTables() )
throw new IllegalArgumentException("Attempting to merge RecalibrationTables with different sizes");
for ( int i = 0; i < numTables(); i++ ) {
final NestedIntegerArray<RecalDatum> myTable = this.getTable(i);
final NestedIntegerArray<RecalDatum> otherTable = toMerge.getTable(i);
RecalUtils.combineTables(myTable, otherTable);
}
}
}