Fundamentally better model for the NanoScheduler

-- Now each map job reads a value, performs map, and does as much reducing as possible.  This ensures that we scale performance with the nct value, so -nct 2 should result in 2x performance, -nct 3 3x, etc.  All of this is accomplished using exactly NCT% of the CPU of the machine.
-- Has the additional value of actually simplifying the code
-- Resolves a long-standing annoyance with the nano scheduler.
This commit is contained in:
Mark DePristo 2012-12-17 21:01:50 -05:00
parent d0cd29cb36
commit 1ca13f9581
4 changed files with 82 additions and 176 deletions

View File

@ -1,16 +1,15 @@
package org.broadinstitute.sting.utils.nanoScheduler;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.utils.MultiThreadedErrorTracker;
import java.util.Iterator;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
/**
* Producer Thread that reads input values from an inputReads and puts them into an output queue
* Helper class that allows multiple threads to reads input values from
* an iterator, and track the number of items read from that iterator.
*/
class InputProducer<InputType> implements Runnable {
class InputProducer<InputType> {
private final static Logger logger = Logger.getLogger(InputProducer.class);
/**
@ -18,13 +17,6 @@ class InputProducer<InputType> implements Runnable {
*/
final Iterator<InputType> inputReader;
/**
* Where we put our input values for consumption
*/
final BlockingQueue<InputValue> outputQueue;
final MultiThreadedErrorTracker errorTracker;
/**
* Have we read the last value from inputReader?
*
@ -34,6 +26,14 @@ class InputProducer<InputType> implements Runnable {
*/
boolean readLastValue = false;
/**
* Once we've readLastValue, lastValue contains a continually
* updating InputValue where EOF is true. It's not necessarily
* a single value, as each read updates lastValue with the
* next EOF marker
*/
private InputValue lastValue = null;
int nRead = 0;
int inputID = -1;
@ -43,16 +43,9 @@ class InputProducer<InputType> implements Runnable {
*/
final CountDownLatch latch = new CountDownLatch(1);
public InputProducer(final Iterator<InputType> inputReader,
final MultiThreadedErrorTracker errorTracker,
final BlockingQueue<InputValue> outputQueue) {
public InputProducer(final Iterator<InputType> inputReader) {
if ( inputReader == null ) throw new IllegalArgumentException("inputReader cannot be null");
if ( errorTracker == null ) throw new IllegalArgumentException("errorTracker cannot be null");
if ( outputQueue == null ) throw new IllegalArgumentException("OutputQueue cannot be null");
this.inputReader = inputReader;
this.errorTracker = errorTracker;
this.outputQueue = outputQueue;
}
/**
@ -82,9 +75,8 @@ class InputProducer<InputType> implements Runnable {
* This method is synchronized, as it manipulates local state accessed across multiple threads.
*
* @return the next input stream value, or null if the stream contains no more elements
* @throws InterruptedException
*/
private synchronized InputType readNextItem() throws InterruptedException {
private synchronized InputType readNextItem() {
if ( ! inputReader.hasNext() ) {
// we are done, mark ourselves as such and return null
readLastValue = true;
@ -100,49 +92,60 @@ class InputProducer<InputType> implements Runnable {
}
/**
* Run this input producer, looping over all items in the input reader and
* enqueueing them as InputValues into the outputQueue. After the
* end of the stream has been encountered, any threads waiting because
* they called waitForDone() will be freed.
* Are there currently more values in the iterator?
*
* Note the word currently. It's possible that some already submitted
* job will read a value from this InputProvider, so in some sense
* there are no more values and in the future there'll be no next
* value. That said, once this returns false it means that all
* of the possible values have been read
*
* @return true if a future call to next might return a non-EOF value, false if
* the underlying iterator is definitely empty
*/
public void run() {
try {
while ( true ) {
final InputType value = readNextItem();
if ( value == null ) {
if ( ! readLastValue )
throw new IllegalStateException("value == null but readLastValue is false!");
// add the EOF object so our consumer knows we are done in all inputs
// note that we do not increase inputID here, so that variable indicates the ID
// of the last real value read from the queue
outputQueue.put(new InputValue(inputID + 1));
break;
} else {
// add the actual value to the outputQueue
outputQueue.put(new InputValue(++inputID, value));
}
}
latch.countDown();
} catch (Throwable ex) {
errorTracker.notifyOfError(ex);
} finally {
// logger.info("Exiting input thread readLastValue = " + readLastValue);
}
public synchronized boolean hasNext() {
return ! allInputsHaveBeenRead();
}
/**
* Block until all of the items have been read from inputReader.
* Get the next InputValue from this producer. The next value is
* either (1) the next value from the iterator, in which case the
* the return value is an InputValue containing that value, or (2)
* an InputValue with the EOF marker, indicating that the underlying
* iterator has been exhausted.
*
* Note that this call doesn't actually read anything. You have to submit a thread
* to actually execute run() directly.
* This function never fails -- it can be called endlessly and
* while the underlying iterator has values it returns them, and then
* it returns a succession of EOF marking input values.
*
* @throws InterruptedException
* @return an InputValue containing the next value in the underlying
* iterator, or one with EOF marker, if the iterator is exhausted
*/
public void waitForDone() throws InterruptedException {
latch.await();
public synchronized InputValue next() {
if ( readLastValue ) {
// we read the last value, so our value is the next
// EOF marker based on the last value. Make sure to
// update the last value so the markers keep incrementing
// their job ids
lastValue = lastValue.nextEOF();
return lastValue;
} else {
final InputType value = readNextItem();
if ( value == null ) {
if ( ! readLastValue )
throw new IllegalStateException("value == null but readLastValue is false!");
// add the EOF object so our consumer knows we are done in all inputs
// note that we do not increase inputID here, so that variable indicates the ID
// of the last real value read from the queue
lastValue = new InputValue(inputID + 1);
return lastValue;
} else {
// add the actual value to the outputQueue
return new InputValue(++inputID, value);
}
}
}
/**

View File

@ -47,7 +47,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
final int bufferSize;
final int nThreads;
final ExecutorService inputExecutor;
final ExecutorService masterExecutor;
final ExecutorService mapExecutor;
final Semaphore runningMapJobSlots;
@ -75,14 +74,12 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
this.nThreads = nThreads;
if ( nThreads == 1 ) {
this.mapExecutor = this.inputExecutor = this.masterExecutor = null;
this.mapExecutor = this.masterExecutor = null;
runningMapJobSlots = null;
} else {
this.mapExecutor = Executors.newFixedThreadPool(nThreads - 1, new NamedThreadFactory("NS-map-thread-%d"));
runningMapJobSlots = new Semaphore(this.bufferSize);
this.inputExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("NS-input-thread-%d"));
this.masterExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("NS-master-thread-%d"));
this.mapExecutor = Executors.newFixedThreadPool(nThreads, new NamedThreadFactory("NS-map-thread-%d"));
runningMapJobSlots = new Semaphore(this.bufferSize);
}
}
@ -111,7 +108,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
*/
public void shutdown() {
if ( nThreads > 1 ) {
shutdownExecutor("inputExecutor", inputExecutor);
shutdownExecutor("mapExecutor", mapExecutor);
shutdownExecutor("masterExecutor", masterExecutor);
}
@ -323,7 +319,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
if ( errorTracker.hasAnErrorOccurred() ) {
masterExecutor.shutdownNow();
mapExecutor.shutdownNow();
inputExecutor.shutdownNow();
errorTracker.throwErrorIfPending();
}
}
@ -351,15 +346,8 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
@Override
public ReduceType call() {
// a blocking queue that limits the number of input datum to the requested buffer size
// note we need +1 because we continue to enqueue the lastObject
final BlockingQueue<InputProducer<InputType>.InputValue> inputQueue
= new LinkedBlockingDeque<InputProducer<InputType>.InputValue>(bufferSize+1);
// Create the input producer and start it running
final InputProducer<InputType> inputProducer =
new InputProducer<InputType>(inputReader, errorTracker, inputQueue);
inputExecutor.submit(inputProducer);
final InputProducer<InputType> inputProducer = new InputProducer<InputType>(inputReader);
// a priority queue that stores up to bufferSize elements
// produced by completed map jobs.
@ -376,7 +364,7 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
// acquire a slot to run a map job. Blocks if too many jobs are enqueued
runningMapJobSlots.acquire();
mapExecutor.submit(new MapReduceJob(inputQueue, mapResultQueue, map, reducer));
mapExecutor.submit(new ReadMapReduceJob(inputProducer, mapResultQueue, map, reducer));
nSubmittedJobs++;
}
@ -402,10 +390,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
// logger.warn("waiting for final reduce");
final ReduceType finalSum = reducer.waitForFinalReduce();
// now wait for the input provider thread to terminate
// logger.warn("waiting on inputProducer");
inputProducer.waitForDone();
// wait for all the map threads to finish by acquiring and then releasing all map job semaphores
// logger.warn("waiting on map");
runningMapJobSlots.acquire(bufferSize);
@ -434,17 +418,17 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
}
}
private class MapReduceJob implements Runnable {
final BlockingQueue<InputProducer<InputType>.InputValue> inputQueue;
private class ReadMapReduceJob implements Runnable {
final InputProducer<InputType> inputProducer;
final PriorityBlockingQueue<MapResult<MapType>> mapResultQueue;
final NSMapFunction<InputType, MapType> map;
final Reducer<MapType, ReduceType> reducer;
private MapReduceJob(BlockingQueue<InputProducer<InputType>.InputValue> inputQueue,
final PriorityBlockingQueue<MapResult<MapType>> mapResultQueue,
final NSMapFunction<InputType, MapType> map,
final Reducer<MapType, ReduceType> reducer) {
this.inputQueue = inputQueue;
private ReadMapReduceJob(final InputProducer<InputType> inputProducer,
final PriorityBlockingQueue<MapResult<MapType>> mapResultQueue,
final NSMapFunction<InputType, MapType> map,
final Reducer<MapType, ReduceType> reducer) {
this.inputProducer = inputProducer;
this.mapResultQueue = mapResultQueue;
this.map = map;
this.reducer = reducer;
@ -453,10 +437,10 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
@Override
public void run() {
try {
//debugPrint("Running MapReduceJob " + jobID);
final InputProducer<InputType>.InputValue inputWrapper = inputQueue.take();
final int jobID = inputWrapper.getId();
// get the next item from the input producer
final InputProducer<InputType>.InputValue inputWrapper = inputProducer.next();
// depending on inputWrapper, actually do some work or not, putting result input result object
final MapResult<MapType> result;
if ( ! inputWrapper.isEOFMarker() ) {
// just skip doing anything if we don't have work to do, which is possible
@ -468,23 +452,19 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
final MapType mapValue = map.apply(input);
// enqueue the result into the mapResultQueue
result = new MapResult<MapType>(mapValue, jobID);
result = new MapResult<MapType>(mapValue, inputWrapper.getId());
if ( progressFunction != null )
progressFunction.progress(input);
} else {
// push back the EOF marker so other waiting threads can read it
inputQueue.put(inputWrapper.nextEOF());
// if there's no input we push empty MapResults with jobIDs for synchronization with Reducer
result = new MapResult<MapType>(jobID);
result = new MapResult<MapType>(inputWrapper.getId());
}
mapResultQueue.put(result);
final int nReduced = reducer.reduceAsMuchAsPossible(mapResultQueue);
} catch (Throwable ex) {
// logger.warn("Map job got exception " + ex);
errorTracker.notifyOfError(ex);
} finally {
// we finished a map job, release the job queue semaphore

View File

@ -84,7 +84,7 @@ class Reducer<MapType, ReduceType> {
if ( nextMapResult == null ) {
return false;
} else if ( nextMapResult.getJobID() < prevJobID + 1 ) {
throw new IllegalStateException("Next job ID " + nextMapResult.getJobID() + " is < previous job id " + prevJobID);
throw new IllegalStateException("Next job ID " + nextMapResult.getJobID() + " is not < previous job id " + prevJobID);
} else if ( nextMapResult.getJobID() == prevJobID + 1 ) {
return true;
} else {

View File

@ -1,19 +1,13 @@
package org.broadinstitute.sting.utils.nanoScheduler;
import org.broadinstitute.sting.BaseTest;
import org.broadinstitute.sting.utils.MultiThreadedErrorTracker;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.Semaphore;
/**
* UnitTests for the InputProducer
@ -42,34 +36,23 @@ public class InputProducerUnitTest extends BaseTest {
final List<Integer> elements = new ArrayList<Integer>(nElements);
for ( int i = 0; i < nElements; i++ ) elements.add(i);
final LinkedBlockingDeque<InputProducer<Integer>.InputValue> readQueue =
new LinkedBlockingDeque<InputProducer<Integer>.InputValue>(queueSize);
final InputProducer<Integer> ip = new InputProducer<Integer>(elements.iterator(), new MultiThreadedErrorTracker(), readQueue);
final ExecutorService es = Executors.newSingleThreadExecutor();
final InputProducer<Integer> ip = new InputProducer<Integer>(elements.iterator());
Assert.assertFalse(ip.allInputsHaveBeenRead(), "InputProvider said that all inputs have been read, but I haven't started reading yet");
Assert.assertEquals(ip.getNumInputValues(), -1, "InputProvider told me that the queue was done, but I haven't started reading yet");
es.submit(ip);
int lastValue = -1;
int nRead = 0;
while ( true ) {
while ( ip.hasNext() ) {
final int nTotalElements = ip.getNumInputValues();
final int observedQueueSize = readQueue.size();
Assert.assertTrue(observedQueueSize <= queueSize,
"Reader is enqueuing more elements " + observedQueueSize + " than allowed " + queueSize);
if ( nRead + observedQueueSize < nElements )
if ( nRead < nElements )
Assert.assertEquals(nTotalElements, -1, "getNumInputValues should have returned -1 with not all elements read");
// note, cannot test else case because elements input could have emptied between calls
final InputProducer<Integer>.InputValue value = readQueue.take();
final InputProducer<Integer>.InputValue value = ip.next();
if ( value.isEOFMarker() ) {
Assert.assertEquals(nRead, nElements, "Number of input values " + nRead + " not all that are expected " + nElements);
Assert.assertEquals(readQueue.size(), 0, "Last queue element found but queue contains more values!");
break;
} else {
Assert.assertTrue(lastValue < value.getValue(), "Read values coming out of order!");
@ -82,65 +65,5 @@ public class InputProducerUnitTest extends BaseTest {
Assert.assertTrue(ip.allInputsHaveBeenRead(), "InputProvider said that all inputs haven't been read, but I read them all");
Assert.assertEquals(ip.getNumInputValues(), nElements, "Wrong number of total elements getNumInputValues");
es.shutdownNow();
}
@Test(enabled = true, dataProvider = "InputProducerTest", timeOut = NanoSchedulerUnitTest.NANO_SCHEDULE_MAX_RUNTIME)
public void testInputProducerLocking(final int nElements, final int queueSize) throws InterruptedException {
final List<Integer> elements = new ArrayList<Integer>(nElements);
for ( int i = 0; i < nElements; i++ ) elements.add(i);
final LinkedBlockingDeque<InputProducer<Integer>.InputValue> readQueue =
new LinkedBlockingDeque<InputProducer<Integer>.InputValue>();
final InputProducer<Integer> ip = new InputProducer<Integer>(elements.iterator(), new MultiThreadedErrorTracker(), readQueue);
final ExecutorService es = Executors.newSingleThreadExecutor();
es.submit(ip);
ip.waitForDone();
Assert.assertEquals(ip.getNumInputValues(), nElements, "InputProvider told me that the queue was done, but I haven't started reading yet");
Assert.assertEquals(readQueue.size(), nElements + 1, "readQueue should have had all elements read into it");
}
final static class BlockingIterator<T> implements Iterator<T> {
final Semaphore blockNext = new Semaphore(0);
final Semaphore blockOnNext = new Semaphore(0);
final Iterator<T> underlyingIterator;
BlockingIterator(Iterator<T> underlyingIterator) {
this.underlyingIterator = underlyingIterator;
}
public void allowNext() {
blockNext.release(1);
}
public void blockTillNext() throws InterruptedException {
blockOnNext.acquire(1);
}
@Override
public boolean hasNext() {
return underlyingIterator.hasNext();
}
@Override
public T next() {
try {
blockNext.acquire(1);
T value = underlyingIterator.next();
blockOnNext.release(1);
return value;
} catch (InterruptedException ex) {
throw new RuntimeException(ex);
}
}
@Override
public void remove() {
throw new UnsupportedOperationException("x");
}
}
}