Fixed GSA-515 Nanoscheduler GSA-555 / Make NT and NCT work together

-- Can now say -nt 4 and -nct 4 to get 16 threads running for you!
-- TraversalEngines are now ThreadLocal variables in the MicroScheduler.
-- Misc. code cleanup, final variables, some contracts.
This commit is contained in:
Mark DePristo 2012-09-09 16:52:52 -04:00
parent 233f70f8ba
commit f713d400e2
4 changed files with 134 additions and 51 deletions

View File

@ -107,7 +107,7 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar
this.traversalTasks = shardStrategy.iterator();
ReduceTree reduceTree = new ReduceTree(this);
final ReduceTree reduceTree = new ReduceTree(this);
initializeWalker(walker);
while (isShardTraversePending() || isTreeReducePending()) {
@ -301,17 +301,13 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar
if (!traversalTasks.hasNext())
throw new IllegalStateException("Cannot traverse; no pending traversals exist.");
Shard shard = traversalTasks.next();
final Shard shard = traversalTasks.next();
// todo -- add ownership claim here
ShardTraverser traverser = new ShardTraverser(this,
traversalEngine,
walker,
shard,
outputTracker);
final ShardTraverser traverser = new ShardTraverser(this, walker, shard, outputTracker);
Future traverseResult = threadPool.submit(traverser);
final Future traverseResult = threadPool.submit(traverser);
// Add this traverse result to the reduce tree. The reduce tree will call a callback to throw its entries on the queue.
reduceTree.addEntry(traverseResult);
@ -326,7 +322,7 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar
protected void queueNextTreeReduce( Walker walker ) {
if (reduceTasks.size() == 0)
throw new IllegalStateException("Cannot reduce; no pending reduces exist.");
TreeReduceTask reducer = reduceTasks.remove();
final TreeReduceTask reducer = reduceTasks.remove();
reducer.setWalker((TreeReducible) walker);
threadPool.submit(reducer);
@ -334,7 +330,7 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar
/** Blocks until a free slot appears in the thread queue. */
protected void waitForFreeQueueSlot() {
ThreadPoolMonitor monitor = new ThreadPoolMonitor();
final ThreadPoolMonitor monitor = new ThreadPoolMonitor();
synchronized (monitor) {
threadPool.submit(monitor);
monitor.watch();
@ -346,8 +342,8 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar
*
* @return A new, composite future of the result of this reduce.
*/
public Future notifyReduce( Future lhs, Future rhs ) {
TreeReduceTask reducer = new TreeReduceTask(new TreeReducer(this, lhs, rhs));
public Future notifyReduce( final Future lhs, final Future rhs ) {
final TreeReduceTask reducer = new TreeReduceTask(new TreeReducer(this, lhs, rhs));
reduceTasks.add(reducer);
return reducer;
}
@ -375,7 +371,7 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar
return this.error;
}
private final RuntimeException toRuntimeException(final Throwable error) {
private RuntimeException toRuntimeException(final Throwable error) {
// If the error is already a Runtime, pass it along as is. Otherwise, wrap it.
if (error instanceof RuntimeException)
return (RuntimeException)error;
@ -386,7 +382,7 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar
/** A small wrapper class that provides the TreeReducer interface along with the FutureTask semantics. */
private class TreeReduceTask extends FutureTask {
private TreeReducer treeReducer = null;
final private TreeReducer treeReducer;
public TreeReduceTask( TreeReducer treeReducer ) {
super(treeReducer);

View File

@ -69,7 +69,7 @@ public class LinearMicroScheduler extends MicroScheduler {
getReadIterator(shard), shard.getGenomeLocs(), SampleUtils.getSAMFileSamples(engine));
for(WindowMaker.WindowMakerIterator iterator: windowMaker) {
ShardDataProvider dataProvider = new LocusShardDataProvider(shard,iterator.getSourceInfo(),engine.getGenomeLocParser(),iterator.getLocus(),iterator,reference,rods);
Object result = traversalEngine.traverse(walker, dataProvider, accumulator.getReduceInit());
Object result = getTraversalEngine().traverse(walker, dataProvider, accumulator.getReduceInit());
accumulator.accumulate(dataProvider,result);
dataProvider.close();
if ( walker.isDone() ) break;
@ -78,7 +78,7 @@ public class LinearMicroScheduler extends MicroScheduler {
}
else {
ShardDataProvider dataProvider = new ReadShardDataProvider(shard,engine.getGenomeLocParser(),getReadIterator(shard),reference,rods);
Object result = traversalEngine.traverse(walker, dataProvider, accumulator.getReduceInit());
Object result = getTraversalEngine().traverse(walker, dataProvider, accumulator.getReduceInit());
accumulator.accumulate(dataProvider,result);
dataProvider.close();
}
@ -87,8 +87,8 @@ public class LinearMicroScheduler extends MicroScheduler {
}
// Special function call to empty out the work queue. Ugly for now but will be cleaned up when we eventually push this functionality more into the engine
if( traversalEngine instanceof TraverseActiveRegions ) {
final Object result = ((TraverseActiveRegions) traversalEngine).endTraversal(walker, accumulator.getReduceInit());
if( getTraversalEngine() instanceof TraverseActiveRegions ) {
final Object result = ((TraverseActiveRegions) getTraversalEngine()).endTraversal(walker, accumulator.getReduceInit());
accumulator.accumulate(null, result); // Assumes only used with StandardAccumulator
}

View File

@ -25,6 +25,7 @@
package org.broadinstitute.sting.gatk.executive;
import com.google.java.contract.Ensures;
import net.sf.picard.reference.IndexedFastaSequenceFile;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.gatk.GenomeAnalysisEngine;
@ -50,6 +51,8 @@ import javax.management.ObjectName;
import java.io.File;
import java.lang.management.ManagementFactory;
import java.util.Collection;
import java.util.LinkedList;
import java.util.List;
import java.util.Map;
@ -78,7 +81,7 @@ public abstract class MicroScheduler implements MicroSchedulerMBean {
*/
protected final GenomeAnalysisEngine engine;
protected final TraversalEngine traversalEngine;
private final TraversalEngineCreator traversalEngineCreator;
protected final IndexedFastaSequenceFile reference;
private final SAMDataSource reads;
@ -110,11 +113,6 @@ public abstract class MicroScheduler implements MicroSchedulerMBean {
*/
public static MicroScheduler create(GenomeAnalysisEngine engine, Walker walker, SAMDataSource reads, IndexedFastaSequenceFile reference, Collection<ReferenceOrderedDataSource> rods, ThreadAllocation threadAllocation) {
if ( threadAllocation.isRunningInParallelMode() ) {
// TODO -- remove me when we fix running NCT within HMS
if ( threadAllocation.getNumDataThreads() > 1 && threadAllocation.getNumCPUThreadsPerDataThread() > 1)
throw new UserException("Currently the GATK does not support running CPU threads within data threads, " +
"please specify only one of NT and NCT");
logger.info(String.format("Running the GATK in parallel mode with %d CPU thread(s) for each of %d data thread(s)",
threadAllocation.getNumCPUThreadsPerDataThread(), threadAllocation.getNumDataThreads()));
}
@ -160,30 +158,12 @@ public abstract class MicroScheduler implements MicroSchedulerMBean {
this.reads = reads;
this.reference = reference;
this.rods = rods;
if (walker instanceof ReadWalker) {
traversalEngine = USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1
? new TraverseReadsNano(threadAllocation.getNumCPUThreadsPerDataThread())
: new TraverseReads();
} else if (walker instanceof LocusWalker) {
traversalEngine = USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1
? new TraverseLociNano(threadAllocation.getNumCPUThreadsPerDataThread())
: new TraverseLociLinear();
} else if (walker instanceof DuplicateWalker) {
traversalEngine = new TraverseDuplicates();
} else if (walker instanceof ReadPairWalker) {
traversalEngine = new TraverseReadPairs();
} else if (walker instanceof ActiveRegionWalker) {
traversalEngine = new TraverseActiveRegions();
} else {
throw new UnsupportedOperationException("Unable to determine traversal type, the walker is an unknown type.");
}
this.traversalEngineCreator = new TraversalEngineCreator(walker, threadAllocation);
final File progressLogFile = engine.getArguments() == null ? null : engine.getArguments().performanceLog;
this.progressMeter = new ProgressMeter(progressLogFile,
traversalEngine.getTraversalUnits(),
traversalEngineCreator.getTraversalUnits(),
engine.getRegionsOfGenomeBeingProcessed());
traversalEngine.initialize(engine, progressMeter);
// JMX does not allow multiple instances with the same ObjectName to be registered with the same platform MXBean.
// To get around this limitation and since we have no job identifier at this point, register a simple counter that
@ -249,8 +229,8 @@ public abstract class MicroScheduler implements MicroSchedulerMBean {
progressMeter.notifyDone(engine.getCumulativeMetrics().getNumIterations());
printReadFilteringStats();
// TODO -- generalize to all local thread copies
traversalEngine.shutdown();
for ( final TraversalEngine te : traversalEngineCreator.getCreatedEngines() )
te.shutdown();
// Print out the threading efficiency of this HMS, if state monitoring is enabled
if ( threadEfficiencyMonitor != null ) {
@ -317,4 +297,115 @@ public abstract class MicroScheduler implements MicroSchedulerMBean {
throw new ReviewedStingException("Unable to unregister microscheduler with JMX", ex);
}
}
/**
* Returns a traversal engine suitable for use in this thread.
*
* May create a new traversal engine for this thread, if this is the first
* time this thread ever asked for a TraversalEngine.
*
* @return a non-null TraversalEngine suitable for execution in this scheduler
*/
public TraversalEngine getTraversalEngine() {
return traversalEngineCreator.get();
}
/**
* ThreadLocal TraversalEngine creator
*
* TraversalEngines are thread local variables to the MicroScheduler. This is necessary
* because in the HMS case you have multiple threads executing a traversal engine independently, and
* these engines may need to create separate resources for efficiency or implementation reasons. For example,
* the nanoScheduler creates threads to implement the traversal, and this creation is instance specific.
* So each HMS thread needs to have it's own distinct copy of the traversal engine if it wants to have
* N data threads x M nano threads => N * M threads total.
*
* This class also tracks all created traversal engines so this microscheduler can properly
* shut them all down when the scheduling is done.
*/
private class TraversalEngineCreator extends ThreadLocal<TraversalEngine> {
final List<TraversalEngine> createdEngines = new LinkedList<TraversalEngine>();
final Walker walker;
final ThreadAllocation threadAllocation;
/**
* Creates an initialized TraversalEngine appropriate for walker and threadAllocation,
* and adds it to the list of created engines for later shutdown.
*
* @return a non-null traversal engine
*/
@Override
protected synchronized TraversalEngine initialValue() {
final TraversalEngine traversalEngine = createEngine();
traversalEngine.initialize(engine, progressMeter);
createdEngines.add(traversalEngine);
return traversalEngine;
}
/**
* Returns the traversal units for traversal engines created here.
*
* This (unfortunately) creates an uninitialized tmp. TraversalEngine so we can get
* it's traversal units, and then immediately shuts it down...
*
* @return the traversal unit as returned by getTraversalUnits of TraversalEngines created here
*/
protected String getTraversalUnits() {
final TraversalEngine tmp = createEngine();
final String units = tmp.getTraversalUnits();
tmp.shutdown();
return units;
}
/**
* Really make us a traversal engine of the appropriate type for walker and thread allocation
*
* @return a non-null uninitialized traversal engine
*/
@Ensures("result != null")
protected TraversalEngine createEngine() {
if (walker instanceof ReadWalker) {
if ( USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1 )
return new TraverseReadsNano(threadAllocation.getNumCPUThreadsPerDataThread());
else
return new TraverseReads();
} else if (walker instanceof LocusWalker) {
if ( USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1 )
return new TraverseLociNano(threadAllocation.getNumCPUThreadsPerDataThread());
else
return new TraverseLociLinear();
} else if (walker instanceof DuplicateWalker) {
return new TraverseDuplicates();
} else if (walker instanceof ReadPairWalker) {
return new TraverseReadPairs();
} else if (walker instanceof ActiveRegionWalker) {
return new TraverseActiveRegions();
} else {
throw new UnsupportedOperationException("Unable to determine traversal type, the walker is an unknown type.");
}
}
/**
* Create a TraversalEngineCreator that makes TraversalEngines appropriate for walker and threadAllocation
*
* @param walker the walker we need traversal engines for
* @param threadAllocation what kind of threading will we use in the traversal?
*/
@com.google.java.contract.Requires({"walker != null", "threadAllocation != null"})
public TraversalEngineCreator(final Walker walker, final ThreadAllocation threadAllocation) {
super();
this.walker = walker;
this.threadAllocation = threadAllocation;
}
/**
* Get the list of all traversal engines we've created
*
* @return a non-null list of engines created so far
*/
@Ensures("result != null")
public List<TraversalEngine> getCreatedEngines() {
return createdEngines;
}
}
}

View File

@ -5,7 +5,6 @@ import org.broadinstitute.sting.gatk.datasources.providers.LocusShardDataProvide
import org.broadinstitute.sting.gatk.datasources.providers.ShardDataProvider;
import org.broadinstitute.sting.gatk.datasources.reads.Shard;
import org.broadinstitute.sting.gatk.io.ThreadLocalOutputTracker;
import org.broadinstitute.sting.gatk.traversals.TraversalEngine;
import org.broadinstitute.sting.gatk.walkers.Walker;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
@ -29,7 +28,6 @@ public class ShardTraverser implements Callable {
final private HierarchicalMicroScheduler microScheduler;
final private Walker walker;
final private Shard shard;
final private TraversalEngine traversalEngine;
final private ThreadLocalOutputTracker outputTracker;
private OutputMergeTask outputMergeTask;
@ -42,13 +40,11 @@ public class ShardTraverser implements Callable {
private boolean complete = false;
public ShardTraverser( HierarchicalMicroScheduler microScheduler,
TraversalEngine traversalEngine,
Walker walker,
Shard shard,
ThreadLocalOutputTracker outputTracker) {
this.microScheduler = microScheduler;
this.walker = walker;
this.traversalEngine = traversalEngine;
this.shard = shard;
this.outputTracker = outputTracker;
}
@ -65,7 +61,7 @@ public class ShardTraverser implements Callable {
for(WindowMaker.WindowMakerIterator iterator: windowMaker) {
final ShardDataProvider dataProvider = new LocusShardDataProvider(shard,iterator.getSourceInfo(),microScheduler.getEngine().getGenomeLocParser(),iterator.getLocus(),iterator,microScheduler.reference,microScheduler.rods);
accumulator = traversalEngine.traverse( walker, dataProvider, accumulator );
accumulator = microScheduler.getTraversalEngine().traverse(walker, dataProvider, accumulator);
dataProvider.close();
}