diff --git a/java/src/org/broadinstitute/sting/utils/MannWhitneyU.java b/java/src/org/broadinstitute/sting/utils/MannWhitneyU.java index 9509cee40..8ac80b281 100755 --- a/java/src/org/broadinstitute/sting/utils/MannWhitneyU.java +++ b/java/src/org/broadinstitute/sting/utils/MannWhitneyU.java @@ -22,15 +22,23 @@ public class MannWhitneyU { private static Normal STANDARD_NORMAL = new Normal(0.0,1.0,null); private static NormalDistribution APACHE_NORMAL = new NormalDistributionImpl(0.0,1.0,1e-2); + private static double LNSQRT2PI = Math.log(Math.sqrt(2.0*Math.PI)); private TreeSet> observations; private int sizeSet1; private int sizeSet2; + private ExactMode exactMode; public MannWhitneyU() { observations = new TreeSet>(new DitheringComparator()); sizeSet1 = 0; sizeSet2 = 0; + exactMode = ExactMode.POINT; + } + + public MannWhitneyU(ExactMode mode) { + super(); + exactMode = mode; } /** @@ -75,7 +83,10 @@ public class MannWhitneyU { return new Pair(Double.NaN,Double.NaN); } - return calculateP(n, m, u, false); + // the null hypothesis is that {N} is stochastically less than {M}, so U has counted + // occurrences of {M}s before {N}s. We would expect that this should be less than (n*m+1)/2 under + // the null hypothesis, so we want to integrate from K=0 to K=U for cumulative cases. Always. + return calculateP(n, m, u, false, exactMode); } /** @@ -94,7 +105,7 @@ public class MannWhitneyU { // test is uninformative as one or both sets have no observations return new Pair(Double.NaN,Double.NaN); } - return calculateP(n, m, u, true); + return calculateP(n, m, u, true, exactMode); } /** @@ -108,7 +119,7 @@ public class MannWhitneyU { */ @Requires({"m > 0","n > 0"}) @Ensures({"result != null", "! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"}) - public static Pair calculateP(int n, int m, long u, boolean twoSided) { + protected static Pair calculateP(int n, int m, long u, boolean twoSided, ExactMode exactMode) { Pair zandP; if ( n > 8 && m > 8 ) { // large m and n - normal approx @@ -122,7 +133,7 @@ public class MannWhitneyU { zandP = calculatePFromTable(n, m, u, twoSided); } else { // small m and n - full approx - zandP = calculatePRecursively(n,m,u, twoSided); + zandP = calculatePRecursively(n,m,u,twoSided,exactMode); } return zandP; @@ -279,16 +290,29 @@ public class MannWhitneyU { * @param m: number of set-two entries * @param u: number of set-two entries that precede set-one entries (e.g. 0,1,0,1,0 -> 3 ) * @param twoSided: whether the test is two sided or not. The recursive formula is symmetric, multiply by two for two-sidedness. + * @param mode: whether the mode is a point probability, or a cumulative distribution * @return the probability under the hypothesis that all sequences are equally likely of finding a set-two entry preceding a set-one entry "u" times. */ @Requires({"m > 0","n > 0","u >= 0"}) @Ensures({"result != null","! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"}) - public static Pair calculatePRecursively(int n, int m, long u, boolean twoSided) { + public static Pair calculatePRecursively(int n, int m, long u, boolean twoSided, ExactMode mode) { if ( m > 8 && n > 5 ) { throw new StingException(String.format("Please use the appropriate (normal or sum of uniform) approximation. Values n: %d, m: %d",n,m)); } - double p = cpr(n,m,u); + double p = mode == ExactMode.POINT ? cpr(n,m,u) : cumulativeCPR(n,m,u); + p *= twoSided ? 2.0 : 1.0; double z; try { - z = APACHE_NORMAL.inverseCumulativeProbability(p); + + if ( mode == ExactMode.CUMULATIVE ) { + z = APACHE_NORMAL.inverseCumulativeProbability(p); + } else { + double sd = Math.sqrt((1.0/(n+m))*(n*m*(1+n+m))/12); // biased variance empirically better fit to distribution then asymptotic variance + if ( u >= n*m/2 ) { + z = (1.0/sd)*Math.sqrt(-2.0*sd*(Math.log(sd)+Math.log(p)+LNSQRT2PI)); + } else { + z = -(1.0/sd)*Math.sqrt(-2.0*sd*(Math.log(sd)+Math.log(p)+LNSQRT2PI)); + } + } + } catch (MathException me) { throw new StingException("A math exception occurred in inverting the probability",me); } @@ -303,12 +327,25 @@ public class MannWhitneyU { * @param m: number of set-two entries * @param u: number of set-two entries that precede set-one entries (e.g. 0,1,0,1,0 -> 3 ) * @return same as cpr - * @deprecated - for testing only (really) */ protected static double calculatePRecursivelyDoNotCheckValuesEvenThoughItIsSlow(int n, int m, long u) { return cpr(n,m,u); } + /** + * For testing + * + * @param n: number of set-one entries (hypothesis: set one is stochastically less than set two) + * @param m: number of set-two entries + * @param u: number of set-two entries that precede set-one entries (e.g. 0,1,0,1,0 -> 3 ) + */ + protected static long countSequences(int n, int m, long u) { + if ( u < 0 ) { return 0; } + if ( m == 0 || n == 0 ) { return u == 0 ? 1 : 0; } + + return countSequences(n-1,m,u-m) + countSequences(n,m-1,u); + } + /** * : just a shorter name for calculatePRecursively. See Mann, Whitney, [1947] * @param n: number of set-1 entries @@ -317,7 +354,7 @@ public class MannWhitneyU { * @return recursive p-value */ private static double cpr(int n, int m, long u) { - if ( u < 0 || n == 0 && m == 0 ) { + if ( u < 0 ) { return 0.0; } if ( m == 0 || n == 0 ) { @@ -330,10 +367,24 @@ public class MannWhitneyU { return (((double)n)/(n+m))*cpr(n-1,m,u-m) + (((double)m)/(n+m))*cpr(n,m-1,u); } + private static double cumulativeCPR(int n, int m, long u ) { + // from above: + // the null hypothesis is that {N} is stochastically less than {M}, so U has counted + // occurrences of {M}s before {N}s. We would expect that this should be less than (n*m+1)/2 under + // the null hypothesis, so we want to integrate from K=0 to K=U for cumulative cases. Always. + double p = 0.0; + // optimization using symmetry, use the least amount of sums possible + long uSym = ( u <= n*m/2 ) ? u : ((long)n)*m-u; + for ( long uu = 0; uu < uSym; uu++ ) { + p += cpr(n,m,uu); + } + // correct by 1.0-p if the optimization above was used (e.g. 1-right tail = left tail) + return (u <= n*m/2) ? p : 1.0-p; + } + /** * hook into the data tree, for testing purposes only * @return observations - * @deprecated - only for testing */ protected TreeSet> getObservations() { return observations; @@ -342,7 +393,6 @@ public class MannWhitneyU { /** * hook into the set sizes, for testing purposes only * @return size set 1, size set 2 - * @deprecated - only for testing */ protected Pair getSetSizes() { return new Pair(sizeSet1,sizeSet2); @@ -389,5 +439,6 @@ public class MannWhitneyU { } public enum USet { SET1, SET2 } + public enum ExactMode { POINT, CUMULATIVE } } diff --git a/java/test/org/broadinstitute/sting/utils/MWUnitTest.java b/java/test/org/broadinstitute/sting/utils/MWUnitTest.java index 78ce627ab..ec10700a9 100755 --- a/java/test/org/broadinstitute/sting/utils/MWUnitTest.java +++ b/java/test/org/broadinstitute/sting/utils/MWUnitTest.java @@ -1,8 +1,10 @@ package org.broadinstitute.sting.utils; +import cern.jet.math.Arithmetic; import org.broadinstitute.sting.BaseTest; import org.broadinstitute.sting.utils.collections.Pair; +import org.jgrapht.alg.StrongConnectivityInspector; import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; import org.testng.Assert; @@ -34,8 +36,8 @@ public class MWUnitTest extends BaseTest { mwu.add(9,MannWhitneyU.USet.SET1); mwu.add(10,MannWhitneyU.USet.SET1); mwu.add(11,MannWhitneyU.USet.SET2); - Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu.getObservations(), MannWhitneyU.USet.SET1),25l); - Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu.getObservations(),MannWhitneyU.USet.SET2),11l); + Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu.getObservations(), MannWhitneyU.USet.SET1),25L); + Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu.getObservations(),MannWhitneyU.USet.SET2),11L); MannWhitneyU mwu2 = new MannWhitneyU(); for ( int dp : new int[]{2,4,5,6,8} ) { @@ -46,20 +48,31 @@ public class MWUnitTest extends BaseTest { mwu2.add(dp,MannWhitneyU.USet.SET2); } + MannWhitneyU.ExactMode pm = MannWhitneyU.ExactMode.POINT; + MannWhitneyU.ExactMode cm = MannWhitneyU.ExactMode.CUMULATIVE; + // tests using the hypothesis that set 2 dominates set 1 (U value = 10) - Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu2.getObservations(),MannWhitneyU.USet.SET1),10l); - Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu2.getObservations(),MannWhitneyU.USet.SET2),30l); + Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu2.getObservations(),MannWhitneyU.USet.SET1),10L); + Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu2.getObservations(),MannWhitneyU.USet.SET2),30L); + Pair sizes = mwu2.getSetSizes(); - Assert.assertEquals(MannWhitneyU.calculatePUniformApproximation(sizes.first,sizes.second,10l),0.4180519701814064,1e-14); - Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.first,sizes.second,10l,false).second,0.021756021756021756,1e-14); - Assert.assertEquals(MannWhitneyU.calculatePNormalApproximation(sizes.first,sizes.second,10l,false).second,0.06214143703127617,1e-14); + + Assert.assertEquals(MannWhitneyU.calculatePUniformApproximation(sizes.first,sizes.second,10L),0.4180519701814064,1e-14); + Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.first,sizes.second,10L,false,pm).second,0.021756021756021756,1e-14); + Assert.assertEquals(MannWhitneyU.calculatePNormalApproximation(sizes.first,sizes.second,10L,false).second,0.06214143703127617,1e-14); logger.warn("Testing two-sided"); - Assert.assertEquals((double)mwu2.runTwoSidedTest().second,2*0.021756021756021756,1e-8); + Assert.assertEquals((double)mwu2.runTwoSidedTest().second,4*0.021756021756021756,1e-8); // tests using the hypothesis that set 1 dominates set 2 (U value = 30) -- empirical should be identical, normall approx close, uniform way off - Assert.assertEquals(MannWhitneyU.calculatePNormalApproximation(sizes.second,sizes.first,30l,true).second,2.0*0.08216463976903321,1e-14); - Assert.assertEquals(MannWhitneyU.calculatePUniformApproximation(sizes.second,sizes.first,30l),0.0023473625009328147,1e-14); - Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,30l,false).second,0.021756021756021756,1e-14); // note -- exactly same value as above + Assert.assertEquals(MannWhitneyU.calculatePNormalApproximation(sizes.second,sizes.first,30L,true).second,2.0*0.08216463976903321,1e-14); + Assert.assertEquals(MannWhitneyU.calculatePUniformApproximation(sizes.second,sizes.first,30L),0.0023473625009328147,1e-14); + Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,30L,false,pm).second,0.021756021756021756,1e-14); // note -- exactly same value as above + Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,29L,false,cm).second,1.0-0.08547008547008,1e-14); // r does a correction, subtracting 1 from U + Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,11L,false,cm).second,0.08547008547008,1e-14); // r does a correction, subtracting 1 from U + Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,11L,false,cm).first,-1.36918910442,1e-2); // apache inversion set to be good only to 1e-2 + Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,29L,false,cm).first,1.36918910442,1e-2); // apache inversion set to be good only to 1e-2 + Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,29L,false,pm).first,1.4908546184916176,1e-8); // PDF should be similar + Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,11L,false,pm).first,-1.4908546184916176,1e-8); // PDF should be similar logger.warn("Set 1"); Assert.assertEquals((double)mwu2.runOneSidedTest(MannWhitneyU.USet.SET1).second,0.021756021756021756,1e-8); @@ -74,8 +87,13 @@ public class MWUnitTest extends BaseTest { mwu3.add(dp,MannWhitneyU.USet.SET2); } long u = MannWhitneyU.calculateOneSidedU(mwu3.getObservations(),MannWhitneyU.USet.SET1); + //logger.warn(String.format("U is: %d",u)); Pair nums = mwu3.getSetSizes(); - Assert.assertEquals(MannWhitneyU.calculatePRecursivelyDoNotCheckValuesEvenThoughItIsSlow(nums.first,nums.second,u),3.665689149560116E-4,1e-14); + //logger.warn(String.format("Corrected p is: %.4e",MannWhitneyU.calculatePRecursivelyDoNotCheckValuesEvenThoughItIsSlow(nums.first,nums.second,u))); + //logger.warn(String.format("Counted sequences: %d",MannWhitneyU.countSequences(nums.first, nums.second, u))); + //logger.warn(String.format("Possible sequences: %d", (long) Arithmetic.binomial(nums.first+nums.second,nums.first))); + //logger.warn(String.format("Ratio: %.4e",MannWhitneyU.countSequences(nums.first,nums.second,u)/Arithmetic.binomial(nums.first+nums.second,nums.first))); + Assert.assertEquals(MannWhitneyU.calculatePRecursivelyDoNotCheckValuesEvenThoughItIsSlow(nums.first, nums.second, u), 3.665689149560116E-4, 1e-14); Assert.assertEquals(MannWhitneyU.calculatePNormalApproximation(nums.first,nums.second,u,false).second,0.0032240865760884696,1e-14); Assert.assertEquals(MannWhitneyU.calculatePUniformApproximation(nums.first,nums.second,u),0.0026195003025784036,1e-14);