New StratificationManager based VariantEval passes unmodified integration tests

-- Now needs cleanup and optimizations
This commit is contained in:
Mark DePristo 2012-03-29 09:46:55 -04:00
parent d37f31e349
commit c8086a79e3
11 changed files with 159 additions and 189 deletions

View File

@ -21,7 +21,6 @@ import org.broadinstitute.sting.gatk.walkers.Window;
import org.broadinstitute.sting.gatk.walkers.varianteval.evaluators.VariantEvaluator;
import org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.IntervalStratification;
import org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.VariantStratifier;
import org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manager.SetOfStates;
import org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manager.StratificationManager;
import org.broadinstitute.sting.gatk.walkers.varianteval.util.*;
import org.broadinstitute.sting.utils.GenomeLoc;
@ -29,6 +28,7 @@ import org.broadinstitute.sting.utils.GenomeLocParser;
import org.broadinstitute.sting.utils.SampleUtils;
import org.broadinstitute.sting.utils.codecs.vcf.VCFHeader;
import org.broadinstitute.sting.utils.codecs.vcf.VCFUtils;
import org.broadinstitute.sting.utils.collections.Pair;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import org.broadinstitute.sting.utils.exceptions.StingException;
import org.broadinstitute.sting.utils.exceptions.UserException;
@ -225,6 +225,22 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
// The set of all possible evaluation contexts
StratificationManager<VariantStratifier, NewEvaluationContext> stratManager;
// TODO
// TODO
// TODO
// TODO
// TODO
//
// TODO -- StratificationManager should hold the master list of strats
// TODO
// TODO
// TODO
// TODO
// TODO
/**
* Initialize the stratifications, evaluations, evaluation contexts, and reporting object
*/
@ -403,14 +419,18 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
final void createStratificationStates(final List<VariantStratifier> stratificationObjects, final Set<Class<? extends VariantEvaluator>> evaluationObjects) {
final List<VariantStratifier> strats = new ArrayList<VariantStratifier>(stratificationObjects);
stratManager =
new StratificationManager<VariantStratifier, NewEvaluationContext>(strats);
stratManager = new StratificationManager<VariantStratifier, NewEvaluationContext>(strats);
logger.info("Creating " + stratManager.size() + " combinatorial stratification states");
for ( int i = 0; i < stratManager.size(); i++ ) {
NewEvaluationContext ec = new NewEvaluationContext();
ec.putAll(stratManager.getStateForKey(i));
ec.addEvaluationClassList(this, null, evaluationObjects);
// // todo -- remove me, tmp conversion
// for ( Pair<VariantStratifier, Object> stratState : stratManager.getStratsAndStatesForKey(i) ) {
// ec.put(stratState.getFirst(), stratState.getSecond());
// }
ec.addEvaluationClassList(this, evaluationObjects);
stratManager.set(i, ec);
}
}
@ -518,23 +538,20 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
public void onTraversalDone(Integer result) {
logger.info("Finalizing variant report");
// TODO -- clean up -- this is deeply unsafe
// TODO -- VS should be sorted first with a TreeSet
for ( int key = 0; key < stratManager.size(); key++ ) {
final Map<VariantStratifier, Object> stateValues = stratManager.getStateForKey(key);
final String stratStateString = stratManager.getStratsAndStatesForKeyString(key);
final List<Pair<VariantStratifier, Object>> stratsAndStates = stratManager.getStratsAndStatesForKey(key);
final NewEvaluationContext nec = stratManager.get(key);
final Map<String, Object> stateKey = new HashMap<String, Object>(stateValues.size());
for ( Map.Entry<VariantStratifier, Object> elt : stateValues.entrySet() )
stateKey.put(elt.getKey().getName(), elt.getValue());
for ( VariantEvaluator ve : nec.getEvaluationClassList().values() ) {
for ( final VariantEvaluator ve : nec.getEvaluationClassList().values() ) {
ve.finalizeEvaluation();
final String veName = ve.getSimpleName(); // ve.getClass().getSimpleName();
AnalysisModuleScanner scanner = new AnalysisModuleScanner(ve);
Map<Field, DataPoint> datamap = scanner.getData();
for (Field field : datamap.keySet()) {
for ( final Field field : datamap.keySet()) {
try {
field.setAccessible(true);
@ -544,52 +561,28 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
final String subTableName = veName + "." + field.getName();
final DataPoint dataPointAnn = datamap.get(field);
GATKReportTable table;
if (!report.hasTable(subTableName)) {
report.addTable(subTableName, dataPointAnn.description());
table = report.getTable(subTableName);
table.addPrimaryKey("entry", false);
table.addColumn(subTableName, subTableName);
for ( VariantStratifier vs : stratificationObjects ) {
table.addColumn(vs.getName(), "unknown");
}
table.addColumn(t.getRowName(), "unknown");
for ( final Object o : t.getColumnKeys() ) {
final String c = o.toString();
table.addColumn(c, 0.0);
}
} else {
table = report.getTable(subTableName);
if (! report.hasTable(subTableName)) {
configureNewReportTable(t, subTableName, dataPointAnn);
}
final GATKReportTable table = report.getTable(subTableName);
for (int row = 0; row < t.getRowKeys().length; row++) {
final String r = t.getRowKeys()[row].toString();
final String newStratStateString = stratStateString + r;
for ( VariantStratifier vs : stratificationObjects ) {
final String columnName = vs.getName();
table.set(stateKey.toString() + r, columnName, stateKey.get(columnName));
}
setTableColumnNames(table, newStratStateString, stratsAndStates);
for (int col = 0; col < t.getColumnKeys().length; col++) {
final String c = t.getColumnKeys()[col].toString();
final String newStateKey = stateKey.toString() + r;
table.set(newStateKey, c, t.getCell(row, col));
table.set(newStateKey, t.getRowName(), r);
table.set(newStratStateString, c, t.getCell(row, col));
table.set(newStratStateString, t.getRowName(), r);
}
}
} else {
GATKReportTable table = report.getTable(veName);
for ( VariantStratifier vs : stratificationObjects ) {
final String columnName = vs.getName();
table.set(stateKey.toString(), columnName, stateKey.get(vs.getName()));
}
table.set(stateKey.toString(), field.getName(), field.get(ve));
final GATKReportTable table = report.getTable(veName);
setTableColumnNames(table, stratStateString, stratsAndStates);
table.set(stratStateString, field.getName(), field.get(ve));
}
} catch (IllegalAccessException e) {
throw new StingException("IllegalAccessException: " + e);
@ -600,6 +593,38 @@ public class VariantEvalWalker extends RodWalker<Integer, Integer> implements Tr
report.print(out);
}
private final void configureNewReportTable(final TableType t, final String subTableName, final DataPoint dataPointAnn) {
// basic table configuration. Set up primary key, dummy column names
report.addTable(subTableName, dataPointAnn.description());
GATKReportTable table = report.getTable(subTableName);
table.addPrimaryKey("entry", false);
table.addColumn(subTableName, subTableName);
for ( VariantStratifier vs : stratificationObjects ) {
table.addColumn(vs.getName(), "unknown");
}
table.addColumn(t.getRowName(), "unknown");
for ( final Object o : t.getColumnKeys() ) {
final String c = o.toString();
table.addColumn(c, 0.0);
}
}
private final void setTableColumnNames(final GATKReportTable table,
final String primaryKey,
final List<Pair<VariantStratifier, Object>> stratsAndStates) {
for ( Pair<VariantStratifier, Object> stratAndState : stratsAndStates ) {
final VariantStratifier vs = stratAndState.getFirst();
final String columnName = vs.getName();
final Object strat = stratAndState.getSecond();
table.set(primaryKey, columnName, strat);
}
}
// Accessors
public Logger getLogger() { return logger; }

View File

@ -55,6 +55,10 @@ public class IntervalStratification extends VariantStratifier {
final protected static Logger logger = Logger.getLogger(IntervalStratification.class);
Map<String, IntervalTree<GenomeLoc>> intervalTreeByContig = null;
final List<Object> OVERLAPPING = Arrays.asList((Object)"all", (Object)"overlaps.intervals");
final List<Object> NOT_OVERLAPPING = Arrays.asList((Object)"all", (Object)"outside.intervals");
@Override
public void initialize() {
if ( getVariantEvalWalker().intervalsFile == null )
@ -79,7 +83,10 @@ public class IntervalStratification extends VariantStratifier {
IntervalTree<GenomeLoc> intervalTree = intervalTreeByContig.get(loc.getContig());
IntervalTree.Node<GenomeLoc> node = intervalTree.minOverlapper(loc.getStart(), loc.getStop());
//logger.info(String.format("Overlap %s found %s", loc, node));
return Collections.singletonList((Object)(node != null ? "overlaps.intervals" : "outside.intervals"));
if ( node != null )
return OVERLAPPING;
else
return NOT_OVERLAPPING;
}
return Collections.emptyList();

View File

@ -3,13 +3,13 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.varianteval.VariantEvalWalker;
import org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manager.SetOfStates;
import org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manager.Stratifier;
import org.broadinstitute.sting.utils.variantcontext.VariantContext;
import java.util.ArrayList;
import java.util.List;
public abstract class VariantStratifier implements Comparable<VariantStratifier>, SetOfStates {
public abstract class VariantStratifier implements Comparable<VariantStratifier>, Stratifier {
private VariantEvalWalker variantEvalWalker;
final private String name;
final protected ArrayList<Object> states = new ArrayList<Object>();
@ -53,6 +53,11 @@ public abstract class VariantStratifier implements Comparable<VariantStratifier>
return this.getName().compareTo(o1.getName());
}
@Override
public String toString() {
return getName();
}
public final String getName() {
return name;
}

View File

@ -59,7 +59,7 @@ import java.util.*;
*/
@Invariant({
"(isLeaf() && stratifier == null && subnodes.isEmpty()) || (!isLeaf() && stratifier != null && !subnodes.isEmpty())"})
class StratNode<T extends SetOfStates> implements Iterable<StratNode<T>> {
class StratNode<T extends Stratifier> implements Iterable<StratNode<T>> {
int key = -1;
final T stratifier;
// TODO -- track state key that maps to root node

View File

@ -34,7 +34,7 @@ import java.util.*;
* @author Mark DePristo
* @since 3/27/12
*/
class StratNodeIterator<T extends SetOfStates> implements Iterator<StratNode<T>> {
class StratNodeIterator<T extends Stratifier> implements Iterator<StratNode<T>> {
Queue<Iterator<StratNode<T>>> iterators = new LinkedList<Iterator<StratNode<T>>>();
Iterator<StratNode<T>> currentIterator;

View File

@ -26,6 +26,7 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manage
import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import org.broadinstitute.sting.utils.collections.Pair;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import java.util.*;
@ -36,7 +37,7 @@ import java.util.*;
* @author Mark DePristo
* @since 3/27/12
*/
public class StratificationManager<K extends SetOfStates, V> implements Map<List<Object>, V> {
public class StratificationManager<K extends Stratifier, V> implements Map<List<Object>, V> {
private final StratNode<K> root;
private final int size;
@ -45,6 +46,7 @@ public class StratificationManager<K extends SetOfStates, V> implements Map<List
// values associated with each key
private final ArrayList<V> valuesByKey;
private final ArrayList<List<Object>> stratifierValuesByKey;
private final ArrayList<String> keyStrings;
// -------------------------------------------------------------------------------------
//
@ -64,9 +66,11 @@ public class StratificationManager<K extends SetOfStates, V> implements Map<List
this.valuesByKey = new ArrayList<V>(size());
this.stratifierValuesByKey = new ArrayList<List<Object>>(size());
this.keyStrings = new ArrayList<String>(size());
for ( int i = 0; i < size(); i++ ) {
this.valuesByKey.add(null);
this.stratifierValuesByKey.add(null);
this.keyStrings.add(null);
}
assignStratifierValuesByKey(root);
}
@ -140,6 +144,11 @@ public class StratificationManager<K extends SetOfStates, V> implements Map<List
return root;
}
@Ensures("result != null")
public List<K> getStratifiers() {
return stratifiers;
}
// -------------------------------------------------------------------------------------
//
// mapping from states -> keys
@ -160,16 +169,39 @@ public class StratificationManager<K extends SetOfStates, V> implements Map<List
return keys;
}
public Map<K, Object> getStateForKey(final int key) {
final Map<K, Object> states = new HashMap<K, Object>(stratifiers.size());
public List<Object> getStatesForKey(final int key) {
final List<Object> states = new ArrayList<Object>(stratifiers.size());
for ( int i = 0; i < stratifiers.size(); i++ ) {
final Object stratValue = stratifierValuesByKey.get(key).get(i);
states.add(stratValue);
}
return states;
}
public List<Pair<K, Object>> getStratsAndStatesForKey(final int key) {
final List<Pair<K, Object>> states = new ArrayList<Pair<K, Object>>(stratifiers.size());
for ( int i = 0; i < stratifiers.size(); i++ ) {
final K strat = stratifiers.get(i);
final Object stratValue = stratifierValuesByKey.get(key).get(i);
states.put(strat, stratValue);
states.add(new Pair<K, Object>(strat, stratValue));
}
return states;
}
public String getStratsAndStatesForKeyString(final int key) {
if ( keyStrings.get(key) == null ) {
StringBuilder b = new StringBuilder();
for ( int i = 0; i < stratifiers.size(); i++ ) {
final K strat = stratifiers.get(i);
final Object stratValue = stratifierValuesByKey.get(key).get(i);
b.append(strat.toString()).append(":").append(stratValue.toString());
}
keyStrings.set(key, b.toString());
}
return keyStrings.get(key);
}
// -------------------------------------------------------------------------------------
//
// valuesByKey

View File

@ -27,12 +27,12 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manage
import java.util.List;
/**
* A basic interface for a class to be used with the StratificationStates system
* A basic interface for a class to be used with the StratificationManager system
*
* @author Mark DePristo
* @since 3/28/12
*/
public interface SetOfStates<Object> {
public interface Stratifier<Object> {
/**
* @return a list of all objects states that may be provided by this States provider
*/

View File

@ -12,10 +12,10 @@ import org.broadinstitute.sting.utils.variantcontext.VariantContext;
import java.util.*;
public class NewEvaluationContext extends HashMap<VariantStratifier, Object> {
public class NewEvaluationContext { // extends HashMap<VariantStratifier, Object> {
private Map<String, VariantEvaluator> evaluationInstances;
public void addEvaluationClassList(VariantEvalWalker walker, StateKey stateKey, Set<Class<? extends VariantEvaluator>> evaluationClasses) {
public void addEvaluationClassList(VariantEvalWalker walker, Set<Class<? extends VariantEvaluator>> evaluationClasses) {
evaluationInstances = new LinkedHashMap<String, VariantEvaluator>(evaluationClasses.size());
for ( final Class<? extends VariantEvaluator> c : evaluationClasses ) {
@ -36,16 +36,6 @@ public class NewEvaluationContext extends HashMap<VariantStratifier, Object> {
return new TreeMap<String, VariantEvaluator>(evaluationInstances);
}
public StateKey makeStateKey() {
Map<String, Object> map = new HashMap<String, Object>(size());
for (Map.Entry<VariantStratifier, Object> elt : this.entrySet() ) {
map.put(elt.getKey().getName(), elt.getValue());
}
return new StateKey(map);
}
public void apply(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context, VariantContext comp, VariantContext eval) {
for ( final VariantEvaluator evaluation : evaluationInstances.values() ) {
// the other updateN methods don't see a null context

View File

@ -1,104 +0,0 @@
package org.broadinstitute.sting.gatk.walkers.varianteval.util;
import java.util.Map;
import java.util.TreeMap;
/**
* A final constant class representing the specific state configuration
* for a VariantEvaluator instance.
*
* The way this is currently implemented is by a map from the name of a VariantStratification to a
* specific state string. For example, the stratification Novelty has states all, known, novel. A
* specific variant and comp would be tagged as "known" by the stratification, and this could be
* represented here by the map (Novelty -> known).
*
* TODO -- PERFORMANCE PROBLEM -- MAD 03/27/12
* TODO -- PERFORMANCE PROBLEM -- MAD 03/27/12
* TODO -- PERFORMANCE PROBLEM -- MAD 03/27/12
* TODO -- PERFORMANCE PROBLEM -- MAD 03/27/12
* TODO -- PERFORMANCE PROBLEM -- MAD 03/27/12
*
* I've been staring at this state key code for a while. It's just not right, and expensive to boot.
* Here are my thoughts for future work. The state key is both a key with specific state values for
* every stratification. For example, (known, sample1, ac=1). This capability is used in some places,
* such as below, to return a set of all states that should be updated given the eval and comp
* VCs. In principle there are a finite set of such combinations (the product of all states for all active
* stratifications at initialization). We could represent such keys as integers into the set of all combinations.
*
* Note that all of the code that manipulates these things is just terrible. It's all string manipulation and
* HashMaps. Since we are effectively always squaring off our VE analyses (i.e., we have a table with
* all variable values for all stratification combinations) it doesn't make sense to allow so much dynamicism. Instead
* we should just upfront create a giant table indexed by integer keys, and manage data via a simple map from
* specific strat state to this key.
*
* The reason this is so important is that >80% of the runtime of VE with VCFs with >1000 samples is spent in
* the initializeStateKey function. Instead, we should have code that looks like:
*
* init:
* allStates <- initializeCombinationalStateSpace
*
* map:
* for each eval / comp pair:
* for each relevantState based on eval / comp:
* allStates[relevantState].update(eval, comp)
*
*
*/
public final class StateKey {
/** High-performance cache of the toString operation for a constant class */
private final String string;
private final TreeMap<String, Object> states;
public StateKey(final Map<String, Object> states) {
this.states = new TreeMap<String, Object>(states);
this.string = formatString();
}
public StateKey(final StateKey toOverride, final String keyOverride, final Object valueOverride) {
if ( toOverride == null ) {
this.states = new TreeMap<String, Object>();
} else {
this.states = new TreeMap<String, Object>(toOverride.states);
}
this.states.put(keyOverride, valueOverride);
this.string = formatString();
}
@Override
public boolean equals(final Object o) {
if (this == o) return true;
if (o == null || getClass() != o.getClass()) return false;
final StateKey stateKey = (StateKey) o;
if (states != null ? !states.equals(stateKey.states) : stateKey.states != null) return false;
return true;
}
@Override
public int hashCode() {
return states.hashCode();
}
@Override
public String toString() {
return string;
}
private final String formatString() {
StringBuilder b = new StringBuilder();
for ( Map.Entry<String, Object> entry : states.entrySet() ) {
b.append(String.format("%s:%s;", entry.getKey(), entry.getValue()));
}
return b.toString();
}
// TODO -- might be slow because of tree map
public Object get(final String key) {
return states.get(key);
}
}

View File

@ -88,7 +88,7 @@ public class VariantEvalUtils {
* @return set of stratifications to use
*/
public List<VariantStratifier> initializeStratificationObjects(VariantEvalWalker variantEvalWalker, boolean noStandardStrats, String[] modulesToUse) {
List<VariantStratifier> strats = new ArrayList<VariantStratifier>();
TreeSet<VariantStratifier> strats = new TreeSet<VariantStratifier>();
Set<String> stratsToUse = new HashSet<String>();
// Create a map for all stratification modules for easy lookup.
@ -139,7 +139,7 @@ public class VariantEvalUtils {
}
}
return strats;
return new ArrayList<VariantStratifier>(strats);
}
/**

View File

@ -31,6 +31,7 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications.manage
import org.broadinstitute.sting.BaseTest;
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.collections.Pair;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
@ -53,7 +54,7 @@ public class StratificationManagerUnitTest extends BaseTest {
private class StratificationStatesTestProvider extends TestDataProvider {
final List<List<Object>> allStates = new ArrayList<List<Object>>();
final List<ListAsSetOfStates> asSetOfStates = new ArrayList<ListAsSetOfStates>();
final List<IntegerStratifier> asSetOfStates = new ArrayList<IntegerStratifier>();
final int nStates;
public StratificationStatesTestProvider(final List<Integer> ... allStates) {
@ -64,7 +65,7 @@ public class StratificationManagerUnitTest extends BaseTest {
}
for ( List<Object> states : this.allStates ) {
asSetOfStates.add(new ListAsSetOfStates(states));
asSetOfStates.add(new IntegerStratifier(states));
}
this.nStates = Utils.nCombinations(allStates);
@ -79,7 +80,7 @@ public class StratificationManagerUnitTest extends BaseTest {
return b.toString();
}
public List<ListAsSetOfStates> getStateSpaceList() {
public List<IntegerStratifier> getStateSpaceList() {
return asSetOfStates;
}
@ -118,10 +119,10 @@ public class StratificationManagerUnitTest extends BaseTest {
}
}
private class ListAsSetOfStates implements SetOfStates {
private class IntegerStratifier implements Stratifier {
final List<Object> integers;
private ListAsSetOfStates(final List<Object> integers) {
private IntegerStratifier(final List<Object> integers) {
this.integers = integers;
}
@ -144,8 +145,8 @@ public class StratificationManagerUnitTest extends BaseTest {
return StratificationStatesTestProvider.getTests(StratificationStatesTestProvider.class);
}
private final StratificationManager<ListAsSetOfStates, Integer> createManager(StratificationStatesTestProvider cfg) {
final StratificationManager<ListAsSetOfStates, Integer> manager = new StratificationManager<ListAsSetOfStates, Integer>(cfg.getStateSpaceList());
private final StratificationManager<IntegerStratifier, Integer> createManager(StratificationStatesTestProvider cfg) {
final StratificationManager<IntegerStratifier, Integer> manager = new StratificationManager<IntegerStratifier, Integer>(cfg.getStateSpaceList());
List<Integer> values = cfg.values();
for ( int i = 0; i < cfg.nStates; i++ )
manager.set(i, values.get(i));
@ -157,7 +158,7 @@ public class StratificationManagerUnitTest extends BaseTest {
@Test(dataProvider = "StratificationStatesTestProvider")
public void testLeafCount(StratificationStatesTestProvider cfg) {
final StratificationManager<ListAsSetOfStates, Integer> stratificationManager = createManager(cfg);
final StratificationManager<IntegerStratifier, Integer> stratificationManager = createManager(cfg);
Assert.assertEquals(stratificationManager.size(), cfg.nStates);
@ -171,7 +172,7 @@ public class StratificationManagerUnitTest extends BaseTest {
@Test(dataProvider = "StratificationStatesTestProvider")
public void testKeys(StratificationStatesTestProvider cfg) {
final StratificationManager<ListAsSetOfStates, Integer> stratificationManager = createManager(cfg);
final StratificationManager<IntegerStratifier, Integer> stratificationManager = createManager(cfg);
final Set<Integer> seenKeys = new HashSet<Integer>(cfg.nStates);
for ( final StratNode node : stratificationManager.getRoot() ) {
if ( node.isLeaf() ) {
@ -183,7 +184,7 @@ public class StratificationManagerUnitTest extends BaseTest {
@Test(dataProvider = "StratificationStatesTestProvider")
public void testFindSingleKeys(StratificationStatesTestProvider cfg) {
final StratificationManager<ListAsSetOfStates, Integer> stratificationManager = createManager(cfg);
final StratificationManager<IntegerStratifier, Integer> stratificationManager = createManager(cfg);
final Set<Integer> seenKeys = new HashSet<Integer>(cfg.nStates);
for ( List<Object> state : cfg.getAllCombinations() ) {
final int key = stratificationManager.getKey(state);
@ -203,7 +204,7 @@ public class StratificationManagerUnitTest extends BaseTest {
@Test(dataProvider = "StratificationStatesTestProvider")
public void testFindMultipleKeys(StratificationStatesTestProvider cfg) {
final StratificationManager<ListAsSetOfStates, Integer> stratificationManager = createManager(cfg);
final StratificationManager<IntegerStratifier, Integer> stratificationManager = createManager(cfg);
final List<List<Object>> states = new ArrayList<List<Object>>(cfg.allStates);
final Set<Integer> keys = stratificationManager.getKeys(states);
Assert.assertEquals(keys.size(), cfg.nStates, "Find all states didn't find all of the expected unique keys");
@ -230,8 +231,22 @@ public class StratificationManagerUnitTest extends BaseTest {
@Test(dataProvider = "StratificationStatesTestProvider")
public void testMapSet(StratificationStatesTestProvider cfg) {
final StratificationManager<ListAsSetOfStates, Integer> stratificationManager = createManager(cfg);
final StratificationManager<IntegerStratifier, Integer> stratificationManager = createManager(cfg);
stratificationManager.set(0, -1);
Assert.assertEquals((int)stratificationManager.get(0), -1);
}
@Test(dataProvider = "StratificationStatesTestProvider")
public void testStratifierByKey(StratificationStatesTestProvider cfg) {
final StratificationManager<IntegerStratifier, Integer> manager = createManager(cfg);
for ( int key = 0; key < cfg.nStates; key++ ) {
List<Pair<IntegerStratifier, Object>> stratsAndStates = manager.getStratsAndStatesForKey(key);
final List<Object> strats = manager.getStatesForKey(key);
Assert.assertEquals((int)manager.get(strats), key, "Key -> strats -> key failed to return same key");
for ( int i = 0; i < strats.size(); i++ ) {
Assert.assertEquals(stratsAndStates.get(i).getSecond(), strats.get(i), "Strats and StratsAndStates differ");
}
}
}
}