diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java index 5c9a86fea..c210ef699 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java @@ -65,9 +65,10 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel private final double[][] mu; // The means for each cluster private final Matrix[] sigma; // The covariance matrix for each cluster - private final Matrix[] sigmaInverse; + private final double[][][] sigmaInverse; private double[] pClusterLog10; private final double[] determinant; + private final double[] sqrtDeterminantLog10; private final double stdThreshold; private double singletonFPRate = -1; // Estimated FP rate for singleton calls. Used to estimate FP rate as a function of AC @@ -78,6 +79,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel private final double[] hyperParameter_b; private final double[] hyperParameter_lambda; + private final double CONSTANT_GAUSSIAN_DENOM_LOG10; + private static final Pattern COMMENT_PATTERN = Pattern.compile("^##.*"); private static final Pattern ANNOTATION_PATTERN = Pattern.compile("^@!ANNOTATION.*"); private static final Pattern CLUSTER_PATTERN = Pattern.compile("^@!CLUSTER.*"); @@ -93,10 +96,12 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel sigma = null; sigmaInverse = null; determinant = null; + sqrtDeterminantLog10 = null; stdThreshold = 0; hyperParameter_a = null; hyperParameter_b = null; hyperParameter_lambda = null; + CONSTANT_GAUSSIAN_DENOM_LOG10 = 0.0; } public VariantGaussianMixtureModel( final VariantDataManager _dataManager, final int _maxGaussians, final int _maxIterations, @@ -108,6 +113,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel mu = new double[maxGaussians][]; sigma = new Matrix[maxGaussians]; determinant = new double[maxGaussians]; + sqrtDeterminantLog10 = null; pClusterLog10 = new double[maxGaussians]; stdThreshold = _stdThreshold; FORCE_INDEPENDENT_ANNOTATIONS = _forceIndependent; @@ -116,6 +122,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel hyperParameter_lambda = new double[maxGaussians]; sigmaInverse = null; // This field isn't used during GenerateVariantClusters pass + CONSTANT_GAUSSIAN_DENOM_LOG10 = Math.log10(Math.pow(2.0 * Math.PI, ((double)dataManager.numAnnotations) / 2.0)); SHRINKAGE = _shrinkage; DIRICHLET_PARAMETER = _dirichlet; } @@ -149,15 +156,17 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel hyperParameter_a = null; hyperParameter_b = null; hyperParameter_lambda = null; + determinant = null; // BUGBUG: move this parsing out of the constructor + CONSTANT_GAUSSIAN_DENOM_LOG10 = Math.log10(Math.pow(2.0 * Math.PI, ((double)dataManager.numAnnotations) / 2.0)); maxGaussians = clusterLines.size(); mu = new double[maxGaussians][dataManager.numAnnotations]; final double sigmaVals[][][] = new double[maxGaussians][dataManager.numAnnotations][dataManager.numAnnotations]; sigma = new Matrix[maxGaussians]; - sigmaInverse = new Matrix[maxGaussians]; + sigmaInverse = new double[maxGaussians][dataManager.numAnnotations][dataManager.numAnnotations]; pClusterLog10 = new double[maxGaussians]; - determinant = new double[maxGaussians]; + sqrtDeterminantLog10 = new double[maxGaussians]; int kkk = 0; for( final String line : clusterLines ) { @@ -171,8 +180,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } sigma[kkk] = new Matrix(sigmaVals[kkk]); - sigmaInverse[kkk] = sigma[kkk].inverse(); // Precompute all the inverses and determinants for use later - determinant[kkk] = sigma[kkk].det(); + sigmaInverse[kkk] = sigma[kkk].inverse().getArray(); // Precompute all the inverses and determinants for use later + sqrtDeterminantLog10[kkk] = Math.log10(Math.pow(sigma[kkk].det(), 0.5)); // Precompute for use later kkk++; } @@ -381,7 +390,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel logger.info("Finished iteration " + ttt ); ttt++; - if( Math.abs(currentLikelihood - previousLikelihood) < MIN_PROB_CONVERGENCE) { + if( Math.abs(currentLikelihood - previousLikelihood) < MIN_PROB_CONVERGENCE ) { logger.info("Convergence!"); break; } @@ -452,59 +461,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel return sum; } - public final void outputClusterReports( final String outputPrefix ) { - final double STD_STEP = 0.2; - final double MAX_STD = 4.0; - final double MIN_STD = -4.0; - final int NUM_BINS = (int)Math.floor((Math.abs(MIN_STD) + Math.abs(MAX_STD)) / STD_STEP); - final int numAnnotations = dataManager.numAnnotations; - int totalCountsKnown = 0; - int totalCountsNovel = 0; - - final int counts[][][] = new int[numAnnotations][NUM_BINS][2]; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - for( int iii = 0; iii < NUM_BINS; iii++ ) { - counts[jjj][iii][0] = 0; - counts[jjj][iii][1] = 0; - } - } - - for( final VariantDatum datum : dataManager.data ) { - final int isKnown = ( datum.isKnown ? 1 : 0 ); - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - int histBin = (int)Math.round((datum.annotations[jjj]-MIN_STD) * (1.0 / STD_STEP)); - if(histBin < 0) { histBin = 0; } - if(histBin > NUM_BINS-1) { histBin = NUM_BINS-1; } - if(histBin >= 0 && histBin <= NUM_BINS-1) { - counts[jjj][histBin][isKnown]++; - } - } - if( isKnown == 1 ) { totalCountsKnown++; } - else { totalCountsNovel++; } - } - - int annIndex = 0; - for( final String annotation : dataManager.annotationKeys ) { - PrintStream outputFile; - File file = new File(outputPrefix + "." + annotation + ".dat"); - try { - outputFile = new PrintStream( file ); - } catch (FileNotFoundException e) { - throw new UserException.CouldNotCreateOutputFile( file, e ); - } - - outputFile.println("annotationValue,knownDist,novelDist"); - - for( int iii = 0; iii < NUM_BINS; iii++ ) { - final double annotationValue = (((double)iii * STD_STEP)+MIN_STD) * dataManager.varianceVector[annIndex] + dataManager.meanVector[annIndex]; - outputFile.println( annotationValue + "," + ( ((double)counts[annIndex][iii][1])/((double)totalCountsKnown) ) + - "," + ( ((double)counts[annIndex][iii][0])/((double)totalCountsNovel) )); - } - - annIndex++; - } - } - public final void outputOptimizationCurve( final VariantDatum[] data, final PrintStream outputReportDatFile, final PrintStream tranchesOutputFile, final int desiredNumVariants, final Double[] FDRtranches, final double QUAL_STEP ) { @@ -721,9 +677,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final int numAnnotations = annotations.length; - final double evalGaussianPDFLog10[] = new double[maxGaussians]; for( int kkk = 0; kkk < maxGaussians; kkk++ ) { - final double sigmaVals[][] = sigmaInverse[kkk].getArray(); + final double sigmaVals[][] = sigmaInverse[kkk]; double sum = 0.0; for( int jjj = 0; jjj < numAnnotations; jjj++ ) { double value = 0.0; @@ -733,15 +688,14 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final double mySigma = sigmaVals[ppp][jjj]; value += (myAnn - myMu) * mySigma; } - double jNorm = annotations[jjj] - mu[kkk][jjj]; - double prod = value * jNorm; + final double jNorm = annotations[jjj] - mu[kkk][jjj]; + final double prod = value * jNorm; sum += prod; } - final double log10SqrtDet = Math.log10(Math.pow(determinant[kkk], 0.5)); - final double denomLog10 = Math.log10(Math.pow(2.0 * Math.PI, ((double)numAnnotations) / 2.0)) + log10SqrtDet; - evalGaussianPDFLog10[kkk] = (( -0.5 * sum ) / Math.log(10.0)) - denomLog10; - double pVar1 = Math.pow(10.0, pClusterLog10[kkk] + evalGaussianPDFLog10[kkk]); + final double denomLog10 = CONSTANT_GAUSSIAN_DENOM_LOG10 + sqrtDeterminantLog10[kkk]; + final double evalGaussianPDFLog10 = (( -0.5 * sum ) / Math.log(10.0)) - denomLog10; + final double pVar1 = Math.pow(10.0, pClusterLog10[kkk] + evalGaussianPDFLog10); pVarInCluster[kkk] = pVar1; } diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java index 227880ef3..1fc8d7bc5 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java @@ -38,6 +38,7 @@ import org.broadinstitute.sting.gatk.refdata.utils.helpers.DbSNPHelper; import org.broadinstitute.sting.gatk.walkers.RodWalker; import org.broadinstitute.sting.utils.*; import org.broadinstitute.sting.utils.collections.ExpandingArrayList; +import org.broadinstitute.sting.utils.collections.NestedHashMap; import org.broadinstitute.sting.utils.exceptions.UserException; import org.broadinstitute.sting.utils.vcf.VCFUtils; @@ -70,7 +71,7 @@ public class VariantRecalibrator extends RodWalker ignoreInputFilterSet = null; private Set inputNames = new HashSet(); + private NestedHashMap priorCache = new NestedHashMap(); + private boolean trustACField = false; //--------------------------------------------------------------------------------------------------------------- // @@ -228,26 +233,56 @@ public class VariantRecalibrator extends RodWalker attrs = new HashMap(vc.getAttributes()); + attrs.put("OQ", String.format("%.2f", vc.getPhredScaledQual())); attrs.put("LOD", String.format("%.4f", lod)); VariantContext newVC = VariantContext.modifyPErrorFiltersAndAttributes(vc, variantDatum.qual / 10.0, new HashSet(), attrs); diff --git a/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java b/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java index a0a6d4d49..49e50f43c 100755 --- a/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java +++ b/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java @@ -46,7 +46,7 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { public void testVariantRecalibrator() { HashMap> e = new HashMap>(); e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", - Arrays.asList("0c6b5085a678b6312ab4bc8ce7b4eee4", "038c31c5bb46a4df89b8ee69ec740812","7d42bbdfb69fdfb18cbda13a63d92602")); // Each test checks the md5 of three output files + Arrays.asList("e94b02016e6f7936999f02979b801c30", "038c31c5bb46a4df89b8ee69ec740812","7d42bbdfb69fdfb18cbda13a63d92602")); // Each test checks the md5 of three output files e.put( validationDataLocation + "lowpass.N3.chr1.raw.vcf", Arrays.asList("bbdffb7fa611f4ae80e919cdf86b9bc6", "661360e85392af9c97e386399871854a","371e5a70a4006420737c5ab259e0e23e")); // Each test checks the md5 of three output files @@ -84,7 +84,7 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { @Test public void testApplyVariantCuts() { HashMap e = new HashMap(); - e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "7429ed494131eb1dab5a9169cf65d1f0" ); + e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "e06aa6b734cc3c881d95cf5ee9315664" ); e.put( validationDataLocation + "lowpass.N3.chr1.raw.vcf", "ad8661cba3b04a7977c97a541fd8a668" ); for ( Map.Entry entry : e.entrySet() ) {