Fundamentally better model for the NanoScheduler
-- Now each map job reads a value, performs map, and does as much reducing as possible. This ensures that we scale performance with the nct value, so -nct 2 should result in 2x performance, -nct 3 3x, etc. All of this is accomplished using exactly NCT% of the CPU of the machine. -- Has the additional value of actually simplifying the code -- Resolves a long-standing annoyance with the nano scheduler.
This commit is contained in:
parent
d0cd29cb36
commit
1ca13f9581
|
|
@ -1,16 +1,15 @@
|
|||
package org.broadinstitute.sting.utils.nanoScheduler;
|
||||
|
||||
import org.apache.log4j.Logger;
|
||||
import org.broadinstitute.sting.utils.MultiThreadedErrorTracker;
|
||||
|
||||
import java.util.Iterator;
|
||||
import java.util.concurrent.BlockingQueue;
|
||||
import java.util.concurrent.CountDownLatch;
|
||||
|
||||
/**
|
||||
* Producer Thread that reads input values from an inputReads and puts them into an output queue
|
||||
* Helper class that allows multiple threads to reads input values from
|
||||
* an iterator, and track the number of items read from that iterator.
|
||||
*/
|
||||
class InputProducer<InputType> implements Runnable {
|
||||
class InputProducer<InputType> {
|
||||
private final static Logger logger = Logger.getLogger(InputProducer.class);
|
||||
|
||||
/**
|
||||
|
|
@ -18,13 +17,6 @@ class InputProducer<InputType> implements Runnable {
|
|||
*/
|
||||
final Iterator<InputType> inputReader;
|
||||
|
||||
/**
|
||||
* Where we put our input values for consumption
|
||||
*/
|
||||
final BlockingQueue<InputValue> outputQueue;
|
||||
|
||||
final MultiThreadedErrorTracker errorTracker;
|
||||
|
||||
/**
|
||||
* Have we read the last value from inputReader?
|
||||
*
|
||||
|
|
@ -34,6 +26,14 @@ class InputProducer<InputType> implements Runnable {
|
|||
*/
|
||||
boolean readLastValue = false;
|
||||
|
||||
/**
|
||||
* Once we've readLastValue, lastValue contains a continually
|
||||
* updating InputValue where EOF is true. It's not necessarily
|
||||
* a single value, as each read updates lastValue with the
|
||||
* next EOF marker
|
||||
*/
|
||||
private InputValue lastValue = null;
|
||||
|
||||
int nRead = 0;
|
||||
int inputID = -1;
|
||||
|
||||
|
|
@ -43,16 +43,9 @@ class InputProducer<InputType> implements Runnable {
|
|||
*/
|
||||
final CountDownLatch latch = new CountDownLatch(1);
|
||||
|
||||
public InputProducer(final Iterator<InputType> inputReader,
|
||||
final MultiThreadedErrorTracker errorTracker,
|
||||
final BlockingQueue<InputValue> outputQueue) {
|
||||
public InputProducer(final Iterator<InputType> inputReader) {
|
||||
if ( inputReader == null ) throw new IllegalArgumentException("inputReader cannot be null");
|
||||
if ( errorTracker == null ) throw new IllegalArgumentException("errorTracker cannot be null");
|
||||
if ( outputQueue == null ) throw new IllegalArgumentException("OutputQueue cannot be null");
|
||||
|
||||
this.inputReader = inputReader;
|
||||
this.errorTracker = errorTracker;
|
||||
this.outputQueue = outputQueue;
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
@ -82,9 +75,8 @@ class InputProducer<InputType> implements Runnable {
|
|||
* This method is synchronized, as it manipulates local state accessed across multiple threads.
|
||||
*
|
||||
* @return the next input stream value, or null if the stream contains no more elements
|
||||
* @throws InterruptedException
|
||||
*/
|
||||
private synchronized InputType readNextItem() throws InterruptedException {
|
||||
private synchronized InputType readNextItem() {
|
||||
if ( ! inputReader.hasNext() ) {
|
||||
// we are done, mark ourselves as such and return null
|
||||
readLastValue = true;
|
||||
|
|
@ -100,49 +92,60 @@ class InputProducer<InputType> implements Runnable {
|
|||
}
|
||||
|
||||
/**
|
||||
* Run this input producer, looping over all items in the input reader and
|
||||
* enqueueing them as InputValues into the outputQueue. After the
|
||||
* end of the stream has been encountered, any threads waiting because
|
||||
* they called waitForDone() will be freed.
|
||||
* Are there currently more values in the iterator?
|
||||
*
|
||||
* Note the word currently. It's possible that some already submitted
|
||||
* job will read a value from this InputProvider, so in some sense
|
||||
* there are no more values and in the future there'll be no next
|
||||
* value. That said, once this returns false it means that all
|
||||
* of the possible values have been read
|
||||
*
|
||||
* @return true if a future call to next might return a non-EOF value, false if
|
||||
* the underlying iterator is definitely empty
|
||||
*/
|
||||
public void run() {
|
||||
try {
|
||||
while ( true ) {
|
||||
final InputType value = readNextItem();
|
||||
|
||||
if ( value == null ) {
|
||||
if ( ! readLastValue )
|
||||
throw new IllegalStateException("value == null but readLastValue is false!");
|
||||
|
||||
// add the EOF object so our consumer knows we are done in all inputs
|
||||
// note that we do not increase inputID here, so that variable indicates the ID
|
||||
// of the last real value read from the queue
|
||||
outputQueue.put(new InputValue(inputID + 1));
|
||||
break;
|
||||
} else {
|
||||
// add the actual value to the outputQueue
|
||||
outputQueue.put(new InputValue(++inputID, value));
|
||||
}
|
||||
}
|
||||
|
||||
latch.countDown();
|
||||
} catch (Throwable ex) {
|
||||
errorTracker.notifyOfError(ex);
|
||||
} finally {
|
||||
// logger.info("Exiting input thread readLastValue = " + readLastValue);
|
||||
}
|
||||
public synchronized boolean hasNext() {
|
||||
return ! allInputsHaveBeenRead();
|
||||
}
|
||||
|
||||
/**
|
||||
* Block until all of the items have been read from inputReader.
|
||||
* Get the next InputValue from this producer. The next value is
|
||||
* either (1) the next value from the iterator, in which case the
|
||||
* the return value is an InputValue containing that value, or (2)
|
||||
* an InputValue with the EOF marker, indicating that the underlying
|
||||
* iterator has been exhausted.
|
||||
*
|
||||
* Note that this call doesn't actually read anything. You have to submit a thread
|
||||
* to actually execute run() directly.
|
||||
* This function never fails -- it can be called endlessly and
|
||||
* while the underlying iterator has values it returns them, and then
|
||||
* it returns a succession of EOF marking input values.
|
||||
*
|
||||
* @throws InterruptedException
|
||||
* @return an InputValue containing the next value in the underlying
|
||||
* iterator, or one with EOF marker, if the iterator is exhausted
|
||||
*/
|
||||
public void waitForDone() throws InterruptedException {
|
||||
latch.await();
|
||||
public synchronized InputValue next() {
|
||||
if ( readLastValue ) {
|
||||
// we read the last value, so our value is the next
|
||||
// EOF marker based on the last value. Make sure to
|
||||
// update the last value so the markers keep incrementing
|
||||
// their job ids
|
||||
lastValue = lastValue.nextEOF();
|
||||
return lastValue;
|
||||
} else {
|
||||
final InputType value = readNextItem();
|
||||
|
||||
if ( value == null ) {
|
||||
if ( ! readLastValue )
|
||||
throw new IllegalStateException("value == null but readLastValue is false!");
|
||||
|
||||
// add the EOF object so our consumer knows we are done in all inputs
|
||||
// note that we do not increase inputID here, so that variable indicates the ID
|
||||
// of the last real value read from the queue
|
||||
lastValue = new InputValue(inputID + 1);
|
||||
return lastValue;
|
||||
} else {
|
||||
// add the actual value to the outputQueue
|
||||
return new InputValue(++inputID, value);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
|
|
|
|||
|
|
@ -47,7 +47,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
|
|||
|
||||
final int bufferSize;
|
||||
final int nThreads;
|
||||
final ExecutorService inputExecutor;
|
||||
final ExecutorService masterExecutor;
|
||||
final ExecutorService mapExecutor;
|
||||
final Semaphore runningMapJobSlots;
|
||||
|
|
@ -75,14 +74,12 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
|
|||
this.nThreads = nThreads;
|
||||
|
||||
if ( nThreads == 1 ) {
|
||||
this.mapExecutor = this.inputExecutor = this.masterExecutor = null;
|
||||
this.mapExecutor = this.masterExecutor = null;
|
||||
runningMapJobSlots = null;
|
||||
} else {
|
||||
this.mapExecutor = Executors.newFixedThreadPool(nThreads - 1, new NamedThreadFactory("NS-map-thread-%d"));
|
||||
runningMapJobSlots = new Semaphore(this.bufferSize);
|
||||
|
||||
this.inputExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("NS-input-thread-%d"));
|
||||
this.masterExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("NS-master-thread-%d"));
|
||||
this.mapExecutor = Executors.newFixedThreadPool(nThreads, new NamedThreadFactory("NS-map-thread-%d"));
|
||||
runningMapJobSlots = new Semaphore(this.bufferSize);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -111,7 +108,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
|
|||
*/
|
||||
public void shutdown() {
|
||||
if ( nThreads > 1 ) {
|
||||
shutdownExecutor("inputExecutor", inputExecutor);
|
||||
shutdownExecutor("mapExecutor", mapExecutor);
|
||||
shutdownExecutor("masterExecutor", masterExecutor);
|
||||
}
|
||||
|
|
@ -323,7 +319,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
|
|||
if ( errorTracker.hasAnErrorOccurred() ) {
|
||||
masterExecutor.shutdownNow();
|
||||
mapExecutor.shutdownNow();
|
||||
inputExecutor.shutdownNow();
|
||||
errorTracker.throwErrorIfPending();
|
||||
}
|
||||
}
|
||||
|
|
@ -351,15 +346,8 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
|
|||
|
||||
@Override
|
||||
public ReduceType call() {
|
||||
// a blocking queue that limits the number of input datum to the requested buffer size
|
||||
// note we need +1 because we continue to enqueue the lastObject
|
||||
final BlockingQueue<InputProducer<InputType>.InputValue> inputQueue
|
||||
= new LinkedBlockingDeque<InputProducer<InputType>.InputValue>(bufferSize+1);
|
||||
|
||||
// Create the input producer and start it running
|
||||
final InputProducer<InputType> inputProducer =
|
||||
new InputProducer<InputType>(inputReader, errorTracker, inputQueue);
|
||||
inputExecutor.submit(inputProducer);
|
||||
final InputProducer<InputType> inputProducer = new InputProducer<InputType>(inputReader);
|
||||
|
||||
// a priority queue that stores up to bufferSize elements
|
||||
// produced by completed map jobs.
|
||||
|
|
@ -376,7 +364,7 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
|
|||
// acquire a slot to run a map job. Blocks if too many jobs are enqueued
|
||||
runningMapJobSlots.acquire();
|
||||
|
||||
mapExecutor.submit(new MapReduceJob(inputQueue, mapResultQueue, map, reducer));
|
||||
mapExecutor.submit(new ReadMapReduceJob(inputProducer, mapResultQueue, map, reducer));
|
||||
nSubmittedJobs++;
|
||||
}
|
||||
|
||||
|
|
@ -402,10 +390,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
|
|||
// logger.warn("waiting for final reduce");
|
||||
final ReduceType finalSum = reducer.waitForFinalReduce();
|
||||
|
||||
// now wait for the input provider thread to terminate
|
||||
// logger.warn("waiting on inputProducer");
|
||||
inputProducer.waitForDone();
|
||||
|
||||
// wait for all the map threads to finish by acquiring and then releasing all map job semaphores
|
||||
// logger.warn("waiting on map");
|
||||
runningMapJobSlots.acquire(bufferSize);
|
||||
|
|
@ -434,17 +418,17 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
|
|||
}
|
||||
}
|
||||
|
||||
private class MapReduceJob implements Runnable {
|
||||
final BlockingQueue<InputProducer<InputType>.InputValue> inputQueue;
|
||||
private class ReadMapReduceJob implements Runnable {
|
||||
final InputProducer<InputType> inputProducer;
|
||||
final PriorityBlockingQueue<MapResult<MapType>> mapResultQueue;
|
||||
final NSMapFunction<InputType, MapType> map;
|
||||
final Reducer<MapType, ReduceType> reducer;
|
||||
|
||||
private MapReduceJob(BlockingQueue<InputProducer<InputType>.InputValue> inputQueue,
|
||||
final PriorityBlockingQueue<MapResult<MapType>> mapResultQueue,
|
||||
final NSMapFunction<InputType, MapType> map,
|
||||
final Reducer<MapType, ReduceType> reducer) {
|
||||
this.inputQueue = inputQueue;
|
||||
private ReadMapReduceJob(final InputProducer<InputType> inputProducer,
|
||||
final PriorityBlockingQueue<MapResult<MapType>> mapResultQueue,
|
||||
final NSMapFunction<InputType, MapType> map,
|
||||
final Reducer<MapType, ReduceType> reducer) {
|
||||
this.inputProducer = inputProducer;
|
||||
this.mapResultQueue = mapResultQueue;
|
||||
this.map = map;
|
||||
this.reducer = reducer;
|
||||
|
|
@ -453,10 +437,10 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
|
|||
@Override
|
||||
public void run() {
|
||||
try {
|
||||
//debugPrint("Running MapReduceJob " + jobID);
|
||||
final InputProducer<InputType>.InputValue inputWrapper = inputQueue.take();
|
||||
final int jobID = inputWrapper.getId();
|
||||
// get the next item from the input producer
|
||||
final InputProducer<InputType>.InputValue inputWrapper = inputProducer.next();
|
||||
|
||||
// depending on inputWrapper, actually do some work or not, putting result input result object
|
||||
final MapResult<MapType> result;
|
||||
if ( ! inputWrapper.isEOFMarker() ) {
|
||||
// just skip doing anything if we don't have work to do, which is possible
|
||||
|
|
@ -468,23 +452,19 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
|
|||
final MapType mapValue = map.apply(input);
|
||||
|
||||
// enqueue the result into the mapResultQueue
|
||||
result = new MapResult<MapType>(mapValue, jobID);
|
||||
result = new MapResult<MapType>(mapValue, inputWrapper.getId());
|
||||
|
||||
if ( progressFunction != null )
|
||||
progressFunction.progress(input);
|
||||
} else {
|
||||
// push back the EOF marker so other waiting threads can read it
|
||||
inputQueue.put(inputWrapper.nextEOF());
|
||||
|
||||
// if there's no input we push empty MapResults with jobIDs for synchronization with Reducer
|
||||
result = new MapResult<MapType>(jobID);
|
||||
result = new MapResult<MapType>(inputWrapper.getId());
|
||||
}
|
||||
|
||||
mapResultQueue.put(result);
|
||||
|
||||
final int nReduced = reducer.reduceAsMuchAsPossible(mapResultQueue);
|
||||
} catch (Throwable ex) {
|
||||
// logger.warn("Map job got exception " + ex);
|
||||
errorTracker.notifyOfError(ex);
|
||||
} finally {
|
||||
// we finished a map job, release the job queue semaphore
|
||||
|
|
|
|||
|
|
@ -84,7 +84,7 @@ class Reducer<MapType, ReduceType> {
|
|||
if ( nextMapResult == null ) {
|
||||
return false;
|
||||
} else if ( nextMapResult.getJobID() < prevJobID + 1 ) {
|
||||
throw new IllegalStateException("Next job ID " + nextMapResult.getJobID() + " is < previous job id " + prevJobID);
|
||||
throw new IllegalStateException("Next job ID " + nextMapResult.getJobID() + " is not < previous job id " + prevJobID);
|
||||
} else if ( nextMapResult.getJobID() == prevJobID + 1 ) {
|
||||
return true;
|
||||
} else {
|
||||
|
|
|
|||
|
|
@ -1,19 +1,13 @@
|
|||
package org.broadinstitute.sting.utils.nanoScheduler;
|
||||
|
||||
import org.broadinstitute.sting.BaseTest;
|
||||
import org.broadinstitute.sting.utils.MultiThreadedErrorTracker;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.DataProvider;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.concurrent.ExecutorService;
|
||||
import java.util.concurrent.Executors;
|
||||
import java.util.concurrent.LinkedBlockingDeque;
|
||||
import java.util.concurrent.Semaphore;
|
||||
|
||||
/**
|
||||
* UnitTests for the InputProducer
|
||||
|
|
@ -42,34 +36,23 @@ public class InputProducerUnitTest extends BaseTest {
|
|||
final List<Integer> elements = new ArrayList<Integer>(nElements);
|
||||
for ( int i = 0; i < nElements; i++ ) elements.add(i);
|
||||
|
||||
final LinkedBlockingDeque<InputProducer<Integer>.InputValue> readQueue =
|
||||
new LinkedBlockingDeque<InputProducer<Integer>.InputValue>(queueSize);
|
||||
|
||||
final InputProducer<Integer> ip = new InputProducer<Integer>(elements.iterator(), new MultiThreadedErrorTracker(), readQueue);
|
||||
|
||||
final ExecutorService es = Executors.newSingleThreadExecutor();
|
||||
final InputProducer<Integer> ip = new InputProducer<Integer>(elements.iterator());
|
||||
|
||||
Assert.assertFalse(ip.allInputsHaveBeenRead(), "InputProvider said that all inputs have been read, but I haven't started reading yet");
|
||||
Assert.assertEquals(ip.getNumInputValues(), -1, "InputProvider told me that the queue was done, but I haven't started reading yet");
|
||||
|
||||
es.submit(ip);
|
||||
|
||||
int lastValue = -1;
|
||||
int nRead = 0;
|
||||
while ( true ) {
|
||||
while ( ip.hasNext() ) {
|
||||
final int nTotalElements = ip.getNumInputValues();
|
||||
final int observedQueueSize = readQueue.size();
|
||||
Assert.assertTrue(observedQueueSize <= queueSize,
|
||||
"Reader is enqueuing more elements " + observedQueueSize + " than allowed " + queueSize);
|
||||
|
||||
if ( nRead + observedQueueSize < nElements )
|
||||
if ( nRead < nElements )
|
||||
Assert.assertEquals(nTotalElements, -1, "getNumInputValues should have returned -1 with not all elements read");
|
||||
// note, cannot test else case because elements input could have emptied between calls
|
||||
|
||||
final InputProducer<Integer>.InputValue value = readQueue.take();
|
||||
final InputProducer<Integer>.InputValue value = ip.next();
|
||||
if ( value.isEOFMarker() ) {
|
||||
Assert.assertEquals(nRead, nElements, "Number of input values " + nRead + " not all that are expected " + nElements);
|
||||
Assert.assertEquals(readQueue.size(), 0, "Last queue element found but queue contains more values!");
|
||||
break;
|
||||
} else {
|
||||
Assert.assertTrue(lastValue < value.getValue(), "Read values coming out of order!");
|
||||
|
|
@ -82,65 +65,5 @@ public class InputProducerUnitTest extends BaseTest {
|
|||
|
||||
Assert.assertTrue(ip.allInputsHaveBeenRead(), "InputProvider said that all inputs haven't been read, but I read them all");
|
||||
Assert.assertEquals(ip.getNumInputValues(), nElements, "Wrong number of total elements getNumInputValues");
|
||||
es.shutdownNow();
|
||||
}
|
||||
|
||||
@Test(enabled = true, dataProvider = "InputProducerTest", timeOut = NanoSchedulerUnitTest.NANO_SCHEDULE_MAX_RUNTIME)
|
||||
public void testInputProducerLocking(final int nElements, final int queueSize) throws InterruptedException {
|
||||
final List<Integer> elements = new ArrayList<Integer>(nElements);
|
||||
for ( int i = 0; i < nElements; i++ ) elements.add(i);
|
||||
|
||||
final LinkedBlockingDeque<InputProducer<Integer>.InputValue> readQueue =
|
||||
new LinkedBlockingDeque<InputProducer<Integer>.InputValue>();
|
||||
|
||||
final InputProducer<Integer> ip = new InputProducer<Integer>(elements.iterator(), new MultiThreadedErrorTracker(), readQueue);
|
||||
|
||||
final ExecutorService es = Executors.newSingleThreadExecutor();
|
||||
es.submit(ip);
|
||||
|
||||
ip.waitForDone();
|
||||
|
||||
Assert.assertEquals(ip.getNumInputValues(), nElements, "InputProvider told me that the queue was done, but I haven't started reading yet");
|
||||
Assert.assertEquals(readQueue.size(), nElements + 1, "readQueue should have had all elements read into it");
|
||||
}
|
||||
|
||||
final static class BlockingIterator<T> implements Iterator<T> {
|
||||
final Semaphore blockNext = new Semaphore(0);
|
||||
final Semaphore blockOnNext = new Semaphore(0);
|
||||
final Iterator<T> underlyingIterator;
|
||||
|
||||
BlockingIterator(Iterator<T> underlyingIterator) {
|
||||
this.underlyingIterator = underlyingIterator;
|
||||
}
|
||||
|
||||
public void allowNext() {
|
||||
blockNext.release(1);
|
||||
}
|
||||
|
||||
public void blockTillNext() throws InterruptedException {
|
||||
blockOnNext.acquire(1);
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return underlyingIterator.hasNext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public T next() {
|
||||
try {
|
||||
blockNext.acquire(1);
|
||||
T value = underlyingIterator.next();
|
||||
blockOnNext.release(1);
|
||||
return value;
|
||||
} catch (InterruptedException ex) {
|
||||
throw new RuntimeException(ex);
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void remove() {
|
||||
throw new UnsupportedOperationException("x");
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue