diff --git a/java/src/org/broadinstitute/sting/utils/interval/IntervalUtils.java b/java/src/org/broadinstitute/sting/utils/interval/IntervalUtils.java index cd3d603a1..1d148ac95 100644 --- a/java/src/org/broadinstitute/sting/utils/interval/IntervalUtils.java +++ b/java/src/org/broadinstitute/sting/utils/interval/IntervalUtils.java @@ -124,7 +124,7 @@ public class IntervalUtils { } //if we have an empty list, throw an exception. If they specified intersection and there are no items, this is bad. - if (retList == null || retList.size() == 0) + if (retList.size() == 0) throw new UserException.BadInput("The INTERSECTION of your -BTI and -L options produced no intervals."); // we don't need to add the rest of remaining locations, since we know they don't overlap. return what we have @@ -227,6 +227,31 @@ public class IntervalUtils { return contigs; } + /** + * Counts the number of interval files an interval list can be split into using scatterIntervalArguments. + * @param reference The reference for the intervals. + * @param intervals The interval as strings or file paths. + * @param splitByContig If true then one contig will not be written to multiple files. + * @return The maximum number of parts the intervals can be split into. + */ + public static int countIntervalArguments(File reference, List intervals, boolean splitByContig) { + ReferenceDataSource referenceSource = new ReferenceDataSource(reference); + List locs = parseIntervalArguments(referenceSource, intervals); + int maxFiles = 0; + if (splitByContig) { + String contig = null; + for (GenomeLoc loc: locs) { + if (contig == null || !contig.equals(loc.getContig())) { + maxFiles++; + contig = loc.getContig(); + } + } + } else { + maxFiles = locs.size(); + } + return maxFiles; + } + /** * Splits an interval list into multiple files. * @param reference The reference for the intervals. diff --git a/java/test/org/broadinstitute/sting/utils/interval/IntervalUtilsUnitTest.java b/java/test/org/broadinstitute/sting/utils/interval/IntervalUtilsUnitTest.java index d366085d9..7711759a5 100644 --- a/java/test/org/broadinstitute/sting/utils/interval/IntervalUtilsUnitTest.java +++ b/java/test/org/broadinstitute/sting/utils/interval/IntervalUtilsUnitTest.java @@ -15,6 +15,7 @@ import java.io.File; import java.io.FileNotFoundException; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.List; /** @@ -117,6 +118,16 @@ public class IntervalUtilsUnitTest extends BaseTest { Assert.assertEquals(IntervalUtils.distinctContigs(reference, Arrays.asList("chr2:1-1", "chr1:1-1", "chr3:2-2")), Arrays.asList("chr1","chr2","chr3")); } + @Test + public void testCountIntervals() { + Assert.assertEquals(IntervalUtils.countIntervalArguments(reference, Collections.emptyList(), false), 45); + Assert.assertEquals(IntervalUtils.countIntervalArguments(reference, Collections.emptyList(), true), 45); + Assert.assertEquals(IntervalUtils.countIntervalArguments(reference, Arrays.asList("chr1", "chr2", "chr3"), false), 3); + Assert.assertEquals(IntervalUtils.countIntervalArguments(reference, Arrays.asList("chr1", "chr2", "chr3"), true), 3); + Assert.assertEquals(IntervalUtils.countIntervalArguments(reference, Arrays.asList("chr1:1-2", "chr1:4-5", "chr2:1-1", "chr3:2-2"), false), 4); + Assert.assertEquals(IntervalUtils.countIntervalArguments(reference, Arrays.asList("chr1:1-2", "chr1:4-5", "chr2:1-1", "chr3:2-2"), true), 3); + } + @Test public void testBasicScatter() { GenomeLoc chr1 = genomeLocParser.parseGenomeInterval("chr1"); diff --git a/scala/qscript/examples/ExampleCountLoci.scala b/scala/qscript/examples/ExampleCountLoci.scala new file mode 100644 index 000000000..d3e1add25 --- /dev/null +++ b/scala/qscript/examples/ExampleCountLoci.scala @@ -0,0 +1,33 @@ +import org.broadinstitute.sting.queue.QScript +import org.broadinstitute.sting.queue.extensions.gatk._ + +/** + * An introductory pipeline with integration tests testing the MD5 of the @Output. + */ +class ExampleCountLoci extends QScript { + @Input(doc="The path to the GenomeAnalysisTK.jar file.", shortName="gatk") + var gatkJar: File = null + + @Input(doc="The reference file for the bam files.", shortName="R") + var referenceFile: File = null + + @Input(doc="One or more bam files.", shortName="I") + var bamFiles: List[File] = Nil + + @Input(doc="Intervals to traverse.", shortName="L", required=false) + var intervals: List[String] = Nil + + @Output + var out: File = _ + + def script = { + val countLoci = new CountLoci + countLoci.jarFile = gatkJar + countLoci.reference_sequence = referenceFile + countLoci.input_file = bamFiles + countLoci.intervalsString = intervals + countLoci.out = out + countLoci.memoryLimit = Some(1) + add(countLoci) + } +} diff --git a/scala/qscript/playground/fullCallingPipeline.q b/scala/qscript/playground/fullCallingPipeline.q index 6799a47a9..805b79632 100755 --- a/scala/qscript/playground/fullCallingPipeline.q +++ b/scala/qscript/playground/fullCallingPipeline.q @@ -1,19 +1,14 @@ -import java.io.PrintWriter import org.broadinstitute.sting.commandline.ArgumentSource import org.broadinstitute.sting.datasources.pipeline.Pipeline -import org.broadinstitute.sting.gatk.DownsampleType import org.broadinstitute.sting.queue.extensions.gatk._ import org.broadinstitute.sting.queue.extensions.picard.PicardBamJarFunction import org.broadinstitute.sting.queue.extensions.samtools._ +import org.broadinstitute.sting.queue.function.ListWriterFunction import org.broadinstitute.sting.queue.function.scattergather.{GatherFunction, CloneFunction, ScatterFunction} -import org.broadinstitute.sting.queue.util.IOUtils import org.broadinstitute.sting.queue.QScript import collection.JavaConversions._ -import org.broadinstitute.sting.utils.exceptions.UserException -import org.broadinstitute.sting.utils.interval.IntervalUtils import org.broadinstitute.sting.utils.yaml.YamlUtils -import org.broadinstitute.sting.utils.report.VE2ReportFactory.VE2TemplateType class fullCallingPipeline extends QScript { qscript => @@ -105,9 +100,6 @@ class fullCallingPipeline extends QScript { //val seq = qscript.machine //val expKind = qscript.protocol - // get max num contigs for indel cleaning parallelism, plus 1 for -L unmapped - val numContigs = IntervalUtils.distinctContigs(qscript.pipeline.getProject.getReferenceFile).size + 1 - for ( sample <- recalibratedSamples ) { val sampleId = sample.getId // put unclean bams in unclean genotypers in advance, create the extension files @@ -136,7 +128,7 @@ class fullCallingPipeline extends QScript { realigner.targetIntervals = targetCreator.out realigner.intervals = Nil realigner.intervalsString = Nil - realigner.scatterCount = num_cleaner_scatter_jobs min numContigs + realigner.scatterCount = num_cleaner_scatter_jobs realigner.rodBind :+= RodBind("dbsnp", dbsnpType, qscript.pipeline.getProject.getDbsnpFile) realigner.rodBind :+= RodBind("indels", "VCF", swapExt(realigner.reference_sequence.getParentFile, realigner.reference_sequence, "fasta", "1kg_pilot_indels.vcf")) @@ -384,21 +376,9 @@ class fullCallingPipeline extends QScript { // 5. Make the bam list val listOfBams = new File(base +".BamFiles.list") - class BamListWriter extends InProcessFunction { - @Input(doc="bamFiles") var bamFiles: List[File] = Nil - @Output(doc="bamList") var bamList: File = _ - - def run { - val writer = new PrintWriter(bamList) - for (bamFile <- bamFiles) - writer.println(bamFile.toString) - writer.close() - } - } - - val writeBamList = new BamListWriter - writeBamList.bamFiles = bamFiles - writeBamList.bamList = listOfBams + val writeBamList = new ListWriterFunction + writeBamList.inputFiles = bamFiles + writeBamList.listFile = listOfBams writeBamList.analysisName = base + "_BamList" writeBamList.jobOutputFile = new File(".queue/logs/SNPCalling/bamlist.out") diff --git a/scala/src/org/broadinstitute/sting/queue/extensions/gatk/IntervalScatterFunction.scala b/scala/src/org/broadinstitute/sting/queue/extensions/gatk/IntervalScatterFunction.scala index 25d3bb699..5e3384020 100644 --- a/scala/src/org/broadinstitute/sting/queue/extensions/gatk/IntervalScatterFunction.scala +++ b/scala/src/org/broadinstitute/sting/queue/extensions/gatk/IntervalScatterFunction.scala @@ -1,12 +1,12 @@ package org.broadinstitute.sting.queue.extensions.gatk -import org.broadinstitute.sting.commandline.ArgumentSource import org.broadinstitute.sting.utils.interval.IntervalUtils import java.io.File import collection.JavaConversions._ import org.broadinstitute.sting.queue.util.IOUtils import org.broadinstitute.sting.queue.function.scattergather.{CloneFunction, ScatterGatherableFunction, ScatterFunction} import org.broadinstitute.sting.queue.function.{QFunction, InProcessFunction} +import org.broadinstitute.sting.commandline.{Output, ArgumentSource} /** * An interval scatter function. @@ -15,7 +15,7 @@ class IntervalScatterFunction extends ScatterFunction with InProcessFunction { var splitByContig = false /** The total number of clone jobs that will be created. */ - private var scatterCount: Int = _ + var scatterCount: Int = _ /** The reference sequence for the GATK function. */ private var referenceSequence: File = _ @@ -32,6 +32,9 @@ class IntervalScatterFunction extends ScatterFunction with InProcessFunction { /** Whether the laster scatter job should also include any unmapped reads. */ private var includeUnmapped: Boolean = _ + @Output(doc="Scattered intervals") + var scatterParts: List[File] = Nil + /** * Checks if the function is scatter gatherable. * @param originalFunction Function to check. @@ -54,7 +57,6 @@ class IntervalScatterFunction extends ScatterFunction with InProcessFunction { val gatk = originalFunction.asInstanceOf[CommandLineGATK] this.intervalsField = QFunction.findField(originalFunction.getClass, "intervals") this.intervalsStringField = QFunction.findField(originalFunction.getClass, "intervalsString") - this.scatterCount = originalFunction.scatterCount this.referenceSequence = gatk.reference_sequence if (gatk.intervals.isEmpty && gatk.intervalsString.isEmpty) { this.intervals ++= IntervalUtils.distinctContigs(this.referenceSequence).toList @@ -64,11 +66,14 @@ class IntervalScatterFunction extends ScatterFunction with InProcessFunction { this.intervals ++= gatk.intervalsString.filterNot(interval => IntervalUtils.isUnmapped(interval)) this.includeUnmapped = gatk.intervalsString.exists(interval => IntervalUtils.isUnmapped(interval)) } + + val maxScatterCount = IntervalUtils.countIntervalArguments(this.referenceSequence, this.intervals, this.splitByContig) + this.scatterCount = maxScatterCount min originalFunction.scatterCount } def initCloneInputs(cloneFunction: CloneFunction, index: Int) = { cloneFunction.setFieldValue(this.intervalsField, List(new File("scatter.intervals"))) - if (index == scatterCount && includeUnmapped) + if (index == this.scatterCount && this.includeUnmapped) cloneFunction.setFieldValue(this.intervalsStringField, List("unmapped")) else cloneFunction.setFieldValue(this.intervalsStringField, List.empty[String]) diff --git a/scala/src/org/broadinstitute/sting/queue/function/ListWriterFunction.scala b/scala/src/org/broadinstitute/sting/queue/function/ListWriterFunction.scala new file mode 100644 index 000000000..f60302ef4 --- /dev/null +++ b/scala/src/org/broadinstitute/sting/queue/function/ListWriterFunction.scala @@ -0,0 +1,33 @@ +package org.broadinstitute.sting.queue.function + +import org.broadinstitute.sting.commandline.{Input, Output} +import java.io.{PrintWriter, File} +import org.apache.commons.io.IOUtils + +/** + * Writes a list of inputs to an output file. + * Custom formats can override addFile. + */ +class ListWriterFunction extends InProcessFunction { + @Input(doc="input files") var inputFiles: List[File] = Nil + @Output(doc="output file") var listFile: File = _ + + def run { + val writer = new PrintWriter(listFile) + try { + for (inputFile <- inputFiles) + addFile(writer, inputFile) + } finally { + IOUtils.closeQuietly(writer) + } + } + + /** + * Adds the inputFile to the output list. + * @param writer Output file. + * @param inputFile File to add to the output file. + */ + def addFile(writer: PrintWriter, inputFile: File) { + writer.println(inputFile.toString) + } +} diff --git a/scala/src/org/broadinstitute/sting/queue/function/scattergather/ScatterFunction.scala b/scala/src/org/broadinstitute/sting/queue/function/scattergather/ScatterFunction.scala index 86e9ab921..790c575bf 100644 --- a/scala/src/org/broadinstitute/sting/queue/function/scattergather/ScatterFunction.scala +++ b/scala/src/org/broadinstitute/sting/queue/function/scattergather/ScatterFunction.scala @@ -1,7 +1,7 @@ package org.broadinstitute.sting.queue.function.scattergather import java.io.File -import org.broadinstitute.sting.commandline.{Input, Output} +import org.broadinstitute.sting.commandline.Input import org.broadinstitute.sting.queue.function.QFunction /** @@ -11,9 +11,6 @@ trait ScatterFunction extends QFunction { @Input(doc="Original inputs to scatter") var originalInputs: Set[File] = _ - @Output(doc="Scattered parts of the original inputs, one set per temp directory") - var scatterParts: List[File] = Nil - /** * Returns true if the scatter function can scatter this original function. * @param originalFunction The original function to check. @@ -27,6 +24,11 @@ trait ScatterFunction extends QFunction { */ def setScatterGatherable(originalFunction: ScatterGatherableFunction) + /** + * After a call to setScatterGatherable(), returns the number of clones that should be created. + */ + def scatterCount: Int + /** * Initializes the input fields for the clone function. * The input values should be set to their defaults diff --git a/scala/src/org/broadinstitute/sting/queue/function/scattergather/ScatterGatherableFunction.scala b/scala/src/org/broadinstitute/sting/queue/function/scattergather/ScatterGatherableFunction.scala index 4609f6f18..a4b31b519 100644 --- a/scala/src/org/broadinstitute/sting/queue/function/scattergather/ScatterGatherableFunction.scala +++ b/scala/src/org/broadinstitute/sting/queue/function/scattergather/ScatterGatherableFunction.scala @@ -11,7 +11,7 @@ import org.broadinstitute.sting.queue.QException */ trait ScatterGatherableFunction extends CommandLineFunction { - /** Number of parts to scatter the function into" */ + /** Maximum number of parts to scatter the function into. */ var scatterCount: Int = 1 /** scatter gather directory */ @@ -79,10 +79,13 @@ trait ScatterGatherableFunction extends CommandLineFunction { initScatterFunction(scatterFunction) functions :+= scatterFunction + // Ask the scatter function how many clones to create. + val numClones = scatterFunction.scatterCount + // Create the gather functions for each output field var gatherFunctions = Map.empty[ArgumentSource, GatherFunction] var gatherOutputs = Map.empty[ArgumentSource, File] - var gatherAddOrder = this.scatterCount + 2 + var gatherAddOrder = numClones + 2 for (gatherField <- outputFieldsWithValues) { val gatherFunction = this.newGatherFunction(gatherField) val gatherOutput = getFieldFile(gatherField) @@ -100,7 +103,7 @@ trait ScatterGatherableFunction extends CommandLineFunction { // Create the clone functions for running the parallel jobs var cloneFunctions = List.empty[CloneFunction] - for (i <- 1 to this.scatterCount) { + for (i <- 1 to numClones) { val cloneFunction = this.newCloneFunction() syncFunction(cloneFunction) diff --git a/scala/test/org/broadinstitute/sting/queue/pipeline/PipelineTest.scala b/scala/test/org/broadinstitute/sting/queue/pipeline/PipelineTest.scala index a7dfae4f3..b4c210ccf 100644 --- a/scala/test/org/broadinstitute/sting/queue/pipeline/PipelineTest.scala +++ b/scala/test/org/broadinstitute/sting/queue/pipeline/PipelineTest.scala @@ -11,6 +11,43 @@ import java.text.SimpleDateFormat import org.broadinstitute.sting.{WalkerTest, BaseTest} object PipelineTest { + + /** The path to the current Sting directory. Useful when specifying Sting resources. */ + val currentStingDir = new File(".").getAbsolutePath + + /** The path to the current build of the GATK jar in the currentStingDir. */ + val currentGATK = new File(currentStingDir, "dist/GenomeAnalysisTK.jar") + + /** + * Returns the top level output path to this test. + * @param testName The name of the test passed to PipelineTest.executeTest() + * @return the top level output path to this test. + */ + def testDir(testName: String) = "pipelinetests/%s/".format(testName) + + /** + * Returns the directory where relative output files will be written for this test. + * @param testName The name of the test passed to PipelineTest.executeTest() + * @return the directory where relative output files will be written for this test. + */ + def runDir(testName: String) = testDir(testName) + "run/" + + /** + * Returns the directory where temp files will be written for this test. + * @param testName The name of the test passed to PipelineTest.executeTest() + * @return the directory where temp files will be written for this test. + */ + def tempDir(testName: String) = testDir(testName) + "temp/" + + /** + * Encapsulates a file MD5 + * @param testName The name of the test also passed to PipelineTest.executeTest(). + * @param filePath The file path of the output file, relative to the directory the pipeline is run in. + * @param md5 The expected MD5 + * @return a file md5 that can be appended to the PipelineTestSpec.fileMD5s + */ + def fileMD5(testName: String, filePath: String, md5: String) = (new File(runDir(testName) + filePath), md5) + private var runningCommandLines = Set.empty[QCommandLine] private val validationReportsDataLocation = "/humgen/gsa-hpprojects/GATK/validationreports/submitted/" @@ -34,10 +71,11 @@ object PipelineTest { var failed = 0 for ((file, expectedMD5) <- fileMD5s) { val calculatedMD5 = BaseTest.testFileMD5(name, file, expectedMD5, false) - failed += 1 + if (expectedMD5 != "" && expectedMD5 != calculatedMD5) + failed += 1 } if (failed > 0) - Assert.fail("%d MD5%s did not match.".format(failed, TextFormatUtils.plural(failed))) + Assert.fail("%d of %d MD5%s did not match.".format(failed, fileMD5s.size, TextFormatUtils.plural(failed))) } private def validateEval(name: String, evalSpec: PipelineTestEvalSpec) { @@ -124,11 +162,6 @@ object PipelineTest { } } - val currentDir = new File(".").getAbsolutePath - def testDir(testName: String) = "pipelinetests/%s/".format(testName) - def runDir(testName: String) = testDir(testName) + "run/" - def tempDir(testName: String) = testDir(testName) + "temp/" - Runtime.getRuntime.addShutdownHook(new Thread { /** Cleanup as the JVM shuts down. */ override def run { diff --git a/scala/test/org/broadinstitute/sting/queue/pipeline/examples/ExampleCountLociPipelineTest.scala b/scala/test/org/broadinstitute/sting/queue/pipeline/examples/ExampleCountLociPipelineTest.scala new file mode 100644 index 000000000..59d4d5162 --- /dev/null +++ b/scala/test/org/broadinstitute/sting/queue/pipeline/examples/ExampleCountLociPipelineTest.scala @@ -0,0 +1,20 @@ +package org.broadinstitute.sting.queue.pipeline.examples + +import org.testng.annotations.Test +import org.broadinstitute.sting.queue.pipeline.{PipelineTest, PipelineTestSpec} +import org.broadinstitute.sting.BaseTest + +class ExampleCountLociPipelineTest { + @Test + def testCountLoci { + var testName = "countloci" + var testOut = "count.out" + val spec = new PipelineTestSpec + // TODO: Use a variable instead of "hour" + spec.args = "-S scala/qscript/examples/ExampleCountLoci.scala -gatk %s -R %s -I %s -o %s -jobQueue hour".format( + PipelineTest.currentGATK, BaseTest.hg18Reference, BaseTest.validationDataLocation + "small_bam_for_countloci.bam", testOut + ) + spec.fileMD5s += PipelineTest.fileMD5(testName, testOut, "67823e4722495eb10a5e4c42c267b3a6") + PipelineTest.executeTest(testName, spec) + } +} diff --git a/scala/test/org/broadinstitute/sting/queue/pipeline/examples/HelloWorldPipelineTest.scala b/scala/test/org/broadinstitute/sting/queue/pipeline/examples/HelloWorldPipelineTest.scala index 6b43479bb..70b188c3a 100644 --- a/scala/test/org/broadinstitute/sting/queue/pipeline/examples/HelloWorldPipelineTest.scala +++ b/scala/test/org/broadinstitute/sting/queue/pipeline/examples/HelloWorldPipelineTest.scala @@ -9,8 +9,6 @@ class HelloWorldPipelineTest { var testName = "helloworld" val spec = new PipelineTestSpec spec.args = "-S scala/qscript/examples/HelloWorld.scala -jobPrefix HelloWorld -jobQueue hour" - // TODO: working example of MD5 usage. - // spec.fileMD5s += new File(PipelineTest.runDir(testName) + "hello.out") -> "0123456789abcdef0123456789abcdef" PipelineTest.executeTest(testName, spec) } } diff --git a/scala/test/org/broadinstitute/sting/queue/pipeline/playground/FullCallingPipelineTest.scala b/scala/test/org/broadinstitute/sting/queue/pipeline/playground/FullCallingPipelineTest.scala index 51d1f999e..7e8335c42 100644 --- a/scala/test/org/broadinstitute/sting/queue/pipeline/playground/FullCallingPipelineTest.scala +++ b/scala/test/org/broadinstitute/sting/queue/pipeline/playground/FullCallingPipelineTest.scala @@ -93,8 +93,8 @@ class FullCallingPipelineTest { var pipelineCommand = ("-retry 1 -S scala/qscript/playground/fullCallingPipeline.q" + " -jobProject %s -Y %s" + " -tearScript %s/R/DataProcessingReport/GetTearsheetStats.R" + - " --gatkjar %s/dist/GenomeAnalysisTK.jar") - .format(projectName, yamlFile, PipelineTest.currentDir, PipelineTest.currentDir) + " --gatkjar %s") + .format(projectName, yamlFile, PipelineTest.currentStingDir, PipelineTest.currentGATK) if (!dataset.runIndelRealigner) { pipelineCommand += " -skipCleaning"