Fully refactored, mostly cleaned up version of VariantEval using StratificationManager

This commit is contained in:
Mark DePristo 2012-03-29 10:34:41 -04:00
parent c8086a79e3
commit b335c22f6d
5 changed files with 106 additions and 93 deletions

View File

@ -1,5 +1,6 @@
package org.broadinstitute.sting.gatk.walkers.varianteval;
import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import net.sf.picard.reference.IndexedFastaSequenceFile;
import net.sf.picard.util.IntervalTree;
@ -200,9 +201,6 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
private Set<String> sampleNamesForStratification = new TreeSet<String>();
private int numSamples = 0;
// The list of stratifiers and evaluators to use
private List<VariantStratifier> stratificationObjects = null;
// important stratifications
private boolean byFilterIsEnabled = false;
private boolean perSampleIsEnabled = false;
@ -223,23 +221,7 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
private IndexedFastaSequenceFile ancestralAlignments = null;
// The set of all possible evaluation contexts
StratificationManager<VariantStratifier, NewEvaluationContext> stratManager;
// TODO
// TODO
// TODO
// TODO
// TODO
//
// TODO -- StratificationManager should hold the master list of strats
// TODO
// TODO
// TODO
// TODO
// TODO
StratificationManager<VariantStratifier, EvaluationContext> stratManager;
/**
* Initialize the stratifications, evaluations, evaluation contexts, and reporting object
@ -285,8 +267,9 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
}
// Initialize the set of stratifications and evaluations to use
stratificationObjects = variantEvalUtils.initializeStratificationObjects(this, NO_STANDARD_STRATIFICATIONS, STRATIFICATIONS_TO_USE);
Set<Class<? extends VariantEvaluator>> evaluationObjects = variantEvalUtils.initializeEvaluationObjects(NO_STANDARD_MODULES, MODULES_TO_USE);
// The list of stratifiers and evaluators to use
final List<VariantStratifier> stratificationObjects = variantEvalUtils.initializeStratificationObjects(NO_STANDARD_STRATIFICATIONS, STRATIFICATIONS_TO_USE);
final Set<Class<? extends VariantEvaluator>> evaluationObjects = variantEvalUtils.initializeEvaluationObjects(NO_STANDARD_MODULES, MODULES_TO_USE);
for ( VariantStratifier vs : stratificationObjects ) {
if ( vs.getName().equals("Filter") )
@ -324,9 +307,19 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
if ( knownCNVsFile != null ) {
knownCNVsByContig = createIntervalTreeByContig(knownCNVsFile);
}
}
final void createStratificationStates(final List<VariantStratifier> stratificationObjects, final Set<Class<? extends VariantEvaluator>> evaluationObjects) {
final List<VariantStratifier> strats = new ArrayList<VariantStratifier>(stratificationObjects);
stratManager = new StratificationManager<VariantStratifier, EvaluationContext>(strats);
logger.info("Creating " + stratManager.size() + " combinatorial stratification states");
for ( int i = 0; i < stratManager.size(); i++ ) {
EvaluationContext ec = new EvaluationContext(this, evaluationObjects);
stratManager.set(i, ec);
}
}
public final Map<String, IntervalTree<GenomeLoc>> createIntervalTreeByContig(final IntervalBinding<Feature> intervals) {
final Map<String, IntervalTree<GenomeLoc>> byContig = new HashMap<String, IntervalTree<GenomeLoc>>();
@ -390,7 +383,7 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
// find the comp
final VariantContext comp = findMatchingComp(eval, compSet);
for ( NewEvaluationContext nec : getEvaluationContexts(tracker, ref, eval, evalRod.getName(), comp, compRod.getName(), sampleName) ) {
for ( EvaluationContext nec : getEvaluationContexts(tracker, ref, eval, evalRod.getName(), comp, compRod.getName(), sampleName) ) {
// eval against the comp
synchronized (nec) {
@ -417,29 +410,32 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
return null;
}
final void createStratificationStates(final List<VariantStratifier> stratificationObjects, final Set<Class<? extends VariantEvaluator>> evaluationObjects) {
final List<VariantStratifier> strats = new ArrayList<VariantStratifier>(stratificationObjects);
stratManager = new StratificationManager<VariantStratifier, NewEvaluationContext>(strats);
logger.info("Creating " + stratManager.size() + " combinatorial stratification states");
for ( int i = 0; i < stratManager.size(); i++ ) {
NewEvaluationContext ec = new NewEvaluationContext();
// // todo -- remove me, tmp conversion
// for ( Pair<VariantStratifier, Object> stratState : stratManager.getStratsAndStatesForKey(i) ) {
// ec.put(stratState.getFirst(), stratState.getSecond());
// }
ec.addEvaluationClassList(this, evaluationObjects);
stratManager.set(i, ec);
}
}
/**
* Given specific eval and comp VCs and the sample name, return an iterable
* over all of the applicable state keys.
*
* See header of StateKey for performance problems...
* this code isn't structured yet for efficiency. Here we currently are
* doing the following inefficient algorithm:
*
* for each strat:
* get list of relevant states that eval and comp according to strat
* add this list of states to a list of list states
*
* then
*
* ask the strat manager to look up all of the keys associated with the combinations
* of these states. For example, suppose we have a single variant S. We have active
* strats EvalRod, CompRod, and Novelty. We produce a list that looks like:
*
* L = [[Eval], [Comp], [All, Novel]]
*
* We then go through the strat manager tree to produce the keys associated with these states:
*
* K = [0, 1] where EVAL x COMP x ALL = 0 and EVAL x COMP x NOVEL = 1
*
* It's clear that a better
*
* TODO -- create an inline version that doesn't create the intermediate list of list
*
* @param tracker
* @param ref
@ -450,7 +446,7 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
* @param sampleName
* @return
*/
private Collection<NewEvaluationContext> getEvaluationContexts(final RefMetaDataTracker tracker,
private Collection<EvaluationContext> getEvaluationContexts(final RefMetaDataTracker tracker,
final ReferenceContext ref,
final VariantContext eval,
final String evalName,
@ -458,10 +454,9 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
final String compName,
final String sampleName ) {
final List<List<Object>> states = new LinkedList<List<Object>>();
for ( final VariantStratifier vs : stratificationObjects ) {
for ( final VariantStratifier vs : stratManager.getStratifiers() ) {
states.add(vs.getRelevantStates(ref, tracker, comp, compName, eval, evalName, sampleName));
}
return stratManager.values(states);
}
@ -538,15 +533,13 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
public void onTraversalDone(Integer result) {
logger.info("Finalizing variant report");
// TODO -- VS should be sorted first with a TreeSet
for ( int key = 0; key < stratManager.size(); key++ ) {
final String stratStateString = stratManager.getStratsAndStatesForKeyString(key);
final List<Pair<VariantStratifier, Object>> stratsAndStates = stratManager.getStratsAndStatesForKey(key);
final NewEvaluationContext nec = stratManager.get(key);
final EvaluationContext nec = stratManager.get(key);
for ( final VariantEvaluator ve : nec.getEvaluationClassList().values() ) {
for ( final VariantEvaluator ve : nec.getVariantEvaluators() ) {
ve.finalizeEvaluation();
final String veName = ve.getSimpleName(); // ve.getClass().getSimpleName();
AnalysisModuleScanner scanner = new AnalysisModuleScanner(ve);
Map<Field, DataPoint> datamap = scanner.getData();
@ -558,12 +551,11 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
if (field.get(ve) instanceof TableType) {
TableType t = (TableType) field.get(ve);
final String subTableName = veName + "." + field.getName();
final String subTableName = ve.getSimpleName() + "." + field.getName();
final DataPoint dataPointAnn = datamap.get(field);
if (! report.hasTable(subTableName)) {
if (! report.hasTable(subTableName))
configureNewReportTable(t, subTableName, dataPointAnn);
}
final GATKReportTable table = report.getTable(subTableName);
@ -580,7 +572,7 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
}
}
} else {
final GATKReportTable table = report.getTable(veName);
final GATKReportTable table = report.getTable(ve.getSimpleName());
setTableColumnNames(table, stratStateString, stratsAndStates);
table.set(stratStateString, field.getName(), field.get(ve));
}
@ -593,37 +585,56 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
report.print(out);
}
/**
* A common utility function to set up the GATKReportTable for an embedded TableType in
* a VariantEvaluation
*
* @param t
* @param subTableName
* @param dataPointAnn
*/
@Requires({"t != null", "subTableName != null", "dataPointAnn != null", "!report.hasTable(subTableName)"})
@Ensures({"report.hasTable(subTableName)"})
private final void configureNewReportTable(final TableType t, final String subTableName, final DataPoint dataPointAnn) {
// basic table configuration. Set up primary key, dummy column names
report.addTable(subTableName, dataPointAnn.description());
GATKReportTable table = report.getTable(subTableName);
final GATKReportTable table = report.getTable(subTableName);
table.addPrimaryKey("entry", false);
table.addColumn(subTableName, subTableName);
for ( VariantStratifier vs : stratificationObjects ) {
for ( final VariantStratifier vs : stratManager.getStratifiers() ) {
table.addColumn(vs.getName(), "unknown");
}
table.addColumn(t.getRowName(), "unknown");
for ( final Object o : t.getColumnKeys() ) {
final String c = o.toString();
table.addColumn(c, 0.0);
table.addColumn(o.toString(), 0.0);
}
}
/**
* Common utility to configure a GATKReportTable columns
*
* Sets the column names to the strat names in stratsAndStates for the primary key in table
*
* @param table
* @param primaryKey
* @param stratsAndStates
*/
private final void setTableColumnNames(final GATKReportTable table,
final String primaryKey,
final List<Pair<VariantStratifier, Object>> stratsAndStates) {
for ( Pair<VariantStratifier, Object> stratAndState : stratsAndStates ) {
for ( final Pair<VariantStratifier, Object> stratAndState : stratsAndStates ) {
final VariantStratifier vs = stratAndState.getFirst();
final String columnName = vs.getName();
final Object strat = stratAndState.getSecond();
if ( columnName == null || strat == null )
throw new ReviewedStingException("Unexpected null variant stratifier state at " + table + " key = " + primaryKey);
table.set(primaryKey, columnName, strat);
}
}
// Accessors
@ -635,8 +646,6 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
public double getMendelianViolationQualThreshold() { return MENDELIAN_VIOLATION_QUAL_THRESHOLD; }
public List<VariantStratifier> getStratificationObjects() { return stratificationObjects; }
public static String getAllSampleName() { return ALL_SAMPLE_NAME; }
public List<RodBinding<VariantContext>> getKnowns() { return knowns; }

View File

@ -9,7 +9,7 @@ import org.broadinstitute.sting.gatk.walkers.phasing.ReadBackedPhasingWalker;
import org.broadinstitute.sting.gatk.walkers.varianteval.VariantEvalWalker;
import org.broadinstitute.sting.gatk.walkers.varianteval.util.Analysis;
import org.broadinstitute.sting.gatk.walkers.varianteval.util.DataPoint;
import org.broadinstitute.sting.gatk.walkers.varianteval.util.NewEvaluationContext;
import org.broadinstitute.sting.gatk.walkers.varianteval.util.EvaluationContext;
import org.broadinstitute.sting.gatk.walkers.varianteval.util.TableType;
import org.broadinstitute.sting.utils.GenomeLoc;
import org.broadinstitute.sting.utils.MathUtils;
@ -85,7 +85,7 @@ public class GenotypePhasingEvaluator extends VariantEvaluator {
return update2(eval,comp,tracker,ref,context,null);
}
public String update2(VariantContext eval, VariantContext comp, RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, NewEvaluationContext group) {
public String update2(VariantContext eval, VariantContext comp, RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, EvaluationContext group) {
//public String update2(VariantContext eval, VariantContext comp, RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, VariantEvalWalker.EvaluationContext group) {
Reasons interesting = new Reasons();
if (ref == null)

View File

@ -6,7 +6,7 @@ import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.varianteval.VariantEvalWalker;
import org.broadinstitute.sting.utils.variantcontext.VariantContext;
public abstract class VariantEvaluator {
public abstract class VariantEvaluator implements Comparable<VariantEvaluator> {
private VariantEvalWalker walker;
private final String simpleName;
@ -99,4 +99,9 @@ public abstract class VariantEvaluator {
public String getSimpleName() {
return simpleName;
}
@Override
public int compareTo(final VariantEvaluator variantEvaluator) {
return getSimpleName().compareTo(variantEvaluator.getSimpleName());
}
}

View File

@ -12,18 +12,18 @@ import org.broadinstitute.sting.utils.variantcontext.VariantContext;
import java.util.*;
public class NewEvaluationContext { // extends HashMap<VariantStratifier, Object> {
private Map<String, VariantEvaluator> evaluationInstances;
public final class EvaluationContext {
// NOTE: must be hashset to avoid O(log n) cost of iteration in the very frequently called apply function
private final HashSet<VariantEvaluator> evaluationInstances;
public void addEvaluationClassList(VariantEvalWalker walker, Set<Class<? extends VariantEvaluator>> evaluationClasses) {
evaluationInstances = new LinkedHashMap<String, VariantEvaluator>(evaluationClasses.size());
public EvaluationContext(final VariantEvalWalker walker, final Set<Class<? extends VariantEvaluator>> evaluationClasses) {
evaluationInstances = new HashSet<VariantEvaluator>(evaluationClasses.size());
for ( final Class<? extends VariantEvaluator> c : evaluationClasses ) {
try {
final VariantEvaluator eval = c.newInstance();
eval.initialize(walker);
evaluationInstances.put(c.getSimpleName(), eval);
evaluationInstances.add(eval);
} catch (InstantiationException e) {
throw new StingException("Unable to instantiate eval module '" + c.getSimpleName() + "'");
} catch (IllegalAccessException e) {
@ -32,12 +32,17 @@ public class NewEvaluationContext { // extends HashMap<VariantStratifier, Object
}
}
public TreeMap<String, VariantEvaluator> getEvaluationClassList() {
return new TreeMap<String, VariantEvaluator>(evaluationInstances);
/**
* Returns a sorted set of VariantEvaluators
*
* @return
*/
public final TreeSet<VariantEvaluator> getVariantEvaluators() {
return new TreeSet<VariantEvaluator>(evaluationInstances);
}
public void apply(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, VariantContext comp, VariantContext eval) {
for ( final VariantEvaluator evaluation : evaluationInstances.values() ) {
public final void apply(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, VariantContext comp, VariantContext eval) {
for ( final VariantEvaluator evaluation : evaluationInstances ) {
// the other updateN methods don't see a null context
if ( tracker == null )
continue;
@ -48,13 +53,9 @@ public class NewEvaluationContext { // extends HashMap<VariantStratifier, Object
if (eval != null) {
evaluation.update1(eval, tracker, ref, context);
}
break;
case 2:
//if (eval != null) {
evaluation.update2(eval, comp, tracker, ref, context);
//}
evaluation.update2(eval, comp, tracker, ref, context);
break;
default:
throw new ReviewedStingException("BUG: Unexpected evaluation order " + evaluation);

View File

@ -82,12 +82,11 @@ public class VariantEvalUtils {
/**
* Initialize required, standard and user-specified stratification objects
*
* @param variantEvalWalker the parent walker
* @param noStandardStrats don't use the standard stratifications
* @param modulesToUse the list of stratification modules to use
* @return set of stratifications to use
*/
public List<VariantStratifier> initializeStratificationObjects(VariantEvalWalker variantEvalWalker, boolean noStandardStrats, String[] modulesToUse) {
public List<VariantStratifier> initializeStratificationObjects(boolean noStandardStrats, String[] modulesToUse) {
TreeSet<VariantStratifier> strats = new TreeSet<VariantStratifier>();
Set<String> stratsToUse = new HashSet<String>();
@ -189,26 +188,25 @@ public class VariantEvalUtils {
* @return an initialized report object
*/
public GATKReport initializeGATKReport(Collection<VariantStratifier> stratificationObjects, Set<Class<? extends VariantEvaluator>> evaluationObjects) {
GATKReport report = new GATKReport();
final GATKReport report = new GATKReport();
for (Class<? extends VariantEvaluator> ve : evaluationObjects) {
String tableName = ve.getSimpleName();
String tableDesc = ve.getAnnotation(Analysis.class).description();
final String tableName = ve.getSimpleName();
final String tableDesc = ve.getAnnotation(Analysis.class).description();
report.addTable(tableName, tableDesc);
GATKReportTable table = report.getTable(tableName);
final GATKReportTable table = report.getTable(tableName);
table.addPrimaryKey("entry", false);
table.addColumn(tableName, tableName);
for (VariantStratifier vs : stratificationObjects) {
String columnName = vs.getName();
for (final VariantStratifier vs : stratificationObjects) {
final String columnName = vs.getName();
table.addColumn(columnName, "unknown");
}
try {
VariantEvaluator vei = ve.newInstance();
final VariantEvaluator vei = ve.newInstance();
vei.initialize(variantEvalWalker);
AnalysisModuleScanner scanner = new AnalysisModuleScanner(vei);
@ -218,7 +216,7 @@ public class VariantEvalUtils {
field.setAccessible(true);
if (!(field.get(vei) instanceof TableType)) {
String format = datamap.get(field).format();
final String format = datamap.get(field).format();
table.addColumn(field.getName(), true, format);
}
}