diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/SetOfStates.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/SetOfStates.java index 564aeaef3..30b432c63 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/SetOfStates.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/SetOfStates.java @@ -27,32 +27,14 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications; import java.util.List; /** -* [Short one sentence description of this walker] -*

-*

-* [Functionality of this walker] -*

-*

-*

Input

-*

-* [Input description] -*

-*

-*

Output

-*

-* [Output description] -*

-*

-*

Examples

-*
-*    java
-*      -jar GenomeAnalysisTK.jar
-*      -T $WalkerName
-*  
-* -* @author Your Name -* @since Date created -*/ -public interface SetOfStates { - public List getAllStates(); + * A basic interface for a class to be used with the StratificationStates system + * + * @author Mark DePristo + * @since 3/28/12 + */ +public interface SetOfStates { + /** + * @return a list of all objects states that may be provided by this States provider + */ + public List getAllStates(); } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratNode.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratNode.java index 1a7e2dde7..f350df47d 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratNode.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratNode.java @@ -24,12 +24,12 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications; +import com.google.java.contract.Ensures; +import com.google.java.contract.Invariant; +import com.google.java.contract.Requires; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; -import java.util.Collections; -import java.util.Iterator; -import java.util.List; -import java.util.Map; +import java.util.*; /** * Helper class representing a tree of stratification splits, where leaf nodes @@ -49,35 +49,48 @@ import java.util.Map; * This code allows us to efficiently look up a state key (A=2, B=3) and map it * to a specific key (an integer) that's unique over the tree * + * Note the structure of this tree is that the keys are -1 for all internal nodes, and + * leafs are the only nodes with meaningful keys. So for a tree with 2N nodes N of these + * will be internal, with no keys, and meaningful maps from states -> subtrees. The + * other N nodes are leafs, with meaningful keys, empty maps, and null stratification objects + * * @author Mark DePristo * @since 3/27/12 */ -public class StratNode implements Iterable> { +@Invariant({ + "(isLeaf() && stratifier == null && subnodes.isEmpty()) || (!isLeaf() && stratifier != null && !subnodes.isEmpty())"}) +class StratNode implements Iterable> { int key = -1; final T stratifier; - final Map> subnodes; + final Map> subnodes; - public StratNode() { + protected StratNode() { this.subnodes = Collections.emptyMap(); this.stratifier = null; } - StratNode(final T stratifier, final Map> subnodes) { + protected StratNode(final T stratifier, final Map> subnodes) { this.stratifier = stratifier; this.subnodes = subnodes; } + @Requires("key >= 0") public void setKey(final int key) { if ( ! isLeaf() ) throw new ReviewedStingException("Cannot set key of non-leaf node"); this.key = key; } - public int find(final List states, int offset) { + @Requires({ + "states != null", + "offset >= 0", + "offset <= states.size()" + }) + public int find(final List states, int offset) { if ( isLeaf() ) // we're here! return key; else { - final String state = states.get(offset); + final Object state = states.get(offset); StratNode subnode = subnodes.get(state); if ( subnode == null ) throw new ReviewedStingException("Couldn't find state for " + state + " at node " + this); @@ -86,6 +99,28 @@ public class StratNode implements Iterable> } } + @Requires({ + "multipleStates != null", + "offset >= 0", + "offset <= multipleStates.size()", + "keys != null", + "offset == multipleStates.size() || multipleStates.get(offset) != null"}) + public void find(final List> multipleStates, final int offset, final HashSet keys) { + if ( isLeaf() ) // we're here! + keys.add(key); + else { + for ( final Object state : multipleStates.get(offset) ) { + // loop over all of the states at this offset + final StratNode subnode = subnodes.get(state); + if ( subnode == null ) + throw new ReviewedStingException("Couldn't find state for " + state + " at node " + this); + else + subnode.find(multipleStates, offset+1, keys); + } + } + } + + @Ensures("result >= 0") public int getKey() { if ( ! isLeaf() ) throw new ReviewedStingException("Cannot get key of non-leaf node"); @@ -93,10 +128,11 @@ public class StratNode implements Iterable> return key; } - protected Map> getSubnodes() { + protected Map> getSubnodes() { return subnodes; } + @Ensures("result >= 0") public int size() { if ( isLeaf() ) return 1; @@ -109,9 +145,19 @@ public class StratNode implements Iterable> return stratifier; } - public boolean isLeaf() { return stratifier == null; } + /** + * @return true if this node is a leaf + */ + public boolean isLeaf() { + return stratifier == null; + } + /** + * Returns an iterator over this node and all subnodes including internal and leaf nodes + * @return + */ @Override + @Ensures("result != null") public Iterator> iterator() { return new StratNodeIterator(this); } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStates.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStates.java index 7f1c75fa9..b6ee7d807 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStates.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStates.java @@ -50,9 +50,9 @@ public class StratificationStates { return new StratNode(); } else { // we are in the middle of the tree - final Collection states = first.getAllStates(); - final LinkedHashMap> subNodes = new LinkedHashMap>(states.size()); - for ( final String state : states ) { + final Collection states = first.getAllStates(); + final LinkedHashMap> subNodes = new LinkedHashMap>(states.size()); + for ( final Object state : states ) { // have to copy because poll modifies the queue final Queue copy = new LinkedList(strats); subNodes.put(state, buildStratificationTree(copy)); @@ -64,19 +64,38 @@ public class StratificationStates { public int getNStates() { return root.size(); } - + public StratNode getRoot() { return root; } - public int getKey(final List states) { + public int getKey(final List states) { return root.find(states, 0); } + public Set getKeys(final List> allStates) { + final HashSet keys = new HashSet(); + root.find(allStates, 0, keys); + return keys; + } + private void assignKeys(final StratNode root, int key) { for ( final StratNode node : root ) { if ( node.isLeaf() ) node.setKey(key++); } } + + public static List> combineStates(final List first, final List second) { + List> combined = new ArrayList>(first.size()); + for ( int i = 0; i < first.size(); i++ ) { + final Object firstI = first.get(i); + final Object secondI = second.get(i); + if ( firstI.equals(secondI) ) + combined.add(Collections.singletonList(firstI)); + else + combined.add(Arrays.asList(firstI, secondI)); + } + return combined; + } } diff --git a/public/java/src/org/broadinstitute/sting/utils/Utils.java b/public/java/src/org/broadinstitute/sting/utils/Utils.java index 130a7fa2f..f91066b0c 100755 --- a/public/java/src/org/broadinstitute/sting/utils/Utils.java +++ b/public/java/src/org/broadinstitute/sting/utils/Utils.java @@ -25,6 +25,7 @@ package org.broadinstitute.sting.utils; +import com.google.java.contract.Requires; import net.sf.samtools.SAMFileHeader; import net.sf.samtools.SAMProgramRecord; import net.sf.samtools.util.StringUtil; @@ -710,4 +711,36 @@ public class Utils { } return list; } + + /** + * Returns the number of combinations represented by this collection + * of collection of options. + * + * For example, if this is [[A, B], [C, D], [E, F, G]] returns 2 * 2 * 3 = 12 + * + * @param options + * @param + * @return + */ + @Requires("options != null") + public static int nCombinations(final Collection[] options) { + int nStates = 1; + for ( Collection states : options ) { + nStates *= states.size(); + } + return nStates; + } + + @Requires("options != null") + public static int nCombinations(final List> options) { + if ( options.isEmpty() ) + return 0; + else { + int nStates = 1; + for ( Collection states : options ) { + nStates *= states.size(); + } + return nStates; + } + } } diff --git a/public/java/test/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStatesUnitTest.java b/public/java/test/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStatesUnitTest.java index 946aef4a9..d6291b812 100644 --- a/public/java/test/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStatesUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/gatk/walkers/varianteval/stratifications/StratificationStatesUnitTest.java @@ -30,6 +30,7 @@ package org.broadinstitute.sting.gatk.walkers.varianteval.stratifications; import org.broadinstitute.sting.BaseTest; +import org.broadinstitute.sting.utils.Utils; import org.testng.Assert; import org.testng.annotations.BeforeClass; import org.testng.annotations.DataProvider; @@ -51,46 +52,58 @@ public class StratificationStatesUnitTest extends BaseTest { // -------------------------------------------------------------------------------- private class StratificationStatesTestProvider extends TestDataProvider { - final List> allStates; + final List> allStates = new ArrayList>(); final List asSetOfStates = new ArrayList(); final int nStates; public StratificationStatesTestProvider(final List ... allStates) { super(StratificationStatesTestProvider.class); - this.allStates = Arrays.asList(allStates); + + for ( List states : allStates ) { + this.allStates.add(new ArrayList(states)); + } - int nStates = 1; - for ( List states : this.allStates ) { - nStates *= states.size(); + for ( List states : this.allStates ) { asSetOfStates.add(new ListAsSetOfStates(states)); } - this.nStates = nStates; - } -// private String getName() { -// return String.format("probs=%s expectedRegions=%s", Utils.join(",", probs), Utils.join(",", expectedRegions)); -// } + this.nStates = Utils.nCombinations(allStates); + setName(getName()); + } + + private String getName() { + StringBuilder b = new StringBuilder(); + int c = 1; + for ( List state : allStates ) + b.append(String.format("%d = [%s] ", c++, Utils.join(",", state))); + return b.toString(); + } + public List getStateSpaceList() { return asSetOfStates; } - public Queue> getAllCombinations() { - return getAllCombinations(new LinkedList>(allStates)); + public Queue> getAllCombinations() { + return getAllCombinations(new LinkedList>(allStates)); } - private Queue> getAllCombinations(Queue> states) { + private Queue> getAllCombinations(Queue> states) { if ( states.isEmpty() ) - return new LinkedList>(); + return new LinkedList>(); else { - List head = states.poll(); - Queue> substates = getAllCombinations(states); - Queue> newStates = new LinkedList>(); - for ( int e : head) { - for ( List state : substates ) { - List newState = new LinkedList(); - newState.add(Integer.toString(e)); - newState.addAll(state); - newStates.add(newState); + List head = states.poll(); + Queue> substates = getAllCombinations(states); + Queue> newStates = new LinkedList>(); + for ( final Object e : head) { + if ( substates.isEmpty() ) { + newStates.add(new LinkedList(Collections.singleton(e))); + } else { + for ( final List state : substates ) { + List newState = new LinkedList(); + newState.add(e); + newState.addAll(state); + newStates.add(newState); + } } } return newStates; @@ -99,16 +112,14 @@ public class StratificationStatesUnitTest extends BaseTest { } private class ListAsSetOfStates implements SetOfStates { - final List integers; + final List integers; - private ListAsSetOfStates(final List integers) { - this.integers = new ArrayList(integers.size()); - for ( int i : integers ) - this.integers.add(Integer.toString(i)); + private ListAsSetOfStates(final List integers) { + this.integers = integers; } - + @Override - public List getAllStates() { + public List getAllStates() { return integers; } } @@ -127,8 +138,8 @@ public class StratificationStatesUnitTest extends BaseTest { } @Test(dataProvider = "StratificationStatesTestProvider") - public void testStratificationStatesTestProvider(StratificationStatesTestProvider cfg) { - StratificationStates stratificationStates = new StratificationStates(cfg.getStateSpaceList()); + public void testLeafCount(StratificationStatesTestProvider cfg) { + final StratificationStates stratificationStates = new StratificationStates(cfg.getStateSpaceList()); Assert.assertEquals(stratificationStates.getNStates(), cfg.nStates); @@ -138,20 +149,55 @@ public class StratificationStatesUnitTest extends BaseTest { nLeafs++; } Assert.assertEquals(nLeafs, cfg.nStates, "Unexpected number of leaves"); - - Set seenKeys = new HashSet(cfg.nStates); + } + + @Test(dataProvider = "StratificationStatesTestProvider") + public void testKeys(StratificationStatesTestProvider cfg) { + final StratificationStates stratificationStates = new StratificationStates(cfg.getStateSpaceList()); + final Set seenKeys = new HashSet(cfg.nStates); for ( final StratNode node : stratificationStates.getRoot() ) { if ( node.isLeaf() ) { Assert.assertFalse(seenKeys.contains(node.getKey()), "Already seen the key"); seenKeys.add(node.getKey()); } } + } - seenKeys.clear(); - for ( List state : cfg.getAllCombinations() ) { + @Test(dataProvider = "StratificationStatesTestProvider") + public void testFindSingleKeys(StratificationStatesTestProvider cfg) { + final StratificationStates stratificationStates = new StratificationStates(cfg.getStateSpaceList()); + final Set seenKeys = new HashSet(cfg.nStates); + for ( List state : cfg.getAllCombinations() ) { final int key = stratificationStates.getKey(state); Assert.assertFalse(seenKeys.contains(key), "Already saw state mapping to this key"); seenKeys.add(key); } } + + @Test(dataProvider = "StratificationStatesTestProvider") + public void testFindMultipleKeys(StratificationStatesTestProvider cfg) { + final StratificationStates stratificationStates = new StratificationStates(cfg.getStateSpaceList()); + final List> states = new ArrayList>(cfg.allStates); + final Set keys = stratificationStates.getKeys(states); + Assert.assertEquals(keys.size(), cfg.nStates, "Find all states didn't find all of the expected unique keys"); + + final Queue> combinations = cfg.getAllCombinations(); + while ( ! combinations.isEmpty() ) { + List first = combinations.poll(); + List second = combinations.peek(); + if ( second != null ) { + List> combined = StratificationStates.combineStates(first, second); + int nExpectedKeys = Utils.nCombinations(combined); + + final int key1 = stratificationStates.getKey(first); + final int key2 = stratificationStates.getKey(second); + final Set keysCombined = stratificationStates.getKeys(combined); + + Assert.assertTrue(keysCombined.contains(key1), "couldn't find key in data set"); + Assert.assertTrue(keysCombined.contains(key2), "couldn't find key in data set"); + + Assert.assertEquals(keysCombined.size(), nExpectedKeys); + } + } + } } \ No newline at end of file