gatk-3.8/scala/src/org/broadinstitute/sting/queue/engine/QGraph.scala

165 lines
5.8 KiB
Scala
Raw Normal View History

package org.broadinstitute.sting.queue.engine
import org.jgrapht.graph.SimpleDirectedGraph
import scala.collection.JavaConversions
import scala.collection.JavaConversions._
import scala.collection.immutable.ListMap
import org.broadinstitute.sting.queue.function.{MappingFunction, CommandLineFunction, QFunction}
import org.broadinstitute.sting.queue.function.scattergather.ScatterGatherableFunction
import org.broadinstitute.sting.queue.util.{CollectionUtils, Logging}
import org.broadinstitute.sting.queue.QException
import org.jgrapht.alg.CycleDetector
import org.jgrapht.EdgeFactory
class QGraph extends Logging {
var dryRun = true
var bsubAllJobs = false
val jobGraph = newGraph
def numJobs = JavaConversions.asSet(jobGraph.edgeSet).filter(_.isInstanceOf[CommandLineFunction]).size
def add(command: CommandLineFunction) {
add(command, true)
}
/**
* Looks through functions with multiple inputs and outputs and adds mapping functions for single inputs and outputs.
*/
def fillIn = {
// clone since edgeSet is backed by the graph
for (function <- JavaConversions.asSet(jobGraph.edgeSet).clone) {
val inputs = function.inputs
val outputs = function.outputs
for ((name, input) <- inputs) {
addCollectionInputs(name, input)
if (inputs.size > 1)
addMappingEdge(ListMap(name -> input), inputs)
}
for ((name, output) <- outputs) {
addCollectionOutputs(name, output)
if (outputs.size > 1)
addMappingEdge(outputs, ListMap(name -> output))
}
}
var pruning = true
while (pruning) {
pruning = false
val filler = jobGraph.edgeSet.filter(isFiller(_))
if (filler.size > 0) {
jobGraph.removeAllEdges(filler)
pruning = true
}
}
jobGraph.removeAllVertices(jobGraph.vertexSet.filter(isOrphan(_)))
}
def run = {
var isReady = true
for (function <- JavaConversions.asSet(jobGraph.edgeSet)) {
function match {
case cmd: CommandLineFunction =>
val missingValues = cmd.missingValues
if (missingValues.size > 0) {
isReady = false
logger.error("Missing values for function: %s".format(cmd.commandLine))
for (missing <- missingValues)
logger.error(" " + missing)
}
case _ =>
}
}
val detector = new CycleDetector(jobGraph)
if (detector.detectCycles) {
logger.error("Cycles were detected in the graph:")
for (cycle <- detector.findCycles)
logger.error(" " + cycle)
isReady = false
}
if (isReady || this.dryRun)
(new TopologicalJobScheduler(this) with LsfJobRunner).runJobs
}
private def newGraph = new SimpleDirectedGraph[QNode, QFunction](new EdgeFactory[QNode, QFunction] {
def createEdge(input: QNode, output: QNode) = new MappingFunction(input.valueMap, output.valueMap)})
private def add(f: QFunction, replace: Boolean): Unit = {
try {
f.freeze
f match {
case scatterGather: ScatterGatherableFunction if (bsubAllJobs && scatterGather.scatterGatherable) =>
val functions = scatterGather.generateFunctions()
if (logger.isTraceEnabled)
logger.trace("Scattered into %d parts: %s".format(functions.size, functions))
functions.foreach(add(_))
case _ =>
val inputs = QNode(f.inputs.values.filter(_ != null).toSet)
val outputs = QNode(f.outputs.values.filter(_ != null).toSet)
val newSource = jobGraph.addVertex(inputs)
val newTarget = jobGraph.addVertex(outputs)
val removedEdges = if (replace) jobGraph.removeAllEdges(inputs, outputs) else Nil
val added = jobGraph.addEdge(inputs, outputs, f)
if (logger.isTraceEnabled) {
logger.trace("Mapped from: " + inputs)
logger.trace("Mapped to: " + outputs)
logger.trace("Mapped via: " + f)
logger.trace("Removed edges: " + removedEdges)
logger.trace("New source?: " + newSource)
logger.trace("New target?: " + newTarget)
logger.trace("")
}
}
} catch {
case e: Exception =>
throw new QException("Error adding function: " + f, e)
}
}
private def addCollectionInputs(name: String, value: Any): Unit = {
CollectionUtils.foreach(value, (item, collection) =>
addMappingEdge(ListMap(name -> item), ListMap(name -> collection)))
}
private def addCollectionOutputs(name: String, value: Any): Unit = {
CollectionUtils.foreach(value, (item, collection) =>
addMappingEdge(ListMap(name -> collection), ListMap(name -> item)))
}
private def addMappingEdge(input: ListMap[String, Any], output: ListMap[String, Any]) =
add(new MappingFunction(input, output), false)
private def isMappingEdge(edge: QFunction) =
edge.isInstanceOf[MappingFunction]
private def isFiller(edge: QFunction) = {
if (isMappingEdge(edge)) {
val source = jobGraph.getEdgeSource(edge)
val target = jobGraph.getEdgeTarget(edge)
if (jobGraph.outgoingEdgesOf(target).size == 0 || jobGraph.incomingEdgesOf(source).size == 0)
true
else if (isLoopback(source) || isLoopback(target))
true
else false
} else false
}
private def isLoopback(node: QNode) = {
var loopback = false
val incoming = jobGraph.incomingEdgesOf(node)
val outgoing = jobGraph.outgoingEdgesOf(node)
if (incoming.size == 1 && outgoing.size == 1)
if (isMappingEdge(incoming.head) && isMappingEdge(outgoing.head))
if (jobGraph.getEdgeSource(incoming.head) == jobGraph.getEdgeTarget(outgoing.head))
loopback = true
loopback
}
private def isOrphan(node: QNode) =
(jobGraph.incomingEdgesOf(node).size + jobGraph.outgoingEdgesOf(node).size) == 0
}