A better (less overflow-y) implementation of multinomialProbability().
git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@579 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
parent
4f818f5c1c
commit
16467ae7cf
|
|
@ -91,30 +91,35 @@ public class MathUtils {
|
|||
*
|
||||
* where xi represents the number of times outcome i was observed, n is the number of total observations, and
|
||||
* pi represents the probability of the i-th outcome to occur. In this implementation, the value of n is
|
||||
* inferred as the sum over i of xi;
|
||||
* inferred as the sum over i of xi.
|
||||
*
|
||||
* @param x an int[] of counts, where each element represents the number of times a certain outcome was observed
|
||||
* @param p a double[] of probabilities, where each element represents the probability a given outcome can occur
|
||||
* @return the multinomial probability of the specified configuration.
|
||||
*/
|
||||
public static double multinomialProbability(int[] x, double[] p) {
|
||||
int n = 0;
|
||||
for ( int obsCount : x ) { n += obsCount; }
|
||||
double nfact = Arithmetic.factorial(n);
|
||||
// In order to avoid overflow in computing large factorials in the multinomial
|
||||
// coefficient, we split the calculation up into the product of a bunch of
|
||||
// binomial coefficients.
|
||||
|
||||
double multinomialCoefficient = 1.0;
|
||||
|
||||
double obsfact = 1.0, probs = 1.0, totalprob = 0.0;
|
||||
for (int i = 0; i < x.length; i++) {
|
||||
int n = 0;
|
||||
for (int j = 0; j <= i; j++) { n += x[j]; }
|
||||
|
||||
double multinomialTerm = Arithmetic.binomial(n, x[i]);
|
||||
multinomialCoefficient *= multinomialTerm;
|
||||
}
|
||||
|
||||
double probs = 1.0, totalprob = 0.0;
|
||||
for (int obsCountsIndex = 0; obsCountsIndex < x.length; obsCountsIndex++) {
|
||||
double ofact = Arithmetic.factorial(x[obsCountsIndex]);
|
||||
obsfact *= ofact;
|
||||
probs *= Math.pow(p[obsCountsIndex], x[obsCountsIndex]);
|
||||
|
||||
totalprob += p[obsCountsIndex];
|
||||
}
|
||||
|
||||
assert(MathUtils.compareDoubles(totalprob, 1.0, 0.01) == 0);
|
||||
|
||||
return (nfact/obsfact)*probs;
|
||||
return multinomialCoefficient*probs;
|
||||
}
|
||||
|
||||
|
||||
}
|
||||
|
|
|
|||
Loading…
Reference in New Issue