diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/SetOfStates.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/SetOfStates.java
index 564aeaef3..30b432c63 100644
--- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/SetOfStates.java
+++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/SetOfStates.java
@@ -27,32 +27,14 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications;
import java.util.List;
/**
-* [Short one sentence description of this walker]
-*
-*
-* [Functionality of this walker]
-*
-*
-* Input
-*
-* [Input description]
-*
-*
-* Output
-*
-* [Output description]
-*
-*
-* Examples
-*
-* java
-* -jar GenomeAnalysisTK.jar
-* -T $WalkerName
-*
-*
-* @author Your Name
-* @since Date created
-*/
-public interface SetOfStates {
- public List getAllStates();
+ * A basic interface for a class to be used with the StratificationStates system
+ *
+ * @author Mark DePristo
+ * @since 3/28/12
+ */
+public interface SetOfStates {
+ /**
+ * @return a list of all objects states that may be provided by this States provider
+ */
+ public List getAllStates();
}
diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratNode.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratNode.java
index 1a7e2dde7..f350df47d 100644
--- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratNode.java
+++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratNode.java
@@ -24,12 +24,12 @@
package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications;
+import com.google.java.contract.Ensures;
+import com.google.java.contract.Invariant;
+import com.google.java.contract.Requires;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
-import java.util.Collections;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
+import java.util.*;
/**
* Helper class representing a tree of stratification splits, where leaf nodes
@@ -49,35 +49,48 @@ import java.util.Map;
* This code allows us to efficiently look up a state key (A=2, B=3) and map it
* to a specific key (an integer) that's unique over the tree
*
+ * Note the structure of this tree is that the keys are -1 for all internal nodes, and
+ * leafs are the only nodes with meaningful keys. So for a tree with 2N nodes N of these
+ * will be internal, with no keys, and meaningful maps from states -> subtrees. The
+ * other N nodes are leafs, with meaningful keys, empty maps, and null stratification objects
+ *
* @author Mark DePristo
* @since 3/27/12
*/
-public class StratNode implements Iterable> {
+@Invariant({
+ "(isLeaf() && stratifier == null && subnodes.isEmpty()) || (!isLeaf() && stratifier != null && !subnodes.isEmpty())"})
+class StratNode implements Iterable> {
int key = -1;
final T stratifier;
- final Map> subnodes;
+ final Map> subnodes;
- public StratNode() {
+ protected StratNode() {
this.subnodes = Collections.emptyMap();
this.stratifier = null;
}
- StratNode(final T stratifier, final Map> subnodes) {
+ protected StratNode(final T stratifier, final Map> subnodes) {
this.stratifier = stratifier;
this.subnodes = subnodes;
}
+ @Requires("key >= 0")
public void setKey(final int key) {
if ( ! isLeaf() )
throw new ReviewedStingException("Cannot set key of non-leaf node");
this.key = key;
}
- public int find(final List states, int offset) {
+ @Requires({
+ "states != null",
+ "offset >= 0",
+ "offset <= states.size()"
+ })
+ public int find(final List states, int offset) {
if ( isLeaf() ) // we're here!
return key;
else {
- final String state = states.get(offset);
+ final Object state = states.get(offset);
StratNode subnode = subnodes.get(state);
if ( subnode == null )
throw new ReviewedStingException("Couldn't find state for " + state + " at node " + this);
@@ -86,6 +99,28 @@ public class StratNode implements Iterable>
}
}
+ @Requires({
+ "multipleStates != null",
+ "offset >= 0",
+ "offset <= multipleStates.size()",
+ "keys != null",
+ "offset == multipleStates.size() || multipleStates.get(offset) != null"})
+ public void find(final List> multipleStates, final int offset, final HashSet keys) {
+ if ( isLeaf() ) // we're here!
+ keys.add(key);
+ else {
+ for ( final Object state : multipleStates.get(offset) ) {
+ // loop over all of the states at this offset
+ final StratNode subnode = subnodes.get(state);
+ if ( subnode == null )
+ throw new ReviewedStingException("Couldn't find state for " + state + " at node " + this);
+ else
+ subnode.find(multipleStates, offset+1, keys);
+ }
+ }
+ }
+
+ @Ensures("result >= 0")
public int getKey() {
if ( ! isLeaf() )
throw new ReviewedStingException("Cannot get key of non-leaf node");
@@ -93,10 +128,11 @@ public class StratNode implements Iterable>
return key;
}
- protected Map> getSubnodes() {
+ protected Map> getSubnodes() {
return subnodes;
}
+ @Ensures("result >= 0")
public int size() {
if ( isLeaf() )
return 1;
@@ -109,9 +145,19 @@ public class StratNode implements Iterable>
return stratifier;
}
- public boolean isLeaf() { return stratifier == null; }
+ /**
+ * @return true if this node is a leaf
+ */
+ public boolean isLeaf() {
+ return stratifier == null;
+ }
+ /**
+ * Returns an iterator over this node and all subnodes including internal and leaf nodes
+ * @return
+ */
@Override
+ @Ensures("result != null")
public Iterator> iterator() {
return new StratNodeIterator(this);
}
diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStates.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStates.java
index 7f1c75fa9..b6ee7d807 100644
--- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStates.java
+++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStates.java
@@ -50,9 +50,9 @@ public class StratificationStates {
return new StratNode();
} else {
// we are in the middle of the tree
- final Collection states = first.getAllStates();
- final LinkedHashMap> subNodes = new LinkedHashMap>(states.size());
- for ( final String state : states ) {
+ final Collection states = first.getAllStates();
+ final LinkedHashMap> subNodes = new LinkedHashMap>(states.size());
+ for ( final Object state : states ) {
// have to copy because poll modifies the queue
final Queue copy = new LinkedList(strats);
subNodes.put(state, buildStratificationTree(copy));
@@ -64,19 +64,38 @@ public class StratificationStates {
public int getNStates() {
return root.size();
}
-
+
public StratNode getRoot() {
return root;
}
- public int getKey(final List states) {
+ public int getKey(final List states) {
return root.find(states, 0);
}
+ public Set getKeys(final List> allStates) {
+ final HashSet keys = new HashSet();
+ root.find(allStates, 0, keys);
+ return keys;
+ }
+
private void assignKeys(final StratNode root, int key) {
for ( final StratNode node : root ) {
if ( node.isLeaf() )
node.setKey(key++);
}
}
+
+ public static List> combineStates(final List first, final List second) {
+ List> combined = new ArrayList>(first.size());
+ for ( int i = 0; i < first.size(); i++ ) {
+ final Object firstI = first.get(i);
+ final Object secondI = second.get(i);
+ if ( firstI.equals(secondI) )
+ combined.add(Collections.singletonList(firstI));
+ else
+ combined.add(Arrays.asList(firstI, secondI));
+ }
+ return combined;
+ }
}
diff --git a/public/java/src/org/broadinstitute/sting/utils/Utils.java b/public/java/src/org/broadinstitute/sting/utils/Utils.java
index 130a7fa2f..f91066b0c 100755
--- a/public/java/src/org/broadinstitute/sting/utils/Utils.java
+++ b/public/java/src/org/broadinstitute/sting/utils/Utils.java
@@ -25,6 +25,7 @@
package org.broadinstitute.sting.utils;
+import com.google.java.contract.Requires;
import net.sf.samtools.SAMFileHeader;
import net.sf.samtools.SAMProgramRecord;
import net.sf.samtools.util.StringUtil;
@@ -710,4 +711,36 @@ public class Utils {
}
return list;
}
+
+ /**
+ * Returns the number of combinations represented by this collection
+ * of collection of options.
+ *
+ * For example, if this is [[A, B], [C, D], [E, F, G]] returns 2 * 2 * 3 = 12
+ *
+ * @param options
+ * @param
+ * @return
+ */
+ @Requires("options != null")
+ public static int nCombinations(final Collection[] options) {
+ int nStates = 1;
+ for ( Collection states : options ) {
+ nStates *= states.size();
+ }
+ return nStates;
+ }
+
+ @Requires("options != null")
+ public static int nCombinations(final List> options) {
+ if ( options.isEmpty() )
+ return 0;
+ else {
+ int nStates = 1;
+ for ( Collection states : options ) {
+ nStates *= states.size();
+ }
+ return nStates;
+ }
+ }
}
diff --git a/public/java/test/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStatesUnitTest.java b/public/java/test/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStatesUnitTest.java
index 946aef4a9..d6291b812 100644
--- a/public/java/test/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStatesUnitTest.java
+++ b/public/java/test/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStatesUnitTest.java
@@ -30,6 +30,7 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications;
import org.broadinstitute.sting.BaseTest;
+import org.broadinstitute.sting.utils.Utils;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.DataProvider;
@@ -51,46 +52,58 @@ public class StratificationStatesUnitTest extends BaseTest {
// --------------------------------------------------------------------------------
private class StratificationStatesTestProvider extends TestDataProvider {
- final List> allStates;
+ final List> allStates = new ArrayList>();
final List asSetOfStates = new ArrayList();
final int nStates;
public StratificationStatesTestProvider(final List ... allStates) {
super(StratificationStatesTestProvider.class);
- this.allStates = Arrays.asList(allStates);
+
+ for ( List states : allStates ) {
+ this.allStates.add(new ArrayList(states));
+ }
- int nStates = 1;
- for ( List states : this.allStates ) {
- nStates *= states.size();
+ for ( List states : this.allStates ) {
asSetOfStates.add(new ListAsSetOfStates(states));
}
- this.nStates = nStates;
- }
-// private String getName() {
-// return String.format("probs=%s expectedRegions=%s", Utils.join(",", probs), Utils.join(",", expectedRegions));
-// }
+ this.nStates = Utils.nCombinations(allStates);
+ setName(getName());
+ }
+
+ private String getName() {
+ StringBuilder b = new StringBuilder();
+ int c = 1;
+ for ( List state : allStates )
+ b.append(String.format("%d = [%s] ", c++, Utils.join(",", state)));
+ return b.toString();
+ }
+
public List getStateSpaceList() {
return asSetOfStates;
}
- public Queue> getAllCombinations() {
- return getAllCombinations(new LinkedList>(allStates));
+ public Queue> getAllCombinations() {
+ return getAllCombinations(new LinkedList>(allStates));
}
- private Queue> getAllCombinations(Queue> states) {
+ private Queue> getAllCombinations(Queue> states) {
if ( states.isEmpty() )
- return new LinkedList>();
+ return new LinkedList>();
else {
- List head = states.poll();
- Queue> substates = getAllCombinations(states);
- Queue> newStates = new LinkedList>();
- for ( int e : head) {
- for ( List state : substates ) {
- List newState = new LinkedList();
- newState.add(Integer.toString(e));
- newState.addAll(state);
- newStates.add(newState);
+ List head = states.poll();
+ Queue> substates = getAllCombinations(states);
+ Queue> newStates = new LinkedList>();
+ for ( final Object e : head) {
+ if ( substates.isEmpty() ) {
+ newStates.add(new LinkedList(Collections.singleton(e)));
+ } else {
+ for ( final List state : substates ) {
+ List newState = new LinkedList();
+ newState.add(e);
+ newState.addAll(state);
+ newStates.add(newState);
+ }
}
}
return newStates;
@@ -99,16 +112,14 @@ public class StratificationStatesUnitTest extends BaseTest {
}
private class ListAsSetOfStates implements SetOfStates {
- final List integers;
+ final List integers;
- private ListAsSetOfStates(final List integers) {
- this.integers = new ArrayList(integers.size());
- for ( int i : integers )
- this.integers.add(Integer.toString(i));
+ private ListAsSetOfStates(final List integers) {
+ this.integers = integers;
}
-
+
@Override
- public List getAllStates() {
+ public List getAllStates() {
return integers;
}
}
@@ -127,8 +138,8 @@ public class StratificationStatesUnitTest extends BaseTest {
}
@Test(dataProvider = "StratificationStatesTestProvider")
- public void testStratificationStatesTestProvider(StratificationStatesTestProvider cfg) {
- StratificationStates stratificationStates = new StratificationStates(cfg.getStateSpaceList());
+ public void testLeafCount(StratificationStatesTestProvider cfg) {
+ final StratificationStates stratificationStates = new StratificationStates(cfg.getStateSpaceList());
Assert.assertEquals(stratificationStates.getNStates(), cfg.nStates);
@@ -138,20 +149,55 @@ public class StratificationStatesUnitTest extends BaseTest {
nLeafs++;
}
Assert.assertEquals(nLeafs, cfg.nStates, "Unexpected number of leaves");
-
- Set seenKeys = new HashSet(cfg.nStates);
+ }
+
+ @Test(dataProvider = "StratificationStatesTestProvider")
+ public void testKeys(StratificationStatesTestProvider cfg) {
+ final StratificationStates stratificationStates = new StratificationStates(cfg.getStateSpaceList());
+ final Set seenKeys = new HashSet(cfg.nStates);
for ( final StratNode node : stratificationStates.getRoot() ) {
if ( node.isLeaf() ) {
Assert.assertFalse(seenKeys.contains(node.getKey()), "Already seen the key");
seenKeys.add(node.getKey());
}
}
+ }
- seenKeys.clear();
- for ( List state : cfg.getAllCombinations() ) {
+ @Test(dataProvider = "StratificationStatesTestProvider")
+ public void testFindSingleKeys(StratificationStatesTestProvider cfg) {
+ final StratificationStates stratificationStates = new StratificationStates(cfg.getStateSpaceList());
+ final Set seenKeys = new HashSet(cfg.nStates);
+ for ( List state : cfg.getAllCombinations() ) {
final int key = stratificationStates.getKey(state);
Assert.assertFalse(seenKeys.contains(key), "Already saw state mapping to this key");
seenKeys.add(key);
}
}
+
+ @Test(dataProvider = "StratificationStatesTestProvider")
+ public void testFindMultipleKeys(StratificationStatesTestProvider cfg) {
+ final StratificationStates stratificationStates = new StratificationStates(cfg.getStateSpaceList());
+ final List> states = new ArrayList>(cfg.allStates);
+ final Set keys = stratificationStates.getKeys(states);
+ Assert.assertEquals(keys.size(), cfg.nStates, "Find all states didn't find all of the expected unique keys");
+
+ final Queue> combinations = cfg.getAllCombinations();
+ while ( ! combinations.isEmpty() ) {
+ List first = combinations.poll();
+ List second = combinations.peek();
+ if ( second != null ) {
+ List> combined = StratificationStates.combineStates(first, second);
+ int nExpectedKeys = Utils.nCombinations(combined);
+
+ final int key1 = stratificationStates.getKey(first);
+ final int key2 = stratificationStates.getKey(second);
+ final Set keysCombined = stratificationStates.getKeys(combined);
+
+ Assert.assertTrue(keysCombined.contains(key1), "couldn't find key in data set");
+ Assert.assertTrue(keysCombined.contains(key2), "couldn't find key in data set");
+
+ Assert.assertEquals(keysCombined.size(), nExpectedKeys);
+ }
+ }
+ }
}
\ No newline at end of file