diff --git a/java/src/org/broadinstitute/sting/utils/MathUtils.java b/java/src/org/broadinstitute/sting/utils/MathUtils.java index c14511f7f..7804e0664 100755 --- a/java/src/org/broadinstitute/sting/utils/MathUtils.java +++ b/java/src/org/broadinstitute/sting/utils/MathUtils.java @@ -67,7 +67,11 @@ public class MathUtils { } /** - * Computes a binomial probability + * Computes a binomial probability. This is computed using the formula + * + * B(k; n; p) = [ n! / ( k! (n - k)! ) ] (p^k)( (1-p)^k ) + * + * where n is the number of trials, k is the number of successes, and p is the probability of success * * @param k number of successes * @param n number of Bernoulli trials @@ -75,7 +79,42 @@ public class MathUtils { * * @return the binomial probability of the specified configuration. Computes values down to about 1e-237. */ - public static double binomialProbability(long k, long n, double p) { + public static double binomialProbability(int k, int n, double p) { return Arithmetic.binomial(n, k)*Math.pow(p, k)*Math.pow(1.0 - p, n - k); + //return (new cern.jet.random.Binomial(n, p, cern.jet.random.engine.RandomEngine.makeDefault())).pdf(k); } + + /** + * Computes a multinomial probability. This is computed using the formula + * + * M(x1,x2,...,xk; n; p1,p2,...,pk) = [ n! / (x1! x2! ... xk!) ] (p1^x1)(p2^x2)(...)(pk^xk) + * + * where xi represents the number of times outcome i was observed, n is the number of total observations, and + * pi represents the probability of the i-th outcome to occur. In this implementation, the value of n is + * inferred as the sum over i of xi; + * + * @param x an int[] of counts, where each element represents the number of times a certain outcome was observed + * @param p a double[] of probabilities, where each element represents the probability a given outcome can occur + * @return the multinomial probability of the specified configuration. + */ + public static double multinomialProbability(int[] x, double[] p) { + int n = 0; + for ( int obsCount : x ) { n += obsCount; } + double nfact = Arithmetic.factorial(n); + + double obsfact = 1.0, probs = 1.0, totalprob = 0.0; + for (int obsCountsIndex = 0; obsCountsIndex < x.length; obsCountsIndex++) { + double ofact = Arithmetic.factorial(x[obsCountsIndex]); + obsfact *= ofact; + probs *= Math.pow(p[obsCountsIndex], x[obsCountsIndex]); + + totalprob += p[obsCountsIndex]; + } + + assert(MathUtils.compareDoubles(totalprob, 1.0, 0.01) == 0); + + return (nfact/obsfact)*probs; + } + + }