Fast algorithm for determining which kmers are good in a read

-- old algorithm was O(kmerSize * readLen) for each read.  New algorithm is O(readLen)
-- Added real unit tests for the addKmersFromReads to the graph.  Using a builder is great because we can create a MockBuilder that captures all of the calls, and then verify that all of the added kmers are the ones we'd expect.
This commit is contained in:
Mark DePristo 2013-04-08 17:19:08 -04:00
parent bf42be44fc
commit fb86887bf2
5 changed files with 153 additions and 33 deletions

View File

@ -55,11 +55,11 @@ import org.apache.commons.lang.ArrayUtils;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.gatk.walkers.haplotypecaller.graphs.*;
import org.broadinstitute.sting.utils.GenomeLoc;
import org.broadinstitute.sting.utils.exceptions.UserException;
import org.broadinstitute.sting.utils.haplotype.Haplotype;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.SWPairwiseAlignment;
import org.broadinstitute.sting.utils.activeregion.ActiveRegion;
import org.broadinstitute.sting.utils.exceptions.UserException;
import org.broadinstitute.sting.utils.haplotype.Haplotype;
import org.broadinstitute.sting.utils.sam.AlignmentUtils;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import org.broadinstitute.sting.utils.sam.ReadUtils;
@ -283,30 +283,27 @@ public class DeBruijnAssembler extends LocalAssemblyEngine {
final byte[] sequence = read.getReadBases();
final byte[] qualities = read.getBaseQualities();
final byte[] reducedReadCounts = read.getReducedReadCounts(); // will be null if read is not reduced
if( sequence.length > kmerLength + KMER_OVERLAP ) {
final int kmersInSequence = sequence.length - kmerLength + 1;
for( int iii = 0; iii < kmersInSequence - 1; iii++ ) {
// TODO -- this is quite expense as it does O(kmerLength^2 work) per read
// if the qualities of all the bases in the kmers are high enough
boolean badKmer = false;
for( int jjj = iii; jjj < iii + kmerLength + 1; jjj++) {
if( qualities[jjj] < minBaseQualityToUseInAssembly ) {
badKmer = true;
break;
}
}
if( !badKmer ) {
if ( sequence.length > kmerLength + KMER_OVERLAP ) {
int lastGood = -1; // the index of the last good base we've seen
for( int end = 0; end < sequence.length; end++ ) {
if ( qualities[end] < minBaseQualityToUseInAssembly ) {
lastGood = -1; // reset the last good base
} else if ( lastGood == -1 ) {
lastGood = end; // we're at a good base, the last good one is us
} else if ( end - kmerLength >= lastGood ) {
// end - kmerLength (the start) is after the lastGood base, so that kmer is good
final int start = end - kmerLength;
// how many observations of this kmer have we seen? A normal read counts for 1, but
// a reduced read might imply a higher multiplicity for our the edge
int countNumber = 1;
if( read.isReducedRead() ) {
if ( read.isReducedRead() ) {
// compute mean number of reduced read counts in current kmer span
// precise rounding can make a difference with low consensus counts
// TODO -- optimization: should extend arrayMax function to take start stop values
countNumber = MathUtils.arrayMax(Arrays.copyOfRange(reducedReadCounts, iii, iii + kmerLength));
countNumber = MathUtils.arrayMax(Arrays.copyOfRange(reducedReadCounts, start, end));
}
builder.addKmerPairFromSeqToGraph(sequence, iii, countNumber);
builder.addKmerPairFromSeqToGraph(sequence, start, countNumber);
}
}
}

View File

@ -46,32 +46,56 @@
package org.broadinstitute.sting.gatk.walkers.haplotypecaller;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.gatk.walkers.haplotypecaller.graphs.DeBruijnGraph;
/**
* Fast approach to building a DeBruijnGraph
*
* Follows the model:
*
* for each X that has bases for the final graph:
* addKmer pair (single kmer with kmer size + 1 spanning the pair)
*
* flushKmersToGraph
*
* User: depristo
* Date: 4/7/13
* Time: 4:14 PM
*/
public class DeBruijnGraphBuilder {
private final static Logger logger = Logger.getLogger(DeBruijnGraphBuilder.class);
/** The size of the kmer graph we want to build */
private final int kmerSize;
private final DeBruijnGraph graph;
private final KmerCounter counter;
/** The graph we're going to add kmers to */
private final DeBruijnGraph graph;
/** keeps counts of all kmer pairs added since the last flush */
private final KMerCounter counter;
/**
* Create a new builder that will write out kmers to graph
*
* @param graph a non-null graph that can contain already added kmers
*/
public DeBruijnGraphBuilder(final DeBruijnGraph graph) {
if ( graph == null ) throw new IllegalArgumentException("Graph cannot be null");
this.kmerSize = graph.getKmerSize();
this.graph = graph;
this.counter = new KmerCounter(kmerSize + 1);
this.counter = new KMerCounter(kmerSize + 1);
}
/**
* The graph we're building
* @return a non-null graph
*/
public DeBruijnGraph getGraph() {
return graph;
}
/**
* The kmer size of our graph
* @return positive integer
*/
public int getKmerSize() {
return kmerSize;
}
@ -93,14 +117,30 @@ public class DeBruijnGraphBuilder {
addKmerPair(kmerPair, multiplicity);
}
/**
* Add a single kmer pair to this builder
* @param kmerPair a kmer pair is a single kmer that has kmerSize + 1 bp, where 0 -> kmersize and 1 -> kmersize + 1
* will have an edge added to this
* @param multiplicity the desired multiplicity of this edge
*/
public void addKmerPair(final Kmer kmerPair, final int multiplicity) {
if ( kmerPair.length() != kmerSize + 1 ) throw new IllegalArgumentException("kmer pair must be of length kmerSize + 1 = " + kmerSize + 1 + " but got " + kmerPair.length());
counter.addKmer(kmerPair, multiplicity);
}
/**
* Flushes the currently added kmers to the graph
*
* After this function is called the builder is reset to an empty state
*
* This flushing is expensive, so many kmers should be added to the builder before flushing. The most
* efficient workflow is to add all of the kmers of a particular class (all ref bases, or all read bases)
* then and do one flush when completed
*
* @param addRefEdges should the kmers present in the builder be added to the graph with isRef = true for the edges?
*/
public void flushKmersToGraph(final boolean addRefEdges) {
for ( final KmerCounter.CountedKmer countedKmer : counter.getCountedKmers() ) {
for ( final KMerCounter.CountedKmer countedKmer : counter.getCountedKmers() ) {
final byte[] first = countedKmer.getKmer().subKmer(0, kmerSize).bases();
final byte[] second = countedKmer.getKmer().subKmer(1, kmerSize).bases();
graph.addKmersToGraph(first, second, addRefEdges, countedKmer.getCount());

View File

@ -59,8 +59,8 @@ import java.util.Map;
* Date: 3/8/13
* Time: 1:16 PM
*/
public class KmerCounter {
//private final static Logger logger = Logger.getLogger(KmerCounter.class);
public class KMerCounter {
//private final static Logger logger = Logger.getLogger(KMerCounter.class);
/**
* A map of for each kmer to its num occurrences in addKmers
@ -73,7 +73,7 @@ public class KmerCounter {
*
* @param kmerLength the length of kmers we'll be counting to error correct, must be >= 1
*/
public KmerCounter(final int kmerLength) {
public KMerCounter(final int kmerLength) {
if ( kmerLength < 1 ) throw new IllegalArgumentException("kmerLength must be > 0 but got " + kmerLength);
this.kmerLength = kmerLength;
}
@ -89,10 +89,17 @@ public class KmerCounter {
return counted == null ? 0 : counted.count;
}
/**
* Get an unordered collection of the counted kmers in this counter
* @return a non-null collection
*/
public Collection<CountedKmer> getCountedKmers() {
return countsByKMer.values();
}
/**
* Remove all current counts, resetting the counter to an empty state
*/
public void clear() {
countsByKMer.clear();
}
@ -117,7 +124,7 @@ public class KmerCounter {
@Override
public String toString() {
final StringBuilder b = new StringBuilder("KmerCounter{");
final StringBuilder b = new StringBuilder("KMerCounter{");
b.append("counting ").append(countsByKMer.size()).append(" distinct kmers");
b.append("\n}");
return b.toString();

View File

@ -55,13 +55,16 @@ package org.broadinstitute.sting.gatk.walkers.haplotypecaller;
import net.sf.samtools.Cigar;
import net.sf.samtools.CigarElement;
import net.sf.samtools.CigarOperator;
import net.sf.samtools.SAMFileHeader;
import org.broadinstitute.sting.BaseTest;
import org.broadinstitute.sting.gatk.walkers.haplotypecaller.graphs.DeBruijnGraph;
import org.broadinstitute.sting.utils.haplotype.Haplotype;
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.sam.AlignmentUtils;
import org.broadinstitute.sting.utils.sam.ArtificialSAMUtils;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.util.*;
@ -122,4 +125,77 @@ public class DeBruijnAssemblerUnitTest extends BaseTest {
}
}
private static class MockBuilder extends DeBruijnGraphBuilder {
public final List<Kmer> addedPairs = new LinkedList<Kmer>();
private MockBuilder(final int kmerSize) {
super(new DeBruijnGraph(kmerSize));
}
@Override
public void addKmerPair(Kmer kmerPair, int multiplicity) {
logger.info("addKmerPair" + kmerPair);
addedPairs.add(kmerPair);
}
@Override
public void flushKmersToGraph(boolean addRefEdges) {
// do nothing
}
}
@DataProvider(name = "AddReadKmersToGraph")
public Object[][] makeAddReadKmersToGraphData() {
List<Object[]> tests = new ArrayList<Object[]>();
// this functionality can be adapted to provide input data for whatever you might want in your data
final String bases = "ACGTAACCGGTTAAACCCGGGTTT";
final int readLen = bases.length();
final List<Integer> allBadStarts = new ArrayList<Integer>(readLen);
for ( int i = 0; i < readLen; i++ ) allBadStarts.add(i);
for ( final int kmerSize : Arrays.asList(3, 4, 5) ) {
for ( final int nBadQuals : Arrays.asList(0, 1, 2) ) {
for ( final List<Integer> badStarts : Utils.makePermutations(allBadStarts, nBadQuals, false) ) {
tests.add(new Object[]{bases, kmerSize, badStarts});
}
}
}
return tests.toArray(new Object[][]{});
}
@Test(dataProvider = "AddReadKmersToGraph")
public void testAddReadKmersToGraph(final String bases, final int kmerSize, final List<Integer> badQualsSites) {
final int readLen = bases.length();
final DeBruijnAssembler assembler = new DeBruijnAssembler();
final MockBuilder builder = new MockBuilder(kmerSize);
final SAMFileHeader header = ArtificialSAMUtils.createArtificialSamHeader(1, 1, 1000);
final byte[] quals = Utils.dupBytes((byte)20, bases.length());
for ( final int badSite : badQualsSites ) quals[badSite] = 0;
final GATKSAMRecord read = ArtificialSAMUtils.createArtificialRead(header, "myRead", 0, 1, readLen);
read.setReadBases(bases.getBytes());
read.setBaseQualities(quals);
final Set<String> expectedBases = new HashSet<String>();
final Set<Integer> expectedStarts = new LinkedHashSet<Integer>();
for ( int i = 0; i < readLen; i++) {
boolean good = true;
for ( int j = 0; j < kmerSize + 1; j++ ) { // +1 is for pairing
good &= i + j < readLen && quals[i+j] >= assembler.getMinBaseQualityToUseInAssembly();
}
if ( good ) {
expectedStarts.add(i);
expectedBases.add(bases.substring(i, i + kmerSize + 1));
}
}
assembler.addReadKmersToGraph(builder, Arrays.asList(read));
Assert.assertEquals(builder.addedPairs.size(), expectedStarts.size());
for ( final Kmer addedKmer : builder.addedPairs ) {
Assert.assertTrue(expectedBases.contains(new String(addedKmer.bases())), "Couldn't find kmer " + addedKmer + " among all expected kmers " + expectedBases);
}
}
}

View File

@ -50,10 +50,10 @@ import org.broadinstitute.sting.BaseTest;
import org.testng.Assert;
import org.testng.annotations.Test;
public class KmerCounterUnitTest extends BaseTest {
public class KMerCounterCaseFixUnitTest extends BaseTest {
@Test
public void testMyData() {
final KmerCounter counter = new KmerCounter(3);
final KMerCounter counter = new KMerCounter(3);
Assert.assertNotNull(counter.toString());
@ -78,7 +78,7 @@ public class KmerCounterUnitTest extends BaseTest {
Assert.assertNotNull(counter.toString());
}
private void testCounting(final KmerCounter counter, final String in, final int expectedCount) {
private void testCounting(final KMerCounter counter, final String in, final int expectedCount) {
Assert.assertEquals(counter.getKmerCount(new Kmer(in)), expectedCount);
}
}