diff --git a/build.xml b/build.xml index f681ddafa..0d1deba29 100644 --- a/build.xml +++ b/build.xml @@ -577,6 +577,7 @@ docletpathref="doclet.classpath" classpathref="external.dependencies" classpath="${java.classes}" + maxmemory="2g" additionalparam="-build-timestamp "${build.timestamp}" -absolute-version ${build.version} -out ${basedir}/${resource.path} -quiet"> @@ -780,6 +781,7 @@ docletpathref="doclet.classpath" classpathref="external.dependencies" classpath="${java.classes}" + maxmemory="2g" additionalparam="${gatkdocs.include.hidden.arg} -private -build-timestamp "${build.timestamp}" -absolute-version ${build.version} -quiet"> diff --git a/public/java/src/org/broadinstitute/sting/gatk/GenomeAnalysisEngine.java b/public/java/src/org/broadinstitute/sting/gatk/GenomeAnalysisEngine.java index b9b5e452d..fa28b02cd 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/GenomeAnalysisEngine.java +++ b/public/java/src/org/broadinstitute/sting/gatk/GenomeAnalysisEngine.java @@ -143,6 +143,8 @@ public class GenomeAnalysisEngine { */ private ThreadAllocation threadAllocation; + private ReadMetrics cumulativeMetrics = null; + /** * A currently hacky unique name for this GATK instance */ @@ -398,28 +400,22 @@ public class GenomeAnalysisEngine { * Parse out the thread allocation from the given command-line argument. */ private void determineThreadAllocation() { - Tags tags = parsingEngine.getTags(argCollection.numberOfThreads); + if ( argCollection.numberOfDataThreads < 1 ) throw new UserException.BadArgumentValue("num_threads", "cannot be less than 1, but saw " + argCollection.numberOfDataThreads); + if ( argCollection.numberOfCPUThreadsPerDataThread < 1 ) throw new UserException.BadArgumentValue("num_cpu_threads", "cannot be less than 1, but saw " + argCollection.numberOfCPUThreadsPerDataThread); + if ( argCollection.numberOfIOThreads < 0 ) throw new UserException.BadArgumentValue("num_io_threads", "cannot be less than 0, but saw " + argCollection.numberOfIOThreads); - // TODO: Kill this complicated logic once Queue supports arbitrary tagged parameters. - Integer numCPUThreads = null; - if(tags.containsKey("cpu") && argCollection.numberOfCPUThreads != null) - throw new UserException("Number of CPU threads specified both directly on the command-line and as a tag to the nt argument. Please specify only one or the other."); - else if(tags.containsKey("cpu")) - numCPUThreads = Integer.parseInt(tags.getValue("cpu")); - else if(argCollection.numberOfCPUThreads != null) - numCPUThreads = argCollection.numberOfCPUThreads; - - Integer numIOThreads = null; - if(tags.containsKey("io") && argCollection.numberOfIOThreads != null) - throw new UserException("Number of IO threads specified both directly on the command-line and as a tag to the nt argument. Please specify only one or the other."); - else if(tags.containsKey("io")) - numIOThreads = Integer.parseInt(tags.getValue("io")); - else if(argCollection.numberOfIOThreads != null) - numIOThreads = argCollection.numberOfIOThreads; - - this.threadAllocation = new ThreadAllocation(argCollection.numberOfThreads, numCPUThreads, numIOThreads, ! argCollection.disableEfficiencyMonitor); + this.threadAllocation = new ThreadAllocation(argCollection.numberOfDataThreads, + argCollection.numberOfCPUThreadsPerDataThread, + argCollection.numberOfIOThreads, + ! argCollection.disableEfficiencyMonitor); } + public int getTotalNumberOfThreads() { + return this.threadAllocation == null ? 1 : threadAllocation.getTotalNumThreads(); + } + + + /** * Allow subclasses and others within this package direct access to the walker manager. * @return The walker manager used by this package. @@ -1035,7 +1031,10 @@ public class GenomeAnalysisEngine { * owned by the caller; the caller can do with the object what they wish. */ public ReadMetrics getCumulativeMetrics() { - return readsDataSource == null ? null : readsDataSource.getCumulativeReadMetrics(); + // todo -- probably shouldn't be lazy + if ( cumulativeMetrics == null ) + cumulativeMetrics = readsDataSource == null ? new ReadMetrics() : readsDataSource.getCumulativeReadMetrics(); + return cumulativeMetrics; } /** diff --git a/public/java/src/org/broadinstitute/sting/gatk/ReadMetrics.java b/public/java/src/org/broadinstitute/sting/gatk/ReadMetrics.java index ceaa30f01..bfea0b1e1 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/ReadMetrics.java +++ b/public/java/src/org/broadinstitute/sting/gatk/ReadMetrics.java @@ -27,7 +27,6 @@ package org.broadinstitute.sting.gatk; import net.sf.picard.filter.SamRecordFilter; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; -import java.util.Collections; import java.util.HashMap; import java.util.Map; import java.util.TreeMap; @@ -119,11 +118,18 @@ public class ReadMetrics implements Cloneable { return nRecords; } + /** + * Increments the number of 'iterations' (one call of filter/map/reduce sequence) completed. + */ + public void incrementNumIterations(final long by) { + nRecords += by; + } + /** * Increments the number of 'iterations' (one call of filter/map/reduce sequence) completed. */ public void incrementNumIterations() { - nRecords++; + incrementNumIterations(1); } public long getNumReadsSeen() { diff --git a/public/java/src/org/broadinstitute/sting/gatk/arguments/GATKArgumentCollection.java b/public/java/src/org/broadinstitute/sting/gatk/arguments/GATKArgumentCollection.java index 72cb5e02f..b9e44d87b 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/arguments/GATKArgumentCollection.java +++ b/public/java/src/org/broadinstitute/sting/gatk/arguments/GATKArgumentCollection.java @@ -41,7 +41,9 @@ import org.broadinstitute.sting.utils.interval.IntervalMergingRule; import org.broadinstitute.sting.utils.interval.IntervalSetRule; import java.io.File; -import java.util.*; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; /** * @author aaron @@ -197,6 +199,12 @@ public class GATKArgumentCollection { // performance log arguments // // -------------------------------------------------------------------------------------------------------------- + + /** + * The file name for the GATK performance log output, or null if you don't want to generate the + * detailed performance logging table. This table is suitable for importing into R or any + * other analysis software that can read tsv files + */ @Argument(fullName = "performanceLog", shortName="PF", doc="If provided, a GATK runtime performance log will be written to this file", required = false) public File performanceLog = null; @@ -279,9 +287,32 @@ public class GATKArgumentCollection { @Argument(fullName = "unsafe", shortName = "U", doc = "If set, enables unsafe operations: nothing will be checked at runtime. For expert users only who know what they are doing. We do not support usage of this argument.", required = false) public ValidationExclusion.TYPE unsafe; - /** How many threads should be allocated to this analysis. */ - @Argument(fullName = "num_threads", shortName = "nt", doc = "How many threads should be allocated to running this analysis.", required = false) - public Integer numberOfThreads = 1; + // -------------------------------------------------------------------------------------------------------------- + // + // Multi-threading arguments + // + // -------------------------------------------------------------------------------------------------------------- + + /** + * How many data threads should be allocated to this analysis? Data threads contains N cpu threads per + * data thread, and act as completely data parallel processing, increasing the memory usage of GATK + * by M data threads. Data threads generally scale extremely effectively, up to 24 cores + */ + @Argument(fullName = "num_threads", shortName = "nt", doc = "How many data threads should be allocated to running this analysis.", required = false) + public Integer numberOfDataThreads = 1; + + /** + * How many CPU threads should be allocated per data thread? Each CPU thread operates the map + * cycle independently, but may run into earlier scaling problems with IO than data threads. Has + * the benefit of not requiring X times as much memory per thread as data threads do, but rather + * only a constant overhead. + */ + @Argument(fullName="num_cpu_threads_per_data_thread", shortName = "cnt", doc="How many CPU threads should be allocated per data thread to running this analysis?", required = false) + public int numberOfCPUThreadsPerDataThread = 1; + + @Argument(fullName="num_io_threads", shortName = "nit", doc="How many of the given threads should be allocated to IO", required = false) + @Hidden + public int numberOfIOThreads = 0; /** * By default the GATK monitors its own efficiency, but this can have a itsy-bitsy tiny @@ -291,17 +322,6 @@ public class GATKArgumentCollection { @Argument(fullName = "disableThreadEfficiencyMonitor", shortName = "dtem", doc = "Disable GATK efficiency monitoring", required = false) public Boolean disableEfficiencyMonitor = false; - /** - * The following two arguments (num_cpu_threads, num_io_threads are TEMPORARY since Queue cannot currently support arbitrary tagged data types. - * TODO: Kill this when I can do a tagged integer in Queue. - */ - @Argument(fullName="num_cpu_threads", shortName = "nct", doc="How many of the given threads should be allocated to the CPU", required = false) - @Hidden - public Integer numberOfCPUThreads = null; - @Argument(fullName="num_io_threads", shortName = "nit", doc="How many of the given threads should be allocated to IO", required = false) - @Hidden - public Integer numberOfIOThreads = null; - @Argument(fullName = "num_bam_file_handles", shortName = "bfh", doc="The total number of BAM file handles to keep open simultaneously", required=false) public Integer numberOfBAMFileHandles = null; diff --git a/public/java/src/org/broadinstitute/sting/gatk/executive/HierarchicalMicroScheduler.java b/public/java/src/org/broadinstitute/sting/gatk/executive/HierarchicalMicroScheduler.java index 9198d210d..f1d2f7b5b 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/executive/HierarchicalMicroScheduler.java +++ b/public/java/src/org/broadinstitute/sting/gatk/executive/HierarchicalMicroScheduler.java @@ -8,6 +8,7 @@ import org.broadinstitute.sting.gatk.datasources.reads.Shard; import org.broadinstitute.sting.gatk.datasources.rmd.ReferenceOrderedDataSource; import org.broadinstitute.sting.gatk.io.OutputTracker; import org.broadinstitute.sting.gatk.io.ThreadLocalOutputTracker; +import org.broadinstitute.sting.gatk.resourcemanagement.ThreadAllocation; import org.broadinstitute.sting.gatk.walkers.TreeReducible; import org.broadinstitute.sting.gatk.walkers.Walker; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; @@ -76,21 +77,21 @@ public class HierarchicalMicroScheduler extends MicroScheduler implements Hierar /** * Create a new hierarchical microscheduler to process the given reads and reference. * - * @param walker the walker used to process the dataset. - * @param reads Reads file(s) to process. - * @param reference Reference for driving the traversal. - * @param nThreadsToUse maximum number of threads to use to do the work + * @param walker the walker used to process the dataset. + * @param reads Reads file(s) to process. + * @param reference Reference for driving the traversal. + * @param threadAllocation How should we apply multi-threaded execution? */ protected HierarchicalMicroScheduler(final GenomeAnalysisEngine engine, final Walker walker, final SAMDataSource reads, final IndexedFastaSequenceFile reference, final Collection rods, - final int nThreadsToUse, - final boolean monitorThreadPerformance ) { - super(engine, walker, reads, reference, rods, nThreadsToUse); + final ThreadAllocation threadAllocation) { + super(engine, walker, reads, reference, rods, threadAllocation); - if ( monitorThreadPerformance ) { + final int nThreadsToUse = threadAllocation.getNumDataThreads(); + if ( threadAllocation.monitorThreadEfficiency() ) { final EfficiencyMonitoringThreadFactory monitoringThreadFactory = new EfficiencyMonitoringThreadFactory(nThreadsToUse); setThreadEfficiencyMonitor(monitoringThreadFactory); this.threadPool = Executors.newFixedThreadPool(nThreadsToUse, monitoringThreadFactory); diff --git a/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java b/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java index 5bcb16c94..ceb4a6f9b 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java +++ b/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java @@ -10,6 +10,7 @@ import org.broadinstitute.sting.gatk.datasources.reads.Shard; import org.broadinstitute.sting.gatk.datasources.rmd.ReferenceOrderedDataSource; import org.broadinstitute.sting.gatk.io.DirectOutputTracker; import org.broadinstitute.sting.gatk.io.OutputTracker; +import org.broadinstitute.sting.gatk.resourcemanagement.ThreadAllocation; import org.broadinstitute.sting.gatk.traversals.TraverseActiveRegions; import org.broadinstitute.sting.gatk.walkers.Walker; import org.broadinstitute.sting.utils.SampleUtils; @@ -39,13 +40,11 @@ public class LinearMicroScheduler extends MicroScheduler { final SAMDataSource reads, final IndexedFastaSequenceFile reference, final Collection rods, - final int numThreads, // may be > 1 if are nanoScheduling - final boolean monitorThreadPerformance ) { - super(engine, walker, reads, reference, rods, numThreads); + final ThreadAllocation threadAllocation) { + super(engine, walker, reads, reference, rods, threadAllocation); - if ( monitorThreadPerformance ) + if ( threadAllocation.monitorThreadEfficiency() ) setThreadEfficiencyMonitor(new ThreadEfficiencyMonitor()); - } /** @@ -60,11 +59,12 @@ public class LinearMicroScheduler extends MicroScheduler { boolean done = walker.isDone(); int counter = 0; + + traversalEngine.startTimersIfNecessary(); for (Shard shard : shardStrategy ) { if ( done || shard == null ) // we ran out of shards that aren't owned break; - traversalEngine.startTimersIfNecessary(); if(shard.getShardType() == Shard.ShardType.LOCUS) { WindowMaker windowMaker = new WindowMaker(shard, engine.getGenomeLocParser(), getReadIterator(shard), shard.getGenomeLocs(), SampleUtils.getSAMFileSamples(engine)); diff --git a/public/java/src/org/broadinstitute/sting/gatk/executive/MicroScheduler.java b/public/java/src/org/broadinstitute/sting/gatk/executive/MicroScheduler.java index 417a0982f..1da712e8a 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/executive/MicroScheduler.java +++ b/public/java/src/org/broadinstitute/sting/gatk/executive/MicroScheduler.java @@ -59,6 +59,8 @@ import java.util.Collection; /** Shards and schedules data in manageable chunks. */ public abstract class MicroScheduler implements MicroSchedulerMBean { + // TODO -- remove me and retire non nano scheduled versions of traversals + private final static boolean USE_NANOSCHEDULER_FOR_EVERYTHING = true; protected static final Logger logger = Logger.getLogger(MicroScheduler.class); /** @@ -100,27 +102,30 @@ public abstract class MicroScheduler implements MicroSchedulerMBean { * @return The best-fit microscheduler. */ public static MicroScheduler create(GenomeAnalysisEngine engine, Walker walker, SAMDataSource reads, IndexedFastaSequenceFile reference, Collection rods, ThreadAllocation threadAllocation) { - if (threadAllocation.getNumCPUThreads() > 1) { + if ( threadAllocation.isRunningInParallelMode() ) + logger.info(String.format("Running the GATK in parallel mode with %d CPU thread(s) for each of %d data thread(s)", + threadAllocation.getNumCPUThreadsPerDataThread(), threadAllocation.getNumDataThreads())); + + if ( threadAllocation.getNumDataThreads() > 1 ) { if (walker.isReduceByInterval()) throw new UserException.BadArgumentValue("nt", String.format("The analysis %s aggregates results by interval. Due to a current limitation of the GATK, analyses of this type do not currently support parallel execution. Please run your analysis without the -nt option.", engine.getWalkerName(walker.getClass()))); - logger.info(String.format("Running the GATK in parallel mode with %d concurrent threads",threadAllocation.getNumCPUThreads())); - - if ( walker instanceof ReadWalker ) { - if ( ! (walker instanceof ThreadSafeMapReduce) ) badNT(engine, walker); - return new LinearMicroScheduler(engine, walker, reads, reference, rods, threadAllocation.getNumCPUThreads(), threadAllocation.monitorThreadEfficiency()); + if ( ! (walker instanceof TreeReducible) ) { + throw badNT("nt", engine, walker); } else { - // TODO -- update test for when nano scheduling only is an option - if ( ! (walker instanceof TreeReducible) ) badNT(engine, walker); - return new HierarchicalMicroScheduler(engine, walker, reads, reference, rods, threadAllocation.getNumCPUThreads(), threadAllocation.monitorThreadEfficiency()); + return new HierarchicalMicroScheduler(engine, walker, reads, reference, rods, threadAllocation); } } else { - return new LinearMicroScheduler(engine, walker, reads, reference, rods, threadAllocation.getNumCPUThreads(), threadAllocation.monitorThreadEfficiency()); + if ( threadAllocation.getNumCPUThreadsPerDataThread() > 1 && ! (walker instanceof NanoSchedulable) ) + throw badNT("cnt", engine, walker); + return new LinearMicroScheduler(engine, walker, reads, reference, rods, threadAllocation); } } - private static void badNT(final GenomeAnalysisEngine engine, final Walker walker) { - throw new UserException.BadArgumentValue("nt", String.format("The analysis %s currently does not support parallel execution. Please run your analysis without the -nt option.", engine.getWalkerName(walker.getClass()))); + private static UserException badNT(final String parallelArg, final GenomeAnalysisEngine engine, final Walker walker) { + throw new UserException.BadArgumentValue("nt", + String.format("The analysis %s currently does not support parallel execution with %s. " + + "Please run your analysis without the %s option.", engine.getWalkerName(walker.getClass()), parallelArg, parallelArg)); } /** @@ -130,23 +135,27 @@ public abstract class MicroScheduler implements MicroSchedulerMBean { * @param reads The reads. * @param reference The reference. * @param rods the rods to include in the traversal - * @param numThreads the number of threads we are using in the underlying traversal + * @param threadAllocation the allocation of threads to use in the underlying traversal */ protected MicroScheduler(final GenomeAnalysisEngine engine, final Walker walker, final SAMDataSource reads, final IndexedFastaSequenceFile reference, final Collection rods, - final int numThreads) { + final ThreadAllocation threadAllocation) { this.engine = engine; this.reads = reads; this.reference = reference; this.rods = rods; if (walker instanceof ReadWalker) { - traversalEngine = numThreads > 1 ? new TraverseReadsNano(numThreads) : new TraverseReads(); + traversalEngine = USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1 + ? new TraverseReadsNano(threadAllocation.getNumCPUThreadsPerDataThread()) + : new TraverseReads(); } else if (walker instanceof LocusWalker) { - traversalEngine = new TraverseLoci(); + traversalEngine = USE_NANOSCHEDULER_FOR_EVERYTHING || threadAllocation.getNumCPUThreadsPerDataThread() > 1 + ? new TraverseLociNano(threadAllocation.getNumCPUThreadsPerDataThread()) + : new TraverseLociLinear(); } else if (walker instanceof DuplicateWalker) { traversalEngine = new TraverseDuplicates(); } else if (walker instanceof ReadPairWalker) { diff --git a/public/java/src/org/broadinstitute/sting/gatk/io/stubs/VariantContextWriterStub.java b/public/java/src/org/broadinstitute/sting/gatk/io/stubs/VariantContextWriterStub.java index 260a7efda..ee1dc63e6 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/io/stubs/VariantContextWriterStub.java +++ b/public/java/src/org/broadinstitute/sting/gatk/io/stubs/VariantContextWriterStub.java @@ -32,9 +32,9 @@ import org.broadinstitute.sting.utils.classloader.JVMUtils; import org.broadinstitute.sting.utils.codecs.vcf.VCFHeader; import org.broadinstitute.sting.utils.codecs.vcf.VCFHeaderLine; import org.broadinstitute.sting.utils.codecs.vcf.VCFUtils; +import org.broadinstitute.sting.utils.variantcontext.VariantContext; import org.broadinstitute.sting.utils.variantcontext.writer.Options; import org.broadinstitute.sting.utils.variantcontext.writer.VariantContextWriter; -import org.broadinstitute.sting.utils.variantcontext.VariantContext; import org.broadinstitute.sting.utils.variantcontext.writer.VariantContextWriterFactory; import java.io.File; @@ -269,7 +269,7 @@ public class VariantContextWriterStub implements Stub, Var * @return */ public boolean alsoWriteBCFForTest() { - return engine.getArguments().numberOfThreads == 1 && // only works single threaded + return engine.getArguments().numberOfDataThreads == 1 && // only works single threaded ! isCompressed() && // for non-compressed outputs getFile() != null && // that are going to disk engine.getArguments().generateShadowBCF; // and we actually want to do it diff --git a/public/java/src/org/broadinstitute/sting/gatk/phonehome/GATKRunReport.java b/public/java/src/org/broadinstitute/sting/gatk/phonehome/GATKRunReport.java index 6f3f175a2..51fed470f 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/phonehome/GATKRunReport.java +++ b/public/java/src/org/broadinstitute/sting/gatk/phonehome/GATKRunReport.java @@ -218,7 +218,7 @@ public class GATKRunReport { // if there was an exception, capture it this.mException = e == null ? null : new ExceptionToXML(e); - numThreads = engine.getArguments().numberOfThreads; + numThreads = engine.getTotalNumberOfThreads(); percentTimeRunning = getThreadEfficiencyPercent(engine, ThreadEfficiencyMonitor.State.USER_CPU); percentTimeBlocking = getThreadEfficiencyPercent(engine, ThreadEfficiencyMonitor.State.BLOCKING); percentTimeWaiting = getThreadEfficiencyPercent(engine, ThreadEfficiencyMonitor.State.WAITING); diff --git a/public/java/src/org/broadinstitute/sting/gatk/resourcemanagement/ThreadAllocation.java b/public/java/src/org/broadinstitute/sting/gatk/resourcemanagement/ThreadAllocation.java index caae55ac5..c86f06c25 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/resourcemanagement/ThreadAllocation.java +++ b/public/java/src/org/broadinstitute/sting/gatk/resourcemanagement/ThreadAllocation.java @@ -24,7 +24,7 @@ package org.broadinstitute.sting.gatk.resourcemanagement; -import org.broadinstitute.sting.utils.exceptions.UserException; +import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; /** * Models how threads are distributed between various components of the GATK. @@ -33,7 +33,12 @@ public class ThreadAllocation { /** * The number of CPU threads to be used by the GATK. */ - private final int numCPUThreads; + private final int numDataThreads; + + /** + * The number of CPU threads per data thread for GATK processing + */ + private final int numCPUThreadsPerDataThread; /** * Number of threads to devote exclusively to IO. Default is 0. @@ -45,8 +50,12 @@ public class ThreadAllocation { */ private final boolean monitorEfficiency; - public int getNumCPUThreads() { - return numCPUThreads; + public int getNumDataThreads() { + return numDataThreads; + } + + public int getNumCPUThreadsPerDataThread() { + return numCPUThreadsPerDataThread; } public int getNumIOThreads() { @@ -57,47 +66,50 @@ public class ThreadAllocation { return monitorEfficiency; } + /** + * Are we running in parallel mode? + * + * @return true if any parallel processing is enabled + */ + public boolean isRunningInParallelMode() { + return getTotalNumThreads() > 1; + } + + /** + * What is the total number of threads in use by the GATK? + * + * @return the sum of all thread allocations in this object + */ + public int getTotalNumThreads() { + return getNumDataThreads() * getNumCPUThreadsPerDataThread() + getNumIOThreads(); + } + /** * Construct the default thread allocation. */ public ThreadAllocation() { - this(1, null, null, false); + this(1, 1, 0, false); } /** * Set up the thread allocation. Default allocation is 1 CPU thread, 0 IO threads. * (0 IO threads means that no threads are devoted exclusively to IO; they're inline on the CPU thread). - * @param totalThreads Complete number of threads to allocate. - * @param numCPUThreads Total number of threads allocated to the traversal. + * @param numDataThreads Total number of threads allocated to the traversal. + * @param numCPUThreadsPerDataThread The number of CPU threads per data thread to allocate * @param numIOThreads Total number of threads allocated exclusively to IO. + * @param monitorEfficiency should we monitor threading efficiency in the GATK? */ - public ThreadAllocation(final int totalThreads, final Integer numCPUThreads, final Integer numIOThreads, final boolean monitorEfficiency) { - // If no allocation information is present, allocate all threads to CPU - if(numCPUThreads == null && numIOThreads == null) { - this.numCPUThreads = totalThreads; - this.numIOThreads = 0; - } - // If only CPU threads are specified, allocate remainder to IO (minimum 0 dedicated IO threads). - else if(numIOThreads == null) { - if(numCPUThreads > totalThreads) - throw new UserException(String.format("Invalid thread allocation. User requested %d threads in total, but the count of cpu threads (%d) is higher than the total threads",totalThreads,numCPUThreads)); - this.numCPUThreads = numCPUThreads; - this.numIOThreads = totalThreads - numCPUThreads; - } - // If only IO threads are specified, allocate remainder to CPU (minimum 1 dedicated CPU thread). - else if(numCPUThreads == null) { - if(numIOThreads > totalThreads) - throw new UserException(String.format("Invalid thread allocation. User requested %d threads in total, but the count of io threads (%d) is higher than the total threads",totalThreads,numIOThreads)); - this.numCPUThreads = Math.max(1,totalThreads-numIOThreads); - this.numIOThreads = numIOThreads; - } - else { - if(numCPUThreads + numIOThreads != totalThreads) - throw new UserException(String.format("Invalid thread allocation. User requested %d threads in total, but the count of cpu threads (%d) + the count of io threads (%d) does not match",totalThreads,numCPUThreads,numIOThreads)); - this.numCPUThreads = numCPUThreads; - this.numIOThreads = numIOThreads; - } + public ThreadAllocation(final int numDataThreads, + final int numCPUThreadsPerDataThread, + final int numIOThreads, + final boolean monitorEfficiency) { + if ( numDataThreads < 1 ) throw new ReviewedStingException("numDataThreads cannot be less than 1, but saw " + numDataThreads); + if ( numCPUThreadsPerDataThread < 1 ) throw new ReviewedStingException("numCPUThreadsPerDataThread cannot be less than 1, but saw " + numCPUThreadsPerDataThread); + if ( numIOThreads < 0 ) throw new ReviewedStingException("numIOThreads cannot be less than 0, but saw " + numIOThreads); + this.numDataThreads = numDataThreads; + this.numCPUThreadsPerDataThread = numCPUThreadsPerDataThread; + this.numIOThreads = numIOThreads; this.monitorEfficiency = monitorEfficiency; } } diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraversalEngine.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraversalEngine.java index abc71e549..8c617e4dc 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraversalEngine.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraversalEngine.java @@ -44,24 +44,12 @@ import java.util.List; import java.util.Map; public abstract class TraversalEngine,ProviderType extends ShardDataProvider> { + /** our log, which we want to capture anything from this class */ + protected static final Logger logger = Logger.getLogger(TraversalEngine.class); + // Time in milliseconds since we initialized this engine private static final int HISTORY_WINDOW_SIZE = 50; - private static class ProcessingHistory { - double elapsedSeconds; - long unitsProcessed; - long bpProcessed; - GenomeLoc loc; - - public ProcessingHistory(double elapsedSeconds, GenomeLoc loc, long unitsProcessed, long bpProcessed) { - this.elapsedSeconds = elapsedSeconds; - this.loc = loc; - this.unitsProcessed = unitsProcessed; - this.bpProcessed = bpProcessed; - } - - } - /** lock object to sure updates to history are consistent across threads */ private static final Object lock = new Object(); LinkedList history = new LinkedList(); @@ -70,13 +58,12 @@ public abstract class TraversalEngine,Provide private SimpleTimer timer = null; // How long can we go without printing some progress info? - private static final int PRINT_PROGRESS_CHECK_FREQUENCY_IN_CYCLES = 1000; - private int printProgressCheckCounter = 0; private long lastProgressPrintTime = -1; // When was the last time we printed progress log? - private long MIN_ELAPSED_TIME_BEFORE_FIRST_PROGRESS = 30 * 1000; // in milliseconds - private long PROGRESS_PRINT_FREQUENCY = 10 * 1000; // in milliseconds - private final double TWO_HOURS_IN_SECONDS = 2.0 * 60.0 * 60.0; - private final double TWELVE_HOURS_IN_SECONDS = 12.0 * 60.0 * 60.0; + + private final static long MIN_ELAPSED_TIME_BEFORE_FIRST_PROGRESS = 30 * 1000; // in milliseconds + private final static double TWO_HOURS_IN_SECONDS = 2.0 * 60.0 * 60.0; + private final static double TWELVE_HOURS_IN_SECONDS = 12.0 * 60.0 * 60.0; + private long progressPrintFrequency = 10 * 1000; // in milliseconds private boolean progressMeterInitialized = false; // for performance log @@ -85,15 +72,12 @@ public abstract class TraversalEngine,Provide private File performanceLogFile; private PrintStream performanceLog = null; private long lastPerformanceLogPrintTime = -1; // When was the last time we printed to the performance log? - private final long PERFORMANCE_LOG_PRINT_FREQUENCY = PROGRESS_PRINT_FREQUENCY; // in milliseconds + private final long PERFORMANCE_LOG_PRINT_FREQUENCY = progressPrintFrequency; // in milliseconds /** Size, in bp, of the area we are processing. Updated once in the system in initial for performance reasons */ long targetSize = -1; GenomeLocSortedSet targetIntervals = null; - /** our log, which we want to capture anything from this class */ - protected static final Logger logger = Logger.getLogger(TraversalEngine.class); - protected GenomeAnalysisEngine engine; // ---------------------------------------------------------------------------------------------------- @@ -186,15 +170,35 @@ public abstract class TraversalEngine,Provide return elapsed > printFreq && elapsed > MIN_ELAPSED_TIME_BEFORE_FIRST_PROGRESS; } + /** + * Update the cumulative traversal metrics according to the data in this shard + * + * @param shard a non-null shard + */ + public void updateCumulativeMetrics(final Shard shard) { + updateCumulativeMetrics(shard.getReadMetrics()); + } + + /** + * Update the cumulative traversal metrics according to the data in this shard + * + * @param singleTraverseMetrics read metrics object containing the information about a single shard's worth + * of data processing + */ + public void updateCumulativeMetrics(final ReadMetrics singleTraverseMetrics) { + engine.getCumulativeMetrics().incrementMetrics(singleTraverseMetrics); + } + /** * Forward request to printProgress * - * @param shard the given shard currently being processed. + * Assumes that one cycle has been completed + * * @param loc the location */ - public void printProgress(Shard shard, GenomeLoc loc) { + public void printProgress(final GenomeLoc loc) { // A bypass is inserted here for unit testing. - printProgress(loc,shard.getReadMetrics(),false); + printProgress(loc, false); } /** @@ -202,15 +206,10 @@ public abstract class TraversalEngine,Provide * every M seconds, for N and M set in global variables. * * @param loc Current location, can be null if you are at the end of the traversal - * @param metrics Data processed since the last cumulative * @param mustPrint If true, will print out info, regardless of nRecords or time interval */ - private void printProgress(GenomeLoc loc, ReadMetrics metrics, boolean mustPrint) { - if ( mustPrint || printProgressCheckCounter++ % PRINT_PROGRESS_CHECK_FREQUENCY_IN_CYCLES != 0 ) - // don't do any work more often than PRINT_PROGRESS_CHECK_FREQUENCY_IN_CYCLES - return; - - if(!progressMeterInitialized && mustPrint == false ) { + private synchronized void printProgress(final GenomeLoc loc, boolean mustPrint) { + if( ! progressMeterInitialized ) { logger.info("[INITIALIZATION COMPLETE; TRAVERSAL STARTING]"); logger.info(String.format("%15s processed.%s runtime per.1M.%s completed total.runtime remaining", "Location", getTraversalType(), getTraversalType())); @@ -218,40 +217,34 @@ public abstract class TraversalEngine,Provide } final long curTime = timer.currentTime(); - boolean printProgress = mustPrint || maxElapsedIntervalForPrinting(curTime, lastProgressPrintTime, PROGRESS_PRINT_FREQUENCY); + boolean printProgress = mustPrint || maxElapsedIntervalForPrinting(curTime, lastProgressPrintTime, progressPrintFrequency); boolean printLog = performanceLog != null && maxElapsedIntervalForPrinting(curTime, lastPerformanceLogPrintTime, PERFORMANCE_LOG_PRINT_FREQUENCY); if ( printProgress || printLog ) { - // getting and appending metrics data actually turns out to be quite a heavyweight - // operation. Postpone it until after determining whether to print the log message. - ReadMetrics cumulativeMetrics = engine.getCumulativeMetrics() != null ? engine.getCumulativeMetrics() : new ReadMetrics(); - if(metrics != null) - cumulativeMetrics.incrementMetrics(metrics); - - final long nRecords = cumulativeMetrics.getNumIterations(); - - ProcessingHistory last = updateHistory(loc,cumulativeMetrics); + final ProcessingHistory last = updateHistory(loc, engine.getCumulativeMetrics()); final AutoFormattingTime elapsed = new AutoFormattingTime(last.elapsedSeconds); - final AutoFormattingTime bpRate = new AutoFormattingTime(secondsPerMillionBP(last)); - final AutoFormattingTime unitRate = new AutoFormattingTime(secondsPerMillionElements(last)); - final double fractionGenomeTargetCompleted = calculateFractionGenomeTargetCompleted(last); + final AutoFormattingTime bpRate = new AutoFormattingTime(last.secondsPerMillionBP()); + final AutoFormattingTime unitRate = new AutoFormattingTime(last.secondsPerMillionElements()); + final double fractionGenomeTargetCompleted = last.calculateFractionGenomeTargetCompleted(targetSize); final AutoFormattingTime estTotalRuntime = new AutoFormattingTime(elapsed.getTimeInSeconds() / fractionGenomeTargetCompleted); final AutoFormattingTime timeToCompletion = new AutoFormattingTime(estTotalRuntime.getTimeInSeconds() - elapsed.getTimeInSeconds()); + final long nRecords = engine.getCumulativeMetrics().getNumIterations(); if ( printProgress ) { lastProgressPrintTime = curTime; // dynamically change the update rate so that short running jobs receive frequent updates while longer jobs receive fewer updates if ( estTotalRuntime.getTimeInSeconds() > TWELVE_HOURS_IN_SECONDS ) - PROGRESS_PRINT_FREQUENCY = 60 * 1000; // in milliseconds + progressPrintFrequency = 60 * 1000; // in milliseconds else if ( estTotalRuntime.getTimeInSeconds() > TWO_HOURS_IN_SECONDS ) - PROGRESS_PRINT_FREQUENCY = 30 * 1000; // in milliseconds + progressPrintFrequency = 30 * 1000; // in milliseconds else - PROGRESS_PRINT_FREQUENCY = 10 * 1000; // in milliseconds + progressPrintFrequency = 10 * 1000; // in milliseconds - logger.info(String.format("%15s %5.2e %s %s %4.1f%% %s %s", - loc == null ? "done with mapped reads" : loc, nRecords*1.0, elapsed, unitRate, + final String posName = loc == null ? (mustPrint ? "done" : "unmapped reads") : String.format("%s:%d", loc.getContig(), loc.getStart()); + logger.info(String.format("%15s %5.2e %s %s %5.1f%% %s %s", + posName, nRecords*1.0, elapsed, unitRate, 100*fractionGenomeTargetCompleted, estTotalRuntime, timeToCompletion)); } @@ -277,7 +270,7 @@ public abstract class TraversalEngine,Provide * @param metrics information about what's been processed already * @return */ - private final ProcessingHistory updateHistory(GenomeLoc loc, ReadMetrics metrics) { + private ProcessingHistory updateHistory(GenomeLoc loc, ReadMetrics metrics) { synchronized (lock) { if ( history.size() > HISTORY_WINDOW_SIZE ) history.pop(); @@ -290,26 +283,11 @@ public abstract class TraversalEngine,Provide } } - /** How long in seconds to process 1M traversal units? */ - private final double secondsPerMillionElements(ProcessingHistory last) { - return (last.elapsedSeconds * 1000000.0) / Math.max(last.unitsProcessed, 1); - } - - /** How long in seconds to process 1M bp on the genome? */ - private final double secondsPerMillionBP(ProcessingHistory last) { - return (last.elapsedSeconds * 1000000.0) / Math.max(last.bpProcessed, 1); - } - - /** What fractoin of the target intervals have we covered? */ - private final double calculateFractionGenomeTargetCompleted(ProcessingHistory last) { - return (1.0*last.bpProcessed) / targetSize; - } - /** * Called after a traversal to print out information about the traversal process */ public void printOnTraversalDone() { - printProgress(null, null, true); + printProgress(null, true); final double elapsed = timer == null ? 0 : timer.getElapsedTime(); @@ -370,7 +348,7 @@ public abstract class TraversalEngine,Provide * @return Frequency, in seconds, of performance log writes. */ public long getPerformanceProgressPrintFrequencySeconds() { - return PROGRESS_PRINT_FREQUENCY; + return progressPrintFrequency; } /** @@ -378,6 +356,35 @@ public abstract class TraversalEngine,Provide * @param seconds number of seconds between messages indicating performance frequency. */ public void setPerformanceProgressPrintFrequencySeconds(long seconds) { - PROGRESS_PRINT_FREQUENCY = seconds; + progressPrintFrequency = seconds; + } + + private static class ProcessingHistory { + double elapsedSeconds; + long unitsProcessed; + long bpProcessed; + GenomeLoc loc; + + public ProcessingHistory(double elapsedSeconds, GenomeLoc loc, long unitsProcessed, long bpProcessed) { + this.elapsedSeconds = elapsedSeconds; + this.loc = loc; + this.unitsProcessed = unitsProcessed; + this.bpProcessed = bpProcessed; + } + + /** How long in seconds to process 1M traversal units? */ + private double secondsPerMillionElements() { + return (elapsedSeconds * 1000000.0) / Math.max(unitsProcessed, 1); + } + + /** How long in seconds to process 1M bp on the genome? */ + private double secondsPerMillionBP() { + return (elapsedSeconds * 1000000.0) / Math.max(bpProcessed, 1); + } + + /** What fractoin of the target intervals have we covered? */ + private double calculateFractionGenomeTargetCompleted(final long targetSize) { + return (1.0*bpProcessed) / targetSize; + } } } diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseActiveRegions.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseActiveRegions.java index ecaa15fe9..bbd9346b3 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseActiveRegions.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseActiveRegions.java @@ -104,7 +104,8 @@ public class TraverseActiveRegions extends TraversalEngine extends TraversalEngine extends TraversalEngine,LocusShardDataProvider> { +public abstract class TraverseLociBase extends TraversalEngine,LocusShardDataProvider> { /** * our log, which we want to capture anything from this class */ protected static final Logger logger = Logger.getLogger(TraversalEngine.class); @Override - protected String getTraversalType() { + protected final String getTraversalType() { return "sites"; } + protected static class TraverseResults { + final int numIterations; + final T reduceResult; + + public TraverseResults(int numIterations, T reduceResult) { + this.numIterations = numIterations; + this.reduceResult = reduceResult; + } + } + + protected abstract TraverseResults traverse( final LocusWalker walker, + final LocusView locusView, + final LocusReferenceView referenceView, + final ReferenceOrderedView referenceOrderedDataView, + final T sum); + @Override public T traverse( LocusWalker walker, LocusShardDataProvider dataProvider, T sum) { - logger.debug(String.format("TraverseLoci.traverse: Shard is %s", dataProvider)); + logger.debug(String.format("TraverseLociBase.traverse: Shard is %s", dataProvider)); - LocusView locusView = getLocusView( walker, dataProvider ); - boolean done = false; + final LocusView locusView = getLocusView( walker, dataProvider ); if ( locusView.hasNext() ) { // trivial optimization to avoid unnecessary processing when there's nothing here at all - //ReferenceOrderedView referenceOrderedDataView = new ReferenceOrderedView( dataProvider ); ReferenceOrderedView referenceOrderedDataView = null; if ( WalkerManager.getWalkerDataSource(walker) != DataSource.REFERENCE_ORDERED_DATA ) @@ -44,43 +56,24 @@ public class TraverseLoci extends TraversalEngine,Locu else referenceOrderedDataView = (RodLocusView)locusView; - LocusReferenceView referenceView = new LocusReferenceView( walker, dataProvider ); + final LocusReferenceView referenceView = new LocusReferenceView( walker, dataProvider ); - // We keep processing while the next reference location is within the interval - while( locusView.hasNext() && ! done ) { - AlignmentContext locus = locusView.next(); - GenomeLoc location = locus.getLocation(); - - dataProvider.getShard().getReadMetrics().incrementNumIterations(); - - // create reference context. Note that if we have a pileup of "extended events", the context will - // hold the (longest) stretch of deleted reference bases (if deletions are present in the pileup). - ReferenceContext refContext = referenceView.getReferenceContext(location); - - // Iterate forward to get all reference ordered data covering this location - final RefMetaDataTracker tracker = referenceOrderedDataView.getReferenceOrderedDataAtLocus(locus.getLocation(), refContext); - - final boolean keepMeP = walker.filter(tracker, refContext, locus); - if (keepMeP) { - M x = walker.map(tracker, refContext, locus); - sum = walker.reduce(x, sum); - done = walker.isDone(); - } - - printProgress(dataProvider.getShard(),locus.getLocation()); - } + final TraverseResults result = traverse( walker, locusView, referenceView, referenceOrderedDataView, sum ); + sum = result.reduceResult; + dataProvider.getShard().getReadMetrics().incrementNumIterations(result.numIterations); + updateCumulativeMetrics(dataProvider.getShard()); } // We have a final map call to execute here to clean up the skipped based from the // last position in the ROD to that in the interval if ( WalkerManager.getWalkerDataSource(walker) == DataSource.REFERENCE_ORDERED_DATA && ! walker.isDone() ) { // only do this if the walker isn't done! - RodLocusView rodLocusView = (RodLocusView)locusView; - long nSkipped = rodLocusView.getLastSkippedBases(); + final RodLocusView rodLocusView = (RodLocusView)locusView; + final long nSkipped = rodLocusView.getLastSkippedBases(); if ( nSkipped > 0 ) { - GenomeLoc site = rodLocusView.getLocOneBeyondShard(); - AlignmentContext ac = new AlignmentContext(site, new ReadBackedPileupImpl(site), nSkipped); - M x = walker.map(null, null, ac); + final GenomeLoc site = rodLocusView.getLocOneBeyondShard(); + final AlignmentContext ac = new AlignmentContext(site, new ReadBackedPileupImpl(site), nSkipped); + final M x = walker.map(null, null, ac); sum = walker.reduce(x, sum); } } @@ -90,14 +83,14 @@ public class TraverseLoci extends TraversalEngine,Locu /** * Gets the best view of loci for this walker given the available data. The view will function as a 'trigger track' - * of sorts, providing a consistent interface so that TraverseLoci doesn't need to be reimplemented for any new datatype + * of sorts, providing a consistent interface so that TraverseLociBase doesn't need to be reimplemented for any new datatype * that comes along. * @param walker walker to interrogate. * @param dataProvider Data which which to drive the locus view. * @return A view of the locus data, where one iteration of the locus view maps to one iteration of the traversal. */ private LocusView getLocusView( Walker walker, LocusShardDataProvider dataProvider ) { - DataSource dataSource = WalkerManager.getWalkerDataSource(walker); + final DataSource dataSource = WalkerManager.getWalkerDataSource(walker); if( dataSource == DataSource.READS ) return new CoveredLocusView(dataProvider); else if( dataSource == DataSource.REFERENCE ) //|| ! GenomeAnalysisEngine.instance.getArguments().enableRodWalkers ) diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociLinear.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociLinear.java new file mode 100755 index 000000000..22381092f --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociLinear.java @@ -0,0 +1,47 @@ +package org.broadinstitute.sting.gatk.traversals; + +import org.broadinstitute.sting.gatk.contexts.AlignmentContext; +import org.broadinstitute.sting.gatk.contexts.ReferenceContext; +import org.broadinstitute.sting.gatk.datasources.providers.LocusReferenceView; +import org.broadinstitute.sting.gatk.datasources.providers.LocusView; +import org.broadinstitute.sting.gatk.datasources.providers.ReferenceOrderedView; +import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; +import org.broadinstitute.sting.gatk.walkers.LocusWalker; +import org.broadinstitute.sting.utils.GenomeLoc; + +/** + * A simple solution to iterating over all reference positions over a series of genomic locations. + */ +public class TraverseLociLinear extends TraverseLociBase { + + @Override + protected TraverseResults traverse(LocusWalker walker, LocusView locusView, LocusReferenceView referenceView, ReferenceOrderedView referenceOrderedDataView, T sum) { + // We keep processing while the next reference location is within the interval + boolean done = false; + int numIterations = 0; + + while( locusView.hasNext() && ! done ) { + numIterations++; + final AlignmentContext locus = locusView.next(); + final GenomeLoc location = locus.getLocation(); + + // create reference context. Note that if we have a pileup of "extended events", the context will + // hold the (longest) stretch of deleted reference bases (if deletions are present in the pileup). + final ReferenceContext refContext = referenceView.getReferenceContext(location); + + // Iterate forward to get all reference ordered data covering this location + final RefMetaDataTracker tracker = referenceOrderedDataView.getReferenceOrderedDataAtLocus(locus.getLocation(), refContext); + + final boolean keepMeP = walker.filter(tracker, refContext, locus); + if (keepMeP) { + final M x = walker.map(tracker, refContext, locus); + sum = walker.reduce(x, sum); + done = walker.isDone(); + } + + printProgress(locus.getLocation()); + } + + return new TraverseResults(numIterations, sum); + } +} diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociNano.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociNano.java new file mode 100755 index 000000000..73b73c002 --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseLociNano.java @@ -0,0 +1,205 @@ +package org.broadinstitute.sting.gatk.traversals; + +import org.broadinstitute.sting.gatk.contexts.AlignmentContext; +import org.broadinstitute.sting.gatk.contexts.ReferenceContext; +import org.broadinstitute.sting.gatk.datasources.providers.LocusReferenceView; +import org.broadinstitute.sting.gatk.datasources.providers.LocusView; +import org.broadinstitute.sting.gatk.datasources.providers.ReferenceOrderedView; +import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; +import org.broadinstitute.sting.gatk.walkers.LocusWalker; +import org.broadinstitute.sting.utils.GenomeLoc; +import org.broadinstitute.sting.utils.nanoScheduler.NanoScheduler; +import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerMapFunction; +import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerProgressFunction; +import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerReduceFunction; + +import java.util.Iterator; + +/** + * A simple solution to iterating over all reference positions over a series of genomic locations. + */ +public class TraverseLociNano extends TraverseLociBase { + /** our log, which we want to capture anything from this class */ + private static final boolean DEBUG = false; + private static final int BUFFER_SIZE = 1000; + + final NanoScheduler nanoScheduler; + + public TraverseLociNano(int nThreads) { + nanoScheduler = new NanoScheduler(BUFFER_SIZE, nThreads); + nanoScheduler.setProgressFunction(new TraverseLociProgress()); + } + + @Override + protected TraverseResults traverse(final LocusWalker walker, + final LocusView locusView, + final LocusReferenceView referenceView, + final ReferenceOrderedView referenceOrderedDataView, + final T sum) { + nanoScheduler.setDebug(DEBUG); + final TraverseLociMap myMap = new TraverseLociMap(walker); + final TraverseLociReduce myReduce = new TraverseLociReduce(walker); + + final MapDataIterator inputIterator = new MapDataIterator(locusView, referenceView, referenceOrderedDataView); + final T result = nanoScheduler.execute(inputIterator, myMap, sum, myReduce); + + return new TraverseResults(inputIterator.numIterations, result); + } + + /** + * Create iterator that provides inputs for all map calls into MapData, to be provided + * to NanoScheduler for Map/Reduce + */ + private class MapDataIterator implements Iterator { + final LocusView locusView; + final LocusReferenceView referenceView; + final ReferenceOrderedView referenceOrderedDataView; + int numIterations = 0; + + private MapDataIterator(LocusView locusView, LocusReferenceView referenceView, ReferenceOrderedView referenceOrderedDataView) { + this.locusView = locusView; + this.referenceView = referenceView; + this.referenceOrderedDataView = referenceOrderedDataView; + } + + @Override + public boolean hasNext() { + return locusView.hasNext(); + } + + @Override + public MapData next() { + final AlignmentContext locus = locusView.next(); + final GenomeLoc location = locus.getLocation(); + + //logger.info("Pulling data from MapDataIterator at " + location); + + // create reference context. Note that if we have a pileup of "extended events", the context will + // hold the (longest) stretch of deleted reference bases (if deletions are present in the pileup). + final ReferenceContext refContext = referenceView.getReferenceContext(location); + + // Iterate forward to get all reference ordered data covering this location + final RefMetaDataTracker tracker = referenceOrderedDataView.getReferenceOrderedDataAtLocus(location, refContext); + + numIterations++; + return new MapData(locus, refContext, tracker); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("Cannot remove elements from MapDataIterator"); + } + } + + @Override + public void printOnTraversalDone() { + nanoScheduler.shutdown(); + super.printOnTraversalDone(); + } + + /** + * The input data needed for each map call. The read, the reference, and the RODs + */ + private class MapData { + final AlignmentContext alignmentContext; + final ReferenceContext refContext; + final RefMetaDataTracker tracker; + + private MapData(final AlignmentContext alignmentContext, ReferenceContext refContext, RefMetaDataTracker tracker) { + this.alignmentContext = alignmentContext; + this.refContext = refContext; + this.tracker = tracker; + } + + @Override + public String toString() { + return "MapData " + alignmentContext.getLocation(); + } + } + + /** + * Contains the results of a map call, indicating whether the call was good, filtered, or done + */ + private class MapResult { + final M value; + final boolean reduceMe; + + /** + * Create a MapResult with value that should be reduced + * + * @param value the value to reduce + */ + private MapResult(final M value) { + this.value = value; + this.reduceMe = true; + } + + /** + * Create a MapResult that shouldn't be reduced + */ + private MapResult() { + this.value = null; + this.reduceMe = false; + } + } + + /** + * A static object that tells reduce that the result of map should be skipped (filtered or done) + */ + private final MapResult SKIP_REDUCE = new MapResult(); + + /** + * MapFunction for TraverseReads meeting NanoScheduler interface requirements + * + * Applies walker.map to MapData, returning a MapResult object containing the result + */ + private class TraverseLociMap implements NanoSchedulerMapFunction { + final LocusWalker walker; + + private TraverseLociMap(LocusWalker walker) { + this.walker = walker; + } + + @Override + public MapResult apply(final MapData data) { + if ( ! walker.isDone() ) { + final boolean keepMeP = walker.filter(data.tracker, data.refContext, data.alignmentContext); + if (keepMeP) { + final M x = walker.map(data.tracker, data.refContext, data.alignmentContext); + return new MapResult(x); + } + } + return SKIP_REDUCE; + } + } + + /** + * NanoSchedulerReduceFunction for TraverseReads meeting NanoScheduler interface requirements + * + * Takes a MapResult object and applies the walkers reduce function to each map result, when applicable + */ + private class TraverseLociReduce implements NanoSchedulerReduceFunction { + final LocusWalker walker; + + private TraverseLociReduce(LocusWalker walker) { + this.walker = walker; + } + + @Override + public T apply(MapResult one, T sum) { + if ( one.reduceMe ) + // only run reduce on values that aren't DONE or FAILED + return walker.reduce(one.value, sum); + else + return sum; + } + } + + private class TraverseLociProgress implements NanoSchedulerProgressFunction { + @Override + public void progress(MapData lastProcessedMap) { + if (lastProcessedMap.alignmentContext != null) + printProgress(lastProcessedMap.alignmentContext.getLocation()); + } + } +} diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadPairs.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadPairs.java index ebaac40af..9b076fce4 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadPairs.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadPairs.java @@ -65,7 +65,8 @@ public class TraverseReadPairs extends TraversalEngine extends TraversalEngine,Read sum = walker.reduce(x, sum); } - GenomeLoc locus = read.getReferenceIndex() == SAMRecord.NO_ALIGNMENT_REFERENCE_INDEX ? null : engine.getGenomeLocParser().createGenomeLoc(read.getReferenceName(),read.getAlignmentStart()); - printProgress(dataProvider.getShard(),locus); + final GenomeLoc locus = read.getReferenceIndex() == SAMRecord.NO_ALIGNMENT_REFERENCE_INDEX ? null : engine.getGenomeLocParser().createGenomeLoc(read.getReferenceName(),read.getAlignmentStart()); + + updateCumulativeMetrics(dataProvider.getShard()); + printProgress(locus); + done = walker.isDone(); } return sum; diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadsNano.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadsNano.java index b397cb8c0..5679747e1 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadsNano.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadsNano.java @@ -34,34 +34,34 @@ import org.broadinstitute.sting.gatk.datasources.providers.ReadView; import org.broadinstitute.sting.gatk.datasources.reads.ReadShard; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.walkers.ReadWalker; -import org.broadinstitute.sting.utils.nanoScheduler.MapFunction; +import org.broadinstitute.sting.utils.GenomeLoc; import org.broadinstitute.sting.utils.nanoScheduler.NanoScheduler; -import org.broadinstitute.sting.utils.nanoScheduler.ReduceFunction; +import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerMapFunction; +import org.broadinstitute.sting.utils.nanoScheduler.NanoSchedulerReduceFunction; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; -import java.util.ArrayList; +import java.util.LinkedList; import java.util.List; /** - * @author aaron + * A nano-scheduling version of TraverseReads. + * + * Implements the traversal of a walker that accepts individual reads, the reference, and + * RODs per map call. Directly supports shared memory parallelism via NanoScheduler + * + * @author depristo * @version 1.0 - * @date Apr 24, 2009 - *

- * Class TraverseReads - *

- * This class handles traversing by reads in the new shardable style + * @date 9/2/2012 */ public class TraverseReadsNano extends TraversalEngine,ReadShardDataProvider> { /** our log, which we want to capture anything from this class */ protected static final Logger logger = Logger.getLogger(TraverseReadsNano.class); private static final boolean DEBUG = false; - private static final int MIN_GROUP_SIZE = 100; - final NanoScheduler nanoScheduler; + final NanoScheduler nanoScheduler; public TraverseReadsNano(int nThreads) { final int bufferSize = ReadShard.getReadBufferSize() + 1; // actually has 1 more than max - final int mapGroupSize = (int)Math.max(Math.ceil(bufferSize / 50.0 + 1), MIN_GROUP_SIZE); - nanoScheduler = new NanoScheduler(bufferSize, mapGroupSize, nThreads); + nanoScheduler = new NanoScheduler(bufferSize, nThreads); } @Override @@ -89,19 +89,32 @@ public class TraverseReadsNano extends TraversalEngine, final TraverseReadsMap myMap = new TraverseReadsMap(walker); final TraverseReadsReduce myReduce = new TraverseReadsReduce(walker); - T result = nanoScheduler.execute(aggregateMapData(dataProvider).iterator(), myMap, sum, myReduce); - // TODO -- how do we print progress? - //printProgress(dataProvider.getShard(), ???); + final List aggregatedInputs = aggregateMapData(dataProvider); + final T result = nanoScheduler.execute(aggregatedInputs.iterator(), myMap, sum, myReduce); + + final GATKSAMRecord lastRead = aggregatedInputs.get(aggregatedInputs.size() - 1).read; + final GenomeLoc locus = engine.getGenomeLocParser().createGenomeLoc(lastRead); + + updateCumulativeMetrics(dataProvider.getShard()); + printProgress(locus); return result; } + /** + * Aggregate all of the inputs for all map calls into MapData, to be provided + * to NanoScheduler for Map/Reduce + * + * @param dataProvider the source of our data + * @return a linked list of MapData objects holding the read, ref, and ROD info for every map/reduce + * should execute + */ private List aggregateMapData(final ReadShardDataProvider dataProvider) { final ReadView reads = new ReadView(dataProvider); final ReadReferenceView reference = new ReadReferenceView(dataProvider); final ReadBasedReferenceOrderedView rodView = new ReadBasedReferenceOrderedView(dataProvider); - final List mapData = new ArrayList(); // TODO -- need size of reads + final List mapData = new LinkedList(); for ( final SAMRecord read : reads ) { final ReferenceContext refContext = ! read.getReadUnmappedFlag() ? reference.getReferenceContext(read) @@ -127,19 +140,9 @@ public class TraverseReadsNano extends TraversalEngine, super.printOnTraversalDone(); } - private class TraverseReadsReduce implements ReduceFunction { - final ReadWalker walker; - - private TraverseReadsReduce(ReadWalker walker) { - this.walker = walker; - } - - @Override - public T apply(M one, T sum) { - return walker.reduce(one, sum); - } - } - + /** + * The input data needed for each map call. The read, the reference, and the RODs + */ private class MapData { final GATKSAMRecord read; final ReferenceContext refContext; @@ -152,7 +155,43 @@ public class TraverseReadsNano extends TraversalEngine, } } - private class TraverseReadsMap implements MapFunction { + /** + * Contains the results of a map call, indicating whether the call was good, filtered, or done + */ + private class MapResult { + final M value; + final boolean reduceMe; + + /** + * Create a MapResult with value that should be reduced + * + * @param value the value to reduce + */ + private MapResult(final M value) { + this.value = value; + this.reduceMe = true; + } + + /** + * Create a MapResult that shouldn't be reduced + */ + private MapResult() { + this.value = null; + this.reduceMe = false; + } + } + + /** + * A static object that tells reduce that the result of map should be skipped (filtered or done) + */ + private final MapResult SKIP_REDUCE = new MapResult(); + + /** + * MapFunction for TraverseReads meeting NanoScheduler interface requirements + * + * Applies walker.map to MapData, returning a MapResult object containing the result + */ + private class TraverseReadsMap implements NanoSchedulerMapFunction { final ReadWalker walker; private TraverseReadsMap(ReadWalker walker) { @@ -160,15 +199,36 @@ public class TraverseReadsNano extends TraversalEngine, } @Override - public M apply(final MapData data) { + public MapResult apply(final MapData data) { if ( ! walker.isDone() ) { final boolean keepMeP = walker.filter(data.refContext, data.read); - if (keepMeP) { - return walker.map(data.refContext, data.read, data.tracker); - } + if (keepMeP) + return new MapResult(walker.map(data.refContext, data.read, data.tracker)); } - return null; // TODO -- what should we return in the case where the walker is done or the read is filtered? + return SKIP_REDUCE; + } + } + + /** + * NanoSchedulerReduceFunction for TraverseReads meeting NanoScheduler interface requirements + * + * Takes a MapResult object and applies the walkers reduce function to each map result, when applicable + */ + private class TraverseReadsReduce implements NanoSchedulerReduceFunction { + final ReadWalker walker; + + private TraverseReadsReduce(ReadWalker walker) { + this.walker = walker; + } + + @Override + public T apply(MapResult one, T sum) { + if ( one.reduceMe ) + // only run reduce on values that aren't DONE or FAILED + return walker.reduce(one.value, sum); + else + return sum; } } } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/FlagStat.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/FlagStat.java index 14d14aca5..b4ef66aaf 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/FlagStat.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/FlagStat.java @@ -45,7 +45,7 @@ import java.text.NumberFormat; */ @DocumentedGATKFeature( groupName = "Quality Control and Simple Analysis Tools", extraDocs = {CommandLineGATK.class} ) @Requires({DataSource.READS}) -public class FlagStat extends ReadWalker implements ThreadSafeMapReduce { +public class FlagStat extends ReadWalker implements NanoSchedulable { @Output PrintStream out; diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/ThreadSafeMapReduce.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/NanoSchedulable.java similarity index 97% rename from public/java/src/org/broadinstitute/sting/gatk/walkers/ThreadSafeMapReduce.java rename to public/java/src/org/broadinstitute/sting/gatk/walkers/NanoSchedulable.java index 1ce469f8c..731ce7e4e 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/ThreadSafeMapReduce.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/NanoSchedulable.java @@ -27,5 +27,5 @@ package org.broadinstitute.sting.gatk.walkers; * declare that their map function is thread-safe and so multiple * map calls can be run in parallel in the same JVM instance. */ -public interface ThreadSafeMapReduce { +public interface NanoSchedulable { } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/Pileup.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/Pileup.java index 52c6e1560..a3efea9f1 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/Pileup.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/Pileup.java @@ -52,7 +52,7 @@ import java.util.List; * samtools pileup [-f in.ref.fasta] [-t in.ref_list] [-l in.site_list] [-iscg] [-T theta] [-N nHap] [-r pairDiffRate] */ @DocumentedGATKFeature( groupName = "Quality Control and Simple Analysis Tools", extraDocs = {CommandLineGATK.class} ) -public class Pileup extends LocusWalker implements TreeReducible { +public class Pileup extends LocusWalker implements TreeReducible, NanoSchedulable { private static final String verboseDelimiter = "@"; // it's ugly to use "@" but it's literally the only usable character not allowed in read names @@ -70,27 +70,32 @@ public class Pileup extends LocusWalker implements TreeReducib @Input(fullName="metadata",shortName="metadata",doc="Add these ROD bindings to the output Pileup", required=false) public List> rods = Collections.emptyList(); - public void initialize() { - } - - public Integer map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) { - - String rods = getReferenceOrderedData( tracker ); + @Override + public String map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) { + final String rods = getReferenceOrderedData( tracker ); ReadBackedPileup basePileup = context.getBasePileup(); - out.printf("%s %s", basePileup.getPileupString((char)ref.getBase()), rods); - if ( SHOW_VERBOSE ) - out.printf(" %s", createVerboseOutput(basePileup)); - out.println(); - return 1; + final StringBuilder s = new StringBuilder(); + s.append(String.format("%s %s", basePileup.getPileupString((char)ref.getBase()), rods)); + if ( SHOW_VERBOSE ) + s.append(" ").append(createVerboseOutput(basePileup)); + s.append("\n"); + + return s.toString(); } // Given result of map function + @Override public Integer reduceInit() { return 0; } - public Integer reduce(Integer value, Integer sum) { - return treeReduce(sum,value); + + @Override + public Integer reduce(String value, Integer sum) { + out.print(value); + return sum + 1; } + + @Override public Integer treeReduce(Integer lhs, Integer rhs) { return lhs + rhs; } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/PrintReads.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/PrintReads.java index a5d4b45b6..37176cbf9 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/PrintReads.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/PrintReads.java @@ -93,7 +93,7 @@ import java.util.*; @ReadTransformersMode(ApplicationTime = ReadTransformer.ApplicationTime.HANDLED_IN_WALKER) @BAQMode(QualityMode = BAQ.QualityMode.ADD_TAG, ApplicationTime = ReadTransformer.ApplicationTime.HANDLED_IN_WALKER) @Requires({DataSource.READS, DataSource.REFERENCE}) -public class PrintReads extends ReadWalker implements ThreadSafeMapReduce { +public class PrintReads extends ReadWalker implements NanoSchedulable { @Output(doc="Write output to this BAM filename instead of STDOUT", required = true) SAMFileWriter out; @@ -228,7 +228,6 @@ public class PrintReads extends ReadWalker impleme GATKSAMRecord workingRead = readIn; for ( final ReadTransformer transformer : readTransformers ) { - if ( logger.isDebugEnabled() ) logger.debug("Applying transformer " + transformer + " to read " + readIn.getReadName()); workingRead = transformer.apply(workingRead); } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/TreeReducible.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/TreeReducible.java index 8621c0e9d..c950e07e4 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/TreeReducible.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/TreeReducible.java @@ -13,7 +13,7 @@ package org.broadinstitute.sting.gatk.walkers; * shards of the data can reduce with each other, and the composite result * can be reduced with other composite results. */ -public interface TreeReducible extends ThreadSafeMapReduce { +public interface TreeReducible { /** * A composite, 'reduce of reduces' function. * @param lhs 'left-most' portion of data in the composite reduce. diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java index 443b493be..43aa85a05 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java @@ -109,7 +109,7 @@ import java.util.ArrayList; @ReadFilters({MappingQualityZeroFilter.class, MappingQualityUnavailableFilter.class}) // only look at covered loci, not every loci of the reference file @Requires({DataSource.READS, DataSource.REFERENCE}) // filter out all reads with zero or unavailable mapping quality @PartitionBy(PartitionType.LOCUS) // this walker requires both -I input.bam and -R reference.fasta -public class BaseRecalibrator extends LocusWalker implements TreeReducible { +public class BaseRecalibrator extends LocusWalker implements TreeReducible, NanoSchedulable { @ArgumentCollection private final RecalibrationArgumentCollection RAC = new RecalibrationArgumentCollection(); // all the command line arguments for BQSR and it's covariates diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/UnifiedGenotyper.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/UnifiedGenotyper.java index 93928a780..32ceff715 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/UnifiedGenotyper.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/UnifiedGenotyper.java @@ -125,7 +125,7 @@ import java.util.*; // TODO -- When LocusIteratorByState gets cleaned up, we should enable multiple @By sources: // TODO -- @By( {DataSource.READS, DataSource.REFERENCE_ORDERED_DATA} ) @Downsample(by=DownsampleType.BY_SAMPLE, toCoverage=250) -public class UnifiedGenotyper extends LocusWalker, UnifiedGenotyper.UGStatistics> implements TreeReducible, AnnotatorCompatible { +public class UnifiedGenotyper extends LocusWalker, UnifiedGenotyper.UGStatistics> implements TreeReducible, AnnotatorCompatible, NanoSchedulable { @ArgumentCollection private UnifiedArgumentCollection UAC = new UnifiedArgumentCollection(); diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/qc/CountLoci.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/qc/CountLoci.java index bd10eab87..cd295f26e 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/qc/CountLoci.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/qc/CountLoci.java @@ -6,6 +6,7 @@ import org.broadinstitute.sting.gatk.contexts.AlignmentContext; import org.broadinstitute.sting.gatk.contexts.ReferenceContext; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.walkers.LocusWalker; +import org.broadinstitute.sting.gatk.walkers.NanoSchedulable; import org.broadinstitute.sting.gatk.walkers.TreeReducible; import org.broadinstitute.sting.utils.help.DocumentedGATKFeature; @@ -40,7 +41,7 @@ import java.io.PrintStream; * */ @DocumentedGATKFeature( groupName = "Quality Control and Simple Analysis Tools", extraDocs = {CommandLineGATK.class} ) -public class CountLoci extends LocusWalker implements TreeReducible { +public class CountLoci extends LocusWalker implements TreeReducible, NanoSchedulable { @Output(doc="Write count to this file instead of STDOUT") PrintStream out; diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/qc/CountRODs.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/qc/CountRODs.java index 9915d617e..ab37a2322 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/qc/CountRODs.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/qc/CountRODs.java @@ -37,6 +37,7 @@ import org.broadinstitute.sting.gatk.contexts.AlignmentContext; import org.broadinstitute.sting.gatk.contexts.ReferenceContext; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.refdata.utils.RODRecordList; +import org.broadinstitute.sting.gatk.walkers.NanoSchedulable; import org.broadinstitute.sting.gatk.walkers.RodWalker; import org.broadinstitute.sting.gatk.walkers.TreeReducible; import org.broadinstitute.sting.utils.GenomeLoc; @@ -73,7 +74,7 @@ import java.util.*; * */ @DocumentedGATKFeature( groupName = "Quality Control and Simple Analysis Tools", extraDocs = {CommandLineGATK.class} ) -public class CountRODs extends RodWalker, Long>> implements TreeReducible, Long>> { +public class CountRODs extends RodWalker, Long>> implements TreeReducible, Long>>, NanoSchedulable { @Output public PrintStream out; diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/qc/CountReads.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/qc/CountReads.java index 856ea77f5..301fa5b9b 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/qc/CountReads.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/qc/CountReads.java @@ -4,9 +4,9 @@ import org.broadinstitute.sting.gatk.CommandLineGATK; import org.broadinstitute.sting.gatk.contexts.ReferenceContext; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.walkers.DataSource; +import org.broadinstitute.sting.gatk.walkers.NanoSchedulable; import org.broadinstitute.sting.gatk.walkers.ReadWalker; import org.broadinstitute.sting.gatk.walkers.Requires; -import org.broadinstitute.sting.gatk.walkers.ThreadSafeMapReduce; import org.broadinstitute.sting.utils.help.DocumentedGATKFeature; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; @@ -41,7 +41,7 @@ import org.broadinstitute.sting.utils.sam.GATKSAMRecord; */ @DocumentedGATKFeature( groupName = "Quality Control and Simple Analysis Tools", extraDocs = {CommandLineGATK.class} ) @Requires({DataSource.READS, DataSource.REFERENCE}) -public class CountReads extends ReadWalker implements ThreadSafeMapReduce { +public class CountReads extends ReadWalker implements NanoSchedulable { public Integer map(ReferenceContext ref, GATKSAMRecord read, RefMetaDataTracker tracker) { return 1; } diff --git a/public/java/src/org/broadinstitute/sting/utils/SimpleTimer.java b/public/java/src/org/broadinstitute/sting/utils/SimpleTimer.java index 15d34a348..b3a9986c5 100644 --- a/public/java/src/org/broadinstitute/sting/utils/SimpleTimer.java +++ b/public/java/src/org/broadinstitute/sting/utils/SimpleTimer.java @@ -1,18 +1,42 @@ package org.broadinstitute.sting.utils; +import com.google.java.contract.Ensures; +import com.google.java.contract.Requires; + +import java.util.concurrent.TimeUnit; + /** - * A useful simple system for timing code. This code is not thread safe! + * A useful simple system for timing code with nano second resolution + * + * Note that this code is not thread-safe. If you have a single timer + * being started and stopped by multiple threads you will need to protect the + * calls to avoid meaningless results of having multiple starts and stops + * called sequentially. * * User: depristo * Date: Dec 10, 2010 * Time: 9:07:44 AM */ public class SimpleTimer { - final private String name; - private long elapsed = 0l; - private long startTime = 0l; - boolean running = false; + protected static final double NANO_TO_SECOND_DOUBLE = 1.0 / TimeUnit.SECONDS.toNanos(1); + private final String name; + + /** + * The elapsedTimeNano time in nanoSeconds of this timer. The elapsedTimeNano time is the + * sum of times between starts/restrats and stops. + */ + private long elapsedTimeNano = 0l; + + /** + * The start time of the last start/restart in nanoSeconds + */ + private long startTimeNano = 0l; + + /** + * Is this timer currently running (i.e., the last call was start/restart) + */ + private boolean running = false; /** * Creates an anonymous simple timer @@ -25,7 +49,8 @@ public class SimpleTimer { * Creates a simple timer named name * @param name of the timer, must not be null */ - public SimpleTimer(String name) { + public SimpleTimer(final String name) { + if ( name == null ) throw new IllegalArgumentException("SimpleTimer name cannot be null"); this.name = name; } @@ -37,27 +62,27 @@ public class SimpleTimer { } /** - * Starts the timer running, and sets the elapsed time to 0. This is equivalent to + * Starts the timer running, and sets the elapsedTimeNano time to 0. This is equivalent to * resetting the time to have no history at all. * * @return this object, for programming convenience */ + @Ensures("elapsedTimeNano == 0l") public synchronized SimpleTimer start() { - elapsed = 0l; - restart(); - return this; + elapsedTimeNano = 0l; + return restart(); } /** - * Starts the timer running, without reseting the elapsed time. This function may be + * Starts the timer running, without resetting the elapsedTimeNano time. This function may be * called without first calling start(). The only difference between start and restart - * is that start resets the elapsed time, while restart does not. + * is that start resets the elapsedTimeNano time, while restart does not. * * @return this object, for programming convenience */ public synchronized SimpleTimer restart() { running = true; - startTime = currentTime(); + startTimeNano = currentTimeNano(); return this; } @@ -71,29 +96,53 @@ public class SimpleTimer { /** * @return A convenience function to obtain the current time in milliseconds from this timer */ - public synchronized long currentTime() { + public long currentTime() { return System.currentTimeMillis(); } /** - * Stops the timer. Increases the elapsed time by difference between start and now. The - * timer must be running in order to call stop + * @return A convenience function to obtain the current time in nanoSeconds from this timer + */ + public long currentTimeNano() { + return System.nanoTime(); + } + + /** + * Stops the timer. Increases the elapsedTimeNano time by difference between start and now. + * + * It's ok to call stop on a timer that's not running. It has no effect on the timer. * * @return this object, for programming convenience */ + @Requires("startTimeNano != 0l") public synchronized SimpleTimer stop() { - running = false; - elapsed += currentTime() - startTime; + if ( running ) { + running = false; + elapsedTimeNano += currentTimeNano() - startTimeNano; + } return this; } /** - * Returns the total elapsed time of all start/stops of this timer. If the timer is currently + * Returns the total elapsedTimeNano time of all start/stops of this timer. If the timer is currently * running, includes the difference from currentTime() and the start as well * * @return this time, in seconds */ public synchronized double getElapsedTime() { - return (running ? (currentTime() - startTime + elapsed) : elapsed) / 1000.0; + return nanoToSecondsAsDouble(getElapsedTimeNano()); + } + + protected static double nanoToSecondsAsDouble(final long nano) { + return nano * NANO_TO_SECOND_DOUBLE; + } + + /** + * @see #getElapsedTime() but returns the result in nanoseconds + * + * @return the elapsed time in nanoseconds + */ + public synchronized long getElapsedTimeNano() { + return running ? (currentTimeNano() - startTimeNano + elapsedTimeNano) : elapsedTimeNano; } } diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java index 668c82524..24db0f7dc 100644 --- a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoScheduler.java @@ -3,7 +3,9 @@ package org.broadinstitute.sting.utils.nanoScheduler; import com.google.java.contract.Ensures; import com.google.java.contract.Requires; import org.apache.log4j.Logger; -import org.broadinstitute.sting.utils.Utils; +import org.broadinstitute.sting.utils.AutoFormattingTime; +import org.broadinstitute.sting.utils.SimpleTimer; +import org.broadinstitute.sting.utils.collections.Pair; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import java.util.Iterator; @@ -45,42 +47,41 @@ import java.util.concurrent.*; public class NanoScheduler { private final static Logger logger = Logger.getLogger(NanoScheduler.class); private final static boolean ALLOW_SINGLE_THREAD_FASTPATH = true; + private final static boolean LOG_MAP_TIMES = false; + private final static boolean TIME_CALLS = true; final int bufferSize; - final int mapGroupSize; final int nThreads; - final ExecutorService executor; + final ExecutorService inputExecutor; + final ExecutorService mapExecutor; boolean shutdown = false; boolean debug = false; + private NanoSchedulerProgressFunction progressFunction = null; + + final SimpleTimer outsideSchedulerTimer = new SimpleTimer("outside"); + final SimpleTimer inputTimer = new SimpleTimer("input"); + final SimpleTimer mapTimer = new SimpleTimer("map"); + final SimpleTimer reduceTimer = new SimpleTimer("reduce"); + /** * Create a new nanoschedule with the desire characteristics requested by the argument * * @param bufferSize the number of input elements to read in each scheduling cycle. - * @param mapGroupSize How many inputs should be grouped together per map? If -1 we make a reasonable guess * @param nThreads the number of threads to use to get work done, in addition to the thread calling execute */ public NanoScheduler(final int bufferSize, - final int mapGroupSize, final int nThreads) { if ( bufferSize < 1 ) throw new IllegalArgumentException("bufferSize must be >= 1, got " + bufferSize); if ( nThreads < 1 ) throw new IllegalArgumentException("nThreads must be >= 1, got " + nThreads); - if ( mapGroupSize > bufferSize ) throw new IllegalArgumentException("mapGroupSize " + mapGroupSize + " must be <= bufferSize " + bufferSize); - if ( mapGroupSize == 0 || mapGroupSize < -1 ) throw new IllegalArgumentException("mapGroupSize cannot be <= 0" + mapGroupSize); - this.bufferSize = bufferSize; this.nThreads = nThreads; + this.mapExecutor = nThreads == 1 ? null : Executors.newFixedThreadPool(nThreads-1); + this.inputExecutor = Executors.newSingleThreadExecutor(); - if ( mapGroupSize == -1 ) { - this.mapGroupSize = (int)Math.ceil(this.bufferSize / (10.0*this.nThreads)); - logger.info(String.format("Dynamically setting grouping size to %d based on buffer size %d and n threads %d", - this.mapGroupSize, this.bufferSize, this.nThreads)); - } else { - this.mapGroupSize = mapGroupSize; - } - - this.executor = nThreads == 1 ? null : Executors.newFixedThreadPool(nThreads); + // start timing the time spent outside of the nanoScheduler + outsideSchedulerTimer.start(); } /** @@ -101,27 +102,35 @@ public class NanoScheduler { return bufferSize; } - /** - * The grouping size used by this NanoScheduler - * @return - */ - @Ensures("result > 0") - public int getMapGroupSize() { - return mapGroupSize; - } - /** * Tells this nanoScheduler to shutdown immediately, releasing all its resources. * * After this call, execute cannot be invoked without throwing an error */ public void shutdown() { - if ( executor != null ) { - final List remaining = executor.shutdownNow(); + outsideSchedulerTimer.stop(); + + if ( mapExecutor != null ) { + final List remaining = mapExecutor.shutdownNow(); if ( ! remaining.isEmpty() ) - throw new IllegalStateException("Remaining tasks found in the executor, unexpected behavior!"); + throw new IllegalStateException("Remaining tasks found in the mapExecutor, unexpected behavior!"); } shutdown = true; + + if (TIME_CALLS) { + printTimerInfo("Input time", inputTimer); + printTimerInfo("Map time", mapTimer); + printTimerInfo("Reduce time", reduceTimer); + printTimerInfo("Outside time", outsideSchedulerTimer); + } + } + + private void printTimerInfo(final String label, final SimpleTimer timer) { + final double total = inputTimer.getElapsedTime() + mapTimer.getElapsedTime() + + reduceTimer.getElapsedTime() + outsideSchedulerTimer.getElapsedTime(); + final double myTimeInSec = timer.getElapsedTime(); + final double myTimePercent = myTimeInSec / total * 100; + logger.info(String.format("%s: %s (%5.2f%%)", label, new AutoFormattingTime(myTimeInSec), myTimePercent)); } /** @@ -145,6 +154,17 @@ public class NanoScheduler { this.debug = debug; } + /** + * Set the progress callback function to progressFunction + * + * The progress callback is invoked after each buffer size elements have been processed by map/reduce + * + * @param progressFunction a progress function to call, or null if you don't want any progress callback + */ + public void setProgressFunction(final NanoSchedulerProgressFunction progressFunction) { + this.progressFunction = progressFunction; + } + /** * Execute a map/reduce job with this nanoScheduler * @@ -159,25 +179,31 @@ public class NanoScheduler { * It is safe to call this function repeatedly on a single nanoScheduler, at least until the * shutdown method is called. * - * @param inputReader - * @param map - * @param reduce - * @return + * @param inputReader an iterator providing us with the input data to nanoSchedule map/reduce over + * @param map the map function from input type -> map type, will be applied in parallel to each input + * @param reduce the reduce function from map type + reduce type -> reduce type to be applied in order to map results + * @return the last reduce value */ public ReduceType execute(final Iterator inputReader, - final MapFunction map, + final NanoSchedulerMapFunction map, final ReduceType initialValue, - final ReduceFunction reduce) { + final NanoSchedulerReduceFunction reduce) { if ( isShutdown() ) throw new IllegalStateException("execute called on already shutdown NanoScheduler"); if ( inputReader == null ) throw new IllegalArgumentException("inputReader cannot be null"); if ( map == null ) throw new IllegalArgumentException("map function cannot be null"); if ( reduce == null ) throw new IllegalArgumentException("reduce function cannot be null"); + outsideSchedulerTimer.stop(); + + ReduceType result; if ( ALLOW_SINGLE_THREAD_FASTPATH && getnThreads() == 1 ) { - return executeSingleThreaded(inputReader, map, initialValue, reduce); + result = executeSingleThreaded(inputReader, map, initialValue, reduce); } else { - return executeMultiThreaded(inputReader, map, initialValue, reduce); + result = executeMultiThreaded(inputReader, map, initialValue, reduce); } + + outsideSchedulerTimer.restart(); + return result; } /** @@ -185,15 +211,36 @@ public class NanoScheduler { * @return the reduce result of this map/reduce job */ private ReduceType executeSingleThreaded(final Iterator inputReader, - final MapFunction map, + final NanoSchedulerMapFunction map, final ReduceType initialValue, - final ReduceFunction reduce) { + final NanoSchedulerReduceFunction reduce) { ReduceType sum = initialValue; + int i = 0; + + // start timer to ensure that both hasNext and next are caught by the timer + if ( TIME_CALLS ) inputTimer.restart(); while ( inputReader.hasNext() ) { final InputType input = inputReader.next(); + if ( TIME_CALLS ) inputTimer.stop(); + + // map + if ( TIME_CALLS ) mapTimer.restart(); + final long preMapTime = LOG_MAP_TIMES ? 0 : mapTimer.currentTimeNano(); final MapType mapValue = map.apply(input); + if ( LOG_MAP_TIMES ) logger.info("MAP TIME " + (mapTimer.currentTimeNano() - preMapTime)); + if ( TIME_CALLS ) mapTimer.stop(); + + if ( i++ % bufferSize == 0 && progressFunction != null ) + progressFunction.progress(input); + + // reduce + if ( TIME_CALLS ) reduceTimer.restart(); sum = reduce.apply(mapValue, sum); + if ( TIME_CALLS ) reduceTimer.stop(); + + if ( TIME_CALLS ) inputTimer.restart(); } + return sum; } @@ -203,21 +250,36 @@ public class NanoScheduler { * @return the reduce result of this map/reduce job */ private ReduceType executeMultiThreaded(final Iterator inputReader, - final MapFunction map, + final NanoSchedulerMapFunction map, final ReduceType initialValue, - final ReduceFunction reduce) { + final NanoSchedulerReduceFunction reduce) { debugPrint("Executing nanoScheduler"); ReduceType sum = initialValue; - while ( inputReader.hasNext() ) { + boolean done = false; + + final BlockingQueue inputQueue = new LinkedBlockingDeque(bufferSize); + + inputExecutor.submit(new InputProducer(inputReader, inputQueue)); + + while ( ! done ) { try { - // read in our input values - final List inputs = readInputs(inputReader); + final Pair, Boolean> readResults = readInputs(inputQueue); + final List inputs = readResults.getFirst(); + done = readResults.getSecond(); - // send jobs for map - final Queue>> mapQueue = submitMapJobs(map, executor, inputs); + if ( ! inputs.isEmpty() ) { + // send jobs for map + final Queue> mapQueue = submitMapJobs(map, mapExecutor, inputs); - // send off the reduce job, and block until we get at least one reduce result - sum = reduceParallel(reduce, mapQueue, sum); + // send off the reduce job, and block until we get at least one reduce result + sum = reduceSerial(reduce, mapQueue, sum); + debugPrint(" Done with cycle of map/reduce"); + + if ( progressFunction != null ) progressFunction.progress(inputs.get(inputs.size()-1)); + } else { + // we must be done + if ( ! done ) throw new IllegalStateException("Inputs empty but not done"); + } } catch (InterruptedException ex) { throw new ReviewedStingException("got execution exception", ex); } catch (ExecutionException ex) { @@ -229,16 +291,19 @@ public class NanoScheduler { } @Requires({"reduce != null", "! mapQueue.isEmpty()"}) - private ReduceType reduceParallel(final ReduceFunction reduce, - final Queue>> mapQueue, - final ReduceType initSum) + private ReduceType reduceSerial(final NanoSchedulerReduceFunction reduce, + final Queue> mapQueue, + final ReduceType initSum) throws InterruptedException, ExecutionException { ReduceType sum = initSum; // while mapQueue has something in it to reduce - for ( final Future> future : mapQueue ) { - for ( final MapType value : future.get() ) // block until we get the values for this task - sum = reduce.apply(value, sum); + for ( final Future future : mapQueue ) { + final MapType value = future.get(); // block until we get the values for this task + + if ( TIME_CALLS ) reduceTimer.restart(); + sum = reduce.apply(value, sum); + if ( TIME_CALLS ) reduceTimer.stop(); } return sum; @@ -247,30 +312,81 @@ public class NanoScheduler { /** * Read up to inputBufferSize elements from inputReader * - * @return a queue of inputs read in, containing one or more values of InputType read in + * @return a queue of input read in, containing one or more values of InputType read in */ - @Requires("inputReader.hasNext()") - @Ensures("!result.isEmpty()") - private List readInputs(final Iterator inputReader) { + @Requires("inputReader != null") + @Ensures("result != null") + private Pair, Boolean> readInputs(final BlockingQueue inputReader) throws InterruptedException { int n = 0; final List inputs = new LinkedList(); - while ( inputReader.hasNext() && n < getBufferSize() ) { - final InputType input = inputReader.next(); - inputs.add(input); - n++; + boolean done = false; + + while ( ! done && n < getBufferSize() ) { + final InputDatum input = inputReader.take(); + done = input.isLast(); + if ( ! done ) { + inputs.add(input.datum); + n++; + } + } + + return new Pair, Boolean>(inputs, done); + } + + private class InputProducer implements Runnable { + final Iterator inputReader; + final BlockingQueue outputQueue; + + public InputProducer(final Iterator inputReader, final BlockingQueue outputQueue) { + this.inputReader = inputReader; + this.outputQueue = outputQueue; + } + + public void run() { + try { + while ( inputReader.hasNext() ) { + if ( TIME_CALLS ) inputTimer.restart(); + final InputType input = inputReader.next(); + if ( TIME_CALLS ) inputTimer.stop(); + outputQueue.put(new InputDatum(input)); + } + + // add the EOF object so we know we are done + outputQueue.put(new InputDatum()); + } catch (InterruptedException ex) { + throw new ReviewedStingException("got execution exception", ex); + } + } + } + + private class InputDatum { + final boolean isLast; + final InputType datum; + + private InputDatum(final InputType datum) { + isLast = false; + this.datum = datum; + } + + private InputDatum() { + isLast = true; + this.datum = null; + } + + public boolean isLast() { + return isLast; } - return inputs; } @Requires({"map != null", "! inputs.isEmpty()"}) - private Queue>> submitMapJobs(final MapFunction map, - final ExecutorService executor, - final List inputs) { - final Queue>> mapQueue = new LinkedList>>(); + private Queue> submitMapJobs(final NanoSchedulerMapFunction map, + final ExecutorService executor, + final List inputs) { + final Queue> mapQueue = new LinkedList>(); - for ( final List subinputs : Utils.groupList(inputs, getMapGroupSize()) ) { - final CallableMap doMap = new CallableMap(map, subinputs); - final Future> future = executor.submit(doMap); + for ( final InputType input : inputs ) { + final CallableMap doMap = new CallableMap(map, input); + final Future future = executor.submit(doMap); mapQueue.add(future); } @@ -280,23 +396,22 @@ public class NanoScheduler { /** * A simple callable version of the map function for use with the executor pool */ - private class CallableMap implements Callable> { - final List inputs; - final MapFunction map; + private class CallableMap implements Callable { + final InputType input; + final NanoSchedulerMapFunction map; - @Requires({"map != null", "inputs.size() <= getMapGroupSize()"}) - private CallableMap(final MapFunction map, final List inputs) { - this.inputs = inputs; + @Requires({"map != null"}) + private CallableMap(final NanoSchedulerMapFunction map, final InputType inputs) { + this.input = inputs; this.map = map; } - @Ensures("result.size() == inputs.size()") - @Override public List call() throws Exception { - final List outputs = new LinkedList(); - for ( final InputType input : inputs ) - outputs.add(map.apply(input)); - debugPrint(" Processed %d elements with map", outputs.size()); - return outputs; + @Override public MapType call() throws Exception { + if ( TIME_CALLS ) mapTimer.restart(); + if ( debug ) debugPrint("\t\tmap " + input); + final MapType result = map.apply(input); + if ( TIME_CALLS ) mapTimer.stop(); + return result; } } } diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/MapFunction.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerMapFunction.java similarity index 84% rename from public/java/src/org/broadinstitute/sting/utils/nanoScheduler/MapFunction.java rename to public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerMapFunction.java index 440c263b7..ddf4421d2 100644 --- a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/MapFunction.java +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerMapFunction.java @@ -9,7 +9,7 @@ package org.broadinstitute.sting.utils.nanoScheduler; * Date: 8/24/12 * Time: 9:49 AM */ -public interface MapFunction { +public interface NanoSchedulerMapFunction { /** * Return function on input, returning a value of ResultType * @param input diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerProgressFunction.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerProgressFunction.java new file mode 100644 index 000000000..8631196a3 --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerProgressFunction.java @@ -0,0 +1,12 @@ +package org.broadinstitute.sting.utils.nanoScheduler; + +/** + * Created with IntelliJ IDEA. + * User: depristo + * Date: 9/4/12 + * Time: 2:10 PM + * To change this template use File | Settings | File Templates. + */ +public interface NanoSchedulerProgressFunction { + public void progress(final InputType lastMapInput); +} diff --git a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/ReduceFunction.java b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerReduceFunction.java similarity index 87% rename from public/java/src/org/broadinstitute/sting/utils/nanoScheduler/ReduceFunction.java rename to public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerReduceFunction.java index 8f1b0eddd..7e58eeaf9 100644 --- a/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/ReduceFunction.java +++ b/public/java/src/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerReduceFunction.java @@ -7,7 +7,7 @@ package org.broadinstitute.sting.utils.nanoScheduler; * Date: 8/24/12 * Time: 9:49 AM */ -public interface ReduceFunction { +public interface NanoSchedulerReduceFunction { /** * Combine one with sum into a new ReduceType * @param one the result of a map call on an input element diff --git a/public/java/test/org/broadinstitute/sting/WalkerTest.java b/public/java/test/org/broadinstitute/sting/WalkerTest.java index 7e38c00f3..bcfd00aed 100755 --- a/public/java/test/org/broadinstitute/sting/WalkerTest.java +++ b/public/java/test/org/broadinstitute/sting/WalkerTest.java @@ -40,13 +40,13 @@ import org.broadinstitute.sting.utils.collections.Pair; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import org.broadinstitute.sting.utils.exceptions.StingException; import org.broadinstitute.sting.utils.variantcontext.VariantContextTestProvider; - -import java.io.*; - import org.testng.Assert; import org.testng.annotations.AfterSuite; import org.testng.annotations.BeforeMethod; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.PrintStream; import java.text.SimpleDateFormat; import java.util.*; @@ -251,20 +251,43 @@ public class WalkerTest extends BaseTest { return false; } - protected Pair, List> executeTestParallel(final String name, WalkerTestSpec spec) { - return executeTest(name, spec, Arrays.asList(1, 4)); + public enum ParallelTestType { + TREE_REDUCIBLE, + NANO_SCHEDULED, + BOTH } - protected Pair, List> executeTest(final String name, WalkerTestSpec spec, List parallelThreads) { + protected Pair, List> executeTestParallel(final String name, WalkerTestSpec spec, ParallelTestType testType) { + final List ntThreads = testType == ParallelTestType.TREE_REDUCIBLE || testType == ParallelTestType.BOTH ? Arrays.asList(1, 4) : Collections.emptyList(); + final List cntThreads = testType == ParallelTestType.NANO_SCHEDULED || testType == ParallelTestType.BOTH ? Arrays.asList(1, 4) : Collections.emptyList(); + + return executeTest(name, spec, ntThreads, cntThreads); + } + + protected Pair, List> executeTestParallel(final String name, WalkerTestSpec spec) { + return executeTestParallel(name, spec, ParallelTestType.TREE_REDUCIBLE); + } + + protected Pair, List> executeTest(final String name, WalkerTestSpec spec, List ntThreads, List cpuThreads) { String originalArgs = spec.args; Pair, List> results = null; - for ( int nt : parallelThreads ) { + boolean ran1 = false; + for ( int nt : ntThreads ) { String extra = nt == 1 ? "" : (" -nt " + nt); + ran1 = ran1 || nt == 1; spec.args = originalArgs + extra; results = executeTest(name + "-nt-" + nt, spec); } + for ( int cnt : cpuThreads ) { + if ( cnt != 1 ) { + String extra = " -cnt " + cnt; + spec.args = originalArgs + extra; + results = executeTest(name + "-cnt-" + cnt, spec); + } + } + return results; } diff --git a/public/java/test/org/broadinstitute/sting/utils/SimpleTimerUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/SimpleTimerUnitTest.java index 7a2696b7b..7285c00ac 100755 --- a/public/java/test/org/broadinstitute/sting/utils/SimpleTimerUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/SimpleTimerUnitTest.java @@ -1,12 +1,12 @@ package org.broadinstitute.sting.utils; import org.broadinstitute.sting.BaseTest; -import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import org.testng.Assert; -import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import java.io.File; +import java.util.Arrays; +import java.util.List; +import java.util.concurrent.TimeUnit; public class SimpleTimerUnitTest extends BaseTest { private final static String NAME = "unit.test.timer"; @@ -17,33 +17,88 @@ public class SimpleTimerUnitTest extends BaseTest { Assert.assertEquals(t.getName(), NAME, "Name is not the provided one"); Assert.assertFalse(t.isRunning(), "Initial state of the timer is running"); Assert.assertEquals(t.getElapsedTime(), 0.0, "New timer elapsed time should be 0"); + Assert.assertEquals(t.getElapsedTimeNano(), 0l, "New timer elapsed time nano should be 0"); t.start(); Assert.assertTrue(t.isRunning(), "Started timer isn't running"); Assert.assertTrue(t.getElapsedTime() >= 0.0, "Elapsed time should be >= 0"); + Assert.assertTrue(t.getElapsedTimeNano() >= 0.0, "Elapsed time nano should be >= 0"); + long n1 = t.getElapsedTimeNano(); double t1 = t.getElapsedTime(); idleLoop(); // idle loop to wait a tiny bit of time + long n2 = t.getElapsedTimeNano(); double t2 = t.getElapsedTime(); Assert.assertTrue(t2 >= t1, "T2 >= T1 for a running time"); + Assert.assertTrue(n2 >= n1, "T2 >= T1 nano for a running time"); t.stop(); Assert.assertFalse(t.isRunning(), "Stopped timer still running"); + long n3 = t.getElapsedTimeNano(); double t3 = t.getElapsedTime(); idleLoop(); // idle loop to wait a tiny bit of time double t4 = t.getElapsedTime(); + long n4 = t.getElapsedTimeNano(); Assert.assertTrue(t4 == t3, "Elapsed times for two calls of stop timer not the same"); + Assert.assertTrue(n4 == n3, "Elapsed times for two calls of stop timer not the same"); t.restart(); idleLoop(); // idle loop to wait a tiny bit of time double t5 = t.getElapsedTime(); + long n5 = t.getElapsedTimeNano(); Assert.assertTrue(t.isRunning(), "Restarted timer should be running"); idleLoop(); // idle loop to wait a tiny bit of time double t6 = t.getElapsedTime(); + long n6 = t.getElapsedTimeNano(); Assert.assertTrue(t5 >= t4, "Restarted timer elapsed time should be after elapsed time preceding the restart"); Assert.assertTrue(t6 >= t5, "Second elapsed time not after the first in restarted timer"); + Assert.assertTrue(n5 >= n4, "Restarted timer elapsed time nano should be after elapsed time preceding the restart"); + Assert.assertTrue(n6 >= n5, "Second elapsed time nano not after the first in restarted timer"); + + final List secondTimes = Arrays.asList(t1, t2, t3, t4, t5, t6); + final List nanoTimes = Arrays.asList(n1, n2, n3, n4, n5, n6); + for ( int i = 0; i < nanoTimes.size(); i++ ) + Assert.assertEquals( + SimpleTimer.nanoToSecondsAsDouble(nanoTimes.get(i)), + secondTimes.get(i), 1e-1, "Nanosecond and second timer disagree"); } - private final static void idleLoop() { + @Test + public void testNanoResolution() { + SimpleTimer t = new SimpleTimer(NAME); + + // test the nanosecond resolution + long n7 = t.currentTimeNano(); + int sum = 0; + for ( int i = 0; i < 100; i++) sum += i; + long n8 = t.currentTimeNano(); + final long delta = n8 - n7; + final long oneMilliInNano = TimeUnit.MILLISECONDS.toNanos(1); + logger.warn("nanoTime before nano operation " + n7); + logger.warn("nanoTime after nano operation of summing 100 ints " + n8 + ", sum = " + sum + " time delta " + delta + " vs. 1 millsecond in nano " + oneMilliInNano); + Assert.assertTrue(n8 > n7, "SimpleTimer doesn't appear to have nanoSecond resolution: n8 " + n8 + " <= n7 " + n7); + Assert.assertTrue(delta < oneMilliInNano, + "SimpleTimer doesn't appear to have nanoSecond resolution: time delta is " + delta + " vs 1 millisecond in nano " + oneMilliInNano); + } + + @Test + public void testMeaningfulTimes() { + SimpleTimer t = new SimpleTimer(NAME); + + t.start(); + for ( int i = 0; i < 100; i++ ) ; + long nano = t.getElapsedTimeNano(); + double secs = t.getElapsedTime(); + + Assert.assertTrue(secs > 0, "Seconds timer doesn't appear to count properly: elapsed time is " + secs); + Assert.assertTrue(secs < 0.01, "Fast operation said to take longer than 10 milliseconds: elapsed time in seconds " + secs); + + Assert.assertTrue(nano > 0, "Nanosecond timer doesn't appear to count properly: elapsed time is " + nano); + final long maxTimeInMicro = 100; + final long maxTimeInNano = TimeUnit.MICROSECONDS.toNanos(100); + Assert.assertTrue(nano < maxTimeInNano, "Fast operation said to take longer than " + maxTimeInMicro + " microseconds: elapsed time in nano " + nano + " micro " + TimeUnit.NANOSECONDS.toMicros(nano)); + } + + private static void idleLoop() { for ( int i = 0; i < 100000; i++ ) ; // idle loop to wait a tiny bit of time } } \ No newline at end of file diff --git a/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java index 89506dcb1..ddfc3cecd 100644 --- a/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/nanoScheduler/NanoSchedulerUnitTest.java @@ -5,7 +5,10 @@ import org.testng.Assert; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; -import java.util.*; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; /** * UnitTests for the NanoScheduler @@ -18,11 +21,11 @@ import java.util.*; public class NanoSchedulerUnitTest extends BaseTest { public static final int NANO_SCHEDULE_MAX_RUNTIME = 60000; - private static class Map2x implements MapFunction { + private static class Map2x implements NanoSchedulerMapFunction { @Override public Integer apply(Integer input) { return input * 2; } } - private static class ReduceSum implements ReduceFunction { + private static class ReduceSum implements NanoSchedulerReduceFunction { int prevOne = Integer.MIN_VALUE; @Override public Integer apply(Integer one, Integer sum) { @@ -31,6 +34,16 @@ public class NanoSchedulerUnitTest extends BaseTest { } } + private static class ProgressCallback implements NanoSchedulerProgressFunction { + int callBacks = 0; + + @Override + public void progress(Integer lastMapInput) { + callBacks++; + } + } + + private static int sum2x(final int start, final int end) { int sum = 0; for ( int i = start; i < end; i++ ) @@ -39,18 +52,17 @@ public class NanoSchedulerUnitTest extends BaseTest { } private static class NanoSchedulerBasicTest extends TestDataProvider { - final int bufferSize, mapGroupSize, nThreads, start, end, expectedResult; + final int bufferSize, nThreads, start, end, expectedResult; - public NanoSchedulerBasicTest(final int bufferSize, final int mapGroupSize, final int nThreads, final int start, final int end) { + public NanoSchedulerBasicTest(final int bufferSize, final int nThreads, final int start, final int end) { super(NanoSchedulerBasicTest.class); this.bufferSize = bufferSize; - this.mapGroupSize = mapGroupSize; this.nThreads = nThreads; this.start = start; this.end = end; this.expectedResult = sum2x(start, end); - setName(String.format("%s nt=%d buf=%d mapGroupSize=%d start=%d end=%d sum=%d", - getClass().getSimpleName(), nThreads, bufferSize, mapGroupSize, start, end, expectedResult)); + setName(String.format("%s nt=%d buf=%d start=%d end=%d sum=%d", + getClass().getSimpleName(), nThreads, bufferSize, start, end, expectedResult)); } public Iterator makeReader() { @@ -60,6 +72,11 @@ public class NanoSchedulerUnitTest extends BaseTest { return ints.iterator(); } + public int nExpectedCallbacks() { + int nElements = Math.max(end - start, 0); + return nElements / bufferSize; + } + public Map2x makeMap() { return new Map2x(); } public Integer initReduce() { return 0; } public ReduceSum makeReduce() { return new ReduceSum(); } @@ -69,14 +86,10 @@ public class NanoSchedulerUnitTest extends BaseTest { @DataProvider(name = "NanoSchedulerBasicTest") public Object[][] createNanoSchedulerBasicTest() { for ( final int bufferSize : Arrays.asList(1, 10, 1000, 1000000) ) { - for ( final int mapGroupSize : Arrays.asList(-1, 1, 10, 100, 1000) ) { - if ( mapGroupSize <= bufferSize ) { - for ( final int nt : Arrays.asList(1, 2, 4) ) { - for ( final int start : Arrays.asList(0) ) { - for ( final int end : Arrays.asList(1, 2, 11, 10000, 100000) ) { - exampleTest = new NanoSchedulerBasicTest(bufferSize, mapGroupSize, nt, start, end); - } - } + for ( final int nt : Arrays.asList(1, 2, 4) ) { + for ( final int start : Arrays.asList(0) ) { + for ( final int end : Arrays.asList(0, 1, 2, 11, 10000, 100000) ) { + exampleTest = new NanoSchedulerBasicTest(bufferSize, nt, start, end); } } } @@ -101,25 +114,29 @@ public class NanoSchedulerUnitTest extends BaseTest { private void testNanoScheduler(final NanoSchedulerBasicTest test) throws InterruptedException { final NanoScheduler nanoScheduler = - new NanoScheduler(test.bufferSize, test.mapGroupSize, test.nThreads); + new NanoScheduler(test.bufferSize, test.nThreads); + + final ProgressCallback callback = new ProgressCallback(); + nanoScheduler.setProgressFunction(callback); Assert.assertEquals(nanoScheduler.getBufferSize(), test.bufferSize, "bufferSize argument"); - Assert.assertTrue(nanoScheduler.getMapGroupSize() >= test.mapGroupSize, "mapGroupSize argument"); Assert.assertEquals(nanoScheduler.getnThreads(), test.nThreads, "nThreads argument"); final Integer sum = nanoScheduler.execute(test.makeReader(), test.makeMap(), test.initReduce(), test.makeReduce()); Assert.assertNotNull(sum); Assert.assertEquals((int)sum, test.expectedResult, "NanoScheduler sum not the same as calculated directly"); + + Assert.assertTrue(callback.callBacks >= test.nExpectedCallbacks(), "Not enough callbacks detected. Expected at least " + test.nExpectedCallbacks() + " but saw only " + callback.callBacks); nanoScheduler.shutdown(); } @Test(enabled = true, dataProvider = "NanoSchedulerBasicTest", dependsOnMethods = "testMultiThreadedNanoScheduler", timeOut = NANO_SCHEDULE_MAX_RUNTIME) public void testNanoSchedulerInLoop(final NanoSchedulerBasicTest test) throws InterruptedException { - if ( test.bufferSize > 1 && (test.mapGroupSize > 1 || test.mapGroupSize == -1)) { + if ( test.bufferSize > 1) { logger.warn("Running " + test); final NanoScheduler nanoScheduler = - new NanoScheduler(test.bufferSize, test.mapGroupSize, test.nThreads); + new NanoScheduler(test.bufferSize, test.nThreads); // test reusing the scheduler for ( int i = 0; i < 10; i++ ) { @@ -134,7 +151,7 @@ public class NanoSchedulerUnitTest extends BaseTest { @Test(timeOut = NANO_SCHEDULE_MAX_RUNTIME) public void testShutdown() throws InterruptedException { - final NanoScheduler nanoScheduler = new NanoScheduler(1, 1, 2); + final NanoScheduler nanoScheduler = new NanoScheduler(1, 2); Assert.assertFalse(nanoScheduler.isShutdown(), "scheduler should be alive"); nanoScheduler.shutdown(); Assert.assertTrue(nanoScheduler.isShutdown(), "scheduler should be dead"); @@ -142,15 +159,16 @@ public class NanoSchedulerUnitTest extends BaseTest { @Test(expectedExceptions = IllegalStateException.class, timeOut = NANO_SCHEDULE_MAX_RUNTIME) public void testShutdownExecuteFailure() throws InterruptedException { - final NanoScheduler nanoScheduler = new NanoScheduler(1, 1, 2); + final NanoScheduler nanoScheduler = new NanoScheduler(1, 2); nanoScheduler.shutdown(); nanoScheduler.execute(exampleTest.makeReader(), exampleTest.makeMap(), exampleTest.initReduce(), exampleTest.makeReduce()); } public static void main(String [ ] args) { - final NanoSchedulerBasicTest test = new NanoSchedulerBasicTest(1000, 100, Integer.valueOf(args[0]), 0, Integer.valueOf(args[1])); + final NanoSchedulerBasicTest test = new NanoSchedulerBasicTest(1000, Integer.valueOf(args[0]), 0, Integer.valueOf(args[1])); final NanoScheduler nanoScheduler = - new NanoScheduler(test.bufferSize, test.mapGroupSize, test.nThreads); + new NanoScheduler(test.bufferSize, test.nThreads); + nanoScheduler.setDebug(true); final Integer sum = nanoScheduler.execute(test.makeReader(), test.makeMap(), test.initReduce(), test.makeReduce()); System.out.printf("Sum = %d, expected =%d%n", sum, test.expectedResult);