HaplotypeCaller no longer creates haplotypes that involve cycles in the SeqGraph

-- The kbest paths algorithm now takes an explicit set of starting and ending vertices, which is conceptually cleaner and works for either the cycle or no-cycle models.  Allowing cycles can be re-enabled with an HC command line switch.
This commit is contained in:
Mark DePristo 2013-04-04 15:15:10 -04:00
parent 5545c629f5
commit 9c7a35f73f
6 changed files with 183 additions and 69 deletions

View File

@ -95,22 +95,25 @@ public class DeBruijnAssembler extends LocalAssemblyEngine {
private final boolean debug;
private final boolean debugGraphTransformations;
private final int minKmer;
private final boolean allowCyclesInKmerGraphToGeneratePaths;
private final int onlyBuildKmersOfThisSizeWhenDebuggingGraphAlgorithms;
protected DeBruijnAssembler() {
this(false, -1, 11);
this(false, -1, 11, false);
}
public DeBruijnAssembler(final boolean debug,
final int debugGraphTransformations,
final int minKmer) {
final int minKmer,
final boolean allowCyclesInKmerGraphToGeneratePaths) {
super();
this.debug = debug;
this.debugGraphTransformations = debugGraphTransformations > 0;
this.onlyBuildKmersOfThisSizeWhenDebuggingGraphAlgorithms = debugGraphTransformations;
this.minKmer = minKmer;
this.allowCyclesInKmerGraphToGeneratePaths = allowCyclesInKmerGraphToGeneratePaths;
}
/**
@ -388,7 +391,12 @@ public class DeBruijnAssembler extends LocalAssemblyEngine {
}
for( final SeqGraph graph : graphs ) {
for ( final Path<SeqVertex> path : new KBestPaths<SeqVertex>().getKBestPaths(graph, NUM_BEST_PATHS_PER_KMER_GRAPH) ) {
final SeqVertex source = graph.getReferenceSourceVertex();
final SeqVertex sink = graph.getReferenceSinkVertex();
if ( source == null || sink == null ) throw new IllegalArgumentException("Both source and sink cannot be null but got " + source + " and sink " + sink + " for graph "+ graph);
final KBestPaths<SeqVertex> pathFinder = new KBestPaths<SeqVertex>(allowCyclesInKmerGraphToGeneratePaths);
for ( final Path<SeqVertex> path : pathFinder.getKBestPaths(graph, NUM_BEST_PATHS_PER_KMER_GRAPH, source, sink) ) {
// logger.info("Found path " + path);
Haplotype h = new Haplotype( path.getBases() );
if( !returnHaplotypes.contains(h) ) {

View File

@ -314,6 +314,11 @@ public class HaplotypeCaller extends ActiveRegionWalker<Integer, Integer> implem
@Argument(fullName="trimActiveRegions", shortName="trimActiveRegions", doc="If specified, we will trim down the active region from the full region (active + extension) to just the active interval for genotyping", required = false)
protected boolean trimActiveRegions = false;
@Hidden
@Argument(fullName="allowCyclesInKmerGraphToGeneratePaths", shortName="allowCyclesInKmerGraphToGeneratePaths", doc="If specified, we will allow cycles in the kmer graphs to generate paths with multiple copies of the path sequenece rather than just the shortest paths", required = false)
protected boolean allowCyclesInKmerGraphToGeneratePaths = false;
// the UG engines
private UnifiedGenotyperEngine UG_engine = null;
private UnifiedGenotyperEngine UG_engine_simple_genotyper = null;
@ -424,7 +429,7 @@ public class HaplotypeCaller extends ActiveRegionWalker<Integer, Integer> implem
}
// setup the assembler
assemblyEngine = new DeBruijnAssembler( DEBUG, debugGraphTransformations, minKmer);
assemblyEngine = new DeBruijnAssembler(DEBUG, debugGraphTransformations, minKmer, allowCyclesInKmerGraphToGeneratePaths);
assemblyEngine.setErrorCorrectKmers(errorCorrectKmers);
assemblyEngine.setPruneFactor(MIN_PRUNE_FACTOR);
if ( graphWriter != null ) assemblyEngine.setGraphWriter(graphWriter);

View File

@ -137,6 +137,30 @@ public class BaseGraph<T extends BaseVertex> extends DefaultDirectedGraph<T, Bas
return outDegreeOf(v) == 0;
}
/**
* Get the set of source vertices of this graph
* @return a non-null set
*/
public Set<T> getSources() {
final Set<T> set = new LinkedHashSet<T>();
for ( final T v : vertexSet() )
if ( isSource(v) )
set.add(v);
return set;
}
/**
* Get the set of sink vertices of this graph
* @return a non-null set
*/
public Set<T> getSinks() {
final Set<T> set = new LinkedHashSet<T>();
for ( final T v : vertexSet() )
if ( isSink(v) )
set.add(v);
return set;
}
/**
* Pull out the additional sequence implied by traversing this node in the graph
* @param v the vertex from which to pull out the additional base sequence

View File

@ -50,10 +50,7 @@ import com.google.common.collect.MinMaxPriorityQueue;
import com.google.java.contract.Ensures;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.List;
import java.util.*;
/**
* Class for finding the K best paths (as determined by the sum of multiplicities of the edges) in a graph.
@ -63,7 +60,23 @@ import java.util.List;
* Date: Mar 23, 2011
*/
public class KBestPaths<T extends BaseVertex> {
public KBestPaths() { }
private final boolean allowCycles;
/**
* Create a new KBestPaths finder that follows cycles in the graph
*/
public KBestPaths() {
this(true);
}
/**
* Create a new KBestPaths finder
*
* @param allowCycles should we allow paths that follow cycles in the graph?
*/
public KBestPaths(final boolean allowCycles) {
this.allowCycles = allowCycles;
}
protected static class MyInt { public int val = 0; }
@ -78,31 +91,61 @@ public class KBestPaths<T extends BaseVertex> {
}
/**
* @see #getKBestPaths(BaseGraph, int) retriving the first 1000 paths
* @see #getKBestPaths(BaseGraph, int) retriving the best 1000 paths
*/
public List<Path<T>> getKBestPaths( final BaseGraph<T> graph ) {
return getKBestPaths(graph, 1000);
}
/**
* Traverse the graph and pull out the best k paths.
* Paths are scored via their comparator function. The default being PathComparatorTotalScore()
* @param graph the graph from which to pull paths
* @param k the number of paths to find
* @return a list with at most k top-scoring paths from the graph
* @see #getKBestPaths(BaseGraph, int, java.util.Set, java.util.Set) retriving the first 1000 paths
* starting from all source vertices and ending with all sink vertices
*/
@Ensures({"result != null", "result.size() <= k"})
public List<Path<T>> getKBestPaths( final BaseGraph<T> graph, final int k ) {
return getKBestPaths(graph, k, graph.getSources(), graph.getSinks());
}
/**
* @see #getKBestPaths(BaseGraph, int, java.util.Set, java.util.Set) with k=1000
*/
public List<Path<T>> getKBestPaths( final BaseGraph<T> graph, final Set<T> sources, final Set<T> sinks ) {
return getKBestPaths(graph, 1000, sources, sinks);
}
/**
* @see #getKBestPaths(BaseGraph, int, java.util.Set, java.util.Set) with k=1000
*/
public List<Path<T>> getKBestPaths( final BaseGraph<T> graph, final T source, final T sink ) {
return getKBestPaths(graph, 1000, source, sink);
}
/**
* @see #getKBestPaths(BaseGraph, int, java.util.Set, java.util.Set) with singleton source and sink sets
*/
public List<Path<T>> getKBestPaths( final BaseGraph<T> graph, final int k, final T source, final T sink ) {
return getKBestPaths(graph, k, Collections.singleton(source), Collections.singleton(sink));
}
/**
* Traverse the graph and pull out the best k paths.
* Paths are scored via their comparator function. The default being PathComparatorTotalScore()
* @param graph the graph from which to pull paths
* @param k the number of paths to find
* @param sources a set of vertices we want to start paths with
* @param sinks a set of vertices we want to end paths with
* @return a list with at most k top-scoring paths from the graph
*/
@Ensures({"result != null", "result.size() <= k"})
public List<Path<T>> getKBestPaths( final BaseGraph<T> graph, final int k, final Set<T> sources, final Set<T> sinks ) {
if( graph == null ) { throw new IllegalArgumentException("Attempting to traverse a null graph."); }
// a min max queue that will collect the best k paths
final MinMaxPriorityQueue<Path<T>> bestPaths = MinMaxPriorityQueue.orderedBy(new PathComparatorTotalScore()).maximumSize(k).create();
// run a DFS for best paths
for ( final T v : graph.vertexSet() ) {
if ( graph.inDegreeOf(v) == 0 ) {
findBestPaths(new Path<T>(v, graph), bestPaths, new MyInt());
}
for ( final T source : sources ) {
final Path<T> startingPath = new Path<T>(source, graph);
findBestPaths(startingPath, sinks, bestPaths, new MyInt());
}
// the MinMaxPriorityQueue iterator returns items in an arbitrary order, so we need to sort the final result
@ -111,9 +154,15 @@ public class KBestPaths<T extends BaseVertex> {
return toReturn;
}
private void findBestPaths( final Path<T> path, final MinMaxPriorityQueue<Path<T>> bestPaths, final MyInt n ) {
// did we hit the end of a path?
if ( allOutgoingEdgesHaveBeenVisited(path) ) {
/**
* Recursive algorithm to find the K best paths in the graph from the current path to any of the sinks
* @param path the current path progress
* @param sinks a set of nodes that are sinks. Will terminate and add a path if the last vertex of path is in this set
* @param bestPaths a path to collect completed paths.
* @param n used to limit the search by tracking the number of vertices visited across all paths
*/
private void findBestPaths( final Path<T> path, final Set<T> sinks, final Collection<Path<T>> bestPaths, final MyInt n ) {
if ( sinks.contains(path.getLastVertex())) {
bestPaths.add(path);
} else if( n.val > 10000 ) {
// do nothing, just return, as we've done too much work already
@ -122,31 +171,15 @@ public class KBestPaths<T extends BaseVertex> {
final ArrayList<BaseEdge> edgeArrayList = new ArrayList<BaseEdge>(path.getOutgoingEdgesOfLastVertex());
Collections.sort(edgeArrayList, new BaseEdge.EdgeWeightComparator());
for ( final BaseEdge edge : edgeArrayList ) {
final T target = path.getGraph().getEdgeTarget(edge);
// make sure the edge is not already in the path
if ( path.containsEdge(edge) )
continue;
final Path<T> newPath = new Path<T>(path, edge);
n.val++;
findBestPaths(newPath, bestPaths, n);
final boolean alreadyVisited = allowCycles ? path.containsEdge(edge) : path.containsVertex(target);
if ( ! alreadyVisited ) {
final Path<T> newPath = new Path<T>(path, edge);
n.val++;
findBestPaths(newPath, sinks, bestPaths, n);
}
}
}
}
/**
* Have all of the outgoing edges of the final vertex been visited?
*
* I.e., are all outgoing vertices of the current path in the list of edges of the graph?
*
* @param path the path to test
* @return true if all the outgoing edges at the end of this path have already been visited
*/
private boolean allOutgoingEdgesHaveBeenVisited( final Path<T> path ) {
for( final BaseEdge edge : path.getOutgoingEdgesOfLastVertex() ) {
if( !path.containsEdge(edge) ) { // TODO -- investigate allowing numInPath < 2 to allow cycles
return false;
}
}
return true;
}
}
}

View File

@ -148,6 +148,19 @@ public class Path<T extends BaseVertex> {
return edgesAsSet.contains(edge);
}
/**
* Does this path contain the given vertex?
*
* @param v a non-null vertex
* @return true if v occurs within this path, false otherwise
*/
public boolean containsVertex(final T v) {
if ( v == null ) throw new IllegalArgumentException("Vertex cannot be null");
// TODO -- warning this is expense. Need to do vertex caching
return getVertices().contains(v);
}
/**
* Check that two paths have the same edges and total score
* @param path the other path we might be the same as

View File

@ -55,10 +55,7 @@ import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
import java.util.*;
/**
* Created with IntelliJ IDEA.
@ -70,15 +67,13 @@ public class KBestPathsUnitTest {
@DataProvider(name = "BasicPathFindingData")
public Object[][] makeBasicPathFindingData() {
List<Object[]> tests = new ArrayList<Object[]>();
// for ( final int nStartNodes : Arrays.asList(1) ) {
// for ( final int nBranchesPerBubble : Arrays.asList(2) ) {
// for ( final int nEndNodes : Arrays.asList(1) ) {
// for ( final boolean addCycle : Arrays.asList(true) ) {
for ( final int nStartNodes : Arrays.asList(1, 2, 3) ) {
for ( final int nBranchesPerBubble : Arrays.asList(2, 3) ) {
for ( final int nEndNodes : Arrays.asList(1, 2, 3) ) {
for ( final boolean addCycle : Arrays.asList(true, false) ) {
tests.add(new Object[]{nStartNodes, nBranchesPerBubble, nEndNodes, addCycle});
for ( final boolean allowCycles : Arrays.asList(false, true)) {
for ( final int nStartNodes : Arrays.asList(1, 2, 3) ) {
for ( final int nBranchesPerBubble : Arrays.asList(2, 3) ) {
for ( final int nEndNodes : Arrays.asList(1, 2, 3) ) {
for ( final boolean addCycle : Arrays.asList(true, false) ) {
tests.add(new Object[]{nStartNodes, nBranchesPerBubble, nEndNodes, addCycle, allowCycles});
}
}
}
}
@ -88,9 +83,9 @@ public class KBestPathsUnitTest {
}
private static int weight = 1;
final List<SeqVertex> createVertices(final SeqGraph graph, final int n, final SeqVertex source, final SeqVertex target) {
final Set<SeqVertex> createVertices(final SeqGraph graph, final int n, final SeqVertex source, final SeqVertex target) {
final List<String> seqs = Arrays.asList("A", "C", "G", "T");
final List<SeqVertex> vertices = new LinkedList<SeqVertex>();
final Set<SeqVertex> vertices = new LinkedHashSet<SeqVertex>();
for ( int i = 0; i < n; i++ ) {
final SeqVertex v = new SeqVertex(seqs.get(i));
graph.addVertex(v);
@ -102,22 +97,22 @@ public class KBestPathsUnitTest {
}
@Test(dataProvider = "BasicPathFindingData", enabled = true)
public void testBasicPathFinding(final int nStartNodes, final int nBranchesPerBubble, final int nEndNodes, final boolean addCycle) {
public void testBasicPathFinding(final int nStartNodes, final int nBranchesPerBubble, final int nEndNodes, final boolean addCycle, final boolean allowCycles) {
SeqGraph graph = new SeqGraph();
final SeqVertex middleTop = new SeqVertex("GTAC");
final SeqVertex middleBottom = new SeqVertex("ACTG");
graph.addVertices(middleTop, middleBottom);
final List<SeqVertex> starts = createVertices(graph, nStartNodes, null, middleTop);
final List<SeqVertex> bubbles = createVertices(graph, nBranchesPerBubble, middleTop, middleBottom);
final List<SeqVertex> ends = createVertices(graph, nEndNodes, middleBottom, null);
final Set<SeqVertex> starts = createVertices(graph, nStartNodes, null, middleTop);
final Set<SeqVertex> bubbles = createVertices(graph, nBranchesPerBubble, middleTop, middleBottom);
final Set<SeqVertex> ends = createVertices(graph, nEndNodes, middleBottom, null);
if ( addCycle ) graph.addEdge(middleBottom, middleBottom);
// enumerate all possible paths
final List<Path<SeqVertex>> paths = new KBestPaths<SeqVertex>().getKBestPaths(graph);
final List<Path<SeqVertex>> paths = new KBestPaths<SeqVertex>(allowCycles).getKBestPaths(graph, starts, ends);
final int expectedNumOfPaths = nStartNodes * nBranchesPerBubble * (addCycle ? 2 : 1) * nEndNodes;
final int expectedNumOfPaths = nStartNodes * nBranchesPerBubble * (addCycle && allowCycles ? 2 : 1) * nEndNodes;
Assert.assertEquals(paths.size(), expectedNumOfPaths, "Didn't find the expected number of paths");
int lastScore = Integer.MAX_VALUE;
@ -128,11 +123,47 @@ public class KBestPathsUnitTest {
// get the best path, and make sure it's the same as our optimal path overall
final Path best = paths.get(0);
final List<Path<SeqVertex>> justOne = new KBestPaths<SeqVertex>().getKBestPaths(graph, 1);
final List<Path<SeqVertex>> justOne = new KBestPaths<SeqVertex>(allowCycles).getKBestPaths(graph, 1, starts, ends);
Assert.assertEquals(justOne.size(), 1);
Assert.assertTrue(justOne.get(0).pathsAreTheSame(best), "Best path from complete enumerate " + best + " not the same as from k = 1 search " + justOne.get(0));
}
@Test
public void testPathFindingComplexCycle() {
SeqGraph graph = new SeqGraph();
final SeqVertex v1 = new SeqVertex("A");
final SeqVertex v2 = new SeqVertex("C");
final SeqVertex v3 = new SeqVertex("G");
final SeqVertex v4 = new SeqVertex("T");
final SeqVertex v5 = new SeqVertex("AA");
graph.addVertices(v1, v2, v3, v4, v5);
graph.addEdges(v1, v2, v3, v4, v5);
graph.addEdges(v3, v3);
graph.addEdges(v4, v2);
// enumerate all possible paths
final List<Path<SeqVertex>> paths = new KBestPaths<SeqVertex>(false).getKBestPaths(graph, v1, v5);
Assert.assertEquals(paths.size(), 1, "Didn't find the expected number of paths");
}
@Test
public void testPathFindingCycleLastNode() {
SeqGraph graph = new SeqGraph();
final SeqVertex v1 = new SeqVertex("A");
final SeqVertex v2 = new SeqVertex("C");
final SeqVertex v3 = new SeqVertex("G");
graph.addVertices(v1, v2, v3);
graph.addEdges(v1, v2, v3, v3);
// enumerate all possible paths
final List<Path<SeqVertex>> paths = new KBestPaths<SeqVertex>(false).getKBestPaths(graph, v1, v3);
Assert.assertEquals(paths.size(), 1, "Didn't find the expected number of paths");
}
@DataProvider(name = "BasicBubbleDataProvider")
public Object[][] makeBasicBubbleDataProvider() {
List<Object[]> tests = new ArrayList<Object[]>();