From 8cdeb51b78696340d9303d44342095bb82a40671 Mon Sep 17 00:00:00 2001 From: Mark DePristo Date: Tue, 4 Sep 2012 14:50:06 -0400 Subject: [PATCH] Cleanup printProgress in TraversalEngine -- Separate updating cumulative traversal metrics from printing progress. There's now an updateCumulativeMetrics function and a printProgress() that only takes a current position -- printProgress now soles relies on the time since the last progress to decide if it will print or not. No longer uses the number of cycles, since this isn't reliable in the case of nano scheduling -- GenomeAnalysisEngine now maintains a pointer to the master cumulative metrics. getCumulativeMetrics never returns null, which was handled in some parts of the code but not others. -- Update all of the traversals to use the new updateCumulativeMetrics, printProgress model -- Added progress callback to nano scheduler. Every bufferSize elements this callback is invoked, allowing us to smoothly update the progress meter in the NanoScheduler -- Rename MapFunction to NanoSchedulerMap and the same for reduce. --- .../sting/gatk/GenomeAnalysisEngine.java | 7 +- .../gatk/traversals/TraversalEngine.java | 152 ++++++++---------- .../traversals/TraverseActiveRegions.java | 3 +- .../gatk/traversals/TraverseDuplicates.java | 3 +- .../gatk/traversals/TraverseLociBase.java | 1 + .../gatk/traversals/TraverseLociLinear.java | 3 +- .../gatk/traversals/TraverseLociNano.java | 25 +-- .../gatk/traversals/TraverseReadPairs.java | 3 +- .../sting/gatk/traversals/TraverseReads.java | 7 +- .../gatk/traversals/TraverseReadsNano.java | 14 +- .../utils/nanoScheduler/NanoScheduler.java | 35 ++-- ...ion.java => NanoSchedulerMapFunction.java} | 2 +- .../NanoSchedulerProgressFunction.java | 12 ++ ....java => NanoSchedulerReduceFunction.java} | 2 +- .../nanoScheduler/NanoSchedulerUnitTest.java | 4 +- 15 files changed, 153 insertions(+), 120 deletions(-) rename public/java/src/org/broadinstitute/sting/utils/nanoScheduler/{MapFunction.java => NanoSchedulerMapFunction.java} (84%) create mode 100644 public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerProgressFunction.java rename public/java/src/org/broadinstitute/sting/utils/nanoScheduler/{ReduceFunction.java => NanoSchedulerReduceFunction.java} (87%) diff --git a/public/java/src/org/broadinstitute/sting/gatk/GenomeAnalysisEngine.java b/public/java/src/org/broadinstitute/sting/gatk/GenomeAnalysisEngine.java index b9b5e452d..1b4333ce2 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/GenomeAnalysisEngine.java +++ b/public/java/src/org/broadinstitute/sting/gatk/GenomeAnalysisEngine.java @@ -143,6 +143,8 @@ public class GenomeAnalysisEngine { */ private ThreadAllocation threadAllocation; + private ReadMetrics cumulativeMetrics = null; + /** * A currently hacky unique name for this GATK instance */ @@ -1035,7 +1037,10 @@ public class GenomeAnalysisEngine { * owned by the caller; the caller can do with the object what they wish. */ public ReadMetrics getCumulativeMetrics() { - return readsDataSource == null ? null : readsDataSource.getCumulativeReadMetrics(); + // todo -- probably shouldn't be lazy + if ( cumulativeMetrics == null ) + cumulativeMetrics = readsDataSource == null ? new ReadMetrics() : readsDataSource.getCumulativeReadMetrics(); + return cumulativeMetrics; } /** diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraversalEngine.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraversalEngine.java index 198f9342e..4422d49ae 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraversalEngine.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraversalEngine.java @@ -44,24 +44,12 @@ import java.util.List; import java.util.Map; public abstract class TraversalEngine,ProviderType extends ShardDataProvider> { + /** our log, which we want to capture anything from this class */ + protected static final Logger logger = Logger.getLogger(TraversalEngine.class); + // Time in milliseconds since we initialized this engine private static final int HISTORY_WINDOW_SIZE = 50; - private static class ProcessingHistory { - double elapsedSeconds; - long unitsProcessed; - long bpProcessed; - GenomeLoc loc; - - public ProcessingHistory(double elapsedSeconds, GenomeLoc loc, long unitsProcessed, long bpProcessed) { - this.elapsedSeconds = elapsedSeconds; - this.loc = loc; - this.unitsProcessed = unitsProcessed; - this.bpProcessed = bpProcessed; - } - - } - /** lock object to sure updates to history are consistent across threads */ private static final Object lock = new Object(); LinkedList history = new LinkedList(); @@ -70,13 +58,12 @@ public abstract class TraversalEngine,Provide private SimpleTimer timer = null; // How long can we go without printing some progress info? - private static final int PRINT_PROGRESS_CHECK_FREQUENCY_IN_CYCLES = 1000; - private int printProgressCheckCounter = 0; private long lastProgressPrintTime = -1; // When was the last time we printed progress log? - private long MIN_ELAPSED_TIME_BEFORE_FIRST_PROGRESS = 30 * 1000; // in milliseconds - private long PROGRESS_PRINT_FREQUENCY = 10 * 1000; // in milliseconds - private final double TWO_HOURS_IN_SECONDS = 2.0 * 60.0 * 60.0; - private final double TWELVE_HOURS_IN_SECONDS = 12.0 * 60.0 * 60.0; + + private final static long MIN_ELAPSED_TIME_BEFORE_FIRST_PROGRESS = 30 * 1000; // in milliseconds + private final static double TWO_HOURS_IN_SECONDS = 2.0 * 60.0 * 60.0; + private final static double TWELVE_HOURS_IN_SECONDS = 12.0 * 60.0 * 60.0; + private long progressPrintFrequency = 10 * 1000; // in milliseconds private boolean progressMeterInitialized = false; // for performance log @@ -85,15 +72,12 @@ public abstract class TraversalEngine,Provide private File performanceLogFile; private PrintStream performanceLog = null; private long lastPerformanceLogPrintTime = -1; // When was the last time we printed to the performance log? - private final long PERFORMANCE_LOG_PRINT_FREQUENCY = PROGRESS_PRINT_FREQUENCY; // in milliseconds + private final long PERFORMANCE_LOG_PRINT_FREQUENCY = progressPrintFrequency; // in milliseconds /** Size, in bp, of the area we are processing. Updated once in the system in initial for performance reasons */ long targetSize = -1; GenomeLocSortedSet targetIntervals = null; - /** our log, which we want to capture anything from this class */ - protected static final Logger logger = Logger.getLogger(TraversalEngine.class); - protected GenomeAnalysisEngine engine; // ---------------------------------------------------------------------------------------------------- @@ -187,28 +171,34 @@ public abstract class TraversalEngine,Provide } /** - * Forward request to printProgress + * Update the cumulative traversal metrics according to the data in this shard * - * Assumes that one cycle has been completed - * - * @param shard the given shard currently being processed. - * @param loc the location + * @param shard a non-null shard */ - public void printProgress(Shard shard, GenomeLoc loc) { - // A bypass is inserted here for unit testing. - printProgress(loc,shard.getReadMetrics(),false, 1); + public void updateCumulativeMetrics(final Shard shard) { + updateCumulativeMetrics(shard.getReadMetrics()); + } + + /** + * Update the cumulative traversal metrics according to the data in this shard + * + * @param singleTraverseMetrics read metrics object containing the information about a single shard's worth + * of data processing + */ + public void updateCumulativeMetrics(final ReadMetrics singleTraverseMetrics) { + engine.getCumulativeMetrics().incrementMetrics(singleTraverseMetrics); } /** * Forward request to printProgress * - * @param shard the given shard currently being processed. + * Assumes that one cycle has been completed + * * @param loc the location - * @param nElapsedCycles the number of cycles (turns of map) that have occurred since the last call */ - public void printProgress(Shard shard, GenomeLoc loc, int nElapsedCycles) { + public void printProgress(final GenomeLoc loc) { // A bypass is inserted here for unit testing. - printProgress(loc,shard.getReadMetrics(),false, nElapsedCycles); + printProgress(loc, false); } /** @@ -216,18 +206,9 @@ public abstract class TraversalEngine,Provide * every M seconds, for N and M set in global variables. * * @param loc Current location, can be null if you are at the end of the traversal - * @param metrics Data processed since the last cumulative * @param mustPrint If true, will print out info, regardless of nRecords or time interval */ - private synchronized void printProgress(GenomeLoc loc, ReadMetrics metrics, boolean mustPrint, int nElapsedCycles) { - final int previousPrintCycle = printProgressCheckCounter / PRINT_PROGRESS_CHECK_FREQUENCY_IN_CYCLES; - final int newPrintCycle = (printProgressCheckCounter+nElapsedCycles) / PRINT_PROGRESS_CHECK_FREQUENCY_IN_CYCLES; - - printProgressCheckCounter += nElapsedCycles; // keep track of our number of cycles through printProgress - if ( newPrintCycle == previousPrintCycle && ! mustPrint ) - // don't do any work more often than PRINT_PROGRESS_CHECK_FREQUENCY_IN_CYCLES - return; - + private synchronized void printProgress(final GenomeLoc loc, boolean mustPrint) { if( ! progressMeterInitialized ) { logger.info("[INITIALIZATION COMPLETE; TRAVERSAL STARTING]"); logger.info(String.format("%15s processed.%s runtime per.1M.%s completed total.runtime remaining", @@ -236,37 +217,30 @@ public abstract class TraversalEngine,Provide } final long curTime = timer.currentTime(); - boolean printProgress = mustPrint || maxElapsedIntervalForPrinting(curTime, lastProgressPrintTime, PROGRESS_PRINT_FREQUENCY); + boolean printProgress = mustPrint || maxElapsedIntervalForPrinting(curTime, lastProgressPrintTime, progressPrintFrequency); boolean printLog = performanceLog != null && maxElapsedIntervalForPrinting(curTime, lastPerformanceLogPrintTime, PERFORMANCE_LOG_PRINT_FREQUENCY); if ( printProgress || printLog ) { - // getting and appending metrics data actually turns out to be quite a heavyweight - // operation. Postpone it until after determining whether to print the log message. - ReadMetrics cumulativeMetrics = engine.getCumulativeMetrics() != null ? engine.getCumulativeMetrics() : new ReadMetrics(); - if(metrics != null) - cumulativeMetrics.incrementMetrics(metrics); - - final long nRecords = cumulativeMetrics.getNumIterations(); - - ProcessingHistory last = updateHistory(loc,cumulativeMetrics); + final ProcessingHistory last = updateHistory(loc, engine.getCumulativeMetrics()); final AutoFormattingTime elapsed = new AutoFormattingTime(last.elapsedSeconds); - final AutoFormattingTime bpRate = new AutoFormattingTime(secondsPerMillionBP(last)); - final AutoFormattingTime unitRate = new AutoFormattingTime(secondsPerMillionElements(last)); - final double fractionGenomeTargetCompleted = calculateFractionGenomeTargetCompleted(last); + final AutoFormattingTime bpRate = new AutoFormattingTime(last.secondsPerMillionBP()); + final AutoFormattingTime unitRate = new AutoFormattingTime(last.secondsPerMillionElements()); + final double fractionGenomeTargetCompleted = last.calculateFractionGenomeTargetCompleted(targetSize); final AutoFormattingTime estTotalRuntime = new AutoFormattingTime(elapsed.getTimeInSeconds() / fractionGenomeTargetCompleted); final AutoFormattingTime timeToCompletion = new AutoFormattingTime(estTotalRuntime.getTimeInSeconds() - elapsed.getTimeInSeconds()); + final long nRecords = engine.getCumulativeMetrics().getNumIterations(); if ( printProgress ) { lastProgressPrintTime = curTime; // dynamically change the update rate so that short running jobs receive frequent updates while longer jobs receive fewer updates if ( estTotalRuntime.getTimeInSeconds() > TWELVE_HOURS_IN_SECONDS ) - PROGRESS_PRINT_FREQUENCY = 60 * 1000; // in milliseconds + progressPrintFrequency = 60 * 1000; // in milliseconds else if ( estTotalRuntime.getTimeInSeconds() > TWO_HOURS_IN_SECONDS ) - PROGRESS_PRINT_FREQUENCY = 30 * 1000; // in milliseconds + progressPrintFrequency = 30 * 1000; // in milliseconds else - PROGRESS_PRINT_FREQUENCY = 10 * 1000; // in milliseconds + progressPrintFrequency = 10 * 1000; // in milliseconds final String posName = loc == null ? (mustPrint ? "done" : "unmapped reads") : Integer.toString(loc.getStart()); logger.info(String.format("%15s %5.2e %s %s %5.1f%% %s %s", @@ -296,7 +270,7 @@ public abstract class TraversalEngine,Provide * @param metrics information about what's been processed already * @return */ - private final ProcessingHistory updateHistory(GenomeLoc loc, ReadMetrics metrics) { + private ProcessingHistory updateHistory(GenomeLoc loc, ReadMetrics metrics) { synchronized (lock) { if ( history.size() > HISTORY_WINDOW_SIZE ) history.pop(); @@ -309,26 +283,11 @@ public abstract class TraversalEngine,Provide } } - /** How long in seconds to process 1M traversal units? */ - private final double secondsPerMillionElements(ProcessingHistory last) { - return (last.elapsedSeconds * 1000000.0) / Math.max(last.unitsProcessed, 1); - } - - /** How long in seconds to process 1M bp on the genome? */ - private final double secondsPerMillionBP(ProcessingHistory last) { - return (last.elapsedSeconds * 1000000.0) / Math.max(last.bpProcessed, 1); - } - - /** What fractoin of the target intervals have we covered? */ - private final double calculateFractionGenomeTargetCompleted(ProcessingHistory last) { - return (1.0*last.bpProcessed) / targetSize; - } - /** * Called after a traversal to print out information about the traversal process */ public void printOnTraversalDone() { - printProgress(null, null, true, 1); + printProgress(null, true); final double elapsed = timer == null ? 0 : timer.getElapsedTime(); @@ -389,7 +348,7 @@ public abstract class TraversalEngine,Provide * @return Frequency, in seconds, of performance log writes. */ public long getPerformanceProgressPrintFrequencySeconds() { - return PROGRESS_PRINT_FREQUENCY; + return progressPrintFrequency; } /** @@ -397,6 +356,35 @@ public abstract class TraversalEngine,Provide * @param seconds number of seconds between messages indicating performance frequency. */ public void setPerformanceProgressPrintFrequencySeconds(long seconds) { - PROGRESS_PRINT_FREQUENCY = seconds; + progressPrintFrequency = seconds; + } + + private static class ProcessingHistory { + double elapsedSeconds; + long unitsProcessed; + long bpProcessed; + GenomeLoc loc; + + public ProcessingHistory(double elapsedSeconds, GenomeLoc loc, long unitsProcessed, long bpProcessed) { + this.elapsedSeconds = elapsedSeconds; + this.loc = loc; + this.unitsProcessed = unitsProcessed; + this.bpProcessed = bpProcessed; + } + + /** How long in seconds to process 1M traversal units? */ + private double secondsPerMillionElements() { + return (elapsedSeconds * 1000000.0) / Math.max(unitsProcessed, 1); + } + + /** How long in seconds to process 1M bp on the genome? */ + private double secondsPerMillionBP() { + return (elapsedSeconds * 1000000.0) / Math.max(bpProcessed, 1); + } + + /** What fractoin of the target intervals have we covered? */ + private double calculateFractionGenomeTargetCompleted(final long targetSize) { + return (1.0*bpProcessed) / targetSize; + } } } diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseActiveRegions.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseActiveRegions.java index ecaa15fe9..bbd9346b3 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseActiveRegions.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseActiveRegions.java @@ -104,7 +104,8 @@ public class TraverseActiveRegions extends TraversalEngine extends TraversalEngine extends TraversalEngine result = traverse( walker, locusView, referenceView, referenceOrderedDataView, sum ); sum = result.reduceResult; dataProvider.getShard().getReadMetrics().incrementNumIterations(result.numIterations); + updateCumulativeMetrics(dataProvider.getShard()); } // We have a final map call to execute here to clean up the skipped based from the diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociLinear.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociLinear.java index 1dec3b238..22381092f 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociLinear.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociLinear.java @@ -39,8 +39,7 @@ public class TraverseLociLinear extends TraverseLociBase { done = walker.isDone(); } - // TODO -- refactor printProgress to separate updating read metrics from printing progress - //printProgress(dataProvider.getShard(),locus.getLocation()); + printProgress(locus.getLocation()); } return new TraverseResults(numIterations, sum); diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociNano.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociNano.java index 4e6eb1915..73b73c002 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociNano.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociNano.java @@ -8,9 +8,10 @@ import org.broadinstitute.sting.gatk.datasources.providers.ReferenceOrderedView; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.walkers.LocusWalker; import org.broadinstitute.sting.utils.GenomeLoc; -import org.broadinstitute.sting.utils.nanoScheduler.MapFunction; import org.broadinstitute.sting.utils.nanoScheduler.NanoScheduler; -import org.broadinstitute.sting.utils.nanoScheduler.ReduceFunction; +import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerMapFunction; +import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerProgressFunction; +import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerReduceFunction; import java.util.Iterator; @@ -26,6 +27,7 @@ public class TraverseLociNano extends TraverseLociBase { public TraverseLociNano(int nThreads) { nanoScheduler = new NanoScheduler(BUFFER_SIZE, nThreads); + nanoScheduler.setProgressFunction(new TraverseLociProgress()); } @Override @@ -41,11 +43,6 @@ public class TraverseLociNano extends TraverseLociBase { final MapDataIterator inputIterator = new MapDataIterator(locusView, referenceView, referenceOrderedDataView); final T result = nanoScheduler.execute(inputIterator, myMap, sum, myReduce); - // todo -- how do I print progress? -// final GATKSAMRecord lastRead = aggregatedInputs.get(aggregatedInputs.size() - 1).read; -// final GenomeLoc locus = engine.getGenomeLocParser().createGenomeLoc(lastRead); -// printProgress(dataProvider.getShard(), locus, aggregatedInputs.size()); - return new TraverseResults(inputIterator.numIterations, result); } @@ -156,7 +153,7 @@ public class TraverseLociNano extends TraverseLociBase { * * Applies walker.map to MapData, returning a MapResult object containing the result */ - private class TraverseLociMap implements MapFunction { + private class TraverseLociMap implements NanoSchedulerMapFunction { final LocusWalker walker; private TraverseLociMap(LocusWalker walker) { @@ -177,11 +174,11 @@ public class TraverseLociNano extends TraverseLociBase { } /** - * ReduceFunction for TraverseReads meeting NanoScheduler interface requirements + * NanoSchedulerReduceFunction for TraverseReads meeting NanoScheduler interface requirements * * Takes a MapResult object and applies the walkers reduce function to each map result, when applicable */ - private class TraverseLociReduce implements ReduceFunction { + private class TraverseLociReduce implements NanoSchedulerReduceFunction { final LocusWalker walker; private TraverseLociReduce(LocusWalker walker) { @@ -197,4 +194,12 @@ public class TraverseLociNano extends TraverseLociBase { return sum; } } + + private class TraverseLociProgress implements NanoSchedulerProgressFunction { + @Override + public void progress(MapData lastProcessedMap) { + if (lastProcessedMap.alignmentContext != null) + printProgress(lastProcessedMap.alignmentContext.getLocation()); + } + } } diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadPairs.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadPairs.java index ebaac40af..9b076fce4 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadPairs.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadPairs.java @@ -65,7 +65,8 @@ public class TraverseReadPairs extends TraversalEngine extends TraversalEngine,Read sum = walker.reduce(x, sum); } - GenomeLoc locus = read.getReferenceIndex() == SAMRecord.NO_ALIGNMENT_REFERENCE_INDEX ? null : engine.getGenomeLocParser().createGenomeLoc(read.getReferenceName(),read.getAlignmentStart()); - printProgress(dataProvider.getShard(),locus); + final GenomeLoc locus = read.getReferenceIndex() == SAMRecord.NO_ALIGNMENT_REFERENCE_INDEX ? null : engine.getGenomeLocParser().createGenomeLoc(read.getReferenceName(),read.getAlignmentStart()); + + updateCumulativeMetrics(dataProvider.getShard()); + printProgress(locus); + done = walker.isDone(); } return sum; diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadsNano.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadsNano.java index 4bb700c37..5679747e1 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadsNano.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadsNano.java @@ -35,9 +35,9 @@ import org.broadinstitute.sting.gatk.datasources.reads.ReadShard; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.walkers.ReadWalker; import org.broadinstitute.sting.utils.GenomeLoc; -import org.broadinstitute.sting.utils.nanoScheduler.MapFunction; import org.broadinstitute.sting.utils.nanoScheduler.NanoScheduler; -import org.broadinstitute.sting.utils.nanoScheduler.ReduceFunction; +import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerMapFunction; +import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerReduceFunction; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; import java.util.LinkedList; @@ -94,7 +94,9 @@ public class TraverseReadsNano extends TraversalEngine, final GATKSAMRecord lastRead = aggregatedInputs.get(aggregatedInputs.size() - 1).read; final GenomeLoc locus = engine.getGenomeLocParser().createGenomeLoc(lastRead); - printProgress(dataProvider.getShard(), locus, aggregatedInputs.size()); + + updateCumulativeMetrics(dataProvider.getShard()); + printProgress(locus); return result; } @@ -189,7 +191,7 @@ public class TraverseReadsNano extends TraversalEngine, * * Applies walker.map to MapData, returning a MapResult object containing the result */ - private class TraverseReadsMap implements MapFunction { + private class TraverseReadsMap implements NanoSchedulerMapFunction { final ReadWalker walker; private TraverseReadsMap(ReadWalker walker) { @@ -209,11 +211,11 @@ public class TraverseReadsNano extends TraversalEngine, } /** - * ReduceFunction for TraverseReads meeting NanoScheduler interface requirements + * NanoSchedulerReduceFunction for TraverseReads meeting NanoScheduler interface requirements * * Takes a MapResult object and applies the walkers reduce function to each map result, when applicable */ - private class TraverseReadsReduce implements ReduceFunction { + private class TraverseReadsReduce implements NanoSchedulerReduceFunction { final ReadWalker walker; private TraverseReadsReduce(ReadWalker walker) { diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java index f0e77354f..f0c2a6723 100644 --- a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java @@ -54,6 +54,8 @@ public class NanoScheduler { boolean shutdown = false; boolean debug = false; + private NanoSchedulerProgressFunction progressFunction = null; + final SimpleTimer outsideSchedulerTimer = new SimpleTimer("outside"); final SimpleTimer inputTimer = new SimpleTimer("input"); final SimpleTimer mapTimer = new SimpleTimer("map"); @@ -148,6 +150,17 @@ public class NanoScheduler { this.debug = debug; } + /** + * Set the progress callback function to progressFunction + * + * The progress callback is invoked after each buffer size elements have been processed by map/reduce + * + * @param progressFunction a progress function to call, or null if you don't want any progress callback + */ + public void setProgressFunction(final NanoSchedulerProgressFunction progressFunction) { + this.progressFunction = progressFunction; + } + /** * Execute a map/reduce job with this nanoScheduler * @@ -168,9 +181,9 @@ public class NanoScheduler { * @return the last reduce value */ public ReduceType execute(final Iterator inputReader, - final MapFunction map, + final NanoSchedulerMapFunction map, final ReduceType initialValue, - final ReduceFunction reduce) { + final NanoSchedulerReduceFunction reduce) { if ( isShutdown() ) throw new IllegalStateException("execute called on already shutdown NanoScheduler"); if ( inputReader == null ) throw new IllegalArgumentException("inputReader cannot be null"); if ( map == null ) throw new IllegalArgumentException("map function cannot be null"); @@ -193,9 +206,9 @@ public class NanoScheduler { * @return the reduce result of this map/reduce job */ private ReduceType executeSingleThreaded(final Iterator inputReader, - final MapFunction map, + final NanoSchedulerMapFunction map, final ReduceType initialValue, - final ReduceFunction reduce) { + final NanoSchedulerReduceFunction reduce) { ReduceType sum = initialValue; while ( inputReader.hasNext() ) { final InputType input = inputReader.next(); @@ -211,9 +224,9 @@ public class NanoScheduler { * @return the reduce result of this map/reduce job */ private ReduceType executeMultiThreaded(final Iterator inputReader, - final MapFunction map, + final NanoSchedulerMapFunction map, final ReduceType initialValue, - final ReduceFunction reduce) { + final NanoSchedulerReduceFunction reduce) { debugPrint("Executing nanoScheduler"); ReduceType sum = initialValue; while ( inputReader.hasNext() ) { @@ -228,6 +241,8 @@ public class NanoScheduler { // send off the reduce job, and block until we get at least one reduce result sum = reduceSerial(reduce, mapQueue, sum); debugPrint(" Done with cycle of map/reduce"); + + if ( progressFunction != null ) progressFunction.progress(inputs.get(inputs.size()-1)); } catch (InterruptedException ex) { throw new ReviewedStingException("got execution exception", ex); } catch (ExecutionException ex) { @@ -239,7 +254,7 @@ public class NanoScheduler { } @Requires({"reduce != null", "! mapQueue.isEmpty()"}) - private ReduceType reduceSerial(final ReduceFunction reduce, + private ReduceType reduceSerial(final NanoSchedulerReduceFunction reduce, final Queue> mapQueue, final ReduceType initSum) throws InterruptedException, ExecutionException { @@ -280,7 +295,7 @@ public class NanoScheduler { } @Requires({"map != null", "! inputs.isEmpty()"}) - private Queue> submitMapJobs(final MapFunction map, + private Queue> submitMapJobs(final NanoSchedulerMapFunction map, final ExecutorService executor, final List inputs) { final Queue> mapQueue = new LinkedList>(); @@ -299,10 +314,10 @@ public class NanoScheduler { */ private class CallableMap implements Callable { final InputType input; - final MapFunction map; + final NanoSchedulerMapFunction map; @Requires({"map != null"}) - private CallableMap(final MapFunction map, final InputType inputs) { + private CallableMap(final NanoSchedulerMapFunction map, final InputType inputs) { this.input = inputs; this.map = map; } diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/MapFunction.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerMapFunction.java similarity index 84% rename from public/java/src/org/broadinstitute/sting/utils/nanoScheduler/MapFunction.java rename to public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerMapFunction.java index 440c263b7..ddf4421d2 100644 --- a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/MapFunction.java +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerMapFunction.java @@ -9,7 +9,7 @@ package org.broadinstitute.sting.utils.nanoScheduler; * Date: 8/24/12 * Time: 9:49 AM */ -public interface MapFunction { +public interface NanoSchedulerMapFunction { /** * Return function on input, returning a value of ResultType * @param input diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerProgressFunction.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerProgressFunction.java new file mode 100644 index 000000000..8631196a3 --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerProgressFunction.java @@ -0,0 +1,12 @@ +package org.broadinstitute.sting.utils.nanoScheduler; + +/** + * Created with IntelliJ IDEA. + * User: depristo + * Date: 9/4/12 + * Time: 2:10 PM + * To change this template use File | Settings | File Templates. + */ +public interface NanoSchedulerProgressFunction { + public void progress(final InputType lastMapInput); +} diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/ReduceFunction.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerReduceFunction.java similarity index 87% rename from public/java/src/org/broadinstitute/sting/utils/nanoScheduler/ReduceFunction.java rename to public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerReduceFunction.java index 8f1b0eddd..7e58eeaf9 100644 --- a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/ReduceFunction.java +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerReduceFunction.java @@ -7,7 +7,7 @@ package org.broadinstitute.sting.utils.nanoScheduler; * Date: 8/24/12 * Time: 9:49 AM */ -public interface ReduceFunction { +public interface NanoSchedulerReduceFunction { /** * Combine one with sum into a new ReduceType * @param one the result of a map call on an input element diff --git a/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java index 1dcc243f2..0ec3035e2 100644 --- a/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java @@ -21,11 +21,11 @@ import java.util.List; public class NanoSchedulerUnitTest extends BaseTest { public static final int NANO_SCHEDULE_MAX_RUNTIME = 60000; - private static class Map2x implements MapFunction { + private static class Map2x implements NanoSchedulerMapFunction { @Override public Integer apply(Integer input) { return input * 2; } } - private static class ReduceSum implements ReduceFunction { + private static class ReduceSum implements NanoSchedulerReduceFunction { int prevOne = Integer.MIN_VALUE; @Override public Integer apply(Integer one, Integer sum) {