VariantEval scalability optimizations

-- StateKey no longer extends TreeMap.  It's now a final immutable data structure that caches it's toString and hashcode values.  TODO optimizations to entirely remove the TreeMap and just store the HashMap for performance and use the tree for the sorted tostring function.
-- NewEvaluationContext has a method makeStateKey() that contains all of the functionality that once was spread around VEUtils
-- AnalysisModuleScanner uses an annotationCache to speed up the reflections getAnnotations() call when invoked over and over on the same objects.  Still expensive to convert each field to a string for the cache, but the only way around that is a complete refactoring of the toTransversalDone of VE
-- VariantEvaluator base class has a cached getSimpleName() function
-- VEUtils: general cleanup due to refactoring of StateKey
-- VEWalker: much better iteration of map data structures.  If you need access to iterate over all key/value pairs use the Map.Entry construct with entrySet.  This is far better than iterating over the keys and calling get() on each key.
This commit is contained in:
Mark DePristo 2012-03-26 07:42:15 -04:00
parent 8f0828daa6
commit 6be5e82860
6 changed files with 101 additions and 44 deletions

View File

@ -482,11 +482,13 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
public void onTraversalDone(Integer result) { public void onTraversalDone(Integer result) {
logger.info("Finalizing variant report"); logger.info("Finalizing variant report");
for ( StateKey stateKey : evaluationContexts.keySet() ) { for ( Map.Entry<StateKey, NewEvaluationContext> ecElt : evaluationContexts.entrySet() ) {
NewEvaluationContext nec = evaluationContexts.get(stateKey); final StateKey stateKey = ecElt.getKey();
final NewEvaluationContext nec = ecElt.getValue();
for ( VariantEvaluator ve : nec.getEvaluationClassList().values() ) { for ( VariantEvaluator ve : nec.getEvaluationClassList().values() ) {
ve.finalizeEvaluation(); ve.finalizeEvaluation();
final String veName = ve.getSimpleName(); // ve.getClass().getSimpleName();
AnalysisModuleScanner scanner = new AnalysisModuleScanner(ve); AnalysisModuleScanner scanner = new AnalysisModuleScanner(ve);
Map<Field, DataPoint> datamap = scanner.getData(); Map<Field, DataPoint> datamap = scanner.getData();
@ -498,7 +500,7 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
if (field.get(ve) instanceof TableType) { if (field.get(ve) instanceof TableType) {
TableType t = (TableType) field.get(ve); TableType t = (TableType) field.get(ve);
final String subTableName = ve.getClass().getSimpleName() + "." + field.getName(); final String subTableName = veName + "." + field.getName();
final DataPoint dataPointAnn = datamap.get(field); final DataPoint dataPointAnn = datamap.get(field);
GATKReportTable table; GATKReportTable table;
@ -539,11 +541,10 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
} }
} }
} else { } else {
GATKReportTable table = report.getTable(ve.getClass().getSimpleName()); GATKReportTable table = report.getTable(veName);
for ( VariantStratifier vs : stratificationObjects ) { for ( VariantStratifier vs : stratificationObjects ) {
String columnName = vs.getName(); final String columnName = vs.getName();
table.set(stateKey.toString(), columnName, stateKey.get(vs.getName())); table.set(stateKey.toString(), columnName, stateKey.get(vs.getName()));
} }

View File

@ -8,6 +8,11 @@ import org.broadinstitute.sting.utils.variantcontext.VariantContext;
public abstract class VariantEvaluator { public abstract class VariantEvaluator {
private VariantEvalWalker walker; private VariantEvalWalker walker;
private final String simpleName;
protected VariantEvaluator() {
this.simpleName = getClass().getSimpleName();
}
public void initialize(VariantEvalWalker walker) { public void initialize(VariantEvalWalker walker) {
this.walker = walker; this.walker = walker;
@ -90,4 +95,8 @@ public abstract class VariantEvaluator {
protected static String formattedRatio(final int num, final int denom) { protected static String formattedRatio(final int num, final int denom) {
return denom == 0 ? "NA" : String.format("%.2f", num / (1.0 * denom)); return denom == 0 ? "NA" : String.format("%.2f", num / (1.0 * denom));
} }
public String getSimpleName() {
return simpleName;
}
} }

View File

@ -27,6 +27,7 @@ import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import java.lang.annotation.Annotation; import java.lang.annotation.Annotation;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.util.HashMap;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
import java.util.Map; import java.util.Map;
@ -40,6 +41,7 @@ import java.util.Map;
* the object, a Mashalling object can serialize or deserialize a analysis module. * the object, a Mashalling object can serialize or deserialize a analysis module.
*/ */
public class AnalysisModuleScanner { public class AnalysisModuleScanner {
final private static Map<String, Annotation[]> annotationCache = new HashMap<String, Annotation[]>();
// what we extracted from the class // what we extracted from the class
private Map<Field, DataPoint> datums = new LinkedHashMap<Field, DataPoint>(); // the data we've discovered private Map<Field, DataPoint> datums = new LinkedHashMap<Field, DataPoint>(); // the data we've discovered
@ -84,13 +86,23 @@ public class AnalysisModuleScanner {
// get the fields from the class, and extract // get the fields from the class, and extract
for ( Class superCls = cls; superCls != null; superCls=superCls.getSuperclass() ) { for ( Class superCls = cls; superCls != null; superCls=superCls.getSuperclass() ) {
for (Field f : superCls.getDeclaredFields()) for (Field f : superCls.getDeclaredFields())
for (Annotation annotation : f.getAnnotations()) { for (Annotation annotation : getAnnotations(f)) {
if (annotation.annotationType().equals(DataPoint.class)) if (annotation.annotationType().equals(DataPoint.class))
datums.put(f,(DataPoint) annotation); datums.put(f,(DataPoint) annotation);
} }
} }
} }
private Annotation[] getAnnotations(final Field field) {
final String fieldName = field.toString();
Annotation[] annotations = annotationCache.get(fieldName);
if ( annotations == null ) {
annotations = field.getAnnotations();
annotationCache.put(fieldName, annotations);
}
return annotations;
}
/** /**
* *
* @return a map of the datum annotations found * @return a map of the datum annotations found

View File

@ -36,6 +36,16 @@ public class NewEvaluationContext extends HashMap<VariantStratifier, String> {
return new TreeMap<String, VariantEvaluator>(evaluationInstances); return new TreeMap<String, VariantEvaluator>(evaluationInstances);
} }
public StateKey makeStateKey() {
Map<String, String> map = new HashMap<String, String>(size());
for (Map.Entry<VariantStratifier, String> elt : this.entrySet() ) {
map.put(elt.getKey().getName(), elt.getValue());
}
return new StateKey(map);
}
public void apply(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, VariantContext comp, VariantContext eval) { public void apply(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, VariantContext comp, VariantContext eval) {
for ( final VariantEvaluator evaluation : evaluationInstances.values() ) { for ( final VariantEvaluator evaluation : evaluationInstances.values() ) {
// the other updateN methods don't see a null context // the other updateN methods don't see a null context

View File

@ -3,25 +3,67 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.util;
import java.util.Map; import java.util.Map;
import java.util.TreeMap; import java.util.TreeMap;
public class StateKey extends TreeMap<String, String> { /**
// public int hashCode() { * A final constant class representing the specific state configuration
// int hashCode = 1; * for a VariantEvaluator instance.
// *
// for (final Map.Entry<String,String> pair : this.entrySet()) { * TODO optimizations to entirely remove the TreeMap and just store the HashMap for performance and use the tree for the sorted tostring function.
// hashCode *= pair.getKey().hashCode() + pair.getValue().hashCode(); */
// } public final class StateKey {
// /** High-performance cache of the toString operation for a constant class */
// return hashCode; private final String string;
// } private final TreeMap<String, String> states;
public String toString() { public StateKey(final Map<String, String> states) {
String value = ""; this.states = new TreeMap<String, String>(states);
this.string = formatString();
for ( final String key : this.keySet() ) {
//value += "\tstate " + key + ":" + this.get(key) + "\n";
value += String.format("%s:%s;", key, this.get(key));
} }
return value; public StateKey(final StateKey toOverride, final String keyOverride, final String valueOverride) {
if ( toOverride == null ) {
this.states = new TreeMap<String, String>();
} else {
this.states = new TreeMap<String, String>(toOverride.states);
}
this.states.put(keyOverride, valueOverride);
this.string = formatString();
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final StateKey stateKey = (StateKey) o;
if (states != null ? !states.equals(stateKey.states) : stateKey.states != null) return false;
return true;
}
@Override
public int hashCode() {
return states.hashCode();
}
@Override
public String toString() {
return string;
}
private final String formatString() {
StringBuilder b = new StringBuilder();
for ( Map.Entry<String, String> entry : states.entrySet() ) {
b.append(String.format("%s:%s;", entry.getKey(), entry.getValue()));
}
return b.toString();
}
// TODO -- might be slow because of tree map
public String get(final String key) {
return states.get(key);
} }
} }

View File

@ -214,20 +214,9 @@ public class VariantEvalUtils {
ecs.putAll(initializeEvaluationContexts(stratificationObjects, evaluationObjects, newStratStack, nec)); ecs.putAll(initializeEvaluationContexts(stratificationObjects, evaluationObjects, newStratStack, nec));
} }
} else { } else {
HashMap<StateKey, NewEvaluationContext> necs = new HashMap<StateKey, NewEvaluationContext>(); final StateKey stateKey = ec.makeStateKey();
StateKey stateKey = new StateKey();
for (VariantStratifier vs : ec.keySet()) {
String state = ec.get(vs);
stateKey.put(vs.getName(), state);
}
ec.addEvaluationClassList(variantEvalWalker, stateKey, evaluationObjects); ec.addEvaluationClassList(variantEvalWalker, stateKey, evaluationObjects);
return new HashMap<StateKey, NewEvaluationContext>(Collections.singletonMap(stateKey, ec));
necs.put(stateKey, ec);
return necs;
} }
return ecs; return ecs;
@ -428,14 +417,8 @@ public class VariantEvalUtils {
HashMap<VariantStratifier, List<String>> oneSetOfStates = newStateStack.pop(); HashMap<VariantStratifier, List<String>> oneSetOfStates = newStateStack.pop();
VariantStratifier vs = oneSetOfStates.keySet().iterator().next(); VariantStratifier vs = oneSetOfStates.keySet().iterator().next();
for (String state : oneSetOfStates.get(vs)) { for (final String state : oneSetOfStates.get(vs)) {
StateKey newStateKey = new StateKey(); final StateKey newStateKey = new StateKey(stateKey, vs.getName(), state);
if (stateKey != null) {
newStateKey.putAll(stateKey);
}
newStateKey.put(vs.getName(), state);
initializeStateKeys(stateMap, newStateStack, newStateKey, stateKeys); initializeStateKeys(stateMap, newStateStack, newStateKey, stateKeys);
} }
} else { } else {