diff --git a/public/java/src/org/broadinstitute/sting/gatk/executive/HierarchicalMicroScheduler.java b/public/java/src/org/broadinstitute/sting/gatk/executive/HierarchicalMicroScheduler.java index 486e83e60..1bac72f3e 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/executive/HierarchicalMicroScheduler.java +++ b/public/java/src/org/broadinstitute/sting/gatk/executive/HierarchicalMicroScheduler.java @@ -107,7 +107,7 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar this.traversalTasks = shardStrategy.iterator(); - ReduceTree reduceTree = new ReduceTree(this); + final ReduceTree reduceTree = new ReduceTree(this); initializeWalker(walker); while (isShardTraversePending() || isTreeReducePending()) { @@ -301,17 +301,13 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar if (!traversalTasks.hasNext()) throw new IllegalStateException("Cannot traverse; no pending traversals exist."); - Shard shard = traversalTasks.next(); + final Shard shard = traversalTasks.next(); // todo -- add ownership claim here - ShardTraverser traverser = new ShardTraverser(this, - traversalEngine, - walker, - shard, - outputTracker); + final ShardTraverser traverser = new ShardTraverser(this, walker, shard, outputTracker); - Future traverseResult = threadPool.submit(traverser); + final Future traverseResult = threadPool.submit(traverser); // Add this traverse result to the reduce tree. The reduce tree will call a callback to throw its entries on the queue. reduceTree.addEntry(traverseResult); @@ -326,7 +322,7 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar protected void queueNextTreeReduce( Walker walker ) { if (reduceTasks.size() == 0) throw new IllegalStateException("Cannot reduce; no pending reduces exist."); - TreeReduceTask reducer = reduceTasks.remove(); + final TreeReduceTask reducer = reduceTasks.remove(); reducer.setWalker((TreeReducible) walker); threadPool.submit(reducer); @@ -334,7 +330,7 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar /** Blocks until a free slot appears in the thread queue. */ protected void waitForFreeQueueSlot() { - ThreadPoolMonitor monitor = new ThreadPoolMonitor(); + final ThreadPoolMonitor monitor = new ThreadPoolMonitor(); synchronized (monitor) { threadPool.submit(monitor); monitor.watch(); @@ -346,8 +342,8 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar * * @return A new, composite future of the result of this reduce. */ - public Future notifyReduce( Future lhs, Future rhs ) { - TreeReduceTask reducer = new TreeReduceTask(new TreeReducer(this, lhs, rhs)); + public Future notifyReduce( final Future lhs, final Future rhs ) { + final TreeReduceTask reducer = new TreeReduceTask(new TreeReducer(this, lhs, rhs)); reduceTasks.add(reducer); return reducer; } @@ -375,7 +371,7 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar return this.error; } - private final RuntimeException toRuntimeException(final Throwable error) { + private RuntimeException toRuntimeException(final Throwable error) { // If the error is already a Runtime, pass it along as is. Otherwise, wrap it. if (error instanceof RuntimeException) return (RuntimeException)error; @@ -386,7 +382,7 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar /** A small wrapper class that provides the TreeReducer interface along with the FutureTask semantics. */ private class TreeReduceTask extends FutureTask { - private TreeReducer treeReducer = null; + final private TreeReducer treeReducer; public TreeReduceTask( TreeReducer treeReducer ) { super(treeReducer); diff --git a/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java b/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java index 697e908fd..60f7317ba 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java +++ b/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java @@ -69,7 +69,7 @@ public class LinearMicroScheduler extends MicroScheduler { getReadIterator(shard), shard.getGenomeLocs(), SampleUtils.getSAMFileSamples(engine)); for(WindowMaker.WindowMakerIterator iterator: windowMaker) { ShardDataProvider dataProvider = new LocusShardDataProvider(shard,iterator.getSourceInfo(),engine.getGenomeLocParser(),iterator.getLocus(),iterator,reference,rods); - Object result = traversalEngine.traverse(walker, dataProvider, accumulator.getReduceInit()); + Object result = getTraversalEngine().traverse(walker, dataProvider, accumulator.getReduceInit()); accumulator.accumulate(dataProvider,result); dataProvider.close(); if ( walker.isDone() ) break; @@ -78,7 +78,7 @@ public class LinearMicroScheduler extends MicroScheduler { } else { ShardDataProvider dataProvider = new ReadShardDataProvider(shard,engine.getGenomeLocParser(),getReadIterator(shard),reference,rods); - Object result = traversalEngine.traverse(walker, dataProvider, accumulator.getReduceInit()); + Object result = getTraversalEngine().traverse(walker, dataProvider, accumulator.getReduceInit()); accumulator.accumulate(dataProvider,result); dataProvider.close(); } @@ -87,8 +87,8 @@ public class LinearMicroScheduler extends MicroScheduler { } // Special function call to empty out the work queue. Ugly for now but will be cleaned up when we eventually push this functionality more into the engine - if( traversalEngine instanceof TraverseActiveRegions ) { - final Object result = ((TraverseActiveRegions) traversalEngine).endTraversal(walker, accumulator.getReduceInit()); + if( getTraversalEngine() instanceof TraverseActiveRegions ) { + final Object result = ((TraverseActiveRegions) getTraversalEngine()).endTraversal(walker, accumulator.getReduceInit()); accumulator.accumulate(null, result); // Assumes only used with StandardAccumulator } diff --git a/public/java/src/org/broadinstitute/sting/gatk/executive/MicroScheduler.java b/public/java/src/org/broadinstitute/sting/gatk/executive/MicroScheduler.java index 3e843de3e..4024b247d 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/executive/MicroScheduler.java +++ b/public/java/src/org/broadinstitute/sting/gatk/executive/MicroScheduler.java @@ -25,6 +25,7 @@ package org.broadinstitute.sting.gatk.executive; +import com.google.java.contract.Ensures; import net.sf.picard.reference.IndexedFastaSequenceFile; import org.apache.log4j.Logger; import org.broadinstitute.sting.gatk.GenomeAnalysisEngine; @@ -50,6 +51,8 @@ import javax.management.ObjectName; import java.io.File; import java.lang.management.ManagementFactory; import java.util.Collection; +import java.util.LinkedList; +import java.util.List; import java.util.Map; @@ -78,7 +81,7 @@ public abstract class MicroScheduler implements MicroSchedulerMBean { */ protected final GenomeAnalysisEngine engine; - protected final TraversalEngine traversalEngine; + private final TraversalEngineCreator traversalEngineCreator; protected final IndexedFastaSequenceFile reference; private final SAMDataSource reads; @@ -110,11 +113,6 @@ public abstract class MicroScheduler implements MicroSchedulerMBean { */ public static MicroScheduler create(GenomeAnalysisEngine engine, Walker walker, SAMDataSource reads, IndexedFastaSequenceFile reference, Collection rods, ThreadAllocation threadAllocation) { if ( threadAllocation.isRunningInParallelMode() ) { - // TODO -- remove me when we fix running NCT within HMS - if ( threadAllocation.getNumDataThreads() > 1 && threadAllocation.getNumCPUThreadsPerDataThread() > 1) - throw new UserException("Currently the GATK does not support running CPU threads within data threads, " + - "please specify only one of NT and NCT"); - logger.info(String.format("Running the GATK in parallel mode with %d CPU thread(s) for each of %d data thread(s)", threadAllocation.getNumCPUThreadsPerDataThread(), threadAllocation.getNumDataThreads())); } @@ -160,30 +158,12 @@ public abstract class MicroScheduler implements MicroSchedulerMBean { this.reads = reads; this.reference = reference; this.rods = rods; - - if (walker instanceof ReadWalker) { - traversalEngine = USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1 - ? new TraverseReadsNano(threadAllocation.getNumCPUThreadsPerDataThread()) - : new TraverseReads(); - } else if (walker instanceof LocusWalker) { - traversalEngine = USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1 - ? new TraverseLociNano(threadAllocation.getNumCPUThreadsPerDataThread()) - : new TraverseLociLinear(); - } else if (walker instanceof DuplicateWalker) { - traversalEngine = new TraverseDuplicates(); - } else if (walker instanceof ReadPairWalker) { - traversalEngine = new TraverseReadPairs(); - } else if (walker instanceof ActiveRegionWalker) { - traversalEngine = new TraverseActiveRegions(); - } else { - throw new UnsupportedOperationException("Unable to determine traversal type, the walker is an unknown type."); - } + this.traversalEngineCreator = new TraversalEngineCreator(walker, threadAllocation); final File progressLogFile = engine.getArguments() == null ? null : engine.getArguments().performanceLog; this.progressMeter = new ProgressMeter(progressLogFile, - traversalEngine.getTraversalUnits(), + traversalEngineCreator.getTraversalUnits(), engine.getRegionsOfGenomeBeingProcessed()); - traversalEngine.initialize(engine, progressMeter); // JMX does not allow multiple instances with the same ObjectName to be registered with the same platform MXBean. // To get around this limitation and since we have no job identifier at this point, register a simple counter that @@ -249,8 +229,8 @@ public abstract class MicroScheduler implements MicroSchedulerMBean { progressMeter.notifyDone(engine.getCumulativeMetrics().getNumIterations()); printReadFilteringStats(); - // TODO -- generalize to all local thread copies - traversalEngine.shutdown(); + for ( final TraversalEngine te : traversalEngineCreator.getCreatedEngines() ) + te.shutdown(); // Print out the threading efficiency of this HMS, if state monitoring is enabled if ( threadEfficiencyMonitor != null ) { @@ -317,4 +297,115 @@ public abstract class MicroScheduler implements MicroSchedulerMBean { throw new ReviewedStingException("Unable to unregister microscheduler with JMX", ex); } } + + /** + * Returns a traversal engine suitable for use in this thread. + * + * May create a new traversal engine for this thread, if this is the first + * time this thread ever asked for a TraversalEngine. + * + * @return a non-null TraversalEngine suitable for execution in this scheduler + */ + public TraversalEngine getTraversalEngine() { + return traversalEngineCreator.get(); + } + + /** + * ThreadLocal TraversalEngine creator + * + * TraversalEngines are thread local variables to the MicroScheduler. This is necessary + * because in the HMS case you have multiple threads executing a traversal engine independently, and + * these engines may need to create separate resources for efficiency or implementation reasons. For example, + * the nanoScheduler creates threads to implement the traversal, and this creation is instance specific. + * So each HMS thread needs to have it's own distinct copy of the traversal engine if it wants to have + * N data threads x M nano threads => N * M threads total. + * + * This class also tracks all created traversal engines so this microscheduler can properly + * shut them all down when the scheduling is done. + */ + private class TraversalEngineCreator extends ThreadLocal { + final List createdEngines = new LinkedList(); + final Walker walker; + final ThreadAllocation threadAllocation; + + /** + * Creates an initialized TraversalEngine appropriate for walker and threadAllocation, + * and adds it to the list of created engines for later shutdown. + * + * @return a non-null traversal engine + */ + @Override + protected synchronized TraversalEngine initialValue() { + final TraversalEngine traversalEngine = createEngine(); + traversalEngine.initialize(engine, progressMeter); + createdEngines.add(traversalEngine); + return traversalEngine; + } + + /** + * Returns the traversal units for traversal engines created here. + * + * This (unfortunately) creates an uninitialized tmp. TraversalEngine so we can get + * it's traversal units, and then immediately shuts it down... + * + * @return the traversal unit as returned by getTraversalUnits of TraversalEngines created here + */ + protected String getTraversalUnits() { + final TraversalEngine tmp = createEngine(); + final String units = tmp.getTraversalUnits(); + tmp.shutdown(); + return units; + } + + /** + * Really make us a traversal engine of the appropriate type for walker and thread allocation + * + * @return a non-null uninitialized traversal engine + */ + @Ensures("result != null") + protected TraversalEngine createEngine() { + if (walker instanceof ReadWalker) { + if ( USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1 ) + return new TraverseReadsNano(threadAllocation.getNumCPUThreadsPerDataThread()); + else + return new TraverseReads(); + } else if (walker instanceof LocusWalker) { + if ( USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1 ) + return new TraverseLociNano(threadAllocation.getNumCPUThreadsPerDataThread()); + else + return new TraverseLociLinear(); + } else if (walker instanceof DuplicateWalker) { + return new TraverseDuplicates(); + } else if (walker instanceof ReadPairWalker) { + return new TraverseReadPairs(); + } else if (walker instanceof ActiveRegionWalker) { + return new TraverseActiveRegions(); + } else { + throw new UnsupportedOperationException("Unable to determine traversal type, the walker is an unknown type."); + } + } + + /** + * Create a TraversalEngineCreator that makes TraversalEngines appropriate for walker and threadAllocation + * + * @param walker the walker we need traversal engines for + * @param threadAllocation what kind of threading will we use in the traversal? + */ + @com.google.java.contract.Requires({"walker != null", "threadAllocation != null"}) + public TraversalEngineCreator(final Walker walker, final ThreadAllocation threadAllocation) { + super(); + this.walker = walker; + this.threadAllocation = threadAllocation; + } + + /** + * Get the list of all traversal engines we've created + * + * @return a non-null list of engines created so far + */ + @Ensures("result != null") + public List getCreatedEngines() { + return createdEngines; + } + } } diff --git a/public/java/src/org/broadinstitute/sting/gatk/executive/ShardTraverser.java b/public/java/src/org/broadinstitute/sting/gatk/executive/ShardTraverser.java index 790c6b3ed..e8f15ebef 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/executive/ShardTraverser.java +++ b/public/java/src/org/broadinstitute/sting/gatk/executive/ShardTraverser.java @@ -5,7 +5,6 @@ import org.broadinstitute.sting.gatk.datasources.providers.LocusShardDataProvide import org.broadinstitute.sting.gatk.datasources.providers.ShardDataProvider; import org.broadinstitute.sting.gatk.datasources.reads.Shard; import org.broadinstitute.sting.gatk.io.ThreadLocalOutputTracker; -import org.broadinstitute.sting.gatk.traversals.TraversalEngine; import org.broadinstitute.sting.gatk.walkers.Walker; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; @@ -29,7 +28,6 @@ public class ShardTraverser implements Callable { final private HierarchicalMicroScheduler microScheduler; final private Walker walker; final private Shard shard; - final private TraversalEngine traversalEngine; final private ThreadLocalOutputTracker outputTracker; private OutputMergeTask outputMergeTask; @@ -42,13 +40,11 @@ public class ShardTraverser implements Callable { private boolean complete = false; public ShardTraverser( HierarchicalMicroScheduler microScheduler, - TraversalEngine traversalEngine, Walker walker, Shard shard, ThreadLocalOutputTracker outputTracker) { this.microScheduler = microScheduler; this.walker = walker; - this.traversalEngine = traversalEngine; this.shard = shard; this.outputTracker = outputTracker; } @@ -65,7 +61,7 @@ public class ShardTraverser implements Callable { for(WindowMaker.WindowMakerIterator iterator: windowMaker) { final ShardDataProvider dataProvider = new LocusShardDataProvider(shard,iterator.getSourceInfo(),microScheduler.getEngine().getGenomeLocParser(),iterator.getLocus(),iterator,microScheduler.reference,microScheduler.rods); - accumulator = traversalEngine.traverse( walker, dataProvider, accumulator ); + accumulator = microScheduler.getTraversalEngine().traverse(walker, dataProvider, accumulator); dataProvider.close(); }