Official parallel CountCovariates, passes all integration tests. Now poster-child example of parallelism in GATK (Matt H). Apparent general performance improvements throughout too.
git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@3833 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
parent
0b56003d1a
commit
c47a5ff5ab
|
|
@ -70,7 +70,7 @@ import java.util.Map;
|
|||
@WalkerName( "CountCovariates" )
|
||||
@ReadFilters( {ZeroMappingQualityReadFilter.class} ) // Filter out all reads with zero mapping quality
|
||||
@Requires( {DataSource.READS, DataSource.REFERENCE, DataSource.REFERENCE_BASES} ) // This walker requires both -I input.bam and -R reference.fasta
|
||||
public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> implements TreeReducible<PrintStream> {
|
||||
public class CovariateCounterWalker extends LocusWalker<CovariateCounterWalker.CountedData, CovariateCounterWalker.CountedData> implements TreeReducible<CovariateCounterWalker.CountedData> {
|
||||
|
||||
/////////////////////////////
|
||||
// Constants
|
||||
|
|
@ -96,8 +96,6 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
|
|||
private String[] COVARIATES = null;
|
||||
@Argument(fullName="standard_covs", shortName="standard", doc="Use the standard set of covariates in addition to the ones listed using the -cov argument", required=false)
|
||||
private boolean USE_STANDARD_COVARIATES = false;
|
||||
@Argument(fullName="process_nth_locus", shortName="pN", required=false, doc="Only process every Nth covered locus we see.")
|
||||
private int PROCESS_EVERY_NTH_LOCUS = 1;
|
||||
|
||||
/////////////////////////////
|
||||
// Debugging-only Arguments
|
||||
|
|
@ -110,17 +108,39 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
|
|||
/////////////////////////////
|
||||
private final RecalDataManager dataManager = new RecalDataManager(); // Holds the data HashMap, mostly used by TableRecalibrationWalker to create collapsed data hashmaps
|
||||
private final ArrayList<Covariate> requestedCovariates = new ArrayList<Covariate>(); // A list to hold the covariate objects that were requested
|
||||
private long countedSites = 0; // Number of loci used in the calculations, used for reporting in the output file
|
||||
private long countedBases = 0; // Number of bases used in the calculations, used for reporting in the output file
|
||||
private long skippedSites = 0; // Number of loci skipped because it was a dbSNP site, used for reporting in the output file
|
||||
private long solidInsertedReferenceBases = 0; // Number of bases where we believe SOLID has inserted the reference because the color space is inconsistent with the read base
|
||||
private long otherColorSpaceInconsistency = 0; // Number of bases where the color space is inconsistent with the read but the reference wasn't inserted.
|
||||
private int numUnprocessed = 0; // Number of consecutive loci skipped because we are only processing every Nth site
|
||||
private final Pair<Long, Long> dbSNP_counts = new Pair<Long, Long>(0L, 0L); // mismatch/base counts for dbSNP loci
|
||||
private final Pair<Long, Long> novel_counts = new Pair<Long, Long>(0L, 0L); // mismatch/base counts for non-dbSNP loci
|
||||
private static final double DBSNP_VS_NOVEL_MISMATCH_RATE = 2.0; // rate at which dbSNP sites (on an individual level) mismatch relative to novel sites (determined by looking at NA12878)
|
||||
private int DBSNP_VALIDATION_CHECK_FREQUENCY = 1000000; // how often to validate dbsnp mismatch rate (in terms of loci seen)
|
||||
private int lociSinceLastDbsnpCheck = 0; // loci since last dbsnp validation
|
||||
private static int DBSNP_VALIDATION_CHECK_FREQUENCY = 1000000; // how often to validate dbsnp mismatch rate (in terms of loci seen)
|
||||
|
||||
public static class CountedData {
|
||||
private long countedSites = 0; // Number of loci used in the calculations, used for reporting in the output file
|
||||
private long countedBases = 0; // Number of bases used in the calculations, used for reporting in the output file
|
||||
private long skippedSites = 0; // Number of loci skipped because it was a dbSNP site, used for reporting in the output file
|
||||
private long solidInsertedReferenceBases = 0; // Number of bases where we believe SOLID has inserted the reference because the color space is inconsistent with the read base
|
||||
private long otherColorSpaceInconsistency = 0; // Number of bases where the color space is inconsistent with the read but the reference wasn't inserted.
|
||||
|
||||
private long dbSNPCountsMM = 0, dbSNPCountsBases = 0; // mismatch/base counts for dbSNP loci
|
||||
private long novelCountsMM = 0, novelCountsBases = 0; // mismatch/base counts for non-dbSNP loci
|
||||
private int lociSinceLastDbsnpCheck = 0; // loci since last dbsnp validation
|
||||
|
||||
/**
|
||||
* Adds the values of other to this, returning this
|
||||
* @param other
|
||||
* @return this object
|
||||
*/
|
||||
public CountedData add(CountedData other) {
|
||||
countedSites += other.countedSites;
|
||||
countedBases += other.countedBases;
|
||||
skippedSites += other.skippedSites;
|
||||
solidInsertedReferenceBases += other.solidInsertedReferenceBases;
|
||||
otherColorSpaceInconsistency += other.otherColorSpaceInconsistency;
|
||||
dbSNPCountsMM += other.dbSNPCountsMM;
|
||||
dbSNPCountsBases += other.dbSNPCountsBases;
|
||||
novelCountsMM += other.novelCountsMM;
|
||||
novelCountsBases += other.novelCountsBases;
|
||||
lociSinceLastDbsnpCheck += other.lociSinceLastDbsnpCheck;
|
||||
return this;
|
||||
}
|
||||
}
|
||||
|
||||
//---------------------------------------------------------------------------------------------------------------
|
||||
//
|
||||
|
|
@ -136,7 +156,6 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
|
|||
|
||||
if( RAC.FORCE_READ_GROUP != null ) { RAC.DEFAULT_READ_GROUP = RAC.FORCE_READ_GROUP; }
|
||||
if( RAC.FORCE_PLATFORM != null ) { RAC.DEFAULT_PLATFORM = RAC.FORCE_PLATFORM; }
|
||||
DBSNP_VALIDATION_CHECK_FREQUENCY *= PROCESS_EVERY_NTH_LOCUS;
|
||||
|
||||
// Get a list of all available covariates
|
||||
final List<Class<? extends Covariate>> covariateClasses = PackageUtils.getClassesImplementingInterface( Covariate.class );
|
||||
|
|
@ -253,7 +272,7 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
|
|||
* @param context The alignment context
|
||||
* @return Returns 1, but this value isn't used in the reduce step
|
||||
*/
|
||||
public Integer map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) {
|
||||
public CountedData map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) {
|
||||
|
||||
// Pull out data for this locus for all the input RODs and check if this is a known variant site in any of them
|
||||
boolean isSNP = false;
|
||||
|
|
@ -266,9 +285,8 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
|
|||
|
||||
// Only use data from non-dbsnp sites
|
||||
// Assume every mismatch at a non-dbsnp site is indicative of poor quality
|
||||
if( !isSNP && ( ++numUnprocessed >= PROCESS_EVERY_NTH_LOCUS ) ) {
|
||||
numUnprocessed = 0; // Reset the counter because we are processing this very locus
|
||||
|
||||
CountedData counter = new CountedData();
|
||||
if( !isSNP ) {
|
||||
// For each read at this locus
|
||||
for( PileupElement p : context.getBasePileup() ) {
|
||||
GATKSAMRecord gatkRead = (GATKSAMRecord) p.getRead();
|
||||
|
|
@ -309,33 +327,25 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
|
|||
!RecalDataManager.isInconsistentColorSpace( gatkRead, offset ) ) {
|
||||
|
||||
// This base finally passed all the checks for a good base, so add it to the big data hashmap
|
||||
updateDataFromRead( gatkRead, offset, refBase );
|
||||
updateDataFromRead( counter, gatkRead, offset, refBase );
|
||||
|
||||
} else { // calculate SOLID reference insertion rate
|
||||
if( refBase == bases[offset] ) {
|
||||
solidInsertedReferenceBases++;
|
||||
counter.solidInsertedReferenceBases++;
|
||||
} else {
|
||||
otherColorSpaceInconsistency++;
|
||||
counter.otherColorSpaceInconsistency++;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
countedSites++;
|
||||
counter.countedSites++;
|
||||
} else { // We skipped over the dbSNP site, and we are only processing every Nth locus
|
||||
skippedSites++;
|
||||
if( isSNP ) {
|
||||
updateMismatchCounts(dbSNP_counts, context, ref.getBase()); // For sanity check to ensure novel mismatch rate vs dnsnp mismatch rate is reasonable
|
||||
}
|
||||
counter.skippedSites++;
|
||||
updateMismatchCounts(counter, context, ref.getBase()); // For sanity check to ensure novel mismatch rate vs dnsnp mismatch rate is reasonable
|
||||
}
|
||||
|
||||
// Do a dbSNP sanity check every so often
|
||||
if( ++lociSinceLastDbsnpCheck == DBSNP_VALIDATION_CHECK_FREQUENCY ) {
|
||||
lociSinceLastDbsnpCheck = 0;
|
||||
validateDbsnpMismatchRate();
|
||||
}
|
||||
|
||||
return 1; // This value isn't actually used anywhere
|
||||
return counter; // This value isn't actually used anywhere
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -343,11 +353,11 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
|
|||
/**
|
||||
* Update the mismatch / total_base counts for a given class of loci.
|
||||
*
|
||||
* @param counts The counts to be updated
|
||||
* @param counter The CountedData to be updated
|
||||
* @param context The AlignmentContext which holds the reads covered by this locus
|
||||
* @param refBase The reference base
|
||||
*/
|
||||
private static void updateMismatchCounts(final Pair<Long, Long> counts, final AlignmentContext context, final byte refBase) {
|
||||
private static void updateMismatchCounts(CountedData counter, final AlignmentContext context, final byte refBase) {
|
||||
for( PileupElement p : context.getBasePileup() ) {
|
||||
final byte readBase = p.getBase();
|
||||
final int readBaseIndex = BaseUtils.simpleBaseToBaseIndex(readBase);
|
||||
|
|
@ -355,31 +365,13 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
|
|||
|
||||
if( readBaseIndex != -1 && refBaseIndex != -1 ) {
|
||||
if( readBaseIndex != refBaseIndex ) {
|
||||
counts.first++;
|
||||
counter.novelCountsMM++;
|
||||
}
|
||||
counts.second++;
|
||||
counter.novelCountsBases++;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Validate the dbSNP reference mismatch rates.
|
||||
*/
|
||||
private void validateDbsnpMismatchRate() {
|
||||
if( novel_counts.second == 0L || dbSNP_counts.second == 0L ) {
|
||||
return;
|
||||
}
|
||||
|
||||
final double fractionMM_novel = (double)novel_counts.first / (double)novel_counts.second;
|
||||
final double fractionMM_dbsnp = (double)dbSNP_counts.first / (double)dbSNP_counts.second;
|
||||
|
||||
if( fractionMM_dbsnp < DBSNP_VS_NOVEL_MISMATCH_RATE * fractionMM_novel ) {
|
||||
Utils.warnUser("The variation rate at the supplied list of known variant sites seems suspiciously low. Please double-check that the correct ROD is being used. " +
|
||||
String.format("[dbSNP variation rate = %.4f, novel variation rate = %.4f]", fractionMM_dbsnp, fractionMM_novel) );
|
||||
DBSNP_VALIDATION_CHECK_FREQUENCY *= 2; // Don't annoyingly output the warning message every megabase of a large file
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Major workhorse routine for this walker.
|
||||
* Loop through the list of requested covariates and pick out the value from the read, offset, and reference
|
||||
|
|
@ -391,7 +383,7 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
|
|||
* @param offset The offset in the read for this locus
|
||||
* @param refBase The reference base at this locus
|
||||
*/
|
||||
private void updateDataFromRead(final GATKSAMRecord gatkRead, final int offset, final byte refBase) {
|
||||
private void updateDataFromRead(CountedData counter, final GATKSAMRecord gatkRead, final int offset, final byte refBase) {
|
||||
final Object[][] covars = (Comparable[][]) gatkRead.getTemporaryAttribute(COVARS_ATTRIBUTE);
|
||||
final Object[] key = covars[offset];
|
||||
|
||||
|
|
@ -399,8 +391,8 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
|
|||
final NestedHashMap data = dataManager.data; //optimization - create local reference
|
||||
RecalDatumOptimized datum = (RecalDatumOptimized) data.get( key );
|
||||
if( datum == null ) { // key doesn't exist yet in the map so make a new bucket and add it
|
||||
datum = new RecalDatumOptimized(); // initialized with zeros, will be incremented at end of method
|
||||
data.put( datum, (Object[])key );
|
||||
// initialized with zeros, will be incremented at end of method
|
||||
datum = (RecalDatumOptimized)data.put( new RecalDatumOptimized(), true, (Object[])key );
|
||||
}
|
||||
|
||||
// Need the bases to determine whether or not we have a mismatch
|
||||
|
|
@ -409,9 +401,9 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
|
|||
|
||||
// Add one to the number of observations and potentially one to the number of mismatches
|
||||
datum.incrementBaseCounts( base, refBase );
|
||||
countedBases++;
|
||||
novel_counts.second++;
|
||||
novel_counts.first += datum.getNumMismatches() - curMismatches; // For sanity check to ensure novel mismatch rate vs dnsnp mismatch rate is reasonable
|
||||
counter.countedBases++;
|
||||
counter.novelCountsBases++;
|
||||
counter.novelCountsMM += datum.getNumMismatches() - curMismatches; // For sanity check to ensure novel mismatch rate vs dnsnp mismatch rate is reasonable
|
||||
}
|
||||
|
||||
|
||||
|
|
@ -422,56 +414,75 @@ public class CovariateCounterWalker extends LocusWalker<Integer, PrintStream> im
|
|||
//---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
/**
|
||||
* Initialize the reudce step by creating a PrintStream from the filename specified as an argument to the walker.
|
||||
* Initialize the reduce step by creating a PrintStream from the filename specified as an argument to the walker.
|
||||
* @return returns A PrintStream created from the -recalFile filename argument specified to the walker
|
||||
*/
|
||||
public PrintStream reduceInit() {
|
||||
return RECAL_FILE;
|
||||
public CountedData reduceInit() {
|
||||
return new CountedData();
|
||||
}
|
||||
|
||||
/**
|
||||
* The Reduce method doesn't do anything for this walker.
|
||||
* @param value Result of the map. This value is immediately ignored.
|
||||
* @param recalTableStream The PrintStream used to output the CSV data
|
||||
* @return returns The PrintStream used to output the CSV data
|
||||
* @param mapped Result of the map. This value is immediately ignored.
|
||||
* @param sum The summing CountedData used to output the CSV data
|
||||
* @return returns The sum used to output the CSV data
|
||||
*/
|
||||
public PrintStream reduce( Integer value, PrintStream recalTableStream ) {
|
||||
return recalTableStream; // Nothing to do here, just return our open stream
|
||||
public CountedData reduce( CountedData mapped, CountedData sum ) {
|
||||
// Do a dbSNP sanity check every so often
|
||||
return validatingDbsnpMismatchRate(sum.add(mapped));
|
||||
}
|
||||
|
||||
public PrintStream treeReduce( PrintStream recalTableStream1, PrintStream recalTableStream2 ) {
|
||||
return recalTableStream1; // Nothing to do here, just return our open stream
|
||||
/**
|
||||
* Validate the dbSNP reference mismatch rates.
|
||||
*/
|
||||
private CountedData validatingDbsnpMismatchRate(CountedData counter) {
|
||||
if( ++counter.lociSinceLastDbsnpCheck >= DBSNP_VALIDATION_CHECK_FREQUENCY ) {
|
||||
counter.lociSinceLastDbsnpCheck = 0;
|
||||
|
||||
if( counter.novelCountsBases != 0L && counter.dbSNPCountsBases != 0L ) {
|
||||
final double fractionMM_novel = (double)counter.novelCountsMM / (double)counter.novelCountsBases;
|
||||
final double fractionMM_dbsnp = (double)counter.dbSNPCountsMM / (double)counter.dbSNPCountsBases;
|
||||
|
||||
if( fractionMM_dbsnp < DBSNP_VS_NOVEL_MISMATCH_RATE * fractionMM_novel ) {
|
||||
Utils.warnUser("The variation rate at the supplied list of known variant sites seems suspiciously low. Please double-check that the correct ROD is being used. " +
|
||||
String.format("[dbSNP variation rate = %.4f, novel variation rate = %.4f]", fractionMM_dbsnp, fractionMM_novel) );
|
||||
DBSNP_VALIDATION_CHECK_FREQUENCY *= 2; // Don't annoyingly output the warning message every megabase of a large file
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return counter;
|
||||
}
|
||||
|
||||
public CountedData treeReduce( CountedData sum1, CountedData sum2 ) {
|
||||
return validatingDbsnpMismatchRate(sum1.add(sum2));
|
||||
}
|
||||
|
||||
/**
|
||||
* Write out the full data hashmap to disk in CSV format
|
||||
* @param recalTableStream The PrintStream to write out to
|
||||
* @param sum The CountedData to write out to RECAL_FILE
|
||||
*/
|
||||
public void onTraversalDone( PrintStream recalTableStream ) {
|
||||
public void onTraversalDone( CountedData sum ) {
|
||||
logger.info( "Writing raw recalibration data..." );
|
||||
outputToCSV( recalTableStream );
|
||||
outputToCSV( sum, RECAL_FILE );
|
||||
logger.info( "...done!" );
|
||||
|
||||
recalTableStream.close();
|
||||
RECAL_FILE.close();
|
||||
}
|
||||
|
||||
/**
|
||||
* For each entry (key-value pair) in the data hashmap output the Covariate's values as well as the RecalDatum's data in CSV format
|
||||
* @param recalTableStream The PrintStream to write out to
|
||||
*/
|
||||
private void outputToCSV( final PrintStream recalTableStream ) {
|
||||
private void outputToCSV( CountedData sum, final PrintStream recalTableStream ) {
|
||||
recalTableStream.printf("# Counted Sites %d%n", sum.countedSites);
|
||||
recalTableStream.printf("# Counted Bases %d%n", sum.countedBases);
|
||||
recalTableStream.printf("# Skipped Sites %d%n", sum.skippedSites);
|
||||
recalTableStream.printf("# Fraction Skipped 1 / %.0f bp%n", (double)sum.countedSites / sum.skippedSites);
|
||||
|
||||
recalTableStream.printf("# Counted Sites %d%n", countedSites);
|
||||
recalTableStream.printf("# Counted Bases %d%n", countedBases);
|
||||
recalTableStream.printf("# Skipped Sites %d%n", skippedSites);
|
||||
if( PROCESS_EVERY_NTH_LOCUS == 1 ) {
|
||||
recalTableStream.printf("# Fraction Skipped 1 / %.0f bp%n", (double)countedSites / skippedSites);
|
||||
} else {
|
||||
recalTableStream.printf("# Percent Skipped %.4f%n", 100.0 * (double)skippedSites / ((double)countedSites+skippedSites));
|
||||
}
|
||||
if( solidInsertedReferenceBases != 0 ) {
|
||||
recalTableStream.printf("# Fraction SOLiD inserted reference 1 / %.0f bases%n", (double) countedBases / solidInsertedReferenceBases);
|
||||
recalTableStream.printf("# Fraction other color space inconsistencies 1 / %.0f bases%n", (double) countedBases / otherColorSpaceInconsistency);
|
||||
if( sum.solidInsertedReferenceBases != 0 ) {
|
||||
recalTableStream.printf("# Fraction SOLiD inserted reference 1 / %.0f bases%n", (double) sum.countedBases / sum.solidInsertedReferenceBases);
|
||||
recalTableStream.printf("# Fraction other color space inconsistencies 1 / %.0f bases%n", (double) sum.countedBases / sum.otherColorSpaceInconsistency);
|
||||
}
|
||||
|
||||
// Output header saying which covariates were used and in what order
|
||||
|
|
|
|||
|
|
@ -53,12 +53,23 @@ public class NestedHashMap{
|
|||
}
|
||||
|
||||
public synchronized void put( final Object value, final Object... keys ) {
|
||||
this.put(value, false, keys );
|
||||
}
|
||||
|
||||
public synchronized Object put( final Object value, boolean keepOldBindingIfPresent, final Object... keys ) {
|
||||
Map map = this.data;
|
||||
final int keysLength = keys.length;
|
||||
for( int iii = 0; iii < keysLength; iii++ ) {
|
||||
if( iii == keysLength - 1 ) {
|
||||
map.put(keys[iii], value);
|
||||
if ( keepOldBindingIfPresent && map.containsKey(keys[iii]) ) {
|
||||
// this code test is for parallel protection when you call put() multiple times in different threads
|
||||
// to initialize the map. It returns the already bound key[iii] -> value
|
||||
return map.get(keys[iii]);
|
||||
} else {
|
||||
// we are a new binding, put it in the map
|
||||
map.put(keys[iii], value);
|
||||
return value;
|
||||
}
|
||||
} else {
|
||||
Map tmp = (Map) map.get(keys[iii]);
|
||||
if( tmp == null ) {
|
||||
|
|
@ -68,5 +79,7 @@ public class NestedHashMap{
|
|||
map = tmp;
|
||||
}
|
||||
}
|
||||
|
||||
return value; // todo -- should never reach this point
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -22,7 +22,7 @@ public class RecalibrationWalkersIntegrationTest extends WalkerTest {
|
|||
e.put( validationDataLocation + "NA12873.454.SRP000031.2009_06.chr1.10_20mb.bam", "596a9ec9cbc1da70481e45a5a588a41d" );
|
||||
e.put( validationDataLocation + "NA12878.1kg.p2.chr1_10mb_11_mb.allTechs.bam", "507dbd3ba6f54e066d04c4d24f59c3ab" );
|
||||
|
||||
for ( String parallelism : Arrays.asList("") ) { // todo -- enable parallel tests. They work but there's a system bug Arrays.asList("", " -nt 4")) {
|
||||
for ( String parallelism : Arrays.asList("", " -nt 4")) {
|
||||
for ( Map.Entry<String, String> entry : e.entrySet() ) {
|
||||
String bam = entry.getKey();
|
||||
String md5 = entry.getValue();
|
||||
|
|
|
|||
Loading…
Reference in New Issue