Moved the maximum number of intervals check from FCP to the Queue core so that scatter gather will no longer blow up if you specify a scatter count that is too high.

Moved the BamListWriter from FCP to ListWriterFunction in the Queue core.
Added an ExampleCountLoci QScript along with an example pipeline integration test which checks MD5s.
Added a few more utility methods to PipelineTest including a currentGATK variable that points to the GATK jar.

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@5121 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
kshakir 2011-01-28 23:33:58 +00:00
parent 04d66a7d0d
commit 2ef66af903
12 changed files with 191 additions and 48 deletions

View File

@ -124,7 +124,7 @@ public class IntervalUtils {
}
//if we have an empty list, throw an exception. If they specified intersection and there are no items, this is bad.
if (retList == null || retList.size() == 0)
if (retList.size() == 0)
throw new UserException.BadInput("The INTERSECTION of your -BTI and -L options produced no intervals.");
// we don't need to add the rest of remaining locations, since we know they don't overlap. return what we have
@ -227,6 +227,31 @@ public class IntervalUtils {
return contigs;
}
/**
* Counts the number of interval files an interval list can be split into using scatterIntervalArguments.
* @param reference The reference for the intervals.
* @param intervals The interval as strings or file paths.
* @param splitByContig If true then one contig will not be written to multiple files.
* @return The maximum number of parts the intervals can be split into.
*/
public static int countIntervalArguments(File reference, List<String> intervals, boolean splitByContig) {
ReferenceDataSource referenceSource = new ReferenceDataSource(reference);
List<GenomeLoc> locs = parseIntervalArguments(referenceSource, intervals);
int maxFiles = 0;
if (splitByContig) {
String contig = null;
for (GenomeLoc loc: locs) {
if (contig == null || !contig.equals(loc.getContig())) {
maxFiles++;
contig = loc.getContig();
}
}
} else {
maxFiles = locs.size();
}
return maxFiles;
}
/**
* Splits an interval list into multiple files.
* @param reference The reference for the intervals.

View File

@ -15,6 +15,7 @@ import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
/**
@ -117,6 +118,16 @@ public class IntervalUtilsUnitTest extends BaseTest {
Assert.assertEquals(IntervalUtils.distinctContigs(reference, Arrays.asList("chr2:1-1", "chr1:1-1", "chr3:2-2")), Arrays.asList("chr1","chr2","chr3"));
}
@Test
public void testCountIntervals() {
Assert.assertEquals(IntervalUtils.countIntervalArguments(reference, Collections.<String>emptyList(), false), 45);
Assert.assertEquals(IntervalUtils.countIntervalArguments(reference, Collections.<String>emptyList(), true), 45);
Assert.assertEquals(IntervalUtils.countIntervalArguments(reference, Arrays.asList("chr1", "chr2", "chr3"), false), 3);
Assert.assertEquals(IntervalUtils.countIntervalArguments(reference, Arrays.asList("chr1", "chr2", "chr3"), true), 3);
Assert.assertEquals(IntervalUtils.countIntervalArguments(reference, Arrays.asList("chr1:1-2", "chr1:4-5", "chr2:1-1", "chr3:2-2"), false), 4);
Assert.assertEquals(IntervalUtils.countIntervalArguments(reference, Arrays.asList("chr1:1-2", "chr1:4-5", "chr2:1-1", "chr3:2-2"), true), 3);
}
@Test
public void testBasicScatter() {
GenomeLoc chr1 = genomeLocParser.parseGenomeInterval("chr1");

View File

@ -0,0 +1,33 @@
import org.broadinstitute.sting.queue.QScript
import org.broadinstitute.sting.queue.extensions.gatk._
/**
* An introductory pipeline with integration tests testing the MD5 of the @Output.
*/
class ExampleCountLoci extends QScript {
@Input(doc="The path to the GenomeAnalysisTK.jar file.", shortName="gatk")
var gatkJar: File = null
@Input(doc="The reference file for the bam files.", shortName="R")
var referenceFile: File = null
@Input(doc="One or more bam files.", shortName="I")
var bamFiles: List[File] = Nil
@Input(doc="Intervals to traverse.", shortName="L", required=false)
var intervals: List[String] = Nil
@Output
var out: File = _
def script = {
val countLoci = new CountLoci
countLoci.jarFile = gatkJar
countLoci.reference_sequence = referenceFile
countLoci.input_file = bamFiles
countLoci.intervalsString = intervals
countLoci.out = out
countLoci.memoryLimit = Some(1)
add(countLoci)
}
}

View File

