Merge pull request #364 from broadinstitute/md_vqsr_improvements
Separate num Gaussians for + and - GMM in VQSR
This commit is contained in:
commit
00f4d767e4
|
|
@ -241,13 +241,13 @@ public class VariantDataManager {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." );
|
logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." );
|
||||||
if( trainingData.size() < VRAC.MIN_NUM_BAD_VARIANTS ) {
|
if( trainingData.size() < VRAC.NUM_BAD_VARIANTS) {
|
||||||
logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." );
|
logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." );
|
||||||
}
|
}
|
||||||
return trainingData;
|
return trainingData;
|
||||||
}
|
}
|
||||||
|
|
||||||
public ExpandingArrayList<VariantDatum> selectWorstVariants( double bottomPercentage, final int minimumNumber ) {
|
public ExpandingArrayList<VariantDatum> selectWorstVariants( final int minimumNumber ) {
|
||||||
// The return value is the list of training variants
|
// The return value is the list of training variants
|
||||||
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<>();
|
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<>();
|
||||||
|
|
||||||
|
|
@ -262,12 +262,9 @@ public class VariantDataManager {
|
||||||
|
|
||||||
// Next sort the variants by the LOD coming from the positive model and add to the list the bottom X percent of variants
|
// Next sort the variants by the LOD coming from the positive model and add to the list the bottom X percent of variants
|
||||||
Collections.sort( data, new VariantDatum.VariantDatumLODComparator() );
|
Collections.sort( data, new VariantDatum.VariantDatumLODComparator() );
|
||||||
final int numToAdd = Math.max( minimumNumber - trainingData.size(), Math.round((float)bottomPercentage * data.size()) );
|
final int numToAdd = minimumNumber - trainingData.size();
|
||||||
if( numToAdd > data.size() ) {
|
if( numToAdd > data.size() ) {
|
||||||
throw new UserException.BadInput( "Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe." );
|
throw new UserException.BadInput( "Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe." );
|
||||||
} else if( numToAdd == minimumNumber - trainingData.size() ) {
|
|
||||||
logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." );
|
|
||||||
bottomPercentage = ((float) numToAdd) / ((float) data.size());
|
|
||||||
}
|
}
|
||||||
int index = 0, numAdded = 0;
|
int index = 0, numAdded = 0;
|
||||||
while( numAdded < numToAdd && index < data.size() ) {
|
while( numAdded < numToAdd && index < data.size() ) {
|
||||||
|
|
@ -278,25 +275,31 @@ public class VariantDataManager {
|
||||||
numAdded++;
|
numAdded++;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
logger.info( "Additionally training with worst " + String.format("%.3f", (float) bottomPercentage * 100.0f) + "% of passing data --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." );
|
logger.info( "Additionally training with worst " + numToAdd + "% of passing data --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." );
|
||||||
return trainingData;
|
return trainingData;
|
||||||
}
|
}
|
||||||
|
|
||||||
public ExpandingArrayList<VariantDatum> getRandomDataForPlotting( int numToAdd ) {
|
public ExpandingArrayList<VariantDatum> getRandomDataForPlotting( int numToAdd ) {
|
||||||
numToAdd = Math.min(numToAdd, data.size());
|
numToAdd = Math.min(numToAdd, data.size());
|
||||||
final ExpandingArrayList<VariantDatum> returnData = new ExpandingArrayList<>();
|
final ExpandingArrayList<VariantDatum> returnData = new ExpandingArrayList<>();
|
||||||
|
// add numToAdd non-anti training sites to plot
|
||||||
for( int iii = 0; iii < numToAdd; iii++) {
|
for( int iii = 0; iii < numToAdd; iii++) {
|
||||||
final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
|
final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
|
||||||
if( !datum.failingSTDThreshold ) {
|
if( ! datum.atAntiTrainingSite && !datum.failingSTDThreshold ) {
|
||||||
returnData.add(datum);
|
returnData.add(datum);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add an extra 5% of points from bad training set, since that set is small but interesting
|
final int MAX_ANTI_TRAINING_SITES = 10000;
|
||||||
for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) {
|
int nAntiTrainingAdded = 0;
|
||||||
final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
|
// Add all anti-training sites to visual
|
||||||
if( datum.atAntiTrainingSite && !datum.failingSTDThreshold ) { returnData.add(datum); }
|
for( final VariantDatum datum : data ) {
|
||||||
else { iii--; }
|
if ( nAntiTrainingAdded > MAX_ANTI_TRAINING_SITES )
|
||||||
|
break;
|
||||||
|
else if ( datum.atAntiTrainingSite ) {
|
||||||
|
returnData.add(datum);
|
||||||
|
nAntiTrainingAdded++;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return returnData;
|
return returnData;
|
||||||
|
|
|
||||||
|
|
@ -326,25 +326,16 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
||||||
|
|
||||||
// Generate the positive model using the training data and evaluate each variant
|
// Generate the positive model using the training data and evaluate each variant
|
||||||
final GaussianMixtureModel goodModel = engine.generateModel( dataManager.getTrainingData() );
|
final GaussianMixtureModel goodModel = engine.generateModel( dataManager.getTrainingData(), VRAC.MAX_GAUSSIANS );
|
||||||
engine.evaluateData( dataManager.getData(), goodModel, false );
|
engine.evaluateData( dataManager.getData(), goodModel, false );
|
||||||
|
|
||||||
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
||||||
final ExpandingArrayList<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS, VRAC.MIN_NUM_BAD_VARIANTS );
|
final ExpandingArrayList<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants( VRAC.NUM_BAD_VARIANTS );
|
||||||
GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData );
|
final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
|
||||||
engine.evaluateData( dataManager.getData(), badModel, true );
|
engine.evaluateData( dataManager.getData(), badModel, true );
|
||||||
|
|
||||||
// Detect if the negative model failed to converge because of too few points and/or too many Gaussians and try again
|
|
||||||
while( badModel.failedToConverge && VRAC.MAX_GAUSSIANS > 4 ) {
|
|
||||||
logger.info("Negative model failed to converge. Retrying...");
|
|
||||||
VRAC.MAX_GAUSSIANS--;
|
|
||||||
badModel = engine.generateModel( negativeTrainingData );
|
|
||||||
engine.evaluateData( dataManager.getData(), goodModel, false );
|
|
||||||
engine.evaluateData( dataManager.getData(), badModel, true );
|
|
||||||
}
|
|
||||||
|
|
||||||
if( badModel.failedToConverge || goodModel.failedToConverge ) {
|
if( badModel.failedToConverge || goodModel.failedToConverge ) {
|
||||||
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider raising the number of variants used to train the negative model (via --percentBadVariants 0.05, for example) or lowering the maximum number of Gaussians to use in the model (via --maxGaussians 4, for example)");
|
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider raising the number of variants used to train the negative model (via --minNumBad, for example) or lowering the maximum number of Gaussians to use in the model (via --maxGaussians 4, for example)");
|
||||||
}
|
}
|
||||||
|
|
||||||
engine.calculateWorstPerformingAnnotation( dataManager.getData(), goodModel, badModel );
|
engine.calculateWorstPerformingAnnotation( dataManager.getData(), goodModel, badModel );
|
||||||
|
|
|
||||||
|
|
@ -46,6 +46,7 @@
|
||||||
|
|
||||||
package org.broadinstitute.sting.gatk.walkers.variantrecalibration;
|
package org.broadinstitute.sting.gatk.walkers.variantrecalibration;
|
||||||
|
|
||||||
|
import org.broadinstitute.sting.commandline.Advanced;
|
||||||
import org.broadinstitute.sting.commandline.Argument;
|
import org.broadinstitute.sting.commandline.Argument;
|
||||||
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
|
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
|
||||||
|
|
||||||
|
|
@ -72,8 +73,13 @@ public class VariantRecalibratorArgumentCollection {
|
||||||
|
|
||||||
@Argument(fullName = "mode", shortName = "mode", doc = "Recalibration mode to employ: 1.) SNP for recalibrating only SNPs (emitting indels untouched in the output VCF); 2.) INDEL for indels (emitting SNPs untouched in the output VCF); and 3.) BOTH for recalibrating both SNPs and indels simultaneously (for testing purposes only, not recommended for general use).", required = false)
|
@Argument(fullName = "mode", shortName = "mode", doc = "Recalibration mode to employ: 1.) SNP for recalibrating only SNPs (emitting indels untouched in the output VCF); 2.) INDEL for indels (emitting SNPs untouched in the output VCF); and 3.) BOTH for recalibrating both SNPs and indels simultaneously (for testing purposes only, not recommended for general use).", required = false)
|
||||||
public VariantRecalibratorArgumentCollection.Mode MODE = VariantRecalibratorArgumentCollection.Mode.SNP;
|
public VariantRecalibratorArgumentCollection.Mode MODE = VariantRecalibratorArgumentCollection.Mode.SNP;
|
||||||
@Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians to try during variational Bayes algorithm.", required=false)
|
@Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians for the positive model to try during variational Bayes algorithm.", required=false)
|
||||||
public int MAX_GAUSSIANS = 10;
|
public int MAX_GAUSSIANS = 10;
|
||||||
|
|
||||||
|
@Advanced
|
||||||
|
@Argument(fullName="maxNegativeGaussians", shortName="mNG", doc="The maximum number of Gaussians for the negative model to try during variational Bayes algorithm. The actual maximum used is the min of the mG and mNG arguments. Note that this number should be small (like 4) to achieve the best results", required=false)
|
||||||
|
public int MAX_GAUSSIANS_FOR_NEGATIVE_MODEL = 4;
|
||||||
|
|
||||||
@Argument(fullName="maxIterations", shortName="mI", doc="The maximum number of VBEM iterations to be performed in variational Bayes algorithm. Procedure will normally end when convergence is detected.", required=false)
|
@Argument(fullName="maxIterations", shortName="mI", doc="The maximum number of VBEM iterations to be performed in variational Bayes algorithm. Procedure will normally end when convergence is detected.", required=false)
|
||||||
public int MAX_ITERATIONS = 100;
|
public int MAX_ITERATIONS = 100;
|
||||||
@Argument(fullName="numKMeans", shortName="nKM", doc="The number of k-means iterations to perform in order to initialize the means of the Gaussians in the Gaussian mixture model.", required=false)
|
@Argument(fullName="numKMeans", shortName="nKM", doc="The number of k-means iterations to perform in order to initialize the means of the Gaussians in the Gaussian mixture model.", required=false)
|
||||||
|
|
@ -88,8 +94,6 @@ public class VariantRecalibratorArgumentCollection {
|
||||||
public double DIRICHLET_PARAMETER = 0.001;
|
public double DIRICHLET_PARAMETER = 0.001;
|
||||||
@Argument(fullName="priorCounts", shortName="priorCounts", doc="The number of prior counts to use in the variational Bayes algorithm.", required=false)
|
@Argument(fullName="priorCounts", shortName="priorCounts", doc="The number of prior counts to use in the variational Bayes algorithm.", required=false)
|
||||||
public double PRIOR_COUNTS = 20.0;
|
public double PRIOR_COUNTS = 20.0;
|
||||||
@Argument(fullName="percentBadVariants", shortName="percentBad", doc="What percentage of the worst scoring variants to use when building the Gaussian mixture model of bad variants. 0.07 means bottom 7 percent.", required=false)
|
@Argument(fullName="numBadVariants", shortName="numBad", doc="The number of worst scoring variants to use when building the Gaussian mixture model of bad variants. Will override -percentBad argument if necessary.", required=false)
|
||||||
public double PERCENT_BAD_VARIANTS = 0.03;
|
public int NUM_BAD_VARIANTS = 1000;
|
||||||
@Argument(fullName="minNumBadVariants", shortName="minNumBad", doc="The minimum amount of worst scoring variants to use when building the Gaussian mixture model of bad variants. Will override -percentBad argument if necessary.", required=false)
|
|
||||||
public int MIN_NUM_BAD_VARIANTS = 2500;
|
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -79,8 +79,8 @@ public class VariantRecalibratorEngine {
|
||||||
this.VRAC = VRAC;
|
this.VRAC = VRAC;
|
||||||
}
|
}
|
||||||
|
|
||||||
public GaussianMixtureModel generateModel( final List<VariantDatum> data ) {
|
public GaussianMixtureModel generateModel( final List<VariantDatum> data, final int maxGaussians ) {
|
||||||
final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS );
|
final GaussianMixtureModel model = new GaussianMixtureModel( maxGaussians, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS );
|
||||||
variationalBayesExpectationMaximization( model, data );
|
variationalBayesExpectationMaximization( model, data );
|
||||||
return model;
|
return model;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -72,9 +72,9 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
VRTest lowPass = new VRTest(validationDataLocation + "phase1.projectConsensus.chr20.raw.snps.vcf",
|
VRTest lowPass = new VRTest(validationDataLocation + "phase1.projectConsensus.chr20.raw.snps.vcf",
|
||||||
"583e8f63475dfd09a26bf11579075c8e", // tranches
|
"0f4ceeeb8e4a3c89f8591d5e531d8410", // tranches
|
||||||
"39a98f13b26c8c1f363f99ab8cead6ca", // recal file
|
"c979a102669498ef40dde47ca4133c42", // recal file
|
||||||
"d235aefef741a6b2c352ef20af1ca790"); // cut VCF
|
"8f60fd849537610b653b321869e94641"); // cut VCF
|
||||||
|
|
||||||
@DataProvider(name = "VRTest")
|
@DataProvider(name = "VRTest")
|
||||||
public Object[][] createData1() {
|
public Object[][] createData1() {
|
||||||
|
|
@ -95,8 +95,6 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
||||||
" -L 20:1,000,000-40,000,000" +
|
" -L 20:1,000,000-40,000,000" +
|
||||||
" --no_cmdline_in_header" +
|
" --no_cmdline_in_header" +
|
||||||
" -an QD -an HaplotypeScore -an HRun" +
|
" -an QD -an HaplotypeScore -an HRun" +
|
||||||
" -percentBad 0.07" +
|
|
||||||
" --minNumBadVariants 0" +
|
|
||||||
" --trustAllPolymorphic" + // for speed
|
" --trustAllPolymorphic" + // for speed
|
||||||
" -recalFile %s" +
|
" -recalFile %s" +
|
||||||
" -tranchesFile %s",
|
" -tranchesFile %s",
|
||||||
|
|
@ -121,9 +119,9 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
||||||
}
|
}
|
||||||
|
|
||||||
VRTest bcfTest = new VRTest(privateTestDir + "vqsr.bcf_test.snps.unfiltered.bcf",
|
VRTest bcfTest = new VRTest(privateTestDir + "vqsr.bcf_test.snps.unfiltered.bcf",
|
||||||
"d29356849670aabcc12643a2b68dcc82", // tranches
|
"6539e025997579cd0c7da12219cbc572", // tranches
|
||||||
"8abaf8142a6ee212b6dddc7053605512", // recal file
|
"778e61f81ab3d468b75f684bef0478e5", // recal file
|
||||||
"d6cd4f61875ae09a030fd9f2d7328246"); // cut VCF
|
"21e96b0bb47e2976f53f11181f920e51"); // cut VCF
|
||||||
|
|
||||||
@DataProvider(name = "VRBCFTest")
|
@DataProvider(name = "VRBCFTest")
|
||||||
public Object[][] createVRBCFTest() {
|
public Object[][] createVRBCFTest() {
|
||||||
|
|
@ -173,15 +171,15 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
||||||
|
|
||||||
VRTest indelUnfiltered = new VRTest(
|
VRTest indelUnfiltered = new VRTest(
|
||||||
validationDataLocation + "combined.phase1.chr20.raw.indels.unfiltered.sites.vcf", // all FILTERs as .
|
validationDataLocation + "combined.phase1.chr20.raw.indels.unfiltered.sites.vcf", // all FILTERs as .
|
||||||
"99c3736dab836ae8b41e344062e01b5a", // tranches
|
"8906fdae8beca712f5ff2808d35ef02d", // tranches
|
||||||
"55d2f89980ea9c6c469314129dbac732", // recal file
|
"07ffea25e04f6ef53079bccb30bd6a7b", // recal file
|
||||||
"482039de04961876890e125055732450"); // cut VCF
|
"8b3ef71cad71e8eb48a856a27ae4f8d5"); // cut VCF
|
||||||
|
|
||||||
VRTest indelFiltered = new VRTest(
|
VRTest indelFiltered = new VRTest(
|
||||||
validationDataLocation + "combined.phase1.chr20.raw.indels.filtered.sites.vcf", // all FILTERs as PASS
|
validationDataLocation + "combined.phase1.chr20.raw.indels.filtered.sites.vcf", // all FILTERs as PASS
|
||||||
"99c3736dab836ae8b41e344062e01b5a", // tranches
|
"8906fdae8beca712f5ff2808d35ef02d", // tranches
|
||||||
"55d2f89980ea9c6c469314129dbac732", // recal file
|
"07ffea25e04f6ef53079bccb30bd6a7b", // recal file
|
||||||
"e63e22ae05ad0bd32b943cde00b6e5a9"); // cut VCF
|
"3d69b280370cdd9611695e4893591306"); // cut VCF
|
||||||
|
|
||||||
@DataProvider(name = "VRIndelTest")
|
@DataProvider(name = "VRIndelTest")
|
||||||
public Object[][] createTestVariantRecalibratorIndel() {
|
public Object[][] createTestVariantRecalibratorIndel() {
|
||||||
|
|
@ -200,9 +198,7 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
||||||
" -L 20:1,000,000-40,000,000" +
|
" -L 20:1,000,000-40,000,000" +
|
||||||
" --no_cmdline_in_header" +
|
" --no_cmdline_in_header" +
|
||||||
" -an QD -an ReadPosRankSum -an HaplotypeScore" +
|
" -an QD -an ReadPosRankSum -an HaplotypeScore" +
|
||||||
" -percentBad 0.08" +
|
|
||||||
" -mode INDEL -mG 3" +
|
" -mode INDEL -mG 3" +
|
||||||
" --minNumBadVariants 0" +
|
|
||||||
" --trustAllPolymorphic" + // for speed
|
" --trustAllPolymorphic" + // for speed
|
||||||
" -recalFile %s" +
|
" -recalFile %s" +
|
||||||
" -tranchesFile %s",
|
" -tranchesFile %s",
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue