Merge pull request #263 from broadinstitute/mccowan_reduce_reads_performance

Reduce reads performance improvements
This commit is contained in:
Eric Banks 2013-06-10 06:19:38 -07:00
commit cbb6c7ae92
3 changed files with 112 additions and 27 deletions

View File

@ -273,8 +273,9 @@ public class ReduceReads extends ReadWalker<ObjectArrayList<GATKSAMRecord>, Redu
int nCompressedReads = 0;
Object2LongOpenHashMap<String> readNameHash; // This hash will keep the name of the original read the new compressed name (a number).
private static int READ_NAME_HASH_DEFAULT_SIZE = 1000;
Long nextReadNumber = 1L; // The next number to use for the compressed read name.
Object2LongOpenHashMap<String> readNameHash; // This hash will keep the name of the original read the new compressed name (a number).
ObjectSortedSet<GenomeLoc> intervalList;
@ -313,7 +314,7 @@ public class ReduceReads extends ReadWalker<ObjectArrayList<GATKSAMRecord>, Redu
knownSnpPositions = new ObjectAVLTreeSet<GenomeLoc>();
GenomeAnalysisEngine toolkit = getToolkit();
readNameHash = new Object2LongOpenHashMap<String>(100000); // prepare the read name hash to keep track of what reads have had their read names compressed
this.resetReadNameHash(); // prepare the read name hash to keep track of what reads have had their read names compressed
intervalList = new ObjectAVLTreeSet<GenomeLoc>(); // get the interval list from the engine. If no interval list was provided, the walker will work in WGS mode
if (toolkit.getIntervals() != null)
@ -335,6 +336,16 @@ public class ReduceReads extends ReadWalker<ObjectArrayList<GATKSAMRecord>, Redu
}
}
/** Initializer for {@link #readNameHash}. */
private void resetReadNameHash() {
// If the hash grows large, subsequent clear operations can be very expensive, so trim the hash down if it grows beyond its default.
if (readNameHash == null || readNameHash.size() > READ_NAME_HASH_DEFAULT_SIZE) {
readNameHash = new Object2LongOpenHashMap<String>(READ_NAME_HASH_DEFAULT_SIZE);
} else {
readNameHash.clear();
}
}
/**
* Takes in a read and prepares it for the SlidingWindow machinery by performing the
* following optional clipping operations:
@ -471,7 +482,7 @@ public class ReduceReads extends ReadWalker<ObjectArrayList<GATKSAMRecord>, Redu
// stash.compress(), the readNameHash can be cleared after the for() loop above.
// The advantage of clearing the hash is that otherwise it holds all reads that have been encountered,
// which can use a lot of memory and cause RR to slow to a crawl and/or run out of memory.
readNameHash.clear();
this.resetReadNameHash();
}
} else

View File

@ -29,9 +29,8 @@ import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import org.broadinstitute.sting.gatk.GenomeAnalysisEngine;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import org.broadinstitute.sting.utils.exceptions.UserException;
import java.lang.IllegalArgumentException;
import javax.annotation.Nullable;
import java.math.BigDecimal;
import java.util.*;
@ -417,9 +416,35 @@ public class MathUtils {
return log10BinomialCoefficient(n, k) + (n * FAIR_BINOMIAL_PROB_LOG10_0_5);
}
/** A memoization container for {@link #binomialCumulativeProbability(int, int, int)}. Synchronized to accomodate multithreading. */
private static final Map<Long, Double> BINOMIAL_CUMULATIVE_PROBABILITY_MEMOIZATION_CACHE =
Collections.synchronizedMap(new LRUCache<Long, Double>(10_000));
/**
* Primitive integer-triplet bijection into long. Returns null when the bijection function fails (in lieu of an exception), which will
* happen when: any value is negative or larger than a short. This method is optimized for speed; it is not intended to serve as a
* utility function.
*/
@Nullable
static Long fastGenerateUniqueHashFromThreeIntegers(final int one, final int two, final int three) {
if (one < 0 || two < 0 || three < 0 || Short.MAX_VALUE < one || Short.MAX_VALUE < two || Short.MAX_VALUE < three) {
return null;
} else {
long result = 0;
result += (short) one;
result <<= 16;
result += (short) two;
result <<= 16;
result += (short) three;
return result;
}
}
/**
* Performs the cumulative sum of binomial probabilities, where the probability calculation is done in log space.
* Assumes that the probability of a successful hit is fair (i.e. 0.5).
*
* This pure function is memoized because of its expensive BigDecimal calculations.
*
* @param n number of attempts for the number of hits
* @param k_start start (inclusive) of the cumulant sum (over hits)
@ -430,23 +455,41 @@ public class MathUtils {
if ( k_end > n )
throw new IllegalArgumentException(String.format("Value for k_end (%d) is greater than n (%d)", k_end, n));
double cumProb = 0.0;
double prevProb;
BigDecimal probCache = BigDecimal.ZERO;
for (int hits = k_start; hits <= k_end; hits++) {
prevProb = cumProb;
final double probability = binomialProbability(n, hits);
cumProb += probability;
if (probability > 0 && cumProb - prevProb < probability / 2) { // loss of precision
probCache = probCache.add(new BigDecimal(prevProb));
cumProb = 0.0;
hits--; // repeat loop
// prevProb changes at start of loop
}
// Fetch cached value, if applicable.
final Long memoizationKey = fastGenerateUniqueHashFromThreeIntegers(n, k_start, k_end);
final Double memoizationCacheResult;
if (memoizationKey != null) {
memoizationCacheResult = BINOMIAL_CUMULATIVE_PROBABILITY_MEMOIZATION_CACHE.get(memoizationKey);
} else {
memoizationCacheResult = null;
}
return probCache.add(new BigDecimal(cumProb)).doubleValue();
final double result;
if (memoizationCacheResult != null) {
result = memoizationCacheResult;
} else {
double cumProb = 0.0;
double prevProb;
BigDecimal probCache = BigDecimal.ZERO;
for (int hits = k_start; hits <= k_end; hits++) {
prevProb = cumProb;
final double probability = binomialProbability(n, hits);
cumProb += probability;
if (probability > 0 && cumProb - prevProb < probability / 2) { // loss of precision
probCache = probCache.add(new BigDecimal(prevProb));
cumProb = 0.0;
hits--; // repeat loop
// prevProb changes at start of loop
}
}
result = probCache.add(new BigDecimal(cumProb)).doubleValue();
if (memoizationKey != null) {
BINOMIAL_CUMULATIVE_PROBABILITY_MEMOIZATION_CACHE.put(memoizationKey, result);
}
}
return result;
}
/**

View File

@ -41,6 +41,35 @@ public class MathUtilsUnitTest extends BaseTest {
public void init() {
}
/**
* Tests that we get unqiue values for the valid (non-null-producing) input space for {@link MathUtils#fastGenerateUniqueHashFromThreeIntegers(int, int, int)}.
*/
@Test
public void testGenerateUniqueHashFromThreePositiveIntegers() {
logger.warn("Executing testGenerateUniqueHashFromThreePositiveIntegers");
final Set<Long> observedLongs = new HashSet<Long>();
for (short i = 0; i < Byte.MAX_VALUE; i++) {
for (short j = 0; j < Byte.MAX_VALUE; j++) {
for (short k = 0; k < Byte.MAX_VALUE; k++) {
final Long aLong = MathUtils.fastGenerateUniqueHashFromThreeIntegers(i, j, k);
//System.out.println(String.format("%s, %s, %s: %s", i, j, k, aLong));
Assert.assertTrue(observedLongs.add(aLong));
}
}
}
for (short i = Byte.MAX_VALUE; i <= Short.MAX_VALUE && i > 0; i += 128) {
for (short j = Byte.MAX_VALUE; j <= Short.MAX_VALUE && j > 0; j += 128) {
for (short k = Byte.MAX_VALUE; k <= Short.MAX_VALUE && k > 0; k += 128) {
final Long aLong = MathUtils.fastGenerateUniqueHashFromThreeIntegers(i, j, k);
// System.out.println(String.format("%s, %s, %s: %s", i, j, k, aLong));
Assert.assertTrue(observedLongs.add(aLong));
}
}
}
}
/**
* Tests that we get the right values from the binomial distribution
*/
@ -64,13 +93,15 @@ public class MathUtilsUnitTest extends BaseTest {
public void testCumulativeBinomialProbability() {
logger.warn("Executing testCumulativeBinomialProbability");
final int numTrials = 10;
for ( int i = 0; i < numTrials; i++ )
Assert.assertEquals(MathUtils.binomialCumulativeProbability(numTrials, i, i), MathUtils.binomialProbability(numTrials, i), 1e-10, String.format("k=%d, n=%d", i, numTrials));
Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 2), 0.05468750, 1e-7);
Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 5), 0.62304687, 1e-7);
Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 10), 1.0, 1e-7);
for (int j = 0; j < 2; j++) { // Test memoizing functionality, as well.
final int numTrials = 10;
for ( int i = 0; i < numTrials; i++ )
Assert.assertEquals(MathUtils.binomialCumulativeProbability(numTrials, i, i), MathUtils.binomialProbability(numTrials, i), 1e-10, String.format("k=%d, n=%d", i, numTrials));
Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 2), 0.05468750, 1e-7);
Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 5), 0.62304687, 1e-7);
Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 10), 1.0, 1e-7);
}
}
/**