From 255b036fb50b136ccba421f90e07f434c703239d Mon Sep 17 00:00:00 2001 From: rpoplin Date: Thu, 1 Jul 2010 18:51:07 +0000 Subject: [PATCH] Variant Recalibrator MLE EM algorithm is moved over to variational Bayes EM in order to eliminate problems with singularities when clustering in higher than two dimensions. Because of this there is no longer a number of Gaussians parameter. Wiki will be updated shortly with new recommended command. git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@3704 348d0f76-0448-11de-a6fe-93d51630548a --- .../GenerateVariantClustersWalker.java | 44 +-- .../variantrecalibration/VariantDatum.java | 7 - .../VariantGaussianMixtureModel.java | 330 ++++++++++++------ .../VariantRecalibrator.java | 8 +- ...ntRecalibrationWalkersIntegrationTest.java | 5 +- 5 files changed, 252 insertions(+), 142 deletions(-) diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GenerateVariantClustersWalker.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GenerateVariantClustersWalker.java index e0b20e931..7a7343ebf 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GenerateVariantClustersWalker.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GenerateVariantClustersWalker.java @@ -65,22 +65,20 @@ public class GenerateVariantClustersWalker extends RodWalker annotationKeys; private Set ignoreInputFilterSet = null; private int maxAC = 0; - private PrintStream outFile; - private final static boolean FXYZ_FILE = false; // Debug argument //--------------------------------------------------------------------------------------------------------------- // @@ -134,14 +134,6 @@ public class GenerateVariantClustersWalker extends RodWalker minVarInCluster ) { + for( int kkk = 0; kkk < maxGaussians; kkk++ ) { + if( Math.pow(10.0, pClusterLog10[kkk]) > 1E-4 ) { // BUGBUG: make this a command line argument final double sigmaVals[][] = sigma[kkk].getArray(); - outputFile.print("@!CLUSTER,"); - outputFile.print(Math.pow(10.0, pClusterLog10[kkk]) + ","); + outputFile.print("@!CLUSTER"); + outputFile.print("," + Math.pow(10.0, pClusterLog10[kkk])); for(int jjj = 0; jjj < numAnnotations; jjj++ ) { - outputFile.print(mu[kkk][jjj] + ","); + outputFile.print("," + mu[kkk][jjj]); } for(int jjj = 0; jjj < numAnnotations; jjj++ ) { for(int ppp = 0; ppp < numAnnotations; ppp++ ) { - outputFile.print(sigmaVals[jjj][ppp] + ","); + outputFile.print("," + (sigmaVals[jjj][ppp] / hyperParameter_a[kkk]) ); } } - outputFile.println(-1); + outputFile.println(); } } outputFile.close(); @@ -369,12 +492,15 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } } - public static double decodeAnnotation( final String annotationKey, final VariantContext vc ) { + public static double decodeAnnotation( final String annotationKey, final VariantContext vc, final boolean jitter ) { double value; //if( annotationKey.equals("AB") && !vc.getAttributes().containsKey(annotationKey) ) { // value = (0.5 - 0.005) + (0.01 * rand.nextDouble()); // HomVar calls don't have an allele balance //} - if( annotationKey.equals("QUAL") ) { + if( jitter && annotationKey.equalsIgnoreCase("HRUN") ) { // HRun values must be jittered a bit to work in this GMM + value = Double.parseDouble( (String)vc.getAttribute( annotationKey ) ); + value += -0.25 + 0.5 * rand.nextDouble(); + } else if( annotationKey.equals("QUAL") ) { value = vc.getPhredScaledQual(); } else { try { @@ -388,18 +514,18 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } public final double evaluateVariantLog10( final VariantContext vc ) { - final double[] pVarInCluster = new double[numGaussians]; + final double[] pVarInCluster = new double[maxGaussians]; final double[] annotations = new double[dataManager.numAnnotations]; for( int jjj = 0; jjj < dataManager.numAnnotations; jjj++ ) { - final double value = decodeAnnotation( dataManager.annotationKeys.get(jjj), vc ); + final double value = decodeAnnotation( dataManager.annotationKeys.get(jjj), vc, true ); annotations[jjj] = (value - dataManager.meanVector[jjj]) / dataManager.varianceVector[jjj]; } evaluateGaussiansForSingleVariant( annotations, pVarInCluster ); double sum = 0.0; - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = 0; kkk < maxGaussians; kkk++ ) { sum += pVarInCluster[kkk]; // * clusterTruePositiveRate[kkk]; } @@ -612,18 +738,32 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final int numAnnotations = data[0].annotations.length; double likelihood = 0.0; - final double sigmaVals[][][] = new double[numGaussians][][]; - final double denomLog10[] = new double[numGaussians]; - final double pVarInClusterLog10[] = new double[numGaussians]; + final double sigmaVals[][][] = new double[maxGaussians][][]; + final double denomLog10[] = new double[maxGaussians]; + final double pVarInClusterLog10[] = new double[maxGaussians]; double pVarInClusterReals[]; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { sigmaVals[kkk] = sigma[kkk].inverse().getArray(); - denomLog10[kkk] = Math.log10(Math.pow(2.0 * Math.PI, ((double)numAnnotations) / 2.0)) + Math.log10(Math.pow(determinant[kkk], 0.5)); - if( Double.isInfinite(denomLog10[kkk]) ) { - throw new StingException("Numerical Instability! Determinant value is too small: " + determinant[kkk] + - "Try running with fewer annotations and then with fewer Gaussians."); + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] *= hyperParameter_a[kkk]; + } } + double sum = 0.0; + for(int jjj = 1; jjj < numAnnotations; jjj++) { + sum += diGamma((hyperParameter_a[kkk] + 1.0 - jjj) / 2.0); + } + sum -= Math.log(determinant[kkk]); + sum += Math.log(2.0) * numAnnotations; + final double gamma = 0.5 * sum; + sum = 0.0; + for(int ccc = 0; ccc < maxGaussians; ccc++) { + sum += hyperParameter_lambda[ccc]; + } + final double pi = diGamma(hyperParameter_lambda[kkk]) - diGamma(sum); + final double beta = (-1.0 * numAnnotations) / (2.0 * hyperParameter_b[kkk]); + denomLog10[kkk] = (pi / Math.log(10.0)) + (gamma / Math.log(10.0)) + (beta / Math.log(10.0)); } final double mult[] = new double[numAnnotations]; double sumWeight = 0.0; @@ -641,14 +781,15 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel sum += mult[jjj] * (data[iii].annotations[jjj] - mu[kkk][jjj]); } - pVarInClusterLog10[kkk] = pClusterLog10[kkk] + (( -0.5 * sum )/Math.log(10.0)) - denomLog10[kkk]; + pVarInClusterLog10[kkk] = (( -0.5 * sum )/Math.log(10.0)) + denomLog10[kkk]; final double pVar = Math.pow(10.0, pVarInClusterLog10[kkk]); likelihood += pVar * data[iii].weight; - if( pVarInClusterLog10[kkk] > 0.0 || Double.isNaN(pVarInClusterLog10[kkk]) || Double.isInfinite(pVarInClusterLog10[kkk]) ) { + if( Double.isNaN(pVarInClusterLog10[kkk]) || Double.isInfinite(pVarInClusterLog10[kkk]) ) { logger.warn("det = " + sigma[kkk].det()); logger.warn("denom = " + denomLog10[kkk]); logger.warn("sumExp = " + sum); + logger.warn("mixtureLog10 = " + pClusterLog10[kkk]); logger.warn("pVar = " + pVar); for( int jjj = 0; jjj < numAnnotations; jjj++ ) { for( int ppp = 0; ppp < numAnnotations; ppp++ ) { @@ -673,7 +814,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } } - logger.info("Explained likelihood = " + String.format("%.5f",likelihood / sumWeight)); + logger.info("explained likelihood = " + String.format("%.5f",likelihood / sumWeight)); return likelihood / sumWeight; } @@ -682,7 +823,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final int numAnnotations = annotations.length; final double mult[] = new double[numAnnotations]; - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = 0; kkk < maxGaussians; kkk++ ) { final double sigmaVals[][] = sigmaInverse[kkk].getArray(); double sum = 0.0; for( int jjj = 0; jjj < numAnnotations; jjj++ ) { @@ -705,8 +846,9 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final int numVariants = data.length; final int numAnnotations = data[0].annotations.length; - final double sigmaVals[][][] = new double[numGaussians][numAnnotations][numAnnotations]; - final double meanVals[][] = new double[numGaussians][numAnnotations]; + final double sigmaVals[][][] = new double[maxGaussians][numAnnotations][numAnnotations]; + final double wishartVals[][] = new double[numAnnotations][numAnnotations]; + final double meanVals[][] = new double[maxGaussians][numAnnotations]; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { @@ -716,6 +858,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } } } + double sumPK = 0.0; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { double sumProb = 0.0; @@ -728,69 +871,60 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } sumProb += prob; for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - meanVals[kkk][jjj] += prob * data[iii].annotations[jjj]; + meanVals[kkk][jjj] += prob * data[iii].annotations[jjj]; } } for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - meanVals[kkk][jjj] /= sumProb; + meanVals[kkk][jjj] = (meanVals[kkk][jjj] + SHRINKAGE * empiricalMu[jjj]) / (sumProb + SHRINKAGE); + } + + final double shrinkageFactor = (SHRINKAGE * sumProb) / (SHRINKAGE + sumProb); + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + wishartVals[jjj][ppp] = shrinkageFactor * (meanVals[kkk][jjj] - empiricalMu[jjj]) * (meanVals[kkk][ppp] - empiricalMu[ppp]); + } } for( int iii = 0; iii < numVariants; iii++ ) { final double prob = pVarInCluster[kkk][iii]; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { - sigmaVals[kkk][jjj][ppp] += prob * (data[iii].annotations[jjj]-meanVals[kkk][jjj]) * (data[iii].annotations[ppp]-meanVals[kkk][ppp]); - } - } - } - - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { - if( sigmaVals[kkk][jjj][ppp] < MIN_SIGMA && sigmaVals[kkk][jjj][ppp] > -MIN_SIGMA ) { // Very small numbers are a very big problem - logger.warn("The sigma values look exceptionally small.... Probably about to crash due to numeric instability."); - } - sigmaVals[kkk][ppp][jjj] = sigmaVals[kkk][jjj][ppp]; // sigma must be a symmetric matrix - } - } - - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - for( int ppp = 0; ppp < numAnnotations; ppp++ ) { - sigmaVals[kkk][jjj][ppp] /= sumProb; - } - } - - if( FORCE_INDEPENDENT_ANNOTATIONS ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { for( int ppp = 0; ppp < numAnnotations; ppp++ ) { - if(jjj!=ppp) { - sigmaVals[kkk][jjj][ppp] = 0.0; - } + sigmaVals[kkk][jjj][ppp] += prob * (data[iii].annotations[jjj]-meanVals[kkk][jjj]) * (data[iii].annotations[ppp]-meanVals[kkk][ppp]); } } - } - - final Matrix tmpMatrix = new Matrix(sigmaVals[kkk]); - if( tmpMatrix.det() > MIN_DETERMINANT ) { - sigma[kkk] = new Matrix(sigmaVals[kkk]); - determinant[kkk] = sigma[kkk].det(); - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - mu[kkk][jjj] = meanVals[kkk][jjj]; - } - } else { - logger.warn("Tried to create a covariance matrix with exceptionally small determinant."); + final Matrix tmpMatrix = empiricalSigma.plus(new Matrix(wishartVals).plus(new Matrix(sigmaVals[kkk]))); + + sigma[kkk] = (Matrix)tmpMatrix.clone(); + determinant[kkk] = sigma[kkk].det(); + + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mu[kkk][jjj] = meanVals[kkk][jjj]; } pClusterLog10[kkk] = sumProb; sumPK += sumProb; + + hyperParameter_a[kkk] = sumProb + numAnnotations; + hyperParameter_b[kkk] = sumProb + SHRINKAGE; + hyperParameter_lambda[kkk] = sumProb + DIRICHLET_PARAMETER; } for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - pClusterLog10[kkk] = Math.log10( pClusterLog10[kkk] / sumPK ); + pClusterLog10[kkk] = Math.log10( pClusterLog10[kkk] / sumPK ); } pClusterLog10 = MathUtils.normalizeFromLog10( pClusterLog10, true ); } + + // from http://en.wikipedia.org/wiki/Digamma_function + // According to J.M. Bernardo AS 103 algorithm the digamma function for x, a real number, can be approximated by: + private static double diGamma(final double x) { + return Math.log(x) - ( 1.0 / (2.0 * x) ) + - ( 1.0 / (12.0 * Math.pow(x, 2.0)) ) + + ( 1.0 / (120.0 * Math.pow(x, 4.0)) ) + - ( 1.0 / (252.0 * Math.pow(x, 6.0)) ); + } } \ No newline at end of file diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java index 231d2349b..782d37f2f 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java @@ -64,7 +64,7 @@ public class VariantRecalibrator extends RodWalker e = new HashMap(); - e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "9b7517ff1fd0fc23a2596acc82e2ed96" ); + e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "72558d2c49bb94dc59e9d4146fe0bc05" ); for ( Map.Entry entry : e.entrySet() ) { String vcf = entry.getKey(); @@ -24,7 +24,6 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { " -T GenerateVariantClusters" + " -B input,VCF," + vcf + " -L 1:1-100,000,000" + - " -nG 6" + " --ignore_filter GATK_STANDARD" + " -an QD -an HRun -an SB" + " -clusterFile %s", @@ -38,7 +37,7 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { @Test public void testVariantRecalibrator() { HashMap e = new HashMap(); - e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "e0e3d959929aa3940a81c9926d1406e2" ); + e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "d41c4326e589f1746278f1ed9815291a" ); for ( Map.Entry entry : e.entrySet() ) { String vcf = entry.getKey();