diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java index 3688efca2..358787e51 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java @@ -241,13 +241,13 @@ public class VariantDataManager { } } logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." ); - if( trainingData.size() < VRAC.MIN_NUM_BAD_VARIANTS ) { + if( trainingData.size() < VRAC.NUM_BAD_VARIANTS) { logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." ); } return trainingData; } - public ExpandingArrayList selectWorstVariants( double bottomPercentage, final int minimumNumber ) { + public ExpandingArrayList selectWorstVariants( final int minimumNumber ) { // The return value is the list of training variants final ExpandingArrayList trainingData = new ExpandingArrayList<>(); @@ -262,12 +262,9 @@ public class VariantDataManager { // Next sort the variants by the LOD coming from the positive model and add to the list the bottom X percent of variants Collections.sort( data, new VariantDatum.VariantDatumLODComparator() ); - final int numToAdd = Math.max( minimumNumber - trainingData.size(), Math.round((float)bottomPercentage * data.size()) ); + final int numToAdd = minimumNumber - trainingData.size(); if( numToAdd > data.size() ) { throw new UserException.BadInput( "Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe." ); - } else if( numToAdd == minimumNumber - trainingData.size() ) { - logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." ); - bottomPercentage = ((float) numToAdd) / ((float) data.size()); } int index = 0, numAdded = 0; while( numAdded < numToAdd && index < data.size() ) { @@ -278,25 +275,31 @@ public class VariantDataManager { numAdded++; } } - logger.info( "Additionally training with worst " + String.format("%.3f", (float) bottomPercentage * 100.0f) + "% of passing data --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." ); + logger.info( "Additionally training with worst " + numToAdd + "% of passing data --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." ); return trainingData; } public ExpandingArrayList getRandomDataForPlotting( int numToAdd ) { numToAdd = Math.min(numToAdd, data.size()); final ExpandingArrayList returnData = new ExpandingArrayList<>(); + // add numToAdd non-anti training sites to plot for( int iii = 0; iii < numToAdd; iii++) { final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size())); - if( !datum.failingSTDThreshold ) { + if( ! datum.atAntiTrainingSite && !datum.failingSTDThreshold ) { returnData.add(datum); } } - // Add an extra 5% of points from bad training set, since that set is small but interesting - for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) { - final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size())); - if( datum.atAntiTrainingSite && !datum.failingSTDThreshold ) { returnData.add(datum); } - else { iii--; } + final int MAX_ANTI_TRAINING_SITES = 10000; + int nAntiTrainingAdded = 0; + // Add all anti-training sites to visual + for( final VariantDatum datum : data ) { + if ( nAntiTrainingAdded > MAX_ANTI_TRAINING_SITES ) + break; + else if ( datum.atAntiTrainingSite ) { + returnData.add(datum); + nAntiTrainingAdded++; + } } return returnData; diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java index 6813a0ed4..496d4fd2b 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java @@ -326,25 +326,16 @@ public class VariantRecalibrator extends RodWalker negativeTrainingData = dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS, VRAC.MIN_NUM_BAD_VARIANTS ); - GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData ); + final ExpandingArrayList negativeTrainingData = dataManager.selectWorstVariants( VRAC.NUM_BAD_VARIANTS ); + final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); engine.evaluateData( dataManager.getData(), badModel, true ); - // Detect if the negative model failed to converge because of too few points and/or too many Gaussians and try again - while( badModel.failedToConverge && VRAC.MAX_GAUSSIANS > 4 ) { - logger.info("Negative model failed to converge. Retrying..."); - VRAC.MAX_GAUSSIANS--; - badModel = engine.generateModel( negativeTrainingData ); - engine.evaluateData( dataManager.getData(), goodModel, false ); - engine.evaluateData( dataManager.getData(), badModel, true ); - } - if( badModel.failedToConverge || goodModel.failedToConverge ) { - throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider raising the number of variants used to train the negative model (via --percentBadVariants 0.05, for example) or lowering the maximum number of Gaussians to use in the model (via --maxGaussians 4, for example)"); + throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider raising the number of variants used to train the negative model (via --minNumBad, for example) or lowering the maximum number of Gaussians to use in the model (via --maxGaussians 4, for example)"); } engine.calculateWorstPerformingAnnotation( dataManager.getData(), goodModel, badModel ); diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java index ae0b4a347..b376874fc 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java @@ -46,6 +46,7 @@ package org.broadinstitute.sting.gatk.walkers.variantrecalibration; +import org.broadinstitute.sting.commandline.Advanced; import org.broadinstitute.sting.commandline.Argument; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; @@ -72,8 +73,13 @@ public class VariantRecalibratorArgumentCollection { @Argument(fullName = "mode", shortName = "mode", doc = "Recalibration mode to employ: 1.) SNP for recalibrating only SNPs (emitting indels untouched in the output VCF); 2.) INDEL for indels (emitting SNPs untouched in the output VCF); and 3.) BOTH for recalibrating both SNPs and indels simultaneously (for testing purposes only, not recommended for general use).", required = false) public VariantRecalibratorArgumentCollection.Mode MODE = VariantRecalibratorArgumentCollection.Mode.SNP; - @Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians to try during variational Bayes algorithm.", required=false) + @Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians for the positive model to try during variational Bayes algorithm.", required=false) public int MAX_GAUSSIANS = 10; + + @Advanced + @Argument(fullName="maxNegativeGaussians", shortName="mNG", doc="The maximum number of Gaussians for the negative model to try during variational Bayes algorithm. The actual maximum used is the min of the mG and mNG arguments. Note that this number should be small (like 4) to achieve the best results", required=false) + public int MAX_GAUSSIANS_FOR_NEGATIVE_MODEL = 4; + @Argument(fullName="maxIterations", shortName="mI", doc="The maximum number of VBEM iterations to be performed in variational Bayes algorithm. Procedure will normally end when convergence is detected.", required=false) public int MAX_ITERATIONS = 100; @Argument(fullName="numKMeans", shortName="nKM", doc="The number of k-means iterations to perform in order to initialize the means of the Gaussians in the Gaussian mixture model.", required=false) @@ -88,8 +94,6 @@ public class VariantRecalibratorArgumentCollection { public double DIRICHLET_PARAMETER = 0.001; @Argument(fullName="priorCounts", shortName="priorCounts", doc="The number of prior counts to use in the variational Bayes algorithm.", required=false) public double PRIOR_COUNTS = 20.0; - @Argument(fullName="percentBadVariants", shortName="percentBad", doc="What percentage of the worst scoring variants to use when building the Gaussian mixture model of bad variants. 0.07 means bottom 7 percent.", required=false) - public double PERCENT_BAD_VARIANTS = 0.03; - @Argument(fullName="minNumBadVariants", shortName="minNumBad", doc="The minimum amount of worst scoring variants to use when building the Gaussian mixture model of bad variants. Will override -percentBad argument if necessary.", required=false) - public int MIN_NUM_BAD_VARIANTS = 2500; + @Argument(fullName="numBadVariants", shortName="numBad", doc="The number of worst scoring variants to use when building the Gaussian mixture model of bad variants. Will override -percentBad argument if necessary.", required=false) + public int NUM_BAD_VARIANTS = 1000; } diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java index 6cebc82c1..ffde46394 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java @@ -79,8 +79,8 @@ public class VariantRecalibratorEngine { this.VRAC = VRAC; } - public GaussianMixtureModel generateModel( final List data ) { - final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS ); + public GaussianMixtureModel generateModel( final List data, final int maxGaussians ) { + final GaussianMixtureModel model = new GaussianMixtureModel( maxGaussians, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS ); variationalBayesExpectationMaximization( model, data ); return model; } diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java index b5a541d53..3a6981bab 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java @@ -72,9 +72,9 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { } VRTest lowPass = new VRTest(validationDataLocation + "phase1.projectConsensus.chr20.raw.snps.vcf", - "583e8f63475dfd09a26bf11579075c8e", // tranches - "39a98f13b26c8c1f363f99ab8cead6ca", // recal file - "d235aefef741a6b2c352ef20af1ca790"); // cut VCF + "0f4ceeeb8e4a3c89f8591d5e531d8410", // tranches + "c979a102669498ef40dde47ca4133c42", // recal file + "8f60fd849537610b653b321869e94641"); // cut VCF @DataProvider(name = "VRTest") public Object[][] createData1() { @@ -95,8 +95,6 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { " -L 20:1,000,000-40,000,000" + " --no_cmdline_in_header" + " -an QD -an HaplotypeScore -an HRun" + - " -percentBad 0.07" + - " --minNumBadVariants 0" + " --trustAllPolymorphic" + // for speed " -recalFile %s" + " -tranchesFile %s", @@ -121,9 +119,9 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { } VRTest bcfTest = new VRTest(privateTestDir + "vqsr.bcf_test.snps.unfiltered.bcf", - "d29356849670aabcc12643a2b68dcc82", // tranches - "8abaf8142a6ee212b6dddc7053605512", // recal file - "d6cd4f61875ae09a030fd9f2d7328246"); // cut VCF + "6539e025997579cd0c7da12219cbc572", // tranches + "778e61f81ab3d468b75f684bef0478e5", // recal file + "21e96b0bb47e2976f53f11181f920e51"); // cut VCF @DataProvider(name = "VRBCFTest") public Object[][] createVRBCFTest() { @@ -173,15 +171,15 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { VRTest indelUnfiltered = new VRTest( validationDataLocation + "combined.phase1.chr20.raw.indels.unfiltered.sites.vcf", // all FILTERs as . - "99c3736dab836ae8b41e344062e01b5a", // tranches - "55d2f89980ea9c6c469314129dbac732", // recal file - "482039de04961876890e125055732450"); // cut VCF + "8906fdae8beca712f5ff2808d35ef02d", // tranches + "07ffea25e04f6ef53079bccb30bd6a7b", // recal file + "8b3ef71cad71e8eb48a856a27ae4f8d5"); // cut VCF VRTest indelFiltered = new VRTest( validationDataLocation + "combined.phase1.chr20.raw.indels.filtered.sites.vcf", // all FILTERs as PASS - "99c3736dab836ae8b41e344062e01b5a", // tranches - "55d2f89980ea9c6c469314129dbac732", // recal file - "e63e22ae05ad0bd32b943cde00b6e5a9"); // cut VCF + "8906fdae8beca712f5ff2808d35ef02d", // tranches + "07ffea25e04f6ef53079bccb30bd6a7b", // recal file + "3d69b280370cdd9611695e4893591306"); // cut VCF @DataProvider(name = "VRIndelTest") public Object[][] createTestVariantRecalibratorIndel() { @@ -200,9 +198,7 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { " -L 20:1,000,000-40,000,000" + " --no_cmdline_in_header" + " -an QD -an ReadPosRankSum -an HaplotypeScore" + - " -percentBad 0.08" + " -mode INDEL -mG 3" + - " --minNumBadVariants 0" + " --trustAllPolymorphic" + // for speed " -recalFile %s" + " -tranchesFile %s",