Moved the class field caching from QScript to a ClassFieldCache utility.

Using ClassFieldCache to pull values from QScript for passing to done() method of QStatusMessenger.
This commit is contained in:
kshakir 2012-10-16 18:49:10 -04:00
parent f0e04376ec
commit f93b279151
8 changed files with 227 additions and 110 deletions

View File

@ -100,7 +100,7 @@ class QCommandLine extends CommandLineProgram with Logging {
new PluginManager[QStatusMessenger](classOf[QStatusMessenger])
}
QFunction.parsingEngine = new ParsingEngine(this)
ClassFieldCache.parsingEngine = new ParsingEngine(this)
/**
* Takes the QScripts passed in, runs their script() methods, retrieves their generated
@ -127,6 +127,9 @@ class QCommandLine extends CommandLineProgram with Logging {
for (script <- allQScripts) {
logger.info("Scripting " + qScriptPluginManager.getName(script.getClass.asSubclass(classOf[QScript])))
loadArgumentsIntoObject(script)
// TODO: Pulling inputs can be time/io expensive! Some scripts are using the files to generate functions-- even for dry runs-- so pull it all down for now.
//if (settings.run)
script.pullInputs()
script.qSettings = settings.qSettings
try {
script.script()
@ -138,10 +141,6 @@ class QCommandLine extends CommandLineProgram with Logging {
logger.info("Added " + script.functions.size + " functions")
}
if (settings.run) {
allQScripts.foreach(_.pullInputs())
}
// Execute the job graph
qGraph.run()
@ -170,7 +169,7 @@ class QCommandLine extends CommandLineProgram with Logging {
if (settings.run) {
allQScripts.foreach(_.pushOutputs())
for (statusMessenger <- allStatusMessengers)
statusMessenger.done()
statusMessenger.done(allQScripts.map(_.remoteOutputs))
}
0
}

View File

@ -28,8 +28,7 @@ import engine.JobRunInfo
import org.broadinstitute.sting.queue.function.QFunction
import annotation.target.field
import util._
import org.broadinstitute.sting.utils.classloader.JVMUtils
import java.lang.reflect.Field
import org.broadinstitute.sting.commandline.ArgumentSource
/**
* Defines a Queue pipeline as a collection of CommandLineFunctions.
@ -110,31 +109,29 @@ trait QScript extends Logging with PrimitiveOptionConversions with StringFileCon
}
def pullInputs() {
val inputs = getInputs
inputs.filter(_.isInstanceOf[RemoteFile]).map(_.asInstanceOf[RemoteFile]).foreach(_.pullToLocal())
val inputs = ClassFieldCache.getFieldFiles(this, inputFields)
filterRemoteFiles(inputs).foreach(_.pullToLocal())
}
def pushOutputs() {
val outputs = getOutputs
outputs.filter(_.isInstanceOf[RemoteFile]).map(_.asInstanceOf[RemoteFile]).foreach(_.pushToRemote())
val outputs = ClassFieldCache.getFieldFiles(this, outputFields)
filterRemoteFiles(outputs).foreach(_.pushToRemote())
}
private def getInputs: Seq[File] = {
getFieldValues(classOf[Input])
}
def remoteOutputs: Map[ArgumentSource, Seq[RemoteFile]] =
outputFields.map(field => (field -> filterRemoteFiles(ClassFieldCache.getFieldFiles(this, field)))).filter(tuple => !tuple._2.isEmpty).toMap
private def getOutputs: Seq[File] = {
getFieldValues(classOf[Output])
}
private def filterRemoteFiles(fields: Seq[File]): Seq[RemoteFile] =
fields.filter(field => field != null && field.isInstanceOf[RemoteFile]).map(_.asInstanceOf[RemoteFile])
private def getFieldValues(annotation: Class[_ <: java.lang.annotation.Annotation]): Seq[File] = {
val filtered: Seq[Field] = fields.filter(field => ReflectionUtils.hasAnnotation(field, annotation))
val files = filtered.filter(field => classOf[File].isAssignableFrom(field.getType)).map(field => ReflectionUtils.getValue(this, field).asInstanceOf[File])
val seqFiles = filtered.filter(field => classOf[Seq[File]].isAssignableFrom(field.getType)).map(field => ReflectionUtils.getValue(this, field).asInstanceOf[Seq[File]])
seqFiles.foldLeft(files)(_ ++ _).filter(_ != null)
}
private lazy val fields = collection.JavaConversions.asScalaBuffer(JVMUtils.getAllFields(this.getClass)).toSeq
/** The complete list of fields. */
def functionFields: Seq[ArgumentSource] = ClassFieldCache.classFunctionFields(this.getClass)
/** The @Input fields. */
def inputFields: Seq[ArgumentSource] = ClassFieldCache.classInputFields(this.getClass)
/** The @Output fields. */
def outputFields: Seq[ArgumentSource] = ClassFieldCache.classOutputFields(this.getClass)
/** The @Argument fields. */
def argumentFields: Seq[ArgumentSource] = ClassFieldCache.classArgumentFields(this.getClass)
}
object QScript {

View File

@ -1,10 +1,13 @@
package org.broadinstitute.sting.queue.engine
import org.broadinstitute.sting.commandline.ArgumentSource
import org.broadinstitute.sting.queue.util.RemoteFile
/**
* Plugin to sends QStatus messages
*/
trait QStatusMessenger {
def started()
def done()
def done(files: Seq[Map[ArgumentSource, Seq[RemoteFile]]])
def exit(message: String)
}

View File

@ -28,6 +28,7 @@ import org.broadinstitute.sting.queue.function.scattergather.GatherFunction
import org.broadinstitute.sting.queue.extensions.picard.PicardBamFunction
import org.broadinstitute.sting.queue.function.{RetryMemoryLimit, QFunction}
import org.broadinstitute.sting.gatk.io.stubs.SAMFileWriterArgumentTypeDescriptor
import org.broadinstitute.sting.queue.util.ClassFieldCache
/**
* Merges BAM files using net.sf.picard.sam.MergeSamFiles.
@ -47,13 +48,13 @@ class BamGatherFunction extends GatherFunction with PicardBamFunction with Retry
// bam_compression and index_output_bam_on_the_fly from SAMFileWriterArgumentTypeDescriptor
// are added by the GATKExtensionsGenerator to the subclass of CommandLineGATK
val compression = QFunction.findField(originalFunction.getClass, SAMFileWriterArgumentTypeDescriptor.COMPRESSION_FULLNAME)
val compression = ClassFieldCache.findField(originalFunction.getClass, SAMFileWriterArgumentTypeDescriptor.COMPRESSION_FULLNAME)
this.compressionLevel = originalGATK.getFieldValue(compression).asInstanceOf[Option[Int]]
val disableIndex = QFunction.findField(originalFunction.getClass, SAMFileWriterArgumentTypeDescriptor.DISABLE_INDEXING_FULLNAME)
val disableIndex = ClassFieldCache.findField(originalFunction.getClass, SAMFileWriterArgumentTypeDescriptor.DISABLE_INDEXING_FULLNAME)
this.createIndex = Some(!originalGATK.getFieldValue(disableIndex).asInstanceOf[Boolean])
val enableMD5 = QFunction.findField(originalFunction.getClass, SAMFileWriterArgumentTypeDescriptor.ENABLE_MD5_FULLNAME)
val enableMD5 = ClassFieldCache.findField(originalFunction.getClass, SAMFileWriterArgumentTypeDescriptor.ENABLE_MD5_FULLNAME)
this.createMD5 = Some(originalGATK.getFieldValue(enableMD5).asInstanceOf[Boolean])
super.freezeFieldValues()

View File

@ -27,6 +27,7 @@ package org.broadinstitute.sting.queue.extensions.gatk
import org.broadinstitute.sting.queue.function.scattergather.GatherFunction
import org.broadinstitute.sting.queue.function.{RetryMemoryLimit, QFunction}
import org.broadinstitute.sting.gatk.io.stubs.VCFWriterArgumentTypeDescriptor
import org.broadinstitute.sting.queue.util.ClassFieldCache
/**
* Merges a vcf text file.
@ -46,10 +47,10 @@ class VcfGatherFunction extends CombineVariants with GatherFunction with RetryMe
// NO_HEADER and sites_only from VCFWriterArgumentTypeDescriptor
// are added by the GATKExtensionsGenerator to the subclass of CommandLineGATK
val noHeader = QFunction.findField(originalFunction.getClass, VCFWriterArgumentTypeDescriptor.NO_HEADER_ARG_NAME)
val noHeader = ClassFieldCache.findField(originalFunction.getClass, VCFWriterArgumentTypeDescriptor.NO_HEADER_ARG_NAME)
this.no_cmdline_in_header = originalGATK.getFieldValue(noHeader).asInstanceOf[Boolean]
val sitesOnly = QFunction.findField(originalFunction.getClass, VCFWriterArgumentTypeDescriptor.SITES_ONLY_ARG_NAME)
val sitesOnly = ClassFieldCache.findField(originalFunction.getClass, VCFWriterArgumentTypeDescriptor.SITES_ONLY_ARG_NAME)
this.sites_only = originalGATK.getFieldValue(sitesOnly).asInstanceOf[Boolean]
// ensure that the gather function receives the same unsafe parameter as the scattered function

View File

@ -28,7 +28,6 @@ import java.io.File
import java.lang.annotation.Annotation
import org.broadinstitute.sting.commandline._
import org.broadinstitute.sting.queue.{QException, QSettings}
import collection.JavaConversions._
import java.lang.IllegalStateException
import org.broadinstitute.sting.queue.util._
import org.broadinstitute.sting.utils.io.IOUtils
@ -194,13 +193,13 @@ trait QFunction extends Logging with QJobReport {
def failOutputs: Seq[File] = statusPrefixes.map(path => new File(path + ".fail"))
/** The complete list of fields on this CommandLineFunction. */
def functionFields = QFunction.classFields(this.functionFieldClass).functionFields
def functionFields: Seq[ArgumentSource] = ClassFieldCache.classFunctionFields(this.functionFieldClass)
/** The @Input fields on this CommandLineFunction. */
def inputFields = QFunction.classFields(this.functionFieldClass).inputFields
def inputFields: Seq[ArgumentSource] = ClassFieldCache.classInputFields(this.functionFieldClass)
/** The @Output fields on this CommandLineFunction. */
def outputFields = QFunction.classFields(this.functionFieldClass).outputFields
def outputFields: Seq[ArgumentSource] = ClassFieldCache.classOutputFields(this.functionFieldClass)
/** The @Argument fields on this CommandLineFunction. */
def argumentFields = QFunction.classFields(this.functionFieldClass).argumentFields
def argumentFields: Seq[ArgumentSource] = ClassFieldCache.classArgumentFields(this.functionFieldClass)
/**
* Returns the class that should be used for looking up fields.
@ -475,79 +474,12 @@ trait QFunction extends Logging with QJobReport {
* @param source Field to get the value for.
* @return value of the field.
*/
def getFieldValue(source: ArgumentSource) = ReflectionUtils.getValue(invokeObj(source), source.field)
def getFieldValue(source: ArgumentSource) = ClassFieldCache.getFieldValue(this, source)
/**
* Gets the value of a field.
* @param source Field to set the value for.
* @return value of the field.
*/
def setFieldValue(source: ArgumentSource, value: Any) = ReflectionUtils.setValue(invokeObj(source), source.field, value)
/**
* Walks gets the fields in this object or any collections in that object
* recursively to find the object holding the field to be retrieved or set.
* @param source Field find the invoke object for.
* @return Object to invoke the field on.
*/
private def invokeObj(source: ArgumentSource) = source.parentFields.foldLeft[AnyRef](this)(ReflectionUtils.getValue(_, _))
}
object QFunction {
var parsingEngine: ParsingEngine = _
/**
* The list of fields defined on a class
* @param clazz The class to lookup fields.
*/
private class ClassFields(clazz: Class[_]) {
/** The complete list of fields on this CommandLineFunction. */
val functionFields: Seq[ArgumentSource] = parsingEngine.extractArgumentSources(clazz).toSeq
/** The @Input fields on this CommandLineFunction. */
val inputFields = functionFields.filter(source => ReflectionUtils.hasAnnotation(source.field, classOf[Input]))
/** The @Output fields on this CommandLineFunction. */
val outputFields = functionFields.filter(source => ReflectionUtils.hasAnnotation(source.field, classOf[Output]))
/** The @Argument fields on this CommandLineFunction. */
val argumentFields = functionFields.filter(source => ReflectionUtils.hasAnnotation(source.field, classOf[Argument]))
}
/**
* The mapping from class to fields.
*/
private var classFieldsMap = Map.empty[Class[_], ClassFields]
/**
* Returns the field on clazz.
* @param clazz Class to search.
* @param name Name of the field to return.
* @return Argument source for the field.
*/
def findField(clazz: Class[_], name: String) = {
classFields(clazz).functionFields.find(_.field.getName == name) match {
case Some(source) => source
case None => throw new QException("Could not find a field on class %s with name %s".format(clazz, name))
}
}
/**
* Returns the fields for a class.
* @param clazz Class to retrieve fields for.
* @return the fields for the class.
*/
private def classFields(clazz: Class[_]) = {
classFieldsMap.get(clazz) match {
case Some(classFields) => classFields
case None =>
val classFields = new ClassFields(clazz)
classFieldsMap += clazz -> classFields
classFields
}
}
/**
* Returns the Seq of fields for a QFunction class.
* @param clazz Class to retrieve fields for.
* @return the fields of the class.
*/
def classFunctionFields(clazz: Class[_]) = classFields(clazz).functionFields
def setFieldValue(source: ArgumentSource, value: Any) = ClassFieldCache.setFieldValue(this, source, value)
}

