BQSR optimization: make RecalibrationTables thread-local, and merge results in onTraversalDone
-- With the newer, faster BQSR, scaling was limited by the NestedIntegerArray. The solution to this is to make the entire table thread-local, so that each nct thread has its own data and doesn't have any collisions. -- Removed the previous partial solution of having a thread-local quality score table -- Added a new argument -lowMemory
This commit is contained in:
parent
1ba8d47a81
commit
7df47418d8
|
|
@ -25,35 +25,21 @@ package org.broadinstitute.sting.gatk.walkers.bqsr;
|
|||
* OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
import org.apache.log4j.Logger;
|
||||
import org.broadinstitute.sting.utils.classloader.ProtectedPackageSource;
|
||||
import org.broadinstitute.sting.utils.collections.NestedIntegerArray;
|
||||
import org.broadinstitute.sting.utils.recalibration.EventType;
|
||||
import org.broadinstitute.sting.utils.recalibration.ReadCovariates;
|
||||
import org.broadinstitute.sting.utils.recalibration.RecalDatum;
|
||||
import org.broadinstitute.sting.utils.recalibration.RecalibrationTables;
|
||||
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
|
||||
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine implements ProtectedPackageSource {
|
||||
private final static Logger logger = Logger.getLogger(AdvancedRecalibrationEngine.class);
|
||||
|
||||
final List<NestedIntegerArray<RecalDatum>> allThreadLocalQualityScoreTables = new LinkedList<NestedIntegerArray<RecalDatum>>();
|
||||
private ThreadLocal<NestedIntegerArray<RecalDatum>> threadLocalQualityScoreTables = new ThreadLocal<NestedIntegerArray<RecalDatum>>() {
|
||||
@Override
|
||||
protected synchronized NestedIntegerArray<RecalDatum> initialValue() {
|
||||
final NestedIntegerArray<RecalDatum> table = recalibrationTables.makeQualityScoreTable();
|
||||
allThreadLocalQualityScoreTables.add(table);
|
||||
return table;
|
||||
}
|
||||
};
|
||||
|
||||
@Override
|
||||
public void updateDataForRead( final ReadRecalibrationInfo recalInfo ) {
|
||||
final GATKSAMRecord read = recalInfo.getRead();
|
||||
final ReadCovariates readCovariates = recalInfo.getCovariatesValues();
|
||||
final NestedIntegerArray<RecalDatum> qualityScoreTable = getThreadLocalQualityScoreTable();
|
||||
final RecalibrationTables tables = getRecalibrationTables();
|
||||
final NestedIntegerArray<RecalDatum> qualityScoreTable = tables.getQualityScoreTable();
|
||||
|
||||
for( int offset = 0; offset < read.getReadBases().length; offset++ ) {
|
||||
if( ! recalInfo.skip(offset) ) {
|
||||
|
|
@ -70,30 +56,10 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp
|
|||
if (keys[i] < 0)
|
||||
continue;
|
||||
|
||||
incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex);
|
||||
incrementDatumOrPutIfNecessary(tables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Get a NestedIntegerArray for a QualityScore table specific to this thread
|
||||
* @return a non-null NestedIntegerArray ready to be used to collect calibration info for the quality score covariate
|
||||
*/
|
||||
private NestedIntegerArray<RecalDatum> getThreadLocalQualityScoreTable() {
|
||||
return threadLocalQualityScoreTables.get();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void finalizeData() {
|
||||
// merge in all of the thread local tables
|
||||
logger.info("Merging " + allThreadLocalQualityScoreTables.size() + " thread-local quality score tables");
|
||||
for ( final NestedIntegerArray<RecalDatum> localTable : allThreadLocalQualityScoreTables ) {
|
||||
recalibrationTables.combineQualityScoreTable(localTable);
|
||||
}
|
||||
allThreadLocalQualityScoreTables.clear(); // cleanup after ourselves
|
||||
|
||||
super.finalizeData();
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -120,9 +120,16 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
|
|||
@Argument(fullName = "bqsrBAQGapOpenPenalty", shortName="bqsrBAQGOP", doc="BQSR BAQ gap open penalty (Phred Scaled). Default value is 40. 30 is perhaps better for whole genome call sets", required = false)
|
||||
public double BAQGOP = BAQ.DEFAULT_GOP;
|
||||
|
||||
private QuantizationInfo quantizationInfo; // an object that keeps track of the information necessary for quality score quantization
|
||||
/**
|
||||
* When you have nct > 1, BQSR uses nct times more memory to compute its recalibration tables, for efficiency
|
||||
* purposes. If you have many covariates, and therefore are using a lot of memory, you can use this flag
|
||||
* to safely access only one table. There may be some CPU cost, but as long as the table is really big
|
||||
* there should be relatively little CPU costs.
|
||||
*/
|
||||
@Argument(fullName = "lowMemoryMode", shortName="lowMemoryMode", doc="Reduce memory usage in multi-threaded code at the expense of threading efficiency", required = false)
|
||||
public boolean lowMemoryMode = false;
|
||||
|
||||
private RecalibrationTables recalibrationTables;
|
||||
private QuantizationInfo quantizationInfo; // an object that keeps track of the information necessary for quality score quantization
|
||||
|
||||
private Covariate[] requestedCovariates; // list to hold the all the covariate objects that were requested (required + standard + experimental)
|
||||
|
||||
|
|
@ -130,8 +137,6 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
|
|||
|
||||
private int minimumQToUse;
|
||||
|
||||
protected static final String COVARS_ATTRIBUTE = "COVARS"; // used to store covariates array as a temporary attribute inside GATKSAMRecord.\
|
||||
|
||||
private static final String NO_DBSNP_EXCEPTION = "This calculation is critically dependent on being able to skip over known variant sites. Please provide a VCF file containing known sites of genetic variation.";
|
||||
|
||||
private BAQ baq; // BAQ the reads on the fly to generate the alignment uncertainty vector
|
||||
|
|
@ -143,7 +148,6 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
|
|||
* Based on the covariates' estimates for initial capacity allocate the data hashmap
|
||||
*/
|
||||
public void initialize() {
|
||||
|
||||
baq = new BAQ(BAQGOP); // setup the BAQ object with the provided gap open penalty
|
||||
|
||||
// check for unsupported access
|
||||
|
|
@ -188,10 +192,11 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
|
|||
int numReadGroups = 0;
|
||||
for ( final SAMFileHeader header : getToolkit().getSAMFileHeaders() )
|
||||
numReadGroups += header.getReadGroups().size();
|
||||
recalibrationTables = new RecalibrationTables(requestedCovariates, numReadGroups, RAC.RECAL_TABLE_UPDATE_LOG);
|
||||
|
||||
recalibrationEngine = initializeRecalibrationEngine();
|
||||
recalibrationEngine.initialize(requestedCovariates, recalibrationTables);
|
||||
recalibrationEngine.initialize(requestedCovariates, numReadGroups, RAC.RECAL_TABLE_UPDATE_LOG);
|
||||
if ( lowMemoryMode )
|
||||
recalibrationEngine.enableLowMemoryMode();
|
||||
|
||||
minimumQToUse = getToolkit().getArguments().PRESERVE_QSCORES_LESS_THAN;
|
||||
referenceReader = getToolkit().getReferenceDataSource().getReference();
|
||||
|
|
@ -501,14 +506,18 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
|
|||
logger.info("Processed: " + result + " reads");
|
||||
}
|
||||
|
||||
private RecalibrationTables getRecalibrationTable() {
|
||||
return recalibrationEngine.getFinalRecalibrationTables();
|
||||
}
|
||||
|
||||
private void generatePlots() {
|
||||
File recalFile = getToolkit().getArguments().BQSR_RECAL_FILE;
|
||||
if (recalFile != null) {
|
||||
RecalibrationReport report = new RecalibrationReport(recalFile);
|
||||
RecalUtils.generateRecalibrationPlot(RAC, report.getRecalibrationTables(), recalibrationTables, requestedCovariates);
|
||||
RecalUtils.generateRecalibrationPlot(RAC, report.getRecalibrationTables(), getRecalibrationTable(), requestedCovariates);
|
||||
}
|
||||
else
|
||||
RecalUtils.generateRecalibrationPlot(RAC, recalibrationTables, requestedCovariates);
|
||||
RecalUtils.generateRecalibrationPlot(RAC, getRecalibrationTable(), requestedCovariates);
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -517,10 +526,10 @@ public class BaseRecalibrator extends ReadWalker<Long, Long> implements NanoSche
|
|||
* generate a quantization map (recalibrated_qual -> quantized_qual)
|
||||
*/
|
||||
private void quantizeQualityScores() {
|
||||
quantizationInfo = new QuantizationInfo(recalibrationTables, RAC.QUANTIZING_LEVELS);
|
||||
quantizationInfo = new QuantizationInfo(getRecalibrationTable(), RAC.QUANTIZING_LEVELS);
|
||||
}
|
||||
|
||||
private void generateReport() {
|
||||
RecalUtils.outputRecalibrationReport(RAC, quantizationInfo, recalibrationTables, requestedCovariates, RAC.SORT_BY_ALL_COLUMNS);
|
||||
RecalUtils.outputRecalibrationReport(RAC, quantizationInfo, getRecalibrationTable(), requestedCovariates, RAC.SORT_BY_ALL_COLUMNS);
|
||||
}
|
||||
}
|
||||
|
|
@ -5,6 +5,8 @@ import org.broadinstitute.sting.utils.recalibration.RecalibrationTables;
|
|||
import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
|
||||
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
|
||||
|
||||
import java.io.PrintStream;
|
||||
|
||||
/*
|
||||
* Copyright (c) 2009 The Broad Institute
|
||||
*
|
||||
|
|
@ -40,9 +42,10 @@ public interface RecalibrationEngine {
|
|||
* The engine should collect match and mismatch data into the recalibrationTables data.
|
||||
*
|
||||
* @param covariates an array of the covariates we'll be using in this engine, order matters
|
||||
* @param recalibrationTables the destination recalibrationTables where stats should be collected
|
||||
* @param numReadGroups the number of read groups we should use for the recalibration tables
|
||||
* @param maybeLogStream an optional print stream for logging calls to the nestedhashmap in the recalibration tables
|
||||
*/
|
||||
public void initialize(final Covariate[] covariates, final RecalibrationTables recalibrationTables);
|
||||
public void initialize(final Covariate[] covariates, final int numReadGroups, final PrintStream maybeLogStream);
|
||||
|
||||
/**
|
||||
* Update the recalibration statistics using the information in recalInfo
|
||||
|
|
@ -57,4 +60,8 @@ public interface RecalibrationEngine {
|
|||
* Called once after all calls to updateDataForRead have been issued.
|
||||
*/
|
||||
public void finalizeData();
|
||||
|
||||
public void enableLowMemoryMode();
|
||||
|
||||
public RecalibrationTables getFinalRecalibrationTables();
|
||||
}
|
||||
|
|
|
|||
|
|
@ -25,26 +25,64 @@ package org.broadinstitute.sting.gatk.walkers.bqsr;
|
|||
* OTHER DEALINGS IN THE SOFTWARE.
|
||||
*/
|
||||
|
||||
import com.google.java.contract.Requires;
|
||||
import org.apache.log4j.Logger;
|
||||
import org.broadinstitute.sting.utils.classloader.PublicPackageSource;
|
||||
import org.broadinstitute.sting.utils.collections.NestedIntegerArray;
|
||||
import org.broadinstitute.sting.utils.recalibration.EventType;
|
||||
import org.broadinstitute.sting.utils.recalibration.ReadCovariates;
|
||||
import org.broadinstitute.sting.utils.recalibration.RecalDatum;
|
||||
import org.broadinstitute.sting.utils.recalibration.RecalibrationTables;
|
||||
import org.broadinstitute.sting.utils.recalibration.*;
|
||||
import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
|
||||
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
|
||||
|
||||
import java.io.PrintStream;
|
||||
import java.util.LinkedList;
|
||||
import java.util.List;
|
||||
|
||||
public class StandardRecalibrationEngine implements RecalibrationEngine, PublicPackageSource {
|
||||
private static final Logger logger = Logger.getLogger(StandardRecalibrationEngine.class);
|
||||
protected Covariate[] covariates;
|
||||
protected RecalibrationTables recalibrationTables;
|
||||
private int numReadGroups;
|
||||
private PrintStream maybeLogStream;
|
||||
private boolean lowMemoryMode = false;
|
||||
|
||||
private boolean finalized = false;
|
||||
private RecalibrationTables mergedRecalibrationTables = null;
|
||||
|
||||
private final List<RecalibrationTables> recalibrationTablesList = new LinkedList<RecalibrationTables>();
|
||||
|
||||
private final ThreadLocal<RecalibrationTables> threadLocalTables = new ThreadLocal<RecalibrationTables>() {
|
||||
private synchronized RecalibrationTables makeAndCaptureTable() {
|
||||
logger.info("Creating RecalibrationTable for " + Thread.currentThread());
|
||||
final RecalibrationTables newTable = new RecalibrationTables(covariates, numReadGroups, maybeLogStream);
|
||||
recalibrationTablesList.add(newTable);
|
||||
return newTable;
|
||||
}
|
||||
|
||||
@Override
|
||||
protected synchronized RecalibrationTables initialValue() {
|
||||
if ( lowMemoryMode ) {
|
||||
return recalibrationTablesList.isEmpty() ? makeAndCaptureTable() : recalibrationTablesList.get(0);
|
||||
} else {
|
||||
return makeAndCaptureTable();
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
protected RecalibrationTables getRecalibrationTables() {
|
||||
return threadLocalTables.get();
|
||||
}
|
||||
|
||||
public void enableLowMemoryMode() {
|
||||
this.lowMemoryMode = true;
|
||||
}
|
||||
|
||||
@Override
|
||||
public void initialize(final Covariate[] covariates, final RecalibrationTables recalibrationTables) {
|
||||
public void initialize(final Covariate[] covariates, final int numReadGroups, final PrintStream maybeLogStream) {
|
||||
if ( covariates == null ) throw new IllegalArgumentException("Covariates cannot be null");
|
||||
if ( recalibrationTables == null ) throw new IllegalArgumentException("recalibrationTables cannot be null");
|
||||
if ( numReadGroups < 1 ) throw new IllegalArgumentException("numReadGroups must be >= 1 but got " + numReadGroups);
|
||||
|
||||
this.covariates = covariates.clone();
|
||||
this.recalibrationTables = recalibrationTables;
|
||||
this.numReadGroups = numReadGroups;
|
||||
this.maybeLogStream = maybeLogStream;
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -59,13 +97,13 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
|
|||
final double isError = recalInfo.getErrorFraction(eventType, offset);
|
||||
final int[] keys = readCovariates.getKeySet(offset, eventType);
|
||||
|
||||
incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventType.index);
|
||||
incrementDatumOrPutIfNecessary(getRecalibrationTables().getQualityScoreTable(), qual, isError, keys[0], keys[1], eventType.index);
|
||||
|
||||
for (int i = 2; i < covariates.length; i++) {
|
||||
if (keys[i] < 0)
|
||||
continue;
|
||||
|
||||
incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventType.index);
|
||||
incrementDatumOrPutIfNecessary(getRecalibrationTables().getTable(i), qual, isError, keys[0], keys[1], keys[i], eventType.index);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -90,8 +128,13 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
|
|||
*/
|
||||
@Override
|
||||
public void finalizeData() {
|
||||
final NestedIntegerArray<RecalDatum> byReadGroupTable = recalibrationTables.getReadGroupTable();
|
||||
final NestedIntegerArray<RecalDatum> byQualTable = recalibrationTables.getQualityScoreTable();
|
||||
if ( finalized ) throw new IllegalStateException("FinalizeData() has already been called");
|
||||
|
||||
// merge all of the thread-local tables
|
||||
mergedRecalibrationTables = mergeThreadLocalRecalibrationTables();
|
||||
|
||||
final NestedIntegerArray<RecalDatum> byReadGroupTable = mergedRecalibrationTables.getReadGroupTable();
|
||||
final NestedIntegerArray<RecalDatum> byQualTable = mergedRecalibrationTables.getQualityScoreTable();
|
||||
|
||||
// iterate over all values in the qual table
|
||||
for ( NestedIntegerArray.Leaf<RecalDatum> leaf : byQualTable.getAllLeaves() ) {
|
||||
|
|
@ -108,6 +151,38 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
|
|||
rgDatum.combine(qualDatum);
|
||||
}
|
||||
}
|
||||
|
||||
finalized = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Merge all of the thread local recalibration tables into a single one.
|
||||
*
|
||||
* Reuses one of the recalibration tables to hold the merged table, so this function can only be
|
||||
* called once in the engine.
|
||||
*
|
||||
* @return the merged recalibration table
|
||||
*/
|
||||
@Requires("! finalized")
|
||||
private RecalibrationTables mergeThreadLocalRecalibrationTables() {
|
||||
if ( recalibrationTablesList.isEmpty() ) throw new IllegalStateException("recalibration tables list is empty");
|
||||
|
||||
RecalibrationTables merged = null;
|
||||
for ( final RecalibrationTables table : recalibrationTablesList ) {
|
||||
if ( merged == null )
|
||||
// fast path -- if there's only only one table, so just make it the merged one
|
||||
merged = table;
|
||||
else {
|
||||
merged.combine(table);
|
||||
}
|
||||
}
|
||||
|
||||
return merged;
|
||||
}
|
||||
|
||||
public RecalibrationTables getFinalRecalibrationTables() {
|
||||
if ( ! finalized ) throw new IllegalStateException("Cannot get final recalibration tables until finalizeData() has been called");
|
||||
return mergedRecalibrationTables;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -123,12 +123,16 @@ public final class RecalibrationTables {
|
|||
}
|
||||
|
||||
/**
|
||||
* Merge in the quality score table information from qualityScoreTable into this
|
||||
* recalibration table's quality score table.
|
||||
*
|
||||
* @param qualityScoreTable the quality score table we want to merge in
|
||||
* Merge all of the tables from toMerge into into this set of tables
|
||||
*/
|
||||
public void combineQualityScoreTable(final NestedIntegerArray<RecalDatum> qualityScoreTable) {
|
||||
RecalUtils.combineTables(getQualityScoreTable(), qualityScoreTable);
|
||||
public void combine(final RecalibrationTables toMerge) {
|
||||
if ( numTables() != toMerge.numTables() )
|
||||
throw new IllegalArgumentException("Attempting to merge RecalibrationTables with different sizes");
|
||||
|
||||
for ( int i = 0; i < numTables(); i++ ) {
|
||||
final NestedIntegerArray<RecalDatum> myTable = this.getTable(i);
|
||||
final NestedIntegerArray<RecalDatum> otherTable = toMerge.getTable(i);
|
||||
RecalUtils.combineTables(myTable, otherTable);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue