More efficient ZipLinearChains algorithm

-- Goes through the graph looking for chains to zip, accumulates the vertices of the chains, and then finally go through and updates the graph in one big go.  Vastly more efficient than the previous version, but unfortunately doesn't actually work now
-- Also incorporate edge weight propagation into SeqGraph zipLinearChains.  The edge weights for all incoming and outgoing edges are now their previous value, plus the sum of the internal chain edges / n such edges
This commit is contained in:
Mark DePristo 2013-03-31 14:40:14 -04:00
parent 7105ad65a6
commit 2aac9e2782
2 changed files with 328 additions and 57 deletions

View File

@ -46,10 +46,13 @@
package org.broadinstitute.sting.gatk.walkers.haplotypecaller.graphs;
import org.apache.commons.lang.ArrayUtils;
import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import java.io.File;
import java.util.HashSet;
import java.util.LinkedList;
import java.util.List;
import java.util.Set;
/**
@ -58,7 +61,7 @@ import java.util.Set;
* @author: depristo
* @since 03/2013
*/
public class SeqGraph extends BaseGraph<SeqVertex> {
public final class SeqGraph extends BaseGraph<SeqVertex> {
private final static boolean PRINT_SIMPLIFY_GRAPHS = false;
private final static int MIN_SUFFIX_TO_MERGE_TAILS = 5;
@ -118,18 +121,8 @@ public class SeqGraph extends BaseGraph<SeqVertex> {
/**
* Zip up all of the simple linear chains present in this graph.
*/
public boolean zipLinearChains() {
boolean foundOne = false;
while( zipOneLinearChain() ) {
// just keep going until zipOneLinearChain says its done
foundOne = true;
}
return foundOne;
}
/**
* Merge together two vertices in the graph v1 -> v2 into a single vertex v' containing v1 + v2 sequence
*
* Merges together all pairs of vertices in the graph v1 -> v2 into a single vertex v' containing v1 + v2 sequence
*
* Only works on vertices where v1's only outgoing edge is to v2 and v2's only incoming edge is from v1.
*
@ -137,44 +130,153 @@ public class SeqGraph extends BaseGraph<SeqVertex> {
*
* @return true if any such pair of vertices could be found, false otherwise
*/
protected boolean zipOneLinearChain() {
for( final BaseEdge e : edgeSet() ) {
final SeqVertex outgoingVertex = getEdgeTarget(e);
final SeqVertex incomingVertex = getEdgeSource(e);
if( !outgoingVertex.equals(incomingVertex)
&& outDegreeOf(incomingVertex) == 1 && inDegreeOf(outgoingVertex) == 1
&& isReferenceNode(incomingVertex) == isReferenceNode(outgoingVertex) ) {
final Set<BaseEdge> outEdges = outgoingEdgesOf(outgoingVertex);
final Set<BaseEdge> inEdges = incomingEdgesOf(incomingVertex);
final BaseEdge singleOutEdge = outEdges.isEmpty() ? null : outEdges.iterator().next();
final BaseEdge singleInEdge = inEdges.isEmpty() ? null : inEdges.iterator().next();
if( inEdges.size() == 1 && outEdges.size() == 1 ) {
singleInEdge.setMultiplicity( singleInEdge.getMultiplicity() + ( e.getMultiplicity() / 2 ) );
singleOutEdge.setMultiplicity( singleOutEdge.getMultiplicity() + ( e.getMultiplicity() / 2 ) );
} else if( inEdges.size() == 1 ) {
singleInEdge.setMultiplicity( Math.max(singleInEdge.getMultiplicity() + ( e.getMultiplicity() - 1 ), 0) );
} else if( outEdges.size() == 1 ) {
singleOutEdge.setMultiplicity( Math.max( singleOutEdge.getMultiplicity() + ( e.getMultiplicity() - 1 ), 0) );
}
final SeqVertex addedVertex = new SeqVertex( ArrayUtils.addAll(incomingVertex.getSequence(), outgoingVertex.getSequence()) );
addVertex(addedVertex);
for( final BaseEdge edge : outEdges ) {
addEdge(addedVertex, getEdgeTarget(edge), new BaseEdge(edge.isRef(), edge.getMultiplicity()));
}
for( final BaseEdge edge : inEdges ) {
addEdge(getEdgeSource(edge), addedVertex, new BaseEdge(edge.isRef(), edge.getMultiplicity()));
}
removeVertex(incomingVertex);
removeVertex(outgoingVertex);
return true;
}
public boolean zipLinearChains() {
// create the list of start sites [doesn't modify graph yet]
final List<SeqVertex> zipStarts = new LinkedList<SeqVertex>();
for ( final SeqVertex source : vertexSet() ) {
if ( isLinearChainStart(source) )
zipStarts.add(source);
}
return false;
if ( zipStarts.isEmpty() ) // nothing to do, as nothing could start a chain
return false;
// At this point, zipStarts contains all of the vertices in this graph that might start some linear
// chain of vertices. We walk through each start, building up the linear chain of vertices and then
// zipping them up with mergeLinearChain, if possible
boolean mergedOne = false;
for ( final SeqVertex zipStart : zipStarts ) {
final LinkedList<SeqVertex> linearChain = traceLinearChain(zipStart);
// merge the linearized chain, recording if we actually did some useful work
mergedOne |= mergeLinearChain(linearChain);
}
return mergedOne;
}
/**
* Is source vertex potentially a start of a linear chain of vertices?
*
* We are a start of a zip chain if our out degree is 1 and either the
* the vertex has no incoming connections or 2 or more (we must start a chain) or
* we have exactly one incoming vertex and that one has out-degree > 1 (i.e., source's incoming
* vertex couldn't be a start itself
*
* @param source a non-null vertex
* @return true if source might start a linear chain
*/
@Requires("source != null")
private boolean isLinearChainStart(final SeqVertex source) {
return outDegreeOf(source) == 1
&& ( inDegreeOf(source) != 1
|| outDegreeOf(incomingVerticesOf(source).iterator().next()) > 1 );
}
/**
* Get all of the vertices in a linear chain of vertices starting at zipStart
*
* Build a list of vertices (in order) starting from zipStart such that each sequential pair of vertices
* in the chain A and B can be zipped together.
*
* @param zipStart a vertex that starts a linear chain
* @return a list of vertices that comprise a linear chain starting with zipStart. The resulting
* list will always contain at least zipStart as the first element.
*/
@Requires("isLinearChainStart(zipStart)")
@Ensures({"result != null", "result.size() >= 1"})
private LinkedList<SeqVertex> traceLinearChain(final SeqVertex zipStart) {
final LinkedList<SeqVertex> linearChain = new LinkedList<SeqVertex>();
linearChain.add(zipStart);
boolean lastIsRef = isReferenceNode(zipStart); // remember because this calculation is expensive
SeqVertex last = zipStart;
while (true) {
if ( outDegreeOf(last) != 1 )
// cannot extend a chain from last if last has multiple outgoing branches
break;
// there can only be one (outgoing edge of last) by contract
final SeqVertex target = getEdgeTarget(outgoingEdgeOf(last));
if ( inDegreeOf(target) != 1 || last.equals(target) )
// cannot zip up a target that has multiple incoming nodes or that's a cycle to the last node
break;
final boolean targetIsRef = isReferenceNode(target);
if ( lastIsRef != targetIsRef ) // both our isRef states must be equal
break;
linearChain.add(target); // extend our chain by one
// update our last state to be the current state, and continue
last = target;
lastIsRef = targetIsRef;
}
return linearChain;
}
/**
* Merge a linear chain of vertices into a single combined vertex, and update this graph to such that
* the incoming edges into the first element of the linearChain and the outgoing edges from linearChain.getLast()
* all point to this new combined vertex.
*
* @param linearChain a non-empty chain of vertices that can be zipped up into a single vertex
* @return true if we actually merged at least two vertices together
*/
protected boolean mergeLinearChain(final LinkedList<SeqVertex> linearChain) {
if ( linearChain.isEmpty() ) throw new IllegalArgumentException("BUG: cannot have linear chain with 0 elements but got " + linearChain);
final SeqVertex first = linearChain.getFirst();
final SeqVertex last = linearChain.getLast();
if ( first == last ) return false; // only one element in the chain, cannot be extended
// create the combined vertex, and add it to the graph
// TODO -- performance problem -- can be optimized if we want
final List<byte[]> seqs = new LinkedList<byte[]>();
for ( SeqVertex v : linearChain ) seqs.add(v.getSequence());
final byte[] seqsCat = org.broadinstitute.sting.utils.Utils.concat(seqs.toArray(new byte[][]{}));
final SeqVertex addedVertex = new SeqVertex( seqsCat );
addVertex(addedVertex);
final Set<BaseEdge> inEdges = incomingEdgesOf(first);
final Set<BaseEdge> outEdges = outgoingEdgesOf(last);
final int nEdges = inEdges.size() + outEdges.size();
int sharedWeightAmongEdges = nEdges == 0 ? 0 : sumEdgeWeightAlongChain(linearChain) / nEdges;
final BaseEdge inc = new BaseEdge(false, sharedWeightAmongEdges); // template to make .add function call easy
// update the incoming and outgoing edges to point to the new vertex
for( final BaseEdge edge : outEdges ) { addEdge(addedVertex, getEdgeTarget(edge), new BaseEdge(edge).add(inc)); }
for( final BaseEdge edge : inEdges ) { addEdge(getEdgeSource(edge), addedVertex, new BaseEdge(edge).add(inc)); }
removeAllVertices(linearChain);
return true;
}
/**
* Get the sum of the edge weights on a linear chain of at least 2 elements
*
* @param chain a linear chain of vertices with at least 2 vertices
* @return the sum of the multiplicities along all edges connecting vertices within the chain
*/
@Requires({"chain != null", "chain.size() >= 2"})
private int sumEdgeWeightAlongChain(final LinkedList<SeqVertex> chain) {
int sum = 0;
SeqVertex prev = null;
for ( final SeqVertex v : chain ) {
if ( prev != null ) {
final BaseEdge e = getEdge(prev, v);
if ( e == null ) throw new IllegalStateException("Something wrong with the linear chain, got a null edge between " + prev + " and " + v);
sum += e.getMultiplicity();
}
prev = v;
}
return sum;
}
/**
@ -241,7 +343,7 @@ public class SeqGraph extends BaseGraph<SeqVertex> {
protected class MergeDiamonds extends VertexBasedTransformer {
@Override
protected boolean tryToTransform(final SeqVertex top) {
final Set<SeqVertex> middles = outgoingVerticesOf(top);
final List<SeqVertex> middles = outgoingVerticesOf(top);
if ( middles.size() <= 1 )
// we can only merge if there's at least two middle nodes
return false;
@ -295,7 +397,7 @@ public class SeqGraph extends BaseGraph<SeqVertex> {
protected class MergeTails extends VertexBasedTransformer {
@Override
protected boolean tryToTransform(final SeqVertex top) {
final Set<SeqVertex> tails = outgoingVerticesOf(top);
final List<SeqVertex> tails = outgoingVerticesOf(top);
if ( tails.size() <= 1 )
return false;
@ -379,7 +481,7 @@ public class SeqGraph extends BaseGraph<SeqVertex> {
protected class MergeHeadlessIncomingSources extends VertexBasedTransformer {
@Override
boolean tryToTransform(final SeqVertex bottom) {
final Set<SeqVertex> incoming = incomingVerticesOf(bottom);
final List<SeqVertex> incoming = incomingVerticesOf(bottom);
if ( incoming.size() <= 1 )
return false;

View File

@ -51,11 +51,15 @@ import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.io.File;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.LinkedList;
import java.util.List;
public class SeqGraphUnitTest extends BaseTest {
private final static boolean DEBUG = true;
private class MergeNodesWithNoVariationTestProvider extends TestDataProvider {
public byte[] sequence;
public int KMER_LENGTH;
@ -98,7 +102,7 @@ public class SeqGraphUnitTest extends BaseTest {
return MergeNodesWithNoVariationTestProvider.getTests(MergeNodesWithNoVariationTestProvider.class);
}
@Test(dataProvider = "MergeNodesWithNoVariationTestProvider", enabled = true)
@Test(dataProvider = "MergeNodesWithNoVariationTestProvider", enabled = !DEBUG)
public void testMergeNodesWithNoVariation(MergeNodesWithNoVariationTestProvider cfg) {
logger.warn(String.format("Test: %s", cfg.toString()));
@ -178,7 +182,7 @@ public class SeqGraphUnitTest extends BaseTest {
return tests.toArray(new Object[][]{});
}
@Test(dataProvider = "IsDiamondData", enabled = true)
@Test(dataProvider = "IsDiamondData", enabled = !DEBUG)
public void testIsDiamond(final SeqGraph graph, final SeqVertex v, final boolean isRootOfDiamond) {
final SeqGraph.MergeDiamonds merger = graph.new MergeDiamonds();
merger.setDontModifyGraphEvenIfPossible();
@ -311,7 +315,7 @@ public class SeqGraphUnitTest extends BaseTest {
return tests.toArray(new Object[][]{});
}
@Test(dataProvider = "MergingData", enabled = true)
@Test(dataProvider = "MergingData", enabled = !DEBUG)
public void testMerging(final SeqGraph graph, final SeqGraph expected) {
final SeqGraph merged = (SeqGraph)graph.clone();
merged.simplifyGraph(1);
@ -333,7 +337,7 @@ public class SeqGraphUnitTest extends BaseTest {
//
// Should become A -> ACT -> C [ref and non-ref edges]
//
@Test
@Test(enabled = !DEBUG)
public void testBubbleSameBasesWithRef() {
final SeqGraph graph = new SeqGraph();
final SeqVertex top = new SeqVertex("A");
@ -351,4 +355,169 @@ public class SeqGraphUnitTest extends BaseTest {
actual.simplifyGraph();
Assert.assertTrue(BaseGraph.graphEquals(actual, expected), "Wrong merging result after complete merging");
}
@DataProvider(name = "LinearZipData")
public Object[][] makeLinearZipData() throws Exception {
List<Object[]> tests = new ArrayList<Object[]>();
SeqGraph graph = new SeqGraph();
SeqGraph expected = new SeqGraph();
// empty graph => empty graph
tests.add(new Object[]{graph.clone(), expected.clone()});
SeqVertex a1 = new SeqVertex("A");
SeqVertex c1 = new SeqVertex("C");
SeqVertex ac1 = new SeqVertex("AC");
// just a single vertex
graph.addVertices(a1, c1);
expected.addVertices(a1, c1);
tests.add(new Object[]{graph.clone(), expected.clone()});
graph.addEdges(a1, c1);
expected = new SeqGraph();
expected.addVertices(ac1);
tests.add(new Object[]{graph.clone(), expected.clone()});
// three long chain merged corrected
SeqVertex g1 = new SeqVertex("G");
graph.addVertices(g1);
graph.addEdges(c1, g1);
expected = new SeqGraph();
expected.addVertex(new SeqVertex("ACG"));
tests.add(new Object[]{graph.clone(), expected.clone()});
// adding something that isn't connected isn't a problem
SeqVertex t1 = new SeqVertex("T");
graph.addVertices(t1);
expected = new SeqGraph();
expected.addVertices(new SeqVertex("ACG"), new SeqVertex("T"));
tests.add(new Object[]{graph.clone(), expected.clone()});
// splitting chain with branch produces the correct zipped subgraphs
final SeqVertex a2 = new SeqVertex("A");
final SeqVertex c2 = new SeqVertex("C");
graph = new SeqGraph();
graph.addVertices(a1, c1, g1, t1, a2, c2);
graph.addEdges(a1, c1, g1, t1, a2);
graph.addEdges(g1, c2);
expected = new SeqGraph();
SeqVertex acg = new SeqVertex("ACG");
SeqVertex ta = new SeqVertex("TA");
expected.addVertices(acg, ta, c2);
expected.addEdges(acg, ta);
expected.addEdges(acg, c2);
tests.add(new Object[]{graph.clone(), expected.clone()});
// Can merge chains with loops in them
{
graph = new SeqGraph();
graph.addVertices(a1, c1, g1);
graph.addEdges(a1, c1, g1);
graph.addEdges(a1, a1);
expected = new SeqGraph();
SeqVertex ac = new SeqVertex("AC");
SeqVertex cg = new SeqVertex("CG");
expected.addVertices(a1, cg);
expected.addEdges(a1, cg);
expected.addEdges(a1, a1);
tests.add(new Object[]{graph.clone(), expected.clone()});
graph.removeEdge(a1, a1);
graph.addEdges(c1, c1);
tests.add(new Object[]{graph.clone(), graph.clone()});
graph.removeEdge(c1, c1);
graph.addEdges(g1, g1);
expected = new SeqGraph();
expected.addVertices(ac, g1);
expected.addEdges(ac, g1, g1);
tests.add(new Object[]{graph.clone(), expected.clone()});
}
// check building n element long chains
{
final List<String> bases = Arrays.asList("A", "C", "G", "T", "TT", "GG", "CC", "AA");
for ( final int len : Arrays.asList(1, 2, 10, 100, 1000)) {
graph = new SeqGraph();
expected = new SeqGraph();
SeqVertex last = null;
String expectedBases = "";
for ( int i = 0; i < len; i++ ) {
final String seq = bases.get(i % bases.size());
expectedBases += seq;
SeqVertex a = new SeqVertex(seq);
graph.addVertex(a);
if ( last != null ) graph.addEdge(last, a);
last = a;
}
expected.addVertex(new SeqVertex(expectedBases));
tests.add(new Object[]{graph.clone(), expected.clone()});
}
}
// check that edge connections are properly maintained
{
int edgeWeight = 1;
for ( final int nIncoming : Arrays.asList(0, 2, 5, 10) ) {
for ( final int nOutgoing : Arrays.asList(0, 2, 5, 10) ) {
graph = new SeqGraph();
expected = new SeqGraph();
graph.addVertices(a1, c1, g1);
graph.addEdges(a1, c1, g1);
expected.addVertex(acg);
for ( final SeqVertex v : makeVertices(nIncoming) ) {
final BaseEdge e = new BaseEdge(false, edgeWeight++);
graph.addVertices(v);
graph.addEdge(v, a1, e);
expected.addVertex(v);
expected.addEdge(v, acg, e);
}
for ( final SeqVertex v : makeVertices(nOutgoing) ) {
final BaseEdge e = new BaseEdge(false, edgeWeight++);
graph.addVertices(v);
graph.addEdge(g1, v, e);
expected.addVertex(v);
expected.addEdge(acg, v, e);
}
tests.add(new Object[]{graph, expected});
}
}
}
return tests.toArray(new Object[][]{});
}
private List<SeqVertex> makeVertices(final int n) {
final List<SeqVertex> vs = new LinkedList<SeqVertex>();
final List<String> bases = Arrays.asList("A", "C", "G", "T", "TT", "GG", "CC", "AA");
for ( int i = 0; i < n; i++ )
vs.add(new SeqVertex(bases.get(i % bases.size())));
return vs;
}
@Test(dataProvider = "LinearZipData", enabled = true)
public void testLinearZip(final SeqGraph graph, final SeqGraph expected) {
final SeqGraph merged = (SeqGraph)graph.clone();
merged.zipLinearChains();
try {
Assert.assertTrue(SeqGraph.graphEquals(merged, expected));
} catch (AssertionError e) {
if ( ! SeqGraph.graphEquals(merged, expected) ) {
graph.printGraph(new File("graph.dot"), 0);
merged.printGraph(new File("merged.dot"), 0);
expected.printGraph(new File("expected.dot"), 0);
}
throw e;
}
}
}