From 1ca13f958111b7441c5c337148acf9b80fb3419e Mon Sep 17 00:00:00 2001 From: Mark DePristo Date: Mon, 17 Dec 2012 21:01:50 -0500 Subject: [PATCH] Fundamentally better model for the NanoScheduler -- Now each map job reads a value, performs map, and does as much reducing as possible. This ensures that we scale performance with the nct value, so -nct 2 should result in 2x performance, -nct 3 3x, etc. All of this is accomplished using exactly NCT% of the CPU of the machine. -- Has the additional value of actually simplifying the code -- Resolves a long-standing annoyance with the nano scheduler. --- .../utils/nanoScheduler/InputProducer.java | 117 +++++++++--------- .../utils/nanoScheduler/NanoScheduler.java | 54 +++----- .../sting/utils/nanoScheduler/Reducer.java | 2 +- .../nanoScheduler/InputProducerUnitTest.java | 85 +------------ 4 files changed, 82 insertions(+), 176 deletions(-) diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/InputProducer.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/InputProducer.java index 0e0237412..84bb8d45f 100644 --- a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/InputProducer.java +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/InputProducer.java @@ -1,16 +1,15 @@ package org.broadinstitute.sting.utils.nanoScheduler; import org.apache.log4j.Logger; -import org.broadinstitute.sting.utils.MultiThreadedErrorTracker; import java.util.Iterator; -import java.util.concurrent.BlockingQueue; import java.util.concurrent.CountDownLatch; /** - * Producer Thread that reads input values from an inputReads and puts them into an output queue + * Helper class that allows multiple threads to reads input values from + * an iterator, and track the number of items read from that iterator. */ -class InputProducer implements Runnable { +class InputProducer { private final static Logger logger = Logger.getLogger(InputProducer.class); /** @@ -18,13 +17,6 @@ class InputProducer implements Runnable { */ final Iterator inputReader; - /** - * Where we put our input values for consumption - */ - final BlockingQueue outputQueue; - - final MultiThreadedErrorTracker errorTracker; - /** * Have we read the last value from inputReader? * @@ -34,6 +26,14 @@ class InputProducer implements Runnable { */ boolean readLastValue = false; + /** + * Once we've readLastValue, lastValue contains a continually + * updating InputValue where EOF is true. It's not necessarily + * a single value, as each read updates lastValue with the + * next EOF marker + */ + private InputValue lastValue = null; + int nRead = 0; int inputID = -1; @@ -43,16 +43,9 @@ class InputProducer implements Runnable { */ final CountDownLatch latch = new CountDownLatch(1); - public InputProducer(final Iterator inputReader, - final MultiThreadedErrorTracker errorTracker, - final BlockingQueue outputQueue) { + public InputProducer(final Iterator inputReader) { if ( inputReader == null ) throw new IllegalArgumentException("inputReader cannot be null"); - if ( errorTracker == null ) throw new IllegalArgumentException("errorTracker cannot be null"); - if ( outputQueue == null ) throw new IllegalArgumentException("OutputQueue cannot be null"); - this.inputReader = inputReader; - this.errorTracker = errorTracker; - this.outputQueue = outputQueue; } /** @@ -82,9 +75,8 @@ class InputProducer implements Runnable { * This method is synchronized, as it manipulates local state accessed across multiple threads. * * @return the next input stream value, or null if the stream contains no more elements - * @throws InterruptedException */ - private synchronized InputType readNextItem() throws InterruptedException { + private synchronized InputType readNextItem() { if ( ! inputReader.hasNext() ) { // we are done, mark ourselves as such and return null readLastValue = true; @@ -100,49 +92,60 @@ class InputProducer implements Runnable { } /** - * Run this input producer, looping over all items in the input reader and - * enqueueing them as InputValues into the outputQueue. After the - * end of the stream has been encountered, any threads waiting because - * they called waitForDone() will be freed. + * Are there currently more values in the iterator? + * + * Note the word currently. It's possible that some already submitted + * job will read a value from this InputProvider, so in some sense + * there are no more values and in the future there'll be no next + * value. That said, once this returns false it means that all + * of the possible values have been read + * + * @return true if a future call to next might return a non-EOF value, false if + * the underlying iterator is definitely empty */ - public void run() { - try { - while ( true ) { - final InputType value = readNextItem(); - - if ( value == null ) { - if ( ! readLastValue ) - throw new IllegalStateException("value == null but readLastValue is false!"); - - // add the EOF object so our consumer knows we are done in all inputs - // note that we do not increase inputID here, so that variable indicates the ID - // of the last real value read from the queue - outputQueue.put(new InputValue(inputID + 1)); - break; - } else { - // add the actual value to the outputQueue - outputQueue.put(new InputValue(++inputID, value)); - } - } - - latch.countDown(); - } catch (Throwable ex) { - errorTracker.notifyOfError(ex); - } finally { -// logger.info("Exiting input thread readLastValue = " + readLastValue); - } + public synchronized boolean hasNext() { + return ! allInputsHaveBeenRead(); } /** - * Block until all of the items have been read from inputReader. + * Get the next InputValue from this producer. The next value is + * either (1) the next value from the iterator, in which case the + * the return value is an InputValue containing that value, or (2) + * an InputValue with the EOF marker, indicating that the underlying + * iterator has been exhausted. * - * Note that this call doesn't actually read anything. You have to submit a thread - * to actually execute run() directly. + * This function never fails -- it can be called endlessly and + * while the underlying iterator has values it returns them, and then + * it returns a succession of EOF marking input values. * - * @throws InterruptedException + * @return an InputValue containing the next value in the underlying + * iterator, or one with EOF marker, if the iterator is exhausted */ - public void waitForDone() throws InterruptedException { - latch.await(); + public synchronized InputValue next() { + if ( readLastValue ) { + // we read the last value, so our value is the next + // EOF marker based on the last value. Make sure to + // update the last value so the markers keep incrementing + // their job ids + lastValue = lastValue.nextEOF(); + return lastValue; + } else { + final InputType value = readNextItem(); + + if ( value == null ) { + if ( ! readLastValue ) + throw new IllegalStateException("value == null but readLastValue is false!"); + + // add the EOF object so our consumer knows we are done in all inputs + // note that we do not increase inputID here, so that variable indicates the ID + // of the last real value read from the queue + lastValue = new InputValue(inputID + 1); + return lastValue; + } else { + // add the actual value to the outputQueue + return new InputValue(++inputID, value); + } + } } /** diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java index 4cc91faa4..38a1d7b8f 100644 --- a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java @@ -47,7 +47,6 @@ public class NanoScheduler { final int bufferSize; final int nThreads; - final ExecutorService inputExecutor; final ExecutorService masterExecutor; final ExecutorService mapExecutor; final Semaphore runningMapJobSlots; @@ -75,14 +74,12 @@ public class NanoScheduler { this.nThreads = nThreads; if ( nThreads == 1 ) { - this.mapExecutor = this.inputExecutor = this.masterExecutor = null; + this.mapExecutor = this.masterExecutor = null; runningMapJobSlots = null; } else { - this.mapExecutor = Executors.newFixedThreadPool(nThreads - 1, new NamedThreadFactory("NS-map-thread-%d")); - runningMapJobSlots = new Semaphore(this.bufferSize); - - this.inputExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("NS-input-thread-%d")); this.masterExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("NS-master-thread-%d")); + this.mapExecutor = Executors.newFixedThreadPool(nThreads, new NamedThreadFactory("NS-map-thread-%d")); + runningMapJobSlots = new Semaphore(this.bufferSize); } } @@ -111,7 +108,6 @@ public class NanoScheduler { */ public void shutdown() { if ( nThreads > 1 ) { - shutdownExecutor("inputExecutor", inputExecutor); shutdownExecutor("mapExecutor", mapExecutor); shutdownExecutor("masterExecutor", masterExecutor); } @@ -323,7 +319,6 @@ public class NanoScheduler { if ( errorTracker.hasAnErrorOccurred() ) { masterExecutor.shutdownNow(); mapExecutor.shutdownNow(); - inputExecutor.shutdownNow(); errorTracker.throwErrorIfPending(); } } @@ -351,15 +346,8 @@ public class NanoScheduler { @Override public ReduceType call() { - // a blocking queue that limits the number of input datum to the requested buffer size - // note we need +1 because we continue to enqueue the lastObject - final BlockingQueue.InputValue> inputQueue - = new LinkedBlockingDeque.InputValue>(bufferSize+1); - // Create the input producer and start it running - final InputProducer inputProducer = - new InputProducer(inputReader, errorTracker, inputQueue); - inputExecutor.submit(inputProducer); + final InputProducer inputProducer = new InputProducer(inputReader); // a priority queue that stores up to bufferSize elements // produced by completed map jobs. @@ -376,7 +364,7 @@ public class NanoScheduler { // acquire a slot to run a map job. Blocks if too many jobs are enqueued runningMapJobSlots.acquire(); - mapExecutor.submit(new MapReduceJob(inputQueue, mapResultQueue, map, reducer)); + mapExecutor.submit(new ReadMapReduceJob(inputProducer, mapResultQueue, map, reducer)); nSubmittedJobs++; } @@ -402,10 +390,6 @@ public class NanoScheduler { // logger.warn("waiting for final reduce"); final ReduceType finalSum = reducer.waitForFinalReduce(); - // now wait for the input provider thread to terminate -// logger.warn("waiting on inputProducer"); - inputProducer.waitForDone(); - // wait for all the map threads to finish by acquiring and then releasing all map job semaphores // logger.warn("waiting on map"); runningMapJobSlots.acquire(bufferSize); @@ -434,17 +418,17 @@ public class NanoScheduler { } } - private class MapReduceJob implements Runnable { - final BlockingQueue.InputValue> inputQueue; + private class ReadMapReduceJob implements Runnable { + final InputProducer inputProducer; final PriorityBlockingQueue> mapResultQueue; final NSMapFunction map; final Reducer reducer; - private MapReduceJob(BlockingQueue.InputValue> inputQueue, - final PriorityBlockingQueue> mapResultQueue, - final NSMapFunction map, - final Reducer reducer) { - this.inputQueue = inputQueue; + private ReadMapReduceJob(final InputProducer inputProducer, + final PriorityBlockingQueue> mapResultQueue, + final NSMapFunction map, + final Reducer reducer) { + this.inputProducer = inputProducer; this.mapResultQueue = mapResultQueue; this.map = map; this.reducer = reducer; @@ -453,10 +437,10 @@ public class NanoScheduler { @Override public void run() { try { - //debugPrint("Running MapReduceJob " + jobID); - final InputProducer.InputValue inputWrapper = inputQueue.take(); - final int jobID = inputWrapper.getId(); + // get the next item from the input producer + final InputProducer.InputValue inputWrapper = inputProducer.next(); + // depending on inputWrapper, actually do some work or not, putting result input result object final MapResult result; if ( ! inputWrapper.isEOFMarker() ) { // just skip doing anything if we don't have work to do, which is possible @@ -468,23 +452,19 @@ public class NanoScheduler { final MapType mapValue = map.apply(input); // enqueue the result into the mapResultQueue - result = new MapResult(mapValue, jobID); + result = new MapResult(mapValue, inputWrapper.getId()); if ( progressFunction != null ) progressFunction.progress(input); } else { - // push back the EOF marker so other waiting threads can read it - inputQueue.put(inputWrapper.nextEOF()); - // if there's no input we push empty MapResults with jobIDs for synchronization with Reducer - result = new MapResult(jobID); + result = new MapResult(inputWrapper.getId()); } mapResultQueue.put(result); final int nReduced = reducer.reduceAsMuchAsPossible(mapResultQueue); } catch (Throwable ex) { -// logger.warn("Map job got exception " + ex); errorTracker.notifyOfError(ex); } finally { // we finished a map job, release the job queue semaphore diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/Reducer.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/Reducer.java index 5cae28187..a7b94e323 100644 --- a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/Reducer.java +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/Reducer.java @@ -84,7 +84,7 @@ class Reducer { if ( nextMapResult == null ) { return false; } else if ( nextMapResult.getJobID() < prevJobID + 1 ) { - throw new IllegalStateException("Next job ID " + nextMapResult.getJobID() + " is < previous job id " + prevJobID); + throw new IllegalStateException("Next job ID " + nextMapResult.getJobID() + " is not < previous job id " + prevJobID); } else if ( nextMapResult.getJobID() == prevJobID + 1 ) { return true; } else { diff --git a/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/InputProducerUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/InputProducerUnitTest.java index 489adab6b..9ccfb9229 100644 --- a/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/InputProducerUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/InputProducerUnitTest.java @@ -1,19 +1,13 @@ package org.broadinstitute.sting.utils.nanoScheduler; import org.broadinstitute.sting.BaseTest; -import org.broadinstitute.sting.utils.MultiThreadedErrorTracker; import org.testng.Assert; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.util.ArrayList; import java.util.Arrays; -import java.util.Iterator; import java.util.List; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; -import java.util.concurrent.LinkedBlockingDeque; -import java.util.concurrent.Semaphore; /** * UnitTests for the InputProducer @@ -42,34 +36,23 @@ public class InputProducerUnitTest extends BaseTest { final List elements = new ArrayList(nElements); for ( int i = 0; i < nElements; i++ ) elements.add(i); - final LinkedBlockingDeque.InputValue> readQueue = - new LinkedBlockingDeque.InputValue>(queueSize); - - final InputProducer ip = new InputProducer(elements.iterator(), new MultiThreadedErrorTracker(), readQueue); - - final ExecutorService es = Executors.newSingleThreadExecutor(); + final InputProducer ip = new InputProducer(elements.iterator()); Assert.assertFalse(ip.allInputsHaveBeenRead(), "InputProvider said that all inputs have been read, but I haven't started reading yet"); Assert.assertEquals(ip.getNumInputValues(), -1, "InputProvider told me that the queue was done, but I haven't started reading yet"); - es.submit(ip); - int lastValue = -1; int nRead = 0; - while ( true ) { + while ( ip.hasNext() ) { final int nTotalElements = ip.getNumInputValues(); - final int observedQueueSize = readQueue.size(); - Assert.assertTrue(observedQueueSize <= queueSize, - "Reader is enqueuing more elements " + observedQueueSize + " than allowed " + queueSize); - if ( nRead + observedQueueSize < nElements ) + if ( nRead < nElements ) Assert.assertEquals(nTotalElements, -1, "getNumInputValues should have returned -1 with not all elements read"); // note, cannot test else case because elements input could have emptied between calls - final InputProducer.InputValue value = readQueue.take(); + final InputProducer.InputValue value = ip.next(); if ( value.isEOFMarker() ) { Assert.assertEquals(nRead, nElements, "Number of input values " + nRead + " not all that are expected " + nElements); - Assert.assertEquals(readQueue.size(), 0, "Last queue element found but queue contains more values!"); break; } else { Assert.assertTrue(lastValue < value.getValue(), "Read values coming out of order!"); @@ -82,65 +65,5 @@ public class InputProducerUnitTest extends BaseTest { Assert.assertTrue(ip.allInputsHaveBeenRead(), "InputProvider said that all inputs haven't been read, but I read them all"); Assert.assertEquals(ip.getNumInputValues(), nElements, "Wrong number of total elements getNumInputValues"); - es.shutdownNow(); - } - - @Test(enabled = true, dataProvider = "InputProducerTest", timeOut = NanoSchedulerUnitTest.NANO_SCHEDULE_MAX_RUNTIME) - public void testInputProducerLocking(final int nElements, final int queueSize) throws InterruptedException { - final List elements = new ArrayList(nElements); - for ( int i = 0; i < nElements; i++ ) elements.add(i); - - final LinkedBlockingDeque.InputValue> readQueue = - new LinkedBlockingDeque.InputValue>(); - - final InputProducer ip = new InputProducer(elements.iterator(), new MultiThreadedErrorTracker(), readQueue); - - final ExecutorService es = Executors.newSingleThreadExecutor(); - es.submit(ip); - - ip.waitForDone(); - - Assert.assertEquals(ip.getNumInputValues(), nElements, "InputProvider told me that the queue was done, but I haven't started reading yet"); - Assert.assertEquals(readQueue.size(), nElements + 1, "readQueue should have had all elements read into it"); - } - - final static class BlockingIterator implements Iterator { - final Semaphore blockNext = new Semaphore(0); - final Semaphore blockOnNext = new Semaphore(0); - final Iterator underlyingIterator; - - BlockingIterator(Iterator underlyingIterator) { - this.underlyingIterator = underlyingIterator; - } - - public void allowNext() { - blockNext.release(1); - } - - public void blockTillNext() throws InterruptedException { - blockOnNext.acquire(1); - } - - @Override - public boolean hasNext() { - return underlyingIterator.hasNext(); - } - - @Override - public T next() { - try { - blockNext.acquire(1); - T value = underlyingIterator.next(); - blockOnNext.release(1); - return value; - } catch (InterruptedException ex) { - throw new RuntimeException(ex); - } - } - - @Override - public void remove() { - throw new UnsupportedOperationException("x"); - } } }