Merge pull request #433 from broadinstitute/eb_finish_chartl_likelihood_posteriors

Introducing the latest-and-greatest in genotyping: CalculatePosteriors.
This commit is contained in:
Eric Banks 2013-11-27 10:53:05 -08:00
commit 42bf83cdc8
7 changed files with 924 additions and 14 deletions

View File

@ -52,6 +52,7 @@ import net.sf.samtools.SAMUtils;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.QualityUtils;
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.exceptions.UserException;
import org.broadinstitute.sting.utils.genotyper.PerReadAlleleLikelihoodMap;
import org.broadinstitute.sting.utils.haplotype.Haplotype;
@ -60,7 +61,9 @@ import org.broadinstitute.sting.utils.recalibration.covariates.RepeatCovariate;
import org.broadinstitute.sting.utils.recalibration.covariates.RepeatLengthCovariate;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import org.broadinstitute.sting.utils.sam.ReadUtils;
import org.broadinstitute.variant.variantcontext.Allele;
import org.broadinstitute.sting.utils.variant.GATKVariantContextUtils;
import org.broadinstitute.variant.variantcontext.*;
import org.broadinstitute.variant.vcf.VCFConstants;
import java.io.File;
import java.io.FileNotFoundException;
@ -572,4 +575,212 @@ public class LikelihoodCalculationEngine {
}
}
// --------------------------------------------------------------------------------
//
// Posterior GL calculations
//
// --------------------------------------------------------------------------------
public static VariantContext calculatePosteriorGLs(final VariantContext vc1,
final Collection<VariantContext> resources,
final int numRefSamplesFromMissingResources,
final double globalFrequencyPriorDirichlet,
final boolean useInputSamples,
final boolean useEM,
final boolean useAC) {
if ( useEM )
throw new IllegalArgumentException("EM loop for posterior GLs not yet implemented");
final Map<Allele,Integer> totalAlleleCounts = new HashMap<>();
for ( final VariantContext resource : resources ) {
addAlleleCounts(totalAlleleCounts,resource,useAC);
}
if ( useInputSamples ) {
addAlleleCounts(totalAlleleCounts,vc1,useAC);
}
totalAlleleCounts.put(vc1.getReference(),totalAlleleCounts.get(vc1.getReference())+numRefSamplesFromMissingResources);
// now extract the counts of the alleles present within vc1, and in order
final double[] alleleCounts = new double[vc1.getNAlleles()];
int alleleIndex = 0;
for ( final Allele allele : vc1.getAlleles() ) {
alleleCounts[alleleIndex++] = globalFrequencyPriorDirichlet + ( totalAlleleCounts.containsKey(allele) ?
totalAlleleCounts.get(allele) : 0 );
}
final List<double[]> likelihoods = new ArrayList<>(vc1.getNSamples());
for ( final Genotype genotype : vc1.getGenotypes() ) {
likelihoods.add(genotype.hasLikelihoods() ? genotype.getLikelihoods().getAsVector() : null );
}
final List<double[]> posteriors = calculatePosteriorGLs(likelihoods,alleleCounts,vc1.getMaxPloidy(2));
final GenotypesContext newContext = GenotypesContext.create();
for ( int genoIdx = 0; genoIdx < vc1.getNSamples(); genoIdx ++ ) {
final GenotypeBuilder builder = new GenotypeBuilder(vc1.getGenotype(genoIdx));
if ( posteriors.get(genoIdx) != null ) {
GATKVariantContextUtils.updateGenotypeAfterSubsetting(vc1.getAlleles(), builder,
GATKVariantContextUtils.GenotypeAssignmentMethod.USE_PLS_TO_ASSIGN, posteriors.get(genoIdx), vc1.getAlleles());
builder.attribute(VCFConstants.GENOTYPE_POSTERIORS_KEY,
Utils.listFromPrimitives(GenotypeLikelihoods.fromLog10Likelihoods(posteriors.get(genoIdx)).getAsPLs()));
}
newContext.add(builder.make());
}
final List<Integer> priors = Utils.listFromPrimitives(
GenotypeLikelihoods.fromLog10Likelihoods(getDirichletPrior(alleleCounts, vc1.getMaxPloidy(2))).getAsPLs());
return new VariantContextBuilder(vc1).genotypes(newContext).attribute("PG",priors).make();
}
/**
* Given genotype likelihoods and known allele counts, calculate the posterior likelihoods
* over the genotype states
* @param genotypeLikelihoods - the genotype likelihoods for the individual
* @param knownAlleleCountsByAllele - the known allele counts in the population. For AC=2 AN=12 site, this is {10,2}
* @param ploidy - the ploidy to assume
* @return - the posterior genotype likelihoods
*/
protected static List<double[]> calculatePosteriorGLs(final List<double[]> genotypeLikelihoods,
final double[] knownAlleleCountsByAllele,
final int ploidy) {
if ( ploidy != 2 ) {
throw new IllegalStateException("Genotype posteriors not yet implemented for ploidy != 2");
}
final double[] genotypePriorByAllele = getDirichletPrior(knownAlleleCountsByAllele,ploidy);
final List<double[]> posteriors = new ArrayList<>(genotypeLikelihoods.size());
for ( final double[] likelihoods : genotypeLikelihoods ) {
double[] posteriorLikelihoods = null;
if ( likelihoods != null ) {
if ( likelihoods.length != genotypePriorByAllele.length ) {
throw new IllegalStateException(String.format("Likelihoods not of correct size: expected %d, observed %d",
knownAlleleCountsByAllele.length*(knownAlleleCountsByAllele.length+1)/2,likelihoods.length));
}
posteriorLikelihoods = new double[genotypePriorByAllele.length];
for ( int genoIdx = 0; genoIdx < likelihoods.length; genoIdx ++ ) {
posteriorLikelihoods[genoIdx] = likelihoods[genoIdx] + genotypePriorByAllele[genoIdx];
}
posteriorLikelihoods = MathUtils.toLog10(MathUtils.normalizeFromLog10(posteriorLikelihoods));
}
posteriors.add(posteriorLikelihoods);
}
return posteriors;
}
// convenience function for a single genotypelikelihoods array. Just wraps.
protected static double[] calculatePosteriorGLs(final double[] genotypeLikelihoods,
final double[] knownAlleleCountsByAllele,
final int ploidy) {
return calculatePosteriorGLs(Arrays.asList(genotypeLikelihoods),knownAlleleCountsByAllele,ploidy).get(0);
}
/**
* Given known allele counts (whether external, from the sample, or both), calculate the prior distribution
* over genotype states. This assumes
* 1) Random sampling of alleles (known counts are unbiased, and frequency estimate is Dirichlet)
* 2) Genotype states are independent (Hardy-Weinberg)
* These assumptions give rise to a Dirichlet-Multinomial distribution of genotype states as a prior
* (the "number of trials" for the multinomial is simply the ploidy)
* @param knownCountsByAllele - the known counts per allele. For an AC=2, AN=12 site this is {10,2}
* @param ploidy - the number of chromosomes in the sample. For now restricted to 2.
* @return - the Dirichlet-Multinomial distribution over genotype states
*/
protected static double[] getDirichletPrior(final double[] knownCountsByAllele, final int ploidy) {
if ( ploidy != 2 ) {
throw new IllegalStateException("Genotype priors not yet implemented for ploidy != 2");
}
// multi-allelic format is
// AA AB BB AC BC CC AD BD CD DD ...
final double sumOfKnownCounts = MathUtils.sum(knownCountsByAllele);
final double[] priors = new double[knownCountsByAllele.length*(knownCountsByAllele.length+1)/2];
int priorIndex = 0;
for ( int allele2 = 0; allele2 < knownCountsByAllele.length; allele2++ ) {
for ( int allele1 = 0; allele1 <= allele2; allele1++) {
final int[] counts = new int[knownCountsByAllele.length];
counts[allele1] += 1;
counts[allele2] += 1;
priors[priorIndex++] = MathUtils.dirichletMultinomial(knownCountsByAllele,sumOfKnownCounts,counts,ploidy);
}
}
return priors;
}
private static void addAlleleCounts(final Map<Allele,Integer> counts, final VariantContext context, final boolean useAC) {
final int[] ac;
if ( context.hasAttribute(VCFConstants.MLE_ALLELE_COUNT_KEY) && ! useAC ) {
ac = extractInts(context.getAttribute(VCFConstants.MLE_ALLELE_COUNT_KEY));
} else if ( context.hasAttribute(VCFConstants.ALLELE_COUNT_KEY) ) {
ac = extractInts(context.getAttribute(VCFConstants.ALLELE_COUNT_KEY));
} else {
ac = new int[context.getAlternateAlleles().size()];
int idx = 0;
for ( final Allele allele : context.getAlternateAlleles() ) {
ac[idx++] = context.getCalledChrCount(allele);
}
}
for ( final Allele allele : context.getAlleles() ) {
final int count;
if ( allele.isReference() ) {
if ( context.hasAttribute(VCFConstants.ALLELE_NUMBER_KEY) ) {
count = context.getAttributeAsInt(VCFConstants.ALLELE_NUMBER_KEY,-1) - (int) MathUtils.sum(ac);
} else {
count = context.getCalledChrCount() - (int) MathUtils.sum(ac);
}
} else {
count = ac[context.getAlternateAlleles().indexOf(allele)];
}
if ( ! counts.containsKey(allele) ) {
counts.put(allele,0);
}
counts.put(allele,count + counts.get(allele));
}
}
public static int[] extractInts(final Object integerListContainingVCField) {
List<Integer> mleList = null;
if ( integerListContainingVCField instanceof List ) {
if ( ((List) integerListContainingVCField).get(0) instanceof String ) {
mleList = new ArrayList<>(((List) integerListContainingVCField).size());
for ( Object s : ((List)integerListContainingVCField)) {
mleList.add(Integer.parseInt((String) s));
}
} else {
mleList = (List<Integer>) integerListContainingVCField;
}
} else if ( integerListContainingVCField instanceof Integer ) {
mleList = Arrays.asList((Integer) integerListContainingVCField);
} else if ( integerListContainingVCField instanceof String ) {
mleList = Arrays.asList(Integer.parseInt((String)integerListContainingVCField));
}
if ( mleList == null )
throw new IllegalArgumentException(String.format("VCF does not have properly formatted "+
VCFConstants.MLE_ALLELE_COUNT_KEY+" or "+VCFConstants.ALLELE_COUNT_KEY));
final int[] mle = new int[mleList.size()];
if ( ! ( mleList.get(0) instanceof Integer ) ) {
throw new IllegalStateException("BUG: The AC values should be an Integer, but was "+mleList.get(0).getClass().getCanonicalName());
}
for ( int idx = 0; idx < mle.length; idx++) {
mle[idx] = mleList.get(idx);
}
return mle;
}
}

