diff --git a/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java b/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java index 60f7317ba..09b18bfe1 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java +++ b/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java @@ -11,6 +11,7 @@ import org.broadinstitute.sting.gatk.datasources.rmd.ReferenceOrderedDataSource; import org.broadinstitute.sting.gatk.io.DirectOutputTracker; import org.broadinstitute.sting.gatk.io.OutputTracker; import org.broadinstitute.sting.gatk.resourcemanagement.ThreadAllocation; +import org.broadinstitute.sting.gatk.traversals.TraversalEngine; import org.broadinstitute.sting.gatk.traversals.TraverseActiveRegions; import org.broadinstitute.sting.gatk.walkers.Walker; import org.broadinstitute.sting.utils.SampleUtils; @@ -60,6 +61,7 @@ public class LinearMicroScheduler extends MicroScheduler { boolean done = walker.isDone(); int counter = 0; + final TraversalEngine traversalEngine = borrowTraversalEngine(); for (Shard shard : shardStrategy ) { if ( done || shard == null ) // we ran out of shards that aren't owned break; @@ -69,7 +71,7 @@ public class LinearMicroScheduler extends MicroScheduler { getReadIterator(shard), shard.getGenomeLocs(), SampleUtils.getSAMFileSamples(engine)); for(WindowMaker.WindowMakerIterator iterator: windowMaker) { ShardDataProvider dataProvider = new LocusShardDataProvider(shard,iterator.getSourceInfo(),engine.getGenomeLocParser(),iterator.getLocus(),iterator,reference,rods); - Object result = getTraversalEngine().traverse(walker, dataProvider, accumulator.getReduceInit()); + Object result = traversalEngine.traverse(walker, dataProvider, accumulator.getReduceInit()); accumulator.accumulate(dataProvider,result); dataProvider.close(); if ( walker.isDone() ) break; @@ -78,7 +80,7 @@ public class LinearMicroScheduler extends MicroScheduler { } else { ShardDataProvider dataProvider = new ReadShardDataProvider(shard,engine.getGenomeLocParser(),getReadIterator(shard),reference,rods); - Object result = getTraversalEngine().traverse(walker, dataProvider, accumulator.getReduceInit()); + Object result = traversalEngine.traverse(walker, dataProvider, accumulator.getReduceInit()); accumulator.accumulate(dataProvider,result); dataProvider.close(); } @@ -87,14 +89,15 @@ public class LinearMicroScheduler extends MicroScheduler { } // Special function call to empty out the work queue. Ugly for now but will be cleaned up when we eventually push this functionality more into the engine - if( getTraversalEngine() instanceof TraverseActiveRegions ) { - final Object result = ((TraverseActiveRegions) getTraversalEngine()).endTraversal(walker, accumulator.getReduceInit()); + if( traversalEngine instanceof TraverseActiveRegions ) { + final Object result = ((TraverseActiveRegions) traversalEngine).endTraversal(walker, accumulator.getReduceInit()); accumulator.accumulate(null, result); // Assumes only used with StandardAccumulator } Object result = accumulator.finishTraversal(); outputTracker.close(); + returnTraversalEngine(traversalEngine); cleanup(); executionIsDone(); diff --git a/public/java/src/org/broadinstitute/sting/gatk/executive/MicroScheduler.java b/public/java/src/org/broadinstitute/sting/gatk/executive/MicroScheduler.java index 893548a9b..030f8d0f2 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/executive/MicroScheduler.java +++ b/public/java/src/org/broadinstitute/sting/gatk/executive/MicroScheduler.java @@ -63,14 +63,36 @@ import java.util.Map; * Time: 12:37:23 PM * * General base class for all scheduling algorithms + * Shards and schedules data in manageable chunks. + * + * Creates N TraversalEngines for each data thread for the MicroScheduler. This is necessary + * because in the HMS case you have multiple threads executing a traversal engine independently, and + * these engines may need to create separate resources for efficiency or implementation reasons. For example, + * the nanoScheduler creates threads to implement the traversal, and this creation is instance specific. + * So each HMS thread needs to have it's own distinct copy of the traversal engine if it wants to have + * N data threads x M nano threads => N * M threads total. These are borrowed from this microscheduler + * and returned when done. Also allows us to tracks all created traversal engines so this microscheduler + * can properly shut them all down when the scheduling is done. + * */ - -/** Shards and schedules data in manageable chunks. */ public abstract class MicroScheduler implements MicroSchedulerMBean { // TODO -- remove me and retire non nano scheduled versions of traversals private final static boolean USE_NANOSCHEDULER_FOR_EVERYTHING = true; protected static final Logger logger = Logger.getLogger(MicroScheduler.class); + /** + * The list of all Traversal engines we've created in this micro scheduler + */ + final List allCreatedTraversalEngines = new LinkedList(); + + /** + * All available engines. Engines are borrowed and returned when a subclass is actually + * going to execute the engine on some data. This allows us to have N copies for + * N data parallel executions, but without the dangerous code of having local + * ThreadLocal variables. + */ + final LinkedList availableTraversalEngines = new LinkedList(); + /** * Counts the number of instances of the class that are currently alive. */ @@ -81,7 +103,6 @@ public abstract class MicroScheduler implements MicroSchedulerMBean { */ protected final GenomeAnalysisEngine engine; - private TraversalEngineCreator traversalEngineCreator; protected final IndexedFastaSequenceFile reference; private final SAMDataSource reads; @@ -158,13 +179,27 @@ public abstract class MicroScheduler implements MicroSchedulerMBean { this.reads = reads; this.reference = reference; this.rods = rods; - this.traversalEngineCreator = new TraversalEngineCreator(walker, threadAllocation); final File progressLogFile = engine.getArguments() == null ? null : engine.getArguments().performanceLog; + + // Creates uninitialized TraversalEngines appropriate for walker and threadAllocation, + // and adds it to the list of created engines for later shutdown. + for ( int i = 0; i < threadAllocation.getNumDataThreads(); i++ ) { + final TraversalEngine traversalEngine = createTraversalEngine(walker, threadAllocation); + allCreatedTraversalEngines.add(traversalEngine); + availableTraversalEngines.add(traversalEngine); + } + logger.info("Creating " + threadAllocation.getNumDataThreads() + " traversal engines"); + + // Create our progress meter this.progressMeter = new ProgressMeter(progressLogFile, - traversalEngineCreator.getTraversalUnits(), + availableTraversalEngines.peek().getTraversalUnits(), engine.getRegionsOfGenomeBeingProcessed()); + // Now that we have a progress meter, go through and initialize the traversal engines + for ( final TraversalEngine traversalEngine : allCreatedTraversalEngines ) + traversalEngine.initialize(engine, progressMeter); + // JMX does not allow multiple instances with the same ObjectName to be registered with the same platform MXBean. // To get around this limitation and since we have no job identifier at this point, register a simple counter that // will count the number of instances of this object that have been created in this JVM. @@ -179,6 +214,35 @@ public abstract class MicroScheduler implements MicroSchedulerMBean { } } + /** + * Really make us a traversal engine of the appropriate type for walker and thread allocation + * + * @return a non-null uninitialized traversal engine + */ + @Ensures("result != null") + private TraversalEngine createTraversalEngine(final Walker walker, final ThreadAllocation threadAllocation) { + if (walker instanceof ReadWalker) { + if ( USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1 ) + return new TraverseReadsNano(threadAllocation.getNumCPUThreadsPerDataThread()); + else + return new TraverseReads(); + } else if (walker instanceof LocusWalker) { + if ( USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1 ) + return new TraverseLociNano(threadAllocation.getNumCPUThreadsPerDataThread()); + else + return new TraverseLociLinear(); + } else if (walker instanceof DuplicateWalker) { + return new TraverseDuplicates(); + } else if (walker instanceof ReadPairWalker) { + return new TraverseReadPairs(); + } else if (walker instanceof ActiveRegionWalker) { + return new TraverseActiveRegions(); + } else { + throw new UnsupportedOperationException("Unable to determine traversal type, the walker is an unknown type."); + } + } + + /** * Return the ThreadEfficiencyMonitor we are using to track our resource utilization, if there is one * @@ -228,9 +292,7 @@ public abstract class MicroScheduler implements MicroSchedulerMBean { protected void executionIsDone() { progressMeter.notifyDone(engine.getCumulativeMetrics().getNumIterations()); printReadFilteringStats(); - - traversalEngineCreator.shutdown(); - traversalEngineCreator = null; + shutdownTraversalEngines(); // Print out the threading efficiency of this HMS, if state monitoring is enabled if ( threadEfficiencyMonitor != null ) { @@ -240,6 +302,23 @@ public abstract class MicroScheduler implements MicroSchedulerMBean { } } + /** + * Shutdown all of the created engines, and clear the list of created engines, dropping + * pointers to the traversal engines + */ + public synchronized void shutdownTraversalEngines() { + if ( availableTraversalEngines.size() != allCreatedTraversalEngines.size() ) + throw new IllegalStateException("Shutting down TraversalEngineCreator but not all engines " + + "have been returned. Expected " + allCreatedTraversalEngines.size() + " but only " + availableTraversalEngines.size() + + " have been returned"); + + for ( final TraversalEngine te : allCreatedTraversalEngines) + te.shutdown(); + + allCreatedTraversalEngines.clear(); + availableTraversalEngines.clear(); + } + /** * Prints out information about number of reads observed and filtering, if any reads were used in the traversal * @@ -301,111 +380,34 @@ public abstract class MicroScheduler implements MicroSchedulerMBean { /** * Returns a traversal engine suitable for use in this thread. * - * May create a new traversal engine for this thread, if this is the first - * time this thread ever asked for a TraversalEngine. + * Pops the next available engine from the available ones maintained by this + * microscheduler. Note that it's a runtime error to pop a traversal engine + * from this scheduler if there are none available. Callers that + * once pop'd an engine for use must return it with returnTraversalEngine * * @return a non-null TraversalEngine suitable for execution in this scheduler */ - public TraversalEngine getTraversalEngine() { - return traversalEngineCreator.get(); + @Ensures("result != null") + protected synchronized TraversalEngine borrowTraversalEngine() { + if ( availableTraversalEngines.isEmpty() ) + throw new IllegalStateException("no traversal engines were available"); + else { + return availableTraversalEngines.pop(); + } } /** - * ThreadLocal TraversalEngine creator + * Return a borrowed traversal engine to this MicroScheduler, for later use + * in another traversal execution * - * TraversalEngines are thread local variables to the MicroScheduler. This is necessary - * because in the HMS case you have multiple threads executing a traversal engine independently, and - * these engines may need to create separate resources for efficiency or implementation reasons. For example, - * the nanoScheduler creates threads to implement the traversal, and this creation is instance specific. - * So each HMS thread needs to have it's own distinct copy of the traversal engine if it wants to have - * N data threads x M nano threads => N * M threads total. - * - * This class also tracks all created traversal engines so this microscheduler can properly - * shut them all down when the scheduling is done. + * @param traversalEngine the borrowed traversal engine. Must have been previously borrowed. */ - private class TraversalEngineCreator extends ThreadLocal { - final List createdEngines = new LinkedList(); - final Walker walker; - final ThreadAllocation threadAllocation; + protected synchronized void returnTraversalEngine(final TraversalEngine traversalEngine) { + if ( traversalEngine == null ) + throw new IllegalArgumentException("Attempting to push a null traversal engine"); + if ( ! allCreatedTraversalEngines.contains(traversalEngine) ) + throw new IllegalArgumentException("Attempting to push a traversal engine not created by this MicroScheduler" + engine); - /** - * Creates an initialized TraversalEngine appropriate for walker and threadAllocation, - * and adds it to the list of created engines for later shutdown. - * - * @return a non-null traversal engine - */ - @Override - protected synchronized TraversalEngine initialValue() { - final TraversalEngine traversalEngine = createEngine(); - traversalEngine.initialize(engine, progressMeter); - createdEngines.add(traversalEngine); - return traversalEngine; - } - - /** - * Returns the traversal units for traversal engines created here. - * - * This (unfortunately) creates an uninitialized tmp. TraversalEngine so we can get - * it's traversal units, and then immediately shuts it down... - * - * @return the traversal unit as returned by getTraversalUnits of TraversalEngines created here - */ - protected String getTraversalUnits() { - final TraversalEngine tmp = createEngine(); - final String units = tmp.getTraversalUnits(); - tmp.shutdown(); - return units; - } - - /** - * Really make us a traversal engine of the appropriate type for walker and thread allocation - * - * @return a non-null uninitialized traversal engine - */ - @Ensures("result != null") - protected TraversalEngine createEngine() { - if (walker instanceof ReadWalker) { - if ( USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1 ) - return new TraverseReadsNano(threadAllocation.getNumCPUThreadsPerDataThread()); - else - return new TraverseReads(); - } else if (walker instanceof LocusWalker) { - if ( USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1 ) - return new TraverseLociNano(threadAllocation.getNumCPUThreadsPerDataThread()); - else - return new TraverseLociLinear(); - } else if (walker instanceof DuplicateWalker) { - return new TraverseDuplicates(); - } else if (walker instanceof ReadPairWalker) { - return new TraverseReadPairs(); - } else if (walker instanceof ActiveRegionWalker) { - return new TraverseActiveRegions(); - } else { - throw new UnsupportedOperationException("Unable to determine traversal type, the walker is an unknown type."); - } - } - - /** - * Create a TraversalEngineCreator that makes TraversalEngines appropriate for walker and threadAllocation - * - * @param walker the walker we need traversal engines for - * @param threadAllocation what kind of threading will we use in the traversal? - */ - @com.google.java.contract.Requires({"walker != null", "threadAllocation != null"}) - public TraversalEngineCreator(final Walker walker, final ThreadAllocation threadAllocation) { - super(); - this.walker = walker; - this.threadAllocation = threadAllocation; - } - - /** - * Shutdown all of the created engines, and clear the list of created engines, dropping - * pointers to the traversal engines - */ - public synchronized void shutdown() { - for ( final TraversalEngine te : traversalEngineCreator.createdEngines ) - te.shutdown(); - createdEngines.clear(); - } + availableTraversalEngines.push(traversalEngine); } } diff --git a/public/java/src/org/broadinstitute/sting/gatk/executive/ShardTraverser.java b/public/java/src/org/broadinstitute/sting/gatk/executive/ShardTraverser.java index e8f15ebef..d632892d5 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/executive/ShardTraverser.java +++ b/public/java/src/org/broadinstitute/sting/gatk/executive/ShardTraverser.java @@ -5,6 +5,7 @@ import org.broadinstitute.sting.gatk.datasources.providers.LocusShardDataProvide import org.broadinstitute.sting.gatk.datasources.providers.ShardDataProvider; import org.broadinstitute.sting.gatk.datasources.reads.Shard; import org.broadinstitute.sting.gatk.io.ThreadLocalOutputTracker; +import org.broadinstitute.sting.gatk.traversals.TraversalEngine; import org.broadinstitute.sting.gatk.walkers.Walker; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; @@ -50,6 +51,7 @@ public class ShardTraverser implements Callable { } public Object call() { + final TraversalEngine traversalEngine = microScheduler.borrowTraversalEngine(); try { final long startTime = System.currentTimeMillis(); @@ -61,7 +63,7 @@ public class ShardTraverser implements Callable { for(WindowMaker.WindowMakerIterator iterator: windowMaker) { final ShardDataProvider dataProvider = new LocusShardDataProvider(shard,iterator.getSourceInfo(),microScheduler.getEngine().getGenomeLocParser(),iterator.getLocus(),iterator,microScheduler.reference,microScheduler.rods); - accumulator = microScheduler.getTraversalEngine().traverse(walker, dataProvider, accumulator); + accumulator = traversalEngine.traverse(walker, dataProvider, accumulator); dataProvider.close(); } @@ -79,6 +81,7 @@ public class ShardTraverser implements Callable { } finally { synchronized(this) { complete = true; + microScheduler.returnTraversalEngine(traversalEngine); notifyAll(); } }