diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java index 255f1fd05..3871101eb 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java @@ -25,35 +25,21 @@ package org.broadinstitute.sting.gatk.walkers.bqsr; * OTHER DEALINGS IN THE SOFTWARE. */ -import org.apache.log4j.Logger; import org.broadinstitute.sting.utils.classloader.ProtectedPackageSource; import org.broadinstitute.sting.utils.collections.NestedIntegerArray; import org.broadinstitute.sting.utils.recalibration.EventType; import org.broadinstitute.sting.utils.recalibration.ReadCovariates; import org.broadinstitute.sting.utils.recalibration.RecalDatum; +import org.broadinstitute.sting.utils.recalibration.RecalibrationTables; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; -import java.util.LinkedList; -import java.util.List; - public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine implements ProtectedPackageSource { - private final static Logger logger = Logger.getLogger(AdvancedRecalibrationEngine.class); - - final List> allThreadLocalQualityScoreTables = new LinkedList>(); - private ThreadLocal> threadLocalQualityScoreTables = new ThreadLocal>() { - @Override - protected synchronized NestedIntegerArray initialValue() { - final NestedIntegerArray table = recalibrationTables.makeQualityScoreTable(); - allThreadLocalQualityScoreTables.add(table); - return table; - } - }; - @Override public void updateDataForRead( final ReadRecalibrationInfo recalInfo ) { final GATKSAMRecord read = recalInfo.getRead(); final ReadCovariates readCovariates = recalInfo.getCovariatesValues(); - final NestedIntegerArray qualityScoreTable = getThreadLocalQualityScoreTable(); + final RecalibrationTables tables = getRecalibrationTables(); + final NestedIntegerArray qualityScoreTable = tables.getQualityScoreTable(); for( int offset = 0; offset < read.getReadBases().length; offset++ ) { if( ! recalInfo.skip(offset) ) { @@ -70,30 +56,10 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp if (keys[i] < 0) continue; - incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex); + incrementDatumOrPutIfNecessary(tables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex); } } } } } - - /** - * Get a NestedIntegerArray for a QualityScore table specific to this thread - * @return a non-null NestedIntegerArray ready to be used to collect calibration info for the quality score covariate - */ - private NestedIntegerArray getThreadLocalQualityScoreTable() { - return threadLocalQualityScoreTables.get(); - } - - @Override - public void finalizeData() { - // merge in all of the thread local tables - logger.info("Merging " + allThreadLocalQualityScoreTables.size() + " thread-local quality score tables"); - for ( final NestedIntegerArray localTable : allThreadLocalQualityScoreTables ) { - recalibrationTables.combineQualityScoreTable(localTable); - } - allThreadLocalQualityScoreTables.clear(); // cleanup after ourselves - - super.finalizeData(); - } } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java index 7692c58e2..ffcfd6233 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java @@ -120,9 +120,16 @@ public class BaseRecalibrator extends ReadWalker implements NanoSche @Argument(fullName = "bqsrBAQGapOpenPenalty", shortName="bqsrBAQGOP", doc="BQSR BAQ gap open penalty (Phred Scaled). Default value is 40. 30 is perhaps better for whole genome call sets", required = false) public double BAQGOP = BAQ.DEFAULT_GOP; - private QuantizationInfo quantizationInfo; // an object that keeps track of the information necessary for quality score quantization + /** + * When you have nct > 1, BQSR uses nct times more memory to compute its recalibration tables, for efficiency + * purposes. If you have many covariates, and therefore are using a lot of memory, you can use this flag + * to safely access only one table. There may be some CPU cost, but as long as the table is really big + * there should be relatively little CPU costs. + */ + @Argument(fullName = "lowMemoryMode", shortName="lowMemoryMode", doc="Reduce memory usage in multi-threaded code at the expense of threading efficiency", required = false) + public boolean lowMemoryMode = false; - private RecalibrationTables recalibrationTables; + private QuantizationInfo quantizationInfo; // an object that keeps track of the information necessary for quality score quantization private Covariate[] requestedCovariates; // list to hold the all the covariate objects that were requested (required + standard + experimental) @@ -130,8 +137,6 @@ public class BaseRecalibrator extends ReadWalker implements NanoSche private int minimumQToUse; - protected static final String COVARS_ATTRIBUTE = "COVARS"; // used to store covariates array as a temporary attribute inside GATKSAMRecord.\ - private static final String NO_DBSNP_EXCEPTION = "This calculation is critically dependent on being able to skip over known variant sites. Please provide a VCF file containing known sites of genetic variation."; private BAQ baq; // BAQ the reads on the fly to generate the alignment uncertainty vector @@ -143,7 +148,6 @@ public class BaseRecalibrator extends ReadWalker implements NanoSche * Based on the covariates' estimates for initial capacity allocate the data hashmap */ public void initialize() { - baq = new BAQ(BAQGOP); // setup the BAQ object with the provided gap open penalty // check for unsupported access @@ -188,10 +192,11 @@ public class BaseRecalibrator extends ReadWalker implements NanoSche int numReadGroups = 0; for ( final SAMFileHeader header : getToolkit().getSAMFileHeaders() ) numReadGroups += header.getReadGroups().size(); - recalibrationTables = new RecalibrationTables(requestedCovariates, numReadGroups, RAC.RECAL_TABLE_UPDATE_LOG); recalibrationEngine = initializeRecalibrationEngine(); - recalibrationEngine.initialize(requestedCovariates, recalibrationTables); + recalibrationEngine.initialize(requestedCovariates, numReadGroups, RAC.RECAL_TABLE_UPDATE_LOG); + if ( lowMemoryMode ) + recalibrationEngine.enableLowMemoryMode(); minimumQToUse = getToolkit().getArguments().PRESERVE_QSCORES_LESS_THAN; referenceReader = getToolkit().getReferenceDataSource().getReference(); @@ -501,14 +506,18 @@ public class BaseRecalibrator extends ReadWalker implements NanoSche logger.info("Processed: " + result + " reads"); } + private RecalibrationTables getRecalibrationTable() { + return recalibrationEngine.getFinalRecalibrationTables(); + } + private void generatePlots() { File recalFile = getToolkit().getArguments().BQSR_RECAL_FILE; if (recalFile != null) { RecalibrationReport report = new RecalibrationReport(recalFile); - RecalUtils.generateRecalibrationPlot(RAC, report.getRecalibrationTables(), recalibrationTables, requestedCovariates); + RecalUtils.generateRecalibrationPlot(RAC, report.getRecalibrationTables(), getRecalibrationTable(), requestedCovariates); } else - RecalUtils.generateRecalibrationPlot(RAC, recalibrationTables, requestedCovariates); + RecalUtils.generateRecalibrationPlot(RAC, getRecalibrationTable(), requestedCovariates); } /** @@ -517,10 +526,10 @@ public class BaseRecalibrator extends ReadWalker implements NanoSche * generate a quantization map (recalibrated_qual -> quantized_qual) */ private void quantizeQualityScores() { - quantizationInfo = new QuantizationInfo(recalibrationTables, RAC.QUANTIZING_LEVELS); + quantizationInfo = new QuantizationInfo(getRecalibrationTable(), RAC.QUANTIZING_LEVELS); } private void generateReport() { - RecalUtils.outputRecalibrationReport(RAC, quantizationInfo, recalibrationTables, requestedCovariates, RAC.SORT_BY_ALL_COLUMNS); + RecalUtils.outputRecalibrationReport(RAC, quantizationInfo, getRecalibrationTable(), requestedCovariates, RAC.SORT_BY_ALL_COLUMNS); } } \ No newline at end of file diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationEngine.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationEngine.java index 5c002b7e5..6c3189be5 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationEngine.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationEngine.java @@ -5,6 +5,8 @@ import org.broadinstitute.sting.utils.recalibration.RecalibrationTables; import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; +import java.io.PrintStream; + /* * Copyright (c) 2009 The Broad Institute * @@ -40,9 +42,10 @@ public interface RecalibrationEngine { * The engine should collect match and mismatch data into the recalibrationTables data. * * @param covariates an array of the covariates we'll be using in this engine, order matters - * @param recalibrationTables the destination recalibrationTables where stats should be collected + * @param numReadGroups the number of read groups we should use for the recalibration tables + * @param maybeLogStream an optional print stream for logging calls to the nestedhashmap in the recalibration tables */ - public void initialize(final Covariate[] covariates, final RecalibrationTables recalibrationTables); + public void initialize(final Covariate[] covariates, final int numReadGroups, final PrintStream maybeLogStream); /** * Update the recalibration statistics using the information in recalInfo @@ -57,4 +60,8 @@ public interface RecalibrationEngine { * Called once after all calls to updateDataForRead have been issued. */ public void finalizeData(); + + public void enableLowMemoryMode(); + + public RecalibrationTables getFinalRecalibrationTables(); } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java index a6ab98e8b..0cd042eeb 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java @@ -25,26 +25,64 @@ package org.broadinstitute.sting.gatk.walkers.bqsr; * OTHER DEALINGS IN THE SOFTWARE. */ +import com.google.java.contract.Requires; +import org.apache.log4j.Logger; import org.broadinstitute.sting.utils.classloader.PublicPackageSource; import org.broadinstitute.sting.utils.collections.NestedIntegerArray; -import org.broadinstitute.sting.utils.recalibration.EventType; -import org.broadinstitute.sting.utils.recalibration.ReadCovariates; -import org.broadinstitute.sting.utils.recalibration.RecalDatum; -import org.broadinstitute.sting.utils.recalibration.RecalibrationTables; +import org.broadinstitute.sting.utils.recalibration.*; import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; +import java.io.PrintStream; +import java.util.LinkedList; +import java.util.List; + public class StandardRecalibrationEngine implements RecalibrationEngine, PublicPackageSource { + private static final Logger logger = Logger.getLogger(StandardRecalibrationEngine.class); protected Covariate[] covariates; - protected RecalibrationTables recalibrationTables; + private int numReadGroups; + private PrintStream maybeLogStream; + private boolean lowMemoryMode = false; + + private boolean finalized = false; + private RecalibrationTables mergedRecalibrationTables = null; + + private final List recalibrationTablesList = new LinkedList(); + + private final ThreadLocal threadLocalTables = new ThreadLocal() { + private synchronized RecalibrationTables makeAndCaptureTable() { + logger.info("Creating RecalibrationTable for " + Thread.currentThread()); + final RecalibrationTables newTable = new RecalibrationTables(covariates, numReadGroups, maybeLogStream); + recalibrationTablesList.add(newTable); + return newTable; + } + + @Override + protected synchronized RecalibrationTables initialValue() { + if ( lowMemoryMode ) { + return recalibrationTablesList.isEmpty() ? makeAndCaptureTable() : recalibrationTablesList.get(0); + } else { + return makeAndCaptureTable(); + } + } + }; + + protected RecalibrationTables getRecalibrationTables() { + return threadLocalTables.get(); + } + + public void enableLowMemoryMode() { + this.lowMemoryMode = true; + } @Override - public void initialize(final Covariate[] covariates, final RecalibrationTables recalibrationTables) { + public void initialize(final Covariate[] covariates, final int numReadGroups, final PrintStream maybeLogStream) { if ( covariates == null ) throw new IllegalArgumentException("Covariates cannot be null"); - if ( recalibrationTables == null ) throw new IllegalArgumentException("recalibrationTables cannot be null"); + if ( numReadGroups < 1 ) throw new IllegalArgumentException("numReadGroups must be >= 1 but got " + numReadGroups); this.covariates = covariates.clone(); - this.recalibrationTables = recalibrationTables; + this.numReadGroups = numReadGroups; + this.maybeLogStream = maybeLogStream; } @Override @@ -59,13 +97,13 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP final double isError = recalInfo.getErrorFraction(eventType, offset); final int[] keys = readCovariates.getKeySet(offset, eventType); - incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventType.index); + incrementDatumOrPutIfNecessary(getRecalibrationTables().getQualityScoreTable(), qual, isError, keys[0], keys[1], eventType.index); for (int i = 2; i < covariates.length; i++) { if (keys[i] < 0) continue; - incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventType.index); + incrementDatumOrPutIfNecessary(getRecalibrationTables().getTable(i), qual, isError, keys[0], keys[1], keys[i], eventType.index); } } } @@ -90,8 +128,13 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP */ @Override public void finalizeData() { - final NestedIntegerArray byReadGroupTable = recalibrationTables.getReadGroupTable(); - final NestedIntegerArray byQualTable = recalibrationTables.getQualityScoreTable(); + if ( finalized ) throw new IllegalStateException("FinalizeData() has already been called"); + + // merge all of the thread-local tables + mergedRecalibrationTables = mergeThreadLocalRecalibrationTables(); + + final NestedIntegerArray byReadGroupTable = mergedRecalibrationTables.getReadGroupTable(); + final NestedIntegerArray byQualTable = mergedRecalibrationTables.getQualityScoreTable(); // iterate over all values in the qual table for ( NestedIntegerArray.Leaf leaf : byQualTable.getAllLeaves() ) { @@ -108,6 +151,38 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP rgDatum.combine(qualDatum); } } + + finalized = true; + } + + /** + * Merge all of the thread local recalibration tables into a single one. + * + * Reuses one of the recalibration tables to hold the merged table, so this function can only be + * called once in the engine. + * + * @return the merged recalibration table + */ + @Requires("! finalized") + private RecalibrationTables mergeThreadLocalRecalibrationTables() { + if ( recalibrationTablesList.isEmpty() ) throw new IllegalStateException("recalibration tables list is empty"); + + RecalibrationTables merged = null; + for ( final RecalibrationTables table : recalibrationTablesList ) { + if ( merged == null ) + // fast path -- if there's only only one table, so just make it the merged one + merged = table; + else { + merged.combine(table); + } + } + + return merged; + } + + public RecalibrationTables getFinalRecalibrationTables() { + if ( ! finalized ) throw new IllegalStateException("Cannot get final recalibration tables until finalizeData() has been called"); + return mergedRecalibrationTables; } /** diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalibrationTables.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalibrationTables.java index 3f968d7f6..a6b1e13b9 100644 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalibrationTables.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalibrationTables.java @@ -123,12 +123,16 @@ public final class RecalibrationTables { } /** - * Merge in the quality score table information from qualityScoreTable into this - * recalibration table's quality score table. - * - * @param qualityScoreTable the quality score table we want to merge in + * Merge all of the tables from toMerge into into this set of tables */ - public void combineQualityScoreTable(final NestedIntegerArray qualityScoreTable) { - RecalUtils.combineTables(getQualityScoreTable(), qualityScoreTable); + public void combine(final RecalibrationTables toMerge) { + if ( numTables() != toMerge.numTables() ) + throw new IllegalArgumentException("Attempting to merge RecalibrationTables with different sizes"); + + for ( int i = 0; i < numTables(); i++ ) { + final NestedIntegerArray myTable = this.getTable(i); + final NestedIntegerArray otherTable = toMerge.getTable(i); + RecalUtils.combineTables(myTable, otherTable); + } } }