diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java index afc30318c..63aa54fa5 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java @@ -147,7 +147,7 @@ public class LikelihoodCalculationEngine { for( int jjj = 0; jjj < numHaplotypes; jjj++ ) { final Haplotype haplotype = haplotypes.get(jjj); - final int haplotypeStart = ( previousHaplotypeSeen == null ? 0 : computeFirstDifferingPosition(haplotype.getBases(), previousHaplotypeSeen.getBases()) ); + final int haplotypeStart = ( previousHaplotypeSeen == null ? 0 : PairHMM.findFirstPositionWhereHaplotypesDiffer(haplotype.getBases(), previousHaplotypeSeen.getBases()) ); previousHaplotypeSeen = haplotype; perReadAlleleLikelihoodMap.add(read, alleleVersions.get(haplotype), @@ -158,15 +158,6 @@ public class LikelihoodCalculationEngine { return perReadAlleleLikelihoodMap; } - private static int computeFirstDifferingPosition( final byte[] b1, final byte[] b2 ) { - for( int iii = 0; iii < b1.length && iii < b2.length; iii++ ) { - if( b1[iii] != b2[iii] ) { - return iii; - } - } - return Math.min(b1.length, b2.length); - } - @Requires({"alleleOrdering.size() > 0"}) @Ensures({"result.length == result[0].length", "result.length == alleleOrdering.size()"}) public static double[][] computeDiploidHaplotypeLikelihoods( final String sample, diff --git a/protected/java/src/org/broadinstitute/sting/utils/pairhmm/LoglessCachingPairHMM.java b/protected/java/src/org/broadinstitute/sting/utils/pairhmm/LoglessCachingPairHMM.java index 6dc500711..6f8bec94f 100644 --- a/protected/java/src/org/broadinstitute/sting/utils/pairhmm/LoglessCachingPairHMM.java +++ b/protected/java/src/org/broadinstitute/sting/utils/pairhmm/LoglessCachingPairHMM.java @@ -46,22 +46,25 @@ package org.broadinstitute.sting.utils.pairhmm; +import com.google.java.contract.Ensures; +import com.google.java.contract.Requires; import org.broadinstitute.sting.utils.QualityUtils; -import java.util.Arrays; - /** * Created with IntelliJ IDEA. * User: rpoplin, carneiro * Date: 10/16/12 */ - public class LoglessCachingPairHMM extends PairHMM { protected static final double SCALE_FACTOR_LOG10 = 300.0; double[][] constantMatrix = null; // The cache double[][] distanceMatrix = null; // The cache + boolean constantsAreInitialized = false; + /** + * Cached data structure that describes the first row's edge condition in the HMM + */ protected static final double [] firstRowConstantMatrix = { QualityUtils.qualToProb((byte) (DEFAULT_GOP + DEFAULT_GOP)), QualityUtils.qualToProb(DEFAULT_GCP), @@ -71,53 +74,48 @@ public class LoglessCachingPairHMM extends PairHMM { 1.0 }; + /** + * {@inheritDoc} + */ @Override - public void initialize( final int READ_MAX_LENGTH, final int HAPLOTYPE_MAX_LENGTH ) { - super.initialize(READ_MAX_LENGTH, HAPLOTYPE_MAX_LENGTH); + public void initialize( final int readMaxLength, final int haplotypeMaxLength) { + super.initialize(readMaxLength, haplotypeMaxLength); - constantMatrix = new double[X_METRIC_LENGTH][6]; - distanceMatrix = new double[X_METRIC_LENGTH][Y_METRIC_LENGTH]; - - // TODO -- this shouldn't be necessary - for( int iii=0; iii < X_METRIC_LENGTH; iii++ ) { - Arrays.fill(matchMetricArray[iii], 0.0); - Arrays.fill(XMetricArray[iii], 0.0); - Arrays.fill(YMetricArray[iii], 0.0); - } - - // the initial condition - matchMetricArray[1][1] = Math.pow(10.0, SCALE_FACTOR_LOG10) / nPotentialXStarts; // Math.log10(1.0); - firstRowConstantMatrix[4] = firstRowConstantMatrix[5] = 1.0; - - // fill in the first row - for( int jjj = 2; jjj < Y_METRIC_LENGTH; jjj++ ) { - updateCell(1, jjj, 1.0, firstRowConstantMatrix, matchMetricArray, XMetricArray, YMetricArray); - } + constantMatrix = new double[X_METRIC_MAX_LENGTH][6]; + distanceMatrix = new double[X_METRIC_MAX_LENGTH][Y_METRIC_MAX_LENGTH]; } + /** + * {@inheritDoc} + */ @Override - public double computeReadLikelihoodGivenHaplotypeLog10( final byte[] haplotypeBases, - final byte[] readBases, - final byte[] readQuals, - final byte[] insertionGOP, - final byte[] deletionGOP, - final byte[] overallGCP, - final int hapStartIndex, - final boolean recacheReadValues ) { - - if ( recacheReadValues ) - initializeConstants( insertionGOP, deletionGOP, overallGCP ); + public double subComputeReadLikelihoodGivenHaplotypeLog10( final byte[] haplotypeBases, + final byte[] readBases, + final byte[] readQuals, + final byte[] insertionGOP, + final byte[] deletionGOP, + final byte[] overallGCP, + final int hapStartIndex, + final boolean recacheReadValues ) { + if ( ! constantsAreInitialized || recacheReadValues ) + initializeConstants( haplotypeBases.length, readBases.length, insertionGOP, deletionGOP, overallGCP ); initializeDistanceMatrix( haplotypeBases, readBases, readQuals, hapStartIndex ); - for (int i = 2; i < X_METRIC_LENGTH; i++) { - for (int j = hapStartIndex+1; j < Y_METRIC_LENGTH; j++) { + // NOTE NOTE NOTE -- because of caching we need to only operate over X and Y according to this + // read and haplotype lengths, not the max lengths + final int readXMetricLength = readBases.length + 2; + final int hapYMetricLength = haplotypeBases.length + 2; + + for (int i = 2; i < readXMetricLength; i++) { + // +1 here is because hapStartIndex is 0-based, but our matrices are 1 based + for (int j = hapStartIndex+1; j < hapYMetricLength; j++) { updateCell(i, j, distanceMatrix[i][j], constantMatrix[i], matchMetricArray, XMetricArray, YMetricArray); } } // final probability is the log10 sum of the last element in all three state arrays - final int endI = X_METRIC_LENGTH - 1; - final int endJ = Y_METRIC_LENGTH - 1; + final int endI = readXMetricLength - 1; + final int endJ = hapYMetricLength - 1; return Math.log10( matchMetricArray[endI][endJ] + XMetricArray[endI][endJ] + YMetricArray[endI][endJ] ) - SCALE_FACTOR_LOG10; } @@ -152,13 +150,32 @@ public class LoglessCachingPairHMM extends PairHMM { /** * Initializes the matrix that holds all the constants related to quality scores. * + * @param haplotypeSize the number of bases in the haplotype we are testing + * @param readSize the number of bases in the read we are testing * @param insertionGOP insertion quality scores of the read * @param deletionGOP deletion quality scores of the read * @param overallGCP overall gap continuation penalty */ - public void initializeConstants( final byte[] insertionGOP, - final byte[] deletionGOP, - final byte[] overallGCP ) { + @Requires({ + "haplotypeSize > 0", + "readSize > 0", + "insertionGOP != null && insertionGOP.length == readSize", + "deletionGOP != null && deletionGOP.length == readSize", + "overallGCP != null && overallGCP.length == readSize" + }) + @Ensures("constantsAreInitialized") + private void initializeConstants( final int haplotypeSize, + final int readSize, + final byte[] insertionGOP, + final byte[] deletionGOP, + final byte[] overallGCP ) { + // the initial condition -- must be here because it needs that actual read and haplotypes, not the maximum in init + matchMetricArray[1][1] = Math.pow(10.0, SCALE_FACTOR_LOG10) / getNPotentialXStarts(haplotypeSize, readSize); + + // fill in the first row + for( int jjj = 2; jjj < Y_METRIC_MAX_LENGTH; jjj++ ) { + updateCell(1, jjj, 1.0, firstRowConstantMatrix, matchMetricArray, XMetricArray, YMetricArray); + } final int l = insertionGOP.length; constantMatrix[1] = firstRowConstantMatrix; @@ -173,6 +190,9 @@ public class LoglessCachingPairHMM extends PairHMM { } constantMatrix[l+1][4] = 1.0; constantMatrix[l+1][5] = 1.0; + + // note that we initialized the constants + constantsAreInitialized = true; } /** diff --git a/protected/java/test/org/broadinstitute/sting/utils/pairhmm/PairHMMUnitTest.java b/protected/java/test/org/broadinstitute/sting/utils/pairhmm/PairHMMUnitTest.java index c463b7f44..9de562aa5 100644 --- a/protected/java/test/org/broadinstitute/sting/utils/pairhmm/PairHMMUnitTest.java +++ b/protected/java/test/org/broadinstitute/sting/utils/pairhmm/PairHMMUnitTest.java @@ -52,6 +52,7 @@ package org.broadinstitute.sting.utils.pairhmm; import org.broadinstitute.sting.BaseTest; import org.broadinstitute.sting.gatk.GenomeAnalysisEngine; import org.broadinstitute.sting.utils.BaseUtils; +import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.QualityUtils; import org.broadinstitute.sting.utils.Utils; import org.testng.Assert; @@ -64,16 +65,15 @@ import java.util.List; import java.util.Random; public class PairHMMUnitTest extends BaseTest { + private final static boolean ALLOW_READS_LONGER_THAN_HAPLOTYPE = true; private final static boolean DEBUG = false; - final static boolean EXTENSIVE_TESTING = false; // TODO -- should be true - PairHMM exactHMM = new Log10PairHMM(true); // the log truth implementation - PairHMM originalHMM = new Log10PairHMM(false); // the reference implementation - PairHMM loglessHMM = new LoglessCachingPairHMM(); + final static boolean EXTENSIVE_TESTING = true; + final PairHMM exactHMM = new Log10PairHMM(true); // the log truth implementation + final PairHMM originalHMM = new Log10PairHMM(false); // the reference implementation + final PairHMM loglessHMM = new LoglessCachingPairHMM(); private List getHMMs() { - // TODO -- re-enable loglessHMM tests - return Arrays.asList(exactHMM, originalHMM); - //return Arrays.asList(exactHMM, originalHMM, cachingHMM, loglessHMM); + return Arrays.asList(exactHMM, originalHMM, loglessHMM); } // -------------------------------------------------------------------------------- @@ -109,8 +109,9 @@ public class PairHMMUnitTest extends BaseTest { readBasesWithContext = asBytes(read, false, false); } - public double expectedLogL() { - return (expectedQual / -10.0) + 0.03 ; + public double expectedLogL(final PairHMM hmm) { + return (expectedQual / -10.0) + 0.03 + + hmm.getNPotentialXStartsLikelihoodPenaltyLog10(refBasesWithContext.length, readBasesWithContext.length); } public double getTolerance(final PairHMM hmm) { @@ -127,7 +128,7 @@ public class PairHMMUnitTest extends BaseTest { } public double toleranceFromReference() { - return 1E-4; + return 1E-3; // has to be very tolerant -- this approximation is quite approximate } public double toleranceFromExact() { @@ -239,10 +240,10 @@ public class PairHMMUnitTest extends BaseTest { for( int iii = 0; iii < readSize; iii++) { read += (char) BaseUtils.BASES[random.nextInt(4)]; } - new BasicLikelihoodTestProvider(ref, read, baseQual, indelQual, indelQual, -0, gcp); - new BasicLikelihoodTestProvider(ref, read, baseQual, indelQual, indelQual, -0, gcp, true, false); - new BasicLikelihoodTestProvider(ref, read, baseQual, indelQual, indelQual, -0, gcp, false, true); - new BasicLikelihoodTestProvider(ref, read, baseQual, indelQual, indelQual, -0, gcp, true, true); + + for ( final boolean leftFlank : Arrays.asList(true, false) ) + for ( final boolean rightFlank : Arrays.asList(true, false) ) + new BasicLikelihoodTestProvider(ref, read, baseQual, indelQual, indelQual, -0, gcp, leftFlank, rightFlank); } } } @@ -254,26 +255,32 @@ public class PairHMMUnitTest extends BaseTest { @Test(enabled = !DEBUG, dataProvider = "BasicLikelihoodTestProvider") public void testBasicLikelihoods(BasicLikelihoodTestProvider cfg) { - final double exactLogL = cfg.calcLogL( exactHMM, true ); - for ( final PairHMM hmm : getHMMs() ) { - double actualLogL = cfg.calcLogL( hmm, true ); - double expectedLogL = cfg.expectedLogL(); + if ( ALLOW_READS_LONGER_THAN_HAPLOTYPE || cfg.read.length() <= cfg.ref.length() ) { + final double exactLogL = cfg.calcLogL( exactHMM, true ); + for ( final PairHMM hmm : getHMMs() ) { + double actualLogL = cfg.calcLogL( hmm, true ); + double expectedLogL = cfg.expectedLogL(hmm); - // compare to our theoretical expectation with appropriate tolerance - Assert.assertEquals(actualLogL, expectedLogL, cfg.toleranceFromTheoretical(), "Failed with hmm " + hmm); - // compare to the exact reference implementation with appropriate tolerance - Assert.assertEquals(actualLogL, exactLogL, cfg.getTolerance(hmm), "Failed with hmm " + hmm); + // compare to our theoretical expectation with appropriate tolerance + Assert.assertEquals(actualLogL, expectedLogL, cfg.toleranceFromTheoretical(), "Failed with hmm " + hmm); + // compare to the exact reference implementation with appropriate tolerance + Assert.assertEquals(actualLogL, exactLogL, cfg.getTolerance(hmm), "Failed with hmm " + hmm); + Assert.assertTrue(MathUtils.goodLog10Probability(actualLogL), "Bad log10 likelihood " + actualLogL); + } } } @Test(enabled = !DEBUG, dataProvider = "OptimizedLikelihoodTestProvider") public void testOptimizedLikelihoods(BasicLikelihoodTestProvider cfg) { - double exactLogL = cfg.calcLogL( exactHMM, false ); + if ( ALLOW_READS_LONGER_THAN_HAPLOTYPE || cfg.read.length() <= cfg.ref.length() ) { + double exactLogL = cfg.calcLogL( exactHMM, false ); - for ( final PairHMM hmm : getHMMs() ) { - double calculatedLogL = cfg.calcLogL( hmm, false ); - // compare to the exact reference implementation with appropriate tolerance - Assert.assertEquals(calculatedLogL, exactLogL, cfg.getTolerance(hmm), String.format("Test: logL calc=%.2f expected=%.2f for %s with hmm %s", calculatedLogL, exactLogL, cfg.toString(), hmm)); + for ( final PairHMM hmm : getHMMs() ) { + double calculatedLogL = cfg.calcLogL( hmm, false ); + // compare to the exact reference implementation with appropriate tolerance + Assert.assertEquals(calculatedLogL, exactLogL, cfg.getTolerance(hmm), String.format("Test: logL calc=%.2f expected=%.2f for %s with hmm %s", calculatedLogL, exactLogL, cfg.toString(), hmm)); + Assert.assertTrue(MathUtils.goodLog10Probability(calculatedLogL), "Bad log10 likelihood " + calculatedLogL); + } } } @@ -304,7 +311,8 @@ public class PairHMMUnitTest extends BaseTest { System.out.format("H:%s\nR: %s\n Pos:%d Result:%4.2f\n",new String(haplotype1), new String(mread), k,res1); - Assert.assertEquals(res1, -2.0, 1e-2); + // - log10 is because of number of start positions + Assert.assertEquals(res1, -2.0 - Math.log10(originalHMM.getNPotentialXStarts(haplotype1.length, mread.length)), 1e-2); } } @@ -335,7 +343,8 @@ public class PairHMMUnitTest extends BaseTest { System.out.format("H:%s\nR: %s\n Pos:%d Result:%4.2f\n",new String(haplotype1), new String(mread), k,res1); - Assert.assertEquals(res1, -2.0, 1e-2); + // - log10 is because of number of start positions + Assert.assertEquals(res1, -2.0 - Math.log10(originalHMM.getNPotentialXStarts(haplotype1.length, mread.length)), 1e-2); } } @@ -343,19 +352,22 @@ public class PairHMMUnitTest extends BaseTest { public Object[][] makeHMMProvider() { List tests = new ArrayList(); - // TODO -- reenable -// for ( final PairHMM hmm : getHMMs() ) -// tests.add(new Object[]{hmm}); - tests.add(new Object[]{loglessHMM}); + for ( final int readSize : Arrays.asList(1, 2, 5, 10) ) { + for ( final int refSize : Arrays.asList(1, 2, 5, 10) ) { + if ( refSize > readSize ) { + for ( final PairHMM hmm : getHMMs() ) + tests.add(new Object[]{hmm, readSize, refSize}); + } + } + } return tests.toArray(new Object[][]{}); } - // TODO -- generalize provider to include read and ref base sizes - @Test(dataProvider = "HMMProvider") - void testMultipleReadMatchesInHaplotype(final PairHMM hmm) { - byte[] readBases = "AAAAAAAAAAAA".getBytes(); - byte[] refBases = "CCAAAAAAAAAAAAAAGGA".getBytes(); + @Test(enabled = !DEBUG, dataProvider = "HMMProvider") + void testMultipleReadMatchesInHaplotype(final PairHMM hmm, final int readSize, final int refSize) { + byte[] readBases = Utils.dupBytes((byte)'A', readSize); + byte[] refBases = ("CC" + new String(Utils.dupBytes((byte)'A', refSize)) + "GGA").getBytes(); byte baseQual = 20; byte insQual = 37; byte delQual = 37; @@ -369,10 +381,10 @@ public class PairHMMUnitTest extends BaseTest { Assert.assertTrue(d <= 0.0, "Likelihoods should be <= 0 but got "+ d); } - @Test(dataProvider = "HMMProvider") - void testAllMatchingRead(final PairHMM hmm) { - byte[] readBases = "AAA".getBytes(); - byte[] refBases = "AAAAA".getBytes(); + @Test(enabled = !DEBUG, dataProvider = "HMMProvider") + void testAllMatchingRead(final PairHMM hmm, final int readSize, final int refSize) { + byte[] readBases = Utils.dupBytes((byte)'A', readSize); + byte[] refBases = Utils.dupBytes((byte)'A', refSize); byte baseQual = 20; byte insQual = 100; byte delQual = 100; @@ -386,4 +398,243 @@ public class PairHMMUnitTest extends BaseTest { final double expected = Math.log10(Math.pow(1.0 - QualityUtils.qualToErrorProb(baseQual), readBases.length)); Assert.assertEquals(d, expected, 1e-3, "Likelihoods should sum to just the error prob of the read"); } + + @DataProvider(name = "HMMProviderWithBigReads") + public Object[][] makeBigReadHMMProvider() { + List tests = new ArrayList(); + + final String read1 = "ACCAAGTAGTCACCGT"; + final String ref1 = "ACCAAGTAGTCACCGTAACG"; + + for ( final int nReadCopies : Arrays.asList(1, 2, 10, 20, 50) ) { + for ( final int nRefCopies : Arrays.asList(1, 2, 10, 20, 100) ) { + if ( nRefCopies > nReadCopies ) { + for ( final PairHMM hmm : getHMMs() ) { + final String read = Utils.dupString(read1, nReadCopies); + final String ref = Utils.dupString(ref1, nRefCopies); + tests.add(new Object[]{hmm, read, ref}); + } + } + } + } + + return tests.toArray(new Object[][]{}); + } + + @Test(enabled = !DEBUG, dataProvider = "HMMProviderWithBigReads") + void testReallyBigReads(final PairHMM hmm, final String read, final String ref) { + byte[] readBases = read.getBytes(); + byte[] refBases = ref.getBytes(); + byte baseQual = 30; + byte insQual = 40; + byte delQual = 40; + byte gcp = 10; + hmm.initialize(readBases.length, refBases.length); + double d = hmm.computeReadLikelihoodGivenHaplotypeLog10( refBases, readBases, + Utils.dupBytes(baseQual, readBases.length), + Utils.dupBytes(insQual, readBases.length), + Utils.dupBytes(delQual, readBases.length), + Utils.dupBytes(gcp, readBases.length), 0, true); + Assert.assertTrue(MathUtils.goodLog10Probability(d), "Likelihoods = " + d +" was bad for a read with " + read.length() + " bases and ref with " + ref.length() + " bases"); + } + + @Test(enabled = !DEBUG) + void testPreviousBadValue() { + byte[] readBases = "A".getBytes(); + byte[] refBases = "AT".getBytes(); + byte baseQual = 30; + byte insQual = 40; + byte delQual = 40; + byte gcp = 10; + + exactHMM.initialize(readBases.length, refBases.length); + double d = exactHMM.computeReadLikelihoodGivenHaplotypeLog10( refBases, readBases, + Utils.dupBytes(baseQual, readBases.length), + Utils.dupBytes(insQual, readBases.length), + Utils.dupBytes(delQual, readBases.length), + Utils.dupBytes(gcp, readBases.length), 0, true); + //exactHMM.dumpMatrices(); + + loglessHMM.initialize(readBases.length, refBases.length); + double logless = loglessHMM.computeReadLikelihoodGivenHaplotypeLog10( refBases, readBases, + Utils.dupBytes(baseQual, readBases.length), + Utils.dupBytes(insQual, readBases.length), + Utils.dupBytes(delQual, readBases.length), + Utils.dupBytes(gcp, readBases.length), 0, true); + loglessHMM.dumpMatrices(); + } + + @DataProvider(name = "JustHMMProvider") + public Object[][] makeJustHMMProvider() { + List tests = new ArrayList(); + + for ( final PairHMM hmm : getHMMs() ) { + tests.add(new Object[]{hmm}); + } + + return tests.toArray(new Object[][]{}); + } + + @Test(enabled = !DEBUG, dataProvider = "JustHMMProvider") + void testMaxLengthsBiggerThanProvidedRead(final PairHMM hmm) { + for ( int nExtraMaxSize = 0; nExtraMaxSize < 100; nExtraMaxSize++ ) { + byte[] readBases = "CTATCTTAGTAAGCCCCCATACCTGCAAATTTCAGGATGTCTCCTCCAAAAATCAACA".getBytes(); + byte[] refBases = "CTATCTTAGTAAGCCCCCATACCTGCAAATTTCAGGATGTCTCCTCCAAAAATCAAAACTTCTGAGAAAAAAAAAAAAAATTAAATCAAACCCTGATTCCTTAAAGGTAGTAAAAAAACATCATTCTTTCTTAGTGGAATAGAAACTAGGTCAAAAGAACAGTGATTC".getBytes(); + byte gcp = 10; + + byte[] quals = new byte[]{35,34,31,32,35,34,32,31,36,30,31,32,36,34,33,32,32,32,33,32,30,35,33,35,36,36,33,33,33,32,32,32,37,33,36,35,33,32,34,31,36,35,35,35,35,33,34,31,31,30,28,27,26,29,26,25,29,29}; + byte[] insQual = new byte[]{46,46,46,46,46,47,45,46,45,48,47,44,45,48,46,43,43,42,48,48,45,47,47,48,48,47,48,45,38,47,45,39,47,48,47,47,48,46,49,48,49,48,46,47,48,44,44,43,39,32,34,36,46,48,46,44,45,45}; + byte[] delQual = new byte[]{44,44,44,43,45,44,43,42,45,46,45,43,44,47,45,40,40,40,45,46,43,45,45,44,46,46,46,43,35,44,43,36,44,45,46,46,44,44,47,43,47,45,45,45,46,45,45,46,44,35,35,35,45,47,45,44,44,43}; + + final int maxHaplotypeLength = refBases.length + nExtraMaxSize; + final int maxReadLength = readBases.length + nExtraMaxSize; + + hmm.initialize(maxReadLength, maxHaplotypeLength); + double d = hmm.computeReadLikelihoodGivenHaplotypeLog10( refBases, readBases, + quals, + insQual, + delQual, + Utils.dupBytes(gcp, readBases.length), 0, true); + Assert.assertTrue(MathUtils.goodLog10Probability(d), "Likelihoods = " + d +" was bad for a read with " + readBases.length + " bases and ref with " + refBases.length + " bases"); + } + } + + @DataProvider(name = "HaplotypeIndexingProvider") + public Object[][] makeHaplotypeIndexingProvider() { + List tests = new ArrayList(); + + final String root1 = "ACGTGTCAAACCGGGTT"; + final String root2 = "ACGTGTCACACTGGGTT"; // differs in two locations + + final String read1 = "ACGTGTCACACTGGATT"; // 1 diff from 2, 2 diff from root1 + final String read2 = root1; // same as root1 + final String read3 = root2; // same as root2 + final String read4 = "ACGTGTCACACTGGATTCGAT"; + final String read5 = "CCAGTAACGTGTCACACTGGATTCGAT"; + +// for ( final String read : Arrays.asList(read2) ) { + for ( final String read : Arrays.asList(read1, read2, read3, read4, read5) ) { + for ( final PairHMM hmm : getHMMs() ) { +// int readLength = read.length(); { + for ( int readLength = 10; readLength < read.length(); readLength++ ) { + final String myRead = read.substring(0, readLength); + tests.add(new Object[]{hmm, root1, root2, myRead}); + } + } + } + + return tests.toArray(new Object[][]{}); + } + + @Test(enabled = !DEBUG, dataProvider = "HaplotypeIndexingProvider") + void testHaplotypeIndexing(final PairHMM hmm, final String root1, final String root2, final String read) { + final double TOLERANCE = 1e-9; + final String prefix = "AACCGGTTTTTGGGCCCAAACGTACGTACAGTTGGTCAACATCGATCAGGTTCCGGAGTAC"; + + final int maxReadLength = read.length(); + final int maxHaplotypeLength = prefix.length() + root1.length(); + + // the initialization occurs once, at the start of the evalution of reads + hmm.initialize(maxReadLength, maxHaplotypeLength); + + for ( int prefixStart = prefix.length(); prefixStart >= 0; prefixStart-- ) { + final String myPrefix = prefix.substring(prefixStart, prefix.length()); + final String hap1 = myPrefix + root1; + final String hap2 = myPrefix + root2; + + final int hapStart = PairHMM.findFirstPositionWhereHaplotypesDiffer(hap1.getBytes(), hap2.getBytes()); + + final double actual1 = testHaplotypeIndexingCalc(hmm, hap1, read, 0, true); + final double actual2 = testHaplotypeIndexingCalc(hmm, hap2, read, hapStart, false); + final double expected2 = testHaplotypeIndexingCalc(hmm, hap2, read, 0, true); + Assert.assertEquals(actual2, expected2, TOLERANCE, "Caching calculation failed for read " + read + " against haplotype with prefix '" + myPrefix + + "' expected " + expected2 + " but got " + actual2 + " with hapStart of " + hapStart); + } + } + + private double testHaplotypeIndexingCalc(final PairHMM hmm, final String hap, final String read, final int hapStart, final boolean recache) { + final byte[] readBases = read.getBytes(); + final byte[] baseQuals = Utils.dupBytes((byte)30, readBases.length); + final byte[] insQuals = Utils.dupBytes((byte)45, readBases.length); + final byte[] delQuals = Utils.dupBytes((byte)40, readBases.length); + final byte[] gcp = Utils.dupBytes((byte)10, readBases.length); + double d = hmm.computeReadLikelihoodGivenHaplotypeLog10( + hap.getBytes(), readBases, baseQuals, insQuals, delQuals, gcp, + hapStart, recache); + Assert.assertTrue(MathUtils.goodLog10Probability(d), "Likelihoods = " + d + " was bad for read " + read + " and ref " + hap + " with hapStart " + hapStart); + return d; + } + + @Test(enabled = !DEBUG) + public void testFindFirstPositionWhereHaplotypesDiffer() { + for ( int haplotypeSize1 = 10; haplotypeSize1 < 30; haplotypeSize1++ ) { + for ( int haplotypeSize2 = 10; haplotypeSize2 < 50; haplotypeSize2++ ) { + final int maxLength = Math.max(haplotypeSize1, haplotypeSize2); + final int minLength = Math.min(haplotypeSize1, haplotypeSize2); + for ( int differingSite = 0; differingSite < maxLength + 1; differingSite++) { + for ( final boolean oneIsDiff : Arrays.asList(true, false) ) { + final byte[] hap1 = Utils.dupBytes((byte)'A', haplotypeSize1); + final byte[] hap2 = Utils.dupBytes((byte)'A', haplotypeSize2); + + final int expected = oneIsDiff + ? makeDiff(hap1, differingSite, minLength) + : makeDiff(hap2, differingSite, minLength); + final int actual = PairHMM.findFirstPositionWhereHaplotypesDiffer(hap1, hap2); + Assert.assertEquals(actual, expected, "Bad differing site for " + new String(hap1) + " vs. " + new String(hap2)); + } + } + } + } + } + + private int makeDiff(final byte[] bytes, final int site, final int minSize) { + if ( site < bytes.length ) { + bytes[site] = 'C'; + return Math.min(site, minSize); + } else + return minSize; + } + + @DataProvider(name = "UninitializedHMMs") + public Object[][] makeUninitializedHMMs() { + List tests = new ArrayList(); + + tests.add(new Object[]{new LoglessCachingPairHMM()}); + tests.add(new Object[]{new Log10PairHMM(true)}); + + return tests.toArray(new Object[][]{}); + } + + @Test(enabled = true, expectedExceptions = IllegalStateException.class, dataProvider = "UninitializedHMMs") + public void testNoInitializeCall(final PairHMM hmm) { + byte[] readBases = "A".getBytes(); + byte[] refBases = "AT".getBytes(); + byte[] baseQuals = Utils.dupBytes((byte)30, readBases.length); + + // didn't call initialize => should exception out + double d = hmm.computeReadLikelihoodGivenHaplotypeLog10( refBases, readBases, + baseQuals, baseQuals, baseQuals, baseQuals, 0, true); + } + + @Test(enabled = true, expectedExceptions = IllegalArgumentException.class, dataProvider = "JustHMMProvider") + public void testHapTooLong(final PairHMM hmm) { + byte[] readBases = "AAA".getBytes(); + byte[] refBases = "AAAT".getBytes(); + byte[] baseQuals = Utils.dupBytes((byte)30, readBases.length); + + hmm.initialize(3, 3); + double d = hmm.computeReadLikelihoodGivenHaplotypeLog10( refBases, readBases, + baseQuals, baseQuals, baseQuals, baseQuals, 0, true); + } + + @Test(enabled = true, expectedExceptions = IllegalArgumentException.class, dataProvider = "JustHMMProvider") + public void testReadTooLong(final PairHMM hmm) { + byte[] readBases = "AAA".getBytes(); + byte[] refBases = "AAAT".getBytes(); + byte[] baseQuals = Utils.dupBytes((byte)30, readBases.length); + + hmm.initialize(2, 3); + double d = hmm.computeReadLikelihoodGivenHaplotypeLog10( refBases, readBases, + baseQuals, baseQuals, baseQuals, baseQuals, 0, true); + } } \ No newline at end of file diff --git a/public/java/src/org/broadinstitute/sting/utils/Utils.java b/public/java/src/org/broadinstitute/sting/utils/Utils.java index 77f3a84c3..d009ba5bc 100644 --- a/public/java/src/org/broadinstitute/sting/utils/Utils.java +++ b/public/java/src/org/broadinstitute/sting/utils/Utils.java @@ -308,6 +308,22 @@ public class Utils { return join(separator, Arrays.asList(objects)); } + /** + * Create a new string thats a n duplicate copies of s + * @param s the string to duplicate + * @param nCopies how many copies? + * @return a string + */ + public static String dupString(final String s, int nCopies) { + if ( s == null || s.equals("") ) throw new IllegalArgumentException("Bad s " + s); + if ( nCopies < 1 ) throw new IllegalArgumentException("nCopies must be >= 1 but got " + nCopies); + + final StringBuilder b = new StringBuilder(); + for ( int i = 0; i < nCopies; i++ ) + b.append(s); + return b.toString(); + } + public static String dupString(char c, int nCopies) { char[] chars = new char[nCopies]; Arrays.fill(chars, c); diff --git a/public/java/src/org/broadinstitute/sting/utils/pairhmm/Log10PairHMM.java b/public/java/src/org/broadinstitute/sting/utils/pairhmm/Log10PairHMM.java index ea2f18f0e..c9d364aac 100644 --- a/public/java/src/org/broadinstitute/sting/utils/pairhmm/Log10PairHMM.java +++ b/public/java/src/org/broadinstitute/sting/utils/pairhmm/Log10PairHMM.java @@ -25,7 +25,6 @@ package org.broadinstitute.sting.utils.pairhmm; -import com.google.java.contract.Ensures; import com.google.java.contract.Requires; import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.QualityUtils; @@ -39,6 +38,9 @@ import java.util.Arrays; * Date: 3/1/12 */ public class Log10PairHMM extends PairHMM { + /** + * Should we use exact log10 calculation (true), or an approximation (false)? + */ private final boolean doExactLog10; /** @@ -58,29 +60,35 @@ public class Log10PairHMM extends PairHMM { return doExactLog10; } + /** + * {@inheritDoc} + */ @Override - public void initialize( final int READ_MAX_LENGTH, final int HAPLOTYPE_MAX_LENGTH ) { - super.initialize(READ_MAX_LENGTH, HAPLOTYPE_MAX_LENGTH); + public void initialize( final int readMaxLength, final int haplotypeMaxLength) { + super.initialize(readMaxLength, haplotypeMaxLength); - for( int iii=0; iii < X_METRIC_LENGTH; iii++ ) { + for( int iii=0; iii < X_METRIC_MAX_LENGTH; iii++ ) { Arrays.fill(matchMetricArray[iii], Double.NEGATIVE_INFINITY); Arrays.fill(XMetricArray[iii], Double.NEGATIVE_INFINITY); Arrays.fill(YMetricArray[iii], Double.NEGATIVE_INFINITY); } - - // the initial condition - matchMetricArray[1][1] = 0.0; //Math.log10(1.0 / nPotentialXStarts); } + /** + * {@inheritDoc} + */ @Override - public double computeReadLikelihoodGivenHaplotypeLog10( final byte[] haplotypeBases, - final byte[] readBases, - final byte[] readQuals, - final byte[] insertionGOP, - final byte[] deletionGOP, - final byte[] overallGCP, - final int hapStartIndex, - final boolean recacheReadValues ) { + public double subComputeReadLikelihoodGivenHaplotypeLog10( final byte[] haplotypeBases, + final byte[] readBases, + final byte[] readQuals, + final byte[] insertionGOP, + final byte[] deletionGOP, + final byte[] overallGCP, + final int hapStartIndex, + final boolean recacheReadValues ) { + // the initial condition -- must be in subComputeReadLikelihoodGivenHaplotypeLog10 because it needs that actual + // read and haplotypes, not the maximum + matchMetricArray[1][1] = getNPotentialXStartsLikelihoodPenaltyLog10(haplotypeBases.length, readBases.length); // M, X, and Y arrays are of size read and haplotype + 1 because of an extra column for initial conditions and + 1 to consider the final base in a non-global alignment final int X_METRIC_LENGTH = readBases.length + 2; @@ -106,13 +114,22 @@ public class Log10PairHMM extends PairHMM { return myLog10SumLog10(new double[]{matchMetricArray[endI][endJ], XMetricArray[endI][endJ], YMetricArray[endI][endJ]}); } + /** + * Compute the log10SumLog10 of the values + * + * NOTE NOTE NOTE + * + * Log10PairHMM depends critically on this function tolerating values that are all -Infinity + * and the sum returning -Infinity. Note good. Needs to be fixed. + * + * NOTE NOTE NOTE + * + * @param values an array of log10 probabilities that need to be summed + * @return the log10 of the sum of the probabilities + */ @Requires("values != null") - @Ensures("MathUtils.goodLog10Probability(result)") private double myLog10SumLog10(final double[] values) { - if ( doExactLog10 ) - return MathUtils.log10sumLog10(values); - else - return MathUtils.approximateLog10SumLog10(values); + return doExactLog10 ? MathUtils.log10sumLog10(values) : MathUtils.approximateLog10SumLog10(values); } private void updateCell( final int indI, final int indJ, final byte[] haplotypeBases, final byte[] readBases, diff --git a/public/java/src/org/broadinstitute/sting/utils/pairhmm/PairHMM.java b/public/java/src/org/broadinstitute/sting/utils/pairhmm/PairHMM.java index d76afff4e..f898faaf3 100644 --- a/public/java/src/org/broadinstitute/sting/utils/pairhmm/PairHMM.java +++ b/public/java/src/org/broadinstitute/sting/utils/pairhmm/PairHMM.java @@ -27,20 +27,25 @@ package org.broadinstitute.sting.utils.pairhmm; import com.google.java.contract.Ensures; import com.google.java.contract.Requires; +import org.apache.log4j.Logger; +import org.broadinstitute.sting.utils.MathUtils; /** - * Created with IntelliJ IDEA. + * Util class for performing the pair HMM for local alignment. Figure 4.3 in Durbin 1998 book. + * * User: rpoplin * Date: 10/16/12 */ public abstract class PairHMM { + protected final static Logger logger = Logger.getLogger(PairHMM.class); + protected static final Byte MAX_CACHED_QUAL = Byte.MAX_VALUE; protected static final byte DEFAULT_GOP = (byte) 45; protected static final byte DEFAULT_GCP = (byte) 10; public enum HMM_IMPLEMENTATION { /* Very slow implementation which uses very accurate log10 sum functions. Only meant to be used as a reference test implementation */ - EXACT, // TODO -- merge with original, using boolean parameter to determine accuracy of HMM + EXACT, /* PairHMM as implemented for the UnifiedGenotyper. Uses log10 sum functions accurate to only 1E-4 */ ORIGINAL, /* Optimized version of the PairHMM which caches per-read computations and operations in real space to avoid costly sums of log10'ed likelihoods */ @@ -50,34 +55,172 @@ public abstract class PairHMM { protected double[][] matchMetricArray = null; protected double[][] XMetricArray = null; protected double[][] YMetricArray = null; - protected int X_METRIC_LENGTH, Y_METRIC_LENGTH; - protected int nPotentialXStarts = 0; + protected int maxHaplotypeLength, maxReadLength; + protected int X_METRIC_MAX_LENGTH, Y_METRIC_MAX_LENGTH; + private boolean initialized = false; + + /** + * Initialize this PairHMM, making it suitable to run against a read and haplotype with given lengths + * @param readMaxLength the max length of reads we want to use with this PairHMM + * @param haplotypeMaxLength the max length of haplotypes we want to use with this PairHMM + */ + public void initialize( final int readMaxLength, final int haplotypeMaxLength ) { + if ( readMaxLength <= 0 ) throw new IllegalArgumentException("READ_MAX_LENGTH must be > 0 but got " + readMaxLength); + if ( haplotypeMaxLength <= 0 ) throw new IllegalArgumentException("HAPLOTYPE_MAX_LENGTH must be > 0 but got " + haplotypeMaxLength); + + maxHaplotypeLength = haplotypeMaxLength; + maxReadLength = readMaxLength; - public void initialize( final int READ_MAX_LENGTH, final int HAPLOTYPE_MAX_LENGTH ) { // M, X, and Y arrays are of size read and haplotype + 1 because of an extra column for initial conditions and + 1 to consider the final base in a non-global alignment - X_METRIC_LENGTH = READ_MAX_LENGTH + 2; - Y_METRIC_LENGTH = HAPLOTYPE_MAX_LENGTH + 2; + X_METRIC_MAX_LENGTH = readMaxLength + 2; + Y_METRIC_MAX_LENGTH = haplotypeMaxLength + 2; - // the number of potential start sites for the read against the haplotype - // for example, a 3 bp read against a 5 bp haplotype could potentially start at 1, 2, 3 = 5 - 3 + 1 = 3 - nPotentialXStarts = HAPLOTYPE_MAX_LENGTH - READ_MAX_LENGTH + 1; - - // TODO -- add meaningful runtime checks on params - - matchMetricArray = new double[X_METRIC_LENGTH][Y_METRIC_LENGTH]; - XMetricArray = new double[X_METRIC_LENGTH][Y_METRIC_LENGTH]; - YMetricArray = new double[X_METRIC_LENGTH][Y_METRIC_LENGTH]; + matchMetricArray = new double[X_METRIC_MAX_LENGTH][Y_METRIC_MAX_LENGTH]; + XMetricArray = new double[X_METRIC_MAX_LENGTH][Y_METRIC_MAX_LENGTH]; + YMetricArray = new double[X_METRIC_MAX_LENGTH][Y_METRIC_MAX_LENGTH]; + initialized = true; } + /** + * Compute the total probability of read arising from haplotypeBases given base substitution, insertion, and deletion + * probabilities. + * + * Note on using hapStartIndex. This allows you to compute the exact true likelihood of a full haplotypes + * given a read, assuming that the previous calculation read over a full haplotype, recaching the read values, + * starting only at the place where the new haplotype bases and the previous haplotype bases different. This + * index is 0-based, and can be computed with findFirstPositionWhereHaplotypesDiffer given the two haplotypes. + * Note that this assumes that the read and all associated quals values are the same. + * + * @param haplotypeBases the full sequence (in standard SAM encoding) of the haplotype, must be >= than read bases in length + * @param readBases the bases (in standard encoding) of the read, must be <= haplotype bases in length + * @param readQuals the phred-scaled per base substitition quality scores of read. Must be the same length as readBases + * @param insertionGOP the phred-scaled per base insertion quality scores of read. Must be the same length as readBases + * @param deletionGOP the phred-scaled per base deletion quality scores of read. Must be the same length as readBases + * @param overallGCP the phred-scaled gap continuation penalties scores of read. Must be the same length as readBases + * @param hapStartIndex start the hmm calculation at this offset in haplotype bases. Used in the caching calculation + * where multiple haplotypes are used, and they only diff starting at hapStartIndex + * @param recacheReadValues if false, we don't recalculate any cached results, assuming that readBases and its associated + * parameters are the same, and only the haplotype bases are changing underneath us + * @return the log10 probability of read coming from the haplotype under the provided error model + */ + public final double computeReadLikelihoodGivenHaplotypeLog10( final byte[] haplotypeBases, + final byte[] readBases, + final byte[] readQuals, + final byte[] insertionGOP, + final byte[] deletionGOP, + final byte[] overallGCP, + final int hapStartIndex, + final boolean recacheReadValues ) { + if ( ! initialized ) throw new IllegalStateException("Must call initialize before calling computeReadLikelihoodGivenHaplotypeLog10"); + if ( haplotypeBases == null ) throw new IllegalArgumentException("haplotypeBases cannot be null"); + if ( haplotypeBases.length > maxHaplotypeLength ) throw new IllegalArgumentException("Haplotype bases is too long, got " + haplotypeBases.length + " but max is " + maxHaplotypeLength); + if ( readBases == null ) throw new IllegalArgumentException("readBases cannot be null"); + if ( readBases.length > maxReadLength ) throw new IllegalArgumentException("readBases is too long, got " + readBases.length + " but max is " + maxReadLength); + if ( readQuals.length != readBases.length ) throw new IllegalArgumentException("Read bases and read quals aren't the same size: " + readBases.length + " vs " + readQuals.length); + if ( insertionGOP.length != readBases.length ) throw new IllegalArgumentException("Read bases and read insertion quals aren't the same size: " + readBases.length + " vs " + insertionGOP.length); + if ( deletionGOP.length != readBases.length ) throw new IllegalArgumentException("Read bases and read deletion quals aren't the same size: " + readBases.length + " vs " + deletionGOP.length); + if ( overallGCP.length != readBases.length ) throw new IllegalArgumentException("Read bases and overall GCP aren't the same size: " + readBases.length + " vs " + overallGCP.length); + if ( hapStartIndex < 0 || hapStartIndex > haplotypeBases.length ) throw new IllegalArgumentException("hapStartIndex is bad, must be between 0 and haplotype length " + haplotypeBases.length + " but got " + hapStartIndex); + + final double result = subComputeReadLikelihoodGivenHaplotypeLog10(haplotypeBases, readBases, readQuals, insertionGOP, deletionGOP, overallGCP, hapStartIndex, recacheReadValues); + + if ( MathUtils.goodLog10Probability(result) ) + return result; + else + throw new IllegalStateException("Bad likelihoods detected: " + result); +// return result; + } + + /** + * To be overloaded by subclasses to actually do calculation for #computeReadLikelihoodGivenHaplotypeLog10 + */ @Requires({"readBases.length == readQuals.length", "readBases.length == insertionGOP.length", "readBases.length == deletionGOP.length", "readBases.length == overallGCP.length", "matchMetricArray!=null", "XMetricArray!=null", "YMetricArray!=null"}) - @Ensures({"!Double.isInfinite(result)", "!Double.isNaN(result)", "result <= 0.0"}) // Result should be a proper log10 likelihood - public abstract double computeReadLikelihoodGivenHaplotypeLog10( final byte[] haplotypeBases, - final byte[] readBases, - final byte[] readQuals, - final byte[] insertionGOP, - final byte[] deletionGOP, - final byte[] overallGCP, - final int hapStartIndex, - final boolean recacheReadValues ); + protected abstract double subComputeReadLikelihoodGivenHaplotypeLog10( final byte[] haplotypeBases, + final byte[] readBases, + final byte[] readQuals, + final byte[] insertionGOP, + final byte[] deletionGOP, + final byte[] overallGCP, + final int hapStartIndex, + final boolean recacheReadValues ); + + /** + * How many potential starting locations are a read with readSize bases against a haplotype with haplotypeSize bases? + * + * for example, a 3 bp read against a 5 bp haplotype could potentially start at 1, 2, 3 = 5 - 3 + 1 = 3 + * the max value is necessary in the case where the read is longer than the haplotype, in which case + * there's a single unique start site by assumption + * + * @param haplotypeSize the number of bases in the haplotype we are testing + * @param readSize the number of bases in the read we are testing + * @return a positive integer >= 1 + */ + @Ensures("result >= 1") + protected int getNPotentialXStarts(final int haplotypeSize, final int readSize) { + return Math.max(haplotypeSize - readSize + 1, 1); + } + + /** + * The the log10 probability penalty for the number of potential start sites of the read aginst the haplotype + * + * @param haplotypeSize the number of bases in the haplotype we are testing + * @param readSize the number of bases in the read we are testing + * @return a log10 probability + */ + @Ensures("MathUtils.goodLog10Probability(result)") + protected double getNPotentialXStartsLikelihoodPenaltyLog10(final int haplotypeSize, final int readSize) { + return - Math.log10(getNPotentialXStarts(haplotypeSize, readSize)); + } + + /** + * Print out the core hmm matrices for debugging + */ + protected void dumpMatrices() { + dumpMatrix("matchMetricArray", matchMetricArray); + dumpMatrix("XMetricArray", XMetricArray); + dumpMatrix("YMetricArray", YMetricArray); + } + + /** + * Print out in a human readable form the matrix for debugging + * @param name the name of this matrix + * @param matrix the matrix of values + */ + @Requires({"name != null", "matrix != null"}) + private void dumpMatrix(final String name, final double[][] matrix) { + System.out.printf("%s%n", name); + for ( int i = 0; i < matrix.length; i++) { + System.out.printf("\t%s[%d]", name, i); + for ( int j = 0; j < matrix[i].length; j++ ) { + if ( Double.isInfinite(matrix[i][j]) ) + System.out.printf(" %15s", String.format("%f", matrix[i][j])); + else + System.out.printf(" % 15.5e", matrix[i][j]); + } + System.out.println(); + } + } + + /** + * Compute the first position at which two haplotypes differ + * + * If the haplotypes are exact copies of each other, returns the min length of the two haplotypes. + * + * @param haplotype1 the first haplotype1 + * @param haplotype2 the second haplotype1 + * @return the index of the first position in haplotype1 and haplotype2 where the byte isn't the same + */ + public static int findFirstPositionWhereHaplotypesDiffer(final byte[] haplotype1, final byte[] haplotype2) { + if ( haplotype1 == null || haplotype1.length == 0 ) throw new IllegalArgumentException("Haplotype1 is bad " + haplotype1); + if ( haplotype2 == null || haplotype2.length == 0 ) throw new IllegalArgumentException("Haplotype2 is bad " + haplotype2); + + for( int iii = 0; iii < haplotype1.length && iii < haplotype2.length; iii++ ) { + if( haplotype1[iii] != haplotype2[iii] ) { + return iii; + } + } + + return Math.min(haplotype1.length, haplotype2.length); + } }