Cap the computational cost of the kmer based error correction in the DeBruijnGraph

-- Simply don't do more than MAX_CORRECTION_OPS_TO_ALLOW = 5000 * 1000 operations to correct a graph.  If the number of ops would exceed this threshold, the original graph is used.
-- Overall the algorithm is just extremely computational expensive, and actually doesn't implement the correct correction.  So we live with this limitations while we continue to explore better algorithms
-- Updating MD5s to reflect changes in assembly algorithms
This commit is contained in:
Mark DePristo 2013-03-20 22:40:10 -04:00
parent d94b3f85bc
commit aa7f172b18
4 changed files with 127 additions and 37 deletions

View File

@ -90,7 +90,8 @@ public class DeBruijnGraph extends BaseGraph<DeBruijnVertex> {
/**
* Error correct the kmers in this graph, returning a new graph built from those error corrected kmers
* @return a freshly allocated graph
* @return an error corrected version of this (freshly allocated graph) or simply this graph if for some reason
* we cannot actually do the error correction
*/
protected DeBruijnGraph errorCorrect() {
final KMerErrorCorrector corrector = new KMerErrorCorrector(getKmerSize(), 1, 1, 5); // TODO -- should be static variables
@ -101,19 +102,23 @@ public class DeBruijnGraph extends BaseGraph<DeBruijnVertex> {
corrector.addKmer(kmer, e.isRef() ? 1000 : e.getMultiplicity());
}
}
corrector.computeErrorCorrectionMap();
final DeBruijnGraph correctedGraph = new DeBruijnGraph(getKmerSize());
if ( corrector.computeErrorCorrectionMap() ) {
final DeBruijnGraph correctedGraph = new DeBruijnGraph(getKmerSize());
for( final BaseEdge e : edgeSet() ) {
final byte[] source = corrector.getErrorCorrectedKmer(getEdgeSource(e).getSequence());
final byte[] target = corrector.getErrorCorrectedKmer(getEdgeTarget(e).getSequence());
if ( source != null && target != null ) {
correctedGraph.addKmersToGraph(source, target, e.isRef(), e.getMultiplicity());
for( final BaseEdge e : edgeSet() ) {
final byte[] source = corrector.getErrorCorrectedKmer(getEdgeSource(e).getSequence());
final byte[] target = corrector.getErrorCorrectedKmer(getEdgeTarget(e).getSequence());
if ( source != null && target != null ) {
correctedGraph.addKmersToGraph(source, target, e.isRef(), e.getMultiplicity());
}
}
}
return correctedGraph;
return correctedGraph;
} else {
// the error correction wasn't possible, simply return this graph
return this;
}
}
/**

View File

@ -46,6 +46,8 @@
package org.broadinstitute.sting.gatk.walkers.haplotypecaller;
import org.apache.log4j.Logger;
import java.util.*;
/**
@ -69,15 +71,54 @@ import java.util.*;
* TODO -- be added to hashmaps (more specifically, those don't implement .equals). A more efficient
* TODO -- version would use the byte[] directly
*
* TODO -- this is just not the right way to implement error correction in the graph. Basically, the
* right way to think about this is error correcting reads:
*
* *
* ACTGAT
* ACT
* CTG
* TGA
* GAT
*
* Now suppose the G is an error. What you are doing is asking for each 3mer in the read whether it's high quality
* or not. Suppose the answer is
*
* *
* ACTGAT
* ACT -- yes
* CTG -- no [CTG is unusual]
* TGA -- no [TGA is unusual]
* GAT -- yes [maybe GAT is just common, even through its an error]
*
* As we do this process it's clear how we can figure out which positions in the read likely harbor errors, and
* then go search around those bases in the read in an attempt to fix the read. We don't have to compute for
* every bad kmer it's best match, as that's just not the problem we are thinking looking to solve. We are actually
* looking for a change to a read such that all spanning kmers are well-supported. This class is being disabled
* until we figure implement this change.
*
*
* User: depristo
* Date: 3/8/13
* Time: 1:16 PM
*/
public class KMerErrorCorrector {
private final static Logger logger = Logger.getLogger(KMerErrorCorrector.class);
/**
* The maximum number of bad kmer -> good kmer correction operations we'll consider doing before
* aborting for efficiency reasons. Basically, the current algorithm sucks, and is O(n^2), and
* so we cannot simply error correct 10K bad kmers against a db of 100K kmers if we ever want
* to finish running in a reasonable amount of time. This isn't worth fixing because fundamentally
* the entire error correction algorithm is just not right (i.e., it's correct but not ideal conceptually
* so we'll just fix the conceptual problem than the performance issue).
*/
private final static int MAX_CORRECTION_OPS_TO_ALLOW = 5000 * 1000;
/**
* A map of for each kmer to its num occurrences in addKmers
*/
Map<String, Integer> countsByKMer = new HashMap<String, Integer>();
Map<String, CountedKmer> countsByKMer = new HashMap<String, CountedKmer>();
/**
* A map from raw kmer -> error corrected kmer
@ -154,35 +195,45 @@ public class KMerErrorCorrector {
* Indicate that no more kmers will be added to the kmer error corrector, so that the
* error correction data structure should be computed from the added kmers. Enabled calls
* to getErrorCorrectedKmer, and disable calls to addKmer.
*
* @return true if the error correction map could actually be computed, false if for any reason
* (efficiency, memory, we're out to lunch) a correction map couldn't be created.
*/
public void computeErrorCorrectionMap() {
public boolean computeErrorCorrectionMap() {
if ( countsByKMer == null )
throw new IllegalStateException("computeErrorCorrectionMap can only be called once");
final LinkedList<String> needsCorrection = new LinkedList<String>();
final LinkedList<String> goodKmers = new LinkedList<String>();
final LinkedList<CountedKmer> needsCorrection = new LinkedList<CountedKmer>();
final List<CountedKmer> goodKmers = new ArrayList<CountedKmer>(countsByKMer.size());
rawToErrorCorrectedMap = new HashMap<String, String>();
for ( Map.Entry<String, Integer> kmerCounts: countsByKMer.entrySet() ) {
if ( kmerCounts.getValue() <= maxCountToCorrect )
needsCorrection.add(kmerCounts.getKey());
rawToErrorCorrectedMap = new HashMap<String, String>(countsByKMer.size());
for ( final CountedKmer countedKmer: countsByKMer.values() ) {
if ( countedKmer.count <= maxCountToCorrect )
needsCorrection.add(countedKmer);
else {
// todo -- optimization could make not in map mean ==
rawToErrorCorrectedMap.put(kmerCounts.getKey(), kmerCounts.getKey());
rawToErrorCorrectedMap.put(countedKmer.kmer, countedKmer.kmer);
// only allow corrections to kmers with at least this count
if ( kmerCounts.getValue() >= minCountOfKmerToBeCorrection )
goodKmers.add(kmerCounts.getKey());
if ( countedKmer.count >= minCountOfKmerToBeCorrection )
goodKmers.add(countedKmer);
}
}
for ( final String toCorrect : needsCorrection ) {
final String corrected = findClosestKMer(toCorrect, goodKmers);
rawToErrorCorrectedMap.put(toCorrect, corrected);
}
// cleanup memory -- we don't need the counts for each kmer any longer
countsByKMer = null;
if ( goodKmers.size() * needsCorrection.size() > MAX_CORRECTION_OPS_TO_ALLOW )
return false;
else {
Collections.sort(goodKmers);
for ( final CountedKmer toCorrect : needsCorrection ) {
final String corrected = findClosestKMer(toCorrect, goodKmers);
rawToErrorCorrectedMap.put(toCorrect.kmer, corrected);
}
return true;
}
}
protected void addKmer(final String rawKmer, final int kmerCount) {
@ -190,30 +241,42 @@ public class KMerErrorCorrector {
if ( kmerCount < 0 ) throw new IllegalArgumentException("bad kmerCount " + kmerCount);
if ( countsByKMer == null ) throw new IllegalStateException("Cannot add kmers to an already finalized error corrector");
final Integer countFromMap = countsByKMer.get(rawKmer);
final int count = countFromMap == null ? 0 : countFromMap;
countsByKMer.put(rawKmer, count + kmerCount);
CountedKmer countFromMap = countsByKMer.get(rawKmer);
if ( countFromMap == null ) {
countFromMap = new CountedKmer(rawKmer);
countsByKMer.put(rawKmer, countFromMap);
}
countFromMap.count += kmerCount;
}
protected String findClosestKMer(final String kmer, final Collection<String> goodKmers) {
protected String findClosestKMer(final CountedKmer kmer, final Collection<CountedKmer> goodKmers) {
String bestMatch = null;
int minMismatches = Integer.MAX_VALUE;
for ( final String goodKmer : goodKmers ) {
final int mismatches = countMismatches(kmer, goodKmer);
for ( final CountedKmer goodKmer : goodKmers ) {
final int mismatches = countMismatches(kmer.kmer, goodKmer.kmer, minMismatches);
if ( mismatches < minMismatches ) {
minMismatches = mismatches;
bestMatch = goodKmer;
bestMatch = goodKmer.kmer;
}
// if we find an edit-distance 1 result, abort early, as we know there can be no edit distance 0 results
if ( mismatches == 1 )
break;
}
return minMismatches > maxMismatchesToCorrect ? null : bestMatch;
}
protected int countMismatches(final String one, final String two) {
protected int countMismatches(final String one, final String two, final int currentBest) {
int mismatches = 0;
for ( int i = 0; i < one.length(); i++ )
for ( int i = 0; i < one.length(); i++ ) {
mismatches += one.charAt(i) == two.charAt(i) ? 0 : 1;
if ( mismatches > currentBest )
break;
if ( mismatches > maxMismatchesToCorrect )
return Integer.MAX_VALUE;
}
return mismatches;
}
@ -238,4 +301,26 @@ public class KMerErrorCorrector {
b.append("\n}");
return b.toString();
}
private static class CountedKmer implements Comparable<CountedKmer> {
final String kmer;
int count;
private CountedKmer(String kmer) {
this.kmer = kmer;
}
@Override
public String toString() {
return "CountedKmer{" +
"kmer='" + kmer + '\'' +
", count=" + count +
'}';
}
@Override
public int compareTo(CountedKmer o) {
return o.count - count;
}
}
}

View File

@ -63,7 +63,7 @@ public class HaplotypeCallerComplexAndSymbolicVariantsIntegrationTest extends Wa
@Test
public void testHaplotypeCallerMultiSampleComplex() {
HCTestComplexVariants(privateTestDir + "AFR.complex.variants.bam", "", "2b9355ab532314bce157c918c7606409");
HCTestComplexVariants(privateTestDir + "AFR.complex.variants.bam", "", "91f4880910e436bf5aca0abbebd58948");
}
private void HCTestSymbolicVariants(String bam, String args, String md5) {

View File

@ -85,7 +85,7 @@ public class HaplotypeCallerIntegrationTest extends WalkerTest {
@Test
public void testHaplotypeCallerMultiSampleGGA() {
HCTest(CEUTRIO_BAM, "--max_alternate_alleles 3 -gt_mode GENOTYPE_GIVEN_ALLELES -out_mode EMIT_ALL_SITES -alleles " + validationDataLocation + "combined.phase1.chr20.raw.indels.sites.vcf",
"9f9062a6eb93f984658492400102b0c7");
"d41a886f69a67e01af2ba1a6b4a681d9");
}
private void HCTestIndelQualityScores(String bam, String args, String md5) {