Intermediate commit for new hyper parallel NanoScheduler

-- There's a logic bug now but I'll go to squash it...
This commit is contained in:
Mark DePristo 2012-09-06 14:33:31 -04:00
parent 576c7280d9
commit 9d12935986
3 changed files with 207 additions and 88 deletions

View File

@ -5,13 +5,11 @@ import com.google.java.contract.Requires;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.utils.AutoFormattingTime;
import org.broadinstitute.sting.utils.SimpleTimer;
import org.broadinstitute.sting.utils.collections.Pair;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import org.broadinstitute.sting.utils.threading.NamedThreadFactory;
import java.util.Iterator;
import java.util.LinkedList;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.*;
/**
@ -52,7 +50,9 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
final int bufferSize;
final int nThreads;
final ExecutorService inputExecutor;
final ExecutorService reduceExecutor;
final ExecutorService mapExecutor;
boolean shutdown = false;
boolean debug = false;
@ -77,8 +77,14 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
this.bufferSize = bufferSize;
this.nThreads = nThreads;
this.mapExecutor = nThreads == 1 ? null : Executors.newFixedThreadPool(nThreads-1);
this.inputExecutor = Executors.newSingleThreadExecutor();
if ( nThreads == 1 ) {
this.mapExecutor = this.inputExecutor = this.reduceExecutor = null;
} else {
this.mapExecutor = Executors.newFixedThreadPool(nThreads-1, new NamedThreadFactory("NS-map-thread-%d"));
this.inputExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("NS-input-thread-%d"));
this.reduceExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("NS-reduce-thread-%d"));
}
// start timing the time spent outside of the nanoScheduler
outsideSchedulerTimer.start();
@ -110,11 +116,9 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
public void shutdown() {
outsideSchedulerTimer.stop();
if ( mapExecutor != null ) {
final List<Runnable> remaining = mapExecutor.shutdownNow();
if ( ! remaining.isEmpty() )
throw new IllegalStateException("Remaining tasks found in the mapExecutor, unexpected behavior!");
}
shutdownExecutor("inputExecutor", inputExecutor);
shutdownExecutor("mapExecutor", mapExecutor);
shutdownExecutor("reduceExecutor", reduceExecutor);
shutdown = true;
if (TIME_CALLS) {
@ -125,6 +129,31 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
}
}
/**
* Helper function to cleanly shutdown an execution service, checking that the execution
* state is clean when it's done.
*
* @param name a string name for error messages for the executorService we are shutting down
* @param executorService the executorService to shut down
*/
private void shutdownExecutor(final String name, final ExecutorService executorService) {
if ( executorService != null ) {
if ( executorService.isShutdown() || executorService.isTerminated() )
throw new IllegalStateException("Executor service " + name + " is already shut down!");
final List<Runnable> remaining = executorService.shutdownNow();
if ( ! remaining.isEmpty() )
throw new IllegalStateException(remaining.size() + " remaining tasks found in an executor " + name + ", unexpected behavior!");
}
}
/**
* Print to logger.info timing information from timer, with name label
*
* @param label the name of the timer to display. Should be human readable
* @param timer the timer whose elapsed time we will display
*/
@Requires({"label != null", "timer != null"})
private void printTimerInfo(final String label, final SimpleTimer timer) {
final double total = inputTimer.getElapsedTime() + mapTimer.getElapsedTime()
+ reduceTimer.getElapsedTime() + outsideSchedulerTimer.getElapsedTime();
@ -140,16 +169,30 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
return shutdown;
}
/**
* @return are we displaying verbose debugging information about the scheduling?
*/
public boolean isDebug() {
return debug;
}
/**
* Helper function to display a String.formatted message if we are doing verbose debugging
*
* @param format the format argument suitable for String.format
* @param args the arguments for String.format
*/
@Requires("format != null")
private void debugPrint(final String format, Object ... args) {
if ( isDebug() )
logger.info("Thread " + Thread.currentThread().getId() + ":" + String.format(format, args));
}
/**
* Turn on/off verbose debugging
*
* @param debug true if we want verbose debugging
*/
public void setDebug(boolean debug) {
this.debug = debug;
}
@ -179,6 +222,9 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
* It is safe to call this function repeatedly on a single nanoScheduler, at least until the
* shutdown method is called.
*
* Note that this function goes through a single threaded fast path if the number of threads
* is 1.
*
* @param inputReader an iterator providing us with the input data to nanoSchedule map/reduce over
* @param map the map function from input type -> map type, will be applied in parallel to each input
* @param reduce the reduce function from map type + reduce type -> reduce type to be applied in order to map results
@ -207,9 +253,11 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
}
/**
* Simple efficient reference implementation for single threaded execution
* Simple efficient reference implementation for single threaded execution.
*
* @return the reduce result of this map/reduce job
*/
@Requires({"inputReader != null", "map != null", "reduce != null"})
private ReduceType executeSingleThreaded(final Iterator<InputType> inputReader,
final NanoSchedulerMapFunction<InputType, MapType> map,
final ReduceType initialValue,
@ -249,88 +297,111 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
*
* @return the reduce result of this map/reduce job
*/
@Requires({"inputReader != null", "map != null", "reduce != null"})
private ReduceType executeMultiThreaded(final Iterator<InputType> inputReader,
final NanoSchedulerMapFunction<InputType, MapType> map,
final ReduceType initialValue,
final NanoSchedulerReduceFunction<MapType, ReduceType> reduce) {
debugPrint("Executing nanoScheduler");
ReduceType sum = initialValue;
boolean done = false;
// a completion service that tracks when jobs complete, so we can wait in this thread
// until all of the map jobs are completed, without having to shut down the executor itself
final ExecutorCompletionService<Integer> mapJobCompletionService =
new ExecutorCompletionService<Integer>(mapExecutor);
// a blocking queue that limits the number of input datum to the requested buffer size
final BlockingQueue<InputDatum> inputQueue = new LinkedBlockingDeque<InputDatum>(bufferSize);
// a priority queue that stores up to bufferSize * MAP_QUEUE_SCALE_FACTOR elements
// produced by completed map jobs.
final PriorityBlockingQueue<MapResult> mapResultQueue = new PriorityBlockingQueue<MapResult>(bufferSize*100);
// TODO -- the logic of this blocking queue is wrong! We need to wait for map jobs in order, not just
// -- in the order in which they are produced
// TODO -- map executor must have fixed size map jobs queue
inputExecutor.submit(new InputProducer(inputReader, inputQueue));
final Future<ReduceType> reduceResult = reduceExecutor.submit(new ReducerThread(reduce, initialValue, mapResultQueue));
while ( ! done ) {
try {
final Pair<List<InputType>, Boolean> readResults = readInputs(inputQueue);
final List<InputType> inputs = readResults.getFirst();
done = readResults.getSecond();
try {
int numJobs = 0;
while ( true ) {
// block on input
final InputDatum inputEnqueueWrapped = inputQueue.take();
if ( ! inputs.isEmpty() ) {
// send jobs for map
final Queue<Future<MapType>> mapQueue = submitMapJobs(map, mapExecutor, inputs);
if ( ! inputEnqueueWrapped.isLast() ) {
// get the object itself
final InputType input = inputEnqueueWrapped.datum;
// the next map call has id + 1
numJobs++;
// send job for map via the completion service
final CallableMap doMap = new CallableMap(map, numJobs, input, mapResultQueue);
mapJobCompletionService.submit(doMap, numJobs);
// send off the reduce job, and block until we get at least one reduce result
sum = reduceSerial(reduce, mapQueue, sum);
debugPrint(" Done with cycle of map/reduce");
if ( progressFunction != null ) progressFunction.progress(inputs.get(inputs.size()-1));
if ( progressFunction != null ) // TODO -- don't cycle so often
progressFunction.progress(input);
} else {
// we must be done
if ( ! done ) throw new IllegalStateException("Inputs empty but not done");
waitForLastJob(mapJobCompletionService, numJobs);
mapResultQueue.add(new MapResult());
return reduceResult.get(); // wait for our result of reduce
}
} catch (InterruptedException ex) {
throw new ReviewedStingException("got execution exception", ex);
} catch (ExecutionException ex) {
throw new ReviewedStingException("got execution exception", ex);
}
} catch (InterruptedException ex) {
throw new ReviewedStingException("got execution exception", ex);
} catch (ExecutionException ex) {
throw new ReviewedStingException("got execution exception", ex);
}
return sum;
}
@Requires({"reduce != null", "! mapQueue.isEmpty()"})
private ReduceType reduceSerial(final NanoSchedulerReduceFunction<MapType, ReduceType> reduce,
final Queue<Future<MapType>> mapQueue,
final ReduceType initSum)
throws InterruptedException, ExecutionException {
ReduceType sum = initSum;
// while mapQueue has something in it to reduce
for ( final Future<MapType> future : mapQueue ) {
final MapType value = future.get(); // block until we get the values for this task
if ( TIME_CALLS ) reduceTimer.restart();
sum = reduce.apply(value, sum);
if ( TIME_CALLS ) reduceTimer.stop();
}
return sum;
}
/**
* Read up to inputBufferSize elements from inputReader
*
* @return a queue of input read in, containing one or more values of InputType read in
* Helper routine that will wait until the last map job finishes running
* by taking numJob values from the executor completion service, using
* the blocking take() call.
*/
@Requires("inputReader != null")
@Ensures("result != null")
private Pair<List<InputType>, Boolean> readInputs(final BlockingQueue<InputDatum> inputReader) throws InterruptedException {
int n = 0;
final List<InputType> inputs = new LinkedList<InputType>();
boolean done = false;
private void waitForLastJob(final ExecutorCompletionService<Integer> mapJobCompletionService,
final int numJobs ) throws InterruptedException {
for ( int i = 0; i < numJobs; i++ )
mapJobCompletionService.take();
}
while ( ! done && n < getBufferSize() ) {
final InputDatum input = inputReader.take();
done = input.isLast();
if ( ! done ) {
inputs.add(input.datum);
n++;
}
private class ReducerThread implements Callable {
final NanoSchedulerReduceFunction<MapType, ReduceType> reduce;
ReduceType sum;
final PriorityBlockingQueue<MapResult> mapResultQueue;
public ReducerThread(final NanoSchedulerReduceFunction<MapType, ReduceType> reduce,
final ReduceType sum,
final PriorityBlockingQueue<MapResult> mapResultQueue) {
this.reduce = reduce;
this.sum = sum;
this.mapResultQueue = mapResultQueue;
}
return new Pair<List<InputType>, Boolean>(inputs, done);
public ReduceType call() {
try {
while ( true ) {
final MapResult result = mapResultQueue.take();
//System.out.println("Reduce of map result " + result.id + " with sum " + sum);
if ( result.isLast() ) {
//System.out.println("Saw last! " + result.id);
return sum;
}
else {
if ( TIME_CALLS ) reduceTimer.restart();
sum = reduce.apply(result.datum, sum);
if ( TIME_CALLS ) reduceTimer.stop();
}
}
} catch (InterruptedException ex) {
//System.out.println("Interrupted");
throw new ReviewedStingException("got execution exception", ex);
}
}
}
private class InputProducer implements Runnable {
@ -359,16 +430,16 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
}
}
private class InputDatum {
private class BlockingDatum<T> {
final boolean isLast;
final InputType datum;
final T datum;
private InputDatum(final InputType datum) {
private BlockingDatum(final T datum) {
isLast = false;
this.datum = datum;
}
private InputDatum() {
private BlockingDatum() {
isLast = true;
this.datum = null;
}
@ -378,40 +449,56 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
}
}
@Requires({"map != null", "! inputs.isEmpty()"})
private Queue<Future<MapType>> submitMapJobs(final NanoSchedulerMapFunction<InputType, MapType> map,
final ExecutorService executor,
final List<InputType> inputs) {
final Queue<Future<MapType>> mapQueue = new LinkedList<Future<MapType>>();
for ( final InputType input : inputs ) {
final CallableMap doMap = new CallableMap(map, input);
final Future<MapType> future = executor.submit(doMap);
mapQueue.add(future);
private class InputDatum extends BlockingDatum<InputType> {
private InputDatum(InputType datum) { super(datum); }
private InputDatum() { }
}
private class MapResult extends BlockingDatum<MapType> implements Comparable<MapResult> {
final Integer id;
private MapResult(MapType datum, Integer id) {
super(datum);
this.id = id;
}
return mapQueue;
private MapResult() {
this.id = Integer.MAX_VALUE;
}
@Override
public int compareTo(MapResult o) {
return id.compareTo(o.id);
}
}
/**
* A simple callable version of the map function for use with the executor pool
*/
private class CallableMap implements Callable<MapType> {
private class CallableMap implements Runnable {
final int id;
final InputType input;
final NanoSchedulerMapFunction<InputType, MapType> map;
final PriorityBlockingQueue<MapResult> mapResultQueue;
@Requires({"map != null"})
private CallableMap(final NanoSchedulerMapFunction<InputType, MapType> map, final InputType inputs) {
this.input = inputs;
private CallableMap(final NanoSchedulerMapFunction<InputType, MapType> map,
final int id,
final InputType input,
final PriorityBlockingQueue<MapResult> mapResultQueue) {
this.id = id;
this.input = input;
this.map = map;
this.mapResultQueue = mapResultQueue;
}
@Override public MapType call() throws Exception {
@Override public void run() {
if ( TIME_CALLS ) mapTimer.restart();
if ( debug ) debugPrint("\t\tmap " + input);
final MapType result = map.apply(input);
if ( TIME_CALLS ) mapTimer.stop();
return result;
mapResultQueue.add(new MapResult(result, id));
}
}
}

View File

@ -0,0 +1,26 @@
package org.broadinstitute.sting.utils.threading;
import java.util.concurrent.ThreadFactory;
/**
* Thread factor that produces threads with a given name pattern
*
* User: depristo
* Date: 9/5/12
* Time: 9:22 PM
*
*/
public class NamedThreadFactory implements ThreadFactory {
static int id = 0;
final String format;
public NamedThreadFactory(String format) {
this.format = format;
String.format(format, id); // test the name
}
@Override
public Thread newThread(Runnable r) {
return new Thread(r, String.format(format, id++));
}
}

View File

@ -1,5 +1,6 @@
package org.broadinstitute.sting.utils.nanoScheduler;
import org.apache.log4j.BasicConfigurator;
import org.broadinstitute.sting.BaseTest;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
@ -165,6 +166,10 @@ public class NanoSchedulerUnitTest extends BaseTest {
}
public static void main(String [ ] args) {
org.apache.log4j.Logger logger = org.apache.log4j.Logger.getRootLogger();
BasicConfigurator.configure();
logger.setLevel(org.apache.log4j.Level.DEBUG);
final NanoSchedulerBasicTest test = new NanoSchedulerBasicTest(1000, Integer.valueOf(args[0]), 0, Integer.valueOf(args[1]));
final NanoScheduler<Integer, Integer, Integer> nanoScheduler =
new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.nThreads);
@ -172,5 +177,6 @@ public class NanoSchedulerUnitTest extends BaseTest {
final Integer sum = nanoScheduler.execute(test.makeReader(), test.makeMap(), test.initReduce(), test.makeReduce());
System.out.printf("Sum = %d, expected =%d%n", sum, test.expectedResult);
nanoScheduler.shutdown();
}
}