NanoScheduler no longer groups inputs, each map() call is interlaced now

-- Maximizes the efficiency of the threads
-- Simplifies interface (yea!)
-- Reduces number of combinatorial tests that need to be performed
This commit is contained in:
Mark DePristo 2012-08-31 20:10:26 -04:00
parent 397a5551ef
commit 6055101df8
3 changed files with 43 additions and 78 deletions

View File

@ -55,13 +55,11 @@ public class TraverseReadsNano<M,T> extends TraversalEngine<M,T,ReadWalker<M,T>,
/** our log, which we want to capture anything from this class */ /** our log, which we want to capture anything from this class */
protected static final Logger logger = Logger.getLogger(TraverseReadsNano.class); protected static final Logger logger = Logger.getLogger(TraverseReadsNano.class);
private static final boolean DEBUG = false; private static final boolean DEBUG = false;
private static final int MIN_GROUP_SIZE = 100;
final NanoScheduler<MapData, M, T> nanoScheduler; final NanoScheduler<MapData, M, T> nanoScheduler;
public TraverseReadsNano(int nThreads) { public TraverseReadsNano(int nThreads) {
final int bufferSize = ReadShard.getReadBufferSize() + 1; // actually has 1 more than max final int bufferSize = ReadShard.getReadBufferSize() + 1; // actually has 1 more than max
final int mapGroupSize = (int)Math.max(Math.ceil(bufferSize / 50.0 + 1), MIN_GROUP_SIZE); nanoScheduler = new NanoScheduler<MapData, M, T>(bufferSize, nThreads);
nanoScheduler = new NanoScheduler<MapData, M, T>(bufferSize, mapGroupSize, nThreads);
} }
@Override @Override

View File

