Working (efficient?) implementation of NanoScheduler

-- Groups inputs for each thread so that we don't have one thread execution per map() call
-- Added shutdown function
-- Documentation everywhere
-- Code cleanup
-- Extensive unittests
-- At this point I'm ready to integrate it into the engine for CPU parallel read walkers
This commit is contained in:
Mark DePristo 2012-08-24 15:34:23 -04:00
parent d6e6b30caf
commit 9de8077eeb
6 changed files with 265 additions and 102 deletions

View File

@ -810,4 +810,25 @@ public class Utils {
return Collections.unmodifiableMap(map); return Collections.unmodifiableMap(map);
} }
/**
* Divides the input list into a list of sublists, which contains group size elements (except potentially the last one)
*
* list = [A, B, C, D, E]
* groupSize = 2
* result = [[A, B], [C, D], [E]]
*
* @param list
* @param groupSize
* @return
*/
public static <T> List<List<T>> groupList(final List<T> list, final int groupSize) {
if ( groupSize < 1 ) throw new IllegalArgumentException("groupSize >= 1");
final List<List<T>> subLists = new LinkedList<List<T>>();
int n = list.size();
for ( int i = 0; i < n; i += groupSize ) {
subLists.add(list.subList(i, Math.min(i + groupSize, n)));
}
return subLists;
}
} }

View File

@ -3,10 +3,17 @@ package org.broadinstitute.sting.utils.nanoScheduler;
/** /**
* A function that maps from InputType -> ResultType * A function that maps from InputType -> ResultType
* *
* For use with the NanoScheduler
*
* User: depristo * User: depristo
* Date: 8/24/12 * Date: 8/24/12
* Time: 9:49 AM * Time: 9:49 AM
*/ */
public interface MapFunction<InputType, ResultType> { public interface MapFunction<InputType, ResultType> {
/**
* Return function on input, returning a value of ResultType
* @param input
* @return
*/
public ResultType apply(final InputType input); public ResultType apply(final InputType input);
} }

View File

@ -1,31 +0,0 @@
package org.broadinstitute.sting.utils.nanoScheduler;
/**
* Created with IntelliJ IDEA.
* User: depristo
* Date: 8/24/12
* Time: 9:57 AM
* To change this template use File | Settings | File Templates.
*/
public class MapResult<MapType> implements Comparable<MapResult<MapType>> {
final Integer id;
final MapType value;
public MapResult(final int id, final MapType value) {
this.id = id;
this.value = value;
}
public Integer getId() {
return id;
}
public MapType getValue() {
return value;
}
@Override
public int compareTo(MapResult<MapType> o) {
return getId().compareTo(o.getId());
}
}

View File