@ -1,19 +1,14 @@
import java.io.PrintWriter
import org.broadinstitute.sting.commandline.ArgumentSource
import org.broadinstitute.sting.datasources.pipeline.Pipeline
import org.broadinstitute.sting.gatk.DownsampleType
import org.broadinstitute.sting.queue.extensions.gatk._
import org.broadinstitute.sting.queue.extensions.picard.PicardBamJarFunction
import org.broadinstitute.sting.queue.extensions.samtools._
import org.broadinstitute.sting.queue.function.ListWriterFunction
import org.broadinstitute.sting.queue.function.scattergather.{GatherFunction, CloneFunction, ScatterFunction}
import org.broadinstitute.sting.queue.util.IOUtils
import org.broadinstitute.sting.queue.QScript
import collection.JavaConversions._
import org.broadinstitute.sting.utils.exceptions.UserException
import org.broadinstitute.sting.utils.interval.IntervalUtils
import org.broadinstitute.sting.utils.yaml.YamlUtils
import org.broadinstitute.sting.utils.report.VE2ReportFactory.VE2TemplateType
class fullCallingPipeline extends QScript {
qscript =>
@ -105,9 +100,6 @@ class fullCallingPipeline extends QScript {
//val seq = qscript.machine
//val expKind = qscript.protocol
// get max num contigs for indel cleaning parallelism, plus 1 for -L unmapped
val numContigs = IntervalUtils.distinctContigs(qscript.pipeline.getProject.getReferenceFile).size + 1
for ( sample <- recalibratedSamples ) {
val sampleId = sample.getId
// put unclean bams in unclean genotypers in advance, create the extension files
@ -136,7 +128,7 @@ class fullCallingPipeline extends QScript {
realigner.targetIntervals = targetCreator.out
realigner.intervals = Nil
realigner.intervalsString = Nil
realigner.scatterCount = num_cleaner_scatter_jobs min numContigs
realigner.scatterCount = num_cleaner_scatter_jobs
realigner.rodBind :+= RodBind("dbsnp", dbsnpType, qscript.pipeline.getProject.getDbsnpFile)
realigner.rodBind :+= RodBind("indels", "VCF", swapExt(realigner.reference_sequence.getParentFile, realigner.reference_sequence, "fasta", "1kg_pilot_indels.vcf"))
@ -384,21 +376,9 @@ class fullCallingPipeline extends QScript {
// 5. Make the bam list
val listOfBams = new File(base +".BamFiles.list")
class BamListWriter extends InProcessFunction {
@Input(doc="bamFiles") var bamFiles: List[File] = Nil
@Output(doc="bamList") var bamList: File = _
def run {
val writer = new PrintWriter(bamList)
for (bamFile <- bamFiles)
writer.println(bamFile.toString)
writer.close()
}
}
val writeBamList = new BamListWriter
writeBamList.bamFiles = bamFiles
writeBamList.bamList = listOfBams
val writeBamList = new ListWriterFunction
writeBamList.inputFiles = bamFiles
writeBamList.listFile = listOfBams
writeBamList.analysisName = base + "_BamList"
writeBamList.jobOutputFile = new File(".queue/logs/SNPCalling/bamlist.out")

View File

@ -1,12 +1,12 @@
package org.broadinstitute.sting.queue.extensions.gatk
import org.broadinstitute.sting.commandline.ArgumentSource
import org.broadinstitute.sting.utils.interval.IntervalUtils
import java.io.File
import collection.JavaConversions._
import org.broadinstitute.sting.queue.util.IOUtils
import org.broadinstitute.sting.queue.function.scattergather.{CloneFunction, ScatterGatherableFunction, ScatterFunction}
import org.broadinstitute.sting.queue.function.{QFunction, InProcessFunction}
import org.broadinstitute.sting.commandline.{Output, ArgumentSource}
/**
* An interval scatter function.
@ -15,7 +15,7 @@ class IntervalScatterFunction extends ScatterFunction with InProcessFunction {
var splitByContig = false
/** The total number of clone jobs that will be created. */
private var scatterCount: Int = _
var scatterCount: Int = _
/** The reference sequence for the GATK function. */
private var referenceSequence: File = _
@ -32,6 +32,9 @@ class IntervalScatterFunction extends ScatterFunction with InProcessFunction {
/** Whether the laster scatter job should also include any unmapped reads. */
private var includeUnmapped: Boolean = _
@Output(doc="Scattered intervals")
var scatterParts: List[File] = Nil
/**
* Checks if the function is scatter gatherable.
* @param originalFunction Function to check.
@ -54,7 +57,6 @@ class IntervalScatterFunction extends ScatterFunction with InProcessFunction {
val gatk = originalFunction.asInstanceOf[CommandLineGATK]
this.intervalsField = QFunction.findField(originalFunction.getClass, "intervals")
this.intervalsStringField = QFunction.findField(originalFunction.getClass, "intervalsString")
this.scatterCount = originalFunction.scatterCount
this.referenceSequence = gatk.reference_sequence
if (gatk.intervals.isEmpty && gatk.intervalsString.isEmpty) {
this.intervals ++= IntervalUtils.distinctContigs(this.referenceSequence).toList
@ -64,11 +66,14 @@ class IntervalScatterFunction extends ScatterFunction with InProcessFunction {
this.intervals ++= gatk.intervalsString.filterNot(interval => IntervalUtils.isUnmapped(interval))
this.includeUnmapped = gatk.intervalsString.exists(interval => IntervalUtils.isUnmapped(interval))
}
val maxScatterCount = IntervalUtils.countIntervalArguments(this.referenceSequence, this.intervals, this.splitByContig)
this.scatterCount = maxScatterCount min originalFunction.scatterCount
}
def initCloneInputs(cloneFunction: CloneFunction, index: Int) = {
cloneFunction.setFieldValue(this.intervalsField, List(new File("scatter.intervals")))
if (index == scatterCount && includeUnmapped)
if (index == this.scatterCount && this.includeUnmapped)
cloneFunction.setFieldValue(this.intervalsStringField, List("unmapped"))
else
cloneFunction.setFieldValue(this.intervalsStringField, List.empty[String])

View File

@ -0,0 +1,33 @@
package org.broadinstitute.sting.queue.function
import org.broadinstitute.sting.commandline.{Input, Output}
import java.io.{PrintWriter, File}
import org.apache.commons.io.IOUtils
/**
* Writes a list of inputs to an output file.
* Custom formats can override addFile.
*/
class ListWriterFunction extends InProcessFunction {
@Input(doc="input files") var inputFiles: List[File] = Nil
@Output(doc="output file") var listFile: File = _
def run {
val writer = new PrintWriter(listFile)
try {
for (inputFile <- inputFiles)
addFile(writer, inputFile)
} finally {
IOUtils.closeQuietly(writer)
}
}
/**
* Adds the inputFile to the output list.
* @param writer Output file.
* @param inputFile File to add to the output file.
*/
def addFile(writer: PrintWriter, inputFile: File) {
writer.println(inputFile.toString)
}
}

View File

@ -1,7 +1,7 @@
package org.broadinstitute.sting.queue.function.scattergather
import java.io.File
import org.broadinstitute.sting.commandline.{Input, Output}
import org.broadinstitute.sting.commandline.Input
import org.broadinstitute.sting.queue.function.QFunction
/**
@ -11,9 +11,6 @@ trait ScatterFunction extends QFunction {
@Input(doc="Original inputs to scatter")
var originalInputs: Set[File] = _
@Output(doc="Scattered parts of the original inputs, one set per temp directory")
var scatterParts: List[File] = Nil
/**
* Returns true if the scatter function can scatter this original function.
* @param originalFunction The original function to check.
@ -27,6 +24,11 @@ trait ScatterFunction extends QFunction {
*/
def setScatterGatherable(originalFunction: ScatterGatherableFunction)
/**
* After a call to setScatterGatherable(), returns the number of clones that should be created.
*/
def scatterCount: Int
/**
* Initializes the input fields for the clone function.
* The input values should be set to their defaults

View File

@ -11,7 +11,7 @@ import org.broadinstitute.sting.queue.QException
*/
trait ScatterGatherableFunction extends CommandLineFunction {
/** Number of parts to scatter the function into" */
/** Maximum number of parts to scatter the function into. */
var scatterCount: Int = 1
/** scatter gather directory */
@ -79,10 +79,13 @@ trait ScatterGatherableFunction extends CommandLineFunction {
initScatterFunction(scatterFunction)
functions :+= scatterFunction
// Ask the scatter function how many clones to create.
val numClones = scatterFunction.scatterCount
// Create the gather functions for each output field
var gatherFunctions = Map.empty[ArgumentSource, GatherFunction]
var gatherOutputs = Map.empty[ArgumentSource, File]
var gatherAddOrder = this.scatterCount + 2
var gatherAddOrder = numClones + 2
for (gatherField <- outputFieldsWithValues) {
val gatherFunction = this.newGatherFunction(gatherField)
val gatherOutput = getFieldFile(gatherField)
@ -100,7 +103,7 @@ trait ScatterGatherableFunction extends CommandLineFunction {
// Create the clone functions for running the parallel jobs
var cloneFunctions = List.empty[CloneFunction]
for (i <- 1 to this.scatterCount) {
for (i <- 1 to numClones) {
val cloneFunction = this.newCloneFunction()
syncFunction(cloneFunction)

View File

@ -11,6 +11,43 @@ import java.text.SimpleDateFormat
import org.broadinstitute.sting.{WalkerTest, BaseTest}
object PipelineTest {
/** The path to the current Sting directory. Useful when specifying Sting resources. */
val currentStingDir = new File(".").getAbsolutePath
/** The path to the current build of the GATK jar in the currentStingDir. */
val currentGATK = new File(currentStingDir, "dist/GenomeAnalysisTK.jar")
/**
* Returns the top level output path to this test.
* @param testName The name of the test passed to PipelineTest.executeTest()
* @return the top level output path to this test.
*/
def testDir(testName: String) = "pipelinetests/%s/".format(testName)
/**
* Returns the directory where relative output files will be written for this test.
* @param testName The name of the test passed to PipelineTest.executeTest()
* @return the directory where relative output files will be written for this test.
*/
def runDir(testName: String) = testDir(testName) + "run/"
/**
* Returns the directory where temp files will be written for this test.
* @param testName The name of the test passed to PipelineTest.executeTest()
* @return the directory where temp files will be written for this test.
*/
def tempDir(testName: String) = testDir(testName) + "temp/"
/**
* Encapsulates a file MD5
* @param testName The name of the test also passed to PipelineTest.executeTest().
* @param filePath The file path of the output file, relative to the directory the pipeline is run in.
* @param md5 The expected MD5
* @return a file md5 that can be appended to the PipelineTestSpec.fileMD5s
*/
def fileMD5(testName: String, filePath: String, md5: String) = (new File(runDir(testName) + filePath), md5)
private var runningCommandLines = Set.empty[QCommandLine]
private val validationReportsDataLocation = "/humgen/gsa-hpprojects/GATK/validationreports/submitted/"
@ -34,10 +71,11 @@ object PipelineTest {
var failed = 0
for ((file, expectedMD5) <- fileMD5s) {
val calculatedMD5 = BaseTest.testFileMD5(name, file, expectedMD5, false)
failed += 1
if (expectedMD5 != "" && expectedMD5 != calculatedMD5)
failed += 1
}
if (failed > 0)
Assert.fail("%d MD5%s did not match.".format(failed, TextFormatUtils.plural(failed)))
Assert.fail("%d of %d MD5%s did not match.".format(failed, fileMD5s.size, TextFormatUtils.plural(failed)))
}
private def validateEval(name: String, evalSpec: PipelineTestEvalSpec) {
@ -124,11 +162,6 @@ object PipelineTest {
}
}
val currentDir = new File(".").getAbsolutePath
def testDir(testName: String) = "pipelinetests/%s/".format(testName)
def runDir(testName: String) = testDir(testName) + "run/"
def tempDir(testName: String) = testDir(testName) + "temp/"
Runtime.getRuntime.addShutdownHook(new Thread {
/** Cleanup as the JVM shuts down. */
override def run {

View File

@ -0,0 +1,20 @@
package org.broadinstitute.sting.queue.pipeline.examples
import org.testng.annotations.Test
import org.broadinstitute.sting.queue.pipeline.{PipelineTest, PipelineTestSpec}
import org.broadinstitute.sting.BaseTest
class ExampleCountLociPipelineTest {
@Test
def testCountLoci {
var testName = "countloci"
var testOut = "count.out"
val spec = new PipelineTestSpec
// TODO: Use a variable instead of "hour"
spec.args = "-S scala/qscript/examples/ExampleCountLoci.scala -gatk %s -R %s -I %s -o %s -jobQueue hour".format(
PipelineTest.currentGATK, BaseTest.hg18Reference, BaseTest.validationDataLocation + "small_bam_for_countloci.bam", testOut
)
spec.fileMD5s += PipelineTest.fileMD5(testName, testOut, "67823e4722495eb10a5e4c42c267b3a6")
PipelineTest.executeTest(testName, spec)
}
}

View File

@ -9,8 +9,6 @@ class HelloWorldPipelineTest {
var testName = "helloworld"
val spec = new PipelineTestSpec
spec.args = "-S scala/qscript/examples/HelloWorld.scala -jobPrefix HelloWorld -jobQueue hour"
// TODO: working example of MD5 usage.
// spec.fileMD5s += new File(PipelineTest.runDir(testName) + "hello.out") -> "0123456789abcdef0123456789abcdef"
PipelineTest.executeTest(testName, spec)
}
}

View File

@ -93,8 +93,8 @@ class FullCallingPipelineTest {
var pipelineCommand = ("-retry 1 -S scala/qscript/playground/fullCallingPipeline.q" +
" -jobProject %s -Y %s" +
" -tearScript %s/R/DataProcessingReport/GetTearsheetStats.R" +
" --gatkjar %s/dist/GenomeAnalysisTK.jar")
.format(projectName, yamlFile, PipelineTest.currentDir, PipelineTest.currentDir)
" --gatkjar %s")
.format(projectName, yamlFile, PipelineTest.currentStingDir, PipelineTest.currentGATK)
if (!dataset.runIndelRealigner) {
pipelineCommand += " -skipCleaning"