diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/MapFunction.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/MapFunction.java new file mode 100644 index 000000000..dd18e09a9 --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/MapFunction.java @@ -0,0 +1,12 @@ +package org.broadinstitute.sting.utils.nanoScheduler; + +/** + * A function that maps from InputType -> ResultType + * + * User: depristo + * Date: 8/24/12 + * Time: 9:49 AM + */ +public interface MapFunction { + public ResultType apply(final InputType input); +} diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/MapResult.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/MapResult.java new file mode 100644 index 000000000..90e7c5908 --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/MapResult.java @@ -0,0 +1,31 @@ +package org.broadinstitute.sting.utils.nanoScheduler; + +/** + * Created with IntelliJ IDEA. + * User: depristo + * Date: 8/24/12 + * Time: 9:57 AM + * To change this template use File | Settings | File Templates. + */ +public class MapResult implements Comparable> { + final Integer id; + final MapType value; + + public MapResult(final int id, final MapType value) { + this.id = id; + this.value = value; + } + + public Integer getId() { + return id; + } + + public MapType getValue() { + return value; + } + + @Override + public int compareTo(MapResult o) { + return getId().compareTo(o.getId()); + } +} diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java new file mode 100644 index 000000000..48a941515 --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java @@ -0,0 +1,165 @@ +package org.broadinstitute.sting.utils.nanoScheduler; + +import com.google.java.contract.Ensures; +import com.google.java.contract.Requires; +import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; + +import java.util.Iterator; +import java.util.LinkedList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.*; + +/** + * Framework for very fine grained MapReduce parallelism + * + * User: depristo + * Date: 8/24/12 + * Time: 9:47 AM + */ +public class NanoScheduler { + final int bufferSize; + final int nThreads; + final Iterator inputReader; + final MapFunction map; + final ReduceFunction reduce; + + public NanoScheduler(final int bufferSize, + final int nThreads, + final Iterator inputReader, + final MapFunction map, + final ReduceFunction reduce) { + if ( bufferSize < 1 ) throw new IllegalArgumentException("bufferSize must be >= 1, got " + bufferSize); + if ( nThreads < 1 ) throw new IllegalArgumentException("nThreads must be >= 1, got " + nThreads); + + this.bufferSize = bufferSize; + this.inputReader = inputReader; + this.map = map; + this.reduce = reduce; + this.nThreads = nThreads; + } + + public int getnThreads() { + return nThreads; + } + + private int getBufferSize() { + return bufferSize; + } + + public ReduceType execute() { + if ( getnThreads() == 1 ) { + return executeSingleThreaded(); + } else { + return executeMultiThreaded(); + } + } + + /** + * Simple efficient reference implementation for single threaded execution + * @return the reduce result of this map/reduce job + */ + private ReduceType executeSingleThreaded() { + ReduceType sum = reduce.init(); + while ( inputReader.hasNext() ) { + final InputType input = inputReader.next(); + final MapType mapValue = map.apply(input); + sum = reduce.apply(mapValue, sum); + } + return sum; + } + + /** + * Efficient parallel version of Map/Reduce + * + * @return the reduce result of this map/reduce job + */ + private ReduceType executeMultiThreaded() { + final ExecutorService executor = Executors.newFixedThreadPool(getnThreads() - 1); + + ReduceType sum = reduce.init(); + while ( inputReader.hasNext() ) { + try { + // read in our input values + final Queue inputs = readInputs(); + + // send jobs for map + final Queue> mapQueue = submitMapJobs(executor, inputs); + + // send off the reduce job, and block until we get at least one reduce result + sum = reduceParallel(mapQueue, sum); + } catch (InterruptedException ex) { + throw new ReviewedStingException("got execution exception", ex); + } catch (ExecutionException ex) { + throw new ReviewedStingException("got execution exception", ex); + } + } + + final List remaining = executor.shutdownNow(); + if ( ! remaining.isEmpty() ) + throw new ReviewedStingException("Remaining tasks found in the executor, unexpected behavior!"); + + return sum; + } + + @Requires("! mapQueue.isEmpty()") + private ReduceType reduceParallel(final Queue> mapQueue, final ReduceType initSum) + throws InterruptedException, ExecutionException { + ReduceType sum = initSum; + + // while mapQueue has something in it to reduce + for ( final Future future : mapQueue ) { + // block until we get the value for this task + final MapType value = future.get(); + sum = reduce.apply(value, sum); + } + + return sum; + } + + /** + * Read up to inputBufferSize elements from inputReader + * + * @return a queue of inputs read in, containing one or more values of InputType read in + */ + @Requires("inputReader.hasNext()") + @Ensures("!result.isEmpty()") + private Queue readInputs() { + int n = 0; + final Queue inputs = new LinkedList(); + while ( inputReader.hasNext() && n < getBufferSize() ) { + final InputType input = inputReader.next(); + inputs.add(input); + n++; + } + return inputs; + } + + @Ensures("result.size() == inputs.size()") + private Queue> submitMapJobs(final ExecutorService executor, final Queue inputs) { + final Queue> mapQueue = new LinkedList>(); + + for ( final InputType input : inputs ) { + final CallableMap doMap = new CallableMap(input); + final Future future = executor.submit(doMap); + mapQueue.add(future); + } + + return mapQueue; + } + + /** + * A simple callable version of the map function for use with the executor pool + */ + private class CallableMap implements Callable { + final InputType input; + + private CallableMap(final InputType input) { + this.input = input; + } + + @Override public MapType call() throws Exception { + return map.apply(input); + } + } +} diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/ReduceFunction.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/ReduceFunction.java new file mode 100644 index 000000000..274e22aff --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/ReduceFunction.java @@ -0,0 +1,13 @@ +package org.broadinstitute.sting.utils.nanoScheduler; + +/** + * A function that maps from InputType -> ResultType + * + * User: depristo + * Date: 8/24/12 + * Time: 9:49 AM + */ +public interface ReduceFunction { + public ReduceType init(); + public ReduceType apply(MapType one, ReduceType sum); +} diff --git a/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java new file mode 100644 index 000000000..18a9f3340 --- /dev/null +++ b/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java @@ -0,0 +1,93 @@ +package org.broadinstitute.sting.utils.nanoScheduler; + +import org.broadinstitute.sting.BaseTest; +import org.testng.Assert; +import org.testng.annotations.DataProvider; +import org.testng.annotations.Test; + +import java.util.*; + +/** + * UnitTests for the NanoScheduler + * + * User: depristo + * Date: 8/24/12 + * Time: 11:25 AM + * To change this template use File | Settings | File Templates. + */ +public class NanoSchedulerUnitTest extends BaseTest { + private class Map2x implements MapFunction { + @Override public Integer apply(Integer input) { return input * 2; } + } + + private class ReduceSum implements ReduceFunction { + @Override public Integer init() { return 0; } + @Override public Integer apply(Integer one, Integer sum) { return one + sum; } + } + + private static int sum2x(final int start, final int end) { + int sum = 0; + for ( int i = start; i < end; i++ ) + sum += 2 * i; + return sum; + } + + private class NanoSchedulerBasicTest extends TestDataProvider { + final int bufferSize, nThreads, start, end, expectedResult; + + public NanoSchedulerBasicTest(final int bufferSize, final int nThreads, final int start, final int end) { + super(NanoSchedulerBasicTest.class); + this.bufferSize = bufferSize; + this.nThreads = nThreads; + this.start = start; + this.end = end; + this.expectedResult = sum2x(start, end); + setName(String.format("%s nt=%d buf=%d start=%d end=%d sum=%d", + getClass().getSimpleName(), nThreads, bufferSize, start, end, expectedResult)); + } + + public Iterator makeReader() { + final List ints = new ArrayList(); + for ( int i = start; i < end; i++ ) + ints.add(i); + return ints.iterator(); + } + + public Map2x makeMap() { return new Map2x(); } + public ReduceSum makeReduce() { return new ReduceSum(); } + } + + @DataProvider(name = "NanoSchedulerBasicTest") + public Object[][] createNanoSchedulerBasicTest() { + for ( final int bufferSize : Arrays.asList(1, 10, 10000, 1000000) ) { + for ( final int nt : Arrays.asList(1, 2, 4, 8, 16, 32) ) { + for ( final int start : Arrays.asList(0) ) { + for ( final int end : Arrays.asList(1, 2, 11, 1000000) ) { + new NanoSchedulerBasicTest(bufferSize, nt, start, end); + } + } + } + } + + return NanoSchedulerBasicTest.getTests(NanoSchedulerBasicTest.class); + } + + @Test(enabled = true, dataProvider = "NanoSchedulerBasicTest", timeOut = 2000) + public void testNanoSchedulerBasicTest(final NanoSchedulerBasicTest test) throws InterruptedException { + logger.warn("Running " + test); + final NanoScheduler nanoScheduler = + new NanoScheduler(test.bufferSize, test.nThreads, + test.makeReader(), test.makeMap(), test.makeReduce()); + final Integer sum = nanoScheduler.execute(); + Assert.assertNotNull(sum); + Assert.assertEquals((int)sum, test.expectedResult, "NanoScheduler sum not the same as calculated directly"); + } + + @Test(enabled = true, dataProvider = "NanoSchedulerBasicTest", timeOut = 10000, dependsOnMethods = "testNanoSchedulerBasicTest") + public void testNanoSchedulerInLoop(final NanoSchedulerBasicTest test) throws InterruptedException { + logger.warn("Running " + test); + for ( int i = 0; i < 10; i++ ) { + testNanoSchedulerBasicTest(test); + } + } +}