Removing parallelism bottleneck in the GATK

-- GenomeLocParser cache was a major performance bottleneck in parallel GATK performance.  With 10 thread > 50% of each thread's time was spent blocking on the MasterSequencingDictionary object.  Made this a thread local variable.
-- Now we can run the GATK with 48 threads efficiently on GSA4!
  -- Running -nt 1 => 75 minutes (didn't let is run all of the way through so likely would take longer)
  -- Running -nt 24 => 3.81 minutes
This commit is contained in:
Mark DePristo 2012-08-13 15:59:35 -04:00
parent cbf290ada0
commit f277d7c09e
1 changed files with 48 additions and 33 deletions

View File

@ -43,9 +43,6 @@ import org.broadinstitute.sting.utils.variantcontext.VariantContext;
/** /**
* Factory class for creating GenomeLocs * Factory class for creating GenomeLocs
*/ */
@Invariant({
"logger != null",
"contigInfo != null"})
public final class GenomeLocParser { public final class GenomeLocParser {
private static Logger logger = Logger.getLogger(GenomeLocParser.class); private static Logger logger = Logger.getLogger(GenomeLocParser.class);
@ -54,20 +51,39 @@ public final class GenomeLocParser {
// Ugly global variable defining the optional ordering of contig elements // Ugly global variable defining the optional ordering of contig elements
// //
// -------------------------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------------------------
private final MasterSequenceDictionary contigInfo;
/**
* This single variable holds the underlying SamSequenceDictionary used by the GATK. We assume
* it is thread safe.
*/
final private SAMSequenceDictionary SINGLE_MASTER_SEQUENCE_DICTIONARY;
/**
* A thread-local caching contig info
*/
private final ThreadLocal<CachingSequenceDictionary> contigInfoPerThread =
new ThreadLocal<CachingSequenceDictionary>();
/**
* @return a caching sequence dictionary appropriate for this thread
*/
private CachingSequenceDictionary getContigInfo() {
if ( contigInfoPerThread.get() == null ) {
// initialize for this thread
logger.debug("Creating thread-local caching sequence dictionary for thread " + Thread.currentThread().getName());
contigInfoPerThread.set(new CachingSequenceDictionary(SINGLE_MASTER_SEQUENCE_DICTIONARY));
}
assert contigInfoPerThread.get() != null;
return contigInfoPerThread.get();
}
/** /**
* A wrapper class that provides efficient last used caching for the global * A wrapper class that provides efficient last used caching for the global
* SAMSequenceDictionary underlying all of the GATK engine capabilities * SAMSequenceDictionary underlying all of the GATK engine capabilities.
*/ */
// todo -- enable when CoFoJa developers identify the problem (likely thread unsafe invariants) private final class CachingSequenceDictionary {
// @Invariant({
// "dict != null",
// "dict.size() > 0",
// "lastSSR == null || dict.getSequence(lastContig).getSequenceIndex() == lastIndex",
// "lastSSR == null || dict.getSequence(lastContig).getSequenceName() == lastContig",
// "lastSSR == null || dict.getSequence(lastContig) == lastSSR"})
private final class MasterSequenceDictionary {
final private SAMSequenceDictionary dict; final private SAMSequenceDictionary dict;
// cache // cache
@ -76,7 +92,7 @@ public final class GenomeLocParser {
int lastIndex = -1; int lastIndex = -1;
@Requires({"dict != null", "dict.size() > 0"}) @Requires({"dict != null", "dict.size() > 0"})
public MasterSequenceDictionary(SAMSequenceDictionary dict) { public CachingSequenceDictionary(SAMSequenceDictionary dict) {
this.dict = dict; this.dict = dict;
} }
@ -111,7 +127,6 @@ public final class GenomeLocParser {
return lastSSR; return lastSSR;
else else
return updateCache(null, index); return updateCache(null, index);
} }
@Requires("contig != null") @Requires("contig != null")
@ -125,12 +140,12 @@ public final class GenomeLocParser {
} }
@Requires({"contig != null", "lastContig != null"}) @Requires({"contig != null", "lastContig != null"})
private final synchronized boolean isCached(final String contig) { private synchronized boolean isCached(final String contig) {
return lastContig.equals(contig); return lastContig.equals(contig);
} }
@Requires({"lastIndex != -1", "index >= 0"}) @Requires({"lastIndex != -1", "index >= 0"})
private final synchronized boolean isCached(final int index) { private synchronized boolean isCached(final int index) {
return lastIndex == index; return lastIndex == index;
} }
@ -144,7 +159,7 @@ public final class GenomeLocParser {
*/ */
@Requires("contig != null || index >= 0") @Requires("contig != null || index >= 0")
@Ensures("result != null") @Ensures("result != null")
private final synchronized SAMSequenceRecord updateCache(final String contig, int index ) { private synchronized SAMSequenceRecord updateCache(final String contig, int index ) {
SAMSequenceRecord rec = contig == null ? dict.getSequence(index) : dict.getSequence(contig); SAMSequenceRecord rec = contig == null ? dict.getSequence(index) : dict.getSequence(contig);
if ( rec == null ) { if ( rec == null ) {
throw new ReviewedStingException("BUG: requested unknown contig=" + contig + " index=" + index); throw new ReviewedStingException("BUG: requested unknown contig=" + contig + " index=" + index);
@ -174,7 +189,7 @@ public final class GenomeLocParser {
throw new UserException.CommandLineException("Failed to load reference dictionary"); throw new UserException.CommandLineException("Failed to load reference dictionary");
} }
contigInfo = new MasterSequenceDictionary(seqDict); SINGLE_MASTER_SEQUENCE_DICTIONARY = seqDict;
logger.debug(String.format("Prepared reference sequence contig dictionary")); logger.debug(String.format("Prepared reference sequence contig dictionary"));
for (SAMSequenceRecord contig : seqDict.getSequences()) { for (SAMSequenceRecord contig : seqDict.getSequences()) {
logger.debug(String.format(" %s (%d bp)", contig.getSequenceName(), contig.getSequenceLength())); logger.debug(String.format(" %s (%d bp)", contig.getSequenceName(), contig.getSequenceLength()));
@ -188,11 +203,11 @@ public final class GenomeLocParser {
* @return True if the contig is valid. False otherwise. * @return True if the contig is valid. False otherwise.
*/ */
public final boolean contigIsInDictionary(String contig) { public final boolean contigIsInDictionary(String contig) {
return contig != null && contigInfo.hasContig(contig); return contig != null && getContigInfo().hasContig(contig);
} }
public final boolean indexIsInDictionary(final int index) { public final boolean indexIsInDictionary(final int index) {
return index >= 0 && contigInfo.hasContig(index); return index >= 0 && getContigInfo().hasContig(index);
} }
@ -208,7 +223,7 @@ public final class GenomeLocParser {
public final SAMSequenceRecord getContigInfo(final String contig) { public final SAMSequenceRecord getContigInfo(final String contig) {
if ( contig == null || ! contigIsInDictionary(contig) ) if ( contig == null || ! contigIsInDictionary(contig) )
throw new UserException.MalformedGenomeLoc(String.format("Contig %s given as location, but this contig isn't present in the Fasta sequence dictionary", contig)); throw new UserException.MalformedGenomeLoc(String.format("Contig %s given as location, but this contig isn't present in the Fasta sequence dictionary", contig));
return contigInfo.getSequence(contig); return getContigInfo().getSequence(contig);
} }
/** /**
@ -226,9 +241,9 @@ public final class GenomeLocParser {
@Requires("contig != null") @Requires("contig != null")
protected int getContigIndexWithoutException(final String contig) { protected int getContigIndexWithoutException(final String contig) {
if ( contig == null || ! contigInfo.hasContig(contig) ) if ( contig == null || ! getContigInfo().hasContig(contig) )
return -1; return -1;
return contigInfo.getSequenceIndex(contig); return getContigInfo().getSequenceIndex(contig);
} }
/** /**
@ -236,7 +251,7 @@ public final class GenomeLocParser {
* @return * @return
*/ */
public final SAMSequenceDictionary getContigs() { public final SAMSequenceDictionary getContigs() {
return contigInfo.dict; return getContigInfo().dict;
} }
// -------------------------------------------------------------------------------------------------------------- // --------------------------------------------------------------------------------------------------------------
@ -291,7 +306,7 @@ public final class GenomeLocParser {
* @return true if it's valid, false otherwise. If exceptOnError, then throws a UserException if invalid * @return true if it's valid, false otherwise. If exceptOnError, then throws a UserException if invalid
*/ */
private boolean validateGenomeLoc(String contig, int contigIndex, int start, int stop, boolean mustBeOnReference, boolean exceptOnError) { private boolean validateGenomeLoc(String contig, int contigIndex, int start, int stop, boolean mustBeOnReference, boolean exceptOnError) {
if ( ! contigInfo.hasContig(contig) ) if ( ! getContigInfo().hasContig(contig) )
return vglHelper(exceptOnError, String.format("Unknown contig %s", contig)); return vglHelper(exceptOnError, String.format("Unknown contig %s", contig));
if (stop < start) if (stop < start)
@ -300,8 +315,8 @@ public final class GenomeLocParser {
if (contigIndex < 0) if (contigIndex < 0)
return vglHelper(exceptOnError, String.format("The contig index %d is less than 0", contigIndex)); return vglHelper(exceptOnError, String.format("The contig index %d is less than 0", contigIndex));
if (contigIndex >= contigInfo.getNSequences()) if (contigIndex >= getContigInfo().getNSequences())
return vglHelper(exceptOnError, String.format("The contig index %d is greater than the stored sequence count (%d)", contigIndex, contigInfo.getNSequences())); return vglHelper(exceptOnError, String.format("The contig index %d is greater than the stored sequence count (%d)", contigIndex, getContigInfo().getNSequences()));
if ( mustBeOnReference ) { if ( mustBeOnReference ) {
if (start < 1) if (start < 1)
@ -310,7 +325,7 @@ public final class GenomeLocParser {
if (stop < 1) if (stop < 1)
return vglHelper(exceptOnError, String.format("The stop position %d is less than 1", stop)); return vglHelper(exceptOnError, String.format("The stop position %d is less than 1", stop));
int contigSize = contigInfo.getSequence(contigIndex).getSequenceLength(); int contigSize = getContigInfo().getSequence(contigIndex).getSequenceLength();
if (start > contigSize || stop > contigSize) if (start > contigSize || stop > contigSize)
return vglHelper(exceptOnError, String.format("The genome loc coordinates %d-%d exceed the contig size (%d)", start, stop, contigSize)); return vglHelper(exceptOnError, String.format("The genome loc coordinates %d-%d exceed the contig size (%d)", start, stop, contigSize));
} }
@ -558,7 +573,7 @@ public final class GenomeLocParser {
@Requires("contigName != null") @Requires("contigName != null")
@Ensures("result != null") @Ensures("result != null")
public GenomeLoc createOverEntireContig(String contigName) { public GenomeLoc createOverEntireContig(String contigName) {
SAMSequenceRecord contig = contigInfo.getSequence(contigName); SAMSequenceRecord contig = getContigInfo().getSequence(contigName);
return createGenomeLoc(contigName,contig.getSequenceIndex(),1,contig.getSequenceLength(), true); return createGenomeLoc(contigName,contig.getSequenceIndex(),1,contig.getSequenceLength(), true);
} }
@ -573,7 +588,7 @@ public final class GenomeLocParser {
if (GenomeLoc.isUnmapped(loc)) if (GenomeLoc.isUnmapped(loc))
return null; return null;
String contigName = loc.getContig(); String contigName = loc.getContig();
SAMSequenceRecord contig = contigInfo.getSequence(contigName); SAMSequenceRecord contig = getContigInfo().getSequence(contigName);
int contigIndex = contig.getSequenceIndex(); int contigIndex = contig.getSequenceIndex();
int start = loc.getStart() - maxBasePairs; int start = loc.getStart() - maxBasePairs;
@ -598,7 +613,7 @@ public final class GenomeLocParser {
if (GenomeLoc.isUnmapped(loc)) if (GenomeLoc.isUnmapped(loc))
return loc; return loc;
final String contigName = loc.getContig(); final String contigName = loc.getContig();
final SAMSequenceRecord contig = contigInfo.getSequence(contigName); final SAMSequenceRecord contig = getContigInfo().getSequence(contigName);
final int contigIndex = contig.getSequenceIndex(); final int contigIndex = contig.getSequenceIndex();
final int contigLength = contig.getSequenceLength(); final int contigLength = contig.getSequenceLength();
@ -619,7 +634,7 @@ public final class GenomeLocParser {
if (GenomeLoc.isUnmapped(loc)) if (GenomeLoc.isUnmapped(loc))
return null; return null;
String contigName = loc.getContig(); String contigName = loc.getContig();
SAMSequenceRecord contig = contigInfo.getSequence(contigName); SAMSequenceRecord contig = getContigInfo().getSequence(contigName);
int contigIndex = contig.getSequenceIndex(); int contigIndex = contig.getSequenceIndex();
int contigLength = contig.getSequenceLength(); int contigLength = contig.getSequenceLength();