diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java index 24db0f7dc..fe8731d3b 100644 --- a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java @@ -5,13 +5,11 @@ import com.google.java.contract.Requires; import org.apache.log4j.Logger; import org.broadinstitute.sting.utils.AutoFormattingTime; import org.broadinstitute.sting.utils.SimpleTimer; -import org.broadinstitute.sting.utils.collections.Pair; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; +import org.broadinstitute.sting.utils.threading.NamedThreadFactory; import java.util.Iterator; -import java.util.LinkedList; import java.util.List; -import java.util.Queue; import java.util.concurrent.*; /** @@ -52,7 +50,9 @@ public class NanoScheduler { final int bufferSize; final int nThreads; + final ExecutorService inputExecutor; + final ExecutorService reduceExecutor; final ExecutorService mapExecutor; boolean shutdown = false; boolean debug = false; @@ -77,8 +77,14 @@ public class NanoScheduler { this.bufferSize = bufferSize; this.nThreads = nThreads; - this.mapExecutor = nThreads == 1 ? null : Executors.newFixedThreadPool(nThreads-1); - this.inputExecutor = Executors.newSingleThreadExecutor(); + + if ( nThreads == 1 ) { + this.mapExecutor = this.inputExecutor = this.reduceExecutor = null; + } else { + this.mapExecutor = Executors.newFixedThreadPool(nThreads-1, new NamedThreadFactory("NS-map-thread-%d")); + this.inputExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("NS-input-thread-%d")); + this.reduceExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("NS-reduce-thread-%d")); + } // start timing the time spent outside of the nanoScheduler outsideSchedulerTimer.start(); @@ -110,11 +116,9 @@ public class NanoScheduler { public void shutdown() { outsideSchedulerTimer.stop(); - if ( mapExecutor != null ) { - final List remaining = mapExecutor.shutdownNow(); - if ( ! remaining.isEmpty() ) - throw new IllegalStateException("Remaining tasks found in the mapExecutor, unexpected behavior!"); - } + shutdownExecutor("inputExecutor", inputExecutor); + shutdownExecutor("mapExecutor", mapExecutor); + shutdownExecutor("reduceExecutor", reduceExecutor); shutdown = true; if (TIME_CALLS) { @@ -125,6 +129,31 @@ public class NanoScheduler { } } + /** + * Helper function to cleanly shutdown an execution service, checking that the execution + * state is clean when it's done. + * + * @param name a string name for error messages for the executorService we are shutting down + * @param executorService the executorService to shut down + */ + private void shutdownExecutor(final String name, final ExecutorService executorService) { + if ( executorService != null ) { + if ( executorService.isShutdown() || executorService.isTerminated() ) + throw new IllegalStateException("Executor service " + name + " is already shut down!"); + + final List remaining = executorService.shutdownNow(); + if ( ! remaining.isEmpty() ) + throw new IllegalStateException(remaining.size() + " remaining tasks found in an executor " + name + ", unexpected behavior!"); + } + } + + /** + * Print to logger.info timing information from timer, with name label + * + * @param label the name of the timer to display. Should be human readable + * @param timer the timer whose elapsed time we will display + */ + @Requires({"label != null", "timer != null"}) private void printTimerInfo(final String label, final SimpleTimer timer) { final double total = inputTimer.getElapsedTime() + mapTimer.getElapsedTime() + reduceTimer.getElapsedTime() + outsideSchedulerTimer.getElapsedTime(); @@ -140,16 +169,30 @@ public class NanoScheduler { return shutdown; } + /** + * @return are we displaying verbose debugging information about the scheduling? + */ public boolean isDebug() { return debug; } + /** + * Helper function to display a String.formatted message if we are doing verbose debugging + * + * @param format the format argument suitable for String.format + * @param args the arguments for String.format + */ + @Requires("format != null") private void debugPrint(final String format, Object ... args) { if ( isDebug() ) logger.info("Thread " + Thread.currentThread().getId() + ":" + String.format(format, args)); } - + /** + * Turn on/off verbose debugging + * + * @param debug true if we want verbose debugging + */ public void setDebug(boolean debug) { this.debug = debug; } @@ -179,6 +222,9 @@ public class NanoScheduler { * It is safe to call this function repeatedly on a single nanoScheduler, at least until the * shutdown method is called. * + * Note that this function goes through a single threaded fast path if the number of threads + * is 1. + * * @param inputReader an iterator providing us with the input data to nanoSchedule map/reduce over * @param map the map function from input type -> map type, will be applied in parallel to each input * @param reduce the reduce function from map type + reduce type -> reduce type to be applied in order to map results @@ -207,9 +253,11 @@ public class NanoScheduler { } /** - * Simple efficient reference implementation for single threaded execution + * Simple efficient reference implementation for single threaded execution. + * * @return the reduce result of this map/reduce job */ + @Requires({"inputReader != null", "map != null", "reduce != null"}) private ReduceType executeSingleThreaded(final Iterator inputReader, final NanoSchedulerMapFunction map, final ReduceType initialValue, @@ -249,88 +297,111 @@ public class NanoScheduler { * * @return the reduce result of this map/reduce job */ + @Requires({"inputReader != null", "map != null", "reduce != null"}) private ReduceType executeMultiThreaded(final Iterator inputReader, final NanoSchedulerMapFunction map, final ReduceType initialValue, final NanoSchedulerReduceFunction reduce) { debugPrint("Executing nanoScheduler"); - ReduceType sum = initialValue; - boolean done = false; + // a completion service that tracks when jobs complete, so we can wait in this thread + // until all of the map jobs are completed, without having to shut down the executor itself + final ExecutorCompletionService mapJobCompletionService = + new ExecutorCompletionService(mapExecutor); + + // a blocking queue that limits the number of input datum to the requested buffer size final BlockingQueue inputQueue = new LinkedBlockingDeque(bufferSize); + // a priority queue that stores up to bufferSize * MAP_QUEUE_SCALE_FACTOR elements + // produced by completed map jobs. + final PriorityBlockingQueue mapResultQueue = new PriorityBlockingQueue(bufferSize*100); + + // TODO -- the logic of this blocking queue is wrong! We need to wait for map jobs in order, not just + // -- in the order in which they are produced + + // TODO -- map executor must have fixed size map jobs queue + inputExecutor.submit(new InputProducer(inputReader, inputQueue)); + final Future reduceResult = reduceExecutor.submit(new ReducerThread(reduce, initialValue, mapResultQueue)); - while ( ! done ) { - try { - final Pair, Boolean> readResults = readInputs(inputQueue); - final List inputs = readResults.getFirst(); - done = readResults.getSecond(); + try { + int numJobs = 0; + while ( true ) { + // block on input + final InputDatum inputEnqueueWrapped = inputQueue.take(); - if ( ! inputs.isEmpty() ) { - // send jobs for map - final Queue> mapQueue = submitMapJobs(map, mapExecutor, inputs); + if ( ! inputEnqueueWrapped.isLast() ) { + // get the object itself + final InputType input = inputEnqueueWrapped.datum; + + // the next map call has id + 1 + numJobs++; + + // send job for map via the completion service + final CallableMap doMap = new CallableMap(map, numJobs, input, mapResultQueue); + mapJobCompletionService.submit(doMap, numJobs); - // send off the reduce job, and block until we get at least one reduce result - sum = reduceSerial(reduce, mapQueue, sum); debugPrint(" Done with cycle of map/reduce"); - if ( progressFunction != null ) progressFunction.progress(inputs.get(inputs.size()-1)); + if ( progressFunction != null ) // TODO -- don't cycle so often + progressFunction.progress(input); } else { - // we must be done - if ( ! done ) throw new IllegalStateException("Inputs empty but not done"); + waitForLastJob(mapJobCompletionService, numJobs); + mapResultQueue.add(new MapResult()); + return reduceResult.get(); // wait for our result of reduce } - } catch (InterruptedException ex) { - throw new ReviewedStingException("got execution exception", ex); - } catch (ExecutionException ex) { - throw new ReviewedStingException("got execution exception", ex); } + } catch (InterruptedException ex) { + throw new ReviewedStingException("got execution exception", ex); + } catch (ExecutionException ex) { + throw new ReviewedStingException("got execution exception", ex); } - - return sum; - } - - @Requires({"reduce != null", "! mapQueue.isEmpty()"}) - private ReduceType reduceSerial(final NanoSchedulerReduceFunction reduce, - final Queue> mapQueue, - final ReduceType initSum) - throws InterruptedException, ExecutionException { - ReduceType sum = initSum; - - // while mapQueue has something in it to reduce - for ( final Future future : mapQueue ) { - final MapType value = future.get(); // block until we get the values for this task - - if ( TIME_CALLS ) reduceTimer.restart(); - sum = reduce.apply(value, sum); - if ( TIME_CALLS ) reduceTimer.stop(); - } - - return sum; } /** - * Read up to inputBufferSize elements from inputReader - * - * @return a queue of input read in, containing one or more values of InputType read in + * Helper routine that will wait until the last map job finishes running + * by taking numJob values from the executor completion service, using + * the blocking take() call. */ - @Requires("inputReader != null") - @Ensures("result != null") - private Pair, Boolean> readInputs(final BlockingQueue inputReader) throws InterruptedException { - int n = 0; - final List inputs = new LinkedList(); - boolean done = false; + private void waitForLastJob(final ExecutorCompletionService mapJobCompletionService, + final int numJobs ) throws InterruptedException { + for ( int i = 0; i < numJobs; i++ ) + mapJobCompletionService.take(); + } - while ( ! done && n < getBufferSize() ) { - final InputDatum input = inputReader.take(); - done = input.isLast(); - if ( ! done ) { - inputs.add(input.datum); - n++; - } + private class ReducerThread implements Callable { + final NanoSchedulerReduceFunction reduce; + ReduceType sum; + final PriorityBlockingQueue mapResultQueue; + + public ReducerThread(final NanoSchedulerReduceFunction reduce, + final ReduceType sum, + final PriorityBlockingQueue mapResultQueue) { + this.reduce = reduce; + this.sum = sum; + this.mapResultQueue = mapResultQueue; } - return new Pair, Boolean>(inputs, done); + public ReduceType call() { + try { + while ( true ) { + final MapResult result = mapResultQueue.take(); + //System.out.println("Reduce of map result " + result.id + " with sum " + sum); + if ( result.isLast() ) { + //System.out.println("Saw last! " + result.id); + return sum; + } + else { + if ( TIME_CALLS ) reduceTimer.restart(); + sum = reduce.apply(result.datum, sum); + if ( TIME_CALLS ) reduceTimer.stop(); + } + } + } catch (InterruptedException ex) { + //System.out.println("Interrupted"); + throw new ReviewedStingException("got execution exception", ex); + } + } } private class InputProducer implements Runnable { @@ -359,16 +430,16 @@ public class NanoScheduler { } } - private class InputDatum { + private class BlockingDatum { final boolean isLast; - final InputType datum; + final T datum; - private InputDatum(final InputType datum) { + private BlockingDatum(final T datum) { isLast = false; this.datum = datum; } - private InputDatum() { + private BlockingDatum() { isLast = true; this.datum = null; } @@ -378,40 +449,56 @@ public class NanoScheduler { } } - @Requires({"map != null", "! inputs.isEmpty()"}) - private Queue> submitMapJobs(final NanoSchedulerMapFunction map, - final ExecutorService executor, - final List inputs) { - final Queue> mapQueue = new LinkedList>(); - for ( final InputType input : inputs ) { - final CallableMap doMap = new CallableMap(map, input); - final Future future = executor.submit(doMap); - mapQueue.add(future); + private class InputDatum extends BlockingDatum { + private InputDatum(InputType datum) { super(datum); } + private InputDatum() { } + } + + private class MapResult extends BlockingDatum implements Comparable { + final Integer id; + + private MapResult(MapType datum, Integer id) { + super(datum); + this.id = id; } - return mapQueue; + private MapResult() { + this.id = Integer.MAX_VALUE; + } + + @Override + public int compareTo(MapResult o) { + return id.compareTo(o.id); + } } /** * A simple callable version of the map function for use with the executor pool */ - private class CallableMap implements Callable { + private class CallableMap implements Runnable { + final int id; final InputType input; final NanoSchedulerMapFunction map; + final PriorityBlockingQueue mapResultQueue; @Requires({"map != null"}) - private CallableMap(final NanoSchedulerMapFunction map, final InputType inputs) { - this.input = inputs; + private CallableMap(final NanoSchedulerMapFunction map, + final int id, + final InputType input, + final PriorityBlockingQueue mapResultQueue) { + this.id = id; + this.input = input; this.map = map; + this.mapResultQueue = mapResultQueue; } - @Override public MapType call() throws Exception { + @Override public void run() { if ( TIME_CALLS ) mapTimer.restart(); if ( debug ) debugPrint("\t\tmap " + input); final MapType result = map.apply(input); if ( TIME_CALLS ) mapTimer.stop(); - return result; + mapResultQueue.add(new MapResult(result, id)); } } } diff --git a/public/java/src/org/broadinstitute/sting/utils/threading/NamedThreadFactory.java b/public/java/src/org/broadinstitute/sting/utils/threading/NamedThreadFactory.java new file mode 100644 index 000000000..b25375b87 --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/utils/threading/NamedThreadFactory.java @@ -0,0 +1,26 @@ +package org.broadinstitute.sting.utils.threading; + +import java.util.concurrent.ThreadFactory; + +/** + * Thread factor that produces threads with a given name pattern + * + * User: depristo + * Date: 9/5/12 + * Time: 9:22 PM + * + */ +public class NamedThreadFactory implements ThreadFactory { + static int id = 0; + final String format; + + public NamedThreadFactory(String format) { + this.format = format; + String.format(format, id); // test the name + } + + @Override + public Thread newThread(Runnable r) { + return new Thread(r, String.format(format, id++)); + } +} diff --git a/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java index ddfc3cecd..21ac6dcec 100644 --- a/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java @@ -1,5 +1,6 @@ package org.broadinstitute.sting.utils.nanoScheduler; +import org.apache.log4j.BasicConfigurator; import org.broadinstitute.sting.BaseTest; import org.testng.Assert; import org.testng.annotations.DataProvider; @@ -165,6 +166,10 @@ public class NanoSchedulerUnitTest extends BaseTest { } public static void main(String [ ] args) { + org.apache.log4j.Logger logger = org.apache.log4j.Logger.getRootLogger(); + BasicConfigurator.configure(); + logger.setLevel(org.apache.log4j.Level.DEBUG); + final NanoSchedulerBasicTest test = new NanoSchedulerBasicTest(1000, Integer.valueOf(args[0]), 0, Integer.valueOf(args[1])); final NanoScheduler nanoScheduler = new NanoScheduler(test.bufferSize, test.nThreads); @@ -172,5 +177,6 @@ public class NanoSchedulerUnitTest extends BaseTest { final Integer sum = nanoScheduler.execute(test.makeReader(), test.makeMap(), test.initReduce(), test.makeReduce()); System.out.printf("Sum = %d, expected =%d%n", sum, test.expectedResult); + nanoScheduler.shutdown(); } }