diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/VariantEvalWalker.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/VariantEvalWalker.java index 04b44a841..cf9b82959 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/VariantEvalWalker.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/VariantEvalWalker.java @@ -21,7 +21,6 @@ import org.broadinstitute.sting.gatk.walkers.Window; import org.broadinstitute.sting.gatk.walkers.varianteval.evaluators.VariantEvaluator; import org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.IntervalStratification; import org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.VariantStratifier; -import org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manager.SetOfStates; import org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manager.StratificationManager; import org.broadinstitute.sting.gatk.walkers.varianteval.util.*; import org.broadinstitute.sting.utils.GenomeLoc; @@ -29,6 +28,7 @@ import org.broadinstitute.sting.utils.GenomeLocParser; import org.broadinstitute.sting.utils.SampleUtils; import org.broadinstitute.sting.utils.codecs.vcf.VCFHeader; import org.broadinstitute.sting.utils.codecs.vcf.VCFUtils; +import org.broadinstitute.sting.utils.collections.Pair; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import org.broadinstitute.sting.utils.exceptions.StingException; import org.broadinstitute.sting.utils.exceptions.UserException; @@ -225,6 +225,22 @@ public class VariantEvalWalker extends RodWalker implements Tr // The set of all possible evaluation contexts StratificationManager stratManager; + // TODO + // TODO + // TODO + // TODO + // TODO + // + // TODO -- StratificationManager should hold the master list of strats + + // TODO + // TODO + // TODO + // TODO + // TODO + + + /** * Initialize the stratifications, evaluations, evaluation contexts, and reporting object */ @@ -403,14 +419,18 @@ public class VariantEvalWalker extends RodWalker implements Tr final void createStratificationStates(final List stratificationObjects, final Set> evaluationObjects) { final List strats = new ArrayList(stratificationObjects); - stratManager = - new StratificationManager(strats); + stratManager = new StratificationManager(strats); logger.info("Creating " + stratManager.size() + " combinatorial stratification states"); for ( int i = 0; i < stratManager.size(); i++ ) { NewEvaluationContext ec = new NewEvaluationContext(); - ec.putAll(stratManager.getStateForKey(i)); - ec.addEvaluationClassList(this, null, evaluationObjects); + +// // todo -- remove me, tmp conversion +// for ( Pair stratState : stratManager.getStratsAndStatesForKey(i) ) { +// ec.put(stratState.getFirst(), stratState.getSecond()); +// } + + ec.addEvaluationClassList(this, evaluationObjects); stratManager.set(i, ec); } } @@ -518,23 +538,20 @@ public class VariantEvalWalker extends RodWalker implements Tr public void onTraversalDone(Integer result) { logger.info("Finalizing variant report"); - // TODO -- clean up -- this is deeply unsafe + // TODO -- VS should be sorted first with a TreeSet for ( int key = 0; key < stratManager.size(); key++ ) { - final Map stateValues = stratManager.getStateForKey(key); + final String stratStateString = stratManager.getStratsAndStatesForKeyString(key); + final List> stratsAndStates = stratManager.getStratsAndStatesForKey(key); final NewEvaluationContext nec = stratManager.get(key); - final Map stateKey = new HashMap(stateValues.size()); - for ( Map.Entry elt : stateValues.entrySet() ) - stateKey.put(elt.getKey().getName(), elt.getValue()); - - for ( VariantEvaluator ve : nec.getEvaluationClassList().values() ) { + for ( final VariantEvaluator ve : nec.getEvaluationClassList().values() ) { ve.finalizeEvaluation(); final String veName = ve.getSimpleName(); // ve.getClass().getSimpleName(); AnalysisModuleScanner scanner = new AnalysisModuleScanner(ve); Map datamap = scanner.getData(); - for (Field field : datamap.keySet()) { + for ( final Field field : datamap.keySet()) { try { field.setAccessible(true); @@ -544,52 +561,28 @@ public class VariantEvalWalker extends RodWalker implements Tr final String subTableName = veName + "." + field.getName(); final DataPoint dataPointAnn = datamap.get(field); - GATKReportTable table; - if (!report.hasTable(subTableName)) { - report.addTable(subTableName, dataPointAnn.description()); - table = report.getTable(subTableName); - - table.addPrimaryKey("entry", false); - table.addColumn(subTableName, subTableName); - - for ( VariantStratifier vs : stratificationObjects ) { - table.addColumn(vs.getName(), "unknown"); - } - - table.addColumn(t.getRowName(), "unknown"); - - for ( final Object o : t.getColumnKeys() ) { - final String c = o.toString(); - table.addColumn(c, 0.0); - } - } else { - table = report.getTable(subTableName); + if (! report.hasTable(subTableName)) { + configureNewReportTable(t, subTableName, dataPointAnn); } + final GATKReportTable table = report.getTable(subTableName); + for (int row = 0; row < t.getRowKeys().length; row++) { final String r = t.getRowKeys()[row].toString(); + final String newStratStateString = stratStateString + r; - for ( VariantStratifier vs : stratificationObjects ) { - final String columnName = vs.getName(); - table.set(stateKey.toString() + r, columnName, stateKey.get(columnName)); - } + setTableColumnNames(table, newStratStateString, stratsAndStates); for (int col = 0; col < t.getColumnKeys().length; col++) { final String c = t.getColumnKeys()[col].toString(); - final String newStateKey = stateKey.toString() + r; - table.set(newStateKey, c, t.getCell(row, col)); - table.set(newStateKey, t.getRowName(), r); + table.set(newStratStateString, c, t.getCell(row, col)); + table.set(newStratStateString, t.getRowName(), r); } } } else { - GATKReportTable table = report.getTable(veName); - - for ( VariantStratifier vs : stratificationObjects ) { - final String columnName = vs.getName(); - table.set(stateKey.toString(), columnName, stateKey.get(vs.getName())); - } - - table.set(stateKey.toString(), field.getName(), field.get(ve)); + final GATKReportTable table = report.getTable(veName); + setTableColumnNames(table, stratStateString, stratsAndStates); + table.set(stratStateString, field.getName(), field.get(ve)); } } catch (IllegalAccessException e) { throw new StingException("IllegalAccessException: " + e); @@ -600,6 +593,38 @@ public class VariantEvalWalker extends RodWalker implements Tr report.print(out); } + + private final void configureNewReportTable(final TableType t, final String subTableName, final DataPoint dataPointAnn) { + // basic table configuration. Set up primary key, dummy column names + report.addTable(subTableName, dataPointAnn.description()); + GATKReportTable table = report.getTable(subTableName); + + table.addPrimaryKey("entry", false); + table.addColumn(subTableName, subTableName); + + for ( VariantStratifier vs : stratificationObjects ) { + table.addColumn(vs.getName(), "unknown"); + } + + table.addColumn(t.getRowName(), "unknown"); + + for ( final Object o : t.getColumnKeys() ) { + final String c = o.toString(); + table.addColumn(c, 0.0); + } + } + + private final void setTableColumnNames(final GATKReportTable table, + final String primaryKey, + final List> stratsAndStates) { + for ( Pair stratAndState : stratsAndStates ) { + final VariantStratifier vs = stratAndState.getFirst(); + final String columnName = vs.getName(); + final Object strat = stratAndState.getSecond(); + table.set(primaryKey, columnName, strat); + } + + } // Accessors public Logger getLogger() { return logger; } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/IntervalStratification.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/IntervalStratification.java index 7fe98ea21..e323b4434 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/IntervalStratification.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/IntervalStratification.java @@ -55,6 +55,10 @@ public class IntervalStratification extends VariantStratifier { final protected static Logger logger = Logger.getLogger(IntervalStratification.class); Map> intervalTreeByContig = null; + final List OVERLAPPING = Arrays.asList((Object)"all", (Object)"overlaps.intervals"); + final List NOT_OVERLAPPING = Arrays.asList((Object)"all", (Object)"outside.intervals"); + + @Override public void initialize() { if ( getVariantEvalWalker().intervalsFile == null ) @@ -79,7 +83,10 @@ public class IntervalStratification extends VariantStratifier { IntervalTree intervalTree = intervalTreeByContig.get(loc.getContig()); IntervalTree.Node node = intervalTree.minOverlapper(loc.getStart(), loc.getStop()); //logger.info(String.format("Overlap %s found %s", loc, node)); - return Collections.singletonList((Object)(node != null ? "overlaps.intervals" : "outside.intervals")); + if ( node != null ) + return OVERLAPPING; + else + return NOT_OVERLAPPING; } return Collections.emptyList(); diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/VariantStratifier.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/VariantStratifier.java index 2398605de..702a10b3d 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/VariantStratifier.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/VariantStratifier.java @@ -3,13 +3,13 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications; import org.broadinstitute.sting.gatk.contexts.ReferenceContext; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.walkers.varianteval.VariantEvalWalker; -import org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manager.SetOfStates; +import org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manager.Stratifier; import org.broadinstitute.sting.utils.variantcontext.VariantContext; import java.util.ArrayList; import java.util.List; -public abstract class VariantStratifier implements Comparable, SetOfStates { +public abstract class VariantStratifier implements Comparable, Stratifier { private VariantEvalWalker variantEvalWalker; final private String name; final protected ArrayList states = new ArrayList(); @@ -53,6 +53,11 @@ public abstract class VariantStratifier implements Comparable return this.getName().compareTo(o1.getName()); } + @Override + public String toString() { + return getName(); + } + public final String getName() { return name; } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratNode.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratNode.java index b82fd2bc4..6b3375048 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratNode.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratNode.java @@ -59,7 +59,7 @@ import java.util.*; */ @Invariant({ "(isLeaf() && stratifier == null && subnodes.isEmpty()) || (!isLeaf() && stratifier != null && !subnodes.isEmpty())"}) -class StratNode implements Iterable> { +class StratNode implements Iterable> { int key = -1; final T stratifier; // TODO -- track state key that maps to root node diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratNodeIterator.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratNodeIterator.java index cda30a0c9..3aff4fe27 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratNodeIterator.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratNodeIterator.java @@ -34,7 +34,7 @@ import java.util.*; * @author Mark DePristo * @since 3/27/12 */ -class StratNodeIterator implements Iterator> { +class StratNodeIterator implements Iterator> { Queue>> iterators = new LinkedList>>(); Iterator> currentIterator; diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratificationManager.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratificationManager.java index 9f5a29fdb..a2653584e 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratificationManager.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratificationManager.java @@ -26,6 +26,7 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manage import com.google.java.contract.Ensures; import com.google.java.contract.Requires; +import org.broadinstitute.sting.utils.collections.Pair; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import java.util.*; @@ -36,7 +37,7 @@ import java.util.*; * @author Mark DePristo * @since 3/27/12 */ -public class StratificationManager implements Map, V> { +public class StratificationManager implements Map, V> { private final StratNode root; private final int size; @@ -45,6 +46,7 @@ public class StratificationManager implements Map valuesByKey; private final ArrayList> stratifierValuesByKey; + private final ArrayList keyStrings; // ------------------------------------------------------------------------------------- // @@ -64,9 +66,11 @@ public class StratificationManager implements Map(size()); this.stratifierValuesByKey = new ArrayList>(size()); + this.keyStrings = new ArrayList(size()); for ( int i = 0; i < size(); i++ ) { this.valuesByKey.add(null); this.stratifierValuesByKey.add(null); + this.keyStrings.add(null); } assignStratifierValuesByKey(root); } @@ -140,6 +144,11 @@ public class StratificationManager implements Map getStratifiers() { + return stratifiers; + } + // ------------------------------------------------------------------------------------- // // mapping from states -> keys @@ -160,16 +169,39 @@ public class StratificationManager implements Map getStateForKey(final int key) { - final Map states = new HashMap(stratifiers.size()); + public List getStatesForKey(final int key) { + final List states = new ArrayList(stratifiers.size()); + for ( int i = 0; i < stratifiers.size(); i++ ) { + final Object stratValue = stratifierValuesByKey.get(key).get(i); + states.add(stratValue); + } + return states; + } + + public List> getStratsAndStatesForKey(final int key) { + final List> states = new ArrayList>(stratifiers.size()); for ( int i = 0; i < stratifiers.size(); i++ ) { final K strat = stratifiers.get(i); final Object stratValue = stratifierValuesByKey.get(key).get(i); - states.put(strat, stratValue); + states.add(new Pair(strat, stratValue)); } return states; } + public String getStratsAndStatesForKeyString(final int key) { + if ( keyStrings.get(key) == null ) { + StringBuilder b = new StringBuilder(); + for ( int i = 0; i < stratifiers.size(); i++ ) { + final K strat = stratifiers.get(i); + final Object stratValue = stratifierValuesByKey.get(key).get(i); + b.append(strat.toString()).append(":").append(stratValue.toString()); + } + keyStrings.set(key, b.toString()); + } + + return keyStrings.get(key); + } + // ------------------------------------------------------------------------------------- // // valuesByKey diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/SetOfStates.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/Stratifier.java similarity index 96% rename from public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/SetOfStates.java rename to public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/Stratifier.java index 7a65e62af..d77ef6eba 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/SetOfStates.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/Stratifier.java @@ -27,12 +27,12 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manage import java.util.List; /** - * A basic interface for a class to be used with the StratificationStates system + * A basic interface for a class to be used with the StratificationManager system * * @author Mark DePristo * @since 3/28/12 */ -public interface SetOfStates { +public interface Stratifier { /** * @return a list of all objects states that may be provided by this States provider */ diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/NewEvaluationContext.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/NewEvaluationContext.java index 5dfc321a6..ef5579b01 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/NewEvaluationContext.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/NewEvaluationContext.java @@ -12,10 +12,10 @@ import org.broadinstitute.sting.utils.variantcontext.VariantContext; import java.util.*; -public class NewEvaluationContext extends HashMap { +public class NewEvaluationContext { // extends HashMap { private Map evaluationInstances; - public void addEvaluationClassList(VariantEvalWalker walker, StateKey stateKey, Set> evaluationClasses) { + public void addEvaluationClassList(VariantEvalWalker walker, Set> evaluationClasses) { evaluationInstances = new LinkedHashMap(evaluationClasses.size()); for ( final Class c : evaluationClasses ) { @@ -36,16 +36,6 @@ public class NewEvaluationContext extends HashMap { return new TreeMap(evaluationInstances); } - public StateKey makeStateKey() { - Map map = new HashMap(size()); - - for (Map.Entry elt : this.entrySet() ) { - map.put(elt.getKey().getName(), elt.getValue()); - } - - return new StateKey(map); - } - public void apply(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, VariantContext comp, VariantContext eval) { for ( final VariantEvaluator evaluation : evaluationInstances.values() ) { // the other updateN methods don't see a null context diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/StateKey.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/StateKey.java deleted file mode 100755 index a52f68a6c..000000000 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/StateKey.java +++ /dev/null @@ -1,104 +0,0 @@ -package org.broadinstitute.sting.gatk.walkers.varianteval.util; - -import java.util.Map; -import java.util.TreeMap; - -/** - * A final constant class representing the specific state configuration - * for a VariantEvaluator instance. - * - * The way this is currently implemented is by a map from the name of a VariantStratification to a - * specific state string. For example, the stratification Novelty has states all, known, novel. A - * specific variant and comp would be tagged as "known" by the stratification, and this could be - * represented here by the map (Novelty -> known). - * - * TODO -- PERFORMANCE PROBLEM -- MAD 03/27/12 - * TODO -- PERFORMANCE PROBLEM -- MAD 03/27/12 - * TODO -- PERFORMANCE PROBLEM -- MAD 03/27/12 - * TODO -- PERFORMANCE PROBLEM -- MAD 03/27/12 - * TODO -- PERFORMANCE PROBLEM -- MAD 03/27/12 - * - * I've been staring at this state key code for a while. It's just not right, and expensive to boot. - * Here are my thoughts for future work. The state key is both a key with specific state values for - * every stratification. For example, (known, sample1, ac=1). This capability is used in some places, - * such as below, to return a set of all states that should be updated given the eval and comp - * VCs. In principle there are a finite set of such combinations (the product of all states for all active - * stratifications at initialization). We could represent such keys as integers into the set of all combinations. - * - * Note that all of the code that manipulates these things is just terrible. It's all string manipulation and - * HashMaps. Since we are effectively always squaring off our VE analyses (i.e., we have a table with - * all variable values for all stratification combinations) it doesn't make sense to allow so much dynamicism. Instead - * we should just upfront create a giant table indexed by integer keys, and manage data via a simple map from - * specific strat state to this key. - * - * The reason this is so important is that >80% of the runtime of VE with VCFs with >1000 samples is spent in - * the initializeStateKey function. Instead, we should have code that looks like: - * - * init: - * allStates <- initializeCombinationalStateSpace - * - * map: - * for each eval / comp pair: - * for each relevantState based on eval / comp: - * allStates[relevantState].update(eval, comp) - * - * - */ -public final class StateKey { - /** High-performance cache of the toString operation for a constant class */ - private final String string; - private final TreeMap states; - - public StateKey(final Map states) { - this.states = new TreeMap(states); - this.string = formatString(); - } - - public StateKey(final StateKey toOverride, final String keyOverride, final Object valueOverride) { - if ( toOverride == null ) { - this.states = new TreeMap(); - } else { - this.states = new TreeMap(toOverride.states); - } - - this.states.put(keyOverride, valueOverride); - this.string = formatString(); - } - - @Override - public boolean equals(final Object o) { - if (this == o) return true; - if (o == null || getClass() != o.getClass()) return false; - - final StateKey stateKey = (StateKey) o; - - if (states != null ? !states.equals(stateKey.states) : stateKey.states != null) return false; - - return true; - } - - @Override - public int hashCode() { - return states.hashCode(); - } - - @Override - public String toString() { - return string; - } - - private final String formatString() { - StringBuilder b = new StringBuilder(); - - for ( Map.Entry entry : states.entrySet() ) { - b.append(String.format("%s:%s;", entry.getKey(), entry.getValue())); - } - - return b.toString(); - } - - // TODO -- might be slow because of tree map - public Object get(final String key) { - return states.get(key); - } -} diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/VariantEvalUtils.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/VariantEvalUtils.java index 81df7215a..66374abb7 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/VariantEvalUtils.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/util/VariantEvalUtils.java @@ -88,7 +88,7 @@ public class VariantEvalUtils { * @return set of stratifications to use */ public List initializeStratificationObjects(VariantEvalWalker variantEvalWalker, boolean noStandardStrats, String[] modulesToUse) { - List strats = new ArrayList(); + TreeSet strats = new TreeSet(); Set stratsToUse = new HashSet(); // Create a map for all stratification modules for easy lookup. @@ -139,7 +139,7 @@ public class VariantEvalUtils { } } - return strats; + return new ArrayList(strats); } /** diff --git a/public/java/test/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratificationManagerUnitTest.java b/public/java/test/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratificationManagerUnitTest.java index 93db1f9ad..2b6f5c712 100644 --- a/public/java/test/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratificationManagerUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/manager/StratificationManagerUnitTest.java @@ -31,6 +31,7 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manage import org.broadinstitute.sting.BaseTest; import org.broadinstitute.sting.utils.Utils; +import org.broadinstitute.sting.utils.collections.Pair; import org.testng.Assert; import org.testng.annotations.BeforeClass; import org.testng.annotations.DataProvider; @@ -53,7 +54,7 @@ public class StratificationManagerUnitTest extends BaseTest { private class StratificationStatesTestProvider extends TestDataProvider { final List> allStates = new ArrayList>(); - final List asSetOfStates = new ArrayList(); + final List asSetOfStates = new ArrayList(); final int nStates; public StratificationStatesTestProvider(final List ... allStates) { @@ -64,7 +65,7 @@ public class StratificationManagerUnitTest extends BaseTest { } for ( List states : this.allStates ) { - asSetOfStates.add(new ListAsSetOfStates(states)); + asSetOfStates.add(new IntegerStratifier(states)); } this.nStates = Utils.nCombinations(allStates); @@ -79,7 +80,7 @@ public class StratificationManagerUnitTest extends BaseTest { return b.toString(); } - public List getStateSpaceList() { + public List getStateSpaceList() { return asSetOfStates; } @@ -118,10 +119,10 @@ public class StratificationManagerUnitTest extends BaseTest { } } - private class ListAsSetOfStates implements SetOfStates { + private class IntegerStratifier implements Stratifier { final List integers; - private ListAsSetOfStates(final List integers) { + private IntegerStratifier(final List integers) { this.integers = integers; } @@ -144,8 +145,8 @@ public class StratificationManagerUnitTest extends BaseTest { return StratificationStatesTestProvider.getTests(StratificationStatesTestProvider.class); } - private final StratificationManager createManager(StratificationStatesTestProvider cfg) { - final StratificationManager manager = new StratificationManager(cfg.getStateSpaceList()); + private final StratificationManager createManager(StratificationStatesTestProvider cfg) { + final StratificationManager manager = new StratificationManager(cfg.getStateSpaceList()); List values = cfg.values(); for ( int i = 0; i < cfg.nStates; i++ ) manager.set(i, values.get(i)); @@ -157,7 +158,7 @@ public class StratificationManagerUnitTest extends BaseTest { @Test(dataProvider = "StratificationStatesTestProvider") public void testLeafCount(StratificationStatesTestProvider cfg) { - final StratificationManager stratificationManager = createManager(cfg); + final StratificationManager stratificationManager = createManager(cfg); Assert.assertEquals(stratificationManager.size(), cfg.nStates); @@ -171,7 +172,7 @@ public class StratificationManagerUnitTest extends BaseTest { @Test(dataProvider = "StratificationStatesTestProvider") public void testKeys(StratificationStatesTestProvider cfg) { - final StratificationManager stratificationManager = createManager(cfg); + final StratificationManager stratificationManager = createManager(cfg); final Set seenKeys = new HashSet(cfg.nStates); for ( final StratNode node : stratificationManager.getRoot() ) { if ( node.isLeaf() ) { @@ -183,7 +184,7 @@ public class StratificationManagerUnitTest extends BaseTest { @Test(dataProvider = "StratificationStatesTestProvider") public void testFindSingleKeys(StratificationStatesTestProvider cfg) { - final StratificationManager stratificationManager = createManager(cfg); + final StratificationManager stratificationManager = createManager(cfg); final Set seenKeys = new HashSet(cfg.nStates); for ( List state : cfg.getAllCombinations() ) { final int key = stratificationManager.getKey(state); @@ -203,7 +204,7 @@ public class StratificationManagerUnitTest extends BaseTest { @Test(dataProvider = "StratificationStatesTestProvider") public void testFindMultipleKeys(StratificationStatesTestProvider cfg) { - final StratificationManager stratificationManager = createManager(cfg); + final StratificationManager stratificationManager = createManager(cfg); final List> states = new ArrayList>(cfg.allStates); final Set keys = stratificationManager.getKeys(states); Assert.assertEquals(keys.size(), cfg.nStates, "Find all states didn't find all of the expected unique keys"); @@ -230,8 +231,22 @@ public class StratificationManagerUnitTest extends BaseTest { @Test(dataProvider = "StratificationStatesTestProvider") public void testMapSet(StratificationStatesTestProvider cfg) { - final StratificationManager stratificationManager = createManager(cfg); + final StratificationManager stratificationManager = createManager(cfg); stratificationManager.set(0, -1); Assert.assertEquals((int)stratificationManager.get(0), -1); } + + @Test(dataProvider = "StratificationStatesTestProvider") + public void testStratifierByKey(StratificationStatesTestProvider cfg) { + final StratificationManager manager = createManager(cfg); + for ( int key = 0; key < cfg.nStates; key++ ) { + List> stratsAndStates = manager.getStratsAndStatesForKey(key); + final List strats = manager.getStatesForKey(key); + Assert.assertEquals((int)manager.get(strats), key, "Key -> strats -> key failed to return same key"); + + for ( int i = 0; i < strats.size(); i++ ) { + Assert.assertEquals(stratsAndStates.get(i).getSecond(), strats.get(i), "Strats and StratsAndStates differ"); + } + } + } } \ No newline at end of file