After extensive testing of MannWhitneyU:

- Verified that exact calculations do agree with R's dwilcox()
 - Verified that exact calculations do not agree with R's wilcox.test
   + This is because R does a correction, and calculates CDFs rather than PDFs (e.g. sums over dwilcox() values)
 - Can now specify MWU to calculate cumulative exact tests, rather than point probabilities
 - Z-scores are now calculated properly for exact tests
   + Previously, z-values calculated by inverting normal CDF from U-statistic PDF
   + Now both inversions are done, with a smart heuristic (biased variance) to make the point-calculated Z-value more accurate
   + Additional tests



git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@5911 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
chartl 2011-06-01 15:51:27 +00:00
parent 2b5683909e
commit a79967d9af
2 changed files with 92 additions and 23 deletions

View File

@ -22,15 +22,23 @@ public class MannWhitneyU {
private static Normal STANDARD_NORMAL = new Normal(0.0,1.0,null);
private static NormalDistribution APACHE_NORMAL = new NormalDistributionImpl(0.0,1.0,1e-2);
private static double LNSQRT2PI = Math.log(Math.sqrt(2.0*Math.PI));
private TreeSet<Pair<Number,USet>> observations;
private int sizeSet1;
private int sizeSet2;
private ExactMode exactMode;
public MannWhitneyU() {
observations = new TreeSet<Pair<Number,USet>>(new DitheringComparator());
sizeSet1 = 0;
sizeSet2 = 0;
exactMode = ExactMode.POINT;
}
public MannWhitneyU(ExactMode mode) {
super();
exactMode = mode;
}
/**
@ -75,7 +83,10 @@ public class MannWhitneyU {
return new Pair<Double,Double>(Double.NaN,Double.NaN);
}
return calculateP(n, m, u, false);
// the null hypothesis is that {N} is stochastically less than {M}, so U has counted
// occurrences of {M}s before {N}s. We would expect that this should be less than (n*m+1)/2 under
// the null hypothesis, so we want to integrate from K=0 to K=U for cumulative cases. Always.
return calculateP(n, m, u, false, exactMode);
}
/**
@ -94,7 +105,7 @@ public class MannWhitneyU {
// test is uninformative as one or both sets have no observations
return new Pair<Double,Double>(Double.NaN,Double.NaN);
}
return calculateP(n, m, u, true);
return calculateP(n, m, u, true, exactMode);
}
/**
@ -108,7 +119,7 @@ public class MannWhitneyU {
*/
@Requires({"m > 0","n > 0"})
@Ensures({"result != null", "! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"})
public static Pair<Double,Double> calculateP(int n, int m, long u, boolean twoSided) {
protected static Pair<Double,Double> calculateP(int n, int m, long u, boolean twoSided, ExactMode exactMode) {
Pair<Double,Double> zandP;
if ( n > 8 && m > 8 ) {
// large m and n - normal approx
@ -122,7 +133,7 @@ public class MannWhitneyU {
zandP = calculatePFromTable(n, m, u, twoSided);
} else {
// small m and n - full approx
zandP = calculatePRecursively(n,m,u, twoSided);
zandP = calculatePRecursively(n,m,u,twoSided,exactMode);
}
return zandP;
@ -279,16 +290,29 @@ public class MannWhitneyU {
* @param m: number of set-two entries
* @param u: number of set-two entries that precede set-one entries (e.g. 0,1,0,1,0 -> 3 )
* @param twoSided: whether the test is two sided or not. The recursive formula is symmetric, multiply by two for two-sidedness.
* @param mode: whether the mode is a point probability, or a cumulative distribution
* @return the probability under the hypothesis that all sequences are equally likely of finding a set-two entry preceding a set-one entry "u" times.
*/
@Requires({"m > 0","n > 0","u >= 0"})
@Ensures({"result != null","! Double.isInfinite(result.getFirst())", "! Double.isInfinite(result.getSecond())"})
public static Pair<Double,Double> calculatePRecursively(int n, int m, long u, boolean twoSided) {
public static Pair<Double,Double> calculatePRecursively(int n, int m, long u, boolean twoSided, ExactMode mode) {
if ( m > 8 && n > 5 ) { throw new StingException(String.format("Please use the appropriate (normal or sum of uniform) approximation. Values n: %d, m: %d",n,m)); }
double p = cpr(n,m,u);
double p = mode == ExactMode.POINT ? cpr(n,m,u) : cumulativeCPR(n,m,u);
p *= twoSided ? 2.0 : 1.0;
double z;
try {
z = APACHE_NORMAL.inverseCumulativeProbability(p);
if ( mode == ExactMode.CUMULATIVE ) {
z = APACHE_NORMAL.inverseCumulativeProbability(p);
} else {
double sd = Math.sqrt((1.0/(n+m))*(n*m*(1+n+m))/12); // biased variance empirically better fit to distribution then asymptotic variance
if ( u >= n*m/2 ) {
z = (1.0/sd)*Math.sqrt(-2.0*sd*(Math.log(sd)+Math.log(p)+LNSQRT2PI));
} else {
z = -(1.0/sd)*Math.sqrt(-2.0*sd*(Math.log(sd)+Math.log(p)+LNSQRT2PI));
}
}
} catch (MathException me) {
throw new StingException("A math exception occurred in inverting the probability",me);
}
@ -303,12 +327,25 @@ public class MannWhitneyU {
* @param m: number of set-two entries
* @param u: number of set-two entries that precede set-one entries (e.g. 0,1,0,1,0 -> 3 )
* @return same as cpr
* @deprecated - for testing only (really)
*/
protected static double calculatePRecursivelyDoNotCheckValuesEvenThoughItIsSlow(int n, int m, long u) {
return cpr(n,m,u);
}
/**
* For testing
*
* @param n: number of set-one entries (hypothesis: set one is stochastically less than set two)
* @param m: number of set-two entries
* @param u: number of set-two entries that precede set-one entries (e.g. 0,1,0,1,0 -> 3 )
*/
protected static long countSequences(int n, int m, long u) {
if ( u < 0 ) { return 0; }
if ( m == 0 || n == 0 ) { return u == 0 ? 1 : 0; }
return countSequences(n-1,m,u-m) + countSequences(n,m-1,u);
}
/**
* : just a shorter name for calculatePRecursively. See Mann, Whitney, [1947]
* @param n: number of set-1 entries
@ -317,7 +354,7 @@ public class MannWhitneyU {
* @return recursive p-value
*/
private static double cpr(int n, int m, long u) {
if ( u < 0 || n == 0 && m == 0 ) {
if ( u < 0 ) {
return 0.0;
}
if ( m == 0 || n == 0 ) {
@ -330,10 +367,24 @@ public class MannWhitneyU {
return (((double)n)/(n+m))*cpr(n-1,m,u-m) + (((double)m)/(n+m))*cpr(n,m-1,u);
}
private static double cumulativeCPR(int n, int m, long u ) {
// from above:
// the null hypothesis is that {N} is stochastically less than {M}, so U has counted
// occurrences of {M}s before {N}s. We would expect that this should be less than (n*m+1)/2 under
// the null hypothesis, so we want to integrate from K=0 to K=U for cumulative cases. Always.
double p = 0.0;
// optimization using symmetry, use the least amount of sums possible
long uSym = ( u <= n*m/2 ) ? u : ((long)n)*m-u;
for ( long uu = 0; uu < uSym; uu++ ) {
p += cpr(n,m,uu);
}
// correct by 1.0-p if the optimization above was used (e.g. 1-right tail = left tail)
return (u <= n*m/2) ? p : 1.0-p;
}
/**
* hook into the data tree, for testing purposes only
* @return observations
* @deprecated - only for testing
*/
protected TreeSet<Pair<Number,USet>> getObservations() {
return observations;
@ -342,7 +393,6 @@ public class MannWhitneyU {
/**
* hook into the set sizes, for testing purposes only
* @return size set 1, size set 2
* @deprecated - only for testing
*/
protected Pair<Integer,Integer> getSetSizes() {
return new Pair<Integer,Integer>(sizeSet1,sizeSet2);
@ -389,5 +439,6 @@ public class MannWhitneyU {
}
public enum USet { SET1, SET2 }
public enum ExactMode { POINT, CUMULATIVE }
}

View File

@ -1,8 +1,10 @@
package org.broadinstitute.sting.utils;
import cern.jet.math.Arithmetic;
import org.broadinstitute.sting.BaseTest;
import org.broadinstitute.sting.utils.collections.Pair;
import org.jgrapht.alg.StrongConnectivityInspector;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import org.testng.Assert;
@ -34,8 +36,8 @@ public class MWUnitTest extends BaseTest {
mwu.add(9,MannWhitneyU.USet.SET1);
mwu.add(10,MannWhitneyU.USet.SET1);
mwu.add(11,MannWhitneyU.USet.SET2);
Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu.getObservations(), MannWhitneyU.USet.SET1),25l);
Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu.getObservations(),MannWhitneyU.USet.SET2),11l);
Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu.getObservations(), MannWhitneyU.USet.SET1),25L);
Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu.getObservations(),MannWhitneyU.USet.SET2),11L);
MannWhitneyU mwu2 = new MannWhitneyU();
for ( int dp : new int[]{2,4,5,6,8} ) {
@ -46,20 +48,31 @@ public class MWUnitTest extends BaseTest {
mwu2.add(dp,MannWhitneyU.USet.SET2);
}
MannWhitneyU.ExactMode pm = MannWhitneyU.ExactMode.POINT;
MannWhitneyU.ExactMode cm = MannWhitneyU.ExactMode.CUMULATIVE;
// tests using the hypothesis that set 2 dominates set 1 (U value = 10)
Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu2.getObservations(),MannWhitneyU.USet.SET1),10l);
Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu2.getObservations(),MannWhitneyU.USet.SET2),30l);
Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu2.getObservations(),MannWhitneyU.USet.SET1),10L);
Assert.assertEquals(MannWhitneyU.calculateOneSidedU(mwu2.getObservations(),MannWhitneyU.USet.SET2),30L);
Pair<Integer,Integer> sizes = mwu2.getSetSizes();
Assert.assertEquals(MannWhitneyU.calculatePUniformApproximation(sizes.first,sizes.second,10l),0.4180519701814064,1e-14);
Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.first,sizes.second,10l,false).second,0.021756021756021756,1e-14);
Assert.assertEquals(MannWhitneyU.calculatePNormalApproximation(sizes.first,sizes.second,10l,false).second,0.06214143703127617,1e-14);
Assert.assertEquals(MannWhitneyU.calculatePUniformApproximation(sizes.first,sizes.second,10L),0.4180519701814064,1e-14);
Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.first,sizes.second,10L,false,pm).second,0.021756021756021756,1e-14);
Assert.assertEquals(MannWhitneyU.calculatePNormalApproximation(sizes.first,sizes.second,10L,false).second,0.06214143703127617,1e-14);
logger.warn("Testing two-sided");
Assert.assertEquals((double)mwu2.runTwoSidedTest().second,2*0.021756021756021756,1e-8);
Assert.assertEquals((double)mwu2.runTwoSidedTest().second,4*0.021756021756021756,1e-8);
// tests using the hypothesis that set 1 dominates set 2 (U value = 30) -- empirical should be identical, normall approx close, uniform way off
Assert.assertEquals(MannWhitneyU.calculatePNormalApproximation(sizes.second,sizes.first,30l,true).second,2.0*0.08216463976903321,1e-14);
Assert.assertEquals(MannWhitneyU.calculatePUniformApproximation(sizes.second,sizes.first,30l),0.0023473625009328147,1e-14);
Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,30l,false).second,0.021756021756021756,1e-14); // note -- exactly same value as above
Assert.assertEquals(MannWhitneyU.calculatePNormalApproximation(sizes.second,sizes.first,30L,true).second,2.0*0.08216463976903321,1e-14);
Assert.assertEquals(MannWhitneyU.calculatePUniformApproximation(sizes.second,sizes.first,30L),0.0023473625009328147,1e-14);
Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,30L,false,pm).second,0.021756021756021756,1e-14); // note -- exactly same value as above
Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,29L,false,cm).second,1.0-0.08547008547008,1e-14); // r does a correction, subtracting 1 from U
Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,11L,false,cm).second,0.08547008547008,1e-14); // r does a correction, subtracting 1 from U
Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,11L,false,cm).first,-1.36918910442,1e-2); // apache inversion set to be good only to 1e-2
Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,29L,false,cm).first,1.36918910442,1e-2); // apache inversion set to be good only to 1e-2
Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,29L,false,pm).first,1.4908546184916176,1e-8); // PDF should be similar
Assert.assertEquals(MannWhitneyU.calculatePRecursively(sizes.second,sizes.first,11L,false,pm).first,-1.4908546184916176,1e-8); // PDF should be similar
logger.warn("Set 1");
Assert.assertEquals((double)mwu2.runOneSidedTest(MannWhitneyU.USet.SET1).second,0.021756021756021756,1e-8);
@ -74,8 +87,13 @@ public class MWUnitTest extends BaseTest {
mwu3.add(dp,MannWhitneyU.USet.SET2);
}
long u = MannWhitneyU.calculateOneSidedU(mwu3.getObservations(),MannWhitneyU.USet.SET1);
//logger.warn(String.format("U is: %d",u));
Pair<Integer,Integer> nums = mwu3.getSetSizes();
Assert.assertEquals(MannWhitneyU.calculatePRecursivelyDoNotCheckValuesEvenThoughItIsSlow(nums.first,nums.second,u),3.665689149560116E-4,1e-14);
//logger.warn(String.format("Corrected p is: %.4e",MannWhitneyU.calculatePRecursivelyDoNotCheckValuesEvenThoughItIsSlow(nums.first,nums.second,u)));
//logger.warn(String.format("Counted sequences: %d",MannWhitneyU.countSequences(nums.first, nums.second, u)));
//logger.warn(String.format("Possible sequences: %d", (long) Arithmetic.binomial(nums.first+nums.second,nums.first)));
//logger.warn(String.format("Ratio: %.4e",MannWhitneyU.countSequences(nums.first,nums.second,u)/Arithmetic.binomial(nums.first+nums.second,nums.first)));
Assert.assertEquals(MannWhitneyU.calculatePRecursivelyDoNotCheckValuesEvenThoughItIsSlow(nums.first, nums.second, u), 3.665689149560116E-4, 1e-14);
Assert.assertEquals(MannWhitneyU.calculatePNormalApproximation(nums.first,nums.second,u,false).second,0.0032240865760884696,1e-14);
Assert.assertEquals(MannWhitneyU.calculatePUniformApproximation(nums.first,nums.second,u),0.0026195003025784036,1e-14);