From 2aac9e2782aaac2aaf60e06ca6734415c8d06743 Mon Sep 17 00:00:00 2001 From: Mark DePristo Date: Sun, 31 Mar 2013 14:40:14 -0400 Subject: [PATCH] More efficient ZipLinearChains algorithm -- Goes through the graph looking for chains to zip, accumulates the vertices of the chains, and then finally go through and updates the graph in one big go. Vastly more efficient than the previous version, but unfortunately doesn't actually work now -- Also incorporate edge weight propagation into SeqGraph zipLinearChains. The edge weights for all incoming and outgoing edges are now their previous value, plus the sum of the internal chain edges / n such edges --- .../haplotypecaller/graphs/SeqGraph.java | 208 +++++++++++++----- .../graphs/SeqGraphUnitTest.java | 177 ++++++++++++++- 2 files changed, 328 insertions(+), 57 deletions(-) diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraph.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraph.java index 400b5c7ee..d08c2f211 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraph.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraph.java @@ -46,10 +46,13 @@ package org.broadinstitute.sting.gatk.walkers.haplotypecaller.graphs; -import org.apache.commons.lang.ArrayUtils; +import com.google.java.contract.Ensures; +import com.google.java.contract.Requires; import java.io.File; import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; import java.util.Set; /** @@ -58,7 +61,7 @@ import java.util.Set; * @author: depristo * @since 03/2013 */ -public class SeqGraph extends BaseGraph { +public final class SeqGraph extends BaseGraph { private final static boolean PRINT_SIMPLIFY_GRAPHS = false; private final static int MIN_SUFFIX_TO_MERGE_TAILS = 5; @@ -118,18 +121,8 @@ public class SeqGraph extends BaseGraph { /** * Zip up all of the simple linear chains present in this graph. - */ - public boolean zipLinearChains() { - boolean foundOne = false; - while( zipOneLinearChain() ) { - // just keep going until zipOneLinearChain says its done - foundOne = true; - } - return foundOne; - } - - /** - * Merge together two vertices in the graph v1 -> v2 into a single vertex v' containing v1 + v2 sequence + * + * Merges together all pairs of vertices in the graph v1 -> v2 into a single vertex v' containing v1 + v2 sequence * * Only works on vertices where v1's only outgoing edge is to v2 and v2's only incoming edge is from v1. * @@ -137,44 +130,153 @@ public class SeqGraph extends BaseGraph { * * @return true if any such pair of vertices could be found, false otherwise */ - protected boolean zipOneLinearChain() { - for( final BaseEdge e : edgeSet() ) { - final SeqVertex outgoingVertex = getEdgeTarget(e); - final SeqVertex incomingVertex = getEdgeSource(e); - if( !outgoingVertex.equals(incomingVertex) - && outDegreeOf(incomingVertex) == 1 && inDegreeOf(outgoingVertex) == 1 - && isReferenceNode(incomingVertex) == isReferenceNode(outgoingVertex) ) { - - final Set outEdges = outgoingEdgesOf(outgoingVertex); - final Set inEdges = incomingEdgesOf(incomingVertex); - final BaseEdge singleOutEdge = outEdges.isEmpty() ? null : outEdges.iterator().next(); - final BaseEdge singleInEdge = inEdges.isEmpty() ? null : inEdges.iterator().next(); - - if( inEdges.size() == 1 && outEdges.size() == 1 ) { - singleInEdge.setMultiplicity( singleInEdge.getMultiplicity() + ( e.getMultiplicity() / 2 ) ); - singleOutEdge.setMultiplicity( singleOutEdge.getMultiplicity() + ( e.getMultiplicity() / 2 ) ); - } else if( inEdges.size() == 1 ) { - singleInEdge.setMultiplicity( Math.max(singleInEdge.getMultiplicity() + ( e.getMultiplicity() - 1 ), 0) ); - } else if( outEdges.size() == 1 ) { - singleOutEdge.setMultiplicity( Math.max( singleOutEdge.getMultiplicity() + ( e.getMultiplicity() - 1 ), 0) ); - } - - final SeqVertex addedVertex = new SeqVertex( ArrayUtils.addAll(incomingVertex.getSequence(), outgoingVertex.getSequence()) ); - addVertex(addedVertex); - for( final BaseEdge edge : outEdges ) { - addEdge(addedVertex, getEdgeTarget(edge), new BaseEdge(edge.isRef(), edge.getMultiplicity())); - } - for( final BaseEdge edge : inEdges ) { - addEdge(getEdgeSource(edge), addedVertex, new BaseEdge(edge.isRef(), edge.getMultiplicity())); - } - - removeVertex(incomingVertex); - removeVertex(outgoingVertex); - return true; - } + public boolean zipLinearChains() { + // create the list of start sites [doesn't modify graph yet] + final List zipStarts = new LinkedList(); + for ( final SeqVertex source : vertexSet() ) { + if ( isLinearChainStart(source) ) + zipStarts.add(source); } - return false; + if ( zipStarts.isEmpty() ) // nothing to do, as nothing could start a chain + return false; + + // At this point, zipStarts contains all of the vertices in this graph that might start some linear + // chain of vertices. We walk through each start, building up the linear chain of vertices and then + // zipping them up with mergeLinearChain, if possible + boolean mergedOne = false; + for ( final SeqVertex zipStart : zipStarts ) { + final LinkedList linearChain = traceLinearChain(zipStart); + + // merge the linearized chain, recording if we actually did some useful work + mergedOne |= mergeLinearChain(linearChain); + } + + return mergedOne; + } + + /** + * Is source vertex potentially a start of a linear chain of vertices? + * + * We are a start of a zip chain if our out degree is 1 and either the + * the vertex has no incoming connections or 2 or more (we must start a chain) or + * we have exactly one incoming vertex and that one has out-degree > 1 (i.e., source's incoming + * vertex couldn't be a start itself + * + * @param source a non-null vertex + * @return true if source might start a linear chain + */ + @Requires("source != null") + private boolean isLinearChainStart(final SeqVertex source) { + return outDegreeOf(source) == 1 + && ( inDegreeOf(source) != 1 + || outDegreeOf(incomingVerticesOf(source).iterator().next()) > 1 ); + } + + /** + * Get all of the vertices in a linear chain of vertices starting at zipStart + * + * Build a list of vertices (in order) starting from zipStart such that each sequential pair of vertices + * in the chain A and B can be zipped together. + * + * @param zipStart a vertex that starts a linear chain + * @return a list of vertices that comprise a linear chain starting with zipStart. The resulting + * list will always contain at least zipStart as the first element. + */ + @Requires("isLinearChainStart(zipStart)") + @Ensures({"result != null", "result.size() >= 1"}) + private LinkedList traceLinearChain(final SeqVertex zipStart) { + final LinkedList linearChain = new LinkedList(); + linearChain.add(zipStart); + + boolean lastIsRef = isReferenceNode(zipStart); // remember because this calculation is expensive + SeqVertex last = zipStart; + while (true) { + if ( outDegreeOf(last) != 1 ) + // cannot extend a chain from last if last has multiple outgoing branches + break; + + // there can only be one (outgoing edge of last) by contract + final SeqVertex target = getEdgeTarget(outgoingEdgeOf(last)); + + if ( inDegreeOf(target) != 1 || last.equals(target) ) + // cannot zip up a target that has multiple incoming nodes or that's a cycle to the last node + break; + + final boolean targetIsRef = isReferenceNode(target); + if ( lastIsRef != targetIsRef ) // both our isRef states must be equal + break; + + linearChain.add(target); // extend our chain by one + + // update our last state to be the current state, and continue + last = target; + lastIsRef = targetIsRef; + } + + return linearChain; + } + + /** + * Merge a linear chain of vertices into a single combined vertex, and update this graph to such that + * the incoming edges into the first element of the linearChain and the outgoing edges from linearChain.getLast() + * all point to this new combined vertex. + * + * @param linearChain a non-empty chain of vertices that can be zipped up into a single vertex + * @return true if we actually merged at least two vertices together + */ + protected boolean mergeLinearChain(final LinkedList linearChain) { + if ( linearChain.isEmpty() ) throw new IllegalArgumentException("BUG: cannot have linear chain with 0 elements but got " + linearChain); + + final SeqVertex first = linearChain.getFirst(); + final SeqVertex last = linearChain.getLast(); + + if ( first == last ) return false; // only one element in the chain, cannot be extended + + // create the combined vertex, and add it to the graph + // TODO -- performance problem -- can be optimized if we want + final List seqs = new LinkedList(); + for ( SeqVertex v : linearChain ) seqs.add(v.getSequence()); + final byte[] seqsCat = org.broadinstitute.sting.utils.Utils.concat(seqs.toArray(new byte[][]{})); + final SeqVertex addedVertex = new SeqVertex( seqsCat ); + addVertex(addedVertex); + + final Set inEdges = incomingEdgesOf(first); + final Set outEdges = outgoingEdgesOf(last); + + final int nEdges = inEdges.size() + outEdges.size(); + int sharedWeightAmongEdges = nEdges == 0 ? 0 : sumEdgeWeightAlongChain(linearChain) / nEdges; + final BaseEdge inc = new BaseEdge(false, sharedWeightAmongEdges); // template to make .add function call easy + + // update the incoming and outgoing edges to point to the new vertex + for( final BaseEdge edge : outEdges ) { addEdge(addedVertex, getEdgeTarget(edge), new BaseEdge(edge).add(inc)); } + for( final BaseEdge edge : inEdges ) { addEdge(getEdgeSource(edge), addedVertex, new BaseEdge(edge).add(inc)); } + + removeAllVertices(linearChain); + return true; + } + + /** + * Get the sum of the edge weights on a linear chain of at least 2 elements + * + * @param chain a linear chain of vertices with at least 2 vertices + * @return the sum of the multiplicities along all edges connecting vertices within the chain + */ + @Requires({"chain != null", "chain.size() >= 2"}) + private int sumEdgeWeightAlongChain(final LinkedList chain) { + int sum = 0; + SeqVertex prev = null; + + for ( final SeqVertex v : chain ) { + if ( prev != null ) { + final BaseEdge e = getEdge(prev, v); + if ( e == null ) throw new IllegalStateException("Something wrong with the linear chain, got a null edge between " + prev + " and " + v); + sum += e.getMultiplicity(); + } + prev = v; + } + + return sum; } /** @@ -241,7 +343,7 @@ public class SeqGraph extends BaseGraph { protected class MergeDiamonds extends VertexBasedTransformer { @Override protected boolean tryToTransform(final SeqVertex top) { - final Set middles = outgoingVerticesOf(top); + final List middles = outgoingVerticesOf(top); if ( middles.size() <= 1 ) // we can only merge if there's at least two middle nodes return false; @@ -295,7 +397,7 @@ public class SeqGraph extends BaseGraph { protected class MergeTails extends VertexBasedTransformer { @Override protected boolean tryToTransform(final SeqVertex top) { - final Set tails = outgoingVerticesOf(top); + final List tails = outgoingVerticesOf(top); if ( tails.size() <= 1 ) return false; @@ -379,7 +481,7 @@ public class SeqGraph extends BaseGraph { protected class MergeHeadlessIncomingSources extends VertexBasedTransformer { @Override boolean tryToTransform(final SeqVertex bottom) { - final Set incoming = incomingVerticesOf(bottom); + final List incoming = incomingVerticesOf(bottom); if ( incoming.size() <= 1 ) return false; diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraphUnitTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraphUnitTest.java index cbd7b1063..698b83199 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraphUnitTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraphUnitTest.java @@ -51,11 +51,15 @@ import org.testng.Assert; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import java.io.File; import java.util.ArrayList; import java.util.Arrays; +import java.util.LinkedList; import java.util.List; public class SeqGraphUnitTest extends BaseTest { + private final static boolean DEBUG = true; + private class MergeNodesWithNoVariationTestProvider extends TestDataProvider { public byte[] sequence; public int KMER_LENGTH; @@ -98,7 +102,7 @@ public class SeqGraphUnitTest extends BaseTest { return MergeNodesWithNoVariationTestProvider.getTests(MergeNodesWithNoVariationTestProvider.class); } - @Test(dataProvider = "MergeNodesWithNoVariationTestProvider", enabled = true) + @Test(dataProvider = "MergeNodesWithNoVariationTestProvider", enabled = !DEBUG) public void testMergeNodesWithNoVariation(MergeNodesWithNoVariationTestProvider cfg) { logger.warn(String.format("Test: %s", cfg.toString())); @@ -178,7 +182,7 @@ public class SeqGraphUnitTest extends BaseTest { return tests.toArray(new Object[][]{}); } - @Test(dataProvider = "IsDiamondData", enabled = true) + @Test(dataProvider = "IsDiamondData", enabled = !DEBUG) public void testIsDiamond(final SeqGraph graph, final SeqVertex v, final boolean isRootOfDiamond) { final SeqGraph.MergeDiamonds merger = graph.new MergeDiamonds(); merger.setDontModifyGraphEvenIfPossible(); @@ -311,7 +315,7 @@ public class SeqGraphUnitTest extends BaseTest { return tests.toArray(new Object[][]{}); } - @Test(dataProvider = "MergingData", enabled = true) + @Test(dataProvider = "MergingData", enabled = !DEBUG) public void testMerging(final SeqGraph graph, final SeqGraph expected) { final SeqGraph merged = (SeqGraph)graph.clone(); merged.simplifyGraph(1); @@ -333,7 +337,7 @@ public class SeqGraphUnitTest extends BaseTest { // // Should become A -> ACT -> C [ref and non-ref edges] // - @Test + @Test(enabled = !DEBUG) public void testBubbleSameBasesWithRef() { final SeqGraph graph = new SeqGraph(); final SeqVertex top = new SeqVertex("A"); @@ -351,4 +355,169 @@ public class SeqGraphUnitTest extends BaseTest { actual.simplifyGraph(); Assert.assertTrue(BaseGraph.graphEquals(actual, expected), "Wrong merging result after complete merging"); } + + @DataProvider(name = "LinearZipData") + public Object[][] makeLinearZipData() throws Exception { + List tests = new ArrayList(); + + SeqGraph graph = new SeqGraph(); + SeqGraph expected = new SeqGraph(); + + // empty graph => empty graph + tests.add(new Object[]{graph.clone(), expected.clone()}); + + SeqVertex a1 = new SeqVertex("A"); + SeqVertex c1 = new SeqVertex("C"); + SeqVertex ac1 = new SeqVertex("AC"); + + // just a single vertex + graph.addVertices(a1, c1); + expected.addVertices(a1, c1); + + tests.add(new Object[]{graph.clone(), expected.clone()}); + + graph.addEdges(a1, c1); + expected = new SeqGraph(); + expected.addVertices(ac1); + tests.add(new Object[]{graph.clone(), expected.clone()}); + + // three long chain merged corrected + SeqVertex g1 = new SeqVertex("G"); + graph.addVertices(g1); + graph.addEdges(c1, g1); + expected = new SeqGraph(); + expected.addVertex(new SeqVertex("ACG")); + tests.add(new Object[]{graph.clone(), expected.clone()}); + + // adding something that isn't connected isn't a problem + SeqVertex t1 = new SeqVertex("T"); + graph.addVertices(t1); + expected = new SeqGraph(); + expected.addVertices(new SeqVertex("ACG"), new SeqVertex("T")); + tests.add(new Object[]{graph.clone(), expected.clone()}); + + // splitting chain with branch produces the correct zipped subgraphs + final SeqVertex a2 = new SeqVertex("A"); + final SeqVertex c2 = new SeqVertex("C"); + graph = new SeqGraph(); + graph.addVertices(a1, c1, g1, t1, a2, c2); + graph.addEdges(a1, c1, g1, t1, a2); + graph.addEdges(g1, c2); + expected = new SeqGraph(); + SeqVertex acg = new SeqVertex("ACG"); + SeqVertex ta = new SeqVertex("TA"); + expected.addVertices(acg, ta, c2); + expected.addEdges(acg, ta); + expected.addEdges(acg, c2); + tests.add(new Object[]{graph.clone(), expected.clone()}); + + // Can merge chains with loops in them + { + graph = new SeqGraph(); + graph.addVertices(a1, c1, g1); + graph.addEdges(a1, c1, g1); + graph.addEdges(a1, a1); + expected = new SeqGraph(); + + SeqVertex ac = new SeqVertex("AC"); + SeqVertex cg = new SeqVertex("CG"); + + expected.addVertices(a1, cg); + expected.addEdges(a1, cg); + expected.addEdges(a1, a1); + tests.add(new Object[]{graph.clone(), expected.clone()}); + + graph.removeEdge(a1, a1); + graph.addEdges(c1, c1); + tests.add(new Object[]{graph.clone(), graph.clone()}); + + graph.removeEdge(c1, c1); + graph.addEdges(g1, g1); + expected = new SeqGraph(); + expected.addVertices(ac, g1); + expected.addEdges(ac, g1, g1); + tests.add(new Object[]{graph.clone(), expected.clone()}); + } + + // check building n element long chains + { + final List bases = Arrays.asList("A", "C", "G", "T", "TT", "GG", "CC", "AA"); + for ( final int len : Arrays.asList(1, 2, 10, 100, 1000)) { + graph = new SeqGraph(); + expected = new SeqGraph(); + SeqVertex last = null; + String expectedBases = ""; + for ( int i = 0; i < len; i++ ) { + final String seq = bases.get(i % bases.size()); + expectedBases += seq; + SeqVertex a = new SeqVertex(seq); + graph.addVertex(a); + if ( last != null ) graph.addEdge(last, a); + last = a; + } + expected.addVertex(new SeqVertex(expectedBases)); + tests.add(new Object[]{graph.clone(), expected.clone()}); + } + } + + // check that edge connections are properly maintained + { + int edgeWeight = 1; + for ( final int nIncoming : Arrays.asList(0, 2, 5, 10) ) { + for ( final int nOutgoing : Arrays.asList(0, 2, 5, 10) ) { + graph = new SeqGraph(); + expected = new SeqGraph(); + + graph.addVertices(a1, c1, g1); + graph.addEdges(a1, c1, g1); + expected.addVertex(acg); + + for ( final SeqVertex v : makeVertices(nIncoming) ) { + final BaseEdge e = new BaseEdge(false, edgeWeight++); + graph.addVertices(v); + graph.addEdge(v, a1, e); + expected.addVertex(v); + expected.addEdge(v, acg, e); + } + + for ( final SeqVertex v : makeVertices(nOutgoing) ) { + final BaseEdge e = new BaseEdge(false, edgeWeight++); + graph.addVertices(v); + graph.addEdge(g1, v, e); + expected.addVertex(v); + expected.addEdge(acg, v, e); + } + + tests.add(new Object[]{graph, expected}); + } + } + } + + return tests.toArray(new Object[][]{}); + } + + private List makeVertices(final int n) { + final List vs = new LinkedList(); + final List bases = Arrays.asList("A", "C", "G", "T", "TT", "GG", "CC", "AA"); + + for ( int i = 0; i < n; i++ ) + vs.add(new SeqVertex(bases.get(i % bases.size()))); + return vs; + } + + @Test(dataProvider = "LinearZipData", enabled = true) + public void testLinearZip(final SeqGraph graph, final SeqGraph expected) { + final SeqGraph merged = (SeqGraph)graph.clone(); + merged.zipLinearChains(); + try { + Assert.assertTrue(SeqGraph.graphEquals(merged, expected)); + } catch (AssertionError e) { + if ( ! SeqGraph.graphEquals(merged, expected) ) { + graph.printGraph(new File("graph.dot"), 0); + merged.printGraph(new File("merged.dot"), 0); + expected.printGraph(new File("expected.dot"), 0); + } + throw e; + } + } }