diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/ExactAFCalculationModel.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/ExactAFCalculationModel.java index 5d0b6f0a7..ae7c2f5c1 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/ExactAFCalculationModel.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/ExactAFCalculationModel.java @@ -28,6 +28,7 @@ package org.broadinstitute.sting.gatk.walkers.genotyper; import org.apache.log4j.Logger; import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.Utils; +import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import org.broadinstitute.sting.utils.exceptions.UserException; import org.broadinstitute.sting.utils.variantcontext.*; @@ -44,8 +45,12 @@ public class ExactAFCalculationModel extends AlleleFrequencyCalculationModel { private final static double SUM_GL_THRESH_NOCALL = -0.001; // if sum(gl) is bigger than this threshold, we treat GL's as non-informative and will force a no-call. private final List NO_CALL_ALLELES = Arrays.asList(Allele.NO_CALL, Allele.NO_CALL); + private final boolean USE_MULTI_ALLELIC_CALCULATION; + + protected ExactAFCalculationModel(UnifiedArgumentCollection UAC, int N, Logger logger, PrintStream verboseWriter) { super(UAC, N, logger, verboseWriter); + USE_MULTI_ALLELIC_CALCULATION = UAC.MULTI_ALLELIC; } public void getLog10PNonRef(GenotypesContext GLs, List alleles, @@ -60,9 +65,9 @@ public class ExactAFCalculationModel extends AlleleFrequencyCalculationModel { for (int k=1; k < numAlleles; k++) { // multi-allelic approximation, part 1: Ideally // for each alt allele compute marginal (suboptimal) posteriors - - // compute indices for AA,AB,BB for current allele - genotype likelihoods are a linear vector that can be thought of - // as a row-wise upper triangular matrix of likelihoods. - // So, for example, with 2 alt alleles, likelihoods have AA,AB,AC,BB,BC,CC. + // compute indices for AA,AB,BB for current allele - genotype log10Likelihoods are a linear vector that can be thought of + // as a row-wise upper triangular matrix of log10Likelihoods. + // So, for example, with 2 alt alleles, log10Likelihoods have AA,AB,AC,BB,BC,CC. // 3 alt alleles: AA,AB,AC,AD BB BC BD CC CD DD final int idxAA = 0; @@ -74,7 +79,9 @@ public class ExactAFCalculationModel extends AlleleFrequencyCalculationModel { final int idxBB = idxDiag; idxDiag += incr--; - final int lastK = linearExact(GLs, log10AlleleFrequencyPriors, log10AlleleFrequencyPosteriors, idxAA, idxAB, idxBB); + final int lastK = USE_MULTI_ALLELIC_CALCULATION ? + linearExactMultiAllelic(GLs, numAlleles - 1, log10AlleleFrequencyPriors, log10AlleleFrequencyPosteriors, false) : + linearExact(GLs, log10AlleleFrequencyPriors, log10AlleleFrequencyPosteriors, idxAA, idxAB, idxBB); if (numAlleles > 2) { posteriorCache[k-1] = log10AlleleFrequencyPosteriors.clone(); @@ -221,6 +228,16 @@ public class ExactAFCalculationModel extends AlleleFrequencyCalculationModel { return lastK; } + final static double approximateLog10SumLog10(double[] vals) { + if ( vals.length < 2 ) + throw new ReviewedStingException("Passing array with fewer than 2 values when computing approximateLog10SumLog10"); + + double approx = approximateLog10SumLog10(vals[0], vals[1]); + for ( int i = 2; i < vals.length; i++ ) + approx = approximateLog10SumLog10(approx, vals[i]); + return approx; + } + final static double approximateLog10SumLog10(double a, double b, double c) { //return softMax(new double[]{a, b, c}); return approximateLog10SumLog10(approximateLog10SumLog10(a, b), c); @@ -256,6 +273,237 @@ public class ExactAFCalculationModel extends AlleleFrequencyCalculationModel { } + // ------------------------------------------------------------------------------------- + // + // Multi-allelic implementation. + // + // ------------------------------------------------------------------------------------- + + private static final int HOM_REF_INDEX = 0; // AA likelihoods are always first + private static final int AC_ZERO_INDEX = 0; // ExactACset index for k=0 over all k + + // This class represents a column in the Exact AC calculation matrix + private static final class ExactACset { + final int[] ACcounts; + final double[] log10Likelihoods; + final HashMap ACsetIndexToPLIndex = new HashMap(); + final ArrayList dependentACsetsToDelete = new ArrayList(); + + private int index = -1; + + public ExactACset(int size, int[] ACcounts) { + this.ACcounts = ACcounts; + log10Likelihoods = new double[size]; + } + + public int getIndex() { + if ( index == -1 ) + index = generateIndex(ACcounts, log10Likelihoods.length); + return index; + } + + public static int generateIndex(int[] ACcounts, int multiplier) { + int index = 0; + for ( int i = 0; i < ACcounts.length; i++ ) + index += Math.pow(multiplier, i) * ACcounts[i]; + return index; + } + + public int getACsum() { + int sum = 0; + for ( int count : ACcounts ) + sum += count; + return sum; + } + } + + public int linearExactMultiAllelic(GenotypesContext GLs, + int numAlternateAlleles, + double[] log10AlleleFrequencyPriors, + double[] log10AlleleFrequencyPosteriors, + boolean preserveData) { + + final ArrayList genotypeLikelihoods = getGLs(GLs); + final int numSamples = genotypeLikelihoods.size()-1; + final int numChr = 2*numSamples; + + // queue of AC conformations to process + final Queue ACqueue = new LinkedList(); + + // mapping of ExactACset indexes to the objects + final HashMap indexesToACset = new HashMap(numChr+1); + + // add AC=0 to the queue + int[] zeroCounts = new int[numAlternateAlleles]; + ExactACset zeroSet = new ExactACset(numSamples+1, zeroCounts); + ACqueue.add(zeroSet); + indexesToACset.put(0, zeroSet); + + // keep processing while we have AC conformations that need to be calculated + double maxLog10L = Double.NEGATIVE_INFINITY; + while ( !ACqueue.isEmpty() ) { + // compute log10Likelihoods + final ExactACset set = ACqueue.remove(); + final double log10LofKs = calculateAlleleCountConformation(set, genotypeLikelihoods, maxLog10L, numChr, preserveData, ACqueue, indexesToACset, log10AlleleFrequencyPosteriors, log10AlleleFrequencyPriors); + + // adjust max likelihood seen if needed + maxLog10L = Math.max(maxLog10L, log10LofKs); + } + + // TODO -- finish me + + return 0; + } + + private static double calculateAlleleCountConformation(final ExactACset set, + final ArrayList genotypeLikelihoods, + final double maxLog10L, + final int numChr, + final boolean preserveData, + final Queue ACqueue, + final HashMap indexesToACset, + double[] log10AlleleFrequencyPriors, + double[] log10AlleleFrequencyPosteriors) { + + // compute the log10Likelihoods + computeLofK(set, genotypeLikelihoods, indexesToACset, log10AlleleFrequencyPosteriors, log10AlleleFrequencyPriors); + + // clean up memory + if ( !preserveData ) { + for ( int index : set.dependentACsetsToDelete ) + indexesToACset.put(index, null); + } + + final double log10LofK = set.log10Likelihoods[set.log10Likelihoods.length-1]; + + // can we abort early because the log10Likelihoods are so small? + if ( log10LofK < maxLog10L - MAX_LOG10_ERROR_TO_STOP_EARLY ) { + if ( DEBUG ) System.out.printf(" *** breaking early ks=%d log10L=%.2f maxLog10L=%.2f%n", set.index, log10LofK, maxLog10L); + return log10LofK; + } + + // iterate over higher frequencies if possible + int ACwiggle = numChr - set.getACsum(); + if ( ACwiggle == 0 ) // all alternate alleles already sum to 2N + return log10LofK; + + ExactACset lastSet = null; + int numAltAlleles = set.ACcounts.length; + + // genotype log10Likelihoods are a linear vector that can be thought of as a row-wise upper triangular matrix of log10Likelihoods. + // So e.g. with 2 alt alleles the log10Likelihoods are AA,AB,AC,BB,BC,CC and with 3 alt alleles they are AA,AB,AC,AD,BB,BC,BD,CC,CD,DD. + + // do it for the k+1 case + int PLindex = 0; + for ( int allele = 0; allele < numAltAlleles; allele++ ) { + int[] ACcountsClone = set.ACcounts.clone(); + ACcountsClone[allele]++; + lastSet = updateACset(ACcountsClone, numChr, set.getIndex(), ++PLindex, ACqueue, indexesToACset); + } + + // do it for the k+2 case if it makes sense; note that the 2 alleles may be the same or different + if ( ACwiggle > 1 ) { + for ( int allele_i = 0; allele_i < numAltAlleles; allele_i++ ) { + for ( int allele_j = allele_i; allele_j < numAltAlleles; allele_j++ ) { + int[] ACcountsClone = set.ACcounts.clone(); + ACcountsClone[allele_i]++; + ACcountsClone[allele_j]++; + lastSet = updateACset(ACcountsClone, numChr,set.getIndex(), ++PLindex , ACqueue, indexesToACset); + } + } + } + + if ( lastSet == null ) + throw new ReviewedStingException("No new AC sets were added or updated but the AC still hasn't reached 2N"); + lastSet.dependentACsetsToDelete.add(set.index); + + return log10LofK; + } + + private static ExactACset updateACset(int[] ACcounts, + int numChr, + final int callingSetIndex, + final int PLsetIndex, + final Queue ACqueue, + final HashMap indexesToACset) { + final int index = ExactACset.generateIndex(ACcounts, numChr+1); + if ( !indexesToACset.containsKey(index) ) { + ExactACset set = new ExactACset(numChr/2 +1, ACcounts); + indexesToACset.put(index, set); + ACqueue.add(set); + } + + // add the given dependency to the set + ExactACset set = indexesToACset.get(index); + set.ACsetIndexToPLIndex.put(callingSetIndex, PLsetIndex); + return set; + } + + private static void computeLofK(ExactACset set, + ArrayList genotypeLikelihoods, + final HashMap indexesToACset, + double[] log10AlleleFrequencyPriors, + double[] log10AlleleFrequencyPosteriors) { + + set.log10Likelihoods[0] = 0.0; // the zero case + int totalK = set.getACsum(); + + // special case for k = 0 over all k + if ( set.getIndex() == AC_ZERO_INDEX ) { + for ( int j = 1; j < set.log10Likelihoods.length; j++ ) + set.log10Likelihoods[j] = set.log10Likelihoods[j-1] + genotypeLikelihoods.get(j)[HOM_REF_INDEX]; + } + // k > 0 for at least one k + else { + // all possible likelihoods for a given cell from which to choose the max + final int numPaths = set.ACsetIndexToPLIndex.size() + 1; + final double[] log10ConformationLikelihoods = new double[numPaths]; + + for ( int j = 1; j < set.log10Likelihoods.length; j++ ) { + final double[] gl = genotypeLikelihoods.get(j); + final double logDenominator = MathUtils.log10Cache[2*j] + MathUtils.log10Cache[2*j-1]; + + for ( int i = 0; i < numPaths; i++ ) + log10ConformationLikelihoods[i] = Double.NEGATIVE_INFINITY; + + // deal with the AA case first + if ( totalK < 2*j-1 ) + log10ConformationLikelihoods[0] = MathUtils.log10Cache[2*j-totalK] + MathUtils.log10Cache[2*j-totalK-1] + set.log10Likelihoods[j-1] + gl[HOM_REF_INDEX]; + + // deal with the other possible conformations now + if ( totalK < 2*j ) { + int conformationIndex = 1; + for ( Map.Entry mapping : set.ACsetIndexToPLIndex.entrySet() ) + log10ConformationLikelihoods[conformationIndex++] = + determineCoefficient(mapping.getValue(), j, totalK) + indexesToACset.get(mapping.getKey()).log10Likelihoods[j-1] + gl[mapping.getValue()]; + } + + double log10Max = approximateLog10SumLog10(log10ConformationLikelihoods); + + // finally, update the L(j,k) value + set.log10Likelihoods[j] = log10Max - logDenominator; + } + } + + // update the posteriors vector + final double log10LofK = set.log10Likelihoods[set.log10Likelihoods.length-1]; + + // TODO -- this needs to be fixed; hard-coding in the biallelic case + log10AlleleFrequencyPosteriors[totalK] = log10LofK + log10AlleleFrequencyPriors[totalK]; + } + + private static double determineCoefficient(int PLindex, int j, int totalK) { + + // TODO -- the math here needs to be fixed and checked; hard-coding in the biallelic case + //AA,AB,AC,AD,BB,BC,BD,CC,CD,DD. + + double coeff; + if ( PLindex == 1 ) + coeff = MathUtils.log10Cache[2*totalK] + MathUtils.log10Cache[2*j-totalK]; + else + coeff = MathUtils.log10Cache[totalK] + MathUtils.log10Cache[totalK-1]; + return coeff; + } /** * Can be overridden by concrete subclasses diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/UnifiedArgumentCollection.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/UnifiedArgumentCollection.java index 62218416d..d7101da6b 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/UnifiedArgumentCollection.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/genotyper/UnifiedArgumentCollection.java @@ -153,6 +153,10 @@ public class UnifiedArgumentCollection { @Argument(fullName = "ignoreSNPAlleles", shortName = "ignoreSNPAlleles", doc = "expt", required = false) public boolean IGNORE_SNP_ALLELES = false; + @Hidden + @Argument(fullName = "multiallelic", shortName = "multiallelic", doc = "Allow multiple alleles in discovery", required = false) + public boolean MULTI_ALLELIC = false; + // Developers must remember to add any newly added arguments to the list here as well otherwise they won't get changed from their default value! public UnifiedArgumentCollection clone() { @@ -180,6 +184,7 @@ public class UnifiedArgumentCollection { // todo- arguments to remove uac.IGNORE_SNP_ALLELES = IGNORE_SNP_ALLELES; uac.BANDED_INDEL_COMPUTATION = BANDED_INDEL_COMPUTATION; + uac.MULTI_ALLELIC = MULTI_ALLELIC; return uac; }