@ -2,6 +2,8 @@ package org.broadinstitute.sting.utils.nanoScheduler;
import com.google.java.contract.Ensures; import com.google.java.contract.Ensures;
import com.google.java.contract.Requires; import com.google.java.contract.Requires;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import java.util.Iterator; import java.util.Iterator;
@ -13,45 +15,147 @@ import java.util.concurrent.*;
/** /**
* Framework for very fine grained MapReduce parallelism * Framework for very fine grained MapReduce parallelism
* *
* The overall framework works like this
*
* nano <- new Nanoschedule(bufferSize, numberOfMapElementsToProcessTogether, nThreads)
* List[Input] outerData : outerDataLoop )
* result = nano.execute(outerData.iterator(), map, reduce)
*
* bufferSize determines how many elements from the input stream are read in one go by the
* nanoscheduler. The scheduler may hold up to bufferSize in memory at one time, as well
* as up to inputBufferSize map results as well.
*
* numberOfMapElementsToProcessTogether determines how many input elements are processed
* together each thread cycle. For example, if this value is 10, then the input data
* is grouped together in units of 10 elements each, and map called on each in term. The more
* heavy-weight the map function is, in terms of CPU costs, the more it makes sense to
* have this number be small. The lighter the CPU cost per element, though, the more this
* parameter introduces overhead due to need to context switch among threads to process
* each input element. A value of -1 lets the nanoscheduler guess at a reasonable trade-off value.
*
* nThreads is a bit obvious yes? Note though that the nanoscheduler assumes that it gets 1 thread
* from its client during the execute call, as this call blocks until all work is done. The caller
* thread is put to work by execute to help with the processing of the data. So in reality the
* nanoScheduler only spawn nThreads - 1 additional workers (if this is > 1).
*
* User: depristo * User: depristo
* Date: 8/24/12 * Date: 8/24/12
* Time: 9:47 AM * Time: 9:47 AM
*/ */
public class NanoScheduler<InputType, MapType, ReduceType> { public class NanoScheduler<InputType, MapType, ReduceType> {
final int bufferSize; private static Logger logger = Logger.getLogger(NanoScheduler.class);
final int nThreads;
final Iterator<InputType> inputReader;
final MapFunction<InputType, MapType> map;
final ReduceFunction<MapType, ReduceType> reduce;
final int bufferSize;
final int mapGroupSize;
final int nThreads;
final ExecutorService executor;
boolean shutdown = false;
/**
* Create a new nanoschedule with the desire characteristics requested by the argument
*
* @param bufferSize the number of input elements to read in each scheduling cycle.
* @param mapGroupSize How many inputs should be grouped together per map? If -1 we make a reasonable guess
* @param nThreads the number of threads to use to get work done, in addition to the thread calling execute
*/
public NanoScheduler(final int bufferSize, public NanoScheduler(final int bufferSize,
final int nThreads, final int mapGroupSize,
final Iterator<InputType> inputReader, final int nThreads) {
final MapFunction<InputType, MapType> map,
final ReduceFunction<MapType, ReduceType> reduce) {
if ( bufferSize < 1 ) throw new IllegalArgumentException("bufferSize must be >= 1, got " + bufferSize); if ( bufferSize < 1 ) throw new IllegalArgumentException("bufferSize must be >= 1, got " + bufferSize);
if ( nThreads < 1 ) throw new IllegalArgumentException("nThreads must be >= 1, got " + nThreads); if ( nThreads < 1 ) throw new IllegalArgumentException("nThreads must be >= 1, got " + nThreads);
if ( mapGroupSize > bufferSize ) throw new IllegalArgumentException("mapGroupSize " + mapGroupSize + " must be <= bufferSize " + bufferSize);
if ( mapGroupSize == 0 || mapGroupSize < -1 ) throw new IllegalArgumentException("mapGroupSize cannot be <= 0" + mapGroupSize);
this.bufferSize = bufferSize; this.bufferSize = bufferSize;
this.inputReader = inputReader;
this.map = map;
this.reduce = reduce;
this.nThreads = nThreads; this.nThreads = nThreads;
if ( mapGroupSize == -1 ) {
this.mapGroupSize = (int)Math.ceil(this.bufferSize / (10.0*this.nThreads));
logger.info(String.format("Dynamically setting grouping size to %d based on buffer size %d and n threads %d",
this.mapGroupSize, this.bufferSize, this.nThreads));
} else {
this.mapGroupSize = mapGroupSize;
}
this.executor = nThreads == 1 ? null : Executors.newFixedThreadPool(nThreads - 1);
} }
/**
* The number of parallel map threads in use with this NanoScheduler
* @return
*/
public int getnThreads() { public int getnThreads() {
return nThreads; return nThreads;
} }
private int getBufferSize() { /**
* The input buffer size used by this NanoScheduler
* @return
*/
public int getBufferSize() {
return bufferSize; return bufferSize;
} }
public ReduceType execute() { /**
* The grouping size used by this NanoScheduler
* @return
*/
public int getMapGroupSize() {
return mapGroupSize;
}
/**
* Tells this nanoScheduler to shutdown immediately, releasing all its resources.
*
* After this call, execute cannot be invoked without throwing an error
*/
public void shutdown() {
if ( executor != null ) {
final List<Runnable> remaining = executor.shutdownNow();
if ( ! remaining.isEmpty() )
throw new IllegalStateException("Remaining tasks found in the executor, unexpected behavior!");
}
shutdown = true;
}
/**
* @return true if this nanoScheduler is shutdown, or false if its still open for business
*/
public boolean isShutdown() {
return shutdown;
}
/**
* Execute a map/reduce job with this nanoScheduler
*
* Data comes from inputReader. Will be read until hasNext() == false.
* map is called on each element provided by inputReader. No order of operations is guarenteed
* reduce is called in order of the input data provided by inputReader on the result of map() applied
* to each element.
*
* Note that the caller thread is put to work with this function call. The call doesn't return
* until all elements have been processes.
*
* It is safe to call this function repeatedly on a single nanoScheduler, at least until the
* shutdown method is called.
*
* @param inputReader
* @param map
* @param reduce
* @return
*/
public ReduceType execute(final Iterator<InputType> inputReader,
final MapFunction<InputType, MapType> map,
final ReduceType initialValue,
final ReduceFunction<MapType, ReduceType> reduce) {
if ( isShutdown() )
throw new IllegalStateException("execute called on already shutdown NanoScheduler");
if ( getnThreads() == 1 ) { if ( getnThreads() == 1 ) {
return executeSingleThreaded(); return executeSingleThreaded(inputReader, map, initialValue, reduce);
} else { } else {
return executeMultiThreaded(); return executeMultiThreaded(inputReader, map, initialValue, reduce);
} }
} }
@ -59,8 +163,11 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
* Simple efficient reference implementation for single threaded execution * Simple efficient reference implementation for single threaded execution
* @return the reduce result of this map/reduce job * @return the reduce result of this map/reduce job
*/ */
private ReduceType executeSingleThreaded() { private ReduceType executeSingleThreaded(final Iterator<InputType> inputReader,
ReduceType sum = reduce.init(); final MapFunction<InputType, MapType> map,
final ReduceType initialValue,
final ReduceFunction<MapType, ReduceType> reduce) {
ReduceType sum = initialValue;
while ( inputReader.hasNext() ) { while ( inputReader.hasNext() ) {
final InputType input = inputReader.next(); final InputType input = inputReader.next();
final MapType mapValue = map.apply(input); final MapType mapValue = map.apply(input);
@ -74,20 +181,21 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
* *
* @return the reduce result of this map/reduce job * @return the reduce result of this map/reduce job
*/ */
private ReduceType executeMultiThreaded() { private ReduceType executeMultiThreaded(final Iterator<InputType> inputReader,
final ExecutorService executor = Executors.newFixedThreadPool(getnThreads() - 1); final MapFunction<InputType, MapType> map,
final ReduceType initialValue,
ReduceType sum = reduce.init(); final ReduceFunction<MapType, ReduceType> reduce) {
ReduceType sum = initialValue;
while ( inputReader.hasNext() ) { while ( inputReader.hasNext() ) {
try { try {
// read in our input values // read in our input values
final Queue<InputType> inputs = readInputs(); final List<InputType> inputs = readInputs(inputReader);
// send jobs for map // send jobs for map
final Queue<Future<MapType>> mapQueue = submitMapJobs(executor, inputs); final Queue<Future<List<MapType>>> mapQueue = submitMapJobs(map, executor, inputs);
// send off the reduce job, and block until we get at least one reduce result // send off the reduce job, and block until we get at least one reduce result
sum = reduceParallel(mapQueue, sum); sum = reduceParallel(reduce, mapQueue, sum);
} catch (InterruptedException ex) { } catch (InterruptedException ex) {
throw new ReviewedStingException("got execution exception", ex); throw new ReviewedStingException("got execution exception", ex);
} catch (ExecutionException ex) { } catch (ExecutionException ex) {
@ -95,23 +203,20 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
} }
} }
final List<Runnable> remaining = executor.shutdownNow();
if ( ! remaining.isEmpty() )
throw new ReviewedStingException("Remaining tasks found in the executor, unexpected behavior!");
return sum; return sum;
} }
@Requires("! mapQueue.isEmpty()") @Requires("! mapQueue.isEmpty()")
private ReduceType reduceParallel(final Queue<Future<MapType>> mapQueue, final ReduceType initSum) private ReduceType reduceParallel(final ReduceFunction<MapType, ReduceType> reduce,
final Queue<Future<List<MapType>>> mapQueue,
final ReduceType initSum)
throws InterruptedException, ExecutionException { throws InterruptedException, ExecutionException {
ReduceType sum = initSum; ReduceType sum = initSum;
// while mapQueue has something in it to reduce // while mapQueue has something in it to reduce
for ( final Future<MapType> future : mapQueue ) { for ( final Future<List<MapType>> future : mapQueue ) {
// block until we get the value for this task for ( final MapType value : future.get() ) // block until we get the values for this task
final MapType value = future.get(); sum = reduce.apply(value, sum);
sum = reduce.apply(value, sum);
} }
return sum; return sum;
@ -124,9 +229,9 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
*/ */
@Requires("inputReader.hasNext()") @Requires("inputReader.hasNext()")
@Ensures("!result.isEmpty()") @Ensures("!result.isEmpty()")
private Queue<InputType> readInputs() { private List<InputType> readInputs(final Iterator<InputType> inputReader) {
int n = 0; int n = 0;
final Queue<InputType> inputs = new LinkedList<InputType>(); final List<InputType> inputs = new LinkedList<InputType>();
while ( inputReader.hasNext() && n < getBufferSize() ) { while ( inputReader.hasNext() && n < getBufferSize() ) {
final InputType input = inputReader.next(); final InputType input = inputReader.next();
inputs.add(input); inputs.add(input);
@ -136,12 +241,14 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
} }
@Ensures("result.size() == inputs.size()") @Ensures("result.size() == inputs.size()")
private Queue<Future<MapType>> submitMapJobs(final ExecutorService executor, final Queue<InputType> inputs) { private Queue<Future<List<MapType>>> submitMapJobs(final MapFunction<InputType, MapType> map,
final Queue<Future<MapType>> mapQueue = new LinkedList<Future<MapType>>(); final ExecutorService executor,
final List<InputType> inputs) {
final Queue<Future<List<MapType>>> mapQueue = new LinkedList<Future<List<MapType>>>();
for ( final InputType input : inputs ) { for ( final List<InputType> subinputs : Utils.groupList(inputs, getMapGroupSize()) ) {
final CallableMap doMap = new CallableMap(input); final CallableMap doMap = new CallableMap(map, subinputs);
final Future<MapType> future = executor.submit(doMap); final Future<List<MapType>> future = executor.submit(doMap);
mapQueue.add(future); mapQueue.add(future);
} }
@ -151,15 +258,20 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
/** /**
* A simple callable version of the map function for use with the executor pool * A simple callable version of the map function for use with the executor pool
*/ */
private class CallableMap implements Callable<MapType> { private class CallableMap implements Callable<List<MapType>> {
final InputType input; final List<InputType> inputs;
final MapFunction<InputType, MapType> map;
private CallableMap(final InputType input) { private CallableMap(final MapFunction<InputType, MapType> map, final List<InputType> inputs) {
this.input = input; this.inputs = inputs;
this.map = map;
} }
@Override public MapType call() throws Exception { @Override public List<MapType> call() throws Exception {
return map.apply(input); final List<MapType> outputs = new LinkedList<MapType>();
for ( final InputType input : inputs )
outputs.add(map.apply(input));
return outputs;
} }
} }
} }

View File

@ -1,13 +1,18 @@
package org.broadinstitute.sting.utils.nanoScheduler; package org.broadinstitute.sting.utils.nanoScheduler;
/** /**
* A function that maps from InputType -> ResultType * A function that combines a value of MapType with an existing ReduceValue into a new ResultType
* *
* User: depristo * User: depristo
* Date: 8/24/12 * Date: 8/24/12
* Time: 9:49 AM * Time: 9:49 AM
*/ */
public interface ReduceFunction<MapType, ReduceType> { public interface ReduceFunction<MapType, ReduceType> {
public ReduceType init(); /**
* Combine one with sum into a new ReduceType
* @param one the result of a map call on an input element
* @param sum the cumulative reduce result over all previous map calls
* @return
*/
public ReduceType apply(MapType one, ReduceType sum); public ReduceType apply(MapType one, ReduceType sum);
} }

View File

@ -21,7 +21,6 @@ public class NanoSchedulerUnitTest extends BaseTest {
} }
private class ReduceSum implements ReduceFunction<Integer, Integer> { private class ReduceSum implements ReduceFunction<Integer, Integer> {
@Override public Integer init() { return 0; }
@Override public Integer apply(Integer one, Integer sum) { return one + sum; } @Override public Integer apply(Integer one, Integer sum) { return one + sum; }
} }
@ -33,17 +32,18 @@ public class NanoSchedulerUnitTest extends BaseTest {
} }
private class NanoSchedulerBasicTest extends TestDataProvider { private class NanoSchedulerBasicTest extends TestDataProvider {
final int bufferSize, nThreads, start, end, expectedResult; final int bufferSize, mapGroupSize, nThreads, start, end, expectedResult;
public NanoSchedulerBasicTest(final int bufferSize, final int nThreads, final int start, final int end) { public NanoSchedulerBasicTest(final int bufferSize, final int mapGroupSize, final int nThreads, final int start, final int end) {
super(NanoSchedulerBasicTest.class); super(NanoSchedulerBasicTest.class);
this.bufferSize = bufferSize; this.bufferSize = bufferSize;
this.mapGroupSize = mapGroupSize;
this.nThreads = nThreads; this.nThreads = nThreads;
this.start = start; this.start = start;
this.end = end; this.end = end;
this.expectedResult = sum2x(start, end); this.expectedResult = sum2x(start, end);
setName(String.format("%s nt=%d buf=%d start=%d end=%d sum=%d", setName(String.format("%s nt=%d buf=%d mapGroupSize=%d start=%d end=%d sum=%d",
getClass().getSimpleName(), nThreads, bufferSize, start, end, expectedResult)); getClass().getSimpleName(), nThreads, bufferSize, mapGroupSize, start, end, expectedResult));
} }
public Iterator<Integer> makeReader() { public Iterator<Integer> makeReader() {
@ -54,16 +54,22 @@ public class NanoSchedulerUnitTest extends BaseTest {
} }
public Map2x makeMap() { return new Map2x(); } public Map2x makeMap() { return new Map2x(); }
public Integer initReduce() { return 0; }
public ReduceSum makeReduce() { return new ReduceSum(); } public ReduceSum makeReduce() { return new ReduceSum(); }
} }
static NanoSchedulerBasicTest exampleTest = null;
@DataProvider(name = "NanoSchedulerBasicTest") @DataProvider(name = "NanoSchedulerBasicTest")
public Object[][] createNanoSchedulerBasicTest() { public Object[][] createNanoSchedulerBasicTest() {
for ( final int bufferSize : Arrays.asList(1, 10, 10000, 1000000) ) { for ( final int bufferSize : Arrays.asList(1, 10, 1000, 1000000) ) {
for ( final int nt : Arrays.asList(1, 2, 4, 8, 16, 32) ) { for ( final int mapGroupSize : Arrays.asList(-1, 1, 10, 100, 1000) ) {
for ( final int start : Arrays.asList(0) ) { if ( mapGroupSize <= bufferSize ) {
for ( final int end : Arrays.asList(1, 2, 11, 1000000) ) { for ( final int nt : Arrays.asList(1, 2, 4) ) {
new NanoSchedulerBasicTest(bufferSize, nt, start, end); for ( final int start : Arrays.asList(0) ) {
for ( final int end : Arrays.asList(1, 2, 11, 10000, 100000) ) {
exampleTest = new NanoSchedulerBasicTest(bufferSize, mapGroupSize, nt, start, end);
}
}
} }
} }
} }
@ -72,22 +78,65 @@ public class NanoSchedulerUnitTest extends BaseTest {
return NanoSchedulerBasicTest.getTests(NanoSchedulerBasicTest.class); return NanoSchedulerBasicTest.getTests(NanoSchedulerBasicTest.class);
} }
@Test(enabled = true, dataProvider = "NanoSchedulerBasicTest", timeOut = 2000) @Test(enabled = true, dataProvider = "NanoSchedulerBasicTest")
public void testNanoSchedulerBasicTest(final NanoSchedulerBasicTest test) throws InterruptedException { public void testSingleThreadedNanoScheduler(final NanoSchedulerBasicTest test) throws InterruptedException {
logger.warn("Running " + test); logger.warn("Running " + test);
final NanoScheduler<Integer, Integer, Integer> nanoScheduler = if ( test.nThreads == 1 )
new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.nThreads, testNanoScheduler(test);
test.makeReader(), test.makeMap(), test.makeReduce());
final Integer sum = nanoScheduler.execute();
Assert.assertNotNull(sum);
Assert.assertEquals((int)sum, test.expectedResult, "NanoScheduler sum not the same as calculated directly");
} }
@Test(enabled = true, dataProvider = "NanoSchedulerBasicTest", timeOut = 10000, dependsOnMethods = "testNanoSchedulerBasicTest") @Test(enabled = true, dataProvider = "NanoSchedulerBasicTest", timeOut = 10000, dependsOnMethods = "testSingleThreadedNanoScheduler")
public void testNanoSchedulerInLoop(final NanoSchedulerBasicTest test) throws InterruptedException { public void testMultiThreadedNanoScheduler(final NanoSchedulerBasicTest test) throws InterruptedException {
logger.warn("Running " + test); logger.warn("Running " + test);
for ( int i = 0; i < 10; i++ ) { if ( test.nThreads >= 1 )
testNanoSchedulerBasicTest(test); testNanoScheduler(test);
}
private void testNanoScheduler(final NanoSchedulerBasicTest test) throws InterruptedException {
final NanoScheduler<Integer, Integer, Integer> nanoScheduler =
new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.mapGroupSize, test.nThreads);
Assert.assertEquals(nanoScheduler.getBufferSize(), test.bufferSize, "bufferSize argument");
Assert.assertTrue(nanoScheduler.getMapGroupSize() >= test.mapGroupSize, "mapGroupSize argument");
Assert.assertEquals(nanoScheduler.getnThreads(), test.nThreads, "nThreads argument");
final Integer sum = nanoScheduler.execute(test.makeReader(), test.makeMap(), test.initReduce(), test.makeReduce());
Assert.assertNotNull(sum);
Assert.assertEquals((int)sum, test.expectedResult, "NanoScheduler sum not the same as calculated directly");
nanoScheduler.shutdown();
}
@Test(enabled = true, dataProvider = "NanoSchedulerBasicTest", dependsOnMethods = "testMultiThreadedNanoScheduler")
public void testNanoSchedulerInLoop(final NanoSchedulerBasicTest test) throws InterruptedException {
if ( test.bufferSize > 1 && (test.mapGroupSize > 1 || test.mapGroupSize == -1)) {
logger.warn("Running " + test);
final NanoScheduler<Integer, Integer, Integer> nanoScheduler =
new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.mapGroupSize, test.nThreads);
// test reusing the scheduler
for ( int i = 0; i < 10; i++ ) {
final Integer sum = nanoScheduler.execute(test.makeReader(), test.makeMap(), test.initReduce(), test.makeReduce());
Assert.assertNotNull(sum);
Assert.assertEquals((int)sum, test.expectedResult, "NanoScheduler sum not the same as calculated directly");
}
nanoScheduler.shutdown();
} }
} }
@Test()
public void testShutdown() throws InterruptedException {
final NanoScheduler<Integer, Integer, Integer> nanoScheduler = new NanoScheduler<Integer, Integer, Integer>(1, 1, 2);
Assert.assertFalse(nanoScheduler.isShutdown(), "scheduler should be alive");
nanoScheduler.shutdown();
Assert.assertTrue(nanoScheduler.isShutdown(), "scheduler should be dead");
}
@Test(expectedExceptions = IllegalStateException.class)
public void testShutdownExecuteFailure() throws InterruptedException {
final NanoScheduler<Integer, Integer, Integer> nanoScheduler = new NanoScheduler<Integer, Integer, Integer>(1, 1, 2);
nanoScheduler.shutdown();
nanoScheduler.execute(exampleTest.makeReader(), exampleTest.makeMap(), exampleTest.initReduce(), exampleTest.makeReduce());
}
} }