Adding new pruning parameter to ReadThreadingAssembler

-- numPruningSamples allows one to specify that the minPruning factor must be met by this many samples for a path to be considered good (e.g. seen twice in three samples). By default this is just one sample.
-- adding unit test to test this new functionality
This commit is contained in:
Ryan Poplin 2013-06-17 14:02:54 -04:00
parent a6a58cbc78
commit 8511c4385c
5 changed files with 80 additions and 48 deletions

View File

@ -80,8 +80,6 @@ import org.broadinstitute.sting.utils.activeregion.ActivityProfileState;
import org.broadinstitute.sting.utils.clipping.ReadClipper;
import org.broadinstitute.sting.utils.exceptions.UserException;
import org.broadinstitute.sting.utils.fasta.CachingIndexedFastaSequenceFile;
import org.broadinstitute.sting.utils.fragments.FragmentCollection;
import org.broadinstitute.sting.utils.fragments.FragmentUtils;
import org.broadinstitute.sting.utils.genotyper.PerReadAlleleLikelihoodMap;
import org.broadinstitute.sting.utils.haplotype.*;
import org.broadinstitute.sting.utils.haplotypeBAMWriter.HaplotypeBAMWriter;
@ -270,6 +268,10 @@ public class HaplotypeCaller extends ActiveRegionWalker<List<VariantContext>, In
@Argument(fullName="dontIncreaseKmerSizesForCycles", shortName="dontIncreaseKmerSizesForCycles", doc="Should we disable the iterating over kmer sizes when graph cycles are detected?", required = false)
protected boolean dontIncreaseKmerSizesForCycles = false;
@Advanced
@Argument(fullName="numPruningSamples", shortName="numPruningSamples", doc="The number of samples that must pass the minPuning factor in order for the path to be kept", required = false)
protected int numPruningSamples = 1;
/**
* Assembly graph can be quite complex, and could imply a very large number of possible haplotypes. Each haplotype
* considered requires N PairHMM evaluations if there are N reads across all samples. In order to control the
@ -539,7 +541,7 @@ public class HaplotypeCaller extends ActiveRegionWalker<List<VariantContext>, In
final int maxAllowedPathsForReadThreadingAssembler = Math.max(maxPathsPerSample * nSamples, MIN_PATHS_PER_GRAPH);
assemblyEngine = useDebruijnAssembler
? new DeBruijnAssembler(minKmerForDebruijnAssembler, onlyUseKmerSizeForDebruijnAssembler)
: new ReadThreadingAssembler(maxAllowedPathsForReadThreadingAssembler, kmerSizes, dontIncreaseKmerSizesForCycles);
: new ReadThreadingAssembler(maxAllowedPathsForReadThreadingAssembler, kmerSizes, dontIncreaseKmerSizesForCycles, numPruningSamples);
assemblyEngine.setErrorCorrectKmers(errorCorrectKmers);
assemblyEngine.setPruneFactor(MIN_PRUNE_FACTOR);

View File

@ -46,6 +46,8 @@
package org.broadinstitute.sting.gatk.walkers.haplotypecaller.graphs;
import java.util.PriorityQueue;
/**
* edge class for connecting nodes in the graph that tracks some per-sample information
*
@ -63,32 +65,43 @@ package org.broadinstitute.sting.gatk.walkers.haplotypecaller.graphs;
* e.getPruningMultiplicity() // = 3
*/
public class MultiSampleEdge extends BaseEdge {
private int maxSingleSampleMultiplicity, currentSingleSampleMultiplicity;
private int currentSingleSampleMultiplicity;
private final int singleSampleCapacity;
private final PriorityQueue<Integer> singleSampleMultiplicities;
/**
* Create a new MultiSampleEdge with weight multiplicity and, if isRef == true, indicates a path through the reference
*
* @param isRef indicates whether this edge is a path through the reference
* @param multiplicity the number of observations of this edge in this sample
* @param singleSampleCapacity the max number of samples to track edge multiplicities
*/
public MultiSampleEdge(final boolean isRef, final int multiplicity) {
public MultiSampleEdge(final boolean isRef, final int multiplicity, final int singleSampleCapacity) {
super(isRef, multiplicity);
maxSingleSampleMultiplicity = multiplicity;
if( singleSampleCapacity <= 0 ) { throw new IllegalArgumentException("singleSampleCapacity must be > 0 but found: " + singleSampleCapacity); }
singleSampleMultiplicities = new PriorityQueue<>(singleSampleCapacity);
singleSampleMultiplicities.add(multiplicity);
currentSingleSampleMultiplicity = multiplicity;
this.singleSampleCapacity = singleSampleCapacity;
}
@Override
public MultiSampleEdge copy() {
return new MultiSampleEdge(isRef(), getMultiplicity()); // TODO -- should I copy values for other features?
return new MultiSampleEdge(isRef(), getMultiplicity(), singleSampleCapacity); // TODO -- should I copy values for other features?
}
/**
* update the max single sample multiplicity based on the current single sample multiplicity, and
* update the single sample multiplicities by adding the current single sample multiplicity to the priority queue, and
* reset the current single sample multiplicity to 0.
*/
public void flushSingleSampleMultiplicity() {
if ( currentSingleSampleMultiplicity > maxSingleSampleMultiplicity )
maxSingleSampleMultiplicity = currentSingleSampleMultiplicity;
singleSampleMultiplicities.add(currentSingleSampleMultiplicity);
if( singleSampleMultiplicities.size() == singleSampleCapacity + 1 ) {
singleSampleMultiplicities.poll(); // remove the lowest multiplicity from the list
} else if( singleSampleMultiplicities.size() > singleSampleCapacity + 1 ) {
throw new IllegalStateException("Somehow the per sample multiplicity list has grown too big: " + singleSampleMultiplicities);
}
currentSingleSampleMultiplicity = 0;
}
@ -100,20 +113,12 @@ public class MultiSampleEdge extends BaseEdge {
@Override
public int getPruningMultiplicity() {
return getMaxSingleSampleMultiplicity();
return singleSampleMultiplicities.peek();
}
@Override
public String getDotLabel() {
return super.getDotLabel() + "/" + getMaxSingleSampleMultiplicity();
}
/**
* Get the maximum multiplicity for this edge seen in any single sample
* @return an integer >= 0
*/
public int getMaxSingleSampleMultiplicity() {
return maxSingleSampleMultiplicity;
return super.getDotLabel() + "/" + getPruningMultiplicity();
}
/** only provided for testing purposes */

View File

@ -71,6 +71,7 @@ public class ReadThreadingAssembler extends LocalAssemblyEngine {
private final int maxAllowedPathsForReadThreadingAssembler;
private final boolean dontIncreaseKmerSizesForCycles;
private final int numPruningSamples;
private boolean requireReasonableNumberOfPaths = false;
protected boolean removePathsNotConnectedToRef = true;
private boolean justReturnRawGraph = false;
@ -80,15 +81,16 @@ public class ReadThreadingAssembler extends LocalAssemblyEngine {
this(DEFAULT_NUM_PATHS_PER_GRAPH, Arrays.asList(25));
}
public ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler, final List<Integer> kmerSizes, final boolean dontIncreaseKmerSizesForCycles) {
public ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler, final List<Integer> kmerSizes, final boolean dontIncreaseKmerSizesForCycles, final int numPruningSamples) {
super(maxAllowedPathsForReadThreadingAssembler);
this.kmerSizes = kmerSizes;
this.maxAllowedPathsForReadThreadingAssembler = maxAllowedPathsForReadThreadingAssembler;
this.dontIncreaseKmerSizesForCycles = dontIncreaseKmerSizesForCycles;
this.numPruningSamples = numPruningSamples;
}
public ReadThreadingAssembler(final int maxAllowedPathsForReadThreadingAssembler, final List<Integer> kmerSizes) {
this(maxAllowedPathsForReadThreadingAssembler, kmerSizes, true);
this(maxAllowedPathsForReadThreadingAssembler, kmerSizes, true, 1);
}
/** for testing purposes */
@ -139,7 +141,7 @@ public class ReadThreadingAssembler extends LocalAssemblyEngine {
final int kmerSize,
final List<Haplotype> activeAlleleHaplotypes,
final boolean allowLowComplexityGraphs) {
final ReadThreadingGraph rtgraph = new ReadThreadingGraph(kmerSize, debugGraphTransformations, minBaseQualityToUseInAssembly);
final ReadThreadingGraph rtgraph = new ReadThreadingGraph(kmerSize, debugGraphTransformations, minBaseQualityToUseInAssembly, numPruningSamples);
// add the reference sequence to the graph
rtgraph.addSequence("ref", refHaplotype.getBases(), null, true);

View File

@ -67,13 +67,24 @@ import java.util.*;
public class ReadThreadingGraph extends BaseGraph<MultiDeBruijnVertex, MultiSampleEdge> {
/**
* Edge factory that creates non-reference multiplicity 1 edges
* Edge factory that encapsulates the numPruningSamples assembly parameter
*/
private static class MyEdgeFactory implements EdgeFactory<MultiDeBruijnVertex, MultiSampleEdge> {
@Override
public MultiSampleEdge createEdge(MultiDeBruijnVertex sourceVertex, MultiDeBruijnVertex targetVertex) {
return new MultiSampleEdge(false, 1);
final int numPruningSamples;
public MyEdgeFactory(int numPruningSamples) {
this.numPruningSamples = numPruningSamples;
}
@Override
public MultiSampleEdge createEdge(final MultiDeBruijnVertex sourceVertex, final MultiDeBruijnVertex targetVertex) {
return new MultiSampleEdge(false, 1, numPruningSamples);
}
public MultiSampleEdge createEdge(final boolean isRef, final int multiplicity) {
return new MultiSampleEdge(isRef, multiplicity, numPruningSamples);
}
}
private final static Logger logger = Logger.getLogger(ReadThreadingGraph.class);
@ -88,7 +99,7 @@ public class ReadThreadingGraph extends BaseGraph<MultiDeBruijnVertex, MultiSamp
/**
* Sequences added for read threading before we've actually built the graph
*/
private final Map<String, List<SequenceForKmers>> pending = new LinkedHashMap<String, List<SequenceForKmers>>();
private final Map<String, List<SequenceForKmers>> pending = new LinkedHashMap<>();
/**
* A set of non-unique kmers that cannot be used as merge points in the graph
@ -117,19 +128,19 @@ public class ReadThreadingGraph extends BaseGraph<MultiDeBruijnVertex, MultiSamp
private boolean alreadyBuilt;
public ReadThreadingGraph() {
this(25, false, (byte)6);
this(25, false, (byte)6, 1);
}
public ReadThreadingGraph(final int kmerSize) {
this(kmerSize, false, (byte)6);
this(kmerSize, false, (byte)6, 1);
}
/**
* Create a new ReadThreadingAssembler using kmerSize for matching
* @param kmerSize must be >= 1
*/
protected ReadThreadingGraph(final int kmerSize, final boolean debugGraphTransformations, final byte minBaseQualityToUseInAssembly) {
super(kmerSize, new MyEdgeFactory());
protected ReadThreadingGraph(final int kmerSize, final boolean debugGraphTransformations, final byte minBaseQualityToUseInAssembly, final int numPruningSamples) {
super(kmerSize, new MyEdgeFactory(numPruningSamples));
if ( kmerSize < 1 ) throw new IllegalArgumentException("bad minkKmerSize " + kmerSize);
this.kmerSize = kmerSize;
@ -324,7 +335,7 @@ public class ReadThreadingGraph extends BaseGraph<MultiDeBruijnVertex, MultiSamp
final int altIndexToMerge = Math.max(danglingTailMergeResult.cigar.getReadLength() - matchingSuffix - 1, 0);
final int refIndexToMerge = lastRefIndex - matchingSuffix + 1;
addEdge(danglingTailMergeResult.danglingPath.get(altIndexToMerge), danglingTailMergeResult.referencePath.get(refIndexToMerge), new MultiSampleEdge(false, 1));
addEdge(danglingTailMergeResult.danglingPath.get(altIndexToMerge), danglingTailMergeResult.referencePath.get(refIndexToMerge), ((MyEdgeFactory)getEdgeFactory()).createEdge(false, 1));
return 1;
}
@ -708,7 +719,7 @@ public class ReadThreadingGraph extends BaseGraph<MultiDeBruijnVertex, MultiSamp
// either use our unique merge vertex, or create a new one in the chain
final MultiDeBruijnVertex nextVertex = uniqueMergeVertex == null ? createVertex(kmer) : uniqueMergeVertex;
addEdge(prevVertex, nextVertex, new MultiSampleEdge(isRef, count));
addEdge(prevVertex, nextVertex, ((MyEdgeFactory)getEdgeFactory()).createEdge(isRef, count));
return nextVertex;
}

View File

@ -54,19 +54,29 @@ import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.*;
public class MultiSampleEdgeUnitTest extends BaseTest {
private class MultiplicityTestProvider {
final List<Integer> countsPerSample;
final int numSamplesPruning;
public MultiplicityTestProvider(final List<Integer> countsPerSample, final int numSamplesPruning) {
this.countsPerSample = countsPerSample;
this.numSamplesPruning = numSamplesPruning;
}
}
@DataProvider(name = "MultiplicityData")
public Object[][] makeMultiplicityData() {
List<Object[]> tests = new ArrayList<Object[]>();
List<Object[]> tests = new ArrayList<>();
final List<Integer> countsPerSample = Arrays.asList(0, 1, 2, 3, 4, 5);
for ( final int nSamples : Arrays.asList(1, 2, 3, 4, 5)) {
for ( final List<Integer> perm : Utils.makePermutations(countsPerSample, nSamples, false) ) {
tests.add(new Object[]{perm});
for ( final int numSamplesPruning : Arrays.asList(1, 2, 3) ) {
for ( final int nSamples : Arrays.asList(1, 2, 3, 4, 5)) {
for ( final List<Integer> perm : Utils.makePermutations(countsPerSample, nSamples, false) ) {
tests.add(new Object[]{new MultiplicityTestProvider(perm, numSamplesPruning)});
}
}
}
@ -77,15 +87,15 @@ public class MultiSampleEdgeUnitTest extends BaseTest {
* Example testng test using MyDataProvider
*/
@Test(dataProvider = "MultiplicityData")
public void testMultiplicity(final List<Integer> countsPerSample) {
final MultiSampleEdge edge = new MultiSampleEdge(false, 0);
public void testMultiplicity(final MultiplicityTestProvider cfg) {
final MultiSampleEdge edge = new MultiSampleEdge(false, 0, cfg.numSamplesPruning);
Assert.assertEquals(edge.getMultiplicity(), 0);
Assert.assertEquals(edge.getPruningMultiplicity(), 0);
int total = 0;
for ( int i = 0; i < countsPerSample.size(); i++ ) {
for ( int i = 0; i < cfg.countsPerSample.size(); i++ ) {
int countForSample = 0;
for ( int count = 0; count < countsPerSample.get(i); count++ ) {
for ( int count = 0; count < cfg.countsPerSample.get(i); count++ ) {
edge.incMultiplicity(1);
total++;
countForSample++;
@ -95,9 +105,11 @@ public class MultiSampleEdgeUnitTest extends BaseTest {
edge.flushSingleSampleMultiplicity();
}
final int max = MathUtils.arrayMax(ArrayUtils.toPrimitive(countsPerSample.toArray(new Integer[countsPerSample.size()])));
ArrayList<Integer> counts = new ArrayList<>(cfg.countsPerSample);
counts.add(0);
Collections.sort(counts);
final int prune = counts.get(Math.max(counts.size() - cfg.numSamplesPruning, 0));
Assert.assertEquals(edge.getMultiplicity(), total);
Assert.assertEquals(edge.getPruningMultiplicity(), max);
Assert.assertEquals(edge.getMaxSingleSampleMultiplicity(), max);
Assert.assertEquals(edge.getPruningMultiplicity(), prune);
}
}