From 9c7a35f73fe5bcff75abbad10af3065bf589e381 Mon Sep 17 00:00:00 2001 From: Mark DePristo Date: Thu, 4 Apr 2013 15:15:10 -0400 Subject: [PATCH] HaplotypeCaller no longer creates haplotypes that involve cycles in the SeqGraph -- The kbest paths algorithm now takes an explicit set of starting and ending vertices, which is conceptually cleaner and works for either the cycle or no-cycle models. Allowing cycles can be re-enabled with an HC command line switch. --- .../haplotypecaller/DeBruijnAssembler.java | 14 ++- .../haplotypecaller/HaplotypeCaller.java | 7 +- .../haplotypecaller/graphs/BaseGraph.java | 24 ++++ .../haplotypecaller/graphs/KBestPaths.java | 119 +++++++++++------- .../walkers/haplotypecaller/graphs/Path.java | 13 ++ .../graphs/KBestPathsUnitTest.java | 75 +++++++---- 6 files changed, 183 insertions(+), 69 deletions(-) diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/DeBruijnAssembler.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/DeBruijnAssembler.java index 40a6a79e0..11701a73b 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/DeBruijnAssembler.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/DeBruijnAssembler.java @@ -95,22 +95,25 @@ public class DeBruijnAssembler extends LocalAssemblyEngine { private final boolean debug; private final boolean debugGraphTransformations; private final int minKmer; + private final boolean allowCyclesInKmerGraphToGeneratePaths; private final int onlyBuildKmersOfThisSizeWhenDebuggingGraphAlgorithms; protected DeBruijnAssembler() { - this(false, -1, 11); + this(false, -1, 11, false); } public DeBruijnAssembler(final boolean debug, final int debugGraphTransformations, - final int minKmer) { + final int minKmer, + final boolean allowCyclesInKmerGraphToGeneratePaths) { super(); this.debug = debug; this.debugGraphTransformations = debugGraphTransformations > 0; this.onlyBuildKmersOfThisSizeWhenDebuggingGraphAlgorithms = debugGraphTransformations; this.minKmer = minKmer; + this.allowCyclesInKmerGraphToGeneratePaths = allowCyclesInKmerGraphToGeneratePaths; } /** @@ -388,7 +391,12 @@ public class DeBruijnAssembler extends LocalAssemblyEngine { } for( final SeqGraph graph : graphs ) { - for ( final Path path : new KBestPaths().getKBestPaths(graph, NUM_BEST_PATHS_PER_KMER_GRAPH) ) { + final SeqVertex source = graph.getReferenceSourceVertex(); + final SeqVertex sink = graph.getReferenceSinkVertex(); + if ( source == null || sink == null ) throw new IllegalArgumentException("Both source and sink cannot be null but got " + source + " and sink " + sink + " for graph "+ graph); + + final KBestPaths pathFinder = new KBestPaths(allowCyclesInKmerGraphToGeneratePaths); + for ( final Path path : pathFinder.getKBestPaths(graph, NUM_BEST_PATHS_PER_KMER_GRAPH, source, sink) ) { // logger.info("Found path " + path); Haplotype h = new Haplotype( path.getBases() ); if( !returnHaplotypes.contains(h) ) { diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCaller.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCaller.java index bce179ee1..80276f7be 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCaller.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCaller.java @@ -314,6 +314,11 @@ public class HaplotypeCaller extends ActiveRegionWalker implem @Argument(fullName="trimActiveRegions", shortName="trimActiveRegions", doc="If specified, we will trim down the active region from the full region (active + extension) to just the active interval for genotyping", required = false) protected boolean trimActiveRegions = false; + @Hidden + @Argument(fullName="allowCyclesInKmerGraphToGeneratePaths", shortName="allowCyclesInKmerGraphToGeneratePaths", doc="If specified, we will allow cycles in the kmer graphs to generate paths with multiple copies of the path sequenece rather than just the shortest paths", required = false) + protected boolean allowCyclesInKmerGraphToGeneratePaths = false; + + // the UG engines private UnifiedGenotyperEngine UG_engine = null; private UnifiedGenotyperEngine UG_engine_simple_genotyper = null; @@ -424,7 +429,7 @@ public class HaplotypeCaller extends ActiveRegionWalker implem } // setup the assembler - assemblyEngine = new DeBruijnAssembler( DEBUG, debugGraphTransformations, minKmer); + assemblyEngine = new DeBruijnAssembler(DEBUG, debugGraphTransformations, minKmer, allowCyclesInKmerGraphToGeneratePaths); assemblyEngine.setErrorCorrectKmers(errorCorrectKmers); assemblyEngine.setPruneFactor(MIN_PRUNE_FACTOR); if ( graphWriter != null ) assemblyEngine.setGraphWriter(graphWriter); diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/BaseGraph.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/BaseGraph.java index 5d591fd5c..7ce57e2e7 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/BaseGraph.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/BaseGraph.java @@ -137,6 +137,30 @@ public class BaseGraph extends DefaultDirectedGraph getSources() { + final Set set = new LinkedHashSet(); + for ( final T v : vertexSet() ) + if ( isSource(v) ) + set.add(v); + return set; + } + + /** + * Get the set of sink vertices of this graph + * @return a non-null set + */ + public Set getSinks() { + final Set set = new LinkedHashSet(); + for ( final T v : vertexSet() ) + if ( isSink(v) ) + set.add(v); + return set; + } + /** * Pull out the additional sequence implied by traversing this node in the graph * @param v the vertex from which to pull out the additional base sequence diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/KBestPaths.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/KBestPaths.java index 1dc712c67..466148588 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/KBestPaths.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/KBestPaths.java @@ -50,10 +50,7 @@ import com.google.common.collect.MinMaxPriorityQueue; import com.google.java.contract.Ensures; import java.io.Serializable; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Comparator; -import java.util.List; +import java.util.*; /** * Class for finding the K best paths (as determined by the sum of multiplicities of the edges) in a graph. @@ -63,7 +60,23 @@ import java.util.List; * Date: Mar 23, 2011 */ public class KBestPaths { - public KBestPaths() { } + private final boolean allowCycles; + + /** + * Create a new KBestPaths finder that follows cycles in the graph + */ + public KBestPaths() { + this(true); + } + + /** + * Create a new KBestPaths finder + * + * @param allowCycles should we allow paths that follow cycles in the graph? + */ + public KBestPaths(final boolean allowCycles) { + this.allowCycles = allowCycles; + } protected static class MyInt { public int val = 0; } @@ -78,31 +91,61 @@ public class KBestPaths { } /** - * @see #getKBestPaths(BaseGraph, int) retriving the first 1000 paths + * @see #getKBestPaths(BaseGraph, int) retriving the best 1000 paths */ public List> getKBestPaths( final BaseGraph graph ) { return getKBestPaths(graph, 1000); } /** - * Traverse the graph and pull out the best k paths. - * Paths are scored via their comparator function. The default being PathComparatorTotalScore() - * @param graph the graph from which to pull paths - * @param k the number of paths to find - * @return a list with at most k top-scoring paths from the graph + * @see #getKBestPaths(BaseGraph, int, java.util.Set, java.util.Set) retriving the first 1000 paths + * starting from all source vertices and ending with all sink vertices */ - @Ensures({"result != null", "result.size() <= k"}) public List> getKBestPaths( final BaseGraph graph, final int k ) { + return getKBestPaths(graph, k, graph.getSources(), graph.getSinks()); + } + + /** + * @see #getKBestPaths(BaseGraph, int, java.util.Set, java.util.Set) with k=1000 + */ + public List> getKBestPaths( final BaseGraph graph, final Set sources, final Set sinks ) { + return getKBestPaths(graph, 1000, sources, sinks); + } + + /** + * @see #getKBestPaths(BaseGraph, int, java.util.Set, java.util.Set) with k=1000 + */ + public List> getKBestPaths( final BaseGraph graph, final T source, final T sink ) { + return getKBestPaths(graph, 1000, source, sink); + } + + /** + * @see #getKBestPaths(BaseGraph, int, java.util.Set, java.util.Set) with singleton source and sink sets + */ + public List> getKBestPaths( final BaseGraph graph, final int k, final T source, final T sink ) { + return getKBestPaths(graph, k, Collections.singleton(source), Collections.singleton(sink)); + } + + /** + * Traverse the graph and pull out the best k paths. + * Paths are scored via their comparator function. The default being PathComparatorTotalScore() + * @param graph the graph from which to pull paths + * @param k the number of paths to find + * @param sources a set of vertices we want to start paths with + * @param sinks a set of vertices we want to end paths with + * @return a list with at most k top-scoring paths from the graph + */ + @Ensures({"result != null", "result.size() <= k"}) + public List> getKBestPaths( final BaseGraph graph, final int k, final Set sources, final Set sinks ) { if( graph == null ) { throw new IllegalArgumentException("Attempting to traverse a null graph."); } // a min max queue that will collect the best k paths final MinMaxPriorityQueue> bestPaths = MinMaxPriorityQueue.orderedBy(new PathComparatorTotalScore()).maximumSize(k).create(); // run a DFS for best paths - for ( final T v : graph.vertexSet() ) { - if ( graph.inDegreeOf(v) == 0 ) { - findBestPaths(new Path(v, graph), bestPaths, new MyInt()); - } + for ( final T source : sources ) { + final Path startingPath = new Path(source, graph); + findBestPaths(startingPath, sinks, bestPaths, new MyInt()); } // the MinMaxPriorityQueue iterator returns items in an arbitrary order, so we need to sort the final result @@ -111,9 +154,15 @@ public class KBestPaths { return toReturn; } - private void findBestPaths( final Path path, final MinMaxPriorityQueue> bestPaths, final MyInt n ) { - // did we hit the end of a path? - if ( allOutgoingEdgesHaveBeenVisited(path) ) { + /** + * Recursive algorithm to find the K best paths in the graph from the current path to any of the sinks + * @param path the current path progress + * @param sinks a set of nodes that are sinks. Will terminate and add a path if the last vertex of path is in this set + * @param bestPaths a path to collect completed paths. + * @param n used to limit the search by tracking the number of vertices visited across all paths + */ + private void findBestPaths( final Path path, final Set sinks, final Collection> bestPaths, final MyInt n ) { + if ( sinks.contains(path.getLastVertex())) { bestPaths.add(path); } else if( n.val > 10000 ) { // do nothing, just return, as we've done too much work already @@ -122,31 +171,15 @@ public class KBestPaths { final ArrayList edgeArrayList = new ArrayList(path.getOutgoingEdgesOfLastVertex()); Collections.sort(edgeArrayList, new BaseEdge.EdgeWeightComparator()); for ( final BaseEdge edge : edgeArrayList ) { + final T target = path.getGraph().getEdgeTarget(edge); // make sure the edge is not already in the path - if ( path.containsEdge(edge) ) - continue; - - final Path newPath = new Path(path, edge); - n.val++; - findBestPaths(newPath, bestPaths, n); + final boolean alreadyVisited = allowCycles ? path.containsEdge(edge) : path.containsVertex(target); + if ( ! alreadyVisited ) { + final Path newPath = new Path(path, edge); + n.val++; + findBestPaths(newPath, sinks, bestPaths, n); + } } } } - - /** - * Have all of the outgoing edges of the final vertex been visited? - * - * I.e., are all outgoing vertices of the current path in the list of edges of the graph? - * - * @param path the path to test - * @return true if all the outgoing edges at the end of this path have already been visited - */ - private boolean allOutgoingEdgesHaveBeenVisited( final Path path ) { - for( final BaseEdge edge : path.getOutgoingEdgesOfLastVertex() ) { - if( !path.containsEdge(edge) ) { // TODO -- investigate allowing numInPath < 2 to allow cycles - return false; - } - } - return true; - } -} +} \ No newline at end of file diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/Path.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/Path.java index 50ca91d41..252ae3449 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/Path.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/Path.java @@ -148,6 +148,19 @@ public class Path { return edgesAsSet.contains(edge); } + /** + * Does this path contain the given vertex? + * + * @param v a non-null vertex + * @return true if v occurs within this path, false otherwise + */ + public boolean containsVertex(final T v) { + if ( v == null ) throw new IllegalArgumentException("Vertex cannot be null"); + + // TODO -- warning this is expense. Need to do vertex caching + return getVertices().contains(v); + } + /** * Check that two paths have the same edges and total score * @param path the other path we might be the same as diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/KBestPathsUnitTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/KBestPathsUnitTest.java index d20a0f778..3c6327842 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/KBestPathsUnitTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/KBestPathsUnitTest.java @@ -55,10 +55,7 @@ import org.testng.Assert; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.LinkedList; -import java.util.List; +import java.util.*; /** * Created with IntelliJ IDEA. @@ -70,15 +67,13 @@ public class KBestPathsUnitTest { @DataProvider(name = "BasicPathFindingData") public Object[][] makeBasicPathFindingData() { List tests = new ArrayList(); -// for ( final int nStartNodes : Arrays.asList(1) ) { -// for ( final int nBranchesPerBubble : Arrays.asList(2) ) { -// for ( final int nEndNodes : Arrays.asList(1) ) { -// for ( final boolean addCycle : Arrays.asList(true) ) { - for ( final int nStartNodes : Arrays.asList(1, 2, 3) ) { - for ( final int nBranchesPerBubble : Arrays.asList(2, 3) ) { - for ( final int nEndNodes : Arrays.asList(1, 2, 3) ) { - for ( final boolean addCycle : Arrays.asList(true, false) ) { - tests.add(new Object[]{nStartNodes, nBranchesPerBubble, nEndNodes, addCycle}); + for ( final boolean allowCycles : Arrays.asList(false, true)) { + for ( final int nStartNodes : Arrays.asList(1, 2, 3) ) { + for ( final int nBranchesPerBubble : Arrays.asList(2, 3) ) { + for ( final int nEndNodes : Arrays.asList(1, 2, 3) ) { + for ( final boolean addCycle : Arrays.asList(true, false) ) { + tests.add(new Object[]{nStartNodes, nBranchesPerBubble, nEndNodes, addCycle, allowCycles}); + } } } } @@ -88,9 +83,9 @@ public class KBestPathsUnitTest { } private static int weight = 1; - final List createVertices(final SeqGraph graph, final int n, final SeqVertex source, final SeqVertex target) { + final Set createVertices(final SeqGraph graph, final int n, final SeqVertex source, final SeqVertex target) { final List seqs = Arrays.asList("A", "C", "G", "T"); - final List vertices = new LinkedList(); + final Set vertices = new LinkedHashSet(); for ( int i = 0; i < n; i++ ) { final SeqVertex v = new SeqVertex(seqs.get(i)); graph.addVertex(v); @@ -102,22 +97,22 @@ public class KBestPathsUnitTest { } @Test(dataProvider = "BasicPathFindingData", enabled = true) - public void testBasicPathFinding(final int nStartNodes, final int nBranchesPerBubble, final int nEndNodes, final boolean addCycle) { + public void testBasicPathFinding(final int nStartNodes, final int nBranchesPerBubble, final int nEndNodes, final boolean addCycle, final boolean allowCycles) { SeqGraph graph = new SeqGraph(); final SeqVertex middleTop = new SeqVertex("GTAC"); final SeqVertex middleBottom = new SeqVertex("ACTG"); graph.addVertices(middleTop, middleBottom); - final List starts = createVertices(graph, nStartNodes, null, middleTop); - final List bubbles = createVertices(graph, nBranchesPerBubble, middleTop, middleBottom); - final List ends = createVertices(graph, nEndNodes, middleBottom, null); + final Set starts = createVertices(graph, nStartNodes, null, middleTop); + final Set bubbles = createVertices(graph, nBranchesPerBubble, middleTop, middleBottom); + final Set ends = createVertices(graph, nEndNodes, middleBottom, null); if ( addCycle ) graph.addEdge(middleBottom, middleBottom); // enumerate all possible paths - final List> paths = new KBestPaths().getKBestPaths(graph); + final List> paths = new KBestPaths(allowCycles).getKBestPaths(graph, starts, ends); - final int expectedNumOfPaths = nStartNodes * nBranchesPerBubble * (addCycle ? 2 : 1) * nEndNodes; + final int expectedNumOfPaths = nStartNodes * nBranchesPerBubble * (addCycle && allowCycles ? 2 : 1) * nEndNodes; Assert.assertEquals(paths.size(), expectedNumOfPaths, "Didn't find the expected number of paths"); int lastScore = Integer.MAX_VALUE; @@ -128,11 +123,47 @@ public class KBestPathsUnitTest { // get the best path, and make sure it's the same as our optimal path overall final Path best = paths.get(0); - final List> justOne = new KBestPaths().getKBestPaths(graph, 1); + final List> justOne = new KBestPaths(allowCycles).getKBestPaths(graph, 1, starts, ends); Assert.assertEquals(justOne.size(), 1); Assert.assertTrue(justOne.get(0).pathsAreTheSame(best), "Best path from complete enumerate " + best + " not the same as from k = 1 search " + justOne.get(0)); } + @Test + public void testPathFindingComplexCycle() { + SeqGraph graph = new SeqGraph(); + + final SeqVertex v1 = new SeqVertex("A"); + final SeqVertex v2 = new SeqVertex("C"); + final SeqVertex v3 = new SeqVertex("G"); + final SeqVertex v4 = new SeqVertex("T"); + final SeqVertex v5 = new SeqVertex("AA"); + graph.addVertices(v1, v2, v3, v4, v5); + graph.addEdges(v1, v2, v3, v4, v5); + graph.addEdges(v3, v3); + graph.addEdges(v4, v2); + + // enumerate all possible paths + final List> paths = new KBestPaths(false).getKBestPaths(graph, v1, v5); + + Assert.assertEquals(paths.size(), 1, "Didn't find the expected number of paths"); + } + + @Test + public void testPathFindingCycleLastNode() { + SeqGraph graph = new SeqGraph(); + + final SeqVertex v1 = new SeqVertex("A"); + final SeqVertex v2 = new SeqVertex("C"); + final SeqVertex v3 = new SeqVertex("G"); + graph.addVertices(v1, v2, v3); + graph.addEdges(v1, v2, v3, v3); + + // enumerate all possible paths + final List> paths = new KBestPaths(false).getKBestPaths(graph, v1, v3); + + Assert.assertEquals(paths.size(), 1, "Didn't find the expected number of paths"); + } + @DataProvider(name = "BasicBubbleDataProvider") public Object[][] makeBasicBubbleDataProvider() { List tests = new ArrayList();