Added in unit tests for the statistics calculated by the test runner; and bug-fixes to the calculations; so we have some assurance that the statistics coming out the back-end are correct.

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@5380 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
chartl 2011-03-06 16:54:02 +00:00
parent c40efe1dea
commit a40a8006b5
10 changed files with 230 additions and 40 deletions

View File

@ -41,6 +41,13 @@ public class AssociationTestRunner {
}
public static String runStudentT(TStatistic context) {
Pair<Double,Double> stats = testStudentT(context);
double t = stats.first;
double p = stats.second;
return String.format("T: %.2f\tP: %.2e",t,p);
}
public static Pair<Double,Double> testStudentT(TStatistic context) {
Map<CaseControl.Cohort,Collection<Number>> caseControlVectors = context.getCaseControl();
double meanCase = MathUtils.average(caseControlVectors.get(CaseControl.Cohort.CASE));
double varCase = MathUtils.variance(caseControlVectors.get(CaseControl.Cohort.CASE),meanCase);
@ -55,10 +62,18 @@ public class AssociationTestRunner {
StudentT studentT = new StudentT(df_num/df_denom,null);
double p = t < 0 ? 2*studentT.cdf(t) : 2*(1-studentT.cdf(t));
return String.format("T: %.2f\tP: %.2e",t,p);
return new Pair<Double,Double>(t,p);
}
public static String runZ(ZStatistic context) {
Pair<Double,Double> stats = testZ(context);
double z = stats.first;
double p = stats.second;
return String.format("Z: %.2f\tP: %.2e",z,p);
}
public static Pair<Double,Double> testZ(ZStatistic context) {
Map<CaseControl.Cohort,Pair<Number,Number>> caseControlCounts = context.getCaseControl();
double pCase = caseControlCounts.get(CaseControl.Cohort.CASE).first.doubleValue()/caseControlCounts.get(CaseControl.Cohort.CASE).second.doubleValue();
double pControl = caseControlCounts.get(CaseControl.Cohort.CONTROL).first.doubleValue()/caseControlCounts.get(CaseControl.Cohort.CONTROL).second.doubleValue();
@ -66,15 +81,21 @@ public class AssociationTestRunner {
double nControl = caseControlCounts.get(CaseControl.Cohort.CONTROL).second.doubleValue();
double p2 = (caseControlCounts.get(CaseControl.Cohort.CASE).first.doubleValue()+caseControlCounts.get(CaseControl.Cohort.CONTROL).first.doubleValue())/
(caseControlCounts.get(CaseControl.Cohort.CASE).second.doubleValue()+caseControlCounts.get(CaseControl.Cohort.CONTROL).first.doubleValue());
(caseControlCounts.get(CaseControl.Cohort.CASE).second.doubleValue()+caseControlCounts.get(CaseControl.Cohort.CONTROL).second.doubleValue());
double se = Math.sqrt(p2*(1-p2)*(1/nCase + 1/nControl));
double z = (pCase-pControl)/se;
double p = z < 0 ? 2*standardNormal.cdf(z) : 2*(1-standardNormal.cdf(z));
return String.format("Z: %.2f\tP: %.2e",z,p);
return new Pair<Double,Double>(z,p);
}
public static String runU(UStatistic context) {
Pair<Integer,Double> results = mannWhitneyUTest(context);
return String.format("U: %d\tP: %.2e",results.first,results.second);
}
public static Pair<Integer,Double> mannWhitneyUTest(UStatistic context) {
Map<CaseControl.Cohort,Collection<Number>> caseControlVectors = context.getCaseControl();
MannWhitneyU mwu = new MannWhitneyU();
for ( Number n : caseControlVectors.get(CaseControl.Cohort.CASE) ) {
@ -83,8 +104,7 @@ public class AssociationTestRunner {
for ( Number n : caseControlVectors.get(CaseControl.Cohort.CONTROL) ) {
mwu.add(n,MannWhitneyU.USet.SET2);
}
Pair<Integer,Double> results = mwu.runTwoSidedTest();
return String.format("U: %d\tP: %.2e",results.first,results.second);
return mwu.runTwoSidedTest();
}
public static String runFisherExact(AssociationContext context) {

View File

@ -61,16 +61,14 @@ public class RegionalAssociationWalker extends LocusWalker<MapHolder, RegionalAs
List<Class<? extends AssociationContext>> contexts = new PluginManager<AssociationContext>(AssociationContext.class).getPlugins();
Map<String,Class<? extends AssociationContext>> classNameToClass = new HashMap<String,Class<? extends AssociationContext>>(contexts.size());
for ( Class<? extends AssociationContext> clazz : contexts ) {
if (! Modifier.isAbstract(clazz.getModifiers())) {
classNameToClass.put(clazz.getSimpleName(),clazz);
}
classNameToClass.put(clazz.getSimpleName(),clazz);
}
Set<AssociationContext> validAssociations = new HashSet<AssociationContext>();
for ( String s : associationsToUse ) {
AssociationContext context;
try {
context = classNameToClass.get(s).getConstructor(new Class[]{}).newInstance(new Object[]{});
context = classNameToClass.get(s).newInstance();
} catch ( Exception e ) {
throw new StingException("The class "+s+" could not be instantiated.",e);
}

View File

@ -30,7 +30,7 @@ public class MateOtherContig extends ZStatistic {
}
}
return new Pair<Number,Number>(tot,otherCon);
return new Pair<Number,Number>(otherCon,tot);
}
}

View File

@ -18,7 +18,7 @@ public class MateUnmapped extends ZStatistic {
int numMatedReads = 0;
int numPairUnmapped = 0;
for (PileupElement e : pileup ) {
if ( e.getRead().getProperPairFlag() ) {
if (e.getRead().getReadPairedFlag() ) {
++numMatedReads;
if ( e.getRead().getMateUnmappedFlag() ) {
++numPairUnmapped;
@ -26,7 +26,7 @@ public class MateUnmapped extends ZStatistic {
}
}
return new Pair<Number,Number>(numMatedReads,numPairUnmapped);
return new Pair<Number,Number>(numPairUnmapped,numMatedReads);
}
public int getWindowSize() { return 100; }

View File

@ -3,6 +3,7 @@ package org.broadinstitute.sting.oneoffprojects.walkers.association.modules;
import org.broadinstitute.sting.gatk.datasources.sample.Sample;
import org.broadinstitute.sting.gatk.walkers.LocusWalker;
import org.broadinstitute.sting.gatk.walkers.Walker;
import org.broadinstitute.sting.oneoffprojects.walkers.association.MapExtender;
import org.broadinstitute.sting.oneoffprojects.walkers.association.RegionalAssociationWalker;
import org.broadinstitute.sting.oneoffprojects.walkers.association.statistics.casecontrol.UStatistic;
import org.broadinstitute.sting.utils.MathUtils;
@ -39,16 +40,18 @@ public class SampleDepth extends UStatistic {
}
}
public Collection<Number> map(ReadBackedPileup pileup) {
Collection<Sample> samples = pileup.getSamples();
Sample sample;
if ( samples.size() > 1 ) {
throw new StingException("Multiple samples inside a sample-specific pileup");
} else if ( samples.size() == 0 ) {
return Arrays.asList();
} else {
sample = samples.iterator().next();
@Override
public Map<Sample,Object> mapLocus(MapExtender extender) {
Map<Sample,ReadBackedPileup> pileups = extender.getReadFilteredPileup();
Map<Sample,Object> maps = new HashMap<Sample,Object>(pileups.size());
for ( Map.Entry<Sample,ReadBackedPileup> samPileup : pileups.entrySet() ) {
maps.put(samPileup.getKey(),map(samPileup.getKey(),samPileup.getValue()));
}
return maps;
}
public Collection<Number> map(Sample sample, ReadBackedPileup pileup) {
Object stats = sampleStats.get(sample);
double mn;
double std;
@ -68,6 +71,9 @@ public class SampleDepth extends UStatistic {
return Arrays.asList((Number)((pileup.size()-mn)/std));
}
// note: this is to satisfy the interface, and is never called due to override
public Collection<Number> map(ReadBackedPileup pileup) { return null; }
public int getWindowSize() { return 25; }
public int slideByValue() { return 5; }
public boolean usePreviouslySeenReads() { return true; }

View File

@ -24,9 +24,9 @@ public abstract class CaseControl<X> extends AssociationContext<X,X> {
for ( Map<Sample,X> sampleXMap : window ) {
for ( Map.Entry<Sample,X> entry : sampleXMap.entrySet() ) {
if ( entry.getKey().getProperty("cohort").equals("case") ) {
accum(accumCase, entry.getValue());
accumCase = accum(accumCase, entry.getValue());
} else if ( entry.getKey().getProperty("cohort").equals("control") ) {
accum(accumControl,entry.getValue());
accumControl = accum(accumControl,entry.getValue());
}
}
}

View File

@ -19,8 +19,8 @@ public abstract class UStatistic extends CaseControl<Collection<Number>> {
public abstract Collection<Number> map(ReadBackedPileup rbp );
public Collection<Number> add(Collection<Number> left, Collection<Number> right) {
if ( left instanceof List) {
((List) left).addAll(right);
if ( left instanceof ArrayList ) {
((ArrayList) left).addAll(right);
return left;
} else if ( left instanceof Set) {
((Set) left).addAll(right);

View File

@ -14,6 +14,8 @@ import java.util.TreeSet;
*/
public class MannWhitneyU {
private static Normal STANDARD_NORMAL = new Normal(0.0,1.0,null);
private TreeSet<Pair<Number,USet>> observations;
private int sizeSet1;
private int sizeSet2;
@ -64,7 +66,7 @@ public class MannWhitneyU {
double pval;
if ( n > 8 && m > 8 ) {
pval = calculatePNormalApproximation(n,m,u);
} else if ( n > 4 && m > 8 ) {
} else if ( n > 4 && m > 7 ) {
pval = calculatePUniformApproximation(n,m,u);
} else {
pval = calculatePRecursively(n,m,u);
@ -82,8 +84,9 @@ public class MannWhitneyU {
*/
public static double calculatePNormalApproximation(int n,int m,int u) {
double mean = ((double) m*n+1)/2;
Normal normal = new Normal( mean , ((double) n*m*(n+m+1))/12, null);
return u < mean ? normal.cdf(u) : 1.0-normal.cdf(u);
double var = (n*m*(n+m+1))/12;
double z = ( u - mean )/Math.sqrt(var);
return z < 0 ? STANDARD_NORMAL.cdf(z) : 1.0-STANDARD_NORMAL.cdf(z);
}
/**
@ -134,13 +137,6 @@ public class MannWhitneyU {
int uSet2DomSet1 = 0;
USet previous = null;
for ( Pair<Number,USet> dataPoint : observed ) {
if ( previous != null && previous != dataPoint.second ) {
if ( dataPoint.second == USet.SET1 ) {
uSet2DomSet1 += set2SeenSoFar;
} else {
uSet1DomSet2 += set1SeenSoFar;
}
}
if ( dataPoint.second == USet.SET1 ) {
++set1SeenSoFar;
@ -148,6 +144,14 @@ public class MannWhitneyU {
++set2SeenSoFar;
}
if ( previous != null ) {
if ( dataPoint.second == USet.SET1 ) {
uSet2DomSet1 += set2SeenSoFar;
} else {
uSet1DomSet2 += set1SeenSoFar;
}
}
previous = dataPoint.second;
}
@ -164,7 +168,7 @@ public class MannWhitneyU {
* @return the probability under the hypothesis that all sequences are equally likely of finding a set-two entry preceding a set-one entry "u" times.
*/
public static double calculatePRecursively(int n, int m, int u) {
if ( m > 6 && n > 4 || m + n > 16 ) { throw new StingException("Please use the appropriate (normal or sum of uniform) approximation"); }
if ( m + n > 16 ) { throw new StingException("Please use the appropriate (normal or sum of uniform) approximation"); }
return cpr(n,m,u);
}
@ -185,7 +189,7 @@ public class MannWhitneyU {
}
return (n/(n+m))*cpr(n-1,m,u-m) + (m/(n+m))*cpr(n,m-1,u);
return (((double)n)/(n+m))*cpr(n-1,m,u-m) + (((double)m)/(n+m))*cpr(n,m-1,u);
}
/**

View File

@ -0,0 +1,159 @@
package org.broadinstitute.sting.oneoffprojects.walkers;
import org.broadinstitute.sting.oneoffprojects.walkers.association.statistics.casecontrol.UStatistic;
import org.broadinstitute.sting.oneoffprojects.walkers.association.statistics.casecontrol.ZStatistic;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.BaseTest;
import org.broadinstitute.sting.oneoffprojects.walkers.association.AssociationTestRunner;
import org.broadinstitute.sting.oneoffprojects.walkers.association.statistics.casecontrol.TStatistic;
import org.broadinstitute.sting.utils.collections.Pair;
import org.broadinstitute.sting.utils.pileup.ReadBackedPileup;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import org.broadinstitute.sting.BaseTest;
import org.testng.Assert;
import java.util.Arrays;
import java.util.Collection;
import java.util.HashMap;
import java.util.Map;
/**
* Created by IntelliJ IDEA.
* User: Ghost
* Date: 3/5/11
* Time: 2:06 PM
* To change this template use File | Settings | File Templates.
*/
public class RegionalAssociationUnitTest extends BaseTest {
@BeforeClass
public void init() { }
@Test
private void testTStatistics() {
logger.warn("Testing T statistics");
TTest test1 = new TTest();
test1.setCaseData((Collection) Arrays.asList(1,1,2,3,4));
test1.setControlData((Collection) Arrays.asList(10, 10, 20, 30, 40));
Assert.assertEquals(AssociationTestRunner.testStudentT(test1).second,0.1702,1e-2);
TTest test2 = new TTest();
test2.setCaseData((Collection) Arrays.asList(5, 6, 5, 2, 3, 8, 7, 12, 10, 6, 4, 2, 8, 7, 3));
test2.setControlData((Collection) Arrays.asList(1, 6, 7, 2, 3, 3, 4, 1, 2, 5, 7, 3, 10, 3, 3, 2, 3));
Assert.assertEquals(AssociationTestRunner.testStudentT(test2).second, 0.5805, 1e-2);
TTest test3 = new TTest();
test3.setCaseData((Collection) Arrays.asList(94,25,68,4,27,51,9,10,91,61,61,37,39,44,36,27,86,33,3,38,5,6,28,93,30,56,81,8,40,44));
test3.setControlData((Collection) Arrays.asList(6,64,96,85,20,74,93,18,31,20,88,38,80,50,33,81,35,8,2,69,49,6,26,74,79,63,63,96,45,18));
Assert.assertEquals(AssociationTestRunner.testStudentT(test3).second,0.8229,1e-4);
TTest test4 = new TTest();
test4.setCaseData((Collection) Arrays.asList(14,8,8,17,8,12,10,10,13,9,13,9,9,12,12,11,16,12,13,16,10,13,11,16,13,16,11,13,9,16,16,14,9,14,17,10,15,15,9,15,17,15,17,12,10,13,11,14,8,14));
test4.setControlData((Collection) Arrays.asList(7,1,4,2,3,7,8,5,5,4,10,6,4,9,2,9,9,3,3,10,1,8,9,5,3,7,2,7,10,9,4,9,2,10,10,3,2,3,4,4,5,10,9,4,3,5,6,10,5,10));
Assert.assertEquals(AssociationTestRunner.testStudentT(test4).second,0.1006,1e-4);
Assert.assertEquals(AssociationTestRunner.testStudentT(test4).first,1.657989,1e-6);
}
@Test
private void testZStatistics() {
logger.warn("Testing Z statistics");
ZTest test1 = new ZTest();
test1.setCaseData(new Pair<Number,Number>(100,500));
test1.setControlData(new Pair<Number,Number>(55,300));
Assert.assertEquals(AssociationTestRunner.testZ(test1).first,0.57742362050306,2e-6);
Assert.assertEquals(AssociationTestRunner.testZ(test1).second,0.56367,2e-5);
ZTest test2 = new ZTest();
test1.setCaseData(new Pair<Number, Number>(1020, 1800));
test1.setControlData(new Pair<Number, Number>(680, 1670));
Assert.assertEquals(AssociationTestRunner.testZ(test1).first,9.3898178216531,2e-6);
ZTest test3 = new ZTest();
test3.setCaseData(new Pair<Number,Number>(20,60));
test3.setControlData(new Pair<Number,Number>(30,80));
Assert.assertEquals(AssociationTestRunner.testZ(test3).first,-0.50917511840392,2e-6);
Assert.assertEquals(AssociationTestRunner.testZ(test3).second,0.610643593,2e-4);
}
@Test
private void testUStatistic() {
logger.warn("Testing U statistics");
UTest test1 = new UTest();
test1.setCaseData((Collection) Arrays.asList(2,4,5,6,8));
test1.setControlData((Collection) Arrays.asList(1,3,7,9,10,11,12,13));
Assert.assertEquals(AssociationTestRunner.mannWhitneyUTest(test1).first,10,0);
Assert.assertEquals(AssociationTestRunner.mannWhitneyUTest(test1).second,0.092292,5e-2); // z-approximation, off by about 0.05
Assert.assertEquals(AssociationTestRunner.mannWhitneyUTest(test1).second,0.044444,1e-3); // recursive calculation
UTest test2 = new UTest();
test2.setCaseData((Collection) Arrays.asList(1,7,8,9,10,11,15,18));
test2.setControlData((Collection) Arrays.asList(2,3,4,5,6,12,13,14,16,17));
Assert.assertEquals(AssociationTestRunner.mannWhitneyUTest(test2).first,37,0);
UTest test3 = new UTest();
test3.setCaseData((Collection)Arrays.asList(13,14,7,18,5,2,9,17,8,10,3,15,19,6,20,16,11,4,12,1));
test3.setControlData((Collection) Arrays.asList(29,21,14,10,12,11,28,19,18,13,7,27,20,5,17,16,9,23,22,26));
Assert.assertEquals(AssociationTestRunner.mannWhitneyUTest(test3).first,93,0);
Assert.assertEquals(AssociationTestRunner.mannWhitneyUTest(test3).second,2*0.00302,1e-3);
UTest test4 = new UTest();
test4.setCaseData((Collection) Arrays.asList(1,2,4,5,6,9));
test4.setControlData((Collection) Arrays.asList(3,8,11,12,13));
Assert.assertEquals(AssociationTestRunner.mannWhitneyUTest(test4).first,5,0);
Assert.assertEquals(AssociationTestRunner.mannWhitneyUTest(test4).second,0.0303,1e-4);
}
private class TTest extends TStatistic {
Map<Cohort,Collection<Number>> toTest = new HashMap<Cohort,Collection<Number>>(2);
@Override
public Map<Cohort,Collection<Number>> getCaseControl() {
return toTest;
}
public void setCaseData(Collection<Number> data) {
toTest.put(Cohort.CASE,data);
}
public void setControlData(Collection<Number> data) {
toTest.put(Cohort.CONTROL,data);
}
public Collection<Number> map(ReadBackedPileup rbp) { return null; }
public int getWindowSize() { return 1; }
public int slideByValue() { return 1; }
public boolean usePreviouslySeenReads() { return false; }
}
private class ZTest extends ZStatistic {
Map<Cohort,Pair<Number,Number>> toTest = new HashMap<Cohort,Pair<Number,Number>>(2);
@Override
public Map<Cohort,Pair<Number,Number>> getCaseControl() {
return toTest;
}
public void setCaseData(Pair<Number,Number> data) {
toTest.put(Cohort.CASE,data);
}
public void setControlData(Pair<Number,Number> data) {
toTest.put(Cohort.CONTROL,data);
}
public Pair<Number,Number> map(ReadBackedPileup p) { return null; }
public int getWindowSize() { return 1; }
public int slideByValue() { return 1; }
public boolean usePreviouslySeenReads() { return true; }
}
private class UTest extends UStatistic {
TTest test = new TTest();
public boolean usePreviouslySeenReads() { return false; }
public int getWindowSize() { return 1; }
public int slideByValue() { return 1; }
public Collection<Number> map(ReadBackedPileup p ){ return null; }
@Override
public Map<Cohort,Collection<Number>> getCaseControl() {
return test.getCaseControl();
}
public void setCaseData(Collection<Number> data) { test.setCaseData(data);}
public void setControlData(Collection<Number> data) { test.setControlData(data); }
}
}

View File

@ -147,13 +147,14 @@ class BootstrapCalls extends QScript {
trait CombineArgs extends CombineVariants {
this.reference_sequence = reference
this.intervals :+= intervalFile
this.rodBind :+= new RodBind("hiCov","vcf",rm.noheadvcf)
this.rodBind :+= new RodBind("loCov","vcf",new File("/humgen/gsa-pipeline/PVQF4/all_batches_v001/batch_001/SnpCalls/ESPGO_Gabriel_NHLBI_EOMI_setone_EOMI_Project.cleaned.annotated.handfiltered.vcf"))
this.rodBind :+= new RodBind("loCov","vcf",rm.noheadvcf)
this.rodBind :+= new RodBind("hiCov","vcf",new File("/humgen/gsa-pipeline/PVQF4/all_batches_v001/batch_001/SnpCalls/ESPGO_Gabriel_NHLBI_EOMI_setone_EOMI_Project.cleaned.annotated.handfiltered.vcf"))
this.variantMergeOptions = Some(VariantMergeType.UNION)
this.genotypeMergeOptions = Some(GenotypeMergeType.PRIORITIZE)
this.priority = "hiCov,loCov"
this.out = swapExt(bootstrapMergedOut,".vcf",".merged.combined.vcf")
this.jarFile = sting
this.memoryLimit = Some(6)
}
var combine : CombineVariants = new CombineVariants with CombineArgs
@ -162,7 +163,9 @@ class BootstrapCalls extends QScript {
trait EvalArgs extends VariantEval {
this.reference_sequence = reference
this.intervals :+= intervalFile
this.rodBind :+= new RodBind("eval","vcf",combine.out)
this.rodBind :+= new RodBind("evalCombined","vcf",combine.out)
//this.rodBind :+= new RodBind("evalCut","vcf",rm.noheadvcf)
//this.rodBind :+= new RodBind("evalFCP","vcf",new File("/humgen/gsa-pipeline/PVQF4/all_batches_v001/batch_001/SnpCalls/ESPGO_Gabriel_NHLBI_EOMI_setone_EOMI_Project.cleaned.annotated.handfiltered.vcf"))
this.rodBind :+= new RodBind("dbsnp","vcf",dbsnp)
this.jarFile = sting
this.ST = List("Filter","Novelty","JexlExpression")