@ -3,7 +3,6 @@ package org.broadinstitute.sting.utils.nanoScheduler;
import com.google.java.contract.Ensures; import com.google.java.contract.Ensures;
import com.google.java.contract.Requires; import com.google.java.contract.Requires;
import org.apache.log4j.Logger; import org.apache.log4j.Logger;
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import java.util.Iterator; import java.util.Iterator;
@ -47,7 +46,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
private final static boolean ALLOW_SINGLE_THREAD_FASTPATH = true; private final static boolean ALLOW_SINGLE_THREAD_FASTPATH = true;
final int bufferSize; final int bufferSize;
final int mapGroupSize;
final int nThreads; final int nThreads;
final ExecutorService executor; final ExecutorService executor;
boolean shutdown = false; boolean shutdown = false;
@ -57,29 +55,15 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
* Create a new nanoschedule with the desire characteristics requested by the argument * Create a new nanoschedule with the desire characteristics requested by the argument
* *
* @param bufferSize the number of input elements to read in each scheduling cycle. * @param bufferSize the number of input elements to read in each scheduling cycle.
* @param mapGroupSize How many inputs should be grouped together per map? If -1 we make a reasonable guess
* @param nThreads the number of threads to use to get work done, in addition to the thread calling execute * @param nThreads the number of threads to use to get work done, in addition to the thread calling execute
*/ */
public NanoScheduler(final int bufferSize, public NanoScheduler(final int bufferSize,
final int mapGroupSize,
final int nThreads) { final int nThreads) {
if ( bufferSize < 1 ) throw new IllegalArgumentException("bufferSize must be >= 1, got " + bufferSize); if ( bufferSize < 1 ) throw new IllegalArgumentException("bufferSize must be >= 1, got " + bufferSize);
if ( nThreads < 1 ) throw new IllegalArgumentException("nThreads must be >= 1, got " + nThreads); if ( nThreads < 1 ) throw new IllegalArgumentException("nThreads must be >= 1, got " + nThreads);
if ( mapGroupSize > bufferSize ) throw new IllegalArgumentException("mapGroupSize " + mapGroupSize + " must be <= bufferSize " + bufferSize);
if ( mapGroupSize == 0 || mapGroupSize < -1 ) throw new IllegalArgumentException("mapGroupSize cannot be <= 0" + mapGroupSize);
this.bufferSize = bufferSize; this.bufferSize = bufferSize;
this.nThreads = nThreads; this.nThreads = nThreads;
if ( mapGroupSize == -1 ) {
this.mapGroupSize = (int)Math.ceil(this.bufferSize / (10.0*this.nThreads));
logger.info(String.format("Dynamically setting grouping size to %d based on buffer size %d and n threads %d",
this.mapGroupSize, this.bufferSize, this.nThreads));
} else {
this.mapGroupSize = mapGroupSize;
}
this.executor = nThreads == 1 ? null : Executors.newFixedThreadPool(nThreads); this.executor = nThreads == 1 ? null : Executors.newFixedThreadPool(nThreads);
} }
@ -101,15 +85,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
return bufferSize; return bufferSize;
} }
/**
* The grouping size used by this NanoScheduler
* @return
*/
@Ensures("result > 0")
public int getMapGroupSize() {
return mapGroupSize;
}
/** /**
* Tells this nanoScheduler to shutdown immediately, releasing all its resources. * Tells this nanoScheduler to shutdown immediately, releasing all its resources.
* *
@ -214,10 +189,10 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
final List<InputType> inputs = readInputs(inputReader); final List<InputType> inputs = readInputs(inputReader);
// send jobs for map // send jobs for map
final Queue<Future<List<MapType>>> mapQueue = submitMapJobs(map, executor, inputs); final Queue<Future<MapType>> mapQueue = submitMapJobs(map, executor, inputs);
// send off the reduce job, and block until we get at least one reduce result // send off the reduce job, and block until we get at least one reduce result
sum = reduceParallel(reduce, mapQueue, sum); sum = reduceSerial(reduce, mapQueue, sum);
} catch (InterruptedException ex) { } catch (InterruptedException ex) {
throw new ReviewedStingException("got execution exception", ex); throw new ReviewedStingException("got execution exception", ex);
} catch (ExecutionException ex) { } catch (ExecutionException ex) {
@ -229,15 +204,15 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
} }
@Requires({"reduce != null", "! mapQueue.isEmpty()"}) @Requires({"reduce != null", "! mapQueue.isEmpty()"})
private ReduceType reduceParallel(final ReduceFunction<MapType, ReduceType> reduce, private ReduceType reduceSerial(final ReduceFunction<MapType, ReduceType> reduce,
final Queue<Future<List<MapType>>> mapQueue, final Queue<Future<MapType>> mapQueue,
final ReduceType initSum) final ReduceType initSum)
throws InterruptedException, ExecutionException { throws InterruptedException, ExecutionException {
ReduceType sum = initSum; ReduceType sum = initSum;
// while mapQueue has something in it to reduce // while mapQueue has something in it to reduce
for ( final Future<List<MapType>> future : mapQueue ) { for ( final Future<MapType> future : mapQueue ) {
for ( final MapType value : future.get() ) // block until we get the values for this task final MapType value = future.get(); // block until we get the values for this task
sum = reduce.apply(value, sum); sum = reduce.apply(value, sum);
} }
@ -247,7 +222,7 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
/** /**
* Read up to inputBufferSize elements from inputReader * Read up to inputBufferSize elements from inputReader
* *
* @return a queue of inputs read in, containing one or more values of InputType read in * @return a queue of input read in, containing one or more values of InputType read in
*/ */
@Requires("inputReader.hasNext()") @Requires("inputReader.hasNext()")
@Ensures("!result.isEmpty()") @Ensures("!result.isEmpty()")
@ -263,14 +238,14 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
} }
@Requires({"map != null", "! inputs.isEmpty()"}) @Requires({"map != null", "! inputs.isEmpty()"})
private Queue<Future<List<MapType>>> submitMapJobs(final MapFunction<InputType, MapType> map, private Queue<Future<MapType>> submitMapJobs(final MapFunction<InputType, MapType> map,
final ExecutorService executor, final ExecutorService executor,
final List<InputType> inputs) { final List<InputType> inputs) {
final Queue<Future<List<MapType>>> mapQueue = new LinkedList<Future<List<MapType>>>(); final Queue<Future<MapType>> mapQueue = new LinkedList<Future<MapType>>();
for ( final List<InputType> subinputs : Utils.groupList(inputs, getMapGroupSize()) ) { for ( final InputType input : inputs ) {
final CallableMap doMap = new CallableMap(map, subinputs); final CallableMap doMap = new CallableMap(map, input);
final Future<List<MapType>> future = executor.submit(doMap); final Future<MapType> future = executor.submit(doMap);
mapQueue.add(future); mapQueue.add(future);
} }
@ -280,23 +255,18 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
/** /**
* A simple callable version of the map function for use with the executor pool * A simple callable version of the map function for use with the executor pool
*/ */
private class CallableMap implements Callable<List<MapType>> { private class CallableMap implements Callable<MapType> {
final List<InputType> inputs; final InputType input;
final MapFunction<InputType, MapType> map; final MapFunction<InputType, MapType> map;
@Requires({"map != null", "inputs.size() <= getMapGroupSize()"}) @Requires({"map != null"})
private CallableMap(final MapFunction<InputType, MapType> map, final List<InputType> inputs) { private CallableMap(final MapFunction<InputType, MapType> map, final InputType inputs) {
this.inputs = inputs; this.input = inputs;
this.map = map; this.map = map;
} }
@Ensures("result.size() == inputs.size()") @Override public MapType call() throws Exception {
@Override public List<MapType> call() throws Exception { return map.apply(input);
final List<MapType> outputs = new LinkedList<MapType>();
for ( final InputType input : inputs )
outputs.add(map.apply(input));
debugPrint(" Processed %d elements with map", outputs.size());
return outputs;
} }
} }
} }

