From aa7f172b18ff5ad8e5e881dce451c30ad362d61a Mon Sep 17 00:00:00 2001 From: Mark DePristo Date: Wed, 20 Mar 2013 22:40:10 -0400 Subject: [PATCH] Cap the computational cost of the kmer based error correction in the DeBruijnGraph -- Simply don't do more than MAX_CORRECTION_OPS_TO_ALLOW = 5000 * 1000 operations to correct a graph. If the number of ops would exceed this threshold, the original graph is used. -- Overall the algorithm is just extremely computational expensive, and actually doesn't implement the correct correction. So we live with this limitations while we continue to explore better algorithms -- Updating MD5s to reflect changes in assembly algorithms --- .../haplotypecaller/DeBruijnGraph.java | 25 ++-- .../haplotypecaller/KMerErrorCorrector.java | 135 ++++++++++++++---- ...lexAndSymbolicVariantsIntegrationTest.java | 2 +- .../HaplotypeCallerIntegrationTest.java | 2 +- 4 files changed, 127 insertions(+), 37 deletions(-) diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/DeBruijnGraph.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/DeBruijnGraph.java index d9df03539..0e20c311b 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/DeBruijnGraph.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/DeBruijnGraph.java @@ -90,7 +90,8 @@ public class DeBruijnGraph extends BaseGraph { /** * Error correct the kmers in this graph, returning a new graph built from those error corrected kmers - * @return a freshly allocated graph + * @return an error corrected version of this (freshly allocated graph) or simply this graph if for some reason + * we cannot actually do the error correction */ protected DeBruijnGraph errorCorrect() { final KMerErrorCorrector corrector = new KMerErrorCorrector(getKmerSize(), 1, 1, 5); // TODO -- should be static variables @@ -101,19 +102,23 @@ public class DeBruijnGraph extends BaseGraph { corrector.addKmer(kmer, e.isRef() ? 1000 : e.getMultiplicity()); } } - corrector.computeErrorCorrectionMap(); - final DeBruijnGraph correctedGraph = new DeBruijnGraph(getKmerSize()); + if ( corrector.computeErrorCorrectionMap() ) { + final DeBruijnGraph correctedGraph = new DeBruijnGraph(getKmerSize()); - for( final BaseEdge e : edgeSet() ) { - final byte[] source = corrector.getErrorCorrectedKmer(getEdgeSource(e).getSequence()); - final byte[] target = corrector.getErrorCorrectedKmer(getEdgeTarget(e).getSequence()); - if ( source != null && target != null ) { - correctedGraph.addKmersToGraph(source, target, e.isRef(), e.getMultiplicity()); + for( final BaseEdge e : edgeSet() ) { + final byte[] source = corrector.getErrorCorrectedKmer(getEdgeSource(e).getSequence()); + final byte[] target = corrector.getErrorCorrectedKmer(getEdgeTarget(e).getSequence()); + if ( source != null && target != null ) { + correctedGraph.addKmersToGraph(source, target, e.isRef(), e.getMultiplicity()); + } } - } - return correctedGraph; + return correctedGraph; + } else { + // the error correction wasn't possible, simply return this graph + return this; + } } /** diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/KMerErrorCorrector.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/KMerErrorCorrector.java index 05bd1b881..b051e5411 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/KMerErrorCorrector.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/KMerErrorCorrector.java @@ -46,6 +46,8 @@ package org.broadinstitute.sting.gatk.walkers.haplotypecaller; +import org.apache.log4j.Logger; + import java.util.*; /** @@ -69,15 +71,54 @@ import java.util.*; * TODO -- be added to hashmaps (more specifically, those don't implement .equals). A more efficient * TODO -- version would use the byte[] directly * + * TODO -- this is just not the right way to implement error correction in the graph. Basically, the + * right way to think about this is error correcting reads: + * + * * + * ACTGAT + * ACT + * CTG + * TGA + * GAT + * + * Now suppose the G is an error. What you are doing is asking for each 3mer in the read whether it's high quality + * or not. Suppose the answer is + * + * * + * ACTGAT + * ACT -- yes + * CTG -- no [CTG is unusual] + * TGA -- no [TGA is unusual] + * GAT -- yes [maybe GAT is just common, even through its an error] + * + * As we do this process it's clear how we can figure out which positions in the read likely harbor errors, and + * then go search around those bases in the read in an attempt to fix the read. We don't have to compute for + * every bad kmer it's best match, as that's just not the problem we are thinking looking to solve. We are actually + * looking for a change to a read such that all spanning kmers are well-supported. This class is being disabled + * until we figure implement this change. + * + * * User: depristo * Date: 3/8/13 * Time: 1:16 PM */ public class KMerErrorCorrector { + private final static Logger logger = Logger.getLogger(KMerErrorCorrector.class); + + /** + * The maximum number of bad kmer -> good kmer correction operations we'll consider doing before + * aborting for efficiency reasons. Basically, the current algorithm sucks, and is O(n^2), and + * so we cannot simply error correct 10K bad kmers against a db of 100K kmers if we ever want + * to finish running in a reasonable amount of time. This isn't worth fixing because fundamentally + * the entire error correction algorithm is just not right (i.e., it's correct but not ideal conceptually + * so we'll just fix the conceptual problem than the performance issue). + */ + private final static int MAX_CORRECTION_OPS_TO_ALLOW = 5000 * 1000; + /** * A map of for each kmer to its num occurrences in addKmers */ - Map countsByKMer = new HashMap(); + Map countsByKMer = new HashMap(); /** * A map from raw kmer -> error corrected kmer @@ -154,35 +195,45 @@ public class KMerErrorCorrector { * Indicate that no more kmers will be added to the kmer error corrector, so that the * error correction data structure should be computed from the added kmers. Enabled calls * to getErrorCorrectedKmer, and disable calls to addKmer. + * + * @return true if the error correction map could actually be computed, false if for any reason + * (efficiency, memory, we're out to lunch) a correction map couldn't be created. */ - public void computeErrorCorrectionMap() { + public boolean computeErrorCorrectionMap() { if ( countsByKMer == null ) throw new IllegalStateException("computeErrorCorrectionMap can only be called once"); - final LinkedList needsCorrection = new LinkedList(); - final LinkedList goodKmers = new LinkedList(); + final LinkedList needsCorrection = new LinkedList(); + final List goodKmers = new ArrayList(countsByKMer.size()); - rawToErrorCorrectedMap = new HashMap(); - for ( Map.Entry kmerCounts: countsByKMer.entrySet() ) { - if ( kmerCounts.getValue() <= maxCountToCorrect ) - needsCorrection.add(kmerCounts.getKey()); + rawToErrorCorrectedMap = new HashMap(countsByKMer.size()); + for ( final CountedKmer countedKmer: countsByKMer.values() ) { + if ( countedKmer.count <= maxCountToCorrect ) + needsCorrection.add(countedKmer); else { // todo -- optimization could make not in map mean == - rawToErrorCorrectedMap.put(kmerCounts.getKey(), kmerCounts.getKey()); + rawToErrorCorrectedMap.put(countedKmer.kmer, countedKmer.kmer); // only allow corrections to kmers with at least this count - if ( kmerCounts.getValue() >= minCountOfKmerToBeCorrection ) - goodKmers.add(kmerCounts.getKey()); + if ( countedKmer.count >= minCountOfKmerToBeCorrection ) + goodKmers.add(countedKmer); } } - for ( final String toCorrect : needsCorrection ) { - final String corrected = findClosestKMer(toCorrect, goodKmers); - rawToErrorCorrectedMap.put(toCorrect, corrected); - } - // cleanup memory -- we don't need the counts for each kmer any longer countsByKMer = null; + + if ( goodKmers.size() * needsCorrection.size() > MAX_CORRECTION_OPS_TO_ALLOW ) + return false; + else { + Collections.sort(goodKmers); + for ( final CountedKmer toCorrect : needsCorrection ) { + final String corrected = findClosestKMer(toCorrect, goodKmers); + rawToErrorCorrectedMap.put(toCorrect.kmer, corrected); + } + + return true; + } } protected void addKmer(final String rawKmer, final int kmerCount) { @@ -190,30 +241,42 @@ public class KMerErrorCorrector { if ( kmerCount < 0 ) throw new IllegalArgumentException("bad kmerCount " + kmerCount); if ( countsByKMer == null ) throw new IllegalStateException("Cannot add kmers to an already finalized error corrector"); - final Integer countFromMap = countsByKMer.get(rawKmer); - final int count = countFromMap == null ? 0 : countFromMap; - countsByKMer.put(rawKmer, count + kmerCount); + CountedKmer countFromMap = countsByKMer.get(rawKmer); + if ( countFromMap == null ) { + countFromMap = new CountedKmer(rawKmer); + countsByKMer.put(rawKmer, countFromMap); + } + countFromMap.count += kmerCount; } - protected String findClosestKMer(final String kmer, final Collection goodKmers) { + protected String findClosestKMer(final CountedKmer kmer, final Collection goodKmers) { String bestMatch = null; int minMismatches = Integer.MAX_VALUE; - for ( final String goodKmer : goodKmers ) { - final int mismatches = countMismatches(kmer, goodKmer); + for ( final CountedKmer goodKmer : goodKmers ) { + final int mismatches = countMismatches(kmer.kmer, goodKmer.kmer, minMismatches); if ( mismatches < minMismatches ) { minMismatches = mismatches; - bestMatch = goodKmer; + bestMatch = goodKmer.kmer; } + + // if we find an edit-distance 1 result, abort early, as we know there can be no edit distance 0 results + if ( mismatches == 1 ) + break; } return minMismatches > maxMismatchesToCorrect ? null : bestMatch; } - protected int countMismatches(final String one, final String two) { + protected int countMismatches(final String one, final String two, final int currentBest) { int mismatches = 0; - for ( int i = 0; i < one.length(); i++ ) + for ( int i = 0; i < one.length(); i++ ) { mismatches += one.charAt(i) == two.charAt(i) ? 0 : 1; + if ( mismatches > currentBest ) + break; + if ( mismatches > maxMismatchesToCorrect ) + return Integer.MAX_VALUE; + } return mismatches; } @@ -238,4 +301,26 @@ public class KMerErrorCorrector { b.append("\n}"); return b.toString(); } + + private static class CountedKmer implements Comparable { + final String kmer; + int count; + + private CountedKmer(String kmer) { + this.kmer = kmer; + } + + @Override + public String toString() { + return "CountedKmer{" + + "kmer='" + kmer + '\'' + + ", count=" + count + + '}'; + } + + @Override + public int compareTo(CountedKmer o) { + return o.count - count; + } + } } diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerComplexAndSymbolicVariantsIntegrationTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerComplexAndSymbolicVariantsIntegrationTest.java index fd16ed856..12dc71799 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerComplexAndSymbolicVariantsIntegrationTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerComplexAndSymbolicVariantsIntegrationTest.java @@ -63,7 +63,7 @@ public class HaplotypeCallerComplexAndSymbolicVariantsIntegrationTest extends Wa @Test public void testHaplotypeCallerMultiSampleComplex() { - HCTestComplexVariants(privateTestDir + "AFR.complex.variants.bam", "", "2b9355ab532314bce157c918c7606409"); + HCTestComplexVariants(privateTestDir + "AFR.complex.variants.bam", "", "91f4880910e436bf5aca0abbebd58948"); } private void HCTestSymbolicVariants(String bam, String args, String md5) { diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerIntegrationTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerIntegrationTest.java index c93e54f87..5ee0a6b81 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerIntegrationTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerIntegrationTest.java @@ -85,7 +85,7 @@ public class HaplotypeCallerIntegrationTest extends WalkerTest { @Test public void testHaplotypeCallerMultiSampleGGA() { HCTest(CEUTRIO_BAM, "--max_alternate_alleles 3 -gt_mode GENOTYPE_GIVEN_ALLELES -out_mode EMIT_ALL_SITES -alleles " + validationDataLocation + "combined.phase1.chr20.raw.indels.sites.vcf", - "9f9062a6eb93f984658492400102b0c7"); + "d41a886f69a67e01af2ba1a6b4a681d9"); } private void HCTestIndelQualityScores(String bam, String args, String md5) {