View File

@ -58,19 +58,34 @@ import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.pairhmm.PairHMM;
import org.broadinstitute.sting.utils.recalibration.covariates.RepeatCovariate;
import org.broadinstitute.sting.utils.recalibration.covariates.RepeatLengthCovariate;
import org.broadinstitute.variant.variantcontext.*;
import org.broadinstitute.variant.vcf.VCFConstants;
import org.testng.Assert;
import org.testng.annotations.BeforeSuite;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.*;
/**
* Unit tests for LikelihoodCalculationEngine
*/
public class LikelihoodCalculationEngineUnitTest extends BaseTest {
Allele Aref, T, C, G, Cref, ATC, ATCATC;
@BeforeSuite
public void setup() {
// alleles
Aref = Allele.create("A", true);
Cref = Allele.create("C", true);
T = Allele.create("T");
C = Allele.create("C");
G = Allele.create("G");
ATC = Allele.create("ATC");
ATCATC = Allele.create("ATCATC");
}
@Test
public void testNormalizeDiploidLikelihoodMatrixFromLog10() {
double[][] likelihoodMatrix = {
@ -263,4 +278,313 @@ public class LikelihoodCalculationEngineUnitTest extends BaseTest {
}
return true;
}
private String arraysEq(int[] a, int[] b) {
if ( a.length != b.length ) {
return String.format("NEQ: %s | %s",Arrays.toString(a),Arrays.toString(b));
}
for ( int idx = 0; idx < a.length; idx++) {
if ( a[idx] - b[idx] > 1 || b[idx] - a[idx] > 1) {
return String.format("NEQ: %s | %s",Arrays.toString(a),Arrays.toString(b));
}
}
return "";
}
private int[] _mleparse(List<Integer> s) {
int[] mle = new int[s.size()];
for ( int idx = 0; idx < mle.length; idx ++) {
mle[idx] = s.get(idx);
}
return mle;
}
private Genotype makeGwithPLs(String sample, Allele a1, Allele a2, double[] pls) {
Genotype gt = new GenotypeBuilder(sample, Arrays.asList(a1, a2)).PL(pls).make();
if ( pls != null && pls.length > 0 ) {
Assert.assertNotNull(gt.getPL());
Assert.assertTrue(gt.getPL().length > 0);
for ( int i : gt.getPL() ) {
Assert.assertTrue(i >= 0);
}
Assert.assertNotEquals(Arrays.toString(gt.getPL()),"[0]");
}
return gt;
}
private Genotype makeG(String sample, Allele a1, Allele a2) {
return GenotypeBuilder.create(sample, Arrays.asList(a1, a2));
}
private Genotype makeG(String sample, Allele a1, Allele a2, int... pls) {
return new GenotypeBuilder(sample, Arrays.asList(a1, a2)).PL(pls).make();
}
private VariantContext makeVC(String source, List<Allele> alleles, Genotype... genotypes) {
int start = 10;
int stop = start; // alleles.contains(ATC) ? start + 3 : start;
return new VariantContextBuilder(source, "1", start, stop, alleles).genotypes(Arrays.asList(genotypes)).filters(null).make();
}
@Test
private void testCalculatePosteriorNoExternalData() {
VariantContext test1 = makeVC("1",Arrays.asList(Aref,T), makeG("s1",Aref,T,20,0,10),
makeG("s2",T,T,60,40,0),
makeG("s3",Aref,Aref,0,30,90));
test1 = new VariantContextBuilder(test1).attribute(VCFConstants.MLE_ALLELE_COUNT_KEY,3).make();
VariantContext test1result = LikelihoodCalculationEngine.calculatePosteriorGLs(test1, new ArrayList<VariantContext>(), 0, 0.001, true, false, false);
Genotype test1exp1 = makeGwithPLs("s1",Aref,T,new double[]{-2.20686, -0.03073215, -1.20686});
Assert.assertTrue(test1exp1.hasPL());
Genotype test1exp2 = makeGwithPLs("s2",T,T,new double[]{-6.000066, -3.823938, -6.557894e-05});
Genotype test1exp3 = makeGwithPLs("s3",Aref,Aref,new double[]{-0.0006510083, -2.824524, -9.000651});
Assert.assertEquals("java.util.ArrayList",test1result.getGenotype(0).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY).getClass().getCanonicalName());
Assert.assertEquals(arraysEq(test1exp1.getPL(), _mleparse((List<Integer>)test1result.getGenotype(0).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), "");
Assert.assertEquals(arraysEq(test1exp2.getPL(),_mleparse((List<Integer>)test1result.getGenotype(1).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), "");
Assert.assertEquals(arraysEq(test1exp3.getPL(),_mleparse((List<Integer>)test1result.getGenotype(2).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), "");
// AA AB BB AC BC CC
// AA AC CC AT CT TT
VariantContext test2 = makeVC("2",Arrays.asList(Aref,C,T),
makeG("s1",Aref,T,30,10,60,0,15,90),
makeG("s2",Aref,C,40,0,10,30,40,80),
makeG("s3",Aref,Aref,0,5,8,15,20,40),
makeG("s4",C,T,80,40,12,20,0,10));
test2 = new VariantContextBuilder(test2).attribute(VCFConstants.MLE_ALLELE_COUNT_KEY,new ArrayList<Integer>(Arrays.asList(2,2))).make();
VariantContext test2result = LikelihoodCalculationEngine.calculatePosteriorGLs(test2,new ArrayList<VariantContext>(),5,0.001,true,false,false);
Genotype test2exp1 = makeGwithPLs("s1",Aref,T,new double[]{-2.647372, -1.045139, -6.823193, -0.04513873, -2.198182, -9.823193});
Genotype test2exp2 = makeGwithPLs("s2",Aref,C,new double[]{-3.609957, -0.007723248, -1.785778, -3.007723, -4.660767, -8.785778});
Genotype test2exp3 = makeGwithPLs("s3",Aref,Aref,new double[] {-0.06094877, -0.9587151, -2.03677,-1.958715, -3.111759, -5.23677});
Genotype test2exp4 = makeGwithPLs("s4",C,T,new double[]{-7.016534, -3.4143, -1.392355, -1.4143, -0.06734388, -1.192355});
Assert.assertEquals(arraysEq(test2exp1.getPL(),(int[]) _mleparse((List<Integer>)test2result.getGenotype(0).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), "");
Assert.assertEquals(arraysEq(test2exp2.getPL(),(int[]) _mleparse((List<Integer>)test2result.getGenotype(1).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), "");
Assert.assertEquals(arraysEq(test2exp3.getPL(),(int[]) _mleparse((List<Integer>)test2result.getGenotype(2).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), "");
Assert.assertEquals(arraysEq(test2exp4.getPL(),(int[]) _mleparse((List<Integer>)test2result.getGenotype(3).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), "");
}
@Test
private void testCalculatePosteriorSamplePlusExternal() {
VariantContext testOverlappingBase = makeVC("1", Arrays.asList(Aref,T), makeG("s1",T,T,40,20,0),
makeG("s2",Aref,T,18,0,24),
makeG("s3",Aref,T,22,0,12));
List<VariantContext> supplTest1 = new ArrayList<>(3);
supplTest1.add(new VariantContextBuilder(makeVC("2",Arrays.asList(Aref,T))).attribute(VCFConstants.MLE_ALLELE_COUNT_KEY,2).attribute(VCFConstants.ALLELE_NUMBER_KEY,10).make());
supplTest1.add(new VariantContextBuilder(makeVC("3",Arrays.asList(Aref,T))).attribute(VCFConstants.ALLELE_COUNT_KEY,4).attribute(VCFConstants.ALLELE_NUMBER_KEY,22).make());
supplTest1.add(makeVC("4",Arrays.asList(Aref,T),
makeG("s_1",T,T),
makeG("s_2",Aref,T)));
VariantContext test1result = LikelihoodCalculationEngine.calculatePosteriorGLs(testOverlappingBase,supplTest1,0,0.001,true,false,false);
// the counts here are ref=30, alt=14
Genotype test1exp1 = makeGwithPLs("t1",T,T,new double[]{-3.370985, -1.415172, -0.01721766});
Genotype test1exp2 = makeGwithPLs("t2",Aref,T,new double[]{-1.763792, -0.007978791, -3.010024});
Genotype test1exp3 = makeGwithPLs("t3",Aref,T,new double[]{-2.165587, -0.009773643, -1.811819});
Assert.assertEquals(arraysEq(test1exp1.getPL(),_mleparse((List<Integer>) test1result.getGenotype(0).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), "");
Assert.assertEquals(arraysEq(test1exp2.getPL(),_mleparse((List<Integer>) test1result.getGenotype(1).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), "");
Assert.assertEquals(arraysEq(test1exp3.getPL(),_mleparse((List<Integer>) test1result.getGenotype(2).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), "");
VariantContext testNonOverlapping = makeVC("1", Arrays.asList(Aref,T), makeG("s1",T,T,3,1,0));
List<VariantContext> other = Arrays.asList(makeVC("2",Arrays.asList(Aref,C),makeG("s2",C,C,10,2,0)));
VariantContext test2result = LikelihoodCalculationEngine.calculatePosteriorGLs(testNonOverlapping,other,0,0.001,true,false,false);
Genotype test2exp1 = makeGwithPLs("SGV",T,T,new double[]{-4.078345, -3.276502, -0.0002661066});
Assert.assertEquals(arraysEq(test2exp1.getPL(),_mleparse((List<Integer>) test2result.getGenotype(0).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), "");
}
private double[] pl2gl(int[] pl) {
double[] gl = new double[pl.length];
for ( int idx = 0; idx < gl.length; idx++ ) {
gl[idx] = pl[idx]/(-10.0);
}
return MathUtils.normalizeFromLog10(gl,true);
}
@Test
private void testCalculatePosterior() {
int[][] likelihood_PLs = new int[][]{
new int[]{3,0,3},
new int[]{99,0,99},
new int[]{50,20,0},
new int[]{10,0,50},
new int[]{80,60,0},
new int[]{0,42,44}};
int[] altCounts = new int[]{10,40,90};
int[] altAlleleNum = new int[]{100,500,1000};
double[] expected_post_10_100 = new double[] {
9.250326e-03, 3.020208e-01, 6.887289e-01,
7.693433e-12, 1.000000e+00, 5.728111e-10,
1.340156e-07, 2.192982e-03, 9.978069e-01,
6.073718e-03, 9.938811e-01, 4.522159e-05,
1.343101e-10, 2.197802e-07, 9.999998e-01,
9.960193e-01, 1.028366e-03, 2.952290e-03
};
double[] expected_post_10_500 = new double[] {
4.226647e-04, 7.513277e-02, 9.244446e-01,
1.413080e-12, 1.000000e+00, 3.090662e-09,
4.570232e-09, 4.071661e-04, 9.995928e-01,
1.120916e-03, 9.986339e-01, 2.451646e-04,
4.572093e-12, 4.073320e-08, 1.000000e+00,
9.151689e-01, 5.144399e-03, 7.968675e-02
};
double[] expected_post_10_1000 = new double[] {
1.077685e-04, 3.870477e-02, 9.611875e-01,
6.994030e-13, 1.000000e+00, 6.237975e-09,
1.120976e-09, 2.017756e-04, 9.997982e-01,
5.549722e-04, 9.989500e-01, 4.949797e-04,
1.121202e-12, 2.018163e-08, 1.000000e+00,
7.318346e-01, 8.311615e-03, 2.598538e-01
};
double[] expected_post_40_100 = new double[] {
1.102354e-01, 6.437516e-01, 2.460131e-01,
4.301328e-11, 1.000000e+00, 9.599306e-11,
4.422850e-06, 1.294493e-02, 9.870507e-01,
3.303763e-02, 9.669550e-01, 7.373032e-06,
4.480868e-09, 1.311474e-06, 9.999987e-01,
9.997266e-01, 1.846199e-04, 8.882157e-05
};
double[] expected_post_40_500 = new double[] {
5.711785e-03, 2.557266e-01, 7.385617e-01,
5.610428e-12, 1.000000e+00, 7.254558e-10,
7.720262e-08, 1.732352e-03, 9.982676e-01,
4.436495e-03, 9.955061e-01, 5.736604e-05,
7.733659e-11, 1.735358e-07, 9.999998e-01,
9.934793e-01, 1.406575e-03, 5.114153e-03
};
double[] expected_post_40_1000 = new double[] {
1.522132e-03, 1.422229e-01, 8.562549e-01,
2.688330e-12, 1.000000e+00, 1.512284e-09,
1.776184e-08, 8.317737e-04, 9.991682e-01,
2.130611e-03, 9.977495e-01, 1.198547e-04,
1.777662e-11, 8.324661e-08, 9.999999e-01,
9.752770e-01, 2.881677e-03, 2.184131e-02
};
double[] expected_post_90_100 = new double[] {
6.887289e-01, 3.020208e-01, 9.250326e-03,
5.728111e-10, 1.000000e+00, 7.693433e-12,
6.394346e-04, 1.405351e-01, 8.588255e-01,
3.127146e-01, 6.872849e-01, 4.200075e-07,
7.445327e-07, 1.636336e-05, 9.999829e-01,
9.999856e-01, 1.386699e-05, 5.346906e-07
};
double[] expected_post_90_500 = new double[] {
2.528165e-02, 4.545461e-01, 5.201723e-01,
1.397100e-11, 1.000000e+00, 2.874546e-10,
4.839050e-07, 4.360463e-03, 9.956391e-01,
1.097551e-02, 9.890019e-01, 2.258221e-05,
4.860244e-10, 4.379560e-07, 9.999996e-01,
9.986143e-01, 5.677671e-04, 8.179741e-04
};
double[] expected_post_90_1000 = new double[] {
7.035938e-03, 2.807708e-01, 7.121932e-01,
6.294627e-12, 1.000000e+00, 6.371561e-10,
9.859771e-08, 1.971954e-03, 9.980279e-01,
4.974874e-03, 9.949748e-01, 5.035678e-05,
9.879252e-11, 1.975850e-07, 9.999998e-01,
9.947362e-01, 1.255272e-03, 4.008518e-03
};
double[][] expectations = new double[][] {
expected_post_10_100,
expected_post_10_500,
expected_post_10_1000,
expected_post_40_100,
expected_post_40_500,
expected_post_40_1000,
expected_post_90_100,
expected_post_90_500,
expected_post_90_1000
};
int testIndex = 0;
for ( int altCount : altCounts ) {
for ( int numAlt : altAlleleNum ) {
double[] knownCounts = new double[2];
knownCounts[0] = altCount;
knownCounts[1] = numAlt-altCount;
int expected_index = 0;
for ( int gl_index = 0; gl_index < likelihood_PLs.length; gl_index++ ) {
double[] post = LikelihoodCalculationEngine.calculatePosteriorGLs(pl2gl(likelihood_PLs[gl_index]), knownCounts, 2);
for ( int i = 0; i < post.length; i++ ) {
double expected = expectations[testIndex][expected_index++];
double observed = Math.pow(10.0,post[i]);
double err = Math.abs( (expected-observed)/expected );
Assert.assertTrue(err < 1e-4, String.format("Counts: %s | Expected: %e | Observed: %e | pre %s | prior %s | post %s",
Arrays.toString(knownCounts), expected,observed, Arrays.toString(pl2gl(likelihood_PLs[gl_index])),
Arrays.toString(LikelihoodCalculationEngine.getDirichletPrior(knownCounts,2)),Arrays.toString(post)));
}
}
testIndex++;
}
}
}
private boolean arraysApproxEqual(double[] a, double[] b, double tol) {
if ( a.length != b.length ) {
return false;
}
for ( int idx = 0; idx < a.length; idx++ ) {
if ( Math.abs(a[idx]-b[idx]) > tol ) {
return false;
}
}
return true;
}
private String errMsgArray(double[] a, double[] b) {
return String.format("Expected %s, Observed %s", Arrays.toString(a), Arrays.toString(b));
}
@Test
private void testPosteriorMultiAllelic() {
// AA AB BB AC BC CC AD BD CD DD
int[] PL_one = new int[] {40,20,30,0,15,25};
int[] PL_two = new int[] {0,20,10,99,99,99};
int[] PL_three = new int[] {50,40,0,30,30,10,20,40,80,50};
int[] PL_four = new int[] {99,90,85,10,5,30,40,20,40,30,0,12,20,14,5};
int[] PL_five = new int[] {60,20,30,0,40,10,8,12,18,22,40,12,80,60,20};
double[] counts_one = new double[]{100.001,40.001,2.001};
double[] counts_two = new double[]{2504.001,16.001,218.001};
double[] counts_three = new double[]{10000.001,500.001,25.001,0.001};
double[] counts_four = new double[]{4140.001,812.001,32.001,104.001,12.001};
double[] counts_five = new double[]{80.001,40.001,8970.001,200.001,1922.001};
double expected_one[] = new double[] { -2.684035, -0.7852596, -2.4735, -0.08608339, -1.984017, -4.409852 };
double expected_two[] = new double[] { -5.736189e-05, -3.893688, -5.362878, -10.65938, -12.85386, -12.0186};
double expected_three[] = new double[] {-2.403234, -2.403276, -0.004467802, -2.70429, -4.005319, -3.59033, -6.102247, -9.403276, -14.70429, -13.40284};
double expected_four[] = new double[] {-7.828677, -7.335196, -7.843136, -0.7395892, -0.947033, -5.139092, -3.227715,
-1.935159, -5.339552, -4.124552, -0.1655353, -2.072979, -4.277372, -3.165498, -3.469589 };
double expected_five[] = new double[] { -9.170334, -5.175724, -6.767055, -0.8250021, -5.126027, -0.07628661, -3.276762,
-3.977787, -2.227065, -4.57769, -5.494041, -2.995066, -7.444344, -7.096104, -2.414187};
double[] post1 = LikelihoodCalculationEngine.calculatePosteriorGLs(pl2gl(PL_one),counts_one,2);
double[] post2 = LikelihoodCalculationEngine.calculatePosteriorGLs(pl2gl(PL_two),counts_two,2);
double[] post3 = LikelihoodCalculationEngine.calculatePosteriorGLs(pl2gl(PL_three),counts_three,2);
double[] post4 = LikelihoodCalculationEngine.calculatePosteriorGLs(pl2gl(PL_four),counts_four,2);
double[] post5 = LikelihoodCalculationEngine.calculatePosteriorGLs(pl2gl(PL_five),counts_five,2);
double[] expecPrior5 = new double[] {-4.2878195, -4.2932090, -4.8845400, -1.9424874, -2.2435120, -0.1937719, -3.5942477,
-3.8952723, -1.5445506, -3.4951749, -2.6115263, -2.9125508, -0.5618292, -2.2135895,
-1.5316722};
Assert.assertTrue(arraysApproxEqual(expecPrior5, LikelihoodCalculationEngine.getDirichletPrior(counts_five,2),1e-5),errMsgArray(expecPrior5,LikelihoodCalculationEngine.getDirichletPrior(counts_five,2)));
Assert.assertTrue(arraysApproxEqual(expected_one,post1,1e-6),errMsgArray(expected_one,post1));
Assert.assertTrue(arraysApproxEqual(expected_two,post2,1e-5),errMsgArray(expected_two,post2));
Assert.assertTrue(arraysApproxEqual(expected_three,post3,1e-5),errMsgArray(expected_three,post3));
Assert.assertTrue(arraysApproxEqual(expected_four,post4,1e-5),errMsgArray(expected_four,post4));
Assert.assertTrue(arraysApproxEqual(expected_five,post5,1e-5),errMsgArray(expected_five,post5));
}
}

View File

@ -1472,4 +1472,34 @@ public class MathUtils {
return sliceListByIndices(sampleIndicesWithoutReplacement(list.size(),N),list);
}
/**
* Return the likelihood of observing the counts of categories having sampled a population
* whose categorial frequencies are distributed according to a Dirichlet distribution
* @param dirichletParams - params of the prior dirichlet distribution
* @param dirichletSum - the sum of those parameters
* @param counts - the counts of observation in each category
* @param countSum - the sum of counts (number of trials)
* @return - associated likelihood
*/
public static double dirichletMultinomial(final double[] dirichletParams, final double dirichletSum,
final int[] counts, final int countSum) {
if ( dirichletParams.length != counts.length ) {
throw new IllegalStateException("The number of dirichlet parameters must match the number of categories");
}
// todo -- lots of lnGammas here. At some point we can safely switch to x * ( ln(x) - 1)
double likelihood = log10MultinomialCoefficient(countSum,counts);
likelihood += log10Gamma(dirichletSum);
likelihood -= log10Gamma(dirichletSum+countSum);
for ( int idx = 0; idx < counts.length; idx++ ) {
likelihood += log10Gamma(counts[idx] + dirichletParams[idx]);
likelihood -= log10Gamma(dirichletParams[idx]);
}
return likelihood;
}
public static double dirichletMultinomial(double[] params, int[] counts) {
return dirichletMultinomial(params,sum(params),counts,(int) sum(counts));
}
}

View File

@ -835,4 +835,18 @@ public class Utils {
// don't perform array copies if we need to copy everything anyways
return ( trimFromFront == 0 && trimFromBack == 0 ) ? seq : Arrays.copyOfRange(seq, trimFromFront, seq.length - trimFromBack);
}
/**
* Simple wrapper for sticking elements of a int[] array into a List<Integer>
* @param ar - the array whose elements should be listified
* @return - a List<Integer> where each element has the same value as the corresponding index in @ar
*/
public static List<Integer> listFromPrimitives(final int[] ar) {
final ArrayList<Integer> lst = new ArrayList<>(ar.length);
for ( final int d : ar ) {
lst.add(d);
}
return lst;
}
}

View File

@ -565,11 +565,11 @@ public class GATKVariantContextUtils {
* @param newLikelihoods a vector of likelihoods to use if the method requires PLs, should be log10 likelihoods, cannot be null
* @param allelesToUse the alleles we are using for our subsetting
*/
protected static void updateGenotypeAfterSubsetting(final List<Allele> originalGT,
final GenotypeBuilder gb,
final GenotypeAssignmentMethod assignmentMethod,
final double[] newLikelihoods,
final List<Allele> allelesToUse) {
public static void updateGenotypeAfterSubsetting(final List<Allele> originalGT,
final GenotypeBuilder gb,
final GenotypeAssignmentMethod assignmentMethod,
final double[] newLikelihoods,
final List<Allele> allelesToUse) {
gb.noAD();
switch ( assignmentMethod ) {
case SET_TO_NO_CALL:

View File

@ -522,4 +522,338 @@ public class MathUtilsUnitTest extends BaseTest {
final Comparable actual = MathUtils.median(values);
Assert.assertEquals(actual, expected, "Failed with " + values);
}
// man. All this to test dirichlet.
private double[] unwrap(List<Double> stuff) {
double[] unwrapped = new double[stuff.size()];
int idx = 0;
for ( Double d : stuff ) {
unwrapped[idx++] = d == null ? 0.0 : d;
}
return unwrapped;
}
/**
* The PartitionGenerator generates all of the partitions of a number n, e.g.
* 5 + 0
* 4 + 1
* 3 + 2
* 3 + 1 + 1
* 2 + 2 + 1
* 2 + 1 + 1 + 1
* 1 + 1 + 1 + 1 + 1
*
* This is used to help enumerate the state space over which the Dirichlet-Multinomial is defined,
* to ensure that the distribution function is properly implemented
*/
class PartitionGenerator implements Iterator<List<Integer>> {
// generate the partitions of an integer, each partition sorted numerically
int n;
List<Integer> a;
int y;
int k;
int state;
int x;
int l;
public PartitionGenerator(int n) {
this.n = n;
this.y = n - 1;
this.k = 1;
this.a = new ArrayList<Integer>();
for ( int i = 0; i < n; i++ ) {
this.a.add(i);
}
this.state = 0;
}
public void remove() { /* do nothing */ }
public boolean hasNext() { return ! ( this.k == 0 && state == 0 ); }
private String dataStr() {
return String.format("a = [%s] k = %d y = %d state = %d x = %d l = %d",
Utils.join(",",a), k, y, state, x, l);
}
public List<Integer> next() {
if ( this.state == 0 ) {
this.x = a.get(k-1)+1;
k -= 1;
this.state = 1;
}
if ( this.state == 1 ) {
while ( 2*x <= y ) {
this.a.set(k,x);
this.y -= x;
this.k++;
}
this.l = 1+this.k;
this.state = 2;
}
if ( this.state == 2 ) {
if ( x <= y ) {
this.a.set(k,x);
this.a.set(l,y);
x += 1;
y -= 1;
return this.a.subList(0, this.k + 2);
} else {
this.state =3;
}
}
if ( this.state == 3 ) {
this.a.set(k,x+y);
this.y = x + y - 1;
this.state = 0;
return a.subList(0, k + 1);
}
throw new IllegalStateException("Cannot get here");
}
public String toString() {
StringBuffer buf = new StringBuffer();
buf.append("{ ");
while ( hasNext() ) {
buf.append("[");
buf.append(Utils.join(",",next()));
buf.append("],");
}
buf.deleteCharAt(buf.lastIndexOf(","));
buf.append(" }");
return buf.toString();
}
}
/**
* NextCounts is the enumerator over the state space of the multinomial dirichlet.
*
* It filters the partition of the total sum to only those with a number of terms
* equal to the number of categories.
*
* It then generates all permutations of that partition.
*
* In so doing it enumerates over the full state space.
*/
class NextCounts implements Iterator<int[]> {
private PartitionGenerator partitioner;
private int numCategories;
private int[] next;
public NextCounts(int numCategories, int totalCounts) {
partitioner = new PartitionGenerator(totalCounts);
this.numCategories = numCategories;
next = nextFromPartitioner();
}
public void remove() { /* do nothing */ }
public boolean hasNext() { return next != null; }
public int[] next() {
int[] toReturn = clone(next);
next = nextPermutation();
if ( next == null ) {
next = nextFromPartitioner();
}
return toReturn;
}
private int[] clone(int[] arr) {
int[] a = new int[arr.length];
for ( int idx = 0; idx < a.length ; idx ++) {
a[idx] = arr[idx];
}
return a;
}
private int[] nextFromPartitioner() {
if ( partitioner.hasNext() ) {
List<Integer> nxt = partitioner.next();
while ( partitioner.hasNext() && nxt.size() > numCategories ) {
nxt = partitioner.next();
}
if ( nxt.size() > numCategories ) {
return null;
} else {
int[] buf = new int[numCategories];
for ( int idx = 0; idx < nxt.size(); idx++ ) {
buf[idx] = nxt.get(idx);
}
Arrays.sort(buf);
return buf;
}
}
return null;
}
public int[] nextPermutation() {
return MathUtilsUnitTest.nextPermutation(next);
}
}
public static int[] nextPermutation(int[] next) {
// the counts can swap among each other. The int[] is originally in ascending order
// this generates the next array in lexicographic order descending
// locate the last occurrence where next[k] < next[k+1]
int gt = -1;
for ( int idx = 0; idx < next.length-1; idx++) {
if ( next[idx] < next[idx+1] ) {
gt = idx;
}
}
if ( gt == -1 ) {
return null;
}
int largestLessThan = gt+1;
for ( int idx = 1 + largestLessThan; idx < next.length; idx++) {
if ( next[gt] < next[idx] ) {
largestLessThan = idx;
}
}
int val = next[gt];
next[gt] = next[largestLessThan];
next[largestLessThan] = val;
// reverse the tail of the array
int[] newTail = new int[next.length-gt-1];
int ctr = 0;
for ( int idx = next.length-1; idx > gt; idx-- ) {
newTail[ctr++] = next[idx];
}
for ( int idx = 0; idx < newTail.length; idx++) {
next[gt+idx+1] = newTail[idx];
}
return next;
}
// before testing the dirichlet multinomial, we need to test the
// classes used to test the dirichlet multinomial
@Test
public void testPartitioner() {
int[] numsToTest = new int[]{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20};
int[] expectedSizes = new int[]{1, 2, 3, 5, 7, 11, 15, 22, 30, 42, 56, 77, 101, 135, 176, 231, 297, 385, 490, 627};
for ( int testNum = 0; testNum < numsToTest.length; testNum++ ) {
PartitionGenerator gen = new PartitionGenerator(numsToTest[testNum]);
int size = 0;
while ( gen.hasNext() ) {
logger.debug(gen.dataStr());
size += 1;
gen.next();
}
Assert.assertEquals(size,expectedSizes[testNum],
String.format("Expected %d partitions, observed %s",expectedSizes[testNum],new PartitionGenerator(numsToTest[testNum]).toString()));
}
}
@Test
public void testNextPermutation() {
int[] arr = new int[]{1,2,3,4};
int[][] gens = new int[][] {
new int[]{1,2,3,4},
new int[]{1,2,4,3},
new int[]{1,3,2,4},
new int[]{1,3,4,2},
new int[]{1,4,2,3},
new int[]{1,4,3,2},
new int[]{2,1,3,4},
new int[]{2,1,4,3},
new int[]{2,3,1,4},
new int[]{2,3,4,1},
new int[]{2,4,1,3},
new int[]{2,4,3,1},
new int[]{3,1,2,4},
new int[]{3,1,4,2},
new int[]{3,2,1,4},
new int[]{3,2,4,1},
new int[]{3,4,1,2},
new int[]{3,4,2,1},
new int[]{4,1,2,3},
new int[]{4,1,3,2},
new int[]{4,2,1,3},
new int[]{4,2,3,1},
new int[]{4,3,1,2},
new int[]{4,3,2,1} };
for ( int gen = 0; gen < gens.length; gen ++ ) {
for ( int idx = 0; idx < 3; idx++ ) {
Assert.assertEquals(arr[idx],gens[gen][idx],
String.format("Error at generation %d, expected %s, observed %s",gen,Arrays.toString(gens[gen]),Arrays.toString(arr)));
}
arr = nextPermutation(arr);
}
}
private double[] addEpsilon(double[] counts) {
double[] d = new double[counts.length];
for ( int i = 0; i < counts.length; i ++ ) {
d[i] = counts[i] + 1e-3;
}
return d;
}
@Test
public void testDirichletMultinomial() {
List<double[]> testAlleles = Arrays.asList(
new double[]{80,240},
new double[]{1,10000},
new double[]{0,500},
new double[]{5140,20480},
new double[]{5000,800,200},
new double[]{6,3,1000},
new double[]{100,400,300,800},
new double[]{8000,100,20,80,2},
new double[]{90,20000,400,20,4,1280,720,1}
);
Assert.assertTrue(! Double.isInfinite(MathUtils.log10Gamma(1e-3)) && ! Double.isNaN(MathUtils.log10Gamma(1e-3)));
int[] numAlleleSampled = new int[]{2,5,10,20,25};
for ( double[] alleles : testAlleles ) {
for ( int count : numAlleleSampled ) {
// test that everything sums to one. Generate all multinomial draws
List<Double> likelihoods = new ArrayList<Double>(100000);
NextCounts generator = new NextCounts(alleles.length,count);
double maxLog = Double.MIN_VALUE;
//List<String> countLog = new ArrayList<String>(200);
while ( generator.hasNext() ) {
int[] thisCount = generator.next();
//countLog.add(Arrays.toString(thisCount));
Double likelihood = MathUtils.dirichletMultinomial(addEpsilon(alleles),thisCount);
Assert.assertTrue(! Double.isNaN(likelihood) && ! Double.isInfinite(likelihood),
String.format("Likelihood for counts %s and nAlleles %d was %s",
Arrays.toString(thisCount),alleles.length,Double.toString(likelihood)));
if ( likelihood > maxLog )
maxLog = likelihood;
likelihoods.add(likelihood);
}
//System.out.printf("%d likelihoods and max is (probability) %e\n",likelihoods.size(),Math.pow(10,maxLog));
Assert.assertEquals(MathUtils.sumLog10(unwrap(likelihoods)),1.0,1e-7,
String.format("Counts %d and alleles %d have nLikelihoods %d. \n Counts: %s",
count,alleles.length,likelihoods.size(), "NODEBUG"/*,countLog*/));
}
}
}
}

View File

@ -32,6 +32,7 @@ import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.collections.Pair;
import org.broadinstitute.variant.variantcontext.*;
import org.broadinstitute.variant.vcf.VCFConstants;
import org.testng.Assert;
import org.testng.annotations.BeforeSuite;
import org.testng.annotations.DataProvider;
@ -56,11 +57,7 @@ public class GATKVariantContextUtilsUnitTest extends BaseTest {
ATCATC = Allele.create("ATCATC");
}
private Genotype makeG(String sample, Allele a1, Allele a2) {
return GenotypeBuilder.create(sample, Arrays.asList(a1, a2));
}
private Genotype makeG(String sample, Allele a1, Allele a2, double log10pError, double... pls) {
private Genotype makeG(String sample, Allele a1, Allele a2, double log10pError, int... pls) {
return new GenotypeBuilder(sample, Arrays.asList(a1, a2)).log10PError(log10pError).PL(pls).make();
}