View File

@ -5,7 +5,10 @@ import org.testng.Assert;
import org.testng.annotations.DataProvider; import org.testng.annotations.DataProvider;
import org.testng.annotations.Test; import org.testng.annotations.Test;
import java.util.*; import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
/** /**
* UnitTests for the NanoScheduler * UnitTests for the NanoScheduler
@ -39,18 +42,17 @@ public class NanoSchedulerUnitTest extends BaseTest {
} }
private static class NanoSchedulerBasicTest extends TestDataProvider { private static class NanoSchedulerBasicTest extends TestDataProvider {
final int bufferSize, mapGroupSize, nThreads, start, end, expectedResult; final int bufferSize, nThreads, start, end, expectedResult;
public NanoSchedulerBasicTest(final int bufferSize, final int mapGroupSize, final int nThreads, final int start, final int end) { public NanoSchedulerBasicTest(final int bufferSize, final int nThreads, final int start, final int end) {
super(NanoSchedulerBasicTest.class); super(NanoSchedulerBasicTest.class);
this.bufferSize = bufferSize; this.bufferSize = bufferSize;
this.mapGroupSize = mapGroupSize;
this.nThreads = nThreads; this.nThreads = nThreads;
this.start = start; this.start = start;
this.end = end; this.end = end;
this.expectedResult = sum2x(start, end); this.expectedResult = sum2x(start, end);
setName(String.format("%s nt=%d buf=%d mapGroupSize=%d start=%d end=%d sum=%d", setName(String.format("%s nt=%d buf=%d start=%d end=%d sum=%d",
getClass().getSimpleName(), nThreads, bufferSize, mapGroupSize, start, end, expectedResult)); getClass().getSimpleName(), nThreads, bufferSize, start, end, expectedResult));
} }
public Iterator<Integer> makeReader() { public Iterator<Integer> makeReader() {
@ -69,14 +71,10 @@ public class NanoSchedulerUnitTest extends BaseTest {
@DataProvider(name = "NanoSchedulerBasicTest") @DataProvider(name = "NanoSchedulerBasicTest")
public Object[][] createNanoSchedulerBasicTest() { public Object[][] createNanoSchedulerBasicTest() {
for ( final int bufferSize : Arrays.asList(1, 10, 1000, 1000000) ) { for ( final int bufferSize : Arrays.asList(1, 10, 1000, 1000000) ) {
for ( final int mapGroupSize : Arrays.asList(-1, 1, 10, 100, 1000) ) {
if ( mapGroupSize <= bufferSize ) {
for ( final int nt : Arrays.asList(1, 2, 4) ) { for ( final int nt : Arrays.asList(1, 2, 4) ) {
for ( final int start : Arrays.asList(0) ) { for ( final int start : Arrays.asList(0) ) {
for ( final int end : Arrays.asList(1, 2, 11, 10000, 100000) ) { for ( final int end : Arrays.asList(1, 2, 11, 10000, 100000) ) {
exampleTest = new NanoSchedulerBasicTest(bufferSize, mapGroupSize, nt, start, end); exampleTest = new NanoSchedulerBasicTest(bufferSize, nt, start, end);
}
}
} }
} }
} }
@ -101,10 +99,9 @@ public class NanoSchedulerUnitTest extends BaseTest {
private void testNanoScheduler(final NanoSchedulerBasicTest test) throws InterruptedException { private void testNanoScheduler(final NanoSchedulerBasicTest test) throws InterruptedException {
final NanoScheduler<Integer, Integer, Integer> nanoScheduler = final NanoScheduler<Integer, Integer, Integer> nanoScheduler =
new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.mapGroupSize, test.nThreads); new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.nThreads);
Assert.assertEquals(nanoScheduler.getBufferSize(), test.bufferSize, "bufferSize argument"); Assert.assertEquals(nanoScheduler.getBufferSize(), test.bufferSize, "bufferSize argument");
Assert.assertTrue(nanoScheduler.getMapGroupSize() >= test.mapGroupSize, "mapGroupSize argument");
Assert.assertEquals(nanoScheduler.getnThreads(), test.nThreads, "nThreads argument"); Assert.assertEquals(nanoScheduler.getnThreads(), test.nThreads, "nThreads argument");
final Integer sum = nanoScheduler.execute(test.makeReader(), test.makeMap(), test.initReduce(), test.makeReduce()); final Integer sum = nanoScheduler.execute(test.makeReader(), test.makeMap(), test.initReduce(), test.makeReduce());
@ -115,11 +112,11 @@ public class NanoSchedulerUnitTest extends BaseTest {
@Test(enabled = true, dataProvider = "NanoSchedulerBasicTest", dependsOnMethods = "testMultiThreadedNanoScheduler", timeOut = NANO_SCHEDULE_MAX_RUNTIME) @Test(enabled = true, dataProvider = "NanoSchedulerBasicTest", dependsOnMethods = "testMultiThreadedNanoScheduler", timeOut = NANO_SCHEDULE_MAX_RUNTIME)
public void testNanoSchedulerInLoop(final NanoSchedulerBasicTest test) throws InterruptedException { public void testNanoSchedulerInLoop(final NanoSchedulerBasicTest test) throws InterruptedException {
if ( test.bufferSize > 1 && (test.mapGroupSize > 1 || test.mapGroupSize == -1)) { if ( test.bufferSize > 1) {
logger.warn("Running " + test); logger.warn("Running " + test);
final NanoScheduler<Integer, Integer, Integer> nanoScheduler = final NanoScheduler<Integer, Integer, Integer> nanoScheduler =
new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.mapGroupSize, test.nThreads); new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.nThreads);
// test reusing the scheduler // test reusing the scheduler
for ( int i = 0; i < 10; i++ ) { for ( int i = 0; i < 10; i++ ) {
@ -134,7 +131,7 @@ public class NanoSchedulerUnitTest extends BaseTest {
@Test(timeOut = NANO_SCHEDULE_MAX_RUNTIME) @Test(timeOut = NANO_SCHEDULE_MAX_RUNTIME)
public void testShutdown() throws InterruptedException { public void testShutdown() throws InterruptedException {
final NanoScheduler<Integer, Integer, Integer> nanoScheduler = new NanoScheduler<Integer, Integer, Integer>(1, 1, 2); final NanoScheduler<Integer, Integer, Integer> nanoScheduler = new NanoScheduler<Integer, Integer, Integer>(1, 2);
Assert.assertFalse(nanoScheduler.isShutdown(), "scheduler should be alive"); Assert.assertFalse(nanoScheduler.isShutdown(), "scheduler should be alive");
nanoScheduler.shutdown(); nanoScheduler.shutdown();
Assert.assertTrue(nanoScheduler.isShutdown(), "scheduler should be dead"); Assert.assertTrue(nanoScheduler.isShutdown(), "scheduler should be dead");
@ -142,15 +139,15 @@ public class NanoSchedulerUnitTest extends BaseTest {
@Test(expectedExceptions = IllegalStateException.class, timeOut = NANO_SCHEDULE_MAX_RUNTIME) @Test(expectedExceptions = IllegalStateException.class, timeOut = NANO_SCHEDULE_MAX_RUNTIME)
public void testShutdownExecuteFailure() throws InterruptedException { public void testShutdownExecuteFailure() throws InterruptedException {
final NanoScheduler<Integer, Integer, Integer> nanoScheduler = new NanoScheduler<Integer, Integer, Integer>(1, 1, 2); final NanoScheduler<Integer, Integer, Integer> nanoScheduler = new NanoScheduler<Integer, Integer, Integer>(1, 2);
nanoScheduler.shutdown(); nanoScheduler.shutdown();
nanoScheduler.execute(exampleTest.makeReader(), exampleTest.makeMap(), exampleTest.initReduce(), exampleTest.makeReduce()); nanoScheduler.execute(exampleTest.makeReader(), exampleTest.makeMap(), exampleTest.initReduce(), exampleTest.makeReduce());
} }
public static void main(String [ ] args) { public static void main(String [ ] args) {
final NanoSchedulerBasicTest test = new NanoSchedulerBasicTest(1000, 100, Integer.valueOf(args[0]), 0, Integer.valueOf(args[1])); final NanoSchedulerBasicTest test = new NanoSchedulerBasicTest(1000, Integer.valueOf(args[0]), 0, Integer.valueOf(args[1]));
final NanoScheduler<Integer, Integer, Integer> nanoScheduler = final NanoScheduler<Integer, Integer, Integer> nanoScheduler =
new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.mapGroupSize, test.nThreads); new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.nThreads);
final Integer sum = nanoScheduler.execute(test.makeReader(), test.makeMap(), test.initReduce(), test.makeReduce()); final Integer sum = nanoScheduler.execute(test.makeReader(), test.makeMap(), test.initReduce(), test.makeReduce());
System.out.printf("Sum = %d, expected =%d%n", sum, test.expectedResult); System.out.printf("Sum = %d, expected =%d%n", sum, test.expectedResult);