diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java index 4eb728390..9cd0f5f49 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java @@ -52,6 +52,7 @@ import net.sf.samtools.SAMUtils; import org.apache.log4j.Logger; import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.QualityUtils; +import org.broadinstitute.sting.utils.Utils; import org.broadinstitute.sting.utils.exceptions.UserException; import org.broadinstitute.sting.utils.genotyper.PerReadAlleleLikelihoodMap; import org.broadinstitute.sting.utils.haplotype.Haplotype; @@ -60,7 +61,9 @@ import org.broadinstitute.sting.utils.recalibration.covariates.RepeatCovariate; import org.broadinstitute.sting.utils.recalibration.covariates.RepeatLengthCovariate; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; import org.broadinstitute.sting.utils.sam.ReadUtils; -import org.broadinstitute.variant.variantcontext.Allele; +import org.broadinstitute.sting.utils.variant.GATKVariantContextUtils; +import org.broadinstitute.variant.variantcontext.*; +import org.broadinstitute.variant.vcf.VCFConstants; import java.io.File; import java.io.FileNotFoundException; @@ -572,4 +575,212 @@ public class LikelihoodCalculationEngine { } } + // -------------------------------------------------------------------------------- + // + // Posterior GL calculations + // + // -------------------------------------------------------------------------------- + + public static VariantContext calculatePosteriorGLs(final VariantContext vc1, + final Collection resources, + final int numRefSamplesFromMissingResources, + final double globalFrequencyPriorDirichlet, + final boolean useInputSamples, + final boolean useEM, + final boolean useAC) { + if ( useEM ) + throw new IllegalArgumentException("EM loop for posterior GLs not yet implemented"); + + final Map totalAlleleCounts = new HashMap<>(); + for ( final VariantContext resource : resources ) { + addAlleleCounts(totalAlleleCounts,resource,useAC); + } + + if ( useInputSamples ) { + addAlleleCounts(totalAlleleCounts,vc1,useAC); + } + + totalAlleleCounts.put(vc1.getReference(),totalAlleleCounts.get(vc1.getReference())+numRefSamplesFromMissingResources); + + // now extract the counts of the alleles present within vc1, and in order + final double[] alleleCounts = new double[vc1.getNAlleles()]; + int alleleIndex = 0; + for ( final Allele allele : vc1.getAlleles() ) { + + alleleCounts[alleleIndex++] = globalFrequencyPriorDirichlet + ( totalAlleleCounts.containsKey(allele) ? + totalAlleleCounts.get(allele) : 0 ); + } + + final List likelihoods = new ArrayList<>(vc1.getNSamples()); + for ( final Genotype genotype : vc1.getGenotypes() ) { + likelihoods.add(genotype.hasLikelihoods() ? genotype.getLikelihoods().getAsVector() : null ); + } + + final List posteriors = calculatePosteriorGLs(likelihoods,alleleCounts,vc1.getMaxPloidy(2)); + + final GenotypesContext newContext = GenotypesContext.create(); + for ( int genoIdx = 0; genoIdx < vc1.getNSamples(); genoIdx ++ ) { + final GenotypeBuilder builder = new GenotypeBuilder(vc1.getGenotype(genoIdx)); + if ( posteriors.get(genoIdx) != null ) { + GATKVariantContextUtils.updateGenotypeAfterSubsetting(vc1.getAlleles(), builder, + GATKVariantContextUtils.GenotypeAssignmentMethod.USE_PLS_TO_ASSIGN, posteriors.get(genoIdx), vc1.getAlleles()); + builder.attribute(VCFConstants.GENOTYPE_POSTERIORS_KEY, + Utils.listFromPrimitives(GenotypeLikelihoods.fromLog10Likelihoods(posteriors.get(genoIdx)).getAsPLs())); + + } + newContext.add(builder.make()); + } + + final List priors = Utils.listFromPrimitives( + GenotypeLikelihoods.fromLog10Likelihoods(getDirichletPrior(alleleCounts, vc1.getMaxPloidy(2))).getAsPLs()); + + return new VariantContextBuilder(vc1).genotypes(newContext).attribute("PG",priors).make(); + } + + /** + * Given genotype likelihoods and known allele counts, calculate the posterior likelihoods + * over the genotype states + * @param genotypeLikelihoods - the genotype likelihoods for the individual + * @param knownAlleleCountsByAllele - the known allele counts in the population. For AC=2 AN=12 site, this is {10,2} + * @param ploidy - the ploidy to assume + * @return - the posterior genotype likelihoods + */ + protected static List calculatePosteriorGLs(final List genotypeLikelihoods, + final double[] knownAlleleCountsByAllele, + final int ploidy) { + if ( ploidy != 2 ) { + throw new IllegalStateException("Genotype posteriors not yet implemented for ploidy != 2"); + } + + final double[] genotypePriorByAllele = getDirichletPrior(knownAlleleCountsByAllele,ploidy); + final List posteriors = new ArrayList<>(genotypeLikelihoods.size()); + for ( final double[] likelihoods : genotypeLikelihoods ) { + double[] posteriorLikelihoods = null; + + if ( likelihoods != null ) { + if ( likelihoods.length != genotypePriorByAllele.length ) { + throw new IllegalStateException(String.format("Likelihoods not of correct size: expected %d, observed %d", + knownAlleleCountsByAllele.length*(knownAlleleCountsByAllele.length+1)/2,likelihoods.length)); + } + + posteriorLikelihoods = new double[genotypePriorByAllele.length]; + for ( int genoIdx = 0; genoIdx < likelihoods.length; genoIdx ++ ) { + posteriorLikelihoods[genoIdx] = likelihoods[genoIdx] + genotypePriorByAllele[genoIdx]; + } + + posteriorLikelihoods = MathUtils.toLog10(MathUtils.normalizeFromLog10(posteriorLikelihoods)); + + } + + posteriors.add(posteriorLikelihoods); + } + + return posteriors; + } + + // convenience function for a single genotypelikelihoods array. Just wraps. + protected static double[] calculatePosteriorGLs(final double[] genotypeLikelihoods, + final double[] knownAlleleCountsByAllele, + final int ploidy) { + return calculatePosteriorGLs(Arrays.asList(genotypeLikelihoods),knownAlleleCountsByAllele,ploidy).get(0); + } + + + /** + * Given known allele counts (whether external, from the sample, or both), calculate the prior distribution + * over genotype states. This assumes + * 1) Random sampling of alleles (known counts are unbiased, and frequency estimate is Dirichlet) + * 2) Genotype states are independent (Hardy-Weinberg) + * These assumptions give rise to a Dirichlet-Multinomial distribution of genotype states as a prior + * (the "number of trials" for the multinomial is simply the ploidy) + * @param knownCountsByAllele - the known counts per allele. For an AC=2, AN=12 site this is {10,2} + * @param ploidy - the number of chromosomes in the sample. For now restricted to 2. + * @return - the Dirichlet-Multinomial distribution over genotype states + */ + protected static double[] getDirichletPrior(final double[] knownCountsByAllele, final int ploidy) { + if ( ploidy != 2 ) { + throw new IllegalStateException("Genotype priors not yet implemented for ploidy != 2"); + } + + // multi-allelic format is + // AA AB BB AC BC CC AD BD CD DD ... + final double sumOfKnownCounts = MathUtils.sum(knownCountsByAllele); + final double[] priors = new double[knownCountsByAllele.length*(knownCountsByAllele.length+1)/2]; + int priorIndex = 0; + for ( int allele2 = 0; allele2 < knownCountsByAllele.length; allele2++ ) { + for ( int allele1 = 0; allele1 <= allele2; allele1++) { + final int[] counts = new int[knownCountsByAllele.length]; + counts[allele1] += 1; + counts[allele2] += 1; + priors[priorIndex++] = MathUtils.dirichletMultinomial(knownCountsByAllele,sumOfKnownCounts,counts,ploidy); + } + } + + return priors; + } + + private static void addAlleleCounts(final Map counts, final VariantContext context, final boolean useAC) { + final int[] ac; + if ( context.hasAttribute(VCFConstants.MLE_ALLELE_COUNT_KEY) && ! useAC ) { + ac = extractInts(context.getAttribute(VCFConstants.MLE_ALLELE_COUNT_KEY)); + } else if ( context.hasAttribute(VCFConstants.ALLELE_COUNT_KEY) ) { + ac = extractInts(context.getAttribute(VCFConstants.ALLELE_COUNT_KEY)); + } else { + ac = new int[context.getAlternateAlleles().size()]; + int idx = 0; + for ( final Allele allele : context.getAlternateAlleles() ) { + ac[idx++] = context.getCalledChrCount(allele); + } + } + + for ( final Allele allele : context.getAlleles() ) { + final int count; + if ( allele.isReference() ) { + if ( context.hasAttribute(VCFConstants.ALLELE_NUMBER_KEY) ) { + count = context.getAttributeAsInt(VCFConstants.ALLELE_NUMBER_KEY,-1) - (int) MathUtils.sum(ac); + } else { + count = context.getCalledChrCount() - (int) MathUtils.sum(ac); + } + } else { + count = ac[context.getAlternateAlleles().indexOf(allele)]; + } + if ( ! counts.containsKey(allele) ) { + counts.put(allele,0); + } + counts.put(allele,count + counts.get(allele)); + } + } + + public static int[] extractInts(final Object integerListContainingVCField) { + List mleList = null; + if ( integerListContainingVCField instanceof List ) { + if ( ((List) integerListContainingVCField).get(0) instanceof String ) { + mleList = new ArrayList<>(((List) integerListContainingVCField).size()); + for ( Object s : ((List)integerListContainingVCField)) { + mleList.add(Integer.parseInt((String) s)); + } + } else { + mleList = (List) integerListContainingVCField; + } + } else if ( integerListContainingVCField instanceof Integer ) { + mleList = Arrays.asList((Integer) integerListContainingVCField); + } else if ( integerListContainingVCField instanceof String ) { + mleList = Arrays.asList(Integer.parseInt((String)integerListContainingVCField)); + } + if ( mleList == null ) + throw new IllegalArgumentException(String.format("VCF does not have properly formatted "+ + VCFConstants.MLE_ALLELE_COUNT_KEY+" or "+VCFConstants.ALLELE_COUNT_KEY)); + + final int[] mle = new int[mleList.size()]; + + if ( ! ( mleList.get(0) instanceof Integer ) ) { + throw new IllegalStateException("BUG: The AC values should be an Integer, but was "+mleList.get(0).getClass().getCanonicalName()); + } + + for ( int idx = 0; idx < mle.length; idx++) { + mle[idx] = mleList.get(idx); + } + + return mle; + } } diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngineUnitTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngineUnitTest.java index 38b00ab07..2bbb7c725 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngineUnitTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngineUnitTest.java @@ -58,19 +58,34 @@ import org.broadinstitute.sting.utils.Utils; import org.broadinstitute.sting.utils.pairhmm.PairHMM; import org.broadinstitute.sting.utils.recalibration.covariates.RepeatCovariate; import org.broadinstitute.sting.utils.recalibration.covariates.RepeatLengthCovariate; +import org.broadinstitute.variant.variantcontext.*; +import org.broadinstitute.variant.vcf.VCFConstants; import org.testng.Assert; +import org.testng.annotations.BeforeSuite; import org.testng.annotations.DataProvider; import org.testng.annotations.Test; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.List; +import java.util.*; /** * Unit tests for LikelihoodCalculationEngine */ public class LikelihoodCalculationEngineUnitTest extends BaseTest { + Allele Aref, T, C, G, Cref, ATC, ATCATC; + + @BeforeSuite + public void setup() { + // alleles + Aref = Allele.create("A", true); + Cref = Allele.create("C", true); + T = Allele.create("T"); + C = Allele.create("C"); + G = Allele.create("G"); + ATC = Allele.create("ATC"); + ATCATC = Allele.create("ATCATC"); + } + @Test public void testNormalizeDiploidLikelihoodMatrixFromLog10() { double[][] likelihoodMatrix = { @@ -263,4 +278,313 @@ public class LikelihoodCalculationEngineUnitTest extends BaseTest { } return true; } + + + private String arraysEq(int[] a, int[] b) { + if ( a.length != b.length ) { + return String.format("NEQ: %s | %s",Arrays.toString(a),Arrays.toString(b)); + } + for ( int idx = 0; idx < a.length; idx++) { + if ( a[idx] - b[idx] > 1 || b[idx] - a[idx] > 1) { + return String.format("NEQ: %s | %s",Arrays.toString(a),Arrays.toString(b)); + } + } + + return ""; + } + + private int[] _mleparse(List s) { + int[] mle = new int[s.size()]; + for ( int idx = 0; idx < mle.length; idx ++) { + mle[idx] = s.get(idx); + } + + return mle; + } + + private Genotype makeGwithPLs(String sample, Allele a1, Allele a2, double[] pls) { + Genotype gt = new GenotypeBuilder(sample, Arrays.asList(a1, a2)).PL(pls).make(); + if ( pls != null && pls.length > 0 ) { + Assert.assertNotNull(gt.getPL()); + Assert.assertTrue(gt.getPL().length > 0); + for ( int i : gt.getPL() ) { + Assert.assertTrue(i >= 0); + } + Assert.assertNotEquals(Arrays.toString(gt.getPL()),"[0]"); + } + return gt; + } + + private Genotype makeG(String sample, Allele a1, Allele a2) { + return GenotypeBuilder.create(sample, Arrays.asList(a1, a2)); + } + + private Genotype makeG(String sample, Allele a1, Allele a2, int... pls) { + return new GenotypeBuilder(sample, Arrays.asList(a1, a2)).PL(pls).make(); + } + + private VariantContext makeVC(String source, List alleles, Genotype... genotypes) { + int start = 10; + int stop = start; // alleles.contains(ATC) ? start + 3 : start; + return new VariantContextBuilder(source, "1", start, stop, alleles).genotypes(Arrays.asList(genotypes)).filters(null).make(); + } + + @Test + private void testCalculatePosteriorNoExternalData() { + VariantContext test1 = makeVC("1",Arrays.asList(Aref,T), makeG("s1",Aref,T,20,0,10), + makeG("s2",T,T,60,40,0), + makeG("s3",Aref,Aref,0,30,90)); + test1 = new VariantContextBuilder(test1).attribute(VCFConstants.MLE_ALLELE_COUNT_KEY,3).make(); + VariantContext test1result = LikelihoodCalculationEngine.calculatePosteriorGLs(test1, new ArrayList(), 0, 0.001, true, false, false); + Genotype test1exp1 = makeGwithPLs("s1",Aref,T,new double[]{-2.20686, -0.03073215, -1.20686}); + Assert.assertTrue(test1exp1.hasPL()); + Genotype test1exp2 = makeGwithPLs("s2",T,T,new double[]{-6.000066, -3.823938, -6.557894e-05}); + Genotype test1exp3 = makeGwithPLs("s3",Aref,Aref,new double[]{-0.0006510083, -2.824524, -9.000651}); + Assert.assertEquals("java.util.ArrayList",test1result.getGenotype(0).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY).getClass().getCanonicalName()); + Assert.assertEquals(arraysEq(test1exp1.getPL(), _mleparse((List)test1result.getGenotype(0).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), ""); + Assert.assertEquals(arraysEq(test1exp2.getPL(),_mleparse((List)test1result.getGenotype(1).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), ""); + Assert.assertEquals(arraysEq(test1exp3.getPL(),_mleparse((List)test1result.getGenotype(2).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), ""); + + // AA AB BB AC BC CC + // AA AC CC AT CT TT + VariantContext test2 = makeVC("2",Arrays.asList(Aref,C,T), + makeG("s1",Aref,T,30,10,60,0,15,90), + makeG("s2",Aref,C,40,0,10,30,40,80), + makeG("s3",Aref,Aref,0,5,8,15,20,40), + makeG("s4",C,T,80,40,12,20,0,10)); + test2 = new VariantContextBuilder(test2).attribute(VCFConstants.MLE_ALLELE_COUNT_KEY,new ArrayList(Arrays.asList(2,2))).make(); + VariantContext test2result = LikelihoodCalculationEngine.calculatePosteriorGLs(test2,new ArrayList(),5,0.001,true,false,false); + Genotype test2exp1 = makeGwithPLs("s1",Aref,T,new double[]{-2.647372, -1.045139, -6.823193, -0.04513873, -2.198182, -9.823193}); + Genotype test2exp2 = makeGwithPLs("s2",Aref,C,new double[]{-3.609957, -0.007723248, -1.785778, -3.007723, -4.660767, -8.785778}); + Genotype test2exp3 = makeGwithPLs("s3",Aref,Aref,new double[] {-0.06094877, -0.9587151, -2.03677,-1.958715, -3.111759, -5.23677}); + Genotype test2exp4 = makeGwithPLs("s4",C,T,new double[]{-7.016534, -3.4143, -1.392355, -1.4143, -0.06734388, -1.192355}); + Assert.assertEquals(arraysEq(test2exp1.getPL(),(int[]) _mleparse((List)test2result.getGenotype(0).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), ""); + Assert.assertEquals(arraysEq(test2exp2.getPL(),(int[]) _mleparse((List)test2result.getGenotype(1).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), ""); + Assert.assertEquals(arraysEq(test2exp3.getPL(),(int[]) _mleparse((List)test2result.getGenotype(2).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), ""); + Assert.assertEquals(arraysEq(test2exp4.getPL(),(int[]) _mleparse((List)test2result.getGenotype(3).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), ""); + } + + @Test + private void testCalculatePosteriorSamplePlusExternal() { + VariantContext testOverlappingBase = makeVC("1", Arrays.asList(Aref,T), makeG("s1",T,T,40,20,0), + makeG("s2",Aref,T,18,0,24), + makeG("s3",Aref,T,22,0,12)); + List supplTest1 = new ArrayList<>(3); + supplTest1.add(new VariantContextBuilder(makeVC("2",Arrays.asList(Aref,T))).attribute(VCFConstants.MLE_ALLELE_COUNT_KEY,2).attribute(VCFConstants.ALLELE_NUMBER_KEY,10).make()); + supplTest1.add(new VariantContextBuilder(makeVC("3",Arrays.asList(Aref,T))).attribute(VCFConstants.ALLELE_COUNT_KEY,4).attribute(VCFConstants.ALLELE_NUMBER_KEY,22).make()); + supplTest1.add(makeVC("4",Arrays.asList(Aref,T), + makeG("s_1",T,T), + makeG("s_2",Aref,T))); + VariantContext test1result = LikelihoodCalculationEngine.calculatePosteriorGLs(testOverlappingBase,supplTest1,0,0.001,true,false,false); + // the counts here are ref=30, alt=14 + Genotype test1exp1 = makeGwithPLs("t1",T,T,new double[]{-3.370985, -1.415172, -0.01721766}); + Genotype test1exp2 = makeGwithPLs("t2",Aref,T,new double[]{-1.763792, -0.007978791, -3.010024}); + Genotype test1exp3 = makeGwithPLs("t3",Aref,T,new double[]{-2.165587, -0.009773643, -1.811819}); + Assert.assertEquals(arraysEq(test1exp1.getPL(),_mleparse((List) test1result.getGenotype(0).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), ""); + Assert.assertEquals(arraysEq(test1exp2.getPL(),_mleparse((List) test1result.getGenotype(1).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), ""); + Assert.assertEquals(arraysEq(test1exp3.getPL(),_mleparse((List) test1result.getGenotype(2).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), ""); + + VariantContext testNonOverlapping = makeVC("1", Arrays.asList(Aref,T), makeG("s1",T,T,3,1,0)); + List other = Arrays.asList(makeVC("2",Arrays.asList(Aref,C),makeG("s2",C,C,10,2,0))); + VariantContext test2result = LikelihoodCalculationEngine.calculatePosteriorGLs(testNonOverlapping,other,0,0.001,true,false,false); + Genotype test2exp1 = makeGwithPLs("SGV",T,T,new double[]{-4.078345, -3.276502, -0.0002661066}); + Assert.assertEquals(arraysEq(test2exp1.getPL(),_mleparse((List) test2result.getGenotype(0).getAnyAttribute(VCFConstants.GENOTYPE_POSTERIORS_KEY))), ""); + } + + private double[] pl2gl(int[] pl) { + double[] gl = new double[pl.length]; + for ( int idx = 0; idx < gl.length; idx++ ) { + gl[idx] = pl[idx]/(-10.0); + } + + return MathUtils.normalizeFromLog10(gl,true); + } + + @Test + private void testCalculatePosterior() { + int[][] likelihood_PLs = new int[][]{ + new int[]{3,0,3}, + new int[]{99,0,99}, + new int[]{50,20,0}, + new int[]{10,0,50}, + new int[]{80,60,0}, + new int[]{0,42,44}}; + + int[] altCounts = new int[]{10,40,90}; + int[] altAlleleNum = new int[]{100,500,1000}; + + double[] expected_post_10_100 = new double[] { + 9.250326e-03, 3.020208e-01, 6.887289e-01, + 7.693433e-12, 1.000000e+00, 5.728111e-10, + 1.340156e-07, 2.192982e-03, 9.978069e-01, + 6.073718e-03, 9.938811e-01, 4.522159e-05, + 1.343101e-10, 2.197802e-07, 9.999998e-01, + 9.960193e-01, 1.028366e-03, 2.952290e-03 + }; + + double[] expected_post_10_500 = new double[] { + 4.226647e-04, 7.513277e-02, 9.244446e-01, + 1.413080e-12, 1.000000e+00, 3.090662e-09, + 4.570232e-09, 4.071661e-04, 9.995928e-01, + 1.120916e-03, 9.986339e-01, 2.451646e-04, + 4.572093e-12, 4.073320e-08, 1.000000e+00, + 9.151689e-01, 5.144399e-03, 7.968675e-02 + }; + + double[] expected_post_10_1000 = new double[] { + 1.077685e-04, 3.870477e-02, 9.611875e-01, + 6.994030e-13, 1.000000e+00, 6.237975e-09, + 1.120976e-09, 2.017756e-04, 9.997982e-01, + 5.549722e-04, 9.989500e-01, 4.949797e-04, + 1.121202e-12, 2.018163e-08, 1.000000e+00, + 7.318346e-01, 8.311615e-03, 2.598538e-01 + }; + + double[] expected_post_40_100 = new double[] { + 1.102354e-01, 6.437516e-01, 2.460131e-01, + 4.301328e-11, 1.000000e+00, 9.599306e-11, + 4.422850e-06, 1.294493e-02, 9.870507e-01, + 3.303763e-02, 9.669550e-01, 7.373032e-06, + 4.480868e-09, 1.311474e-06, 9.999987e-01, + 9.997266e-01, 1.846199e-04, 8.882157e-05 + }; + + double[] expected_post_40_500 = new double[] { + 5.711785e-03, 2.557266e-01, 7.385617e-01, + 5.610428e-12, 1.000000e+00, 7.254558e-10, + 7.720262e-08, 1.732352e-03, 9.982676e-01, + 4.436495e-03, 9.955061e-01, 5.736604e-05, + 7.733659e-11, 1.735358e-07, 9.999998e-01, + 9.934793e-01, 1.406575e-03, 5.114153e-03 + }; + + double[] expected_post_40_1000 = new double[] { + 1.522132e-03, 1.422229e-01, 8.562549e-01, + 2.688330e-12, 1.000000e+00, 1.512284e-09, + 1.776184e-08, 8.317737e-04, 9.991682e-01, + 2.130611e-03, 9.977495e-01, 1.198547e-04, + 1.777662e-11, 8.324661e-08, 9.999999e-01, + 9.752770e-01, 2.881677e-03, 2.184131e-02 + }; + + double[] expected_post_90_100 = new double[] { + 6.887289e-01, 3.020208e-01, 9.250326e-03, + 5.728111e-10, 1.000000e+00, 7.693433e-12, + 6.394346e-04, 1.405351e-01, 8.588255e-01, + 3.127146e-01, 6.872849e-01, 4.200075e-07, + 7.445327e-07, 1.636336e-05, 9.999829e-01, + 9.999856e-01, 1.386699e-05, 5.346906e-07 + }; + + double[] expected_post_90_500 = new double[] { + 2.528165e-02, 4.545461e-01, 5.201723e-01, + 1.397100e-11, 1.000000e+00, 2.874546e-10, + 4.839050e-07, 4.360463e-03, 9.956391e-01, + 1.097551e-02, 9.890019e-01, 2.258221e-05, + 4.860244e-10, 4.379560e-07, 9.999996e-01, + 9.986143e-01, 5.677671e-04, 8.179741e-04 + }; + + double[] expected_post_90_1000 = new double[] { + 7.035938e-03, 2.807708e-01, 7.121932e-01, + 6.294627e-12, 1.000000e+00, 6.371561e-10, + 9.859771e-08, 1.971954e-03, 9.980279e-01, + 4.974874e-03, 9.949748e-01, 5.035678e-05, + 9.879252e-11, 1.975850e-07, 9.999998e-01, + 9.947362e-01, 1.255272e-03, 4.008518e-03 + }; + + double[][] expectations = new double[][] { + expected_post_10_100, + expected_post_10_500, + expected_post_10_1000, + expected_post_40_100, + expected_post_40_500, + expected_post_40_1000, + expected_post_90_100, + expected_post_90_500, + expected_post_90_1000 + }; + + int testIndex = 0; + for ( int altCount : altCounts ) { + for ( int numAlt : altAlleleNum ) { + double[] knownCounts = new double[2]; + knownCounts[0] = altCount; + knownCounts[1] = numAlt-altCount; + int expected_index = 0; + for ( int gl_index = 0; gl_index < likelihood_PLs.length; gl_index++ ) { + double[] post = LikelihoodCalculationEngine.calculatePosteriorGLs(pl2gl(likelihood_PLs[gl_index]), knownCounts, 2); + for ( int i = 0; i < post.length; i++ ) { + double expected = expectations[testIndex][expected_index++]; + double observed = Math.pow(10.0,post[i]); + double err = Math.abs( (expected-observed)/expected ); + Assert.assertTrue(err < 1e-4, String.format("Counts: %s | Expected: %e | Observed: %e | pre %s | prior %s | post %s", + Arrays.toString(knownCounts), expected,observed, Arrays.toString(pl2gl(likelihood_PLs[gl_index])), + Arrays.toString(LikelihoodCalculationEngine.getDirichletPrior(knownCounts,2)),Arrays.toString(post))); + } + } + testIndex++; + } + } + } + + private boolean arraysApproxEqual(double[] a, double[] b, double tol) { + if ( a.length != b.length ) { + return false; + } + + for ( int idx = 0; idx < a.length; idx++ ) { + if ( Math.abs(a[idx]-b[idx]) > tol ) { + return false; + } + } + + return true; + } + + private String errMsgArray(double[] a, double[] b) { + return String.format("Expected %s, Observed %s", Arrays.toString(a), Arrays.toString(b)); + } + + @Test + private void testPosteriorMultiAllelic() { + // AA AB BB AC BC CC AD BD CD DD + int[] PL_one = new int[] {40,20,30,0,15,25}; + int[] PL_two = new int[] {0,20,10,99,99,99}; + int[] PL_three = new int[] {50,40,0,30,30,10,20,40,80,50}; + int[] PL_four = new int[] {99,90,85,10,5,30,40,20,40,30,0,12,20,14,5}; + int[] PL_five = new int[] {60,20,30,0,40,10,8,12,18,22,40,12,80,60,20}; + double[] counts_one = new double[]{100.001,40.001,2.001}; + double[] counts_two = new double[]{2504.001,16.001,218.001}; + double[] counts_three = new double[]{10000.001,500.001,25.001,0.001}; + double[] counts_four = new double[]{4140.001,812.001,32.001,104.001,12.001}; + double[] counts_five = new double[]{80.001,40.001,8970.001,200.001,1922.001}; + + double expected_one[] = new double[] { -2.684035, -0.7852596, -2.4735, -0.08608339, -1.984017, -4.409852 }; + double expected_two[] = new double[] { -5.736189e-05, -3.893688, -5.362878, -10.65938, -12.85386, -12.0186}; + double expected_three[] = new double[] {-2.403234, -2.403276, -0.004467802, -2.70429, -4.005319, -3.59033, -6.102247, -9.403276, -14.70429, -13.40284}; + double expected_four[] = new double[] {-7.828677, -7.335196, -7.843136, -0.7395892, -0.947033, -5.139092, -3.227715, + -1.935159, -5.339552, -4.124552, -0.1655353, -2.072979, -4.277372, -3.165498, -3.469589 }; + double expected_five[] = new double[] { -9.170334, -5.175724, -6.767055, -0.8250021, -5.126027, -0.07628661, -3.276762, + -3.977787, -2.227065, -4.57769, -5.494041, -2.995066, -7.444344, -7.096104, -2.414187}; + + double[] post1 = LikelihoodCalculationEngine.calculatePosteriorGLs(pl2gl(PL_one),counts_one,2); + double[] post2 = LikelihoodCalculationEngine.calculatePosteriorGLs(pl2gl(PL_two),counts_two,2); + double[] post3 = LikelihoodCalculationEngine.calculatePosteriorGLs(pl2gl(PL_three),counts_three,2); + double[] post4 = LikelihoodCalculationEngine.calculatePosteriorGLs(pl2gl(PL_four),counts_four,2); + double[] post5 = LikelihoodCalculationEngine.calculatePosteriorGLs(pl2gl(PL_five),counts_five,2); + + double[] expecPrior5 = new double[] {-4.2878195, -4.2932090, -4.8845400, -1.9424874, -2.2435120, -0.1937719, -3.5942477, + -3.8952723, -1.5445506, -3.4951749, -2.6115263, -2.9125508, -0.5618292, -2.2135895, + -1.5316722}; + + Assert.assertTrue(arraysApproxEqual(expecPrior5, LikelihoodCalculationEngine.getDirichletPrior(counts_five,2),1e-5),errMsgArray(expecPrior5,LikelihoodCalculationEngine.getDirichletPrior(counts_five,2))); + + Assert.assertTrue(arraysApproxEqual(expected_one,post1,1e-6),errMsgArray(expected_one,post1)); + Assert.assertTrue(arraysApproxEqual(expected_two,post2,1e-5),errMsgArray(expected_two,post2)); + Assert.assertTrue(arraysApproxEqual(expected_three,post3,1e-5),errMsgArray(expected_three,post3)); + Assert.assertTrue(arraysApproxEqual(expected_four,post4,1e-5),errMsgArray(expected_four,post4)); + Assert.assertTrue(arraysApproxEqual(expected_five,post5,1e-5),errMsgArray(expected_five,post5)); + } } diff --git a/public/java/src/org/broadinstitute/sting/utils/MathUtils.java b/public/java/src/org/broadinstitute/sting/utils/MathUtils.java index bfae7e94c..8697fcab6 100644 --- a/public/java/src/org/broadinstitute/sting/utils/MathUtils.java +++ b/public/java/src/org/broadinstitute/sting/utils/MathUtils.java @@ -1472,4 +1472,34 @@ public class MathUtils { return sliceListByIndices(sampleIndicesWithoutReplacement(list.size(),N),list); } + /** + * Return the likelihood of observing the counts of categories having sampled a population + * whose categorial frequencies are distributed according to a Dirichlet distribution + * @param dirichletParams - params of the prior dirichlet distribution + * @param dirichletSum - the sum of those parameters + * @param counts - the counts of observation in each category + * @param countSum - the sum of counts (number of trials) + * @return - associated likelihood + */ + public static double dirichletMultinomial(final double[] dirichletParams, final double dirichletSum, + final int[] counts, final int countSum) { + if ( dirichletParams.length != counts.length ) { + throw new IllegalStateException("The number of dirichlet parameters must match the number of categories"); + } + // todo -- lots of lnGammas here. At some point we can safely switch to x * ( ln(x) - 1) + double likelihood = log10MultinomialCoefficient(countSum,counts); + likelihood += log10Gamma(dirichletSum); + likelihood -= log10Gamma(dirichletSum+countSum); + for ( int idx = 0; idx < counts.length; idx++ ) { + likelihood += log10Gamma(counts[idx] + dirichletParams[idx]); + likelihood -= log10Gamma(dirichletParams[idx]); + } + + return likelihood; + } + + public static double dirichletMultinomial(double[] params, int[] counts) { + return dirichletMultinomial(params,sum(params),counts,(int) sum(counts)); + } + } diff --git a/public/java/src/org/broadinstitute/sting/utils/Utils.java b/public/java/src/org/broadinstitute/sting/utils/Utils.java index 75bd6a3d1..5cb141074 100644 --- a/public/java/src/org/broadinstitute/sting/utils/Utils.java +++ b/public/java/src/org/broadinstitute/sting/utils/Utils.java @@ -835,4 +835,18 @@ public class Utils { // don't perform array copies if we need to copy everything anyways return ( trimFromFront == 0 && trimFromBack == 0 ) ? seq : Arrays.copyOfRange(seq, trimFromFront, seq.length - trimFromBack); } + + /** + * Simple wrapper for sticking elements of a int[] array into a List + * @param ar - the array whose elements should be listified + * @return - a List where each element has the same value as the corresponding index in @ar + */ + public static List listFromPrimitives(final int[] ar) { + final ArrayList lst = new ArrayList<>(ar.length); + for ( final int d : ar ) { + lst.add(d); + } + + return lst; + } } diff --git a/public/java/src/org/broadinstitute/sting/utils/variant/GATKVariantContextUtils.java b/public/java/src/org/broadinstitute/sting/utils/variant/GATKVariantContextUtils.java index 11cd27a9f..03bb9763c 100644 --- a/public/java/src/org/broadinstitute/sting/utils/variant/GATKVariantContextUtils.java +++ b/public/java/src/org/broadinstitute/sting/utils/variant/GATKVariantContextUtils.java @@ -565,11 +565,11 @@ public class GATKVariantContextUtils { * @param newLikelihoods a vector of likelihoods to use if the method requires PLs, should be log10 likelihoods, cannot be null * @param allelesToUse the alleles we are using for our subsetting */ - protected static void updateGenotypeAfterSubsetting(final List originalGT, - final GenotypeBuilder gb, - final GenotypeAssignmentMethod assignmentMethod, - final double[] newLikelihoods, - final List allelesToUse) { + public static void updateGenotypeAfterSubsetting(final List originalGT, + final GenotypeBuilder gb, + final GenotypeAssignmentMethod assignmentMethod, + final double[] newLikelihoods, + final List allelesToUse) { gb.noAD(); switch ( assignmentMethod ) { case SET_TO_NO_CALL: diff --git a/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java index a13797523..de049fe89 100644 --- a/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java @@ -522,4 +522,338 @@ public class MathUtilsUnitTest extends BaseTest { final Comparable actual = MathUtils.median(values); Assert.assertEquals(actual, expected, "Failed with " + values); } + + + + // man. All this to test dirichlet. + + private double[] unwrap(List stuff) { + double[] unwrapped = new double[stuff.size()]; + int idx = 0; + for ( Double d : stuff ) { + unwrapped[idx++] = d == null ? 0.0 : d; + } + + return unwrapped; + } + + /** + * The PartitionGenerator generates all of the partitions of a number n, e.g. + * 5 + 0 + * 4 + 1 + * 3 + 2 + * 3 + 1 + 1 + * 2 + 2 + 1 + * 2 + 1 + 1 + 1 + * 1 + 1 + 1 + 1 + 1 + * + * This is used to help enumerate the state space over which the Dirichlet-Multinomial is defined, + * to ensure that the distribution function is properly implemented + */ + class PartitionGenerator implements Iterator> { + // generate the partitions of an integer, each partition sorted numerically + int n; + List a; + int y; + int k; + int state; + int x; + int l; + + public PartitionGenerator(int n) { + this.n = n; + this.y = n - 1; + this.k = 1; + this.a = new ArrayList(); + for ( int i = 0; i < n; i++ ) { + this.a.add(i); + } + this.state = 0; + } + + public void remove() { /* do nothing */ } + + public boolean hasNext() { return ! ( this.k == 0 && state == 0 ); } + + private String dataStr() { + return String.format("a = [%s] k = %d y = %d state = %d x = %d l = %d", + Utils.join(",",a), k, y, state, x, l); + } + + public List next() { + if ( this.state == 0 ) { + this.x = a.get(k-1)+1; + k -= 1; + this.state = 1; + } + + if ( this.state == 1 ) { + while ( 2*x <= y ) { + this.a.set(k,x); + this.y -= x; + this.k++; + } + this.l = 1+this.k; + this.state = 2; + } + + if ( this.state == 2 ) { + if ( x <= y ) { + this.a.set(k,x); + this.a.set(l,y); + x += 1; + y -= 1; + return this.a.subList(0, this.k + 2); + } else { + this.state =3; + } + } + + if ( this.state == 3 ) { + this.a.set(k,x+y); + this.y = x + y - 1; + this.state = 0; + return a.subList(0, k + 1); + } + + throw new IllegalStateException("Cannot get here"); + } + + public String toString() { + StringBuffer buf = new StringBuffer(); + buf.append("{ "); + while ( hasNext() ) { + buf.append("["); + buf.append(Utils.join(",",next())); + buf.append("],"); + } + buf.deleteCharAt(buf.lastIndexOf(",")); + buf.append(" }"); + return buf.toString(); + } + + } + + /** + * NextCounts is the enumerator over the state space of the multinomial dirichlet. + * + * It filters the partition of the total sum to only those with a number of terms + * equal to the number of categories. + * + * It then generates all permutations of that partition. + * + * In so doing it enumerates over the full state space. + */ + class NextCounts implements Iterator { + + private PartitionGenerator partitioner; + private int numCategories; + private int[] next; + + public NextCounts(int numCategories, int totalCounts) { + partitioner = new PartitionGenerator(totalCounts); + this.numCategories = numCategories; + next = nextFromPartitioner(); + } + + public void remove() { /* do nothing */ } + + public boolean hasNext() { return next != null; } + + public int[] next() { + int[] toReturn = clone(next); + next = nextPermutation(); + if ( next == null ) { + next = nextFromPartitioner(); + } + + return toReturn; + } + + private int[] clone(int[] arr) { + int[] a = new int[arr.length]; + for ( int idx = 0; idx < a.length ; idx ++) { + a[idx] = arr[idx]; + } + + return a; + } + + private int[] nextFromPartitioner() { + if ( partitioner.hasNext() ) { + List nxt = partitioner.next(); + while ( partitioner.hasNext() && nxt.size() > numCategories ) { + nxt = partitioner.next(); + } + + if ( nxt.size() > numCategories ) { + return null; + } else { + int[] buf = new int[numCategories]; + for ( int idx = 0; idx < nxt.size(); idx++ ) { + buf[idx] = nxt.get(idx); + } + Arrays.sort(buf); + return buf; + } + } + + return null; + } + + public int[] nextPermutation() { + return MathUtilsUnitTest.nextPermutation(next); + } + + } + + public static int[] nextPermutation(int[] next) { + // the counts can swap among each other. The int[] is originally in ascending order + // this generates the next array in lexicographic order descending + + // locate the last occurrence where next[k] < next[k+1] + int gt = -1; + for ( int idx = 0; idx < next.length-1; idx++) { + if ( next[idx] < next[idx+1] ) { + gt = idx; + } + } + + if ( gt == -1 ) { + return null; + } + + int largestLessThan = gt+1; + for ( int idx = 1 + largestLessThan; idx < next.length; idx++) { + if ( next[gt] < next[idx] ) { + largestLessThan = idx; + } + } + + int val = next[gt]; + next[gt] = next[largestLessThan]; + next[largestLessThan] = val; + + // reverse the tail of the array + int[] newTail = new int[next.length-gt-1]; + int ctr = 0; + for ( int idx = next.length-1; idx > gt; idx-- ) { + newTail[ctr++] = next[idx]; + } + + for ( int idx = 0; idx < newTail.length; idx++) { + next[gt+idx+1] = newTail[idx]; + } + + return next; + } + + + // before testing the dirichlet multinomial, we need to test the + // classes used to test the dirichlet multinomial + + @Test + public void testPartitioner() { + int[] numsToTest = new int[]{1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20}; + int[] expectedSizes = new int[]{1, 2, 3, 5, 7, 11, 15, 22, 30, 42, 56, 77, 101, 135, 176, 231, 297, 385, 490, 627}; + for ( int testNum = 0; testNum < numsToTest.length; testNum++ ) { + PartitionGenerator gen = new PartitionGenerator(numsToTest[testNum]); + int size = 0; + while ( gen.hasNext() ) { + logger.debug(gen.dataStr()); + size += 1; + gen.next(); + } + Assert.assertEquals(size,expectedSizes[testNum], + String.format("Expected %d partitions, observed %s",expectedSizes[testNum],new PartitionGenerator(numsToTest[testNum]).toString())); + } + } + + @Test + public void testNextPermutation() { + int[] arr = new int[]{1,2,3,4}; + int[][] gens = new int[][] { + new int[]{1,2,3,4}, + new int[]{1,2,4,3}, + new int[]{1,3,2,4}, + new int[]{1,3,4,2}, + new int[]{1,4,2,3}, + new int[]{1,4,3,2}, + new int[]{2,1,3,4}, + new int[]{2,1,4,3}, + new int[]{2,3,1,4}, + new int[]{2,3,4,1}, + new int[]{2,4,1,3}, + new int[]{2,4,3,1}, + new int[]{3,1,2,4}, + new int[]{3,1,4,2}, + new int[]{3,2,1,4}, + new int[]{3,2,4,1}, + new int[]{3,4,1,2}, + new int[]{3,4,2,1}, + new int[]{4,1,2,3}, + new int[]{4,1,3,2}, + new int[]{4,2,1,3}, + new int[]{4,2,3,1}, + new int[]{4,3,1,2}, + new int[]{4,3,2,1} }; + for ( int gen = 0; gen < gens.length; gen ++ ) { + for ( int idx = 0; idx < 3; idx++ ) { + Assert.assertEquals(arr[idx],gens[gen][idx], + String.format("Error at generation %d, expected %s, observed %s",gen,Arrays.toString(gens[gen]),Arrays.toString(arr))); + } + arr = nextPermutation(arr); + } + } + + private double[] addEpsilon(double[] counts) { + double[] d = new double[counts.length]; + for ( int i = 0; i < counts.length; i ++ ) { + d[i] = counts[i] + 1e-3; + } + return d; + } + + @Test + public void testDirichletMultinomial() { + List testAlleles = Arrays.asList( + new double[]{80,240}, + new double[]{1,10000}, + new double[]{0,500}, + new double[]{5140,20480}, + new double[]{5000,800,200}, + new double[]{6,3,1000}, + new double[]{100,400,300,800}, + new double[]{8000,100,20,80,2}, + new double[]{90,20000,400,20,4,1280,720,1} + ); + + Assert.assertTrue(! Double.isInfinite(MathUtils.log10Gamma(1e-3)) && ! Double.isNaN(MathUtils.log10Gamma(1e-3))); + + int[] numAlleleSampled = new int[]{2,5,10,20,25}; + for ( double[] alleles : testAlleles ) { + for ( int count : numAlleleSampled ) { + // test that everything sums to one. Generate all multinomial draws + List likelihoods = new ArrayList(100000); + NextCounts generator = new NextCounts(alleles.length,count); + double maxLog = Double.MIN_VALUE; + //List countLog = new ArrayList(200); + while ( generator.hasNext() ) { + int[] thisCount = generator.next(); + //countLog.add(Arrays.toString(thisCount)); + Double likelihood = MathUtils.dirichletMultinomial(addEpsilon(alleles),thisCount); + Assert.assertTrue(! Double.isNaN(likelihood) && ! Double.isInfinite(likelihood), + String.format("Likelihood for counts %s and nAlleles %d was %s", + Arrays.toString(thisCount),alleles.length,Double.toString(likelihood))); + if ( likelihood > maxLog ) + maxLog = likelihood; + likelihoods.add(likelihood); + } + //System.out.printf("%d likelihoods and max is (probability) %e\n",likelihoods.size(),Math.pow(10,maxLog)); + Assert.assertEquals(MathUtils.sumLog10(unwrap(likelihoods)),1.0,1e-7, + String.format("Counts %d and alleles %d have nLikelihoods %d. \n Counts: %s", + count,alleles.length,likelihoods.size(), "NODEBUG"/*,countLog*/)); + } + } + } } diff --git a/public/java/test/org/broadinstitute/sting/utils/variant/GATKVariantContextUtilsUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/variant/GATKVariantContextUtilsUnitTest.java index 220e64f7d..575fe4936 100644 --- a/public/java/test/org/broadinstitute/sting/utils/variant/GATKVariantContextUtilsUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/variant/GATKVariantContextUtilsUnitTest.java @@ -32,6 +32,7 @@ import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.Utils; import org.broadinstitute.sting.utils.collections.Pair; import org.broadinstitute.variant.variantcontext.*; +import org.broadinstitute.variant.vcf.VCFConstants; import org.testng.Assert; import org.testng.annotations.BeforeSuite; import org.testng.annotations.DataProvider; @@ -56,11 +57,7 @@ public class GATKVariantContextUtilsUnitTest extends BaseTest { ATCATC = Allele.create("ATCATC"); } - private Genotype makeG(String sample, Allele a1, Allele a2) { - return GenotypeBuilder.create(sample, Arrays.asList(a1, a2)); - } - - private Genotype makeG(String sample, Allele a1, Allele a2, double log10pError, double... pls) { + private Genotype makeG(String sample, Allele a1, Allele a2, double log10pError, int... pls) { return new GenotypeBuilder(sample, Arrays.asList(a1, a2)).log10PError(log10pError).PL(pls).make(); }