Removing parallelism bottleneck in the GATK

-- GenomeLocParser cache was a major performance bottleneck in parallel GATK performance.  With 10 thread > 50% of each thread's time was spent blocking on the MasterSequencingDictionary object.  Made this a thread local variable.
-- Now we can run the GATK with 48 threads efficiently on GSA4!
  -- Running -nt 1 => 75 minutes (didn't let is run all of the way through so likely would take longer)
  -- Running -nt 24 => 3.81 minutes
This commit is contained in:
Mark DePristo 2012-08-13 15:59:35 -04:00
parent cbf290ada0
commit f277d7c09e
1 changed files with 48 additions and 33 deletions

View File

@ -43,9 +43,6 @@ import org.broadinstitute.sting.utils.variantcontext.VariantContext;
/**
* Factory class for creating GenomeLocs
*/
@Invariant({
"logger != null",
"contigInfo != null"})
public final class GenomeLocParser {
private static Logger logger = Logger.getLogger(GenomeLocParser.class);
@ -54,20 +51,39 @@ public final class GenomeLocParser {
// Ugly global variable defining the optional ordering of contig elements
//
// --------------------------------------------------------------------------------------------------------------
private final MasterSequenceDictionary contigInfo;
/**
* This single variable holds the underlying SamSequenceDictionary used by the GATK. We assume
* it is thread safe.
*/
final private SAMSequenceDictionary SINGLE_MASTER_SEQUENCE_DICTIONARY;
/**
* A thread-local caching contig info
*/
private final ThreadLocal<CachingSequenceDictionary> contigInfoPerThread =
new ThreadLocal<CachingSequenceDictionary>();
/**
* @return a caching sequence dictionary appropriate for this thread
*/
private CachingSequenceDictionary getContigInfo() {
if ( contigInfoPerThread.get() == null ) {
// initialize for this thread
logger.debug("Creating thread-local caching sequence dictionary for thread " + Thread.currentThread().getName());
contigInfoPerThread.set(new CachingSequenceDictionary(SINGLE_MASTER_SEQUENCE_DICTIONARY));
}
assert contigInfoPerThread.get() != null;
return contigInfoPerThread.get();
}
/**
* A wrapper class that provides efficient last used caching for the global
* SAMSequenceDictionary underlying all of the GATK engine capabilities
* SAMSequenceDictionary underlying all of the GATK engine capabilities.
*/
// todo -- enable when CoFoJa developers identify the problem (likely thread unsafe invariants)
// @Invariant({
// "dict != null",
// "dict.size() > 0",
// "lastSSR == null || dict.getSequence(lastContig).getSequenceIndex() == lastIndex",
// "lastSSR == null || dict.getSequence(lastContig).getSequenceName() == lastContig",
// "lastSSR == null || dict.getSequence(lastContig) == lastSSR"})
private final class MasterSequenceDictionary {
private final class CachingSequenceDictionary {
final private SAMSequenceDictionary dict;
// cache
@ -76,7 +92,7 @@ public final class GenomeLocParser {
int lastIndex = -1;
@Requires({"dict != null", "dict.size() > 0"})
public MasterSequenceDictionary(SAMSequenceDictionary dict) {
public CachingSequenceDictionary(SAMSequenceDictionary dict) {
this.dict = dict;
}
@ -111,7 +127,6 @@ public final class GenomeLocParser {
return lastSSR;
else
return updateCache(null, index);
}
@Requires("contig != null")
@ -125,12 +140,12 @@ public final class GenomeLocParser {
}
@Requires({"contig != null", "lastContig != null"})
private final synchronized boolean isCached(final String contig) {
private synchronized boolean isCached(final String contig) {
return lastContig.equals(contig);
}
@Requires({"lastIndex != -1", "index >= 0"})
private final synchronized boolean isCached(final int index) {
private synchronized boolean isCached(final int index) {
return lastIndex == index;
}
@ -144,7 +159,7 @@ public final class GenomeLocParser {
*/
@Requires("contig != null || index >= 0")
@Ensures("result != null")
private final synchronized SAMSequenceRecord updateCache(final String contig, int index ) {
private synchronized SAMSequenceRecord updateCache(final String contig, int index ) {
SAMSequenceRecord rec = contig == null ? dict.getSequence(index) : dict.getSequence(contig);
if ( rec == null ) {
throw new ReviewedStingException("BUG: requested unknown contig=" + contig + " index=" + index);
@ -174,7 +189,7 @@ public final class GenomeLocParser {
throw new UserException.CommandLineException("Failed to load reference dictionary");
}
contigInfo = new MasterSequenceDictionary(seqDict);
SINGLE_MASTER_SEQUENCE_DICTIONARY = seqDict;
logger.debug(String.format("Prepared reference sequence contig dictionary"));
for (SAMSequenceRecord contig : seqDict.getSequences()) {
logger.debug(String.format(" %s (%d bp)", contig.getSequenceName(), contig.getSequenceLength()));
@ -188,11 +203,11 @@ public final class GenomeLocParser {
* @return True if the contig is valid. False otherwise.
*/
public final boolean contigIsInDictionary(String contig) {
return contig != null && contigInfo.hasContig(contig);
return contig != null && getContigInfo().hasContig(contig);
}
public final boolean indexIsInDictionary(final int index) {
return index >= 0 && contigInfo.hasContig(index);
return index >= 0 && getContigInfo().hasContig(index);
}
@ -208,7 +223,7 @@ public final class GenomeLocParser {
public final SAMSequenceRecord getContigInfo(final String contig) {
if ( contig == null || ! contigIsInDictionary(contig) )
throw new UserException.MalformedGenomeLoc(String.format("Contig %s given as location, but this contig isn't present in the Fasta sequence dictionary", contig));
return contigInfo.getSequence(contig);
return getContigInfo().getSequence(contig);
}
/**
@ -226,9 +241,9 @@ public final class GenomeLocParser {
@Requires("contig != null")
protected int getContigIndexWithoutException(final String contig) {
if ( contig == null || ! contigInfo.hasContig(contig) )
if ( contig == null || ! getContigInfo().hasContig(contig) )
return -1;
return contigInfo.getSequenceIndex(contig);
return getContigInfo().getSequenceIndex(contig);
}
/**
@ -236,7 +251,7 @@ public final class GenomeLocParser {
* @return
*/
public final SAMSequenceDictionary getContigs() {
return contigInfo.dict;
return getContigInfo().dict;
}
// --------------------------------------------------------------------------------------------------------------
@ -291,7 +306,7 @@ public final class GenomeLocParser {
* @return true if it's valid, false otherwise. If exceptOnError, then throws a UserException if invalid
*/
private boolean validateGenomeLoc(String contig, int contigIndex, int start, int stop, boolean mustBeOnReference, boolean exceptOnError) {
if ( ! contigInfo.hasContig(contig) )
if ( ! getContigInfo().hasContig(contig) )
return vglHelper(exceptOnError, String.format("Unknown contig %s", contig));
if (stop < start)
@ -300,8 +315,8 @@ public final class GenomeLocParser {
if (contigIndex < 0)
return vglHelper(exceptOnError, String.format("The contig index %d is less than 0", contigIndex));
if (contigIndex >= contigInfo.getNSequences())
return vglHelper(exceptOnError, String.format("The contig index %d is greater than the stored sequence count (%d)", contigIndex, contigInfo.getNSequences()));
if (contigIndex >= getContigInfo().getNSequences())
return vglHelper(exceptOnError, String.format("The contig index %d is greater than the stored sequence count (%d)", contigIndex, getContigInfo().getNSequences()));
if ( mustBeOnReference ) {
if (start < 1)
@ -310,7 +325,7 @@ public final class GenomeLocParser {
if (stop < 1)
return vglHelper(exceptOnError, String.format("The stop position %d is less than 1", stop));
int contigSize = contigInfo.getSequence(contigIndex).getSequenceLength();
int contigSize = getContigInfo().getSequence(contigIndex).getSequenceLength();
if (start > contigSize || stop > contigSize)
return vglHelper(exceptOnError, String.format("The genome loc coordinates %d-%d exceed the contig size (%d)", start, stop, contigSize));
}
@ -558,7 +573,7 @@ public final class GenomeLocParser {
@Requires("contigName != null")
@Ensures("result != null")
public GenomeLoc createOverEntireContig(String contigName) {
SAMSequenceRecord contig = contigInfo.getSequence(contigName);
SAMSequenceRecord contig = getContigInfo().getSequence(contigName);
return createGenomeLoc(contigName,contig.getSequenceIndex(),1,contig.getSequenceLength(), true);
}
@ -573,7 +588,7 @@ public final class GenomeLocParser {
if (GenomeLoc.isUnmapped(loc))
return null;
String contigName = loc.getContig();
SAMSequenceRecord contig = contigInfo.getSequence(contigName);
SAMSequenceRecord contig = getContigInfo().getSequence(contigName);
int contigIndex = contig.getSequenceIndex();
int start = loc.getStart() - maxBasePairs;
@ -598,7 +613,7 @@ public final class GenomeLocParser {
if (GenomeLoc.isUnmapped(loc))
return loc;
final String contigName = loc.getContig();
final SAMSequenceRecord contig = contigInfo.getSequence(contigName);
final SAMSequenceRecord contig = getContigInfo().getSequence(contigName);
final int contigIndex = contig.getSequenceIndex();
final int contigLength = contig.getSequenceLength();
@ -619,7 +634,7 @@ public final class GenomeLocParser {
if (GenomeLoc.isUnmapped(loc))
return null;
String contigName = loc.getContig();
SAMSequenceRecord contig = contigInfo.getSequence(contigName);
SAMSequenceRecord contig = getContigInfo().getSequence(contigName);
int contigIndex = contig.getSequenceIndex();
int contigLength = contig.getSequenceLength();