Official parallel CountCovariates, passes all integration tests. Now poster-child example of parallelism in GATK (Matt H). Apparent general performance improvements throughout too.

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@3833 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
depristo 2010-07-19 22:13:18 +00:00
parent 0b56003d1a
commit c47a5ff5ab
3 changed files with 113 additions and 89 deletions

View File

@ -70,7 +70,7 @@ import java.util.Map;
@WalkerName( "CountCovariates" )
@ReadFilters( {ZeroMappingQualityReadFilter.class} ) // Filter out all reads with zero mapping quality
@Requires( {DataSource.READS, DataSource.REFERENCE, DataSource.REFERENCE_BASES} ) // This walker requires both -I input.bam and -R reference.fasta
public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> implements TreeReducible<PrintStream> {
public class CovariateCounterWalker extends LocusWalker<CovariateCounterWalker.CountedData, CovariateCounterWalker.CountedData> implements TreeReducible<CovariateCounterWalker.CountedData> {
/////////////////////////////
// Constants
@ -96,8 +96,6 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
private String[] COVARIATES = null;
@Argument(fullName="standard_covs", shortName="standard", doc="Use the standard set of covariates in addition to the ones listed using the -cov argument", required=false)
private boolean USE_STANDARD_COVARIATES = false;
@Argument(fullName="process_nth_locus", shortName="pN", required=false, doc="Only process every Nth covered locus we see.")
private int PROCESS_EVERY_NTH_LOCUS = 1;
/////////////////////////////
// Debugging-only Arguments
@ -110,17 +108,39 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
/////////////////////////////
private final RecalDataManager dataManager = new RecalDataManager(); // Holds the data HashMap, mostly used by TableRecalibrationWalker to create collapsed data hashmaps
private final ArrayList<Covariate> requestedCovariates = new ArrayList<Covariate>(); // A list to hold the covariate objects that were requested
private long countedSites = 0; // Number of loci used in the calculations, used for reporting in the output file
private long countedBases = 0; // Number of bases used in the calculations, used for reporting in the output file
private long skippedSites = 0; // Number of loci skipped because it was a dbSNP site, used for reporting in the output file
private long solidInsertedReferenceBases = 0; // Number of bases where we believe SOLID has inserted the reference because the color space is inconsistent with the read base
private long otherColorSpaceInconsistency = 0; // Number of bases where the color space is inconsistent with the read but the reference wasn't inserted.
private int numUnprocessed = 0; // Number of consecutive loci skipped because we are only processing every Nth site
private final Pair<Long, Long> dbSNP_counts = new Pair<Long, Long>(0L, 0L); // mismatch/base counts for dbSNP loci
private final Pair<Long, Long> novel_counts = new Pair<Long, Long>(0L, 0L); // mismatch/base counts for non-dbSNP loci
private static final double DBSNP_VS_NOVEL_MISMATCH_RATE = 2.0; // rate at which dbSNP sites (on an individual level) mismatch relative to novel sites (determined by looking at NA12878)
private int DBSNP_VALIDATION_CHECK_FREQUENCY = 1000000; // how often to validate dbsnp mismatch rate (in terms of loci seen)
private int lociSinceLastDbsnpCheck = 0; // loci since last dbsnp validation
private static int DBSNP_VALIDATION_CHECK_FREQUENCY = 1000000; // how often to validate dbsnp mismatch rate (in terms of loci seen)
public static class CountedData {
private long countedSites = 0; // Number of loci used in the calculations, used for reporting in the output file
private long countedBases = 0; // Number of bases used in the calculations, used for reporting in the output file
private long skippedSites = 0; // Number of loci skipped because it was a dbSNP site, used for reporting in the output file
private long solidInsertedReferenceBases = 0; // Number of bases where we believe SOLID has inserted the reference because the color space is inconsistent with the read base
private long otherColorSpaceInconsistency = 0; // Number of bases where the color space is inconsistent with the read but the reference wasn't inserted.
private long dbSNPCountsMM = 0, dbSNPCountsBases = 0; // mismatch/base counts for dbSNP loci
private long novelCountsMM = 0, novelCountsBases = 0; // mismatch/base counts for non-dbSNP loci
private int lociSinceLastDbsnpCheck = 0; // loci since last dbsnp validation
/**
* Adds the values of other to this, returning this
* @param other
* @return this object
*/
public CountedData add(CountedData other) {
countedSites += other.countedSites;
countedBases += other.countedBases;
skippedSites += other.skippedSites;
solidInsertedReferenceBases += other.solidInsertedReferenceBases;
otherColorSpaceInconsistency += other.otherColorSpaceInconsistency;
dbSNPCountsMM += other.dbSNPCountsMM;
dbSNPCountsBases += other.dbSNPCountsBases;
novelCountsMM += other.novelCountsMM;
novelCountsBases += other.novelCountsBases;
lociSinceLastDbsnpCheck += other.lociSinceLastDbsnpCheck;
return this;
}
}
//---------------------------------------------------------------------------------------------------------------
//
@ -136,7 +156,6 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
if( RAC.FORCE_READ_GROUP != null ) { RAC.DEFAULT_READ_GROUP = RAC.FORCE_READ_GROUP; }
if( RAC.FORCE_PLATFORM != null ) { RAC.DEFAULT_PLATFORM = RAC.FORCE_PLATFORM; }
DBSNP_VALIDATION_CHECK_FREQUENCY *= PROCESS_EVERY_NTH_LOCUS;
// Get a list of all available covariates
final List<Class<? extends Covariate>> covariateClasses = PackageUtils.getClassesImplementingInterface( Covariate.class );
@ -253,7 +272,7 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
* @param context The alignment context
* @return Returns 1, but this value isn't used in the reduce step
*/
public Integer map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) {
public CountedData map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) {
// Pull out data for this locus for all the input RODs and check if this is a known variant site in any of them
boolean isSNP = false;
@ -266,9 +285,8 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
// Only use data from non-dbsnp sites
// Assume every mismatch at a non-dbsnp site is indicative of poor quality
if( !isSNP && ( ++numUnprocessed >= PROCESS_EVERY_NTH_LOCUS ) ) {
numUnprocessed = 0; // Reset the counter because we are processing this very locus
CountedData counter = new CountedData();
if( !isSNP ) {
// For each read at this locus
for( PileupElement p : context.getBasePileup() ) {
GATKSAMRecord gatkRead = (GATKSAMRecord) p.getRead();
@ -309,33 +327,25 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
!RecalDataManager.isInconsistentColorSpace( gatkRead, offset ) ) {
// This base finally passed all the checks for a good base, so add it to the big data hashmap
updateDataFromRead( gatkRead, offset, refBase );
updateDataFromRead( counter, gatkRead, offset, refBase );
} else { // calculate SOLID reference insertion rate
if( refBase == bases[offset] ) {
solidInsertedReferenceBases++;
counter.solidInsertedReferenceBases++;
} else {
otherColorSpaceInconsistency++;
counter.otherColorSpaceInconsistency++;
}
}
}
}
}
countedSites++;
counter.countedSites++;
} else { // We skipped over the dbSNP site, and we are only processing every Nth locus
skippedSites++;
if( isSNP ) {
updateMismatchCounts(dbSNP_counts, context, ref.getBase()); // For sanity check to ensure novel mismatch rate vs dnsnp mismatch rate is reasonable
}
counter.skippedSites++;
updateMismatchCounts(counter, context, ref.getBase()); // For sanity check to ensure novel mismatch rate vs dnsnp mismatch rate is reasonable
}
// Do a dbSNP sanity check every so often
if( ++lociSinceLastDbsnpCheck == DBSNP_VALIDATION_CHECK_FREQUENCY ) {
lociSinceLastDbsnpCheck = 0;
validateDbsnpMismatchRate();
}
return 1; // This value isn't actually used anywhere
return counter; // This value isn't actually used anywhere
}
@ -343,11 +353,11 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
/**
* Update the mismatch / total_base counts for a given class of loci.
*
* @param counts The counts to be updated
* @param counter The CountedData to be updated
* @param context The AlignmentContext which holds the reads covered by this locus
* @param refBase The reference base
*/
private static void updateMismatchCounts(final Pair<Long, Long> counts, final AlignmentContext context, final byte refBase) {
private static void updateMismatchCounts(CountedData counter, final AlignmentContext context, final byte refBase) {
for( PileupElement p : context.getBasePileup() ) {
final byte readBase = p.getBase();
final int readBaseIndex = BaseUtils.simpleBaseToBaseIndex(readBase);
@ -355,31 +365,13 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
if( readBaseIndex != -1 && refBaseIndex != -1 ) {
if( readBaseIndex != refBaseIndex ) {
counts.first++;
counter.novelCountsMM++;
}
counts.second++;
counter.novelCountsBases++;
}
}
}
/**
* Validate the dbSNP reference mismatch rates.
*/
private void validateDbsnpMismatchRate() {
if( novel_counts.second == 0L || dbSNP_counts.second == 0L ) {
return;
}
final double fractionMM_novel = (double)novel_counts.first / (double)novel_counts.second;
final double fractionMM_dbsnp = (double)dbSNP_counts.first / (double)dbSNP_counts.second;
if( fractionMM_dbsnp < DBSNP_VS_NOVEL_MISMATCH_RATE * fractionMM_novel ) {
Utils.warnUser("The variation rate at the supplied list of known variant sites seems suspiciously low. Please double-check that the correct ROD is being used. " +
String.format("[dbSNP variation rate = %.4f, novel variation rate = %.4f]", fractionMM_dbsnp, fractionMM_novel) );
DBSNP_VALIDATION_CHECK_FREQUENCY *= 2; // Don't annoyingly output the warning message every megabase of a large file
}
}
/**
* Major workhorse routine for this walker.
* Loop through the list of requested covariates and pick out the value from the read, offset, and reference
@ -391,7 +383,7 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
* @param offset The offset in the read for this locus
* @param refBase The reference base at this locus
*/
private void updateDataFromRead(final GATKSAMRecord gatkRead, final int offset, final byte refBase) {
private void updateDataFromRead(CountedData counter, final GATKSAMRecord gatkRead, final int offset, final byte refBase) {
final Object[][] covars = (Comparable[][]) gatkRead.getTemporaryAttribute(COVARS_ATTRIBUTE);
final Object[] key = covars[offset];
@ -399,8 +391,8 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
final NestedHashMap data = dataManager.data; //optimization - create local reference
RecalDatumOptimized datum = (RecalDatumOptimized) data.get( key );
if( datum == null ) { // key doesn't exist yet in the map so make a new bucket and add it
datum = new RecalDatumOptimized(); // initialized with zeros, will be incremented at end of method
data.put( datum, (Object[])key );
// initialized with zeros, will be incremented at end of method
datum = (RecalDatumOptimized)data.put( new RecalDatumOptimized(), true, (Object[])key );
}
// Need the bases to determine whether or not we have a mismatch
@ -409,9 +401,9 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
// Add one to the number of observations and potentially one to the number of mismatches
datum.incrementBaseCounts( base, refBase );
countedBases++;
novel_counts.second++;
novel_counts.first += datum.getNumMismatches() - curMismatches; // For sanity check to ensure novel mismatch rate vs dnsnp mismatch rate is reasonable
counter.countedBases++;
counter.novelCountsBases++;
counter.novelCountsMM += datum.getNumMismatches() - curMismatches; // For sanity check to ensure novel mismatch rate vs dnsnp mismatch rate is reasonable
}
@ -422,56 +414,75 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
//---------------------------------------------------------------------------------------------------------------
/**
* Initialize the reudce step by creating a PrintStream from the filename specified as an argument to the walker.
* Initialize the reduce step by creating a PrintStream from the filename specified as an argument to the walker.
* @return returns A PrintStream created from the -recalFile filename argument specified to the walker
*/
public PrintStream reduceInit() {
return RECAL_FILE;
public CountedData reduceInit() {
return new CountedData();
}
/**
* The Reduce method doesn't do anything for this walker.
* @param value Result of the map. This value is immediately ignored.
* @param recalTableStream The PrintStream used to output the CSV data
* @return returns The PrintStream used to output the CSV data
* @param mapped Result of the map. This value is immediately ignored.
* @param sum The summing CountedData used to output the CSV data
* @return returns The sum used to output the CSV data
*/
public PrintStream reduce( Integer value, PrintStream recalTableStream ) {
return recalTableStream; // Nothing to do here, just return our open stream
public CountedData reduce( CountedData mapped, CountedData sum ) {
// Do a dbSNP sanity check every so often
return validatingDbsnpMismatchRate(sum.add(mapped));
}
public PrintStream treeReduce( PrintStream recalTableStream1, PrintStream recalTableStream2 ) {
return recalTableStream1; // Nothing to do here, just return our open stream
/**
* Validate the dbSNP reference mismatch rates.
*/
private CountedData validatingDbsnpMismatchRate(CountedData counter) {
if( ++counter.lociSinceLastDbsnpCheck >= DBSNP_VALIDATION_CHECK_FREQUENCY ) {
counter.lociSinceLastDbsnpCheck = 0;
if( counter.novelCountsBases != 0L && counter.dbSNPCountsBases != 0L ) {
final double fractionMM_novel = (double)counter.novelCountsMM / (double)counter.novelCountsBases;
final double fractionMM_dbsnp = (double)counter.dbSNPCountsMM / (double)counter.dbSNPCountsBases;
if( fractionMM_dbsnp < DBSNP_VS_NOVEL_MISMATCH_RATE * fractionMM_novel ) {
Utils.warnUser("The variation rate at the supplied list of known variant sites seems suspiciously low. Please double-check that the correct ROD is being used. " +
String.format("[dbSNP variation rate = %.4f, novel variation rate = %.4f]", fractionMM_dbsnp, fractionMM_novel) );
DBSNP_VALIDATION_CHECK_FREQUENCY *= 2; // Don't annoyingly output the warning message every megabase of a large file
}
}
}
return counter;
}
public CountedData treeReduce( CountedData sum1, CountedData sum2 ) {
return validatingDbsnpMismatchRate(sum1.add(sum2));
}
/**
* Write out the full data hashmap to disk in CSV format
* @param recalTableStream The PrintStream to write out to
* @param sum The CountedData to write out to RECAL_FILE
*/
public void onTraversalDone( PrintStream recalTableStream ) {
public void onTraversalDone( CountedData sum ) {
logger.info( "Writing raw recalibration data..." );
outputToCSV( recalTableStream );
outputToCSV( sum, RECAL_FILE );
logger.info( "...done!" );
recalTableStream.close();
RECAL_FILE.close();
}
/**
* For each entry (key-value pair) in the data hashmap output the Covariate's values as well as the RecalDatum's data in CSV format
* @param recalTableStream The PrintStream to write out to
*/
private void outputToCSV( final PrintStream recalTableStream ) {
private void outputToCSV( CountedData sum, final PrintStream recalTableStream ) {
recalTableStream.printf("# Counted Sites %d%n", sum.countedSites);
recalTableStream.printf("# Counted Bases %d%n", sum.countedBases);
recalTableStream.printf("# Skipped Sites %d%n", sum.skippedSites);
recalTableStream.printf("# Fraction Skipped 1 / %.0f bp%n", (double)sum.countedSites / sum.skippedSites);
recalTableStream.printf("# Counted Sites %d%n", countedSites);
recalTableStream.printf("# Counted Bases %d%n", countedBases);
recalTableStream.printf("# Skipped Sites %d%n", skippedSites);
if( PROCESS_EVERY_NTH_LOCUS == 1 ) {
recalTableStream.printf("# Fraction Skipped 1 / %.0f bp%n", (double)countedSites / skippedSites);
} else {
recalTableStream.printf("# Percent Skipped %.4f%n", 100.0 * (double)skippedSites / ((double)countedSites+skippedSites));
}
if( solidInsertedReferenceBases != 0 ) {
recalTableStream.printf("# Fraction SOLiD inserted reference 1 / %.0f bases%n", (double) countedBases / solidInsertedReferenceBases);
recalTableStream.printf("# Fraction other color space inconsistencies 1 / %.0f bases%n", (double) countedBases / otherColorSpaceInconsistency);
if( sum.solidInsertedReferenceBases != 0 ) {
recalTableStream.printf("# Fraction SOLiD inserted reference 1 / %.0f bases%n", (double) sum.countedBases / sum.solidInsertedReferenceBases);
recalTableStream.printf("# Fraction other color space inconsistencies 1 / %.0f bases%n", (double) sum.countedBases / sum.otherColorSpaceInconsistency);
}
// Output header saying which covariates were used and in what order

View File

@ -53,12 +53,23 @@ public class NestedHashMap{
}
public synchronized void put( final Object value, final Object... keys ) {
this.put(value, false, keys );
}
public synchronized Object put( final Object value, boolean keepOldBindingIfPresent, final Object... keys ) {
Map map = this.data;
final int keysLength = keys.length;
for( int iii = 0; iii < keysLength; iii++ ) {
if( iii == keysLength - 1 ) {
map.put(keys[iii], value);
if ( keepOldBindingIfPresent && map.containsKey(keys[iii]) ) {
// this code test is for parallel protection when you call put() multiple times in different threads
// to initialize the map. It returns the already bound key[iii] -> value
return map.get(keys[iii]);
} else {
// we are a new binding, put it in the map
map.put(keys[iii], value);
return value;
}
} else {
Map tmp = (Map) map.get(keys[iii]);
if( tmp == null ) {
@ -68,5 +79,7 @@ public class NestedHashMap{
map = tmp;
}
}
return value; // todo -- should never reach this point
}
}

View File

@ -22,7 +22,7 @@ public class RecalibrationWalkersIntegrationTest extends WalkerTest {
e.put( validationDataLocation + "NA12873.454.SRP000031.2009_06.chr1.10_20mb.bam", "596a9ec9cbc1da70481e45a5a588a41d" );
e.put( validationDataLocation + "NA12878.1kg.p2.chr1_10mb_11_mb.allTechs.bam", "507dbd3ba6f54e066d04c4d24f59c3ab" );
for ( String parallelism : Arrays.asList("") ) { // todo -- enable parallel tests. They work but there's a system bug Arrays.asList("", " -nt 4")) {
for ( String parallelism : Arrays.asList("", " -nt 4")) {
for ( Map.Entry<String, String> entry : e.entrySet() ) {
String bam = entry.getKey();
String md5 = entry.getValue();