From 00c06e9e52f416599bb9b906c32857848a9abd39 Mon Sep 17 00:00:00 2001 From: Michael McCowan Date: Tue, 4 Jun 2013 10:08:24 -0400 Subject: [PATCH] Performance improvements: - Memoized MathUtil's cumulative binomial probability function. - Reduced the default size of the read name map in reduced reads and handle its resets more efficiently. --- .../compression/reducereads/ReduceReads.java | 17 +++- .../broadinstitute/sting/utils/MathUtils.java | 77 +++++++++++++++---- .../sting/utils/MathUtilsUnitTest.java | 45 +++++++++-- 3 files changed, 112 insertions(+), 27 deletions(-) diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/ReduceReads.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/ReduceReads.java index eb55701ae..e636f8f17 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/ReduceReads.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/ReduceReads.java @@ -273,8 +273,9 @@ public class ReduceReads extends ReadWalker, Redu int nCompressedReads = 0; - Object2LongOpenHashMap readNameHash; // This hash will keep the name of the original read the new compressed name (a number). + private static int READ_NAME_HASH_DEFAULT_SIZE = 1000; Long nextReadNumber = 1L; // The next number to use for the compressed read name. + Object2LongOpenHashMap readNameHash; // This hash will keep the name of the original read the new compressed name (a number). ObjectSortedSet intervalList; @@ -313,7 +314,7 @@ public class ReduceReads extends ReadWalker, Redu knownSnpPositions = new ObjectAVLTreeSet(); GenomeAnalysisEngine toolkit = getToolkit(); - readNameHash = new Object2LongOpenHashMap(100000); // prepare the read name hash to keep track of what reads have had their read names compressed + this.resetReadNameHash(); // prepare the read name hash to keep track of what reads have had their read names compressed intervalList = new ObjectAVLTreeSet(); // get the interval list from the engine. If no interval list was provided, the walker will work in WGS mode if (toolkit.getIntervals() != null) @@ -335,6 +336,16 @@ public class ReduceReads extends ReadWalker, Redu } } + /** Initializer for {@link #readNameHash}. */ + private void resetReadNameHash() { + // If the hash grows large, subsequent clear operations can be very expensive, so trim the hash down if it grows beyond its default. + if (readNameHash == null || readNameHash.size() > READ_NAME_HASH_DEFAULT_SIZE) { + readNameHash = new Object2LongOpenHashMap(READ_NAME_HASH_DEFAULT_SIZE); + } else { + readNameHash.clear(); + } + } + /** * Takes in a read and prepares it for the SlidingWindow machinery by performing the * following optional clipping operations: @@ -471,7 +482,7 @@ public class ReduceReads extends ReadWalker, Redu // stash.compress(), the readNameHash can be cleared after the for() loop above. // The advantage of clearing the hash is that otherwise it holds all reads that have been encountered, // which can use a lot of memory and cause RR to slow to a crawl and/or run out of memory. - readNameHash.clear(); + this.resetReadNameHash(); } } else diff --git a/public/java/src/org/broadinstitute/sting/utils/MathUtils.java b/public/java/src/org/broadinstitute/sting/utils/MathUtils.java index dfd3537da..07aff5983 100644 --- a/public/java/src/org/broadinstitute/sting/utils/MathUtils.java +++ b/public/java/src/org/broadinstitute/sting/utils/MathUtils.java @@ -29,9 +29,8 @@ import com.google.java.contract.Ensures; import com.google.java.contract.Requires; import org.broadinstitute.sting.gatk.GenomeAnalysisEngine; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; -import org.broadinstitute.sting.utils.exceptions.UserException; -import java.lang.IllegalArgumentException; +import javax.annotation.Nullable; import java.math.BigDecimal; import java.util.*; @@ -417,9 +416,35 @@ public class MathUtils { return log10BinomialCoefficient(n, k) + (n * FAIR_BINOMIAL_PROB_LOG10_0_5); } + /** A memoization container for {@link #binomialCumulativeProbability(int, int, int)}. Synchronized to accomodate multithreading. */ + private static final Map BINOMIAL_CUMULATIVE_PROBABILITY_MEMOIZATION_CACHE = + Collections.synchronizedMap(new LRUCache(10_000)); + + /** + * Primitive integer-triplet bijection into long. Returns null when the bijection function fails (in lieu of an exception), which will + * happen when: any value is negative or larger than a short. This method is optimized for speed; it is not intended to serve as a + * utility function. + */ + @Nullable + static Long fastGenerateUniqueHashFromThreeIntegers(final int one, final int two, final int three) { + if (one < 0 || two < 0 || three < 0 || Short.MAX_VALUE < one || Short.MAX_VALUE < two || Short.MAX_VALUE < three) { + return null; + } else { + long result = 0; + result += (short) one; + result <<= 16; + result += (short) two; + result <<= 16; + result += (short) three; + return result; + } + } + /** * Performs the cumulative sum of binomial probabilities, where the probability calculation is done in log space. * Assumes that the probability of a successful hit is fair (i.e. 0.5). + * + * This pure function is memoized because of its expensive BigDecimal calculations. * * @param n number of attempts for the number of hits * @param k_start start (inclusive) of the cumulant sum (over hits) @@ -430,23 +455,41 @@ public class MathUtils { if ( k_end > n ) throw new IllegalArgumentException(String.format("Value for k_end (%d) is greater than n (%d)", k_end, n)); - double cumProb = 0.0; - double prevProb; - BigDecimal probCache = BigDecimal.ZERO; - - for (int hits = k_start; hits <= k_end; hits++) { - prevProb = cumProb; - final double probability = binomialProbability(n, hits); - cumProb += probability; - if (probability > 0 && cumProb - prevProb < probability / 2) { // loss of precision - probCache = probCache.add(new BigDecimal(prevProb)); - cumProb = 0.0; - hits--; // repeat loop - // prevProb changes at start of loop - } + // Fetch cached value, if applicable. + final Long memoizationKey = fastGenerateUniqueHashFromThreeIntegers(n, k_start, k_end); + final Double memoizationCacheResult; + if (memoizationKey != null) { + memoizationCacheResult = BINOMIAL_CUMULATIVE_PROBABILITY_MEMOIZATION_CACHE.get(memoizationKey); + } else { + memoizationCacheResult = null; } - return probCache.add(new BigDecimal(cumProb)).doubleValue(); + final double result; + if (memoizationCacheResult != null) { + result = memoizationCacheResult; + } else { + double cumProb = 0.0; + double prevProb; + BigDecimal probCache = BigDecimal.ZERO; + + for (int hits = k_start; hits <= k_end; hits++) { + prevProb = cumProb; + final double probability = binomialProbability(n, hits); + cumProb += probability; + if (probability > 0 && cumProb - prevProb < probability / 2) { // loss of precision + probCache = probCache.add(new BigDecimal(prevProb)); + cumProb = 0.0; + hits--; // repeat loop + // prevProb changes at start of loop + } + } + + result = probCache.add(new BigDecimal(cumProb)).doubleValue(); + if (memoizationKey != null) { + BINOMIAL_CUMULATIVE_PROBABILITY_MEMOIZATION_CACHE.put(memoizationKey, result); + } + } + return result; } /** diff --git a/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java index e4c74a0ad..3933b3830 100644 --- a/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java @@ -41,6 +41,35 @@ public class MathUtilsUnitTest extends BaseTest { public void init() { } + /** + * Tests that we get unqiue values for the valid (non-null-producing) input space for {@link MathUtils#fastGenerateUniqueHashFromThreeIntegers(int, int, int)}. + */ + @Test + public void testGenerateUniqueHashFromThreePositiveIntegers() { + logger.warn("Executing testGenerateUniqueHashFromThreePositiveIntegers"); + + final Set observedLongs = new HashSet(); + for (short i = 0; i < Byte.MAX_VALUE; i++) { + for (short j = 0; j < Byte.MAX_VALUE; j++) { + for (short k = 0; k < Byte.MAX_VALUE; k++) { + final Long aLong = MathUtils.fastGenerateUniqueHashFromThreeIntegers(i, j, k); + //System.out.println(String.format("%s, %s, %s: %s", i, j, k, aLong)); + Assert.assertTrue(observedLongs.add(aLong)); + } + } + } + + for (short i = Byte.MAX_VALUE; i <= Short.MAX_VALUE && i > 0; i += 128) { + for (short j = Byte.MAX_VALUE; j <= Short.MAX_VALUE && j > 0; j += 128) { + for (short k = Byte.MAX_VALUE; k <= Short.MAX_VALUE && k > 0; k += 128) { + final Long aLong = MathUtils.fastGenerateUniqueHashFromThreeIntegers(i, j, k); + // System.out.println(String.format("%s, %s, %s: %s", i, j, k, aLong)); + Assert.assertTrue(observedLongs.add(aLong)); + } + } + } + } + /** * Tests that we get the right values from the binomial distribution */ @@ -64,13 +93,15 @@ public class MathUtilsUnitTest extends BaseTest { public void testCumulativeBinomialProbability() { logger.warn("Executing testCumulativeBinomialProbability"); - final int numTrials = 10; - for ( int i = 0; i < numTrials; i++ ) - Assert.assertEquals(MathUtils.binomialCumulativeProbability(numTrials, i, i), MathUtils.binomialProbability(numTrials, i), 1e-10, String.format("k=%d, n=%d", i, numTrials)); - - Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 2), 0.05468750, 1e-7); - Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 5), 0.62304687, 1e-7); - Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 10), 1.0, 1e-7); + for (int j = 0; j < 2; j++) { // Test memoizing functionality, as well. + final int numTrials = 10; + for ( int i = 0; i < numTrials; i++ ) + Assert.assertEquals(MathUtils.binomialCumulativeProbability(numTrials, i, i), MathUtils.binomialProbability(numTrials, i), 1e-10, String.format("k=%d, n=%d", i, numTrials)); + + Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 2), 0.05468750, 1e-7); + Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 5), 0.62304687, 1e-7); + Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 10), 1.0, 1e-7); + } } /**