From c47a5ff5abc0ea90db23155866db59e6cbb27095 Mon Sep 17 00:00:00 2001 From: depristo Date: Mon, 19 Jul 2010 22:13:18 +0000 Subject: [PATCH] Official parallel CountCovariates, passes all integration tests. Now poster-child example of parallelism in GATK (Matt H). Apparent general performance improvements throughout too. git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@3833 348d0f76-0448-11de-a6fe-93d51630548a --- .../recalibration/CovariateCounterWalker.java | 185 ++++++++++-------- .../utils/collections/NestedHashMap.java | 15 +- .../RecalibrationWalkersIntegrationTest.java | 2 +- 3 files changed, 113 insertions(+), 89 deletions(-) diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/recalibration/CovariateCounterWalker.java b/java/src/org/broadinstitute/sting/gatk/walkers/recalibration/CovariateCounterWalker.java index 060404343..fcdd6519d 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/recalibration/CovariateCounterWalker.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/recalibration/CovariateCounterWalker.java @@ -70,7 +70,7 @@ import java.util.Map; @WalkerName( "CountCovariates" ) @ReadFilters( {ZeroMappingQualityReadFilter.class} ) // Filter out all reads with zero mapping quality @Requires( {DataSource.READS, DataSource.REFERENCE, DataSource.REFERENCE_BASES} ) // This walker requires both -I input.bam and -R reference.fasta -public class CovariateCounterWalker extends LocusWalker implements TreeReducible { +public class CovariateCounterWalker extends LocusWalker implements TreeReducible { ///////////////////////////// // Constants @@ -96,8 +96,6 @@ public class CovariateCounterWalker extends LocusWalker im private String[] COVARIATES = null; @Argument(fullName="standard_covs", shortName="standard", doc="Use the standard set of covariates in addition to the ones listed using the -cov argument", required=false) private boolean USE_STANDARD_COVARIATES = false; - @Argument(fullName="process_nth_locus", shortName="pN", required=false, doc="Only process every Nth covered locus we see.") - private int PROCESS_EVERY_NTH_LOCUS = 1; ///////////////////////////// // Debugging-only Arguments @@ -110,17 +108,39 @@ public class CovariateCounterWalker extends LocusWalker im ///////////////////////////// private final RecalDataManager dataManager = new RecalDataManager(); // Holds the data HashMap, mostly used by TableRecalibrationWalker to create collapsed data hashmaps private final ArrayList requestedCovariates = new ArrayList(); // A list to hold the covariate objects that were requested - private long countedSites = 0; // Number of loci used in the calculations, used for reporting in the output file - private long countedBases = 0; // Number of bases used in the calculations, used for reporting in the output file - private long skippedSites = 0; // Number of loci skipped because it was a dbSNP site, used for reporting in the output file - private long solidInsertedReferenceBases = 0; // Number of bases where we believe SOLID has inserted the reference because the color space is inconsistent with the read base - private long otherColorSpaceInconsistency = 0; // Number of bases where the color space is inconsistent with the read but the reference wasn't inserted. - private int numUnprocessed = 0; // Number of consecutive loci skipped because we are only processing every Nth site - private final Pair dbSNP_counts = new Pair(0L, 0L); // mismatch/base counts for dbSNP loci - private final Pair novel_counts = new Pair(0L, 0L); // mismatch/base counts for non-dbSNP loci private static final double DBSNP_VS_NOVEL_MISMATCH_RATE = 2.0; // rate at which dbSNP sites (on an individual level) mismatch relative to novel sites (determined by looking at NA12878) - private int DBSNP_VALIDATION_CHECK_FREQUENCY = 1000000; // how often to validate dbsnp mismatch rate (in terms of loci seen) - private int lociSinceLastDbsnpCheck = 0; // loci since last dbsnp validation + private static int DBSNP_VALIDATION_CHECK_FREQUENCY = 1000000; // how often to validate dbsnp mismatch rate (in terms of loci seen) + + public static class CountedData { + private long countedSites = 0; // Number of loci used in the calculations, used for reporting in the output file + private long countedBases = 0; // Number of bases used in the calculations, used for reporting in the output file + private long skippedSites = 0; // Number of loci skipped because it was a dbSNP site, used for reporting in the output file + private long solidInsertedReferenceBases = 0; // Number of bases where we believe SOLID has inserted the reference because the color space is inconsistent with the read base + private long otherColorSpaceInconsistency = 0; // Number of bases where the color space is inconsistent with the read but the reference wasn't inserted. + + private long dbSNPCountsMM = 0, dbSNPCountsBases = 0; // mismatch/base counts for dbSNP loci + private long novelCountsMM = 0, novelCountsBases = 0; // mismatch/base counts for non-dbSNP loci + private int lociSinceLastDbsnpCheck = 0; // loci since last dbsnp validation + + /** + * Adds the values of other to this, returning this + * @param other + * @return this object + */ + public CountedData add(CountedData other) { + countedSites += other.countedSites; + countedBases += other.countedBases; + skippedSites += other.skippedSites; + solidInsertedReferenceBases += other.solidInsertedReferenceBases; + otherColorSpaceInconsistency += other.otherColorSpaceInconsistency; + dbSNPCountsMM += other.dbSNPCountsMM; + dbSNPCountsBases += other.dbSNPCountsBases; + novelCountsMM += other.novelCountsMM; + novelCountsBases += other.novelCountsBases; + lociSinceLastDbsnpCheck += other.lociSinceLastDbsnpCheck; + return this; + } + } //--------------------------------------------------------------------------------------------------------------- // @@ -136,7 +156,6 @@ public class CovariateCounterWalker extends LocusWalker im if( RAC.FORCE_READ_GROUP != null ) { RAC.DEFAULT_READ_GROUP = RAC.FORCE_READ_GROUP; } if( RAC.FORCE_PLATFORM != null ) { RAC.DEFAULT_PLATFORM = RAC.FORCE_PLATFORM; } - DBSNP_VALIDATION_CHECK_FREQUENCY *= PROCESS_EVERY_NTH_LOCUS; // Get a list of all available covariates final List> covariateClasses = PackageUtils.getClassesImplementingInterface( Covariate.class ); @@ -253,7 +272,7 @@ public class CovariateCounterWalker extends LocusWalker im * @param context The alignment context * @return Returns 1, but this value isn't used in the reduce step */ - public Integer map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) { + public CountedData map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) { // Pull out data for this locus for all the input RODs and check if this is a known variant site in any of them boolean isSNP = false; @@ -266,9 +285,8 @@ public class CovariateCounterWalker extends LocusWalker im // Only use data from non-dbsnp sites // Assume every mismatch at a non-dbsnp site is indicative of poor quality - if( !isSNP && ( ++numUnprocessed >= PROCESS_EVERY_NTH_LOCUS ) ) { - numUnprocessed = 0; // Reset the counter because we are processing this very locus - + CountedData counter = new CountedData(); + if( !isSNP ) { // For each read at this locus for( PileupElement p : context.getBasePileup() ) { GATKSAMRecord gatkRead = (GATKSAMRecord) p.getRead(); @@ -309,33 +327,25 @@ public class CovariateCounterWalker extends LocusWalker im !RecalDataManager.isInconsistentColorSpace( gatkRead, offset ) ) { // This base finally passed all the checks for a good base, so add it to the big data hashmap - updateDataFromRead( gatkRead, offset, refBase ); + updateDataFromRead( counter, gatkRead, offset, refBase ); } else { // calculate SOLID reference insertion rate if( refBase == bases[offset] ) { - solidInsertedReferenceBases++; + counter.solidInsertedReferenceBases++; } else { - otherColorSpaceInconsistency++; + counter.otherColorSpaceInconsistency++; } } } } } - countedSites++; + counter.countedSites++; } else { // We skipped over the dbSNP site, and we are only processing every Nth locus - skippedSites++; - if( isSNP ) { - updateMismatchCounts(dbSNP_counts, context, ref.getBase()); // For sanity check to ensure novel mismatch rate vs dnsnp mismatch rate is reasonable - } + counter.skippedSites++; + updateMismatchCounts(counter, context, ref.getBase()); // For sanity check to ensure novel mismatch rate vs dnsnp mismatch rate is reasonable } - // Do a dbSNP sanity check every so often - if( ++lociSinceLastDbsnpCheck == DBSNP_VALIDATION_CHECK_FREQUENCY ) { - lociSinceLastDbsnpCheck = 0; - validateDbsnpMismatchRate(); - } - - return 1; // This value isn't actually used anywhere + return counter; // This value isn't actually used anywhere } @@ -343,11 +353,11 @@ public class CovariateCounterWalker extends LocusWalker im /** * Update the mismatch / total_base counts for a given class of loci. * - * @param counts The counts to be updated + * @param counter The CountedData to be updated * @param context The AlignmentContext which holds the reads covered by this locus * @param refBase The reference base */ - private static void updateMismatchCounts(final Pair counts, final AlignmentContext context, final byte refBase) { + private static void updateMismatchCounts(CountedData counter, final AlignmentContext context, final byte refBase) { for( PileupElement p : context.getBasePileup() ) { final byte readBase = p.getBase(); final int readBaseIndex = BaseUtils.simpleBaseToBaseIndex(readBase); @@ -355,31 +365,13 @@ public class CovariateCounterWalker extends LocusWalker im if( readBaseIndex != -1 && refBaseIndex != -1 ) { if( readBaseIndex != refBaseIndex ) { - counts.first++; + counter.novelCountsMM++; } - counts.second++; + counter.novelCountsBases++; } } } - /** - * Validate the dbSNP reference mismatch rates. - */ - private void validateDbsnpMismatchRate() { - if( novel_counts.second == 0L || dbSNP_counts.second == 0L ) { - return; - } - - final double fractionMM_novel = (double)novel_counts.first / (double)novel_counts.second; - final double fractionMM_dbsnp = (double)dbSNP_counts.first / (double)dbSNP_counts.second; - - if( fractionMM_dbsnp < DBSNP_VS_NOVEL_MISMATCH_RATE * fractionMM_novel ) { - Utils.warnUser("The variation rate at the supplied list of known variant sites seems suspiciously low. Please double-check that the correct ROD is being used. " + - String.format("[dbSNP variation rate = %.4f, novel variation rate = %.4f]", fractionMM_dbsnp, fractionMM_novel) ); - DBSNP_VALIDATION_CHECK_FREQUENCY *= 2; // Don't annoyingly output the warning message every megabase of a large file - } - } - /** * Major workhorse routine for this walker. * Loop through the list of requested covariates and pick out the value from the read, offset, and reference @@ -391,7 +383,7 @@ public class CovariateCounterWalker extends LocusWalker im * @param offset The offset in the read for this locus * @param refBase The reference base at this locus */ - private void updateDataFromRead(final GATKSAMRecord gatkRead, final int offset, final byte refBase) { + private void updateDataFromRead(CountedData counter, final GATKSAMRecord gatkRead, final int offset, final byte refBase) { final Object[][] covars = (Comparable[][]) gatkRead.getTemporaryAttribute(COVARS_ATTRIBUTE); final Object[] key = covars[offset]; @@ -399,8 +391,8 @@ public class CovariateCounterWalker extends LocusWalker im final NestedHashMap data = dataManager.data; //optimization - create local reference RecalDatumOptimized datum = (RecalDatumOptimized) data.get( key ); if( datum == null ) { // key doesn't exist yet in the map so make a new bucket and add it - datum = new RecalDatumOptimized(); // initialized with zeros, will be incremented at end of method - data.put( datum, (Object[])key ); + // initialized with zeros, will be incremented at end of method + datum = (RecalDatumOptimized)data.put( new RecalDatumOptimized(), true, (Object[])key ); } // Need the bases to determine whether or not we have a mismatch @@ -409,9 +401,9 @@ public class CovariateCounterWalker extends LocusWalker im // Add one to the number of observations and potentially one to the number of mismatches datum.incrementBaseCounts( base, refBase ); - countedBases++; - novel_counts.second++; - novel_counts.first += datum.getNumMismatches() - curMismatches; // For sanity check to ensure novel mismatch rate vs dnsnp mismatch rate is reasonable + counter.countedBases++; + counter.novelCountsBases++; + counter.novelCountsMM += datum.getNumMismatches() - curMismatches; // For sanity check to ensure novel mismatch rate vs dnsnp mismatch rate is reasonable } @@ -422,56 +414,75 @@ public class CovariateCounterWalker extends LocusWalker im //--------------------------------------------------------------------------------------------------------------- /** - * Initialize the reudce step by creating a PrintStream from the filename specified as an argument to the walker. + * Initialize the reduce step by creating a PrintStream from the filename specified as an argument to the walker. * @return returns A PrintStream created from the -recalFile filename argument specified to the walker */ - public PrintStream reduceInit() { - return RECAL_FILE; + public CountedData reduceInit() { + return new CountedData(); } /** * The Reduce method doesn't do anything for this walker. - * @param value Result of the map. This value is immediately ignored. - * @param recalTableStream The PrintStream used to output the CSV data - * @return returns The PrintStream used to output the CSV data + * @param mapped Result of the map. This value is immediately ignored. + * @param sum The summing CountedData used to output the CSV data + * @return returns The sum used to output the CSV data */ - public PrintStream reduce( Integer value, PrintStream recalTableStream ) { - return recalTableStream; // Nothing to do here, just return our open stream + public CountedData reduce( CountedData mapped, CountedData sum ) { + // Do a dbSNP sanity check every so often + return validatingDbsnpMismatchRate(sum.add(mapped)); } - public PrintStream treeReduce( PrintStream recalTableStream1, PrintStream recalTableStream2 ) { - return recalTableStream1; // Nothing to do here, just return our open stream + /** + * Validate the dbSNP reference mismatch rates. + */ + private CountedData validatingDbsnpMismatchRate(CountedData counter) { + if( ++counter.lociSinceLastDbsnpCheck >= DBSNP_VALIDATION_CHECK_FREQUENCY ) { + counter.lociSinceLastDbsnpCheck = 0; + + if( counter.novelCountsBases != 0L && counter.dbSNPCountsBases != 0L ) { + final double fractionMM_novel = (double)counter.novelCountsMM / (double)counter.novelCountsBases; + final double fractionMM_dbsnp = (double)counter.dbSNPCountsMM / (double)counter.dbSNPCountsBases; + + if( fractionMM_dbsnp < DBSNP_VS_NOVEL_MISMATCH_RATE * fractionMM_novel ) { + Utils.warnUser("The variation rate at the supplied list of known variant sites seems suspiciously low. Please double-check that the correct ROD is being used. " + + String.format("[dbSNP variation rate = %.4f, novel variation rate = %.4f]", fractionMM_dbsnp, fractionMM_novel) ); + DBSNP_VALIDATION_CHECK_FREQUENCY *= 2; // Don't annoyingly output the warning message every megabase of a large file + } + } + } + + return counter; + } + + public CountedData treeReduce( CountedData sum1, CountedData sum2 ) { + return validatingDbsnpMismatchRate(sum1.add(sum2)); } /** * Write out the full data hashmap to disk in CSV format - * @param recalTableStream The PrintStream to write out to + * @param sum The CountedData to write out to RECAL_FILE */ - public void onTraversalDone( PrintStream recalTableStream ) { + public void onTraversalDone( CountedData sum ) { logger.info( "Writing raw recalibration data..." ); - outputToCSV( recalTableStream ); + outputToCSV( sum, RECAL_FILE ); logger.info( "...done!" ); - recalTableStream.close(); + RECAL_FILE.close(); } /** * For each entry (key-value pair) in the data hashmap output the Covariate's values as well as the RecalDatum's data in CSV format * @param recalTableStream The PrintStream to write out to */ - private void outputToCSV( final PrintStream recalTableStream ) { + private void outputToCSV( CountedData sum, final PrintStream recalTableStream ) { + recalTableStream.printf("# Counted Sites %d%n", sum.countedSites); + recalTableStream.printf("# Counted Bases %d%n", sum.countedBases); + recalTableStream.printf("# Skipped Sites %d%n", sum.skippedSites); + recalTableStream.printf("# Fraction Skipped 1 / %.0f bp%n", (double)sum.countedSites / sum.skippedSites); - recalTableStream.printf("# Counted Sites %d%n", countedSites); - recalTableStream.printf("# Counted Bases %d%n", countedBases); - recalTableStream.printf("# Skipped Sites %d%n", skippedSites); - if( PROCESS_EVERY_NTH_LOCUS == 1 ) { - recalTableStream.printf("# Fraction Skipped 1 / %.0f bp%n", (double)countedSites / skippedSites); - } else { - recalTableStream.printf("# Percent Skipped %.4f%n", 100.0 * (double)skippedSites / ((double)countedSites+skippedSites)); - } - if( solidInsertedReferenceBases != 0 ) { - recalTableStream.printf("# Fraction SOLiD inserted reference 1 / %.0f bases%n", (double) countedBases / solidInsertedReferenceBases); - recalTableStream.printf("# Fraction other color space inconsistencies 1 / %.0f bases%n", (double) countedBases / otherColorSpaceInconsistency); + if( sum.solidInsertedReferenceBases != 0 ) { + recalTableStream.printf("# Fraction SOLiD inserted reference 1 / %.0f bases%n", (double) sum.countedBases / sum.solidInsertedReferenceBases); + recalTableStream.printf("# Fraction other color space inconsistencies 1 / %.0f bases%n", (double) sum.countedBases / sum.otherColorSpaceInconsistency); } // Output header saying which covariates were used and in what order diff --git a/java/src/org/broadinstitute/sting/utils/collections/NestedHashMap.java b/java/src/org/broadinstitute/sting/utils/collections/NestedHashMap.java index e383a6f7c..587bcc724 100755 --- a/java/src/org/broadinstitute/sting/utils/collections/NestedHashMap.java +++ b/java/src/org/broadinstitute/sting/utils/collections/NestedHashMap.java @@ -53,12 +53,23 @@ public class NestedHashMap{ } public synchronized void put( final Object value, final Object... keys ) { + this.put(value, false, keys ); + } + public synchronized Object put( final Object value, boolean keepOldBindingIfPresent, final Object... keys ) { Map map = this.data; final int keysLength = keys.length; for( int iii = 0; iii < keysLength; iii++ ) { if( iii == keysLength - 1 ) { - map.put(keys[iii], value); + if ( keepOldBindingIfPresent && map.containsKey(keys[iii]) ) { + // this code test is for parallel protection when you call put() multiple times in different threads + // to initialize the map. It returns the already bound key[iii] -> value + return map.get(keys[iii]); + } else { + // we are a new binding, put it in the map + map.put(keys[iii], value); + return value; + } } else { Map tmp = (Map) map.get(keys[iii]); if( tmp == null ) { @@ -68,5 +79,7 @@ public class NestedHashMap{ map = tmp; } } + + return value; // todo -- should never reach this point } } diff --git a/java/test/org/broadinstitute/sting/gatk/walkers/recalibration/RecalibrationWalkersIntegrationTest.java b/java/test/org/broadinstitute/sting/gatk/walkers/recalibration/RecalibrationWalkersIntegrationTest.java index ff828f071..b906cdd4e 100755 --- a/java/test/org/broadinstitute/sting/gatk/walkers/recalibration/RecalibrationWalkersIntegrationTest.java +++ b/java/test/org/broadinstitute/sting/gatk/walkers/recalibration/RecalibrationWalkersIntegrationTest.java @@ -22,7 +22,7 @@ public class RecalibrationWalkersIntegrationTest extends WalkerTest { e.put( validationDataLocation + "NA12873.454.SRP000031.2009_06.chr1.10_20mb.bam", "596a9ec9cbc1da70481e45a5a588a41d" ); e.put( validationDataLocation + "NA12878.1kg.p2.chr1_10mb_11_mb.allTechs.bam", "507dbd3ba6f54e066d04c4d24f59c3ab" ); - for ( String parallelism : Arrays.asList("") ) { // todo -- enable parallel tests. They work but there's a system bug Arrays.asList("", " -nt 4")) { + for ( String parallelism : Arrays.asList("", " -nt 4")) { for ( Map.Entry entry : e.entrySet() ) { String bam = entry.getKey(); String md5 = entry.getValue();