Refactored binomial probability code in MathUtils.

* Moved redundant code out of UGEngine
  * Added overloaded methods that assume p=0.5 for speed efficiency
  * Added unit test for the binomialCumulativeProbability method
This commit is contained in:
Eric Banks 2013-04-09 15:37:08 -04:00
parent df189293ce
commit 5bce0e086e
5 changed files with 106 additions and 77 deletions

View File

@ -253,7 +253,7 @@ public class HeaderElement {
final int totalCount = consensusBaseCounts.totalCountWithoutIndels();
final BaseIndex mostCommon = consensusBaseCounts.baseIndexWithMostProbabilityWithoutIndels();
final int countOfOtherBases = totalCount - consensusBaseCounts.countOfBase(mostCommon);
final double pvalue = countOfOtherBases == 0 ? 0.0 : MathUtils.binomialCumulativeProbability(0, countOfOtherBases+1, totalCount, 0.5);
final double pvalue = countOfOtherBases == 0 ? 0.0 : MathUtils.binomialCumulativeProbability(totalCount, 0, countOfOtherBases);
return pvalue > minVariantPvalue;
}
@ -301,7 +301,7 @@ public class HeaderElement {
if ( baseCount == 0 )
continue;
final double pvalue = MathUtils.binomialCumulativeProbability(0, baseCount+1, totalBaseCount, 0.5);
final double pvalue = MathUtils.binomialCumulativeProbability(totalBaseCount, 0, baseCount);
if ( pvalue > minVariantPvalue ) {
if ( base == BaseIndex.D )
@ -334,7 +334,8 @@ public class HeaderElement {
if ( count == 0 || totalBaseCount == 0 )
return false;
final double pvalue = MathUtils.binomialCumulativeProbability(0, count+1, totalBaseCount, 0.5);
// technically, count can be greater than totalBaseCount (because of the way insertions are counted) so we need to account for that
final double pvalue = MathUtils.binomialCumulativeProbability(totalBaseCount, 0, Math.min(count, totalBaseCount));
return pvalue > minVariantPvalue;
}
}

View File

@ -610,20 +610,8 @@ public class UnifiedGenotyperEngine {
return stratifiedContexts;
}
private final static double[] binomialProbabilityDepthCache = new double[10000];
private final static double REF_BINOMIAL_PROB_LOG10_0_5 = Math.log10(0.5);
static {
for ( int i = 1; i < binomialProbabilityDepthCache.length; i++ ) {
binomialProbabilityDepthCache[i] = MathUtils.log10BinomialProbability(i, 0, REF_BINOMIAL_PROB_LOG10_0_5);
}
}
private final double getRefBinomialProbLog10(final int depth) {
if ( depth < binomialProbabilityDepthCache.length )
return binomialProbabilityDepthCache[depth];
else
return MathUtils.log10BinomialProbability(depth, 0, REF_BINOMIAL_PROB_LOG10_0_5);
return MathUtils.log10BinomialProbability(depth, 0);
}
private VariantCallContext estimateReferenceConfidence(VariantContext vc, Map<String, AlignmentContext> contexts, double theta, boolean ignoreCoveredSamples, double initialPofRef) {

View File

@ -195,7 +195,7 @@ public class HeaderElementUnitTest extends BaseTest {
if ( count == 0 )
continue;
final double pvalue = MathUtils.binomialCumulativeProbability(0, count + 1, total, 0.5);
final double pvalue = MathUtils.binomialCumulativeProbability(total, 0, count);
if ( pvalue > targetPvalue ) {
if ( index == BaseIndex.D.index )

View File

@ -27,7 +27,6 @@ package org.broadinstitute.sting.utils;
import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import net.sf.samtools.SAMRecord;
import org.broadinstitute.sting.gatk.GenomeAnalysisEngine;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import org.broadinstitute.sting.utils.exceptions.UserException;
@ -63,6 +62,7 @@ public class MathUtils {
* where the real-space value is 0.0.
*/
public final static double LOG10_P_OF_ZERO = -1000000.0;
public final static double FAIR_BINOMIAL_PROB_LOG10_0_5 = Math.log10(0.5);
static {
log10Cache = new double[LOG10_CACHE_SIZE];
@ -70,6 +70,7 @@ public class MathUtils {
jacobianLogTable = new double[JACOBIAN_LOG_TABLE_SIZE];
log10Cache[0] = Double.NEGATIVE_INFINITY;
log10FactorialCache[0] = 0.0;
for (int k = 1; k < LOG10_CACHE_SIZE; k++) {
log10Cache[k] = Math.log10(k);
log10FactorialCache[k] = log10FactorialCache[k-1] + log10Cache[k];
@ -306,10 +307,25 @@ public class MathUtils {
return a * b;
}
/**
* Calculates the log10 of the binomial coefficient. Designed to prevent
* overflows even with very large numbers.
*
* @param n total number of trials
* @param k number of successes
* @return the log10 of the binomial coefficient
*/
public static double binomialCoefficient(final int n, final int k) {
return Math.pow(10, log10BinomialCoefficient(n, k));
}
/**
* @see #binomialCoefficient(int, int) with log10 applied to result
*/
public static double log10BinomialCoefficient(final int n, final int k) {
return log10Factorial(n) - log10Factorial(k) - log10Factorial(n - k);
}
/**
* Computes a binomial probability. This is computed using the formula
* <p/>
@ -326,23 +342,48 @@ public class MathUtils {
return Math.pow(10, log10BinomialProbability(n, k, Math.log10(p)));
}
/**
* @see #binomialProbability(int, int, double) with log10 applied to result
*/
public static double log10BinomialProbability(final int n, final int k, final double log10p) {
double log10OneMinusP = Math.log10(1 - Math.pow(10, log10p));
return log10BinomialCoefficient(n, k) + log10p * k + log10OneMinusP * (n - k);
}
/**
* @see #binomialProbability(int, int, double) with p=0.5
*/
public static double binomialProbability(final int n, final int k) {
return Math.pow(10, log10BinomialProbability(n, k));
}
/**
* @see #binomialProbability(int, int, double) with p=0.5 and log10 applied to result
*/
public static double log10BinomialProbability(final int n, final int k) {
return log10BinomialCoefficient(n, k) + (n * FAIR_BINOMIAL_PROB_LOG10_0_5);
}
/**
* Performs the cumulative sum of binomial probabilities, where the probability calculation is done in log space.
* Assumes that the probability of a successful hit is fair (i.e. 0.5).
*
* @param start - start (inclusive) of the cumulant sum (over hits)
* @param end - end (exclusive) of the cumulant sum (over hits)
* @param total - number of attempts for the number of hits
* @param probHit - probability of a successful hit
* @param n number of attempts for the number of hits
* @param k_start start (inclusive) of the cumulant sum (over hits)
* @param k_end end (inclusive) of the cumulant sum (over hits)
* @return - returns the cumulative probability
*/
public static double binomialCumulativeProbability(final int start, final int end, final int total, final double probHit) {
public static double binomialCumulativeProbability(final int n, final int k_start, final int k_end) {
if ( k_end > n )
throw new IllegalArgumentException(String.format("Value for k_end (%d) is greater than n (%d)", k_end, n));
double cumProb = 0.0;
double prevProb;
BigDecimal probCache = BigDecimal.ZERO;
for (int hits = start; hits < end; hits++) {
for (int hits = k_start; hits <= k_end; hits++) {
prevProb = cumProb;
double probability = binomialProbability(total, hits, probHit);
double probability = binomialProbability(n, hits);
cumProb += probability;
if (probability > 0 && cumProb - prevProb < probability / 2) { // loss of precision
probCache = probCache.add(new BigDecimal(prevProb));
@ -355,6 +396,41 @@ public class MathUtils {
return probCache.add(new BigDecimal(cumProb)).doubleValue();
}
/**
* Calculates the log10 of the multinomial coefficient. Designed to prevent
* overflows even with very large numbers.
*
* @param n total number of trials
* @param k array of any size with the number of successes for each grouping (k1, k2, k3, ..., km)
* @return
*/
public static double log10MultinomialCoefficient(final int n, final int[] k) {
double denominator = 0.0;
for (int x : k) {
denominator += log10Factorial(x);
}
return log10Factorial(n) - denominator;
}
/**
* Computes the log10 of the multinomial distribution probability given a vector
* of log10 probabilities. Designed to prevent overflows even with very large numbers.
*
* @param n number of trials
* @param k array of number of successes for each possibility
* @param log10p array of log10 probabilities
* @return
*/
public static double log10MultinomialProbability(final int n, final int[] k, final double[] log10p) {
if (log10p.length != k.length)
throw new UserException.BadArgumentValue("p and k", "Array of log10 probabilities must have the same size as the array of number of sucesses: " + log10p.length + ", " + k.length);
double log10Prod = 0.0;
for (int i = 0; i < log10p.length; i++) {
log10Prod += log10p[i] * k[i];
}
return log10MultinomialCoefficient(n, k) + log10Prod;
}
/**
* Computes a multinomial coefficient efficiently avoiding overflow even for large numbers.
* This is computed using the formula:
@ -1120,58 +1196,6 @@ public class MathUtils {
return lnToLog10(lnGamma(x));
}
/**
* Calculates the log10 of the binomial coefficient. Designed to prevent
* overflows even with very large numbers.
*
* @param n total number of trials
* @param k number of successes
* @return the log10 of the binomial coefficient
*/
public static double log10BinomialCoefficient(final int n, final int k) {
return log10Factorial(n) - log10Factorial(k) - log10Factorial(n - k);
}
public static double log10BinomialProbability(final int n, final int k, final double log10p) {
double log10OneMinusP = Math.log10(1 - Math.pow(10, log10p));
return log10BinomialCoefficient(n, k) + log10p * k + log10OneMinusP * (n - k);
}
/**
* Calculates the log10 of the multinomial coefficient. Designed to prevent
* overflows even with very large numbers.
*
* @param n total number of trials
* @param k array of any size with the number of successes for each grouping (k1, k2, k3, ..., km)
* @return
*/
public static double log10MultinomialCoefficient(final int n, final int[] k) {
double denominator = 0.0;
for (int x : k) {
denominator += log10Factorial(x);
}
return log10Factorial(n) - denominator;
}
/**
* Computes the log10 of the multinomial distribution probability given a vector
* of log10 probabilities. Designed to prevent overflows even with very large numbers.
*
* @param n number of trials
* @param k array of number of successes for each possibility
* @param log10p array of log10 probabilities
* @return
*/
public static double log10MultinomialProbability(final int n, final int[] k, final double[] log10p) {
if (log10p.length != k.length)
throw new UserException.BadArgumentValue("p and k", "Array of log10 probabilities must have the same size as the array of number of sucesses: " + log10p.length + ", " + k.length);
double log10Prod = 0.0;
for (int i = 0; i < log10p.length; i++) {
log10Prod += log10p[i] * k[i];
}
return log10MultinomialCoefficient(n, k) + log10Prod;
}
public static double factorial(final int x) {
// avoid rounding errors caused by fact that 10^log(x) might be slightly lower than x and flooring may produce 1 less than real value
return (double)Math.round(Math.pow(10, log10Factorial(x)));

View File

@ -56,6 +56,22 @@ public class MathUtilsUnitTest extends BaseTest {
Assert.assertEquals(MathUtils.binomialProbability(300, 112, 0.98), 2.34763e-236, 1e-237);
}
/**
* Tests that we get the right values from the binomial distribution
*/
@Test
public void testCumulativeBinomialProbability() {
logger.warn("Executing testCumulativeBinomialProbability");
final int numTrials = 10;
for ( int i = 0; i < numTrials; i++ )
Assert.assertEquals(MathUtils.binomialCumulativeProbability(numTrials, i, i), MathUtils.binomialProbability(numTrials, i), 1e-10, String.format("k=%d, n=%d", i, numTrials));
Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 2), 0.05468750, 1e-7);
Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 5), 0.62304687, 1e-7);
Assert.assertEquals(MathUtils.binomialCumulativeProbability(10, 0, 10), 1.0, 1e-7);
}
/**
* Tests that we get the right values from the multinomial distribution
*/