Merge pull request #364 from broadinstitute/md_vqsr_improvements
Separate num Gaussians for + and - GMM in VQSR
This commit is contained in:
commit
00f4d767e4
|
|
@ -241,13 +241,13 @@ public class VariantDataManager {
|
|||
}
|
||||
}
|
||||
logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." );
|
||||
if( trainingData.size() < VRAC.MIN_NUM_BAD_VARIANTS ) {
|
||||
if( trainingData.size() < VRAC.NUM_BAD_VARIANTS) {
|
||||
logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." );
|
||||
}
|
||||
return trainingData;
|
||||
}
|
||||
|
||||
public ExpandingArrayList<VariantDatum> selectWorstVariants( double bottomPercentage, final int minimumNumber ) {
|
||||
public ExpandingArrayList<VariantDatum> selectWorstVariants( final int minimumNumber ) {
|
||||
// The return value is the list of training variants
|
||||
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<>();
|
||||
|
||||
|
|
@ -262,12 +262,9 @@ public class VariantDataManager {
|
|||
|
||||
// Next sort the variants by the LOD coming from the positive model and add to the list the bottom X percent of variants
|
||||
Collections.sort( data, new VariantDatum.VariantDatumLODComparator() );
|
||||
final int numToAdd = Math.max( minimumNumber - trainingData.size(), Math.round((float)bottomPercentage * data.size()) );
|
||||
final int numToAdd = minimumNumber - trainingData.size();
|
||||
if( numToAdd > data.size() ) {
|
||||
throw new UserException.BadInput( "Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe." );
|
||||
} else if( numToAdd == minimumNumber - trainingData.size() ) {
|
||||
logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." );
|
||||
bottomPercentage = ((float) numToAdd) / ((float) data.size());
|
||||
}
|
||||
int index = 0, numAdded = 0;
|
||||
while( numAdded < numToAdd && index < data.size() ) {
|
||||
|
|
@ -278,25 +275,31 @@ public class VariantDataManager {
|
|||
numAdded++;
|
||||
}
|
||||
}
|
||||
logger.info( "Additionally training with worst " + String.format("%.3f", (float) bottomPercentage * 100.0f) + "% of passing data --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." );
|
||||
logger.info( "Additionally training with worst " + numToAdd + "% of passing data --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." );
|
||||
return trainingData;
|
||||
}
|
||||
|
||||
public ExpandingArrayList<VariantDatum> getRandomDataForPlotting( int numToAdd ) {
|
||||
numToAdd = Math.min(numToAdd, data.size());
|
||||
final ExpandingArrayList<VariantDatum> returnData = new ExpandingArrayList<>();
|
||||
// add numToAdd non-anti training sites to plot
|
||||
for( int iii = 0; iii < numToAdd; iii++) {
|
||||
final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
|
||||
if( !datum.failingSTDThreshold ) {
|
||||
if( ! datum.atAntiTrainingSite && !datum.failingSTDThreshold ) {
|
||||
returnData.add(datum);
|
||||
}
|
||||
}
|
||||
|
||||
// Add an extra 5% of points from bad training set, since that set is small but interesting
|
||||
for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) {
|
||||
final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
|
||||
if( datum.atAntiTrainingSite && !datum.failingSTDThreshold ) { returnData.add(datum); }
|
||||
else { iii--; }
|
||||
final int MAX_ANTI_TRAINING_SITES = 10000;
|
||||
int nAntiTrainingAdded = 0;
|
||||
// Add all anti-training sites to visual
|
||||
for( final VariantDatum datum : data ) {
|
||||
if ( nAntiTrainingAdded > MAX_ANTI_TRAINING_SITES )
|
||||
break;
|
||||
else if ( datum.atAntiTrainingSite ) {
|
||||
returnData.add(datum);
|
||||
nAntiTrainingAdded++;
|
||||
}
|
||||
}
|
||||
|
||||
return returnData;
|
||||
|
|
|
|||
|
|
@ -326,25 +326,16 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
||||
|
||||
// Generate the positive model using the training data and evaluate each variant
|
||||
final GaussianMixtureModel goodModel = engine.generateModel( dataManager.getTrainingData() );
|
||||
final GaussianMixtureModel goodModel = engine.generateModel( dataManager.getTrainingData(), VRAC.MAX_GAUSSIANS );
|
||||
engine.evaluateData( dataManager.getData(), goodModel, false );
|
||||
|
||||
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
||||
final ExpandingArrayList<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS, VRAC.MIN_NUM_BAD_VARIANTS );
|
||||
GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData );
|
||||
final ExpandingArrayList<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants( VRAC.NUM_BAD_VARIANTS );
|
||||
final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
|
||||
engine.evaluateData( dataManager.getData(), badModel, true );
|
||||
|
||||
// Detect if the negative model failed to converge because of too few points and/or too many Gaussians and try again
|
||||
while( badModel.failedToConverge && VRAC.MAX_GAUSSIANS > 4 ) {
|
||||
logger.info("Negative model failed to converge. Retrying...");
|
||||
VRAC.MAX_GAUSSIANS--;
|
||||
badModel = engine.generateModel( negativeTrainingData );
|
||||
engine.evaluateData( dataManager.getData(), goodModel, false );
|
||||
engine.evaluateData( dataManager.getData(), badModel, true );
|
||||
}
|
||||
|
||||
if( badModel.failedToConverge || goodModel.failedToConverge ) {
|
||||
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider raising the number of variants used to train the negative model (via --percentBadVariants 0.05, for example) or lowering the maximum number of Gaussians to use in the model (via --maxGaussians 4, for example)");
|
||||
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider raising the number of variants used to train the negative model (via --minNumBad, for example) or lowering the maximum number of Gaussians to use in the model (via --maxGaussians 4, for example)");
|
||||
}
|
||||
|
||||
engine.calculateWorstPerformingAnnotation( dataManager.getData(), goodModel, badModel );
|
||||
|
|
|
|||
|
|
@ -46,6 +46,7 @@
|
|||
|
||||
package org.broadinstitute.sting.gatk.walkers.variantrecalibration;
|
||||
|
||||
import org.broadinstitute.sting.commandline.Advanced;
|
||||
import org.broadinstitute.sting.commandline.Argument;
|
||||
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
|
||||
|
||||
|
|
@ -72,8 +73,13 @@ public class VariantRecalibratorArgumentCollection {
|
|||
|
||||
@Argument(fullName = "mode", shortName = "mode", doc = "Recalibration mode to employ: 1.) SNP for recalibrating only SNPs (emitting indels untouched in the output VCF); 2.) INDEL for indels (emitting SNPs untouched in the output VCF); and 3.) BOTH for recalibrating both SNPs and indels simultaneously (for testing purposes only, not recommended for general use).", required = false)
|
||||
public VariantRecalibratorArgumentCollection.Mode MODE = VariantRecalibratorArgumentCollection.Mode.SNP;
|
||||
@Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians to try during variational Bayes algorithm.", required=false)
|
||||
@Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians for the positive model to try during variational Bayes algorithm.", required=false)
|
||||
public int MAX_GAUSSIANS = 10;
|
||||
|
||||
@Advanced
|
||||
@Argument(fullName="maxNegativeGaussians", shortName="mNG", doc="The maximum number of Gaussians for the negative model to try during variational Bayes algorithm. The actual maximum used is the min of the mG and mNG arguments. Note that this number should be small (like 4) to achieve the best results", required=false)
|
||||
public int MAX_GAUSSIANS_FOR_NEGATIVE_MODEL = 4;
|
||||
|
||||
@Argument(fullName="maxIterations", shortName="mI", doc="The maximum number of VBEM iterations to be performed in variational Bayes algorithm. Procedure will normally end when convergence is detected.", required=false)
|
||||
public int MAX_ITERATIONS = 100;
|
||||
@Argument(fullName="numKMeans", shortName="nKM", doc="The number of k-means iterations to perform in order to initialize the means of the Gaussians in the Gaussian mixture model.", required=false)
|
||||
|
|
@ -88,8 +94,6 @@ public class VariantRecalibratorArgumentCollection {
|
|||
public double DIRICHLET_PARAMETER = 0.001;
|
||||
@Argument(fullName="priorCounts", shortName="priorCounts", doc="The number of prior counts to use in the variational Bayes algorithm.", required=false)
|
||||
public double PRIOR_COUNTS = 20.0;
|
||||
@Argument(fullName="percentBadVariants", shortName="percentBad", doc="What percentage of the worst scoring variants to use when building the Gaussian mixture model of bad variants. 0.07 means bottom 7 percent.", required=false)
|
||||
public double PERCENT_BAD_VARIANTS = 0.03;
|
||||
@Argument(fullName="minNumBadVariants", shortName="minNumBad", doc="The minimum amount of worst scoring variants to use when building the Gaussian mixture model of bad variants. Will override -percentBad argument if necessary.", required=false)
|
||||
public int MIN_NUM_BAD_VARIANTS = 2500;
|
||||
@Argument(fullName="numBadVariants", shortName="numBad", doc="The number of worst scoring variants to use when building the Gaussian mixture model of bad variants. Will override -percentBad argument if necessary.", required=false)
|
||||
public int NUM_BAD_VARIANTS = 1000;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -79,8 +79,8 @@ public class VariantRecalibratorEngine {
|
|||
this.VRAC = VRAC;
|
||||
}
|
||||
|
||||
public GaussianMixtureModel generateModel( final List<VariantDatum> data ) {
|
||||
final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS );
|
||||
public GaussianMixtureModel generateModel( final List<VariantDatum> data, final int maxGaussians ) {
|
||||
final GaussianMixtureModel model = new GaussianMixtureModel( maxGaussians, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS );
|
||||
variationalBayesExpectationMaximization( model, data );
|
||||
return model;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -72,9 +72,9 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
|||
}
|
||||
|
||||
VRTest lowPass = new VRTest(validationDataLocation + "phase1.projectConsensus.chr20.raw.snps.vcf",
|
||||
"583e8f63475dfd09a26bf11579075c8e", // tranches
|
||||
"39a98f13b26c8c1f363f99ab8cead6ca", // recal file
|
||||
"d235aefef741a6b2c352ef20af1ca790"); // cut VCF
|
||||
"0f4ceeeb8e4a3c89f8591d5e531d8410", // tranches
|
||||
"c979a102669498ef40dde47ca4133c42", // recal file
|
||||
"8f60fd849537610b653b321869e94641"); // cut VCF
|
||||
|
||||
@DataProvider(name = "VRTest")
|
||||
public Object[][] createData1() {
|
||||
|
|
@ -95,8 +95,6 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
|||
" -L 20:1,000,000-40,000,000" +
|
||||
" --no_cmdline_in_header" +
|
||||
" -an QD -an HaplotypeScore -an HRun" +
|
||||
" -percentBad 0.07" +
|
||||
" --minNumBadVariants 0" +
|
||||
" --trustAllPolymorphic" + // for speed
|
||||
" -recalFile %s" +
|
||||
" -tranchesFile %s",
|
||||
|
|
@ -121,9 +119,9 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
|||
}
|
||||
|
||||
VRTest bcfTest = new VRTest(privateTestDir + "vqsr.bcf_test.snps.unfiltered.bcf",
|
||||
"d29356849670aabcc12643a2b68dcc82", // tranches
|
||||
"8abaf8142a6ee212b6dddc7053605512", // recal file
|
||||
"d6cd4f61875ae09a030fd9f2d7328246"); // cut VCF
|
||||
"6539e025997579cd0c7da12219cbc572", // tranches
|
||||
"778e61f81ab3d468b75f684bef0478e5", // recal file
|
||||
"21e96b0bb47e2976f53f11181f920e51"); // cut VCF
|
||||
|
||||
@DataProvider(name = "VRBCFTest")
|
||||
public Object[][] createVRBCFTest() {
|
||||
|
|
@ -173,15 +171,15 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
|||
|
||||
VRTest indelUnfiltered = new VRTest(
|
||||
validationDataLocation + "combined.phase1.chr20.raw.indels.unfiltered.sites.vcf", // all FILTERs as .
|
||||
"99c3736dab836ae8b41e344062e01b5a", // tranches
|
||||
"55d2f89980ea9c6c469314129dbac732", // recal file
|
||||
"482039de04961876890e125055732450"); // cut VCF
|
||||
"8906fdae8beca712f5ff2808d35ef02d", // tranches
|
||||
"07ffea25e04f6ef53079bccb30bd6a7b", // recal file
|
||||
"8b3ef71cad71e8eb48a856a27ae4f8d5"); // cut VCF
|
||||
|
||||
VRTest indelFiltered = new VRTest(
|
||||
validationDataLocation + "combined.phase1.chr20.raw.indels.filtered.sites.vcf", // all FILTERs as PASS
|
||||
"99c3736dab836ae8b41e344062e01b5a", // tranches
|
||||
"55d2f89980ea9c6c469314129dbac732", // recal file
|
||||
"e63e22ae05ad0bd32b943cde00b6e5a9"); // cut VCF
|
||||
"8906fdae8beca712f5ff2808d35ef02d", // tranches
|
||||
"07ffea25e04f6ef53079bccb30bd6a7b", // recal file
|
||||
"3d69b280370cdd9611695e4893591306"); // cut VCF
|
||||
|
||||
@DataProvider(name = "VRIndelTest")
|
||||
public Object[][] createTestVariantRecalibratorIndel() {
|
||||
|
|
@ -200,9 +198,7 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
|||
" -L 20:1,000,000-40,000,000" +
|
||||
" --no_cmdline_in_header" +
|
||||
" -an QD -an ReadPosRankSum -an HaplotypeScore" +
|
||||
" -percentBad 0.08" +
|
||||
" -mode INDEL -mG 3" +
|
||||
" --minNumBadVariants 0" +
|
||||
" --trustAllPolymorphic" + // for speed
|
||||
" -recalFile %s" +
|
||||
" -tranchesFile %s",
|
||||
|
|
|
|||
Loading…
Reference in New Issue