diff --git a/public/scala/src/org/broadinstitute/sting/queue/QCommandLine.scala b/public/scala/src/org/broadinstitute/sting/queue/QCommandLine.scala index f4c4b613f..5b84bfd16 100644 --- a/public/scala/src/org/broadinstitute/sting/queue/QCommandLine.scala +++ b/public/scala/src/org/broadinstitute/sting/queue/QCommandLine.scala @@ -100,7 +100,7 @@ class QCommandLine extends CommandLineProgram with Logging { new PluginManager[QStatusMessenger](classOf[QStatusMessenger]) } - QFunction.parsingEngine = new ParsingEngine(this) + ClassFieldCache.parsingEngine = new ParsingEngine(this) /** * Takes the QScripts passed in, runs their script() methods, retrieves their generated @@ -127,6 +127,9 @@ class QCommandLine extends CommandLineProgram with Logging { for (script <- allQScripts) { logger.info("Scripting " + qScriptPluginManager.getName(script.getClass.asSubclass(classOf[QScript]))) loadArgumentsIntoObject(script) + // TODO: Pulling inputs can be time/io expensive! Some scripts are using the files to generate functions-- even for dry runs-- so pull it all down for now. + //if (settings.run) + script.pullInputs() script.qSettings = settings.qSettings try { script.script() @@ -138,10 +141,6 @@ class QCommandLine extends CommandLineProgram with Logging { logger.info("Added " + script.functions.size + " functions") } - if (settings.run) { - allQScripts.foreach(_.pullInputs()) - } - // Execute the job graph qGraph.run() @@ -170,7 +169,7 @@ class QCommandLine extends CommandLineProgram with Logging { if (settings.run) { allQScripts.foreach(_.pushOutputs()) for (statusMessenger <- allStatusMessengers) - statusMessenger.done() + statusMessenger.done(allQScripts.map(_.remoteOutputs)) } 0 } diff --git a/public/scala/src/org/broadinstitute/sting/queue/QScript.scala b/public/scala/src/org/broadinstitute/sting/queue/QScript.scala index da24b854e..ee2089dc5 100755 --- a/public/scala/src/org/broadinstitute/sting/queue/QScript.scala +++ b/public/scala/src/org/broadinstitute/sting/queue/QScript.scala @@ -28,8 +28,7 @@ import engine.JobRunInfo import org.broadinstitute.sting.queue.function.QFunction import annotation.target.field import util._ -import org.broadinstitute.sting.utils.classloader.JVMUtils -import java.lang.reflect.Field +import org.broadinstitute.sting.commandline.ArgumentSource /** * Defines a Queue pipeline as a collection of CommandLineFunctions. @@ -110,31 +109,29 @@ trait QScript extends Logging with PrimitiveOptionConversions with StringFileCon } def pullInputs() { - val inputs = getInputs - inputs.filter(_.isInstanceOf[RemoteFile]).map(_.asInstanceOf[RemoteFile]).foreach(_.pullToLocal()) + val inputs = ClassFieldCache.getFieldFiles(this, inputFields) + filterRemoteFiles(inputs).foreach(_.pullToLocal()) } def pushOutputs() { - val outputs = getOutputs - outputs.filter(_.isInstanceOf[RemoteFile]).map(_.asInstanceOf[RemoteFile]).foreach(_.pushToRemote()) + val outputs = ClassFieldCache.getFieldFiles(this, outputFields) + filterRemoteFiles(outputs).foreach(_.pushToRemote()) } - private def getInputs: Seq[File] = { - getFieldValues(classOf[Input]) - } + def remoteOutputs: Map[ArgumentSource, Seq[RemoteFile]] = + outputFields.map(field => (field -> filterRemoteFiles(ClassFieldCache.getFieldFiles(this, field)))).filter(tuple => !tuple._2.isEmpty).toMap - private def getOutputs: Seq[File] = { - getFieldValues(classOf[Output]) - } + private def filterRemoteFiles(fields: Seq[File]): Seq[RemoteFile] = + fields.filter(field => field != null && field.isInstanceOf[RemoteFile]).map(_.asInstanceOf[RemoteFile]) - private def getFieldValues(annotation: Class[_ <: java.lang.annotation.Annotation]): Seq[File] = { - val filtered: Seq[Field] = fields.filter(field => ReflectionUtils.hasAnnotation(field, annotation)) - val files = filtered.filter(field => classOf[File].isAssignableFrom(field.getType)).map(field => ReflectionUtils.getValue(this, field).asInstanceOf[File]) - val seqFiles = filtered.filter(field => classOf[Seq[File]].isAssignableFrom(field.getType)).map(field => ReflectionUtils.getValue(this, field).asInstanceOf[Seq[File]]) - seqFiles.foldLeft(files)(_ ++ _).filter(_ != null) - } - - private lazy val fields = collection.JavaConversions.asScalaBuffer(JVMUtils.getAllFields(this.getClass)).toSeq + /** The complete list of fields. */ + def functionFields: Seq[ArgumentSource] = ClassFieldCache.classFunctionFields(this.getClass) + /** The @Input fields. */ + def inputFields: Seq[ArgumentSource] = ClassFieldCache.classInputFields(this.getClass) + /** The @Output fields. */ + def outputFields: Seq[ArgumentSource] = ClassFieldCache.classOutputFields(this.getClass) + /** The @Argument fields. */ + def argumentFields: Seq[ArgumentSource] = ClassFieldCache.classArgumentFields(this.getClass) } object QScript { diff --git a/public/scala/src/org/broadinstitute/sting/queue/engine/QStatusMessenger.scala b/public/scala/src/org/broadinstitute/sting/queue/engine/QStatusMessenger.scala index c61f2ef1f..eeabe6d1d 100644 --- a/public/scala/src/org/broadinstitute/sting/queue/engine/QStatusMessenger.scala +++ b/public/scala/src/org/broadinstitute/sting/queue/engine/QStatusMessenger.scala @@ -1,10 +1,13 @@ package org.broadinstitute.sting.queue.engine +import org.broadinstitute.sting.commandline.ArgumentSource +import org.broadinstitute.sting.queue.util.RemoteFile + /** * Plugin to sends QStatus messages */ trait QStatusMessenger { def started() - def done() + def done(files: Seq[Map[ArgumentSource, Seq[RemoteFile]]]) def exit(message: String) } diff --git a/public/scala/src/org/broadinstitute/sting/queue/extensions/gatk/BamGatherFunction.scala b/public/scala/src/org/broadinstitute/sting/queue/extensions/gatk/BamGatherFunction.scala index 9522ec86c..a59f273ad 100644 --- a/public/scala/src/org/broadinstitute/sting/queue/extensions/gatk/BamGatherFunction.scala +++ b/public/scala/src/org/broadinstitute/sting/queue/extensions/gatk/BamGatherFunction.scala @@ -28,6 +28,7 @@ import org.broadinstitute.sting.queue.function.scattergather.GatherFunction import org.broadinstitute.sting.queue.extensions.picard.PicardBamFunction import org.broadinstitute.sting.queue.function.{RetryMemoryLimit, QFunction} import org.broadinstitute.sting.gatk.io.stubs.SAMFileWriterArgumentTypeDescriptor +import org.broadinstitute.sting.queue.util.ClassFieldCache /** * Merges BAM files using net.sf.picard.sam.MergeSamFiles. @@ -47,13 +48,13 @@ class BamGatherFunction extends GatherFunction with PicardBamFunction with Retry // bam_compression and index_output_bam_on_the_fly from SAMFileWriterArgumentTypeDescriptor // are added by the GATKExtensionsGenerator to the subclass of CommandLineGATK - val compression = QFunction.findField(originalFunction.getClass, SAMFileWriterArgumentTypeDescriptor.COMPRESSION_FULLNAME) + val compression = ClassFieldCache.findField(originalFunction.getClass, SAMFileWriterArgumentTypeDescriptor.COMPRESSION_FULLNAME) this.compressionLevel = originalGATK.getFieldValue(compression).asInstanceOf[Option[Int]] - val disableIndex = QFunction.findField(originalFunction.getClass, SAMFileWriterArgumentTypeDescriptor.DISABLE_INDEXING_FULLNAME) + val disableIndex = ClassFieldCache.findField(originalFunction.getClass, SAMFileWriterArgumentTypeDescriptor.DISABLE_INDEXING_FULLNAME) this.createIndex = Some(!originalGATK.getFieldValue(disableIndex).asInstanceOf[Boolean]) - val enableMD5 = QFunction.findField(originalFunction.getClass, SAMFileWriterArgumentTypeDescriptor.ENABLE_MD5_FULLNAME) + val enableMD5 = ClassFieldCache.findField(originalFunction.getClass, SAMFileWriterArgumentTypeDescriptor.ENABLE_MD5_FULLNAME) this.createMD5 = Some(originalGATK.getFieldValue(enableMD5).asInstanceOf[Boolean]) super.freezeFieldValues() diff --git a/public/scala/src/org/broadinstitute/sting/queue/extensions/gatk/VcfGatherFunction.scala b/public/scala/src/org/broadinstitute/sting/queue/extensions/gatk/VcfGatherFunction.scala index 75be4d773..fb22554f0 100644 --- a/public/scala/src/org/broadinstitute/sting/queue/extensions/gatk/VcfGatherFunction.scala +++ b/public/scala/src/org/broadinstitute/sting/queue/extensions/gatk/VcfGatherFunction.scala @@ -27,6 +27,7 @@ package org.broadinstitute.sting.queue.extensions.gatk import org.broadinstitute.sting.queue.function.scattergather.GatherFunction import org.broadinstitute.sting.queue.function.{RetryMemoryLimit, QFunction} import org.broadinstitute.sting.gatk.io.stubs.VCFWriterArgumentTypeDescriptor +import org.broadinstitute.sting.queue.util.ClassFieldCache /** * Merges a vcf text file. @@ -46,10 +47,10 @@ class VcfGatherFunction extends CombineVariants with GatherFunction with RetryMe // NO_HEADER and sites_only from VCFWriterArgumentTypeDescriptor // are added by the GATKExtensionsGenerator to the subclass of CommandLineGATK - val noHeader = QFunction.findField(originalFunction.getClass, VCFWriterArgumentTypeDescriptor.NO_HEADER_ARG_NAME) + val noHeader = ClassFieldCache.findField(originalFunction.getClass, VCFWriterArgumentTypeDescriptor.NO_HEADER_ARG_NAME) this.no_cmdline_in_header = originalGATK.getFieldValue(noHeader).asInstanceOf[Boolean] - val sitesOnly = QFunction.findField(originalFunction.getClass, VCFWriterArgumentTypeDescriptor.SITES_ONLY_ARG_NAME) + val sitesOnly = ClassFieldCache.findField(originalFunction.getClass, VCFWriterArgumentTypeDescriptor.SITES_ONLY_ARG_NAME) this.sites_only = originalGATK.getFieldValue(sitesOnly).asInstanceOf[Boolean] // ensure that the gather function receives the same unsafe parameter as the scattered function diff --git a/public/scala/src/org/broadinstitute/sting/queue/function/QFunction.scala b/public/scala/src/org/broadinstitute/sting/queue/function/QFunction.scala index aae846534..3849b976a 100644 --- a/public/scala/src/org/broadinstitute/sting/queue/function/QFunction.scala +++ b/public/scala/src/org/broadinstitute/sting/queue/function/QFunction.scala @@ -28,7 +28,6 @@ import java.io.File import java.lang.annotation.Annotation import org.broadinstitute.sting.commandline._ import org.broadinstitute.sting.queue.{QException, QSettings} -import collection.JavaConversions._ import java.lang.IllegalStateException import org.broadinstitute.sting.queue.util._ import org.broadinstitute.sting.utils.io.IOUtils @@ -194,13 +193,13 @@ trait QFunction extends Logging with QJobReport { def failOutputs: Seq[File] = statusPrefixes.map(path => new File(path + ".fail")) /** The complete list of fields on this CommandLineFunction. */ - def functionFields = QFunction.classFields(this.functionFieldClass).functionFields + def functionFields: Seq[ArgumentSource] = ClassFieldCache.classFunctionFields(this.functionFieldClass) /** The @Input fields on this CommandLineFunction. */ - def inputFields = QFunction.classFields(this.functionFieldClass).inputFields + def inputFields: Seq[ArgumentSource] = ClassFieldCache.classInputFields(this.functionFieldClass) /** The @Output fields on this CommandLineFunction. */ - def outputFields = QFunction.classFields(this.functionFieldClass).outputFields + def outputFields: Seq[ArgumentSource] = ClassFieldCache.classOutputFields(this.functionFieldClass) /** The @Argument fields on this CommandLineFunction. */ - def argumentFields = QFunction.classFields(this.functionFieldClass).argumentFields + def argumentFields: Seq[ArgumentSource] = ClassFieldCache.classArgumentFields(this.functionFieldClass) /** * Returns the class that should be used for looking up fields. @@ -475,79 +474,12 @@ trait QFunction extends Logging with QJobReport { * @param source Field to get the value for. * @return value of the field. */ - def getFieldValue(source: ArgumentSource) = ReflectionUtils.getValue(invokeObj(source), source.field) + def getFieldValue(source: ArgumentSource) = ClassFieldCache.getFieldValue(this, source) /** * Gets the value of a field. * @param source Field to set the value for. * @return value of the field. */ - def setFieldValue(source: ArgumentSource, value: Any) = ReflectionUtils.setValue(invokeObj(source), source.field, value) - - /** - * Walks gets the fields in this object or any collections in that object - * recursively to find the object holding the field to be retrieved or set. - * @param source Field find the invoke object for. - * @return Object to invoke the field on. - */ - private def invokeObj(source: ArgumentSource) = source.parentFields.foldLeft[AnyRef](this)(ReflectionUtils.getValue(_, _)) -} - -object QFunction { - var parsingEngine: ParsingEngine = _ - - /** - * The list of fields defined on a class - * @param clazz The class to lookup fields. - */ - private class ClassFields(clazz: Class[_]) { - /** The complete list of fields on this CommandLineFunction. */ - val functionFields: Seq[ArgumentSource] = parsingEngine.extractArgumentSources(clazz).toSeq - /** The @Input fields on this CommandLineFunction. */ - val inputFields = functionFields.filter(source => ReflectionUtils.hasAnnotation(source.field, classOf[Input])) - /** The @Output fields on this CommandLineFunction. */ - val outputFields = functionFields.filter(source => ReflectionUtils.hasAnnotation(source.field, classOf[Output])) - /** The @Argument fields on this CommandLineFunction. */ - val argumentFields = functionFields.filter(source => ReflectionUtils.hasAnnotation(source.field, classOf[Argument])) - } - - /** - * The mapping from class to fields. - */ - private var classFieldsMap = Map.empty[Class[_], ClassFields] - - /** - * Returns the field on clazz. - * @param clazz Class to search. - * @param name Name of the field to return. - * @return Argument source for the field. - */ - def findField(clazz: Class[_], name: String) = { - classFields(clazz).functionFields.find(_.field.getName == name) match { - case Some(source) => source - case None => throw new QException("Could not find a field on class %s with name %s".format(clazz, name)) - } - } - - /** - * Returns the fields for a class. - * @param clazz Class to retrieve fields for. - * @return the fields for the class. - */ - private def classFields(clazz: Class[_]) = { - classFieldsMap.get(clazz) match { - case Some(classFields) => classFields - case None => - val classFields = new ClassFields(clazz) - classFieldsMap += clazz -> classFields - classFields - } - } - - /** - * Returns the Seq of fields for a QFunction class. - * @param clazz Class to retrieve fields for. - * @return the fields of the class. - */ - def classFunctionFields(clazz: Class[_]) = classFields(clazz).functionFields + def setFieldValue(source: ArgumentSource, value: Any) = ClassFieldCache.setFieldValue(this, source, value) } diff --git a/public/scala/src/org/broadinstitute/sting/queue/function/scattergather/CloneFunction.scala b/public/scala/src/org/broadinstitute/sting/queue/function/scattergather/CloneFunction.scala index 686188e72..91cacbb71 100644 --- a/public/scala/src/org/broadinstitute/sting/queue/function/scattergather/CloneFunction.scala +++ b/public/scala/src/org/broadinstitute/sting/queue/function/scattergather/CloneFunction.scala @@ -25,13 +25,14 @@ package org.broadinstitute.sting.queue.function.scattergather import org.broadinstitute.sting.commandline.ArgumentSource -import org.broadinstitute.sting.queue.function.{QFunction, CommandLineFunction} +import org.broadinstitute.sting.queue.function.CommandLineFunction +import org.broadinstitute.sting.queue.util.ClassFieldCache /** * Shadow clones another command line function. */ object CloneFunction { - private lazy val cloneFunctionFields = QFunction.classFunctionFields(classOf[CloneFunction]) + private lazy val cloneFunctionFields = ClassFieldCache.classFunctionFields(classOf[CloneFunction]) } class CloneFunction extends CommandLineFunction { @@ -76,7 +77,7 @@ class CloneFunction extends CommandLineFunction { def commandLine = withScatterPart(() => originalFunction.commandLine) def getFieldValue(field: String): AnyRef = { - val source = QFunction.findField(originalFunction.getClass, field) + val source = ClassFieldCache.findField(originalFunction.getClass, field) getFieldValue(source) } @@ -98,7 +99,7 @@ class CloneFunction extends CommandLineFunction { } def setFieldValue(field: String, value: Any) { - val source = QFunction.findField(originalFunction.getClass, field) + val source = ClassFieldCache.findField(originalFunction.getClass, field) setFieldValue(source, value) } diff --git a/public/scala/src/org/broadinstitute/sting/queue/util/ClassFieldCache.scala b/public/scala/src/org/broadinstitute/sting/queue/util/ClassFieldCache.scala new file mode 100644 index 000000000..870dd5617 --- /dev/null +++ b/public/scala/src/org/broadinstitute/sting/queue/util/ClassFieldCache.scala @@ -0,0 +1,183 @@ +package org.broadinstitute.sting.queue.util + +import org.broadinstitute.sting.commandline._ +import scala.Some +import org.broadinstitute.sting.queue.QException +import collection.JavaConversions._ +import java.io.File + +/** + * Utilities and a static cache of argument fields for various classes populated by the parsingEngine. + * Because this class works with the ParsingEngine it can walk @ArgumentCollection hierarchies. + */ +object ClassFieldCache { + var parsingEngine: ParsingEngine = _ + + + // + // Field caching + // + + /** + * The list of fields defined on a class + * @param clazz The class to lookup fields. + */ + private class ClassFields(clazz: Class[_]) { + /** The complete list of fields on this CommandLineFunction. */ + val functionFields: Seq[ArgumentSource] = parsingEngine.extractArgumentSources(clazz).toSeq + /** The @Input fields on this CommandLineFunction. */ + val inputFields: Seq[ArgumentSource] = functionFields.filter(source => ReflectionUtils.hasAnnotation(source.field, classOf[Input])) + /** The @Output fields on this CommandLineFunction. */ + val outputFields: Seq[ArgumentSource] = functionFields.filter(source => ReflectionUtils.hasAnnotation(source.field, classOf[Output])) + /** The @Argument fields on this CommandLineFunction. */ + val argumentFields: Seq[ArgumentSource] = functionFields.filter(source => ReflectionUtils.hasAnnotation(source.field, classOf[Argument])) + } + + /** + * The mapping from class to fields. + */ + private var classFieldsMap = Map.empty[Class[_], ClassFields] + + /** + * Returns the fields for a class. + * @param clazz Class to retrieve fields for. + * @return the fields for the class. + */ + private def classFields(clazz: Class[_]): ClassFields = { + classFieldsMap.get(clazz) match { + case Some(classFields) => classFields + case None => + val classFields = new ClassFields(clazz) + classFieldsMap += clazz -> classFields + classFields + } + } + + /** + * Returns the field on clazz. + * @param clazz Class to search. + * @param name Name of the field to return. + * @return Argument source for the field. + */ + def findField(clazz: Class[_], name: String): ArgumentSource = { + classFields(clazz).functionFields.find(_.field.getName == name) match { + case Some(source) => source + case None => throw new QException("Could not find a field on class %s with name %s".format(clazz, name)) + } + } + + /** + * Returns the Seq of fields for a QFunction class. + * @param clazz Class to retrieve fields for. + * @return the fields of the class. + */ + def classFunctionFields(clazz: Class[_]): Seq[ArgumentSource] = classFields(clazz).functionFields + + /** + * Returns the Seq of inputs for a QFunction class. + * @param clazz Class to retrieve inputs for. + * @return the inputs of the class. + */ + def classInputFields(clazz: Class[_]): Seq[ArgumentSource] = classFields(clazz).inputFields + + /** + * Returns the Seq of outputs for a QFunction class. + * @param clazz Class to retrieve outputs for. + * @return the outputs of the class. + */ + def classOutputFields(clazz: Class[_]): Seq[ArgumentSource] = classFields(clazz).outputFields + + /** + * Returns the Seq of arguments for a QFunction class. + * @param clazz Class to retrieve arguments for. + * @return the arguments of the class. + */ + def classArgumentFields(clazz: Class[_]): Seq[ArgumentSource] = classFields(clazz).argumentFields + + + // + // get/set fields as AnyRef + // + + /** + * Gets the value of a field. + * @param obj Top level object storing the source info. + * @param source Field to get the value for. + * @return value of the field. + */ + def getFieldValue(obj: AnyRef, source: ArgumentSource) = ReflectionUtils.getValue(invokeObj(obj, source), source.field) + + /** + * Gets the value of a field. + * @param obj Top level object storing the source info. + * @param source Field to set the value for. + * @return value of the field. + */ + def setFieldValue(obj: AnyRef, source: ArgumentSource, value: Any) = ReflectionUtils.setValue(invokeObj(obj, source), source.field, value) + + /** + * Walks gets the fields in this object or any collections in that object + * recursively to find the object holding the field to be retrieved or set. + * @param obj Top level object storing the source info. + * @param source Field find the invoke object for. + * @return Object to invoke the field on. + */ + private def invokeObj(obj: AnyRef, source: ArgumentSource) = source.parentFields.foldLeft[AnyRef](obj)(ReflectionUtils.getValue(_, _)) + + + // + // get/set fields as java.io.File + // + + /** + * Gets the files from the fields. The fields must be a File, a FileExtension, or a Seq or Set of either. + * @param obj Top level object storing the source info. + * @param fields Fields to get files. + * @return for the fields. + */ + def getFieldFiles(obj: AnyRef, fields: Seq[ArgumentSource]): Seq[File] = { + var files: Seq[File] = Nil + for (field <- fields) + files ++= getFieldFiles(obj, field) + files.distinct + } + + /** + * Gets the files from the field. The field must be a File, a FileExtension, or a Seq or Set of either. + * @param obj Top level object storing the source info. + * @param field Field to get files. + * @return for the field. + */ + def getFieldFiles(obj: AnyRef, field: ArgumentSource): Seq[File] = { + var files: Seq[File] = Nil + CollectionUtils.foreach(getFieldValue(obj, field), (fieldValue) => { + val file = fieldValueToFile(field, fieldValue) + if (file != null) + files :+= file + }) + files.distinct + } + + /** + * Gets the file from the field. The field must be a File or a FileExtension and not a Seq or Set. + * @param obj Top level object storing the source info. + * @param field Field to get the file. + * @return for the field. + */ + def getFieldFile(obj: AnyRef, field: ArgumentSource): File = + fieldValueToFile(field, getFieldValue(obj, field)) + + /** + * Converts the field value to a file. The field must be a File or a FileExtension. + * @param field Field to get the file. + * @param value Value of the File or FileExtension or null. + * @return Null if value is null, otherwise the File. + * @throws QException if the value is not a File or FileExtension. + */ + private def fieldValueToFile(field: ArgumentSource, value: Any): File = value match { + case file: File => file + case null => null + case unknown => throw new QException("Non-file found. Try removing the annotation, change the annotation to @Argument, or extend File with FileExtension: %s: %s".format(field.field, unknown)) + } + +}