View File

@ -25,13 +25,14 @@
package org.broadinstitute.sting.queue.function.scattergather
import org.broadinstitute.sting.commandline.ArgumentSource
import org.broadinstitute.sting.queue.function.{QFunction, CommandLineFunction}
import org.broadinstitute.sting.queue.function.CommandLineFunction
import org.broadinstitute.sting.queue.util.ClassFieldCache
/**
* Shadow clones another command line function.
*/
object CloneFunction {
private lazy val cloneFunctionFields = QFunction.classFunctionFields(classOf[CloneFunction])
private lazy val cloneFunctionFields = ClassFieldCache.classFunctionFields(classOf[CloneFunction])
}
class CloneFunction extends CommandLineFunction {
@ -76,7 +77,7 @@ class CloneFunction extends CommandLineFunction {
def commandLine = withScatterPart(() => originalFunction.commandLine)
def getFieldValue(field: String): AnyRef = {
val source = QFunction.findField(originalFunction.getClass, field)
val source = ClassFieldCache.findField(originalFunction.getClass, field)
getFieldValue(source)
}
@ -98,7 +99,7 @@ class CloneFunction extends CommandLineFunction {
}
def setFieldValue(field: String, value: Any) {
val source = QFunction.findField(originalFunction.getClass, field)
val source = ClassFieldCache.findField(originalFunction.getClass, field)
setFieldValue(source, value)
}

View File

@ -0,0 +1,183 @@
package org.broadinstitute.sting.queue.util
import org.broadinstitute.sting.commandline._
import scala.Some
import org.broadinstitute.sting.queue.QException
import collection.JavaConversions._
import java.io.File
/**
* Utilities and a static cache of argument fields for various classes populated by the parsingEngine.
* Because this class works with the ParsingEngine it can walk @ArgumentCollection hierarchies.
*/
object ClassFieldCache {
var parsingEngine: ParsingEngine = _
//
// Field caching
//
/**
* The list of fields defined on a class
* @param clazz The class to lookup fields.
*/
private class ClassFields(clazz: Class[_]) {
/** The complete list of fields on this CommandLineFunction. */
val functionFields: Seq[ArgumentSource] = parsingEngine.extractArgumentSources(clazz).toSeq
/** The @Input fields on this CommandLineFunction. */
val inputFields: Seq[ArgumentSource] = functionFields.filter(source => ReflectionUtils.hasAnnotation(source.field, classOf[Input]))
/** The @Output fields on this CommandLineFunction. */
val outputFields: Seq[ArgumentSource] = functionFields.filter(source => ReflectionUtils.hasAnnotation(source.field, classOf[Output]))
/** The @Argument fields on this CommandLineFunction. */
val argumentFields: Seq[ArgumentSource] = functionFields.filter(source => ReflectionUtils.hasAnnotation(source.field, classOf[Argument]))
}
/**
* The mapping from class to fields.
*/
private var classFieldsMap = Map.empty[Class[_], ClassFields]
/**
* Returns the fields for a class.
* @param clazz Class to retrieve fields for.
* @return the fields for the class.
*/
private def classFields(clazz: Class[_]): ClassFields = {
classFieldsMap.get(clazz) match {
case Some(classFields) => classFields
case None =>
val classFields = new ClassFields(clazz)
classFieldsMap += clazz -> classFields
classFields
}
}
/**
* Returns the field on clazz.
* @param clazz Class to search.
* @param name Name of the field to return.
* @return Argument source for the field.
*/
def findField(clazz: Class[_], name: String): ArgumentSource = {
classFields(clazz).functionFields.find(_.field.getName == name) match {
case Some(source) => source
case None => throw new QException("Could not find a field on class %s with name %s".format(clazz, name))
}
}
/**
* Returns the Seq of fields for a QFunction class.
* @param clazz Class to retrieve fields for.
* @return the fields of the class.
*/
def classFunctionFields(clazz: Class[_]): Seq[ArgumentSource] = classFields(clazz).functionFields
/**
* Returns the Seq of inputs for a QFunction class.
* @param clazz Class to retrieve inputs for.
* @return the inputs of the class.
*/
def classInputFields(clazz: Class[_]): Seq[ArgumentSource] = classFields(clazz).inputFields
/**
* Returns the Seq of outputs for a QFunction class.
* @param clazz Class to retrieve outputs for.
* @return the outputs of the class.
*/
def classOutputFields(clazz: Class[_]): Seq[ArgumentSource] = classFields(clazz).outputFields
/**
* Returns the Seq of arguments for a QFunction class.
* @param clazz Class to retrieve arguments for.
* @return the arguments of the class.
*/
def classArgumentFields(clazz: Class[_]): Seq[ArgumentSource] = classFields(clazz).argumentFields
//
// get/set fields as AnyRef
//
/**
* Gets the value of a field.
* @param obj Top level object storing the source info.
* @param source Field to get the value for.
* @return value of the field.
*/
def getFieldValue(obj: AnyRef, source: ArgumentSource) = ReflectionUtils.getValue(invokeObj(obj, source), source.field)
/**
* Gets the value of a field.
* @param obj Top level object storing the source info.
* @param source Field to set the value for.
* @return value of the field.
*/
def setFieldValue(obj: AnyRef, source: ArgumentSource, value: Any) = ReflectionUtils.setValue(invokeObj(obj, source), source.field, value)
/**
* Walks gets the fields in this object or any collections in that object
* recursively to find the object holding the field to be retrieved or set.
* @param obj Top level object storing the source info.
* @param source Field find the invoke object for.
* @return Object to invoke the field on.
*/
private def invokeObj(obj: AnyRef, source: ArgumentSource) = source.parentFields.foldLeft[AnyRef](obj)(ReflectionUtils.getValue(_, _))
//
// get/set fields as java.io.File
//
/**
* Gets the files from the fields. The fields must be a File, a FileExtension, or a Seq or Set of either.
* @param obj Top level object storing the source info.
* @param fields Fields to get files.
* @return for the fields.
*/
def getFieldFiles(obj: AnyRef, fields: Seq[ArgumentSource]): Seq[File] = {
var files: Seq[File] = Nil
for (field <- fields)
files ++= getFieldFiles(obj, field)
files.distinct
}
/**
* Gets the files from the field. The field must be a File, a FileExtension, or a Seq or Set of either.
* @param obj Top level object storing the source info.
* @param field Field to get files.
* @return for the field.
*/
def getFieldFiles(obj: AnyRef, field: ArgumentSource): Seq[File] = {
var files: Seq[File] = Nil
CollectionUtils.foreach(getFieldValue(obj, field), (fieldValue) => {
val file = fieldValueToFile(field, fieldValue)
if (file != null)
files :+= file
})
files.distinct
}
/**
* Gets the file from the field. The field must be a File or a FileExtension and not a Seq or Set.
* @param obj Top level object storing the source info.
* @param field Field to get the file.
* @return for the field.
*/
def getFieldFile(obj: AnyRef, field: ArgumentSource): File =
fieldValueToFile(field, getFieldValue(obj, field))
/**
* Converts the field value to a file. The field must be a File or a FileExtension.
* @param field Field to get the file.
* @param value Value of the File or FileExtension or null.
* @return Null if value is null, otherwise the File.
* @throws QException if the value is not a File or FileExtension.
*/
private def fieldValueToFile(field: ArgumentSource, value: Any): File = value match {
case file: File => file
case null => null
case unknown => throw new QException("Non-file found. Try removing the annotation, change the annotation to @Argument, or extend File with FileExtension: %s: %s".format(field.field, unknown))
}
}