Merge branch 'nanoScheduler'

This commit is contained in:
Mark DePristo 2012-09-05 21:10:08 -04:00
commit e77abfa82d
37 changed files with 1065 additions and 420 deletions

View File

@ -577,6 +577,7 @@
docletpathref="doclet.classpath"
classpathref="external.dependencies"
classpath="${java.classes}"
maxmemory="2g"
additionalparam="-build-timestamp "${build.timestamp}" -absolute-version ${build.version} -out ${basedir}/${resource.path} -quiet">
<sourcefiles>
<union>
@ -780,6 +781,7 @@
docletpathref="doclet.classpath"
classpathref="external.dependencies"
classpath="${java.classes}"
maxmemory="2g"
additionalparam="${gatkdocs.include.hidden.arg} -private -build-timestamp &quot;${build.timestamp}&quot; -absolute-version ${build.version} -quiet"> <!-- -test to only do DocumentationTest walker -->
<sourcefiles>
<fileset refid="java.source.files"/>

View File

@ -143,6 +143,8 @@ public class GenomeAnalysisEngine {
*/
private ThreadAllocation threadAllocation;
private ReadMetrics cumulativeMetrics = null;
/**
* A currently hacky unique name for this GATK instance
*/
@ -398,28 +400,22 @@ public class GenomeAnalysisEngine {
* Parse out the thread allocation from the given command-line argument.
*/
private void determineThreadAllocation() {
Tags tags = parsingEngine.getTags(argCollection.numberOfThreads);
if ( argCollection.numberOfDataThreads < 1 ) throw new UserException.BadArgumentValue("num_threads", "cannot be less than 1, but saw " + argCollection.numberOfDataThreads);
if ( argCollection.numberOfCPUThreadsPerDataThread < 1 ) throw new UserException.BadArgumentValue("num_cpu_threads", "cannot be less than 1, but saw " + argCollection.numberOfCPUThreadsPerDataThread);
if ( argCollection.numberOfIOThreads < 0 ) throw new UserException.BadArgumentValue("num_io_threads", "cannot be less than 0, but saw " + argCollection.numberOfIOThreads);
// TODO: Kill this complicated logic once Queue supports arbitrary tagged parameters.
Integer numCPUThreads = null;
if(tags.containsKey("cpu") && argCollection.numberOfCPUThreads != null)
throw new UserException("Number of CPU threads specified both directly on the command-line and as a tag to the nt argument. Please specify only one or the other.");
else if(tags.containsKey("cpu"))
numCPUThreads = Integer.parseInt(tags.getValue("cpu"));
else if(argCollection.numberOfCPUThreads != null)
numCPUThreads = argCollection.numberOfCPUThreads;
Integer numIOThreads = null;
if(tags.containsKey("io") && argCollection.numberOfIOThreads != null)
throw new UserException("Number of IO threads specified both directly on the command-line and as a tag to the nt argument. Please specify only one or the other.");
else if(tags.containsKey("io"))
numIOThreads = Integer.parseInt(tags.getValue("io"));
else if(argCollection.numberOfIOThreads != null)
numIOThreads = argCollection.numberOfIOThreads;
this.threadAllocation = new ThreadAllocation(argCollection.numberOfThreads, numCPUThreads, numIOThreads, ! argCollection.disableEfficiencyMonitor);
this.threadAllocation = new ThreadAllocation(argCollection.numberOfDataThreads,
argCollection.numberOfCPUThreadsPerDataThread,
argCollection.numberOfIOThreads,
! argCollection.disableEfficiencyMonitor);
}
public int getTotalNumberOfThreads() {
return this.threadAllocation == null ? 1 : threadAllocation.getTotalNumThreads();
}
/**
* Allow subclasses and others within this package direct access to the walker manager.
* @return The walker manager used by this package.
@ -1035,7 +1031,10 @@ public class GenomeAnalysisEngine {
* owned by the caller; the caller can do with the object what they wish.
*/
public ReadMetrics getCumulativeMetrics() {
return readsDataSource == null ? null : readsDataSource.getCumulativeReadMetrics();
// todo -- probably shouldn't be lazy
if ( cumulativeMetrics == null )
cumulativeMetrics = readsDataSource == null ? new ReadMetrics() : readsDataSource.getCumulativeReadMetrics();
return cumulativeMetrics;
}
/**

View File

@ -27,7 +27,6 @@ package org.broadinstitute.sting.gatk;
import net.sf.picard.filter.SamRecordFilter;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
@ -119,11 +118,18 @@ public class ReadMetrics implements Cloneable {
return nRecords;
}
/**
* Increments the number of 'iterations' (one call of filter/map/reduce sequence) completed.
*/
public void incrementNumIterations(final long by) {
nRecords += by;
}
/**
* Increments the number of 'iterations' (one call of filter/map/reduce sequence) completed.
*/
public void incrementNumIterations() {
nRecords++;
incrementNumIterations(1);
}
public long getNumReadsSeen() {

View File

@ -41,7 +41,9 @@ import org.broadinstitute.sting.utils.interval.IntervalMergingRule;
import org.broadinstitute.sting.utils.interval.IntervalSetRule;
import java.io.File;
import java.util.*;
import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
/**
* @author aaron
@ -197,6 +199,12 @@ public class GATKArgumentCollection {
// performance log arguments
//
// --------------------------------------------------------------------------------------------------------------
/**
* The file name for the GATK performance log output, or null if you don't want to generate the
* detailed performance logging table. This table is suitable for importing into R or any
* other analysis software that can read tsv files
*/
@Argument(fullName = "performanceLog", shortName="PF", doc="If provided, a GATK runtime performance log will be written to this file", required = false)
public File performanceLog = null;
@ -279,9 +287,32 @@ public class GATKArgumentCollection {
@Argument(fullName = "unsafe", shortName = "U", doc = "If set, enables unsafe operations: nothing will be checked at runtime. For expert users only who know what they are doing. We do not support usage of this argument.", required = false)
public ValidationExclusion.TYPE unsafe;
/** How many threads should be allocated to this analysis. */
@Argument(fullName = "num_threads", shortName = "nt", doc = "How many threads should be allocated to running this analysis.", required = false)
public Integer numberOfThreads = 1;
// --------------------------------------------------------------------------------------------------------------
//
// Multi-threading arguments
//
// --------------------------------------------------------------------------------------------------------------
/**
* How many data threads should be allocated to this analysis? Data threads contains N cpu threads per
* data thread, and act as completely data parallel processing, increasing the memory usage of GATK
* by M data threads. Data threads generally scale extremely effectively, up to 24 cores
*/
@Argument(fullName = "num_threads", shortName = "nt", doc = "How many data threads should be allocated to running this analysis.", required = false)
public Integer numberOfDataThreads = 1;
/**
* How many CPU threads should be allocated per data thread? Each CPU thread operates the map
* cycle independently, but may run into earlier scaling problems with IO than data threads. Has
* the benefit of not requiring X times as much memory per thread as data threads do, but rather
* only a constant overhead.
*/
@Argument(fullName="num_cpu_threads_per_data_thread", shortName = "cnt", doc="How many CPU threads should be allocated per data thread to running this analysis?", required = false)
public int numberOfCPUThreadsPerDataThread = 1;
@Argument(fullName="num_io_threads", shortName = "nit", doc="How many of the given threads should be allocated to IO", required = false)
@Hidden
public int numberOfIOThreads = 0;
/**
* By default the GATK monitors its own efficiency, but this can have a itsy-bitsy tiny
@ -291,17 +322,6 @@ public class GATKArgumentCollection {
@Argument(fullName = "disableThreadEfficiencyMonitor", shortName = "dtem", doc = "Disable GATK efficiency monitoring", required = false)
public Boolean disableEfficiencyMonitor = false;
/**
* The following two arguments (num_cpu_threads, num_io_threads are TEMPORARY since Queue cannot currently support arbitrary tagged data types.
* TODO: Kill this when I can do a tagged integer in Queue.
*/
@Argument(fullName="num_cpu_threads", shortName = "nct", doc="How many of the given threads should be allocated to the CPU", required = false)
@Hidden
public Integer numberOfCPUThreads = null;
@Argument(fullName="num_io_threads", shortName = "nit", doc="How many of the given threads should be allocated to IO", required = false)
@Hidden
public Integer numberOfIOThreads = null;
@Argument(fullName = "num_bam_file_handles", shortName = "bfh", doc="The total number of BAM file handles to keep open simultaneously", required=false)
public Integer numberOfBAMFileHandles = null;

View File

@ -8,6 +8,7 @@ import org.broadinstitute.sting.gatk.datasources.reads.Shard;
import org.broadinstitute.sting.gatk.datasources.rmd.ReferenceOrderedDataSource;
import org.broadinstitute.sting.gatk.io.OutputTracker;
import org.broadinstitute.sting.gatk.io.ThreadLocalOutputTracker;
import org.broadinstitute.sting.gatk.resourcemanagement.ThreadAllocation;
import org.broadinstitute.sting.gatk.walkers.TreeReducible;
import org.broadinstitute.sting.gatk.walkers.Walker;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
@ -76,21 +77,21 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar
/**
* Create a new hierarchical microscheduler to process the given reads and reference.
*
* @param walker the walker used to process the dataset.
* @param reads Reads file(s) to process.
* @param reference Reference for driving the traversal.
* @param nThreadsToUse maximum number of threads to use to do the work
* @param walker the walker used to process the dataset.
* @param reads Reads file(s) to process.
* @param reference Reference for driving the traversal.
* @param threadAllocation How should we apply multi-threaded execution?
*/
protected HierarchicalMicroScheduler(final GenomeAnalysisEngine engine,
final Walker walker,
final SAMDataSource reads,
final IndexedFastaSequenceFile reference,
final Collection<ReferenceOrderedDataSource> rods,
final int nThreadsToUse,
final boolean monitorThreadPerformance ) {
super(engine, walker, reads, reference, rods, nThreadsToUse);
final ThreadAllocation threadAllocation) {
super(engine, walker, reads, reference, rods, threadAllocation);
if ( monitorThreadPerformance ) {
final int nThreadsToUse = threadAllocation.getNumDataThreads();
if ( threadAllocation.monitorThreadEfficiency() ) {
final EfficiencyMonitoringThreadFactory monitoringThreadFactory = new EfficiencyMonitoringThreadFactory(nThreadsToUse);
setThreadEfficiencyMonitor(monitoringThreadFactory);
this.threadPool = Executors.newFixedThreadPool(nThreadsToUse, monitoringThreadFactory);

View File

@ -10,6 +10,7 @@ import org.broadinstitute.sting.gatk.datasources.reads.Shard;
import org.broadinstitute.sting.gatk.datasources.rmd.ReferenceOrderedDataSource;
import org.broadinstitute.sting.gatk.io.DirectOutputTracker;
import org.broadinstitute.sting.gatk.io.OutputTracker;
import org.broadinstitute.sting.gatk.resourcemanagement.ThreadAllocation;
import org.broadinstitute.sting.gatk.traversals.TraverseActiveRegions;
import org.broadinstitute.sting.gatk.walkers.Walker;
import org.broadinstitute.sting.utils.SampleUtils;
@ -39,13 +40,11 @@ public class LinearMicroScheduler extends MicroScheduler {
final SAMDataSource reads,
final IndexedFastaSequenceFile reference,
final Collection<ReferenceOrderedDataSource> rods,
final int numThreads, // may be > 1 if are nanoScheduling
final boolean monitorThreadPerformance ) {
super(engine, walker, reads, reference, rods, numThreads);
final ThreadAllocation threadAllocation) {
super(engine, walker, reads, reference, rods, threadAllocation);
if ( monitorThreadPerformance )
if ( threadAllocation.monitorThreadEfficiency() )
setThreadEfficiencyMonitor(new ThreadEfficiencyMonitor());
}
/**
@ -60,11 +59,12 @@ public class LinearMicroScheduler extends MicroScheduler {
boolean done = walker.isDone();
int counter = 0;
traversalEngine.startTimersIfNecessary();
for (Shard shard : shardStrategy ) {
if ( done || shard == null ) // we ran out of shards that aren't owned
break;
traversalEngine.startTimersIfNecessary();
if(shard.getShardType() == Shard.ShardType.LOCUS) {
WindowMaker windowMaker = new WindowMaker(shard, engine.getGenomeLocParser(),
getReadIterator(shard), shard.getGenomeLocs(), SampleUtils.getSAMFileSamples(engine));

View File

@ -59,6 +59,8 @@ import java.util.Collection;
/** Shards and schedules data in manageable chunks. */
public abstract class MicroScheduler implements MicroSchedulerMBean {
// TODO -- remove me and retire non nano scheduled versions of traversals
private final static boolean USE_NANOSCHEDULER_FOR_EVERYTHING = true;
protected static final Logger logger = Logger.getLogger(MicroScheduler.class);
/**
@ -100,27 +102,30 @@ public abstract class MicroScheduler implements MicroSchedulerMBean {
* @return The best-fit microscheduler.
*/
public static MicroScheduler create(GenomeAnalysisEngine engine, Walker walker, SAMDataSource reads, IndexedFastaSequenceFile reference, Collection<ReferenceOrderedDataSource> rods, ThreadAllocation threadAllocation) {
if (threadAllocation.getNumCPUThreads() > 1) {
if ( threadAllocation.isRunningInParallelMode() )
logger.info(String.format("Running the GATK in parallel mode with %d CPU thread(s) for each of %d data thread(s)",
threadAllocation.getNumCPUThreadsPerDataThread(), threadAllocation.getNumDataThreads()));
if ( threadAllocation.getNumDataThreads() > 1 ) {
if (walker.isReduceByInterval())
throw new UserException.BadArgumentValue("nt", String.format("The analysis %s aggregates results by interval. Due to a current limitation of the GATK, analyses of this type do not currently support parallel execution. Please run your analysis without the -nt option.", engine.getWalkerName(walker.getClass())));
logger.info(String.format("Running the GATK in parallel mode with %d concurrent threads",threadAllocation.getNumCPUThreads()));
if ( walker instanceof ReadWalker ) {
if ( ! (walker instanceof ThreadSafeMapReduce) ) badNT(engine, walker);
return new LinearMicroScheduler(engine, walker, reads, reference, rods, threadAllocation.getNumCPUThreads(), threadAllocation.monitorThreadEfficiency());
if ( ! (walker instanceof TreeReducible) ) {
throw badNT("nt", engine, walker);
} else {
// TODO -- update test for when nano scheduling only is an option
if ( ! (walker instanceof TreeReducible) ) badNT(engine, walker);
return new HierarchicalMicroScheduler(engine, walker, reads, reference, rods, threadAllocation.getNumCPUThreads(), threadAllocation.monitorThreadEfficiency());
return new HierarchicalMicroScheduler(engine, walker, reads, reference, rods, threadAllocation);
}
} else {
return new LinearMicroScheduler(engine, walker, reads, reference, rods, threadAllocation.getNumCPUThreads(), threadAllocation.monitorThreadEfficiency());
if ( threadAllocation.getNumCPUThreadsPerDataThread() > 1 && ! (walker instanceof NanoSchedulable) )
throw badNT("cnt", engine, walker);
return new LinearMicroScheduler(engine, walker, reads, reference, rods, threadAllocation);
}
}
private static void badNT(final GenomeAnalysisEngine engine, final Walker walker) {
throw new UserException.BadArgumentValue("nt", String.format("The analysis %s currently does not support parallel execution. Please run your analysis without the -nt option.", engine.getWalkerName(walker.getClass())));
private static UserException badNT(final String parallelArg, final GenomeAnalysisEngine engine, final Walker walker) {
throw new UserException.BadArgumentValue("nt",
String.format("The analysis %s currently does not support parallel execution with %s. " +
"Please run your analysis without the %s option.", engine.getWalkerName(walker.getClass()), parallelArg, parallelArg));
}
/**
@ -130,23 +135,27 @@ public abstract class MicroScheduler implements MicroSchedulerMBean {
* @param reads The reads.
* @param reference The reference.
* @param rods the rods to include in the traversal
* @param numThreads the number of threads we are using in the underlying traversal
* @param threadAllocation the allocation of threads to use in the underlying traversal
*/
protected MicroScheduler(final GenomeAnalysisEngine engine,
final Walker walker,
final SAMDataSource reads,
final IndexedFastaSequenceFile reference,
final Collection<ReferenceOrderedDataSource> rods,
final int numThreads) {
final ThreadAllocation threadAllocation) {
this.engine = engine;
this.reads = reads;
this.reference = reference;
this.rods = rods;
if (walker instanceof ReadWalker) {
traversalEngine = numThreads > 1 ? new TraverseReadsNano(numThreads) : new TraverseReads();
traversalEngine = USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1
? new TraverseReadsNano(threadAllocation.getNumCPUThreadsPerDataThread())
: new TraverseReads();
} else if (walker instanceof LocusWalker) {
traversalEngine = new TraverseLoci();
traversalEngine = USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1
? new TraverseLociNano(threadAllocation.getNumCPUThreadsPerDataThread())
: new TraverseLociLinear();
} else if (walker instanceof DuplicateWalker) {
traversalEngine = new TraverseDuplicates();
} else if (walker instanceof ReadPairWalker) {

View File

@ -32,9 +32,9 @@ import org.broadinstitute.sting.utils.classloader.JVMUtils;
import org.broadinstitute.sting.utils.codecs.vcf.VCFHeader;
import org.broadinstitute.sting.utils.codecs.vcf.VCFHeaderLine;
import org.broadinstitute.sting.utils.codecs.vcf.VCFUtils;
import org.broadinstitute.sting.utils.variantcontext.VariantContext;
import org.broadinstitute.sting.utils.variantcontext.writer.Options;
import org.broadinstitute.sting.utils.variantcontext.writer.VariantContextWriter;
import org.broadinstitute.sting.utils.variantcontext.VariantContext;
import org.broadinstitute.sting.utils.variantcontext.writer.VariantContextWriterFactory;
import java.io.File;
@ -269,7 +269,7 @@ public class VariantContextWriterStub implements Stub<VariantContextWriter>, Var
* @return
*/
public boolean alsoWriteBCFForTest() {
return engine.getArguments().numberOfThreads == 1 && // only works single threaded
return engine.getArguments().numberOfDataThreads == 1 && // only works single threaded
! isCompressed() && // for non-compressed outputs
getFile() != null && // that are going to disk
engine.getArguments().generateShadowBCF; // and we actually want to do it

View File

@ -218,7 +218,7 @@ public class GATKRunReport {
// if there was an exception, capture it
this.mException = e == null ? null : new ExceptionToXML(e);
numThreads = engine.getArguments().numberOfThreads;
numThreads = engine.getTotalNumberOfThreads();
percentTimeRunning = getThreadEfficiencyPercent(engine, ThreadEfficiencyMonitor.State.USER_CPU);
percentTimeBlocking = getThreadEfficiencyPercent(engine, ThreadEfficiencyMonitor.State.BLOCKING);
percentTimeWaiting = getThreadEfficiencyPercent(engine, ThreadEfficiencyMonitor.State.WAITING);

View File

@ -24,7 +24,7 @@
package org.broadinstitute.sting.gatk.resourcemanagement;
import org.broadinstitute.sting.utils.exceptions.UserException;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
/**
* Models how threads are distributed between various components of the GATK.
@ -33,7 +33,12 @@ public class ThreadAllocation {
/**
* The number of CPU threads to be used by the GATK.
*/
private final int numCPUThreads;
private final int numDataThreads;
/**
* The number of CPU threads per data thread for GATK processing
*/
private final int numCPUThreadsPerDataThread;
/**
* Number of threads to devote exclusively to IO. Default is 0.
@ -45,8 +50,12 @@ public class ThreadAllocation {
*/
private final boolean monitorEfficiency;
public int getNumCPUThreads() {
return numCPUThreads;
public int getNumDataThreads() {
return numDataThreads;
}
public int getNumCPUThreadsPerDataThread() {
return numCPUThreadsPerDataThread;
}
public int getNumIOThreads() {
@ -57,47 +66,50 @@ public class ThreadAllocation {
return monitorEfficiency;
}
/**
* Are we running in parallel mode?
*
* @return true if any parallel processing is enabled
*/
public boolean isRunningInParallelMode() {
return getTotalNumThreads() > 1;
}
/**
* What is the total number of threads in use by the GATK?
*
* @return the sum of all thread allocations in this object
*/
public int getTotalNumThreads() {
return getNumDataThreads() * getNumCPUThreadsPerDataThread() + getNumIOThreads();
}
/**
* Construct the default thread allocation.
*/
public ThreadAllocation() {
this(1, null, null, false);
this(1, 1, 0, false);
}
/**
* Set up the thread allocation. Default allocation is 1 CPU thread, 0 IO threads.
* (0 IO threads means that no threads are devoted exclusively to IO; they're inline on the CPU thread).
* @param totalThreads Complete number of threads to allocate.
* @param numCPUThreads Total number of threads allocated to the traversal.
* @param numDataThreads Total number of threads allocated to the traversal.
* @param numCPUThreadsPerDataThread The number of CPU threads per data thread to allocate
* @param numIOThreads Total number of threads allocated exclusively to IO.
* @param monitorEfficiency should we monitor threading efficiency in the GATK?
*/
public ThreadAllocation(final int totalThreads, final Integer numCPUThreads, final Integer numIOThreads, final boolean monitorEfficiency) {
// If no allocation information is present, allocate all threads to CPU
if(numCPUThreads == null && numIOThreads == null) {
this.numCPUThreads = totalThreads;
this.numIOThreads = 0;
}
// If only CPU threads are specified, allocate remainder to IO (minimum 0 dedicated IO threads).
else if(numIOThreads == null) {
if(numCPUThreads > totalThreads)
throw new UserException(String.format("Invalid thread allocation. User requested %d threads in total, but the count of cpu threads (%d) is higher than the total threads",totalThreads,numCPUThreads));
this.numCPUThreads = numCPUThreads;
this.numIOThreads = totalThreads - numCPUThreads;
}
// If only IO threads are specified, allocate remainder to CPU (minimum 1 dedicated CPU thread).
else if(numCPUThreads == null) {
if(numIOThreads > totalThreads)
throw new UserException(String.format("Invalid thread allocation. User requested %d threads in total, but the count of io threads (%d) is higher than the total threads",totalThreads,numIOThreads));
this.numCPUThreads = Math.max(1,totalThreads-numIOThreads);
this.numIOThreads = numIOThreads;
}
else {
if(numCPUThreads + numIOThreads != totalThreads)
throw new UserException(String.format("Invalid thread allocation. User requested %d threads in total, but the count of cpu threads (%d) + the count of io threads (%d) does not match",totalThreads,numCPUThreads,numIOThreads));
this.numCPUThreads = numCPUThreads;
this.numIOThreads = numIOThreads;
}
public ThreadAllocation(final int numDataThreads,
final int numCPUThreadsPerDataThread,
final int numIOThreads,
final boolean monitorEfficiency) {
if ( numDataThreads < 1 ) throw new ReviewedStingException("numDataThreads cannot be less than 1, but saw " + numDataThreads);
if ( numCPUThreadsPerDataThread < 1 ) throw new ReviewedStingException("numCPUThreadsPerDataThread cannot be less than 1, but saw " + numCPUThreadsPerDataThread);
if ( numIOThreads < 0 ) throw new ReviewedStingException("numIOThreads cannot be less than 0, but saw " + numIOThreads);
this.numDataThreads = numDataThreads;
this.numCPUThreadsPerDataThread = numCPUThreadsPerDataThread;
this.numIOThreads = numIOThreads;
this.monitorEfficiency = monitorEfficiency;
}
}

View File

@ -44,24 +44,12 @@ import java.util.List;
import java.util.Map;
public abstract class TraversalEngine<M,T,WalkerType extends Walker<M,T>,ProviderType extends ShardDataProvider> {
/** our log, which we want to capture anything from this class */
protected static final Logger logger = Logger.getLogger(TraversalEngine.class);
// Time in milliseconds since we initialized this engine
private static final int HISTORY_WINDOW_SIZE = 50;
private static class ProcessingHistory {
double elapsedSeconds;
long unitsProcessed;
long bpProcessed;
GenomeLoc loc;
public ProcessingHistory(double elapsedSeconds, GenomeLoc loc, long unitsProcessed, long bpProcessed) {
this.elapsedSeconds = elapsedSeconds;
this.loc = loc;
this.unitsProcessed = unitsProcessed;
this.bpProcessed = bpProcessed;
}
}
/** lock object to sure updates to history are consistent across threads */
private static final Object lock = new Object();
LinkedList<ProcessingHistory> history = new LinkedList<ProcessingHistory>();
@ -70,13 +58,12 @@ public abstract class TraversalEngine<M,T,WalkerType extends Walker<M,T>,Provide
private SimpleTimer timer = null;
// How long can we go without printing some progress info?
private static final int PRINT_PROGRESS_CHECK_FREQUENCY_IN_CYCLES = 1000;
private int printProgressCheckCounter = 0;
private long lastProgressPrintTime = -1; // When was the last time we printed progress log?
private long MIN_ELAPSED_TIME_BEFORE_FIRST_PROGRESS = 30 * 1000; // in milliseconds
private long PROGRESS_PRINT_FREQUENCY = 10 * 1000; // in milliseconds
private final double TWO_HOURS_IN_SECONDS = 2.0 * 60.0 * 60.0;
private final double TWELVE_HOURS_IN_SECONDS = 12.0 * 60.0 * 60.0;
private final static long MIN_ELAPSED_TIME_BEFORE_FIRST_PROGRESS = 30 * 1000; // in milliseconds
private final static double TWO_HOURS_IN_SECONDS = 2.0 * 60.0 * 60.0;
private final static double TWELVE_HOURS_IN_SECONDS = 12.0 * 60.0 * 60.0;
private long progressPrintFrequency = 10 * 1000; // in milliseconds
private boolean progressMeterInitialized = false;
// for performance log
@ -85,15 +72,12 @@ public abstract class TraversalEngine<M,T,WalkerType extends Walker<M,T>,Provide
private File performanceLogFile;
private PrintStream performanceLog = null;
private long lastPerformanceLogPrintTime = -1; // When was the last time we printed to the performance log?
private final long PERFORMANCE_LOG_PRINT_FREQUENCY = PROGRESS_PRINT_FREQUENCY; // in milliseconds
private final long PERFORMANCE_LOG_PRINT_FREQUENCY = progressPrintFrequency; // in milliseconds
/** Size, in bp, of the area we are processing. Updated once in the system in initial for performance reasons */
long targetSize = -1;
GenomeLocSortedSet targetIntervals = null;
/** our log, which we want to capture anything from this class */
protected static final Logger logger = Logger.getLogger(TraversalEngine.class);
protected GenomeAnalysisEngine engine;
// ----------------------------------------------------------------------------------------------------
@ -186,15 +170,35 @@ public abstract class TraversalEngine<M,T,WalkerType extends Walker<M,T>,Provide
return elapsed > printFreq && elapsed > MIN_ELAPSED_TIME_BEFORE_FIRST_PROGRESS;
}
/**
* Update the cumulative traversal metrics according to the data in this shard
*
* @param shard a non-null shard
*/
public void updateCumulativeMetrics(final Shard shard) {
updateCumulativeMetrics(shard.getReadMetrics());
}
/**
* Update the cumulative traversal metrics according to the data in this shard
*
* @param singleTraverseMetrics read metrics object containing the information about a single shard's worth
* of data processing
*/
public void updateCumulativeMetrics(final ReadMetrics singleTraverseMetrics) {
engine.getCumulativeMetrics().incrementMetrics(singleTraverseMetrics);
}
/**
* Forward request to printProgress
*
* @param shard the given shard currently being processed.
* Assumes that one cycle has been completed
*
* @param loc the location
*/
public void printProgress(Shard shard, GenomeLoc loc) {
public void printProgress(final GenomeLoc loc) {
// A bypass is inserted here for unit testing.
printProgress(loc,shard.getReadMetrics(),false);
printProgress(loc, false);
}
/**
@ -202,15 +206,10 @@ public abstract class TraversalEngine<M,T,WalkerType extends Walker<M,T>,Provide
* every M seconds, for N and M set in global variables.
*
* @param loc Current location, can be null if you are at the end of the traversal
* @param metrics Data processed since the last cumulative
* @param mustPrint If true, will print out info, regardless of nRecords or time interval
*/
private void printProgress(GenomeLoc loc, ReadMetrics metrics, boolean mustPrint) {
if ( mustPrint || printProgressCheckCounter++ % PRINT_PROGRESS_CHECK_FREQUENCY_IN_CYCLES != 0 )
// don't do any work more often than PRINT_PROGRESS_CHECK_FREQUENCY_IN_CYCLES
return;
if(!progressMeterInitialized && mustPrint == false ) {
private synchronized void printProgress(final GenomeLoc loc, boolean mustPrint) {
if( ! progressMeterInitialized ) {
logger.info("[INITIALIZATION COMPLETE; TRAVERSAL STARTING]");
logger.info(String.format("%15s processed.%s runtime per.1M.%s completed total.runtime remaining",
"Location", getTraversalType(), getTraversalType()));
@ -218,40 +217,34 @@ public abstract class TraversalEngine<M,T,WalkerType extends Walker<M,T>,Provide
}
final long curTime = timer.currentTime();
boolean printProgress = mustPrint || maxElapsedIntervalForPrinting(curTime, lastProgressPrintTime, PROGRESS_PRINT_FREQUENCY);
boolean printProgress = mustPrint || maxElapsedIntervalForPrinting(curTime, lastProgressPrintTime, progressPrintFrequency);
boolean printLog = performanceLog != null && maxElapsedIntervalForPrinting(curTime, lastPerformanceLogPrintTime, PERFORMANCE_LOG_PRINT_FREQUENCY);
if ( printProgress || printLog ) {
// getting and appending metrics data actually turns out to be quite a heavyweight
// operation. Postpone it until after determining whether to print the log message.
ReadMetrics cumulativeMetrics = engine.getCumulativeMetrics() != null ? engine.getCumulativeMetrics() : new ReadMetrics();
if(metrics != null)
cumulativeMetrics.incrementMetrics(metrics);
final long nRecords = cumulativeMetrics.getNumIterations();
ProcessingHistory last = updateHistory(loc,cumulativeMetrics);
final ProcessingHistory last = updateHistory(loc, engine.getCumulativeMetrics());
final AutoFormattingTime elapsed = new AutoFormattingTime(last.elapsedSeconds);
final AutoFormattingTime bpRate = new AutoFormattingTime(secondsPerMillionBP(last));
final AutoFormattingTime unitRate = new AutoFormattingTime(secondsPerMillionElements(last));
final double fractionGenomeTargetCompleted = calculateFractionGenomeTargetCompleted(last);
final AutoFormattingTime bpRate = new AutoFormattingTime(last.secondsPerMillionBP());
final AutoFormattingTime unitRate = new AutoFormattingTime(last.secondsPerMillionElements());
final double fractionGenomeTargetCompleted = last.calculateFractionGenomeTargetCompleted(targetSize);
final AutoFormattingTime estTotalRuntime = new AutoFormattingTime(elapsed.getTimeInSeconds() / fractionGenomeTargetCompleted);
final AutoFormattingTime timeToCompletion = new AutoFormattingTime(estTotalRuntime.getTimeInSeconds() - elapsed.getTimeInSeconds());
final long nRecords = engine.getCumulativeMetrics().getNumIterations();
if ( printProgress ) {
lastProgressPrintTime = curTime;
// dynamically change the update rate so that short running jobs receive frequent updates while longer jobs receive fewer updates
if ( estTotalRuntime.getTimeInSeconds() > TWELVE_HOURS_IN_SECONDS )
PROGRESS_PRINT_FREQUENCY = 60 * 1000; // in milliseconds
progressPrintFrequency = 60 * 1000; // in milliseconds
else if ( estTotalRuntime.getTimeInSeconds() > TWO_HOURS_IN_SECONDS )
PROGRESS_PRINT_FREQUENCY = 30 * 1000; // in milliseconds
progressPrintFrequency = 30 * 1000; // in milliseconds
else
PROGRESS_PRINT_FREQUENCY = 10 * 1000; // in milliseconds
progressPrintFrequency = 10 * 1000; // in milliseconds
logger.info(String.format("%15s %5.2e %s %s %4.1f%% %s %s",
loc == null ? "done with mapped reads" : loc, nRecords*1.0, elapsed, unitRate,
final String posName = loc == null ? (mustPrint ? "done" : "unmapped reads") : String.format("%s:%d", loc.getContig(), loc.getStart());
logger.info(String.format("%15s %5.2e %s %s %5.1f%% %s %s",
posName, nRecords*1.0, elapsed, unitRate,
100*fractionGenomeTargetCompleted, estTotalRuntime, timeToCompletion));
}
@ -277,7 +270,7 @@ public abstract class TraversalEngine<M,T,WalkerType extends Walker<M,T>,Provide
* @param metrics information about what's been processed already
* @return
*/
private final ProcessingHistory updateHistory(GenomeLoc loc, ReadMetrics metrics) {
private ProcessingHistory updateHistory(GenomeLoc loc, ReadMetrics metrics) {
synchronized (lock) {
if ( history.size() > HISTORY_WINDOW_SIZE )
history.pop();
@ -290,26 +283,11 @@ public abstract class TraversalEngine<M,T,WalkerType extends Walker<M,T>,Provide
}
}
/** How long in seconds to process 1M traversal units? */
private final double secondsPerMillionElements(ProcessingHistory last) {
return (last.elapsedSeconds * 1000000.0) / Math.max(last.unitsProcessed, 1);
}
/** How long in seconds to process 1M bp on the genome? */
private final double secondsPerMillionBP(ProcessingHistory last) {
return (last.elapsedSeconds * 1000000.0) / Math.max(last.bpProcessed, 1);
}
/** What fractoin of the target intervals have we covered? */
private final double calculateFractionGenomeTargetCompleted(ProcessingHistory last) {
return (1.0*last.bpProcessed) / targetSize;
}
/**
* Called after a traversal to print out information about the traversal process
*/
public void printOnTraversalDone() {
printProgress(null, null, true);
printProgress(null, true);
final double elapsed = timer == null ? 0 : timer.getElapsedTime();
@ -370,7 +348,7 @@ public abstract class TraversalEngine<M,T,WalkerType extends Walker<M,T>,Provide
* @return Frequency, in seconds, of performance log writes.
*/
public long getPerformanceProgressPrintFrequencySeconds() {
return PROGRESS_PRINT_FREQUENCY;
return progressPrintFrequency;
}
/**
@ -378,6 +356,35 @@ public abstract class TraversalEngine<M,T,WalkerType extends Walker<M,T>,Provide
* @param seconds number of seconds between messages indicating performance frequency.
*/
public void setPerformanceProgressPrintFrequencySeconds(long seconds) {
PROGRESS_PRINT_FREQUENCY = seconds;
progressPrintFrequency = seconds;
}
private static class ProcessingHistory {
double elapsedSeconds;
long unitsProcessed;
long bpProcessed;
GenomeLoc loc;
public ProcessingHistory(double elapsedSeconds, GenomeLoc loc, long unitsProcessed, long bpProcessed) {
this.elapsedSeconds = elapsedSeconds;
this.loc = loc;
this.unitsProcessed = unitsProcessed;
this.bpProcessed = bpProcessed;
}
/** How long in seconds to process 1M traversal units? */
private double secondsPerMillionElements() {
return (elapsedSeconds * 1000000.0) / Math.max(unitsProcessed, 1);
}
/** How long in seconds to process 1M bp on the genome? */
private double secondsPerMillionBP() {
return (elapsedSeconds * 1000000.0) / Math.max(bpProcessed, 1);
}
/** What fractoin of the target intervals have we covered? */
private double calculateFractionGenomeTargetCompleted(final long targetSize) {
return (1.0*bpProcessed) / targetSize;
}
}
}

View File

@ -104,7 +104,8 @@ public class TraverseActiveRegions <M,T> extends TraversalEngine<M,T,ActiveRegio
prevLoc = location;
printProgress(dataProvider.getShard(), locus.getLocation());
updateCumulativeMetrics(dataProvider.getShard());
printProgress(locus.getLocation());
}
// Take the individual isActive calls and integrate them into contiguous active regions and

View File

@ -196,7 +196,8 @@ public class TraverseDuplicates<M,T> extends TraversalEngine<M,T,DuplicateWalker
sum = walker.reduce(x, sum);
}
printProgress(dataProvider.getShard(),site);
updateCumulativeMetrics(dataProvider.getShard());
printProgress(site);
done = walker.isDone();
}

View File

@ -3,9 +3,7 @@ package org.broadinstitute.sting.gatk.traversals;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.gatk.WalkerManager;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.datasources.providers.*;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.DataSource;
import org.broadinstitute.sting.gatk.walkers.LocusWalker;
import org.broadinstitute.sting.gatk.walkers.Walker;
@ -15,28 +13,42 @@ import org.broadinstitute.sting.utils.pileup.ReadBackedPileupImpl;
/**
* A simple solution to iterating over all reference positions over a series of genomic locations.
*/
public class TraverseLoci<M,T> extends TraversalEngine<M,T,LocusWalker<M,T>,LocusShardDataProvider> {
public abstract class TraverseLociBase<M,T> extends TraversalEngine<M,T,LocusWalker<M,T>,LocusShardDataProvider> {
/**
* our log, which we want to capture anything from this class
*/
protected static final Logger logger = Logger.getLogger(TraversalEngine.class);
@Override
protected String getTraversalType() {
protected final String getTraversalType() {
return "sites";
}
protected static class TraverseResults<T> {
final int numIterations;
final T reduceResult;
public TraverseResults(int numIterations, T reduceResult) {
this.numIterations = numIterations;
this.reduceResult = reduceResult;
}
}
protected abstract TraverseResults<T> traverse( final LocusWalker<M,T> walker,
final LocusView locusView,
final LocusReferenceView referenceView,
final ReferenceOrderedView referenceOrderedDataView,
final T sum);
@Override
public T traverse( LocusWalker<M,T> walker,
LocusShardDataProvider dataProvider,
T sum) {
logger.debug(String.format("TraverseLoci.traverse: Shard is %s", dataProvider));
logger.debug(String.format("TraverseLociBase.traverse: Shard is %s", dataProvider));
LocusView locusView = getLocusView( walker, dataProvider );
boolean done = false;
final LocusView locusView = getLocusView( walker, dataProvider );
if ( locusView.hasNext() ) { // trivial optimization to avoid unnecessary processing when there's nothing here at all
//ReferenceOrderedView referenceOrderedDataView = new ReferenceOrderedView( dataProvider );
ReferenceOrderedView referenceOrderedDataView = null;
if ( WalkerManager.getWalkerDataSource(walker) != DataSource.REFERENCE_ORDERED_DATA )
@ -44,43 +56,24 @@ public class TraverseLoci<M,T> extends TraversalEngine<M,T,LocusWalker<M,T>,Locu
else
referenceOrderedDataView = (RodLocusView)locusView;
LocusReferenceView referenceView = new LocusReferenceView( walker, dataProvider );
final LocusReferenceView referenceView = new LocusReferenceView( walker, dataProvider );
// We keep processing while the next reference location is within the interval
while( locusView.hasNext() && ! done ) {
AlignmentContext locus = locusView.next();
GenomeLoc location = locus.getLocation();
dataProvider.getShard().getReadMetrics().incrementNumIterations();
// create reference context. Note that if we have a pileup of "extended events", the context will
// hold the (longest) stretch of deleted reference bases (if deletions are present in the pileup).
ReferenceContext refContext = referenceView.getReferenceContext(location);
// Iterate forward to get all reference ordered data covering this location
final RefMetaDataTracker tracker = referenceOrderedDataView.getReferenceOrderedDataAtLocus(locus.getLocation(), refContext);
final boolean keepMeP = walker.filter(tracker, refContext, locus);
if (keepMeP) {
M x = walker.map(tracker, refContext, locus);
sum = walker.reduce(x, sum);
done = walker.isDone();
}
printProgress(dataProvider.getShard(),locus.getLocation());
}
final TraverseResults<T> result = traverse( walker, locusView, referenceView, referenceOrderedDataView, sum );
sum = result.reduceResult;
dataProvider.getShard().getReadMetrics().incrementNumIterations(result.numIterations);
updateCumulativeMetrics(dataProvider.getShard());
}
// We have a final map call to execute here to clean up the skipped based from the
// last position in the ROD to that in the interval
if ( WalkerManager.getWalkerDataSource(walker) == DataSource.REFERENCE_ORDERED_DATA && ! walker.isDone() ) {
// only do this if the walker isn't done!
RodLocusView rodLocusView = (RodLocusView)locusView;
long nSkipped = rodLocusView.getLastSkippedBases();
final RodLocusView rodLocusView = (RodLocusView)locusView;
final long nSkipped = rodLocusView.getLastSkippedBases();
if ( nSkipped > 0 ) {
GenomeLoc site = rodLocusView.getLocOneBeyondShard();
AlignmentContext ac = new AlignmentContext(site, new ReadBackedPileupImpl(site), nSkipped);
M x = walker.map(null, null, ac);
final GenomeLoc site = rodLocusView.getLocOneBeyondShard();
final AlignmentContext ac = new AlignmentContext(site, new ReadBackedPileupImpl(site), nSkipped);
final M x = walker.map(null, null, ac);
sum = walker.reduce(x, sum);
}
}
@ -90,14 +83,14 @@ public class TraverseLoci<M,T> extends TraversalEngine<M,T,LocusWalker<M,T>,Locu
/**
* Gets the best view of loci for this walker given the available data. The view will function as a 'trigger track'
* of sorts, providing a consistent interface so that TraverseLoci doesn't need to be reimplemented for any new datatype
* of sorts, providing a consistent interface so that TraverseLociBase doesn't need to be reimplemented for any new datatype
* that comes along.
* @param walker walker to interrogate.
* @param dataProvider Data which which to drive the locus view.
* @return A view of the locus data, where one iteration of the locus view maps to one iteration of the traversal.
*/
private LocusView getLocusView( Walker<M,T> walker, LocusShardDataProvider dataProvider ) {
DataSource dataSource = WalkerManager.getWalkerDataSource(walker);
final DataSource dataSource = WalkerManager.getWalkerDataSource(walker);
if( dataSource == DataSource.READS )
return new CoveredLocusView(dataProvider);
else if( dataSource == DataSource.REFERENCE ) //|| ! GenomeAnalysisEngine.instance.getArguments().enableRodWalkers )

View File

@ -0,0 +1,47 @@
package org.broadinstitute.sting.gatk.traversals;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.datasources.providers.LocusReferenceView;
import org.broadinstitute.sting.gatk.datasources.providers.LocusView;
import org.broadinstitute.sting.gatk.datasources.providers.ReferenceOrderedView;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.LocusWalker;
import org.broadinstitute.sting.utils.GenomeLoc;
/**
* A simple solution to iterating over all reference positions over a series of genomic locations.
*/
public class TraverseLociLinear<M,T> extends TraverseLociBase<M,T> {
@Override
protected TraverseResults<T> traverse(LocusWalker<M, T> walker, LocusView locusView, LocusReferenceView referenceView, ReferenceOrderedView referenceOrderedDataView, T sum) {
// We keep processing while the next reference location is within the interval
boolean done = false;
int numIterations = 0;
while( locusView.hasNext() && ! done ) {
numIterations++;
final AlignmentContext locus = locusView.next();
final GenomeLoc location = locus.getLocation();
// create reference context. Note that if we have a pileup of "extended events", the context will
// hold the (longest) stretch of deleted reference bases (if deletions are present in the pileup).
final ReferenceContext refContext = referenceView.getReferenceContext(location);
// Iterate forward to get all reference ordered data covering this location
final RefMetaDataTracker tracker = referenceOrderedDataView.getReferenceOrderedDataAtLocus(locus.getLocation(), refContext);
final boolean keepMeP = walker.filter(tracker, refContext, locus);
if (keepMeP) {
final M x = walker.map(tracker, refContext, locus);
sum = walker.reduce(x, sum);
done = walker.isDone();
}
printProgress(locus.getLocation());
}
return new TraverseResults<T>(numIterations, sum);
}
}

View File

@ -0,0 +1,205 @@
package org.broadinstitute.sting.gatk.traversals;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.datasources.providers.LocusReferenceView;
import org.broadinstitute.sting.gatk.datasources.providers.LocusView;
import org.broadinstitute.sting.gatk.datasources.providers.ReferenceOrderedView;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.LocusWalker;
import org.broadinstitute.sting.utils.GenomeLoc;
import org.broadinstitute.sting.utils.nanoScheduler.NanoScheduler;
import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerMapFunction;
import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerProgressFunction;
import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerReduceFunction;
import java.util.Iterator;
/**
* A simple solution to iterating over all reference positions over a series of genomic locations.
*/
public class TraverseLociNano<M,T> extends TraverseLociBase<M,T> {
/** our log, which we want to capture anything from this class */
private static final boolean DEBUG = false;
private static final int BUFFER_SIZE = 1000;
final NanoScheduler<MapData, MapResult, T> nanoScheduler;
public TraverseLociNano(int nThreads) {
nanoScheduler = new NanoScheduler<MapData, MapResult, T>(BUFFER_SIZE, nThreads);
nanoScheduler.setProgressFunction(new TraverseLociProgress());
}
@Override
protected TraverseResults<T> traverse(final LocusWalker<M, T> walker,
final LocusView locusView,
final LocusReferenceView referenceView,
final ReferenceOrderedView referenceOrderedDataView,
final T sum) {
nanoScheduler.setDebug(DEBUG);
final TraverseLociMap myMap = new TraverseLociMap(walker);
final TraverseLociReduce myReduce = new TraverseLociReduce(walker);
final MapDataIterator inputIterator = new MapDataIterator(locusView, referenceView, referenceOrderedDataView);
final T result = nanoScheduler.execute(inputIterator, myMap, sum, myReduce);
return new TraverseResults<T>(inputIterator.numIterations, result);
}
/**
* Create iterator that provides inputs for all map calls into MapData, to be provided
* to NanoScheduler for Map/Reduce
*/
private class MapDataIterator implements Iterator<MapData> {
final LocusView locusView;
final LocusReferenceView referenceView;
final ReferenceOrderedView referenceOrderedDataView;
int numIterations = 0;
private MapDataIterator(LocusView locusView, LocusReferenceView referenceView, ReferenceOrderedView referenceOrderedDataView) {
this.locusView = locusView;
this.referenceView = referenceView;
this.referenceOrderedDataView = referenceOrderedDataView;
}
@Override
public boolean hasNext() {
return locusView.hasNext();
}
@Override
public MapData next() {
final AlignmentContext locus = locusView.next();
final GenomeLoc location = locus.getLocation();
//logger.info("Pulling data from MapDataIterator at " + location);
// create reference context. Note that if we have a pileup of "extended events", the context will
// hold the (longest) stretch of deleted reference bases (if deletions are present in the pileup).
final ReferenceContext refContext = referenceView.getReferenceContext(location);
// Iterate forward to get all reference ordered data covering this location
final RefMetaDataTracker tracker = referenceOrderedDataView.getReferenceOrderedDataAtLocus(location, refContext);
numIterations++;
return new MapData(locus, refContext, tracker);
}
@Override
public void remove() {
throw new UnsupportedOperationException("Cannot remove elements from MapDataIterator");
}
}
@Override
public void printOnTraversalDone() {
nanoScheduler.shutdown();
super.printOnTraversalDone();
}
/**
* The input data needed for each map call. The read, the reference, and the RODs
*/
private class MapData {
final AlignmentContext alignmentContext;
final ReferenceContext refContext;
final RefMetaDataTracker tracker;
private MapData(final AlignmentContext alignmentContext, ReferenceContext refContext, RefMetaDataTracker tracker) {
this.alignmentContext = alignmentContext;
this.refContext = refContext;
this.tracker = tracker;
}
@Override
public String toString() {
return "MapData " + alignmentContext.getLocation();
}
}
/**
* Contains the results of a map call, indicating whether the call was good, filtered, or done
*/
private class MapResult {
final M value;
final boolean reduceMe;
/**
* Create a MapResult with value that should be reduced
*
* @param value the value to reduce
*/
private MapResult(final M value) {
this.value = value;
this.reduceMe = true;
}
/**
* Create a MapResult that shouldn't be reduced
*/
private MapResult() {
this.value = null;
this.reduceMe = false;
}
}
/**
* A static object that tells reduce that the result of map should be skipped (filtered or done)
*/
private final MapResult SKIP_REDUCE = new MapResult();
/**
* MapFunction for TraverseReads meeting NanoScheduler interface requirements
*
* Applies walker.map to MapData, returning a MapResult object containing the result
*/
private class TraverseLociMap implements NanoSchedulerMapFunction<MapData, MapResult> {
final LocusWalker<M,T> walker;
private TraverseLociMap(LocusWalker<M, T> walker) {
this.walker = walker;
}
@Override
public MapResult apply(final MapData data) {
if ( ! walker.isDone() ) {
final boolean keepMeP = walker.filter(data.tracker, data.refContext, data.alignmentContext);
if (keepMeP) {
final M x = walker.map(data.tracker, data.refContext, data.alignmentContext);
return new MapResult(x);
}
}
return SKIP_REDUCE;
}
}
/**
* NanoSchedulerReduceFunction for TraverseReads meeting NanoScheduler interface requirements
*
* Takes a MapResult object and applies the walkers reduce function to each map result, when applicable
*/
private class TraverseLociReduce implements NanoSchedulerReduceFunction<MapResult, T> {
final LocusWalker<M,T> walker;
private TraverseLociReduce(LocusWalker<M, T> walker) {
this.walker = walker;
}
@Override
public T apply(MapResult one, T sum) {
if ( one.reduceMe )
// only run reduce on values that aren't DONE or FAILED
return walker.reduce(one.value, sum);
else
return sum;
}
}
private class TraverseLociProgress implements NanoSchedulerProgressFunction<MapData> {
@Override
public void progress(MapData lastProcessedMap) {
if (lastProcessedMap.alignmentContext != null)
printProgress(lastProcessedMap.alignmentContext.getLocation());
}
}
}

View File

@ -65,7 +65,8 @@ public class TraverseReadPairs<M,T> extends TraversalEngine<M,T, ReadPairWalker<
pairs.clear();
pairs.add(read);
printProgress(dataProvider.getShard(),null);
updateCumulativeMetrics(dataProvider.getShard());
printProgress(null);
}
done = walker.isDone();

View File

@ -99,8 +99,11 @@ public class TraverseReads<M,T> extends TraversalEngine<M,T,ReadWalker<M,T>,Read
sum = walker.reduce(x, sum);
}
GenomeLoc locus = read.getReferenceIndex() == SAMRecord.NO_ALIGNMENT_REFERENCE_INDEX ? null : engine.getGenomeLocParser().createGenomeLoc(read.getReferenceName(),read.getAlignmentStart());
printProgress(dataProvider.getShard(),locus);
final GenomeLoc locus = read.getReferenceIndex() == SAMRecord.NO_ALIGNMENT_REFERENCE_INDEX ? null : engine.getGenomeLocParser().createGenomeLoc(read.getReferenceName(),read.getAlignmentStart());
updateCumulativeMetrics(dataProvider.getShard());
printProgress(locus);
done = walker.isDone();
}
return sum;

View File

@ -34,34 +34,34 @@ import org.broadinstitute.sting.gatk.datasources.providers.ReadView;
import org.broadinstitute.sting.gatk.datasources.reads.ReadShard;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.ReadWalker;
import org.broadinstitute.sting.utils.nanoScheduler.MapFunction;
import org.broadinstitute.sting.utils.GenomeLoc;
import org.broadinstitute.sting.utils.nanoScheduler.NanoScheduler;
import org.broadinstitute.sting.utils.nanoScheduler.ReduceFunction;
import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerMapFunction;
import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerReduceFunction;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import java.util.ArrayList;
import java.util.LinkedList;
import java.util.List;
/**
* @author aaron
* A nano-scheduling version of TraverseReads.
*
* Implements the traversal of a walker that accepts individual reads, the reference, and
* RODs per map call. Directly supports shared memory parallelism via NanoScheduler
*
* @author depristo
* @version 1.0
* @date Apr 24, 2009
* <p/>
* Class TraverseReads
* <p/>
* This class handles traversing by reads in the new shardable style
* @date 9/2/2012
*/
public class TraverseReadsNano<M,T> extends TraversalEngine<M,T,ReadWalker<M,T>,ReadShardDataProvider> {
/** our log, which we want to capture anything from this class */
protected static final Logger logger = Logger.getLogger(TraverseReadsNano.class);
private static final boolean DEBUG = false;
private static final int MIN_GROUP_SIZE = 100;
final NanoScheduler<MapData, M, T> nanoScheduler;
final NanoScheduler<MapData, MapResult, T> nanoScheduler;
public TraverseReadsNano(int nThreads) {
final int bufferSize = ReadShard.getReadBufferSize() + 1; // actually has 1 more than max
final int mapGroupSize = (int)Math.max(Math.ceil(bufferSize / 50.0 + 1), MIN_GROUP_SIZE);
nanoScheduler = new NanoScheduler<MapData, M, T>(bufferSize, mapGroupSize, nThreads);
nanoScheduler = new NanoScheduler<MapData, MapResult, T>(bufferSize, nThreads);
}
@Override
@ -89,19 +89,32 @@ public class TraverseReadsNano<M,T> extends TraversalEngine<M,T,ReadWalker<M,T>,
final TraverseReadsMap myMap = new TraverseReadsMap(walker);
final TraverseReadsReduce myReduce = new TraverseReadsReduce(walker);
T result = nanoScheduler.execute(aggregateMapData(dataProvider).iterator(), myMap, sum, myReduce);
// TODO -- how do we print progress?
//printProgress(dataProvider.getShard(), ???);
final List<MapData> aggregatedInputs = aggregateMapData(dataProvider);
final T result = nanoScheduler.execute(aggregatedInputs.iterator(), myMap, sum, myReduce);
final GATKSAMRecord lastRead = aggregatedInputs.get(aggregatedInputs.size() - 1).read;
final GenomeLoc locus = engine.getGenomeLocParser().createGenomeLoc(lastRead);
updateCumulativeMetrics(dataProvider.getShard());
printProgress(locus);
return result;
}
/**
* Aggregate all of the inputs for all map calls into MapData, to be provided
* to NanoScheduler for Map/Reduce
*
* @param dataProvider the source of our data
* @return a linked list of MapData objects holding the read, ref, and ROD info for every map/reduce
* should execute
*/
private List<MapData> aggregateMapData(final ReadShardDataProvider dataProvider) {
final ReadView reads = new ReadView(dataProvider);
final ReadReferenceView reference = new ReadReferenceView(dataProvider);
final ReadBasedReferenceOrderedView rodView = new ReadBasedReferenceOrderedView(dataProvider);
final List<MapData> mapData = new ArrayList<MapData>(); // TODO -- need size of reads
final List<MapData> mapData = new LinkedList<MapData>();
for ( final SAMRecord read : reads ) {
final ReferenceContext refContext = ! read.getReadUnmappedFlag()
? reference.getReferenceContext(read)
@ -127,19 +140,9 @@ public class TraverseReadsNano<M,T> extends TraversalEngine<M,T,ReadWalker<M,T>,
super.printOnTraversalDone();
}
private class TraverseReadsReduce implements ReduceFunction<M, T> {
final ReadWalker<M,T> walker;
private TraverseReadsReduce(ReadWalker<M, T> walker) {
this.walker = walker;
}
@Override
public T apply(M one, T sum) {
return walker.reduce(one, sum);
}
}
/**
* The input data needed for each map call. The read, the reference, and the RODs
*/
private class MapData {
final GATKSAMRecord read;
final ReferenceContext refContext;
@ -152,7 +155,43 @@ public class TraverseReadsNano<M,T> extends TraversalEngine<M,T,ReadWalker<M,T>,
}
}
private class TraverseReadsMap implements MapFunction<MapData, M> {
/**
* Contains the results of a map call, indicating whether the call was good, filtered, or done
*/
private class MapResult {
final M value;
final boolean reduceMe;
/**
* Create a MapResult with value that should be reduced
*
* @param value the value to reduce
*/
private MapResult(final M value) {
this.value = value;
this.reduceMe = true;
}
/**
* Create a MapResult that shouldn't be reduced
*/
private MapResult() {
this.value = null;
this.reduceMe = false;
}
}
/**
* A static object that tells reduce that the result of map should be skipped (filtered or done)
*/
private final MapResult SKIP_REDUCE = new MapResult();
/**
* MapFunction for TraverseReads meeting NanoScheduler interface requirements
*
* Applies walker.map to MapData, returning a MapResult object containing the result
*/
private class TraverseReadsMap implements NanoSchedulerMapFunction<MapData, MapResult> {
final ReadWalker<M,T> walker;
private TraverseReadsMap(ReadWalker<M, T> walker) {
@ -160,15 +199,36 @@ public class TraverseReadsNano<M,T> extends TraversalEngine<M,T,ReadWalker<M,T>,
}
@Override
public M apply(final MapData data) {
public MapResult apply(final MapData data) {
if ( ! walker.isDone() ) {
final boolean keepMeP = walker.filter(data.refContext, data.read);
if (keepMeP) {
return walker.map(data.refContext, data.read, data.tracker);
}
if (keepMeP)
return new MapResult(walker.map(data.refContext, data.read, data.tracker));
}
return null; // TODO -- what should we return in the case where the walker is done or the read is filtered?
return SKIP_REDUCE;
}
}
/**
* NanoSchedulerReduceFunction for TraverseReads meeting NanoScheduler interface requirements
*
* Takes a MapResult object and applies the walkers reduce function to each map result, when applicable
*/
private class TraverseReadsReduce implements NanoSchedulerReduceFunction<MapResult, T> {
final ReadWalker<M,T> walker;
private TraverseReadsReduce(ReadWalker<M, T> walker) {
this.walker = walker;
}
@Override
public T apply(MapResult one, T sum) {
if ( one.reduceMe )
// only run reduce on values that aren't DONE or FAILED
return walker.reduce(one.value, sum);
else
return sum;
}
}
}

View File

@ -45,7 +45,7 @@ import java.text.NumberFormat;
*/
@DocumentedGATKFeature( groupName = "Quality Control and Simple Analysis Tools", extraDocs = {CommandLineGATK.class} )
@Requires({DataSource.READS})
public class FlagStat extends ReadWalker<FlagStat.FlagStatus, FlagStat.FlagStatus> implements ThreadSafeMapReduce {
public class FlagStat extends ReadWalker<FlagStat.FlagStatus, FlagStat.FlagStatus> implements NanoSchedulable {
@Output
PrintStream out;

View File

@ -27,5 +27,5 @@ package org.broadinstitute.sting.gatk.walkers;
* declare that their map function is thread-safe and so multiple
* map calls can be run in parallel in the same JVM instance.
*/
public interface ThreadSafeMapReduce {
public interface NanoSchedulable {
}

View File

@ -52,7 +52,7 @@ import java.util.List;
* samtools pileup [-f in.ref.fasta] [-t in.ref_list] [-l in.site_list] [-iscg] [-T theta] [-N nHap] [-r pairDiffRate] <in.alignment>
*/
@DocumentedGATKFeature( groupName = "Quality Control and Simple Analysis Tools", extraDocs = {CommandLineGATK.class} )
public class Pileup extends LocusWalker<Integer, Integer> implements TreeReducible<Integer> {
public class Pileup extends LocusWalker<String, Integer> implements TreeReducible<Integer>, NanoSchedulable {
private static final String verboseDelimiter = "@"; // it's ugly to use "@" but it's literally the only usable character not allowed in read names
@ -70,27 +70,32 @@ public class Pileup extends LocusWalker<Integer, Integer> implements TreeReducib
@Input(fullName="metadata",shortName="metadata",doc="Add these ROD bindings to the output Pileup", required=false)
public List<RodBinding<Feature>> rods = Collections.emptyList();
public void initialize() {
}
public Integer map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) {
String rods = getReferenceOrderedData( tracker );
@Override
public String map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) {
final String rods = getReferenceOrderedData( tracker );
ReadBackedPileup basePileup = context.getBasePileup();
out.printf("%s %s", basePileup.getPileupString((char)ref.getBase()), rods);
if ( SHOW_VERBOSE )
out.printf(" %s", createVerboseOutput(basePileup));
out.println();
return 1;
final StringBuilder s = new StringBuilder();
s.append(String.format("%s %s", basePileup.getPileupString((char)ref.getBase()), rods));
if ( SHOW_VERBOSE )
s.append(" ").append(createVerboseOutput(basePileup));
s.append("\n");
return s.toString();
}
// Given result of map function
@Override
public Integer reduceInit() { return 0; }
public Integer reduce(Integer value, Integer sum) {
return treeReduce(sum,value);
@Override
public Integer reduce(String value, Integer sum) {
out.print(value);
return sum + 1;
}
@Override
public Integer treeReduce(Integer lhs, Integer rhs) {
return lhs + rhs;
}

View File

@ -93,7 +93,7 @@ import java.util.*;
@ReadTransformersMode(ApplicationTime = ReadTransformer.ApplicationTime.HANDLED_IN_WALKER)
@BAQMode(QualityMode = BAQ.QualityMode.ADD_TAG, ApplicationTime = ReadTransformer.ApplicationTime.HANDLED_IN_WALKER)
@Requires({DataSource.READS, DataSource.REFERENCE})
public class PrintReads extends ReadWalker<GATKSAMRecord, SAMFileWriter> implements ThreadSafeMapReduce {
public class PrintReads extends ReadWalker<GATKSAMRecord, SAMFileWriter> implements NanoSchedulable {
@Output(doc="Write output to this BAM filename instead of STDOUT", required = true)
SAMFileWriter out;
@ -228,7 +228,6 @@ public class PrintReads extends ReadWalker<GATKSAMRecord, SAMFileWriter> impleme
GATKSAMRecord workingRead = readIn;
for ( final ReadTransformer transformer : readTransformers ) {
if ( logger.isDebugEnabled() ) logger.debug("Applying transformer " + transformer + " to read " + readIn.getReadName());
workingRead = transformer.apply(workingRead);
}

View File

@ -13,7 +13,7 @@ package org.broadinstitute.sting.gatk.walkers;
* shards of the data can reduce with each other, and the composite result
* can be reduced with other composite results.
*/
public interface TreeReducible<ReduceType> extends ThreadSafeMapReduce {
public interface TreeReducible<ReduceType> {
/**
* A composite, 'reduce of reduces' function.
* @param lhs 'left-most' portion of data in the composite reduce.

View File

@ -109,7 +109,7 @@ import java.util.ArrayList;
@ReadFilters({MappingQualityZeroFilter.class, MappingQualityUnavailableFilter.class}) // only look at covered loci, not every loci of the reference file
@Requires({DataSource.READS, DataSource.REFERENCE}) // filter out all reads with zero or unavailable mapping quality
@PartitionBy(PartitionType.LOCUS) // this walker requires both -I input.bam and -R reference.fasta
public class BaseRecalibrator extends LocusWalker<Long, Long> implements TreeReducible<Long> {
public class BaseRecalibrator extends LocusWalker<Long, Long> implements TreeReducible<Long>, NanoSchedulable {
@ArgumentCollection
private final RecalibrationArgumentCollection RAC = new RecalibrationArgumentCollection(); // all the command line arguments for BQSR and it's covariates

View File

@ -125,7 +125,7 @@ import java.util.*;
// TODO -- When LocusIteratorByState gets cleaned up, we should enable multiple @By sources:
// TODO -- @By( {DataSource.READS, DataSource.REFERENCE_ORDERED_DATA} )
@Downsample(by=DownsampleType.BY_SAMPLE, toCoverage=250)
public class UnifiedGenotyper extends LocusWalker<List<VariantCallContext>, UnifiedGenotyper.UGStatistics> implements TreeReducible<UnifiedGenotyper.UGStatistics>, AnnotatorCompatible {
public class UnifiedGenotyper extends LocusWalker<List<VariantCallContext>, UnifiedGenotyper.UGStatistics> implements TreeReducible<UnifiedGenotyper.UGStatistics>, AnnotatorCompatible, NanoSchedulable {
@ArgumentCollection
private UnifiedArgumentCollection UAC = new UnifiedArgumentCollection();

View File

@ -6,6 +6,7 @@ import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.LocusWalker;
import org.broadinstitute.sting.gatk.walkers.NanoSchedulable;
import org.broadinstitute.sting.gatk.walkers.TreeReducible;
import org.broadinstitute.sting.utils.help.DocumentedGATKFeature;
@ -40,7 +41,7 @@ import java.io.PrintStream;
*
*/
@DocumentedGATKFeature( groupName = "Quality Control and Simple Analysis Tools", extraDocs = {CommandLineGATK.class} )
public class CountLoci extends LocusWalker<Integer, Long> implements TreeReducible<Long> {
public class CountLoci extends LocusWalker<Integer, Long> implements TreeReducible<Long>, NanoSchedulable {
@Output(doc="Write count to this file instead of STDOUT")
PrintStream out;

View File

@ -37,6 +37,7 @@ import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.refdata.utils.RODRecordList;
import org.broadinstitute.sting.gatk.walkers.NanoSchedulable;
import org.broadinstitute.sting.gatk.walkers.RodWalker;
import org.broadinstitute.sting.gatk.walkers.TreeReducible;
import org.broadinstitute.sting.utils.GenomeLoc;
@ -73,7 +74,7 @@ import java.util.*;
*
*/
@DocumentedGATKFeature( groupName = "Quality Control and Simple Analysis Tools", extraDocs = {CommandLineGATK.class} )
public class CountRODs extends RodWalker<CountRODs.Datum, Pair<ExpandingArrayList<Long>, Long>> implements TreeReducible<Pair<ExpandingArrayList<Long>, Long>> {
public class CountRODs extends RodWalker<CountRODs.Datum, Pair<ExpandingArrayList<Long>, Long>> implements TreeReducible<Pair<ExpandingArrayList<Long>, Long>>, NanoSchedulable {
@Output
public PrintStream out;

View File

@ -4,9 +4,9 @@ import org.broadinstitute.sting.gatk.CommandLineGATK;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.DataSource;
import org.broadinstitute.sting.gatk.walkers.NanoSchedulable;
import org.broadinstitute.sting.gatk.walkers.ReadWalker;
import org.broadinstitute.sting.gatk.walkers.Requires;
import org.broadinstitute.sting.gatk.walkers.ThreadSafeMapReduce;
import org.broadinstitute.sting.utils.help.DocumentedGATKFeature;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
@ -41,7 +41,7 @@ import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
*/
@DocumentedGATKFeature( groupName = "Quality Control and Simple Analysis Tools", extraDocs = {CommandLineGATK.class} )
@Requires({DataSource.READS, DataSource.REFERENCE})
public class CountReads extends ReadWalker<Integer, Integer> implements ThreadSafeMapReduce {
public class CountReads extends ReadWalker<Integer, Integer> implements NanoSchedulable {
public Integer map(ReferenceContext ref, GATKSAMRecord read, RefMetaDataTracker tracker) {
return 1;
}

View File

@ -1,18 +1,42 @@
package org.broadinstitute.sting.utils;
import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import java.util.concurrent.TimeUnit;
/**
* A useful simple system for timing code. This code is not thread safe!
* A useful simple system for timing code with nano second resolution
*
* Note that this code is not thread-safe. If you have a single timer
* being started and stopped by multiple threads you will need to protect the
* calls to avoid meaningless results of having multiple starts and stops
* called sequentially.
*
* User: depristo
* Date: Dec 10, 2010
* Time: 9:07:44 AM
*/
public class SimpleTimer {
final private String name;
private long elapsed = 0l;
private long startTime = 0l;
boolean running = false;
protected static final double NANO_TO_SECOND_DOUBLE = 1.0 / TimeUnit.SECONDS.toNanos(1);
private final String name;
/**
* The elapsedTimeNano time in nanoSeconds of this timer. The elapsedTimeNano time is the
* sum of times between starts/restrats and stops.
*/
private long elapsedTimeNano = 0l;
/**
* The start time of the last start/restart in nanoSeconds
*/
private long startTimeNano = 0l;
/**
* Is this timer currently running (i.e., the last call was start/restart)
*/
private boolean running = false;
/**
* Creates an anonymous simple timer
@ -25,7 +49,8 @@ public class SimpleTimer {
* Creates a simple timer named name
* @param name of the timer, must not be null
*/
public SimpleTimer(String name) {
public SimpleTimer(final String name) {
if ( name == null ) throw new IllegalArgumentException("SimpleTimer name cannot be null");
this.name = name;
}
@ -37,27 +62,27 @@ public class SimpleTimer {
}
/**
* Starts the timer running, and sets the elapsed time to 0. This is equivalent to
* Starts the timer running, and sets the elapsedTimeNano time to 0. This is equivalent to
* resetting the time to have no history at all.
*
* @return this object, for programming convenience
*/
@Ensures("elapsedTimeNano == 0l")
public synchronized SimpleTimer start() {
elapsed = 0l;
restart();
return this;
elapsedTimeNano = 0l;
return restart();
}
/**
* Starts the timer running, without reseting the elapsed time. This function may be
* Starts the timer running, without resetting the elapsedTimeNano time. This function may be
* called without first calling start(). The only difference between start and restart
* is that start resets the elapsed time, while restart does not.
* is that start resets the elapsedTimeNano time, while restart does not.
*
* @return this object, for programming convenience
*/
public synchronized SimpleTimer restart() {
running = true;
startTime = currentTime();
startTimeNano = currentTimeNano();
return this;
}
@ -71,29 +96,53 @@ public class SimpleTimer {
/**
* @return A convenience function to obtain the current time in milliseconds from this timer
*/
public synchronized long currentTime() {
public long currentTime() {
return System.currentTimeMillis();
}
/**
* Stops the timer. Increases the elapsed time by difference between start and now. The
* timer must be running in order to call stop
* @return A convenience function to obtain the current time in nanoSeconds from this timer
*/
public long currentTimeNano() {
return System.nanoTime();
}
/**
* Stops the timer. Increases the elapsedTimeNano time by difference between start and now.
*
* It's ok to call stop on a timer that's not running. It has no effect on the timer.
*
* @return this object, for programming convenience
*/
@Requires("startTimeNano != 0l")
public synchronized SimpleTimer stop() {
running = false;
elapsed += currentTime() - startTime;
if ( running ) {
running = false;
elapsedTimeNano += currentTimeNano() - startTimeNano;
}
return this;
}
/**
* Returns the total elapsed time of all start/stops of this timer. If the timer is currently
* Returns the total elapsedTimeNano time of all start/stops of this timer. If the timer is currently
* running, includes the difference from currentTime() and the start as well
*
* @return this time, in seconds
*/
public synchronized double getElapsedTime() {
return (running ? (currentTime() - startTime + elapsed) : elapsed) / 1000.0;
return nanoToSecondsAsDouble(getElapsedTimeNano());
}
protected static double nanoToSecondsAsDouble(final long nano) {
return nano * NANO_TO_SECOND_DOUBLE;
}
/**
* @see #getElapsedTime() but returns the result in nanoseconds
*
* @return the elapsed time in nanoseconds
*/
public synchronized long getElapsedTimeNano() {
return running ? (currentTimeNano() - startTimeNano + elapsedTimeNano) : elapsedTimeNano;
}
}

View File

@ -3,7 +3,9 @@ package org.broadinstitute.sting.utils.nanoScheduler;
import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.AutoFormattingTime;
import org.broadinstitute.sting.utils.SimpleTimer;
import org.broadinstitute.sting.utils.collections.Pair;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import java.util.Iterator;
@ -45,42 +47,41 @@ import java.util.concurrent.*;
public class NanoScheduler<InputType, MapType, ReduceType> {
private final static Logger logger = Logger.getLogger(NanoScheduler.class);
private final static boolean ALLOW_SINGLE_THREAD_FASTPATH = true;
private final static boolean LOG_MAP_TIMES = false;
private final static boolean TIME_CALLS = true;
final int bufferSize;
final int mapGroupSize;
final int nThreads;
final ExecutorService executor;
final ExecutorService inputExecutor;
final ExecutorService mapExecutor;
boolean shutdown = false;
boolean debug = false;
private NanoSchedulerProgressFunction<InputType> progressFunction = null;
final SimpleTimer outsideSchedulerTimer = new SimpleTimer("outside");
final SimpleTimer inputTimer = new SimpleTimer("input");
final SimpleTimer mapTimer = new SimpleTimer("map");
final SimpleTimer reduceTimer = new SimpleTimer("reduce");
/**
* Create a new nanoschedule with the desire characteristics requested by the argument
*
* @param bufferSize the number of input elements to read in each scheduling cycle.
* @param mapGroupSize How many inputs should be grouped together per map? If -1 we make a reasonable guess
* @param nThreads the number of threads to use to get work done, in addition to the thread calling execute
*/
public NanoScheduler(final int bufferSize,
final int mapGroupSize,
final int nThreads) {
if ( bufferSize < 1 ) throw new IllegalArgumentException("bufferSize must be >= 1, got " + bufferSize);
if ( nThreads < 1 ) throw new IllegalArgumentException("nThreads must be >= 1, got " + nThreads);
if ( mapGroupSize > bufferSize ) throw new IllegalArgumentException("mapGroupSize " + mapGroupSize + " must be <= bufferSize " + bufferSize);
if ( mapGroupSize == 0 || mapGroupSize < -1 ) throw new IllegalArgumentException("mapGroupSize cannot be <= 0" + mapGroupSize);
this.bufferSize = bufferSize;
this.nThreads = nThreads;
this.mapExecutor = nThreads == 1 ? null : Executors.newFixedThreadPool(nThreads-1);
this.inputExecutor = Executors.newSingleThreadExecutor();
if ( mapGroupSize == -1 ) {
this.mapGroupSize = (int)Math.ceil(this.bufferSize / (10.0*this.nThreads));
logger.info(String.format("Dynamically setting grouping size to %d based on buffer size %d and n threads %d",
this.mapGroupSize, this.bufferSize, this.nThreads));
} else {
this.mapGroupSize = mapGroupSize;
}
this.executor = nThreads == 1 ? null : Executors.newFixedThreadPool(nThreads);
// start timing the time spent outside of the nanoScheduler
outsideSchedulerTimer.start();
}
/**
@ -101,27 +102,35 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
return bufferSize;
}
/**
* The grouping size used by this NanoScheduler
* @return
*/
@Ensures("result > 0")
public int getMapGroupSize() {
return mapGroupSize;
}
/**
* Tells this nanoScheduler to shutdown immediately, releasing all its resources.
*
* After this call, execute cannot be invoked without throwing an error
*/
public void shutdown() {
if ( executor != null ) {
final List<Runnable> remaining = executor.shutdownNow();
outsideSchedulerTimer.stop();
if ( mapExecutor != null ) {
final List<Runnable> remaining = mapExecutor.shutdownNow();
if ( ! remaining.isEmpty() )
throw new IllegalStateException("Remaining tasks found in the executor, unexpected behavior!");
throw new IllegalStateException("Remaining tasks found in the mapExecutor, unexpected behavior!");
}
shutdown = true;
if (TIME_CALLS) {
printTimerInfo("Input time", inputTimer);
printTimerInfo("Map time", mapTimer);
printTimerInfo("Reduce time", reduceTimer);
printTimerInfo("Outside time", outsideSchedulerTimer);
}
}
private void printTimerInfo(final String label, final SimpleTimer timer) {
final double total = inputTimer.getElapsedTime() + mapTimer.getElapsedTime()
+ reduceTimer.getElapsedTime() + outsideSchedulerTimer.getElapsedTime();
final double myTimeInSec = timer.getElapsedTime();
final double myTimePercent = myTimeInSec / total * 100;
logger.info(String.format("%s: %s (%5.2f%%)", label, new AutoFormattingTime(myTimeInSec), myTimePercent));
}
/**
@ -145,6 +154,17 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
this.debug = debug;
}
/**
* Set the progress callback function to progressFunction
*
* The progress callback is invoked after each buffer size elements have been processed by map/reduce
*
* @param progressFunction a progress function to call, or null if you don't want any progress callback
*/
public void setProgressFunction(final NanoSchedulerProgressFunction<InputType> progressFunction) {
this.progressFunction = progressFunction;
}
/**
* Execute a map/reduce job with this nanoScheduler
*
@ -159,25 +179,31 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
* It is safe to call this function repeatedly on a single nanoScheduler, at least until the
* shutdown method is called.
*
* @param inputReader
* @param map
* @param reduce
* @return
* @param inputReader an iterator providing us with the input data to nanoSchedule map/reduce over
* @param map the map function from input type -> map type, will be applied in parallel to each input
* @param reduce the reduce function from map type + reduce type -> reduce type to be applied in order to map results
* @return the last reduce value
*/
public ReduceType execute(final Iterator<InputType> inputReader,
final MapFunction<InputType, MapType> map,
final NanoSchedulerMapFunction<InputType, MapType> map,
final ReduceType initialValue,
final ReduceFunction<MapType, ReduceType> reduce) {
final NanoSchedulerReduceFunction<MapType, ReduceType> reduce) {
if ( isShutdown() ) throw new IllegalStateException("execute called on already shutdown NanoScheduler");
if ( inputReader == null ) throw new IllegalArgumentException("inputReader cannot be null");
if ( map == null ) throw new IllegalArgumentException("map function cannot be null");
if ( reduce == null ) throw new IllegalArgumentException("reduce function cannot be null");
outsideSchedulerTimer.stop();
ReduceType result;
if ( ALLOW_SINGLE_THREAD_FASTPATH && getnThreads() == 1 ) {
return executeSingleThreaded(inputReader, map, initialValue, reduce);
result = executeSingleThreaded(inputReader, map, initialValue, reduce);
} else {
return executeMultiThreaded(inputReader, map, initialValue, reduce);
result = executeMultiThreaded(inputReader, map, initialValue, reduce);
}
outsideSchedulerTimer.restart();
return result;
}
/**
@ -185,15 +211,36 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
* @return the reduce result of this map/reduce job
*/
private ReduceType executeSingleThreaded(final Iterator<InputType> inputReader,
final MapFunction<InputType, MapType> map,
final NanoSchedulerMapFunction<InputType, MapType> map,
final ReduceType initialValue,
final ReduceFunction<MapType, ReduceType> reduce) {
final NanoSchedulerReduceFunction<MapType, ReduceType> reduce) {
ReduceType sum = initialValue;
int i = 0;
// start timer to ensure that both hasNext and next are caught by the timer
if ( TIME_CALLS ) inputTimer.restart();
while ( inputReader.hasNext() ) {
final InputType input = inputReader.next();
if ( TIME_CALLS ) inputTimer.stop();
// map
if ( TIME_CALLS ) mapTimer.restart();
final long preMapTime = LOG_MAP_TIMES ? 0 : mapTimer.currentTimeNano();
final MapType mapValue = map.apply(input);
if ( LOG_MAP_TIMES ) logger.info("MAP TIME " + (mapTimer.currentTimeNano() - preMapTime));
if ( TIME_CALLS ) mapTimer.stop();
if ( i++ % bufferSize == 0 && progressFunction != null )
progressFunction.progress(input);
// reduce
if ( TIME_CALLS ) reduceTimer.restart();
sum = reduce.apply(mapValue, sum);
if ( TIME_CALLS ) reduceTimer.stop();
if ( TIME_CALLS ) inputTimer.restart();
}
return sum;
}
@ -203,21 +250,36 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
* @return the reduce result of this map/reduce job
*/
private ReduceType executeMultiThreaded(final Iterator<InputType> inputReader,
final MapFunction<InputType, MapType> map,
final NanoSchedulerMapFunction<InputType, MapType> map,
final ReduceType initialValue,
final ReduceFunction<MapType, ReduceType> reduce) {
final NanoSchedulerReduceFunction<MapType, ReduceType> reduce) {
debugPrint("Executing nanoScheduler");
ReduceType sum = initialValue;
while ( inputReader.hasNext() ) {
boolean done = false;
final BlockingQueue<InputDatum> inputQueue = new LinkedBlockingDeque<InputDatum>(bufferSize);
inputExecutor.submit(new InputProducer(inputReader, inputQueue));
while ( ! done ) {
try {
// read in our input values
final List<InputType> inputs = readInputs(inputReader);
final Pair<List<InputType>, Boolean> readResults = readInputs(inputQueue);
final List<InputType> inputs = readResults.getFirst();
done = readResults.getSecond();
// send jobs for map
final Queue<Future<List<MapType>>> mapQueue = submitMapJobs(map, executor, inputs);
if ( ! inputs.isEmpty() ) {
// send jobs for map
final Queue<Future<MapType>> mapQueue = submitMapJobs(map, mapExecutor, inputs);
// send off the reduce job, and block until we get at least one reduce result
sum = reduceParallel(reduce, mapQueue, sum);
// send off the reduce job, and block until we get at least one reduce result
sum = reduceSerial(reduce, mapQueue, sum);
debugPrint(" Done with cycle of map/reduce");
if ( progressFunction != null ) progressFunction.progress(inputs.get(inputs.size()-1));
} else {
// we must be done
if ( ! done ) throw new IllegalStateException("Inputs empty but not done");
}
} catch (InterruptedException ex) {
throw new ReviewedStingException("got execution exception", ex);
} catch (ExecutionException ex) {
@ -229,16 +291,19 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
}
@Requires({"reduce != null", "! mapQueue.isEmpty()"})
private ReduceType reduceParallel(final ReduceFunction<MapType, ReduceType> reduce,
final Queue<Future<List<MapType>>> mapQueue,
final ReduceType initSum)
private ReduceType reduceSerial(final NanoSchedulerReduceFunction<MapType, ReduceType> reduce,
final Queue<Future<MapType>> mapQueue,
final ReduceType initSum)
throws InterruptedException, ExecutionException {
ReduceType sum = initSum;
// while mapQueue has something in it to reduce
for ( final Future<List<MapType>> future : mapQueue ) {
for ( final MapType value : future.get() ) // block until we get the values for this task
sum = reduce.apply(value, sum);
for ( final Future<MapType> future : mapQueue ) {
final MapType value = future.get(); // block until we get the values for this task
if ( TIME_CALLS ) reduceTimer.restart();
sum = reduce.apply(value, sum);
if ( TIME_CALLS ) reduceTimer.stop();
}
return sum;
@ -247,30 +312,81 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
/**
* Read up to inputBufferSize elements from inputReader
*
* @return a queue of inputs read in, containing one or more values of InputType read in
* @return a queue of input read in, containing one or more values of InputType read in
*/
@Requires("inputReader.hasNext()")
@Ensures("!result.isEmpty()")
private List<InputType> readInputs(final Iterator<InputType> inputReader) {
@Requires("inputReader != null")
@Ensures("result != null")
private Pair<List<InputType>, Boolean> readInputs(final BlockingQueue<InputDatum> inputReader) throws InterruptedException {
int n = 0;
final List<InputType> inputs = new LinkedList<InputType>();
while ( inputReader.hasNext() && n < getBufferSize() ) {
final InputType input = inputReader.next();
inputs.add(input);
n++;
boolean done = false;
while ( ! done && n < getBufferSize() ) {
final InputDatum input = inputReader.take();
done = input.isLast();
if ( ! done ) {
inputs.add(input.datum);
n++;
}
}
return new Pair<List<InputType>, Boolean>(inputs, done);
}
private class InputProducer implements Runnable {
final Iterator<InputType> inputReader;
final BlockingQueue<InputDatum> outputQueue;
public InputProducer(final Iterator<InputType> inputReader, final BlockingQueue<InputDatum> outputQueue) {
this.inputReader = inputReader;
this.outputQueue = outputQueue;
}
public void run() {
try {
while ( inputReader.hasNext() ) {
if ( TIME_CALLS ) inputTimer.restart();
final InputType input = inputReader.next();
if ( TIME_CALLS ) inputTimer.stop();
outputQueue.put(new InputDatum(input));
}
// add the EOF object so we know we are done
outputQueue.put(new InputDatum());
} catch (InterruptedException ex) {
throw new ReviewedStingException("got execution exception", ex);
}
}
}
private class InputDatum {
final boolean isLast;
final InputType datum;
private InputDatum(final InputType datum) {
isLast = false;
this.datum = datum;
}
private InputDatum() {
isLast = true;
this.datum = null;
}
public boolean isLast() {
return isLast;
}
return inputs;
}
@Requires({"map != null", "! inputs.isEmpty()"})
private Queue<Future<List<MapType>>> submitMapJobs(final MapFunction<InputType, MapType> map,
final ExecutorService executor,
final List<InputType> inputs) {
final Queue<Future<List<MapType>>> mapQueue = new LinkedList<Future<List<MapType>>>();
private Queue<Future<MapType>> submitMapJobs(final NanoSchedulerMapFunction<InputType, MapType> map,
final ExecutorService executor,
final List<InputType> inputs) {
final Queue<Future<MapType>> mapQueue = new LinkedList<Future<MapType>>();
for ( final List<InputType> subinputs : Utils.groupList(inputs, getMapGroupSize()) ) {
final CallableMap doMap = new CallableMap(map, subinputs);
final Future<List<MapType>> future = executor.submit(doMap);
for ( final InputType input : inputs ) {
final CallableMap doMap = new CallableMap(map, input);
final Future<MapType> future = executor.submit(doMap);
mapQueue.add(future);
}
@ -280,23 +396,22 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
/**
* A simple callable version of the map function for use with the executor pool
*/
private class CallableMap implements Callable<List<MapType>> {
final List<InputType> inputs;
final MapFunction<InputType, MapType> map;
private class CallableMap implements Callable<MapType> {
final InputType input;
final NanoSchedulerMapFunction<InputType, MapType> map;
@Requires({"map != null", "inputs.size() <= getMapGroupSize()"})
private CallableMap(final MapFunction<InputType, MapType> map, final List<InputType> inputs) {
this.inputs = inputs;
@Requires({"map != null"})
private CallableMap(final NanoSchedulerMapFunction<InputType, MapType> map, final InputType inputs) {
this.input = inputs;
this.map = map;
}
@Ensures("result.size() == inputs.size()")
@Override public List<MapType> call() throws Exception {
final List<MapType> outputs = new LinkedList<MapType>();
for ( final InputType input : inputs )
outputs.add(map.apply(input));
debugPrint(" Processed %d elements with map", outputs.size());
return outputs;
@Override public MapType call() throws Exception {
if ( TIME_CALLS ) mapTimer.restart();
if ( debug ) debugPrint("\t\tmap " + input);
final MapType result = map.apply(input);
if ( TIME_CALLS ) mapTimer.stop();
return result;
}
}
}

View File

@ -9,7 +9,7 @@ package org.broadinstitute.sting.utils.nanoScheduler;
* Date: 8/24/12
* Time: 9:49 AM
*/
public interface MapFunction<InputType, ResultType> {
public interface NanoSchedulerMapFunction<InputType, ResultType> {
/**
* Return function on input, returning a value of ResultType
* @param input

View File

@ -0,0 +1,12 @@
package org.broadinstitute.sting.utils.nanoScheduler;
/**
* Created with IntelliJ IDEA.
* User: depristo
* Date: 9/4/12
* Time: 2:10 PM
* To change this template use File | Settings | File Templates.
*/
public interface NanoSchedulerProgressFunction<InputType> {
public void progress(final InputType lastMapInput);
}

View File

@ -7,7 +7,7 @@ package org.broadinstitute.sting.utils.nanoScheduler;
* Date: 8/24/12
* Time: 9:49 AM
*/
public interface ReduceFunction<MapType, ReduceType> {
public interface NanoSchedulerReduceFunction<MapType, ReduceType> {
/**
* Combine one with sum into a new ReduceType
* @param one the result of a map call on an input element

View File

@ -40,13 +40,13 @@ import org.broadinstitute.sting.utils.collections.Pair;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import org.broadinstitute.sting.utils.exceptions.StingException;
import org.broadinstitute.sting.utils.variantcontext.VariantContextTestProvider;
import java.io.*;
import org.testng.Assert;
import org.testng.annotations.AfterSuite;
import org.testng.annotations.BeforeMethod;
import java.io.ByteArrayOutputStream;
import java.io.File;
import java.io.PrintStream;
import java.text.SimpleDateFormat;
import java.util.*;
@ -251,20 +251,43 @@ public class WalkerTest extends BaseTest {
return false;
}
protected Pair<List<File>, List<String>> executeTestParallel(final String name, WalkerTestSpec spec) {
return executeTest(name, spec, Arrays.asList(1, 4));
public enum ParallelTestType {
TREE_REDUCIBLE,
NANO_SCHEDULED,
BOTH
}
protected Pair<List<File>, List<String>> executeTest(final String name, WalkerTestSpec spec, List<Integer> parallelThreads) {
protected Pair<List<File>, List<String>> executeTestParallel(final String name, WalkerTestSpec spec, ParallelTestType testType) {
final List<Integer> ntThreads = testType == ParallelTestType.TREE_REDUCIBLE || testType == ParallelTestType.BOTH ? Arrays.asList(1, 4) : Collections.<Integer>emptyList();
final List<Integer> cntThreads = testType == ParallelTestType.NANO_SCHEDULED || testType == ParallelTestType.BOTH ? Arrays.asList(1, 4) : Collections.<Integer>emptyList();
return executeTest(name, spec, ntThreads, cntThreads);
}
protected Pair<List<File>, List<String>> executeTestParallel(final String name, WalkerTestSpec spec) {
return executeTestParallel(name, spec, ParallelTestType.TREE_REDUCIBLE);
}
protected Pair<List<File>, List<String>> executeTest(final String name, WalkerTestSpec spec, List<Integer> ntThreads, List<Integer> cpuThreads) {
String originalArgs = spec.args;
Pair<List<File>, List<String>> results = null;
for ( int nt : parallelThreads ) {
boolean ran1 = false;
for ( int nt : ntThreads ) {
String extra = nt == 1 ? "" : (" -nt " + nt);
ran1 = ran1 || nt == 1;
spec.args = originalArgs + extra;
results = executeTest(name + "-nt-" + nt, spec);
}
for ( int cnt : cpuThreads ) {
if ( cnt != 1 ) {
String extra = " -cnt " + cnt;
spec.args = originalArgs + extra;
results = executeTest(name + "-cnt-" + cnt, spec);
}
}
return results;
}

View File

@ -1,12 +1,12 @@
package org.broadinstitute.sting.utils;
import org.broadinstitute.sting.BaseTest;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import java.io.File;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;
public class SimpleTimerUnitTest extends BaseTest {
private final static String NAME = "unit.test.timer";
@ -17,33 +17,88 @@ public class SimpleTimerUnitTest extends BaseTest {
Assert.assertEquals(t.getName(), NAME, "Name is not the provided one");
Assert.assertFalse(t.isRunning(), "Initial state of the timer is running");
Assert.assertEquals(t.getElapsedTime(), 0.0, "New timer elapsed time should be 0");
Assert.assertEquals(t.getElapsedTimeNano(), 0l, "New timer elapsed time nano should be 0");
t.start();
Assert.assertTrue(t.isRunning(), "Started timer isn't running");
Assert.assertTrue(t.getElapsedTime() >= 0.0, "Elapsed time should be >= 0");
Assert.assertTrue(t.getElapsedTimeNano() >= 0.0, "Elapsed time nano should be >= 0");
long n1 = t.getElapsedTimeNano();
double t1 = t.getElapsedTime();
idleLoop(); // idle loop to wait a tiny bit of time
long n2 = t.getElapsedTimeNano();
double t2 = t.getElapsedTime();
Assert.assertTrue(t2 >= t1, "T2 >= T1 for a running time");
Assert.assertTrue(n2 >= n1, "T2 >= T1 nano for a running time");
t.stop();
Assert.assertFalse(t.isRunning(), "Stopped timer still running");
long n3 = t.getElapsedTimeNano();
double t3 = t.getElapsedTime();
idleLoop(); // idle loop to wait a tiny bit of time
double t4 = t.getElapsedTime();
long n4 = t.getElapsedTimeNano();
Assert.assertTrue(t4 == t3, "Elapsed times for two calls of stop timer not the same");
Assert.assertTrue(n4 == n3, "Elapsed times for two calls of stop timer not the same");
t.restart();
idleLoop(); // idle loop to wait a tiny bit of time
double t5 = t.getElapsedTime();
long n5 = t.getElapsedTimeNano();
Assert.assertTrue(t.isRunning(), "Restarted timer should be running");
idleLoop(); // idle loop to wait a tiny bit of time
double t6 = t.getElapsedTime();
long n6 = t.getElapsedTimeNano();
Assert.assertTrue(t5 >= t4, "Restarted timer elapsed time should be after elapsed time preceding the restart");
Assert.assertTrue(t6 >= t5, "Second elapsed time not after the first in restarted timer");
Assert.assertTrue(n5 >= n4, "Restarted timer elapsed time nano should be after elapsed time preceding the restart");
Assert.assertTrue(n6 >= n5, "Second elapsed time nano not after the first in restarted timer");
final List<Double> secondTimes = Arrays.asList(t1, t2, t3, t4, t5, t6);
final List<Long> nanoTimes = Arrays.asList(n1, n2, n3, n4, n5, n6);
for ( int i = 0; i < nanoTimes.size(); i++ )
Assert.assertEquals(
SimpleTimer.nanoToSecondsAsDouble(nanoTimes.get(i)),
secondTimes.get(i), 1e-1, "Nanosecond and second timer disagree");
}
private final static void idleLoop() {
@Test
public void testNanoResolution() {
SimpleTimer t = new SimpleTimer(NAME);
// test the nanosecond resolution
long n7 = t.currentTimeNano();
int sum = 0;
for ( int i = 0; i < 100; i++) sum += i;
long n8 = t.currentTimeNano();
final long delta = n8 - n7;
final long oneMilliInNano = TimeUnit.MILLISECONDS.toNanos(1);
logger.warn("nanoTime before nano operation " + n7);
logger.warn("nanoTime after nano operation of summing 100 ints " + n8 + ", sum = " + sum + " time delta " + delta + " vs. 1 millsecond in nano " + oneMilliInNano);
Assert.assertTrue(n8 > n7, "SimpleTimer doesn't appear to have nanoSecond resolution: n8 " + n8 + " <= n7 " + n7);
Assert.assertTrue(delta < oneMilliInNano,
"SimpleTimer doesn't appear to have nanoSecond resolution: time delta is " + delta + " vs 1 millisecond in nano " + oneMilliInNano);
}
@Test
public void testMeaningfulTimes() {
SimpleTimer t = new SimpleTimer(NAME);
t.start();
for ( int i = 0; i < 100; i++ ) ;
long nano = t.getElapsedTimeNano();
double secs = t.getElapsedTime();
Assert.assertTrue(secs > 0, "Seconds timer doesn't appear to count properly: elapsed time is " + secs);
Assert.assertTrue(secs < 0.01, "Fast operation said to take longer than 10 milliseconds: elapsed time in seconds " + secs);
Assert.assertTrue(nano > 0, "Nanosecond timer doesn't appear to count properly: elapsed time is " + nano);
final long maxTimeInMicro = 100;
final long maxTimeInNano = TimeUnit.MICROSECONDS.toNanos(100);
Assert.assertTrue(nano < maxTimeInNano, "Fast operation said to take longer than " + maxTimeInMicro + " microseconds: elapsed time in nano " + nano + " micro " + TimeUnit.NANOSECONDS.toMicros(nano));
}
private static void idleLoop() {
for ( int i = 0; i < 100000; i++ ) ; // idle loop to wait a tiny bit of time
}
}

View File

@ -5,7 +5,10 @@ import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.util.*;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Iterator;
import java.util.List;
/**
* UnitTests for the NanoScheduler
@ -18,11 +21,11 @@ import java.util.*;
public class NanoSchedulerUnitTest extends BaseTest {
public static final int NANO_SCHEDULE_MAX_RUNTIME = 60000;
private static class Map2x implements MapFunction<Integer, Integer> {
private static class Map2x implements NanoSchedulerMapFunction<Integer, Integer> {
@Override public Integer apply(Integer input) { return input * 2; }
}
private static class ReduceSum implements ReduceFunction<Integer, Integer> {
private static class ReduceSum implements NanoSchedulerReduceFunction<Integer, Integer> {
int prevOne = Integer.MIN_VALUE;
@Override public Integer apply(Integer one, Integer sum) {
@ -31,6 +34,16 @@ public class NanoSchedulerUnitTest extends BaseTest {
}
}
private static class ProgressCallback implements NanoSchedulerProgressFunction<Integer> {
int callBacks = 0;
@Override
public void progress(Integer lastMapInput) {
callBacks++;
}
}
private static int sum2x(final int start, final int end) {
int sum = 0;
for ( int i = start; i < end; i++ )
@ -39,18 +52,17 @@ public class NanoSchedulerUnitTest extends BaseTest {
}
private static class NanoSchedulerBasicTest extends TestDataProvider {
final int bufferSize, mapGroupSize, nThreads, start, end, expectedResult;
final int bufferSize, nThreads, start, end, expectedResult;
public NanoSchedulerBasicTest(final int bufferSize, final int mapGroupSize, final int nThreads, final int start, final int end) {
public NanoSchedulerBasicTest(final int bufferSize, final int nThreads, final int start, final int end) {
super(NanoSchedulerBasicTest.class);
this.bufferSize = bufferSize;
this.mapGroupSize = mapGroupSize;
this.nThreads = nThreads;
this.start = start;
this.end = end;
this.expectedResult = sum2x(start, end);
setName(String.format("%s nt=%d buf=%d mapGroupSize=%d start=%d end=%d sum=%d",
getClass().getSimpleName(), nThreads, bufferSize, mapGroupSize, start, end, expectedResult));
setName(String.format("%s nt=%d buf=%d start=%d end=%d sum=%d",
getClass().getSimpleName(), nThreads, bufferSize, start, end, expectedResult));
}
public Iterator<Integer> makeReader() {
@ -60,6 +72,11 @@ public class NanoSchedulerUnitTest extends BaseTest {
return ints.iterator();
}
public int nExpectedCallbacks() {
int nElements = Math.max(end - start, 0);
return nElements / bufferSize;
}
public Map2x makeMap() { return new Map2x(); }
public Integer initReduce() { return 0; }
public ReduceSum makeReduce() { return new ReduceSum(); }
@ -69,14 +86,10 @@ public class NanoSchedulerUnitTest extends BaseTest {
@DataProvider(name = "NanoSchedulerBasicTest")
public Object[][] createNanoSchedulerBasicTest() {
for ( final int bufferSize : Arrays.asList(1, 10, 1000, 1000000) ) {
for ( final int mapGroupSize : Arrays.asList(-1, 1, 10, 100, 1000) ) {
if ( mapGroupSize <= bufferSize ) {
for ( final int nt : Arrays.asList(1, 2, 4) ) {
for ( final int start : Arrays.asList(0) ) {
for ( final int end : Arrays.asList(1, 2, 11, 10000, 100000) ) {
exampleTest = new NanoSchedulerBasicTest(bufferSize, mapGroupSize, nt, start, end);
}
}
for ( final int nt : Arrays.asList(1, 2, 4) ) {
for ( final int start : Arrays.asList(0) ) {
for ( final int end : Arrays.asList(0, 1, 2, 11, 10000, 100000) ) {
exampleTest = new NanoSchedulerBasicTest(bufferSize, nt, start, end);
}
}
}
@ -101,25 +114,29 @@ public class NanoSchedulerUnitTest extends BaseTest {
private void testNanoScheduler(final NanoSchedulerBasicTest test) throws InterruptedException {
final NanoScheduler<Integer, Integer, Integer> nanoScheduler =
new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.mapGroupSize, test.nThreads);
new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.nThreads);
final ProgressCallback callback = new ProgressCallback();
nanoScheduler.setProgressFunction(callback);
Assert.assertEquals(nanoScheduler.getBufferSize(), test.bufferSize, "bufferSize argument");
Assert.assertTrue(nanoScheduler.getMapGroupSize() >= test.mapGroupSize, "mapGroupSize argument");
Assert.assertEquals(nanoScheduler.getnThreads(), test.nThreads, "nThreads argument");
final Integer sum = nanoScheduler.execute(test.makeReader(), test.makeMap(), test.initReduce(), test.makeReduce());
Assert.assertNotNull(sum);
Assert.assertEquals((int)sum, test.expectedResult, "NanoScheduler sum not the same as calculated directly");
Assert.assertTrue(callback.callBacks >= test.nExpectedCallbacks(), "Not enough callbacks detected. Expected at least " + test.nExpectedCallbacks() + " but saw only " + callback.callBacks);
nanoScheduler.shutdown();
}
@Test(enabled = true, dataProvider = "NanoSchedulerBasicTest", dependsOnMethods = "testMultiThreadedNanoScheduler", timeOut = NANO_SCHEDULE_MAX_RUNTIME)
public void testNanoSchedulerInLoop(final NanoSchedulerBasicTest test) throws InterruptedException {
if ( test.bufferSize > 1 && (test.mapGroupSize > 1 || test.mapGroupSize == -1)) {
if ( test.bufferSize > 1) {
logger.warn("Running " + test);
final NanoScheduler<Integer, Integer, Integer> nanoScheduler =
new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.mapGroupSize, test.nThreads);
new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.nThreads);
// test reusing the scheduler
for ( int i = 0; i < 10; i++ ) {
@ -134,7 +151,7 @@ public class NanoSchedulerUnitTest extends BaseTest {
@Test(timeOut = NANO_SCHEDULE_MAX_RUNTIME)
public void testShutdown() throws InterruptedException {
final NanoScheduler<Integer, Integer, Integer> nanoScheduler = new NanoScheduler<Integer, Integer, Integer>(1, 1, 2);
final NanoScheduler<Integer, Integer, Integer> nanoScheduler = new NanoScheduler<Integer, Integer, Integer>(1, 2);
Assert.assertFalse(nanoScheduler.isShutdown(), "scheduler should be alive");
nanoScheduler.shutdown();
Assert.assertTrue(nanoScheduler.isShutdown(), "scheduler should be dead");
@ -142,15 +159,16 @@ public class NanoSchedulerUnitTest extends BaseTest {
@Test(expectedExceptions = IllegalStateException.class, timeOut = NANO_SCHEDULE_MAX_RUNTIME)
public void testShutdownExecuteFailure() throws InterruptedException {
final NanoScheduler<Integer, Integer, Integer> nanoScheduler = new NanoScheduler<Integer, Integer, Integer>(1, 1, 2);
final NanoScheduler<Integer, Integer, Integer> nanoScheduler = new NanoScheduler<Integer, Integer, Integer>(1, 2);
nanoScheduler.shutdown();
nanoScheduler.execute(exampleTest.makeReader(), exampleTest.makeMap(), exampleTest.initReduce(), exampleTest.makeReduce());
}
public static void main(String [ ] args) {
final NanoSchedulerBasicTest test = new NanoSchedulerBasicTest(1000, 100, Integer.valueOf(args[0]), 0, Integer.valueOf(args[1]));
final NanoSchedulerBasicTest test = new NanoSchedulerBasicTest(1000, Integer.valueOf(args[0]), 0, Integer.valueOf(args[1]));
final NanoScheduler<Integer, Integer, Integer> nanoScheduler =
new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.mapGroupSize, test.nThreads);
new NanoScheduler<Integer, Integer, Integer>(test.bufferSize, test.nThreads);
nanoScheduler.setDebug(true);
final Integer sum = nanoScheduler.execute(test.makeReader(), test.makeMap(), test.initReduce(), test.makeReduce());
System.out.printf("Sum = %d, expected =%d%n", sum, test.expectedResult);