diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraph.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraph.java index 400b5c7ee..d08c2f211 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraph.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraph.java @@ -46,10 +46,13 @@ package org.broadinstitute.sting.gatk.walkers.haplotypecaller.graphs; -import org.apache.commons.lang.ArrayUtils; +import com.google.java.contract.Ensures; +import com.google.java.contract.Requires; import java.io.File; import java.util.HashSet; +import java.util.LinkedList; +import java.util.List; import java.util.Set; /** @@ -58,7 +61,7 @@ import java.util.Set; * @author: depristo * @since 03/2013 */ -public class SeqGraph extends BaseGraph { +public final class SeqGraph extends BaseGraph { private final static boolean PRINT_SIMPLIFY_GRAPHS = false; private final static int MIN_SUFFIX_TO_MERGE_TAILS = 5; @@ -118,18 +121,8 @@ public class SeqGraph extends BaseGraph { /** * Zip up all of the simple linear chains present in this graph. - */ - public boolean zipLinearChains() { - boolean foundOne = false; - while( zipOneLinearChain() ) { - // just keep going until zipOneLinearChain says its done - foundOne = true; - } - return foundOne; - } - - /** - * Merge together two vertices in the graph v1 -> v2 into a single vertex v' containing v1 + v2 sequence + * + * Merges together all pairs of vertices in the graph v1 -> v2 into a single vertex v' containing v1 + v2 sequence * * Only works on vertices where v1's only outgoing edge is to v2 and v2's only incoming edge is from v1. * @@ -137,44 +130,153 @@ public class SeqGraph extends BaseGraph { * * @return true if any such pair of vertices could be found, false otherwise */ - protected boolean zipOneLinearChain() { - for( final BaseEdge e : edgeSet() ) { - final SeqVertex outgoingVertex = getEdgeTarget(e); - final SeqVertex incomingVertex = getEdgeSource(e); - if( !outgoingVertex.equals(incomingVertex) - && outDegreeOf(incomingVertex) == 1 && inDegreeOf(outgoingVertex) == 1 - && isReferenceNode(incomingVertex) == isReferenceNode(outgoingVertex) ) { - - final Set outEdges = outgoingEdgesOf(outgoingVertex); - final Set inEdges = incomingEdgesOf(incomingVertex); - final BaseEdge singleOutEdge = outEdges.isEmpty() ? null : outEdges.iterator().next(); - final BaseEdge singleInEdge = inEdges.isEmpty() ? null : inEdges.iterator().next(); - - if( inEdges.size() == 1 && outEdges.size() == 1 ) { - singleInEdge.setMultiplicity( singleInEdge.getMultiplicity() + ( e.getMultiplicity() / 2 ) ); - singleOutEdge.setMultiplicity( singleOutEdge.getMultiplicity() + ( e.getMultiplicity() / 2 ) ); - } else if( inEdges.size() == 1 ) { - singleInEdge.setMultiplicity( Math.max(singleInEdge.getMultiplicity() + ( e.getMultiplicity() - 1 ), 0) ); - } else if( outEdges.size() == 1 ) { - singleOutEdge.setMultiplicity( Math.max( singleOutEdge.getMultiplicity() + ( e.getMultiplicity() - 1 ), 0) ); - } - - final SeqVertex addedVertex = new SeqVertex( ArrayUtils.addAll(incomingVertex.getSequence(), outgoingVertex.getSequence()) ); - addVertex(addedVertex); - for( final BaseEdge edge : outEdges ) { - addEdge(addedVertex, getEdgeTarget(edge), new BaseEdge(edge.isRef(), edge.getMultiplicity())); - } - for( final BaseEdge edge : inEdges ) { - addEdge(getEdgeSource(edge), addedVertex, new BaseEdge(edge.isRef(), edge.getMultiplicity())); - } - - removeVertex(incomingVertex); - removeVertex(outgoingVertex); - return true; - } + public boolean zipLinearChains() { + // create the list of start sites [doesn't modify graph yet] + final List zipStarts = new LinkedList(); + for ( final SeqVertex source : vertexSet() ) { + if ( isLinearChainStart(source) ) + zipStarts.add(source); } - return false; + if ( zipStarts.isEmpty() ) // nothing to do, as nothing could start a chain + return false; + + // At this point, zipStarts contains all of the vertices in this graph that might start some linear + // chain of vertices. We walk through each start, building up the linear chain of vertices and then + // zipping them up with mergeLinearChain, if possible + boolean mergedOne = false; + for ( final SeqVertex zipStart : zipStarts ) { + final LinkedList linearChain = traceLinearChain(zipStart); + + // merge the linearized chain, recording if we actually did some useful work + mergedOne |= mergeLinearChain(linearChain); + } + + return mergedOne; + } + + /** + * Is source vertex potentially a start of a linear chain of vertices? + * + * We are a start of a zip chain if our out degree is 1 and either the + * the vertex has no incoming connections or 2 or more (we must start a chain) or + * we have exactly one incoming vertex and that one has out-degree > 1 (i.e., source's incoming + * vertex couldn't be a start itself + * + * @param source a non-null vertex + * @return true if source might start a linear chain + */ + @Requires("source != null") + private boolean isLinearChainStart(final SeqVertex source) { + return outDegreeOf(source) == 1 + && ( inDegreeOf(source) != 1 + || outDegreeOf(incomingVerticesOf(source).iterator().next()) > 1 ); + } + + /** + * Get all of the vertices in a linear chain of vertices starting at zipStart + * + * Build a list of vertices (in order) starting from zipStart such that each sequential pair of vertices + * in the chain A and B can be zipped together. + * + * @param zipStart a vertex that starts a linear chain + * @return a list of vertices that comprise a linear chain starting with zipStart. The resulting + * list will always contain at least zipStart as the first element. + */ + @Requires("isLinearChainStart(zipStart)") + @Ensures({"result != null", "result.size() >= 1"}) + private LinkedList traceLinearChain(final SeqVertex zipStart) { + final LinkedList linearChain = new LinkedList(); + linearChain.add(zipStart); + + boolean lastIsRef = isReferenceNode(zipStart); // remember because this calculation is expensive + SeqVertex last = zipStart; + while (true) { + if ( outDegreeOf(last) != 1 ) + // cannot extend a chain from last if last has multiple outgoing branches + break; + + // there can only be one (outgoing edge of last) by contract + final SeqVertex target = getEdgeTarget(outgoingEdgeOf(last)); + + if ( inDegreeOf(target) != 1 || last.equals(target) ) + // cannot zip up a target that has multiple incoming nodes or that's a cycle to the last node + break; + + final boolean targetIsRef = isReferenceNode(target); + if ( lastIsRef != targetIsRef ) // both our isRef states must be equal + break; + + linearChain.add(target); // extend our chain by one + + // update our last state to be the current state, and continue + last = target; + lastIsRef = targetIsRef; + } + + return linearChain; + } + + /** + * Merge a linear chain of vertices into a single combined vertex, and update this graph to such that + * the incoming edges into the first element of the linearChain and the outgoing edges from linearChain.getLast() + * all point to this new combined vertex. + * + * @param linearChain a non-empty chain of vertices that can be zipped up into a single vertex + * @return true if we actually merged at least two vertices together + */ + protected boolean mergeLinearChain(final LinkedList linearChain) { + if ( linearChain.isEmpty() ) throw new IllegalArgumentException("BUG: cannot have linear chain with 0 elements but got " + linearChain); + + final SeqVertex first = linearChain.getFirst(); + final SeqVertex last = linearChain.getLast(); + + if ( first == last ) return false; // only one element in the chain, cannot be extended + + // create the combined vertex, and add it to the graph + // TODO -- performance problem -- can be optimized if we want + final List seqs = new LinkedList(); + for ( SeqVertex v : linearChain ) seqs.add(v.getSequence()); + final byte[] seqsCat = org.broadinstitute.sting.utils.Utils.concat(seqs.toArray(new byte[][]{})); + final SeqVertex addedVertex = new SeqVertex( seqsCat ); + addVertex(addedVertex); + + final Set inEdges = incomingEdgesOf(first); + final Set outEdges = outgoingEdgesOf(last); + + final int nEdges = inEdges.size() + outEdges.size(); + int sharedWeightAmongEdges = nEdges == 0 ? 0 : sumEdgeWeightAlongChain(linearChain) / nEdges; + final BaseEdge inc = new BaseEdge(false, sharedWeightAmongEdges); // template to make .add function call easy + + // update the incoming and outgoing edges to point to the new vertex + for( final BaseEdge edge : outEdges ) { addEdge(addedVertex, getEdgeTarget(edge), new BaseEdge(edge).add(inc)); } + for( final BaseEdge edge : inEdges ) { addEdge(getEdgeSource(edge), addedVertex, new BaseEdge(edge).add(inc)); } + + removeAllVertices(linearChain); + return true; + } + + /** + * Get the sum of the edge weights on a linear chain of at least 2 elements + * + * @param chain a linear chain of vertices with at least 2 vertices + * @return the sum of the multiplicities along all edges connecting vertices within the chain + */ + @Requires({"chain != null", "chain.size() >= 2"}) + private int sumEdgeWeightAlongChain(final LinkedList chain) { + int sum = 0; + SeqVertex prev = null; + + for ( final SeqVertex v : chain ) { + if ( prev != null ) { + final BaseEdge e = getEdge(prev, v); + if ( e == null ) throw new IllegalStateException("Something wrong with the linear chain, got a null edge between " + prev + " and " + v); + sum += e.getMultiplicity(); + } + prev = v; + } + + return sum; } /** @@ -241,7 +343,7 @@ public class SeqGraph extends BaseGraph { protected class MergeDiamonds extends VertexBasedTransformer { @Override protected boolean tryToTransform(final SeqVertex top) { - final Set middles = outgoingVerticesOf(top); + final List middles = outgoingVerticesOf(top); if ( middles.size() <= 1 ) // we can only merge if there's at least two middle nodes return false; @@ -295,7 +397,7 @@ public class SeqGraph extends BaseGraph { protected class MergeTails extends VertexBasedTransformer { @Override protected boolean tryToTransform(final SeqVertex top) { - final Set tails = outgoingVerticesOf(top); + final List tails = outgoingVerticesOf(top); if ( tails.size() <= 1 ) return false; @@ -379,7 +481,7 @@ public class SeqGraph extends BaseGraph { protected class MergeHeadlessIncomingSources extends VertexBasedTransformer { @Override boolean tryToTransform(final SeqVertex bottom) { - final Set incoming = incomingVerticesOf(bottom); + final List incoming = incomingVerticesOf(bottom); if ( incoming.size() <= 1 ) return false; diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraphUnitTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraphUnitTest.java index cbd7b1063..698b83199 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraphUnitTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/SeqGraphUnitTest.java @@ -51,11 +51,15 @@ import org.testng.Assert; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; +import java.io.File; import java.util.ArrayList; import java.util.Arrays; +import java.util.LinkedList; import java.util.List; public class SeqGraphUnitTest extends BaseTest { + private final static boolean DEBUG = true; + private class MergeNodesWithNoVariationTestProvider extends TestDataProvider { public byte[] sequence; public int KMER_LENGTH; @@ -98,7 +102,7 @@ public class SeqGraphUnitTest extends BaseTest { return MergeNodesWithNoVariationTestProvider.getTests(MergeNodesWithNoVariationTestProvider.class); } - @Test(dataProvider = "MergeNodesWithNoVariationTestProvider", enabled = true) + @Test(dataProvider = "MergeNodesWithNoVariationTestProvider", enabled = !DEBUG) public void testMergeNodesWithNoVariation(MergeNodesWithNoVariationTestProvider cfg) { logger.warn(String.format("Test: %s", cfg.toString())); @@ -178,7 +182,7 @@ public class SeqGraphUnitTest extends BaseTest { return tests.toArray(new Object[][]{}); } - @Test(dataProvider = "IsDiamondData", enabled = true) + @Test(dataProvider = "IsDiamondData", enabled = !DEBUG) public void testIsDiamond(final SeqGraph graph, final SeqVertex v, final boolean isRootOfDiamond) { final SeqGraph.MergeDiamonds merger = graph.new MergeDiamonds(); merger.setDontModifyGraphEvenIfPossible(); @@ -311,7 +315,7 @@ public class SeqGraphUnitTest extends BaseTest { return tests.toArray(new Object[][]{}); } - @Test(dataProvider = "MergingData", enabled = true) + @Test(dataProvider = "MergingData", enabled = !DEBUG) public void testMerging(final SeqGraph graph, final SeqGraph expected) { final SeqGraph merged = (SeqGraph)graph.clone(); merged.simplifyGraph(1); @@ -333,7 +337,7 @@ public class SeqGraphUnitTest extends BaseTest { // // Should become A -> ACT -> C [ref and non-ref edges] // - @Test + @Test(enabled = !DEBUG) public void testBubbleSameBasesWithRef() { final SeqGraph graph = new SeqGraph(); final SeqVertex top = new SeqVertex("A"); @@ -351,4 +355,169 @@ public class SeqGraphUnitTest extends BaseTest { actual.simplifyGraph(); Assert.assertTrue(BaseGraph.graphEquals(actual, expected), "Wrong merging result after complete merging"); } + + @DataProvider(name = "LinearZipData") + public Object[][] makeLinearZipData() throws Exception { + List tests = new ArrayList(); + + SeqGraph graph = new SeqGraph(); + SeqGraph expected = new SeqGraph(); + + // empty graph => empty graph + tests.add(new Object[]{graph.clone(), expected.clone()}); + + SeqVertex a1 = new SeqVertex("A"); + SeqVertex c1 = new SeqVertex("C"); + SeqVertex ac1 = new SeqVertex("AC"); + + // just a single vertex + graph.addVertices(a1, c1); + expected.addVertices(a1, c1); + + tests.add(new Object[]{graph.clone(), expected.clone()}); + + graph.addEdges(a1, c1); + expected = new SeqGraph(); + expected.addVertices(ac1); + tests.add(new Object[]{graph.clone(), expected.clone()}); + + // three long chain merged corrected + SeqVertex g1 = new SeqVertex("G"); + graph.addVertices(g1); + graph.addEdges(c1, g1); + expected = new SeqGraph(); + expected.addVertex(new SeqVertex("ACG")); + tests.add(new Object[]{graph.clone(), expected.clone()}); + + // adding something that isn't connected isn't a problem + SeqVertex t1 = new SeqVertex("T"); + graph.addVertices(t1); + expected = new SeqGraph(); + expected.addVertices(new SeqVertex("ACG"), new SeqVertex("T")); + tests.add(new Object[]{graph.clone(), expected.clone()}); + + // splitting chain with branch produces the correct zipped subgraphs + final SeqVertex a2 = new SeqVertex("A"); + final SeqVertex c2 = new SeqVertex("C"); + graph = new SeqGraph(); + graph.addVertices(a1, c1, g1, t1, a2, c2); + graph.addEdges(a1, c1, g1, t1, a2); + graph.addEdges(g1, c2); + expected = new SeqGraph(); + SeqVertex acg = new SeqVertex("ACG"); + SeqVertex ta = new SeqVertex("TA"); + expected.addVertices(acg, ta, c2); + expected.addEdges(acg, ta); + expected.addEdges(acg, c2); + tests.add(new Object[]{graph.clone(), expected.clone()}); + + // Can merge chains with loops in them + { + graph = new SeqGraph(); + graph.addVertices(a1, c1, g1); + graph.addEdges(a1, c1, g1); + graph.addEdges(a1, a1); + expected = new SeqGraph(); + + SeqVertex ac = new SeqVertex("AC"); + SeqVertex cg = new SeqVertex("CG"); + + expected.addVertices(a1, cg); + expected.addEdges(a1, cg); + expected.addEdges(a1, a1); + tests.add(new Object[]{graph.clone(), expected.clone()}); + + graph.removeEdge(a1, a1); + graph.addEdges(c1, c1); + tests.add(new Object[]{graph.clone(), graph.clone()}); + + graph.removeEdge(c1, c1); + graph.addEdges(g1, g1); + expected = new SeqGraph(); + expected.addVertices(ac, g1); + expected.addEdges(ac, g1, g1); + tests.add(new Object[]{graph.clone(), expected.clone()}); + } + + // check building n element long chains + { + final List bases = Arrays.asList("A", "C", "G", "T", "TT", "GG", "CC", "AA"); + for ( final int len : Arrays.asList(1, 2, 10, 100, 1000)) { + graph = new SeqGraph(); + expected = new SeqGraph(); + SeqVertex last = null; + String expectedBases = ""; + for ( int i = 0; i < len; i++ ) { + final String seq = bases.get(i % bases.size()); + expectedBases += seq; + SeqVertex a = new SeqVertex(seq); + graph.addVertex(a); + if ( last != null ) graph.addEdge(last, a); + last = a; + } + expected.addVertex(new SeqVertex(expectedBases)); + tests.add(new Object[]{graph.clone(), expected.clone()}); + } + } + + // check that edge connections are properly maintained + { + int edgeWeight = 1; + for ( final int nIncoming : Arrays.asList(0, 2, 5, 10) ) { + for ( final int nOutgoing : Arrays.asList(0, 2, 5, 10) ) { + graph = new SeqGraph(); + expected = new SeqGraph(); + + graph.addVertices(a1, c1, g1); + graph.addEdges(a1, c1, g1); + expected.addVertex(acg); + + for ( final SeqVertex v : makeVertices(nIncoming) ) { + final BaseEdge e = new BaseEdge(false, edgeWeight++); + graph.addVertices(v); + graph.addEdge(v, a1, e); + expected.addVertex(v); + expected.addEdge(v, acg, e); + } + + for ( final SeqVertex v : makeVertices(nOutgoing) ) { + final BaseEdge e = new BaseEdge(false, edgeWeight++); + graph.addVertices(v); + graph.addEdge(g1, v, e); + expected.addVertex(v); + expected.addEdge(acg, v, e); + } + + tests.add(new Object[]{graph, expected}); + } + } + } + + return tests.toArray(new Object[][]{}); + } + + private List makeVertices(final int n) { + final List vs = new LinkedList(); + final List bases = Arrays.asList("A", "C", "G", "T", "TT", "GG", "CC", "AA"); + + for ( int i = 0; i < n; i++ ) + vs.add(new SeqVertex(bases.get(i % bases.size()))); + return vs; + } + + @Test(dataProvider = "LinearZipData", enabled = true) + public void testLinearZip(final SeqGraph graph, final SeqGraph expected) { + final SeqGraph merged = (SeqGraph)graph.clone(); + merged.zipLinearChains(); + try { + Assert.assertTrue(SeqGraph.graphEquals(merged, expected)); + } catch (AssertionError e) { + if ( ! SeqGraph.graphEquals(merged, expected) ) { + graph.printGraph(new File("graph.dot"), 0); + merged.printGraph(new File("merged.dot"), 0); + expected.printGraph(new File("expected.dot"), 0); + } + throw e; + } + } }