From c21402d4af7c5b5549b28eb519cca2d5425cfb23 Mon Sep 17 00:00:00 2001 From: Mark DePristo Date: Mon, 5 Aug 2013 14:52:53 -0400 Subject: [PATCH] Separate num Gaussians for + and - GMM in VQSR -- The previous approach in VQSR was to build a GMM with the same max. number of Gaussians for the positive and negative models. However, we usually have many more positive sites than negative, so we'd prefer to use a more detailed GMM for the positive model and a less well defined model using few sites for the negative model. -- Now the maxGaussians argument only applies to the positive model -- This update builds a GMM for the negative model with a default 4 max gaussians (though this can be controlled via command line parameter) -- Removes the percentBadVariants argument. The only way to control how many variants are included in the negative model is with minNumBad -- Reduced the minNumBad argument default to 1000 from 2500 -- Update MD5s for VQSR. md5s changed significantly due to underlying changes in the default GMM model. Only sites with NEGATIVE_TRAINING_LABELs and the resulting VQSLOD are different, as expected. -- minNumBad is now numBad -- Plot all negative training points as well, since this significantly changes our view of the GMM PDF --- .../VariantDataManager.java | 29 ++++++++++--------- .../VariantRecalibrator.java | 17 +++-------- ...VariantRecalibratorArgumentCollection.java | 14 +++++---- .../VariantRecalibratorEngine.java | 4 +-- ...ntRecalibrationWalkersIntegrationTest.java | 28 ++++++++---------- 5 files changed, 43 insertions(+), 49 deletions(-) diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java index 3688efca2..358787e51 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java @@ -241,13 +241,13 @@ public class VariantDataManager { } } logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." ); - if( trainingData.size() < VRAC.MIN_NUM_BAD_VARIANTS ) { + if( trainingData.size() < VRAC.NUM_BAD_VARIANTS) { logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." ); } return trainingData; } - public ExpandingArrayList selectWorstVariants( double bottomPercentage, final int minimumNumber ) { + public ExpandingArrayList selectWorstVariants( final int minimumNumber ) { // The return value is the list of training variants final ExpandingArrayList trainingData = new ExpandingArrayList<>(); @@ -262,12 +262,9 @@ public class VariantDataManager { // Next sort the variants by the LOD coming from the positive model and add to the list the bottom X percent of variants Collections.sort( data, new VariantDatum.VariantDatumLODComparator() ); - final int numToAdd = Math.max( minimumNumber - trainingData.size(), Math.round((float)bottomPercentage * data.size()) ); + final int numToAdd = minimumNumber - trainingData.size(); if( numToAdd > data.size() ) { throw new UserException.BadInput( "Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe." ); - } else if( numToAdd == minimumNumber - trainingData.size() ) { - logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." ); - bottomPercentage = ((float) numToAdd) / ((float) data.size()); } int index = 0, numAdded = 0; while( numAdded < numToAdd && index < data.size() ) { @@ -278,25 +275,31 @@ public class VariantDataManager { numAdded++; } } - logger.info( "Additionally training with worst " + String.format("%.3f", (float) bottomPercentage * 100.0f) + "% of passing data --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." ); + logger.info( "Additionally training with worst " + numToAdd + "% of passing data --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." ); return trainingData; } public ExpandingArrayList getRandomDataForPlotting( int numToAdd ) { numToAdd = Math.min(numToAdd, data.size()); final ExpandingArrayList returnData = new ExpandingArrayList<>(); + // add numToAdd non-anti training sites to plot for( int iii = 0; iii < numToAdd; iii++) { final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size())); - if( !datum.failingSTDThreshold ) { + if( ! datum.atAntiTrainingSite && !datum.failingSTDThreshold ) { returnData.add(datum); } } - // Add an extra 5% of points from bad training set, since that set is small but interesting - for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) { - final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size())); - if( datum.atAntiTrainingSite && !datum.failingSTDThreshold ) { returnData.add(datum); } - else { iii--; } + final int MAX_ANTI_TRAINING_SITES = 10000; + int nAntiTrainingAdded = 0; + // Add all anti-training sites to visual + for( final VariantDatum datum : data ) { + if ( nAntiTrainingAdded > MAX_ANTI_TRAINING_SITES ) + break; + else if ( datum.atAntiTrainingSite ) { + returnData.add(datum); + nAntiTrainingAdded++; + } } return returnData; diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java index 6813a0ed4..496d4fd2b 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java @@ -326,25 +326,16 @@ public class VariantRecalibrator extends RodWalker negativeTrainingData = dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS, VRAC.MIN_NUM_BAD_VARIANTS ); - GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData ); + final ExpandingArrayList negativeTrainingData = dataManager.selectWorstVariants( VRAC.NUM_BAD_VARIANTS ); + final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); engine.evaluateData( dataManager.getData(), badModel, true ); - // Detect if the negative model failed to converge because of too few points and/or too many Gaussians and try again - while( badModel.failedToConverge && VRAC.MAX_GAUSSIANS > 4 ) { - logger.info("Negative model failed to converge. Retrying..."); - VRAC.MAX_GAUSSIANS--; - badModel = engine.generateModel( negativeTrainingData ); - engine.evaluateData( dataManager.getData(), goodModel, false ); - engine.evaluateData( dataManager.getData(), badModel, true ); - } - if( badModel.failedToConverge || goodModel.failedToConverge ) { - throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider raising the number of variants used to train the negative model (via --percentBadVariants 0.05, for example) or lowering the maximum number of Gaussians to use in the model (via --maxGaussians 4, for example)"); + throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider raising the number of variants used to train the negative model (via --minNumBad, for example) or lowering the maximum number of Gaussians to use in the model (via --maxGaussians 4, for example)"); } engine.calculateWorstPerformingAnnotation( dataManager.getData(), goodModel, badModel ); diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java index ae0b4a347..b376874fc 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java @@ -46,6 +46,7 @@ package org.broadinstitute.sting.gatk.walkers.variantrecalibration; +import org.broadinstitute.sting.commandline.Advanced; import org.broadinstitute.sting.commandline.Argument; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; @@ -72,8 +73,13 @@ public class VariantRecalibratorArgumentCollection { @Argument(fullName = "mode", shortName = "mode", doc = "Recalibration mode to employ: 1.) SNP for recalibrating only SNPs (emitting indels untouched in the output VCF); 2.) INDEL for indels (emitting SNPs untouched in the output VCF); and 3.) BOTH for recalibrating both SNPs and indels simultaneously (for testing purposes only, not recommended for general use).", required = false) public VariantRecalibratorArgumentCollection.Mode MODE = VariantRecalibratorArgumentCollection.Mode.SNP; - @Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians to try during variational Bayes algorithm.", required=false) + @Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians for the positive model to try during variational Bayes algorithm.", required=false) public int MAX_GAUSSIANS = 10; + + @Advanced + @Argument(fullName="maxNegativeGaussians", shortName="mNG", doc="The maximum number of Gaussians for the negative model to try during variational Bayes algorithm. The actual maximum used is the min of the mG and mNG arguments. Note that this number should be small (like 4) to achieve the best results", required=false) + public int MAX_GAUSSIANS_FOR_NEGATIVE_MODEL = 4; + @Argument(fullName="maxIterations", shortName="mI", doc="The maximum number of VBEM iterations to be performed in variational Bayes algorithm. Procedure will normally end when convergence is detected.", required=false) public int MAX_ITERATIONS = 100; @Argument(fullName="numKMeans", shortName="nKM", doc="The number of k-means iterations to perform in order to initialize the means of the Gaussians in the Gaussian mixture model.", required=false) @@ -88,8 +94,6 @@ public class VariantRecalibratorArgumentCollection { public double DIRICHLET_PARAMETER = 0.001; @Argument(fullName="priorCounts", shortName="priorCounts", doc="The number of prior counts to use in the variational Bayes algorithm.", required=false) public double PRIOR_COUNTS = 20.0; - @Argument(fullName="percentBadVariants", shortName="percentBad", doc="What percentage of the worst scoring variants to use when building the Gaussian mixture model of bad variants. 0.07 means bottom 7 percent.", required=false) - public double PERCENT_BAD_VARIANTS = 0.03; - @Argument(fullName="minNumBadVariants", shortName="minNumBad", doc="The minimum amount of worst scoring variants to use when building the Gaussian mixture model of bad variants. Will override -percentBad argument if necessary.", required=false) - public int MIN_NUM_BAD_VARIANTS = 2500; + @Argument(fullName="numBadVariants", shortName="numBad", doc="The number of worst scoring variants to use when building the Gaussian mixture model of bad variants. Will override -percentBad argument if necessary.", required=false) + public int NUM_BAD_VARIANTS = 1000; } diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java index 6cebc82c1..ffde46394 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java @@ -79,8 +79,8 @@ public class VariantRecalibratorEngine { this.VRAC = VRAC; } - public GaussianMixtureModel generateModel( final List data ) { - final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS ); + public GaussianMixtureModel generateModel( final List data, final int maxGaussians ) { + final GaussianMixtureModel model = new GaussianMixtureModel( maxGaussians, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS ); variationalBayesExpectationMaximization( model, data ); return model; } diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java index b5a541d53..3a6981bab 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java @@ -72,9 +72,9 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { } VRTest lowPass = new VRTest(validationDataLocation + "phase1.projectConsensus.chr20.raw.snps.vcf", - "583e8f63475dfd09a26bf11579075c8e", // tranches - "39a98f13b26c8c1f363f99ab8cead6ca", // recal file - "d235aefef741a6b2c352ef20af1ca790"); // cut VCF + "0f4ceeeb8e4a3c89f8591d5e531d8410", // tranches + "c979a102669498ef40dde47ca4133c42", // recal file + "8f60fd849537610b653b321869e94641"); // cut VCF @DataProvider(name = "VRTest") public Object[][] createData1() { @@ -95,8 +95,6 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { " -L 20:1,000,000-40,000,000" + " --no_cmdline_in_header" + " -an QD -an HaplotypeScore -an HRun" + - " -percentBad 0.07" + - " --minNumBadVariants 0" + " --trustAllPolymorphic" + // for speed " -recalFile %s" + " -tranchesFile %s", @@ -121,9 +119,9 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { } VRTest bcfTest = new VRTest(privateTestDir + "vqsr.bcf_test.snps.unfiltered.bcf", - "d29356849670aabcc12643a2b68dcc82", // tranches - "8abaf8142a6ee212b6dddc7053605512", // recal file - "d6cd4f61875ae09a030fd9f2d7328246"); // cut VCF + "6539e025997579cd0c7da12219cbc572", // tranches + "778e61f81ab3d468b75f684bef0478e5", // recal file + "21e96b0bb47e2976f53f11181f920e51"); // cut VCF @DataProvider(name = "VRBCFTest") public Object[][] createVRBCFTest() { @@ -173,15 +171,15 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { VRTest indelUnfiltered = new VRTest( validationDataLocation + "combined.phase1.chr20.raw.indels.unfiltered.sites.vcf", // all FILTERs as . - "99c3736dab836ae8b41e344062e01b5a", // tranches - "55d2f89980ea9c6c469314129dbac732", // recal file - "482039de04961876890e125055732450"); // cut VCF + "8906fdae8beca712f5ff2808d35ef02d", // tranches + "07ffea25e04f6ef53079bccb30bd6a7b", // recal file + "8b3ef71cad71e8eb48a856a27ae4f8d5"); // cut VCF VRTest indelFiltered = new VRTest( validationDataLocation + "combined.phase1.chr20.raw.indels.filtered.sites.vcf", // all FILTERs as PASS - "99c3736dab836ae8b41e344062e01b5a", // tranches - "55d2f89980ea9c6c469314129dbac732", // recal file - "e63e22ae05ad0bd32b943cde00b6e5a9"); // cut VCF + "8906fdae8beca712f5ff2808d35ef02d", // tranches + "07ffea25e04f6ef53079bccb30bd6a7b", // recal file + "3d69b280370cdd9611695e4893591306"); // cut VCF @DataProvider(name = "VRIndelTest") public Object[][] createTestVariantRecalibratorIndel() { @@ -200,9 +198,7 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { " -L 20:1,000,000-40,000,000" + " --no_cmdline_in_header" + " -an QD -an ReadPosRankSum -an HaplotypeScore" + - " -percentBad 0.08" + " -mode INDEL -mG 3" + - " --minNumBadVariants 0" + " --trustAllPolymorphic" + // for speed " -recalFile %s" + " -tranchesFile %s",