From b335c22f6d7f1788638e09af8a6dcc59dec87d40 Mon Sep 17 00:00:00 2001 From: Mark DePristo Date: Thu, 29 Mar 2012 10:34:41 -0400 Subject: [PATCH] Fully refactored, mostly cleaned up version of VariantEval using StratificationManager --- .../varianteval/VariantEvalWalker.java | 137 ++++++++++-------- .../evaluators/GenotypePhasingEvaluator.java | 4 +- .../evaluators/VariantEvaluator.java | 7 +- ...ionContext.java => EvaluationContext.java} | 31 ++-- .../varianteval/util/VariantEvalUtils.java | 20 ++- 5 files changed, 106 insertions(+), 93 deletions(-) rename public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/{NewEvaluationContext.java => EvaluationContext.java} (68%) diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/VariantEvalWalker.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/VariantEvalWalker.java index cf9b82959..f2423da33 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/VariantEvalWalker.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/VariantEvalWalker.java @@ -1,5 +1,6 @@ package org.broadinstitute.sting.gatk.walkers.varianteval; +import com.google.java.contract.Ensures; import com.google.java.contract.Requires; import net.sf.picard.reference.IndexedFastaSequenceFile; import net.sf.picard.util.IntervalTree; @@ -200,9 +201,6 @@ public class VariantEvalWalker extends RodWalker implements Tr private Set sampleNamesForStratification = new TreeSet(); private int numSamples = 0; - // The list of stratifiers and evaluators to use - private List stratificationObjects = null; - // important stratifications private boolean byFilterIsEnabled = false; private boolean perSampleIsEnabled = false; @@ -223,23 +221,7 @@ public class VariantEvalWalker extends RodWalker implements Tr private IndexedFastaSequenceFile ancestralAlignments = null; // The set of all possible evaluation contexts - StratificationManager stratManager; - - // TODO - // TODO - // TODO - // TODO - // TODO - // - // TODO -- StratificationManager should hold the master list of strats - - // TODO - // TODO - // TODO - // TODO - // TODO - - + StratificationManager stratManager; /** * Initialize the stratifications, evaluations, evaluation contexts, and reporting object @@ -285,8 +267,9 @@ public class VariantEvalWalker extends RodWalker implements Tr } // Initialize the set of stratifications and evaluations to use - stratificationObjects = variantEvalUtils.initializeStratificationObjects(this, NO_STANDARD_STRATIFICATIONS, STRATIFICATIONS_TO_USE); - Set> evaluationObjects = variantEvalUtils.initializeEvaluationObjects(NO_STANDARD_MODULES, MODULES_TO_USE); + // The list of stratifiers and evaluators to use + final List stratificationObjects = variantEvalUtils.initializeStratificationObjects(NO_STANDARD_STRATIFICATIONS, STRATIFICATIONS_TO_USE); + final Set> evaluationObjects = variantEvalUtils.initializeEvaluationObjects(NO_STANDARD_MODULES, MODULES_TO_USE); for ( VariantStratifier vs : stratificationObjects ) { if ( vs.getName().equals("Filter") ) @@ -324,9 +307,19 @@ public class VariantEvalWalker extends RodWalker implements Tr if ( knownCNVsFile != null ) { knownCNVsByContig = createIntervalTreeByContig(knownCNVsFile); } - } + final void createStratificationStates(final List stratificationObjects, final Set> evaluationObjects) { + final List strats = new ArrayList(stratificationObjects); + stratManager = new StratificationManager(strats); + + logger.info("Creating " + stratManager.size() + " combinatorial stratification states"); + for ( int i = 0; i < stratManager.size(); i++ ) { + EvaluationContext ec = new EvaluationContext(this, evaluationObjects); + stratManager.set(i, ec); + } + } + public final Map> createIntervalTreeByContig(final IntervalBinding intervals) { final Map> byContig = new HashMap>(); @@ -390,7 +383,7 @@ public class VariantEvalWalker extends RodWalker implements Tr // find the comp final VariantContext comp = findMatchingComp(eval, compSet); - for ( NewEvaluationContext nec : getEvaluationContexts(tracker, ref, eval, evalRod.getName(), comp, compRod.getName(), sampleName) ) { + for ( EvaluationContext nec : getEvaluationContexts(tracker, ref, eval, evalRod.getName(), comp, compRod.getName(), sampleName) ) { // eval against the comp synchronized (nec) { @@ -417,29 +410,32 @@ public class VariantEvalWalker extends RodWalker implements Tr return null; } - final void createStratificationStates(final List stratificationObjects, final Set> evaluationObjects) { - final List strats = new ArrayList(stratificationObjects); - stratManager = new StratificationManager(strats); - - logger.info("Creating " + stratManager.size() + " combinatorial stratification states"); - for ( int i = 0; i < stratManager.size(); i++ ) { - NewEvaluationContext ec = new NewEvaluationContext(); - -// // todo -- remove me, tmp conversion -// for ( Pair stratState : stratManager.getStratsAndStatesForKey(i) ) { -// ec.put(stratState.getFirst(), stratState.getSecond()); -// } - - ec.addEvaluationClassList(this, evaluationObjects); - stratManager.set(i, ec); - } - } - /** * Given specific eval and comp VCs and the sample name, return an iterable * over all of the applicable state keys. * - * See header of StateKey for performance problems... + * this code isn't structured yet for efficiency. Here we currently are + * doing the following inefficient algorithm: + * + * for each strat: + * get list of relevant states that eval and comp according to strat + * add this list of states to a list of list states + * + * then + * + * ask the strat manager to look up all of the keys associated with the combinations + * of these states. For example, suppose we have a single variant S. We have active + * strats EvalRod, CompRod, and Novelty. We produce a list that looks like: + * + * L = [[Eval], [Comp], [All, Novel]] + * + * We then go through the strat manager tree to produce the keys associated with these states: + * + * K = [0, 1] where EVAL x COMP x ALL = 0 and EVAL x COMP x NOVEL = 1 + * + * It's clear that a better + * + * TODO -- create an inline version that doesn't create the intermediate list of list * * @param tracker * @param ref @@ -450,7 +446,7 @@ public class VariantEvalWalker extends RodWalker implements Tr * @param sampleName * @return */ - private Collection getEvaluationContexts(final RefMetaDataTracker tracker, + private Collection getEvaluationContexts(final RefMetaDataTracker tracker, final ReferenceContext ref, final VariantContext eval, final String evalName, @@ -458,10 +454,9 @@ public class VariantEvalWalker extends RodWalker implements Tr final String compName, final String sampleName ) { final List> states = new LinkedList>(); - for ( final VariantStratifier vs : stratificationObjects ) { + for ( final VariantStratifier vs : stratManager.getStratifiers() ) { states.add(vs.getRelevantStates(ref, tracker, comp, compName, eval, evalName, sampleName)); } - return stratManager.values(states); } @@ -538,15 +533,13 @@ public class VariantEvalWalker extends RodWalker implements Tr public void onTraversalDone(Integer result) { logger.info("Finalizing variant report"); - // TODO -- VS should be sorted first with a TreeSet for ( int key = 0; key < stratManager.size(); key++ ) { final String stratStateString = stratManager.getStratsAndStatesForKeyString(key); final List> stratsAndStates = stratManager.getStratsAndStatesForKey(key); - final NewEvaluationContext nec = stratManager.get(key); + final EvaluationContext nec = stratManager.get(key); - for ( final VariantEvaluator ve : nec.getEvaluationClassList().values() ) { + for ( final VariantEvaluator ve : nec.getVariantEvaluators() ) { ve.finalizeEvaluation(); - final String veName = ve.getSimpleName(); // ve.getClass().getSimpleName(); AnalysisModuleScanner scanner = new AnalysisModuleScanner(ve); Map datamap = scanner.getData(); @@ -558,12 +551,11 @@ public class VariantEvalWalker extends RodWalker implements Tr if (field.get(ve) instanceof TableType) { TableType t = (TableType) field.get(ve); - final String subTableName = veName + "." + field.getName(); + final String subTableName = ve.getSimpleName() + "." + field.getName(); final DataPoint dataPointAnn = datamap.get(field); - if (! report.hasTable(subTableName)) { + if (! report.hasTable(subTableName)) configureNewReportTable(t, subTableName, dataPointAnn); - } final GATKReportTable table = report.getTable(subTableName); @@ -580,7 +572,7 @@ public class VariantEvalWalker extends RodWalker implements Tr } } } else { - final GATKReportTable table = report.getTable(veName); + final GATKReportTable table = report.getTable(ve.getSimpleName()); setTableColumnNames(table, stratStateString, stratsAndStates); table.set(stratStateString, field.getName(), field.get(ve)); } @@ -593,37 +585,56 @@ public class VariantEvalWalker extends RodWalker implements Tr report.print(out); } - + + /** + * A common utility function to set up the GATKReportTable for an embedded TableType in + * a VariantEvaluation + * + * @param t + * @param subTableName + * @param dataPointAnn + */ + @Requires({"t != null", "subTableName != null", "dataPointAnn != null", "!report.hasTable(subTableName)"}) + @Ensures({"report.hasTable(subTableName)"}) private final void configureNewReportTable(final TableType t, final String subTableName, final DataPoint dataPointAnn) { // basic table configuration. Set up primary key, dummy column names report.addTable(subTableName, dataPointAnn.description()); - GATKReportTable table = report.getTable(subTableName); + final GATKReportTable table = report.getTable(subTableName); table.addPrimaryKey("entry", false); table.addColumn(subTableName, subTableName); - for ( VariantStratifier vs : stratificationObjects ) { + for ( final VariantStratifier vs : stratManager.getStratifiers() ) { table.addColumn(vs.getName(), "unknown"); } table.addColumn(t.getRowName(), "unknown"); for ( final Object o : t.getColumnKeys() ) { - final String c = o.toString(); - table.addColumn(c, 0.0); + table.addColumn(o.toString(), 0.0); } } - + + /** + * Common utility to configure a GATKReportTable columns + * + * Sets the column names to the strat names in stratsAndStates for the primary key in table + * + * @param table + * @param primaryKey + * @param stratsAndStates + */ private final void setTableColumnNames(final GATKReportTable table, final String primaryKey, final List> stratsAndStates) { - for ( Pair stratAndState : stratsAndStates ) { + for ( final Pair stratAndState : stratsAndStates ) { final VariantStratifier vs = stratAndState.getFirst(); final String columnName = vs.getName(); final Object strat = stratAndState.getSecond(); + if ( columnName == null || strat == null ) + throw new ReviewedStingException("Unexpected null variant stratifier state at " + table + " key = " + primaryKey); table.set(primaryKey, columnName, strat); } - } // Accessors @@ -635,8 +646,6 @@ public class VariantEvalWalker extends RodWalker implements Tr public double getMendelianViolationQualThreshold() { return MENDELIAN_VIOLATION_QUAL_THRESHOLD; } - public List getStratificationObjects() { return stratificationObjects; } - public static String getAllSampleName() { return ALL_SAMPLE_NAME; } public List> getKnowns() { return knowns; } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/evaluators/GenotypePhasingEvaluator.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/evaluators/GenotypePhasingEvaluator.java index 41979798e..266c4fa89 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/evaluators/GenotypePhasingEvaluator.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/evaluators/GenotypePhasingEvaluator.java @@ -9,7 +9,7 @@ import org.broadinstitute.sting.gatk.walkers.phasing.ReadBackedPhasingWalker; import org.broadinstitute.sting.gatk.walkers.varianteval.VariantEvalWalker; import org.broadinstitute.sting.gatk.walkers.varianteval.util.Analysis; import org.broadinstitute.sting.gatk.walkers.varianteval.util.DataPoint; -import org.broadinstitute.sting.gatk.walkers.varianteval.util.NewEvaluationContext; +import org.broadinstitute.sting.gatk.walkers.varianteval.util.EvaluationContext; import org.broadinstitute.sting.gatk.walkers.varianteval.util.TableType; import org.broadinstitute.sting.utils.GenomeLoc; import org.broadinstitute.sting.utils.MathUtils; @@ -85,7 +85,7 @@ public class GenotypePhasingEvaluator extends VariantEvaluator { return update2(eval,comp,tracker,ref,context,null); } - public String update2(VariantContext eval, VariantContext comp, RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, NewEvaluationContext group) { + public String update2(VariantContext eval, VariantContext comp, RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, EvaluationContext group) { //public String update2(VariantContext eval, VariantContext comp, RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, VariantEvalWalker.EvaluationContext group) { Reasons interesting = new Reasons(); if (ref == null) diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/evaluators/VariantEvaluator.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/evaluators/VariantEvaluator.java index 226429439..35a100bd9 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/evaluators/VariantEvaluator.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/evaluators/VariantEvaluator.java @@ -6,7 +6,7 @@ import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.walkers.varianteval.VariantEvalWalker; import org.broadinstitute.sting.utils.variantcontext.VariantContext; -public abstract class VariantEvaluator { +public abstract class VariantEvaluator implements Comparable { private VariantEvalWalker walker; private final String simpleName; @@ -99,4 +99,9 @@ public abstract class VariantEvaluator { public String getSimpleName() { return simpleName; } + + @Override + public int compareTo(final VariantEvaluator variantEvaluator) { + return getSimpleName().compareTo(variantEvaluator.getSimpleName()); + } } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/NewEvaluationContext.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/EvaluationContext.java similarity index 68% rename from public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/NewEvaluationContext.java rename to public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/EvaluationContext.java index ef5579b01..5679299e2 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/NewEvaluationContext.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/EvaluationContext.java @@ -12,18 +12,18 @@ import org.broadinstitute.sting.utils.variantcontext.VariantContext; import java.util.*; -public class NewEvaluationContext { // extends HashMap { - private Map evaluationInstances; +public final class EvaluationContext { + // NOTE: must be hashset to avoid O(log n) cost of iteration in the very frequently called apply function + private final HashSet evaluationInstances; - public void addEvaluationClassList(VariantEvalWalker walker, Set> evaluationClasses) { - evaluationInstances = new LinkedHashMap(evaluationClasses.size()); + public EvaluationContext(final VariantEvalWalker walker, final Set> evaluationClasses) { + evaluationInstances = new HashSet(evaluationClasses.size()); for ( final Class c : evaluationClasses ) { try { final VariantEvaluator eval = c.newInstance(); eval.initialize(walker); - - evaluationInstances.put(c.getSimpleName(), eval); + evaluationInstances.add(eval); } catch (InstantiationException e) { throw new StingException("Unable to instantiate eval module '" + c.getSimpleName() + "'"); } catch (IllegalAccessException e) { @@ -32,12 +32,17 @@ public class NewEvaluationContext { // extends HashMap getEvaluationClassList() { - return new TreeMap(evaluationInstances); + /** + * Returns a sorted set of VariantEvaluators + * + * @return + */ + public final TreeSet getVariantEvaluators() { + return new TreeSet(evaluationInstances); } - public void apply(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, VariantContext comp, VariantContext eval) { - for ( final VariantEvaluator evaluation : evaluationInstances.values() ) { + public final void apply(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, VariantContext comp, VariantContext eval) { + for ( final VariantEvaluator evaluation : evaluationInstances ) { // the other updateN methods don't see a null context if ( tracker == null ) continue; @@ -48,13 +53,9 @@ public class NewEvaluationContext { // extends HashMap initializeStratificationObjects(VariantEvalWalker variantEvalWalker, boolean noStandardStrats, String[] modulesToUse) { + public List initializeStratificationObjects(boolean noStandardStrats, String[] modulesToUse) { TreeSet strats = new TreeSet(); Set stratsToUse = new HashSet(); @@ -189,26 +188,25 @@ public class VariantEvalUtils { * @return an initialized report object */ public GATKReport initializeGATKReport(Collection stratificationObjects, Set> evaluationObjects) { - GATKReport report = new GATKReport(); + final GATKReport report = new GATKReport(); for (Class ve : evaluationObjects) { - String tableName = ve.getSimpleName(); - String tableDesc = ve.getAnnotation(Analysis.class).description(); + final String tableName = ve.getSimpleName(); + final String tableDesc = ve.getAnnotation(Analysis.class).description(); report.addTable(tableName, tableDesc); - GATKReportTable table = report.getTable(tableName); + final GATKReportTable table = report.getTable(tableName); table.addPrimaryKey("entry", false); table.addColumn(tableName, tableName); - for (VariantStratifier vs : stratificationObjects) { - String columnName = vs.getName(); - + for (final VariantStratifier vs : stratificationObjects) { + final String columnName = vs.getName(); table.addColumn(columnName, "unknown"); } try { - VariantEvaluator vei = ve.newInstance(); + final VariantEvaluator vei = ve.newInstance(); vei.initialize(variantEvalWalker); AnalysisModuleScanner scanner = new AnalysisModuleScanner(vei); @@ -218,7 +216,7 @@ public class VariantEvalUtils { field.setAccessible(true); if (!(field.get(vei) instanceof TableType)) { - String format = datamap.get(field).format(); + final String format = datamap.get(field).format(); table.addColumn(field.getName(), true, format); } }