diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GenerateVariantClustersWalker.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GenerateVariantClustersWalker.java index e0b20e931..7a7343ebf 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GenerateVariantClustersWalker.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GenerateVariantClustersWalker.java @@ -65,22 +65,20 @@ public class GenerateVariantClustersWalker extends RodWalker annotationKeys; private Set ignoreInputFilterSet = null; private int maxAC = 0; - private PrintStream outFile; - private final static boolean FXYZ_FILE = false; // Debug argument //--------------------------------------------------------------------------------------------------------------- // @@ -134,14 +134,6 @@ public class GenerateVariantClustersWalker extends RodWalker minVarInCluster ) { + for( int kkk = 0; kkk < maxGaussians; kkk++ ) { + if( Math.pow(10.0, pClusterLog10[kkk]) > 1E-4 ) { // BUGBUG: make this a command line argument final double sigmaVals[][] = sigma[kkk].getArray(); - outputFile.print("@!CLUSTER,"); - outputFile.print(Math.pow(10.0, pClusterLog10[kkk]) + ","); + outputFile.print("@!CLUSTER"); + outputFile.print("," + Math.pow(10.0, pClusterLog10[kkk])); for(int jjj = 0; jjj < numAnnotations; jjj++ ) { - outputFile.print(mu[kkk][jjj] + ","); + outputFile.print("," + mu[kkk][jjj]); } for(int jjj = 0; jjj < numAnnotations; jjj++ ) { for(int ppp = 0; ppp < numAnnotations; ppp++ ) { - outputFile.print(sigmaVals[jjj][ppp] + ","); + outputFile.print("," + (sigmaVals[jjj][ppp] / hyperParameter_a[kkk]) ); } } - outputFile.println(-1); + outputFile.println(); } } outputFile.close(); @@ -369,12 +492,15 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } } - public static double decodeAnnotation( final String annotationKey, final VariantContext vc ) { + public static double decodeAnnotation( final String annotationKey, final VariantContext vc, final boolean jitter ) { double value; //if( annotationKey.equals("AB") && !vc.getAttributes().containsKey(annotationKey) ) { // value = (0.5 - 0.005) + (0.01 * rand.nextDouble()); // HomVar calls don't have an allele balance //} - if( annotationKey.equals("QUAL") ) { + if( jitter && annotationKey.equalsIgnoreCase("HRUN") ) { // HRun values must be jittered a bit to work in this GMM + value = Double.parseDouble( (String)vc.getAttribute( annotationKey ) ); + value += -0.25 + 0.5 * rand.nextDouble(); + } else if( annotationKey.equals("QUAL") ) { value = vc.getPhredScaledQual(); } else { try { @@ -388,18 +514,18 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } public final double evaluateVariantLog10( final VariantContext vc ) { - final double[] pVarInCluster = new double[numGaussians]; + final double[] pVarInCluster = new double[maxGaussians]; final double[] annotations = new double[dataManager.numAnnotations]; for( int jjj = 0; jjj < dataManager.numAnnotations; jjj++ ) { - final double value = decodeAnnotation( dataManager.annotationKeys.get(jjj), vc ); + final double value = decodeAnnotation( dataManager.annotationKeys.get(jjj), vc, true ); annotations[jjj] = (value - dataManager.meanVector[jjj]) / dataManager.varianceVector[jjj]; } evaluateGaussiansForSingleVariant( annotations, pVarInCluster ); double sum = 0.0; - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = 0; kkk < maxGaussians; kkk++ ) { sum += pVarInCluster[kkk]; // * clusterTruePositiveRate[kkk]; } @@ -612,18 +738,32 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final int numAnnotations = data[0].annotations.length; double likelihood = 0.0; - final double sigmaVals[][][] = new double[numGaussians][][]; - final double denomLog10[] = new double[numGaussians]; - final double pVarInClusterLog10[] = new double[numGaussians]; + final double sigmaVals[][][] = new double[maxGaussians][][]; + final double denomLog10[] = new double[maxGaussians]; + final double pVarInClusterLog10[] = new double[maxGaussians]; double pVarInClusterReals[]; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { sigmaVals[kkk] = sigma[kkk].inverse().getArray(); - denomLog10[kkk] = Math.log10(Math.pow(2.0 * Math.PI, ((double)numAnnotations) / 2.0)) + Math.log10(Math.pow(determinant[kkk], 0.5)); - if( Double.isInfinite(denomLog10[kkk]) ) { - throw new StingException("Numerical Instability! Determinant value is too small: " + determinant[kkk] + - "Try running with fewer annotations and then with fewer Gaussians."); + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] *= hyperParameter_a[kkk]; + } } + double sum = 0.0; + for(int jjj = 1; jjj < numAnnotations; jjj++) { + sum += diGamma((hyperParameter_a[kkk] + 1.0 - jjj) / 2.0); + } + sum -= Math.log(determinant[kkk]); + sum += Math.log(2.0) * numAnnotations; + final double gamma = 0.5 * sum; + sum = 0.0; + for(int ccc = 0; ccc < maxGaussians; ccc++) { + sum += hyperParameter_lambda[ccc]; + } + final double pi = diGamma(hyperParameter_lambda[kkk]) - diGamma(sum); + final double beta = (-1.0 * numAnnotations) / (2.0 * hyperParameter_b[kkk]); + denomLog10[kkk] = (pi / Math.log(10.0)) + (gamma / Math.log(10.0)) + (beta / Math.log(10.0)); } final double mult[] = new double[numAnnotations]; double sumWeight = 0.0; @@ -641,14 +781,15 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel sum += mult[jjj] * (data[iii].annotations[jjj] - mu[kkk][jjj]); } - pVarInClusterLog10[kkk] = pClusterLog10[kkk] + (( -0.5 * sum )/Math.log(10.0)) - denomLog10[kkk]; + pVarInClusterLog10[kkk] = (( -0.5 * sum )/Math.log(10.0)) + denomLog10[kkk]; final double pVar = Math.pow(10.0, pVarInClusterLog10[kkk]); likelihood += pVar * data[iii].weight; - if( pVarInClusterLog10[kkk] > 0.0 || Double.isNaN(pVarInClusterLog10[kkk]) || Double.isInfinite(pVarInClusterLog10[kkk]) ) { + if( Double.isNaN(pVarInClusterLog10[kkk]) || Double.isInfinite(pVarInClusterLog10[kkk]) ) { logger.warn("det = " + sigma[kkk].det()); logger.warn("denom = " + denomLog10[kkk]); logger.warn("sumExp = " + sum); + logger.warn("mixtureLog10 = " + pClusterLog10[kkk]); logger.warn("pVar = " + pVar); for( int jjj = 0; jjj < numAnnotations; jjj++ ) { for( int ppp = 0; ppp < numAnnotations; ppp++ ) { @@ -673,7 +814,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } } - logger.info("Explained likelihood = " + String.format("%.5f",likelihood / sumWeight)); + logger.info("explained likelihood = " + String.format("%.5f",likelihood / sumWeight)); return likelihood / sumWeight; } @@ -682,7 +823,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final int numAnnotations = annotations.length; final double mult[] = new double[numAnnotations]; - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = 0; kkk < maxGaussians; kkk++ ) { final double sigmaVals[][] = sigmaInverse[kkk].getArray(); double sum = 0.0; for( int jjj = 0; jjj < numAnnotations; jjj++ ) { @@ -705,8 +846,9 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final int numVariants = data.length; final int numAnnotations = data[0].annotations.length; - final double sigmaVals[][][] = new double[numGaussians][numAnnotations][numAnnotations]; - final double meanVals[][] = new double[numGaussians][numAnnotations]; + final double sigmaVals[][][] = new double[maxGaussians][numAnnotations][numAnnotations]; + final double wishartVals[][] = new double[numAnnotations][numAnnotations]; + final double meanVals[][] = new double[maxGaussians][numAnnotations]; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { @@ -716,6 +858,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } } } + double sumPK = 0.0; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { double sumProb = 0.0; @@ -728,69 +871,60 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } sumProb += prob; for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - meanVals[kkk][jjj] += prob * data[iii].annotations[jjj]; + meanVals[kkk][jjj] += prob * data[iii].annotations[jjj]; } } for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - meanVals[kkk][jjj] /= sumProb; + meanVals[kkk][jjj] = (meanVals[kkk][jjj] + SHRINKAGE * empiricalMu[jjj]) / (sumProb + SHRINKAGE); + } + + final double shrinkageFactor = (SHRINKAGE * sumProb) / (SHRINKAGE + sumProb); + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + wishartVals[jjj][ppp] = shrinkageFactor * (meanVals[kkk][jjj] - empiricalMu[jjj]) * (meanVals[kkk][ppp] - empiricalMu[ppp]); + } } for( int iii = 0; iii < numVariants; iii++ ) { final double prob = pVarInCluster[kkk][iii]; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { - sigmaVals[kkk][jjj][ppp] += prob * (data[iii].annotations[jjj]-meanVals[kkk][jjj]) * (data[iii].annotations[ppp]-meanVals[kkk][ppp]); - } - } - } - - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { - if( sigmaVals[kkk][jjj][ppp] < MIN_SIGMA && sigmaVals[kkk][jjj][ppp] > -MIN_SIGMA ) { // Very small numbers are a very big problem - logger.warn("The sigma values look exceptionally small.... Probably about to crash due to numeric instability."); - } - sigmaVals[kkk][ppp][jjj] = sigmaVals[kkk][jjj][ppp]; // sigma must be a symmetric matrix - } - } - - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - for( int ppp = 0; ppp < numAnnotations; ppp++ ) { - sigmaVals[kkk][jjj][ppp] /= sumProb; - } - } - - if( FORCE_INDEPENDENT_ANNOTATIONS ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { for( int ppp = 0; ppp < numAnnotations; ppp++ ) { - if(jjj!=ppp) { - sigmaVals[kkk][jjj][ppp] = 0.0; - } + sigmaVals[kkk][jjj][ppp] += prob * (data[iii].annotations[jjj]-meanVals[kkk][jjj]) * (data[iii].annotations[ppp]-meanVals[kkk][ppp]); } } - } - - final Matrix tmpMatrix = new Matrix(sigmaVals[kkk]); - if( tmpMatrix.det() > MIN_DETERMINANT ) { - sigma[kkk] = new Matrix(sigmaVals[kkk]); - determinant[kkk] = sigma[kkk].det(); - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - mu[kkk][jjj] = meanVals[kkk][jjj]; - } - } else { - logger.warn("Tried to create a covariance matrix with exceptionally small determinant."); + final Matrix tmpMatrix = empiricalSigma.plus(new Matrix(wishartVals).plus(new Matrix(sigmaVals[kkk]))); + + sigma[kkk] = (Matrix)tmpMatrix.clone(); + determinant[kkk] = sigma[kkk].det(); + + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mu[kkk][jjj] = meanVals[kkk][jjj]; } pClusterLog10[kkk] = sumProb; sumPK += sumProb; + + hyperParameter_a[kkk] = sumProb + numAnnotations; + hyperParameter_b[kkk] = sumProb + SHRINKAGE; + hyperParameter_lambda[kkk] = sumProb + DIRICHLET_PARAMETER; } for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - pClusterLog10[kkk] = Math.log10( pClusterLog10[kkk] / sumPK ); + pClusterLog10[kkk] = Math.log10( pClusterLog10[kkk] / sumPK ); } pClusterLog10 = MathUtils.normalizeFromLog10( pClusterLog10, true ); } + + // from http://en.wikipedia.org/wiki/Digamma_function + // According to J.M. Bernardo AS 103 algorithm the digamma function for x, a real number, can be approximated by: + private static double diGamma(final double x) { + return Math.log(x) - ( 1.0 / (2.0 * x) ) + - ( 1.0 / (12.0 * Math.pow(x, 2.0)) ) + + ( 1.0 / (120.0 * Math.pow(x, 4.0)) ) + - ( 1.0 / (252.0 * Math.pow(x, 6.0)) ); + } } \ No newline at end of file diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java index 231d2349b..782d37f2f 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java @@ -64,7 +64,7 @@ public class VariantRecalibrator extends RodWalker e = new HashMap(); - e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "9b7517ff1fd0fc23a2596acc82e2ed96" ); + e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "72558d2c49bb94dc59e9d4146fe0bc05" ); for ( Map.Entry entry : e.entrySet() ) { String vcf = entry.getKey(); @@ -24,7 +24,6 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { " -T GenerateVariantClusters" + " -B input,VCF," + vcf + " -L 1:1-100,000,000" + - " -nG 6" + " --ignore_filter GATK_STANDARD" + " -an QD -an HRun -an SB" + " -clusterFile %s", @@ -38,7 +37,7 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { @Test public void testVariantRecalibrator() { HashMap e = new HashMap(); - e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "e0e3d959929aa3940a81c9926d1406e2" ); + e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "d41c4326e589f1746278f1ed9815291a" ); for ( Map.Entry entry : e.entrySet() ) { String vcf = entry.getKey();