gatk-3.8/scala/src/org/broadinstitute/sting/queue/engine/QGraph.scala

150 lines
5.1 KiB
Scala
Raw Normal View History

package org.broadinstitute.sting.queue.engine
import org.jgrapht.graph.SimpleDirectedGraph
import scala.collection.JavaConversions
import scala.collection.JavaConversions._
import org.broadinstitute.sting.queue.function.{MappingFunction, CommandLineFunction, QFunction}
import org.broadinstitute.sting.queue.function.scattergather.ScatterGatherableFunction
import org.broadinstitute.sting.queue.util.{CollectionUtils, Logging}
import org.broadinstitute.sting.queue.QException
import org.jgrapht.alg.CycleDetector
import org.jgrapht.EdgeFactory
class QGraph extends Logging {
var dryRun = true
var bsubAllJobs = false
var bsubWaitJobs = false
var properties = Map.empty[String, String]
val jobGraph = newGraph
def numJobs = JavaConversions.asSet(jobGraph.edgeSet).filter(_.isInstanceOf[CommandLineFunction]).size
def add(command: CommandLineFunction) {
addFunction(command)
}
/**
* Looks through functions with multiple inputs and outputs and adds mapping functions for single inputs and outputs.
*/
def fillIn = {
// clone since edgeSet is backed by the graph
for (function <- JavaConversions.asSet(jobGraph.edgeSet).clone) {
addCollectionOutputs(function.outputs)
addCollectionInputs(function.inputs)
}
var pruning = true
while (pruning) {
pruning = false
val filler = jobGraph.edgeSet.filter(isFiller(_))
if (filler.size > 0) {
jobGraph.removeAllEdges(filler)
pruning = true
}
}
jobGraph.removeAllVertices(jobGraph.vertexSet.filter(isOrphan(_)))
}
def run = {
var isReady = true
for (function <- JavaConversions.asSet(jobGraph.edgeSet)) {
function match {
case cmd: CommandLineFunction =>
val missingValues = cmd.missingValues
if (missingValues.size > 0) {
isReady = false
logger.error("Missing values for function: %s".format(cmd.commandLine))
for (missing <- missingValues)
logger.error(" " + missing)
}
case _ =>
}
}
val detector = new CycleDetector(jobGraph)
if (detector.detectCycles) {
logger.error("Cycles were detected in the graph:")
for (cycle <- detector.findCycles)
logger.error(" " + cycle)
isReady = false
}
if (isReady || this.dryRun)
(new TopologicalJobScheduler(this) with LsfJobRunner).runJobs
}
private def newGraph = new SimpleDirectedGraph[QNode, QFunction](new EdgeFactory[QNode, QFunction] {
def createEdge(input: QNode, output: QNode) = new MappingFunction(input.items, output.items)})
private def addFunction(f: QFunction): Unit = {
try {
f.freeze
f match {
case scatterGather: ScatterGatherableFunction if (bsubAllJobs && scatterGather.scatterGatherable) =>
val functions = scatterGather.generateFunctions()
if (logger.isTraceEnabled)
logger.trace("Scattered into %d parts: %s".format(functions.size, functions))
functions.foreach(addFunction(_))
case _ =>
val inputs = QNode(f.inputs)
val outputs = QNode(f.outputs)
val newSource = jobGraph.addVertex(inputs)
val newTarget = jobGraph.addVertex(outputs)
val removedEdges = jobGraph.removeAllEdges(inputs, outputs)
val added = jobGraph.addEdge(inputs, outputs, f)
if (logger.isTraceEnabled) {
logger.trace("Mapped from: " + inputs)
logger.trace("Mapped to: " + outputs)
logger.trace("Mapped via: " + f)
logger.trace("Removed edges: " + removedEdges)
logger.trace("New source?: " + newSource)
logger.trace("New target?: " + newTarget)
logger.trace("")
}
}
} catch {
case e: Exception =>
throw new QException("Error adding function: " + f, e)
}
}
private def addCollectionInputs(value: Any): Unit = {
CollectionUtils.foreach(value, (item, collection) =>
addMappingEdge(item, collection))
}
private def addCollectionOutputs(value: Any): Unit = {
CollectionUtils.foreach(value, (item, collection) =>
addMappingEdge(collection, item))
}
private def addMappingEdge(input: Any, output: Any) = {
val inputSet = asSet(input)
val outputSet = asSet(output)
val hasEdge = inputSet == outputSet ||
jobGraph.getEdge(QNode(inputSet), QNode(outputSet)) != null ||
jobGraph.getEdge(QNode(outputSet), QNode(inputSet)) != null
if (!hasEdge)
addFunction(new MappingFunction(inputSet, outputSet))
}
private def asSet(value: Any): Set[Any] = if (value.isInstanceOf[Set[_]]) value.asInstanceOf[Set[Any]] else Set(value)
private def isMappingEdge(edge: QFunction) =
edge.isInstanceOf[MappingFunction]
private def isFiller(edge: QFunction) = {
if (isMappingEdge(edge)) {
if (jobGraph.outgoingEdgesOf(jobGraph.getEdgeTarget(edge)).size == 0)
true
else if (jobGraph.incomingEdgesOf(jobGraph.getEdgeSource(edge)).size == 0)
true
else false
} else false
}
private def isOrphan(node: QNode) =
(jobGraph.incomingEdgesOf(node).size + jobGraph.outgoingEdgesOf(node).size) == 0
}