diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCaller.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCaller.java index b94b74748..9b9c3924b 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCaller.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCaller.java @@ -80,8 +80,6 @@ import org.broadinstitute.sting.utils.activeregion.ActivityProfileState; import org.broadinstitute.sting.utils.clipping.ReadClipper; import org.broadinstitute.sting.utils.exceptions.UserException; import org.broadinstitute.sting.utils.fasta.CachingIndexedFastaSequenceFile; -import org.broadinstitute.sting.utils.fragments.FragmentCollection; -import org.broadinstitute.sting.utils.fragments.FragmentUtils; import org.broadinstitute.sting.utils.genotyper.PerReadAlleleLikelihoodMap; import org.broadinstitute.sting.utils.haplotype.*; import org.broadinstitute.sting.utils.haplotypeBAMWriter.HaplotypeBAMWriter; @@ -270,6 +268,10 @@ public class HaplotypeCaller extends ActiveRegionWalker, In @Argument(fullName="dontIncreaseKmerSizesForCycles", shortName="dontIncreaseKmerSizesForCycles", doc="Should we disable the iterating over kmer sizes when graph cycles are detected?", required = false) protected boolean dontIncreaseKmerSizesForCycles = false; + @Advanced + @Argument(fullName="numPruningSamples", shortName="numPruningSamples", doc="The number of samples that must pass the minPuning factor in order for the path to be kept", required = false) + protected int numPruningSamples = 1; + /** * Assembly graph can be quite complex, and could imply a very large number of possible haplotypes. Each haplotype * considered requires N PairHMM evaluations if there are N reads across all samples. In order to control the @@ -539,7 +541,7 @@ public class HaplotypeCaller extends ActiveRegionWalker, In final int maxAllowedPathsForReadThreadingAssembler = Math.max(maxPathsPerSample * nSamples, MIN_PATHS_PER_GRAPH); assemblyEngine = useDebruijnAssembler ? new DeBruijnAssembler(minKmerForDebruijnAssembler, onlyUseKmerSizeForDebruijnAssembler) - : new ReadThreadingAssembler(maxAllowedPathsForReadThreadingAssembler, kmerSizes, dontIncreaseKmerSizesForCycles); + : new ReadThreadingAssembler(maxAllowedPathsForReadThreadingAssembler, kmerSizes, dontIncreaseKmerSizesForCycles, numPruningSamples); assemblyEngine.setErrorCorrectKmers(errorCorrectKmers); assemblyEngine.setPruneFactor(MIN_PRUNE_FACTOR); diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/MultiSampleEdge.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/MultiSampleEdge.java index c1937e5c8..978d83eb4 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/MultiSampleEdge.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/graphs/MultiSampleEdge.java @@ -46,6 +46,8 @@ package org.broadinstitute.sting.gatk.walkers.haplotypecaller.graphs; +import java.util.PriorityQueue; + /** * edge class for connecting nodes in the graph that tracks some per-sample information * @@ -63,32 +65,43 @@ package org.broadinstitute.sting.gatk.walkers.haplotypecaller.graphs; * e.getPruningMultiplicity() // = 3 */ public class MultiSampleEdge extends BaseEdge { - private int maxSingleSampleMultiplicity, currentSingleSampleMultiplicity; + private int currentSingleSampleMultiplicity; + private final int singleSampleCapacity; + private final PriorityQueue singleSampleMultiplicities; /** * Create a new MultiSampleEdge with weight multiplicity and, if isRef == true, indicates a path through the reference * * @param isRef indicates whether this edge is a path through the reference * @param multiplicity the number of observations of this edge in this sample + * @param singleSampleCapacity the max number of samples to track edge multiplicities */ - public MultiSampleEdge(final boolean isRef, final int multiplicity) { + public MultiSampleEdge(final boolean isRef, final int multiplicity, final int singleSampleCapacity) { super(isRef, multiplicity); - maxSingleSampleMultiplicity = multiplicity; + + if( singleSampleCapacity <= 0 ) { throw new IllegalArgumentException("singleSampleCapacity must be > 0 but found: " + singleSampleCapacity); } + singleSampleMultiplicities = new PriorityQueue<>(singleSampleCapacity); + singleSampleMultiplicities.add(multiplicity); currentSingleSampleMultiplicity = multiplicity; + this.singleSampleCapacity = singleSampleCapacity; } @Override public MultiSampleEdge copy() { - return new MultiSampleEdge(isRef(), getMultiplicity()); // TODO -- should I copy values for other features? + return new MultiSampleEdge(isRef(), getMultiplicity(), singleSampleCapacity); // TODO -- should I copy values for other features? } /** - * update the max single sample multiplicity based on the current single sample multiplicity, and + * update the single sample multiplicities by adding the current single sample multiplicity to the priority queue, and * reset the current single sample multiplicity to 0. */ public void flushSingleSampleMultiplicity() { - if ( currentSingleSampleMultiplicity > maxSingleSampleMultiplicity ) - maxSingleSampleMultiplicity = currentSingleSampleMultiplicity; + singleSampleMultiplicities.add(currentSingleSampleMultiplicity); + if( singleSampleMultiplicities.size() == singleSampleCapacity + 1 ) { + singleSampleMultiplicities.poll(); // remove the lowest multiplicity from the list + } else if( singleSampleMultiplicities.size() > singleSampleCapacity + 1 ) { + throw new IllegalStateException("Somehow the per sample multiplicity list has grown too big: " + singleSampleMultiplicities); + } currentSingleSampleMultiplicity = 0; } @@ -100,20 +113,12 @@ public class MultiSampleEdge extends BaseEdge { @Override public int getPruningMultiplicity() { - return getMaxSingleSampleMultiplicity(); + return singleSampleMultiplicities.peek(); } @Override public String getDotLabel() { - return super.getDotLabel() + "/" + getMaxSingleSampleMultiplicity(); - } - - /** - * Get the maximum multiplicity for this edge seen in any single sample - * @return an integer >= 0 - */ - public int getMaxSingleSampleMultiplicity() { - return maxSingleSampleMultiplicity; + return super.getDotLabel() + "/" + getPruningMultiplicity(); } /** only provided for testing purposes */ diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/readthreading/ReadThreadingAssembler.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/readthreading/ReadThreadingAssembler.java index fc0f781c5..672c61c0f 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/readthreading/ReadThreadingAssembler.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/readthreading/ReadThreadingAssembler.java @@ -71,6 +71,7 @@ public class ReadThreadingAssembler extends LocalAssemblyEngine { private final int maxAllowedPathsForReadThreadingAssembler; private final boolean dontIncreaseKmerSizesForCycles; + private final int numPruningSamples; private boolean requireReasonableNumberOfPaths = false; protected boolean removePathsNotConnectedToRef = true; private boolean justReturnRawGraph = false; @@ -80,15 +81,16 @@ public class ReadThreadingAssembler extends LocalAssemblyEngine { this(DEFAULT_NUM_PATHS_PER_GRAPH, Arrays.asList(25)); } - public ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler, final List kmerSizes, final boolean dontIncreaseKmerSizesForCycles) { + public ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler, final List kmerSizes, final boolean dontIncreaseKmerSizesForCycles, final int numPruningSamples) { super(maxAllowedPathsForReadThreadingAssembler); this.kmerSizes = kmerSizes; this.maxAllowedPathsForReadThreadingAssembler = maxAllowedPathsForReadThreadingAssembler; this.dontIncreaseKmerSizesForCycles = dontIncreaseKmerSizesForCycles; + this.numPruningSamples = numPruningSamples; } public ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler, final List kmerSizes) { - this(maxAllowedPathsForReadThreadingAssembler, kmerSizes, true); + this(maxAllowedPathsForReadThreadingAssembler, kmerSizes, true, 1); } /** for testing purposes */ @@ -139,7 +141,7 @@ public class ReadThreadingAssembler extends LocalAssemblyEngine { final int kmerSize, final List activeAlleleHaplotypes, final boolean allowLowComplexityGraphs) { - final ReadThreadingGraph rtgraph = new ReadThreadingGraph(kmerSize, debugGraphTransformations, minBaseQualityToUseInAssembly); + final ReadThreadingGraph rtgraph = new ReadThreadingGraph(kmerSize, debugGraphTransformations, minBaseQualityToUseInAssembly, numPruningSamples); // add the reference sequence to the graph rtgraph.addSequence("ref", refHaplotype.getBases(), null, true); diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/readthreading/ReadThreadingGraph.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/readthreading/ReadThreadingGraph.java index 0844f979b..7d7df2c06 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/readthreading/ReadThreadingGraph.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/readthreading/ReadThreadingGraph.java @@ -67,13 +67,24 @@ import java.util.*; public class ReadThreadingGraph extends BaseGraph { /** - * Edge factory that creates non-reference multiplicity 1 edges + * Edge factory that encapsulates the numPruningSamples assembly parameter */ private static class MyEdgeFactory implements EdgeFactory { - @Override - public MultiSampleEdge createEdge(MultiDeBruijnVertex sourceVertex, MultiDeBruijnVertex targetVertex) { - return new MultiSampleEdge(false, 1); + final int numPruningSamples; + + public MyEdgeFactory(int numPruningSamples) { + this.numPruningSamples = numPruningSamples; } + + @Override + public MultiSampleEdge createEdge(final MultiDeBruijnVertex sourceVertex, final MultiDeBruijnVertex targetVertex) { + return new MultiSampleEdge(false, 1, numPruningSamples); + } + + public MultiSampleEdge createEdge(final boolean isRef, final int multiplicity) { + return new MultiSampleEdge(isRef, multiplicity, numPruningSamples); + } + } private final static Logger logger = Logger.getLogger(ReadThreadingGraph.class); @@ -88,7 +99,7 @@ public class ReadThreadingGraph extends BaseGraph> pending = new LinkedHashMap>(); + private final Map> pending = new LinkedHashMap<>(); /** * A set of non-unique kmers that cannot be used as merge points in the graph @@ -117,19 +128,19 @@ public class ReadThreadingGraph extends BaseGraph= 1 */ - protected ReadThreadingGraph(final int kmerSize, final boolean debugGraphTransformations, final byte minBaseQualityToUseInAssembly) { - super(kmerSize, new MyEdgeFactory()); + protected ReadThreadingGraph(final int kmerSize, final boolean debugGraphTransformations, final byte minBaseQualityToUseInAssembly, final int numPruningSamples) { + super(kmerSize, new MyEdgeFactory(numPruningSamples)); if ( kmerSize < 1 ) throw new IllegalArgumentException("bad minkKmerSize " + kmerSize); this.kmerSize = kmerSize; @@ -324,7 +335,7 @@ public class ReadThreadingGraph extends BaseGraph countsPerSample; + final int numSamplesPruning; + public MultiplicityTestProvider(final List countsPerSample, final int numSamplesPruning) { + this.countsPerSample = countsPerSample; + this.numSamplesPruning = numSamplesPruning; + } + } + @DataProvider(name = "MultiplicityData") public Object[][] makeMultiplicityData() { - List tests = new ArrayList(); + List tests = new ArrayList<>(); final List countsPerSample = Arrays.asList(0, 1, 2, 3, 4, 5); - for ( final int nSamples : Arrays.asList(1, 2, 3, 4, 5)) { - for ( final List perm : Utils.makePermutations(countsPerSample, nSamples, false) ) { - tests.add(new Object[]{perm}); + for ( final int numSamplesPruning : Arrays.asList(1, 2, 3) ) { + for ( final int nSamples : Arrays.asList(1, 2, 3, 4, 5)) { + for ( final List perm : Utils.makePermutations(countsPerSample, nSamples, false) ) { + tests.add(new Object[]{new MultiplicityTestProvider(perm, numSamplesPruning)}); + } } } @@ -77,15 +87,15 @@ public class MultiSampleEdgeUnitTest extends BaseTest { * Example testng test using MyDataProvider */ @Test(dataProvider = "MultiplicityData") - public void testMultiplicity(final List countsPerSample) { - final MultiSampleEdge edge = new MultiSampleEdge(false, 0); + public void testMultiplicity(final MultiplicityTestProvider cfg) { + final MultiSampleEdge edge = new MultiSampleEdge(false, 0, cfg.numSamplesPruning); Assert.assertEquals(edge.getMultiplicity(), 0); Assert.assertEquals(edge.getPruningMultiplicity(), 0); int total = 0; - for ( int i = 0; i < countsPerSample.size(); i++ ) { + for ( int i = 0; i < cfg.countsPerSample.size(); i++ ) { int countForSample = 0; - for ( int count = 0; count < countsPerSample.get(i); count++ ) { + for ( int count = 0; count < cfg.countsPerSample.get(i); count++ ) { edge.incMultiplicity(1); total++; countForSample++; @@ -95,9 +105,11 @@ public class MultiSampleEdgeUnitTest extends BaseTest { edge.flushSingleSampleMultiplicity(); } - final int max = MathUtils.arrayMax(ArrayUtils.toPrimitive(countsPerSample.toArray(new Integer[countsPerSample.size()]))); + ArrayList counts = new ArrayList<>(cfg.countsPerSample); + counts.add(0); + Collections.sort(counts); + final int prune = counts.get(Math.max(counts.size() - cfg.numSamplesPruning, 0)); Assert.assertEquals(edge.getMultiplicity(), total); - Assert.assertEquals(edge.getPruningMultiplicity(), max); - Assert.assertEquals(edge.getMaxSingleSampleMultiplicity(), max); + Assert.assertEquals(edge.getPruningMultiplicity(), prune); } }