Merge pull request #1392 from broadinstitute/gvda_vqsr_retries_mdp
Ability to retry building VQSR model (contributed by mdp)
This commit is contained in:
commit
801aa49d25
|
|
@ -266,9 +266,11 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
private File RSCRIPT_FILE = null;
|
||||
|
||||
/**
|
||||
* This GATKReport gives information to describe the VQSR model fit. Normalized means for the positive model are concatenated as one table and negative model normalized means as another table.
|
||||
* Covariances are also concatenated for postive and negative models, respectively. Tables of annotation means and standard deviations are provided to help describe the normalization.
|
||||
* The model fit report can be read in with our R gsalib package. Individual model Gaussians can be subset by the value in the "Gaussian" column if desired.
|
||||
* This GATKReport gives information to describe the VQSR model fit. Normalized means for the positive model are
|
||||
* concatenated as one table and negative model normalized means as another table. Covariances are also concatenated
|
||||
* for positive and negative models, respectively. Tables of annotation means and standard deviations are provided
|
||||
* to help describe the normalization. The model fit report can be read in with our R gsalib package. Individual
|
||||
* model Gaussians can be subset by the value in the "Gaussian" column if desired.
|
||||
*/
|
||||
@Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model fit to the file specified by -modelFile or to stdout", required=false)
|
||||
private boolean outputModel = false;
|
||||
|
|
@ -280,6 +282,21 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
protected int REPLICATE = 200;
|
||||
private ArrayList<Double> replicate = new ArrayList<>();
|
||||
|
||||
/**
|
||||
* The statistical model being built by this tool may fail due to simple statistical sampling
|
||||
* issues. Rather than dying immediately when the initial model fails, this argument allows the
|
||||
* tool to restart with a different random seed and try to build the model again. The first
|
||||
* successfully built model will be kept.
|
||||
*
|
||||
* Note that the most common underlying cause of model building failure is that there is insufficient data to
|
||||
* build a really robust model. This argument provides a workaround for that issue but it is
|
||||
* preferable to provide this tool with more data (typically by including more samples or more territory)
|
||||
* in order to generate a more robust model.
|
||||
*/
|
||||
@Advanced
|
||||
@Argument(fullName="max_attempts", shortName = "max_attempts", doc="Number of attempts to build a model before failing", required=false)
|
||||
protected int max_attempts = 1;
|
||||
|
||||
/////////////////////////////
|
||||
// Debug Arguments
|
||||
/////////////////////////////
|
||||
|
|
@ -457,55 +474,66 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
|
||||
@Override
|
||||
public void onTraversalDone( final ExpandingArrayList<VariantDatum> reduceSum ) {
|
||||
dataManager.setData( reduceSum );
|
||||
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
||||
for (int i = 1; i <= max_attempts; i++) {
|
||||
try {
|
||||
dataManager.setData(reduceSum);
|
||||
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
||||
|
||||
// Generate the positive model using the training data and evaluate each variant
|
||||
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
||||
final GaussianMixtureModel goodModel = engine.generateModel( positiveTrainingData, VRAC.MAX_GAUSSIANS );
|
||||
engine.evaluateData( dataManager.getData(), goodModel, false );
|
||||
// Generate the positive model using the training data and evaluate each variant
|
||||
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
||||
final GaussianMixtureModel goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
|
||||
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||
|
||||
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
||||
final List<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants();
|
||||
final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
|
||||
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
||||
engine.evaluateData( dataManager.getData(), badModel, true );
|
||||
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
||||
final List<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants();
|
||||
final GaussianMixtureModel badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
|
||||
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
||||
engine.evaluateData(dataManager.getData(), badModel, true);
|
||||
|
||||
if( badModel.failedToConverge || goodModel.failedToConverge ) {
|
||||
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).") );
|
||||
}
|
||||
if (badModel.failedToConverge || goodModel.failedToConverge) {
|
||||
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example)."));
|
||||
}
|
||||
|
||||
if (outputModel) {
|
||||
GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS);
|
||||
report.print(modelReport);
|
||||
}
|
||||
if (outputModel) {
|
||||
GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS);
|
||||
report.print(modelReport);
|
||||
}
|
||||
|
||||
engine.calculateWorstPerformingAnnotation( dataManager.getData(), goodModel, badModel );
|
||||
engine.calculateWorstPerformingAnnotation(dataManager.getData(), goodModel, badModel);
|
||||
|
||||
// Find the VQSLOD cutoff values which correspond to the various tranches of calls requested by the user
|
||||
final int nCallsAtTruth = TrancheManager.countCallsAtTruth( dataManager.getData(), Double.NEGATIVE_INFINITY );
|
||||
final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric( nCallsAtTruth );
|
||||
final List<Tranche> tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric, VRAC.MODE );
|
||||
tranchesStream.print(Tranche.tranchesString( tranches ));
|
||||
// Find the VQSLOD cutoff values which correspond to the various tranches of calls requested by the user
|
||||
final int nCallsAtTruth = TrancheManager.countCallsAtTruth(dataManager.getData(), Double.NEGATIVE_INFINITY);
|
||||
final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric(nCallsAtTruth);
|
||||
final List<Tranche> tranches = TrancheManager.findTranches(dataManager.getData(), TS_TRANCHES, metric, VRAC.MODE);
|
||||
tranchesStream.print(Tranche.tranchesString(tranches));
|
||||
|
||||
logger.info( "Writing out recalibration table..." );
|
||||
dataManager.writeOutRecalibrationTable( recalWriter );
|
||||
if( RSCRIPT_FILE != null ) {
|
||||
logger.info( "Writing out visualization Rscript file...");
|
||||
createVisualizationScript( dataManager.getRandomDataForPlotting( 1000, positiveTrainingData, negativeTrainingData, dataManager.getEvaluationData() ), goodModel, badModel, 0.0, dataManager.getAnnotationKeys().toArray(new String[USE_ANNOTATIONS.size()]) );
|
||||
}
|
||||
logger.info("Writing out recalibration table...");
|
||||
dataManager.writeOutRecalibrationTable(recalWriter);
|
||||
if (RSCRIPT_FILE != null) {
|
||||
logger.info("Writing out visualization Rscript file...");
|
||||
createVisualizationScript(dataManager.getRandomDataForPlotting(1000, positiveTrainingData, negativeTrainingData, dataManager.getEvaluationData()), goodModel, badModel, 0.0, dataManager.getAnnotationKeys().toArray(new String[USE_ANNOTATIONS.size()]));
|
||||
}
|
||||
|
||||
if(VRAC.MODE == VariantRecalibratorArgumentCollection.Mode.INDEL) {
|
||||
// Print out an info message to make it clear why the tranches plot is not generated
|
||||
logger.info("Tranches plot will not be generated since we are running in INDEL mode");
|
||||
} else {
|
||||
// Execute the RScript command to plot the table of truth values
|
||||
RScriptExecutor executor = new RScriptExecutor();
|
||||
executor.addScript(new Resource(PLOT_TRANCHES_RSCRIPT, VariantRecalibrator.class));
|
||||
executor.addArgs(TRANCHES_FILE.getAbsoluteFile(), TARGET_TITV);
|
||||
// Print out the command line to make it clear to the user what is being executed and how one might modify it
|
||||
logger.info("Executing: " + executor.getApproximateCommandLine());
|
||||
executor.exec();
|
||||
if (VRAC.MODE == VariantRecalibratorArgumentCollection.Mode.INDEL) {
|
||||
// Print out an info message to make it clear why the tranches plot is not generated
|
||||
logger.info("Tranches plot will not be generated since we are running in INDEL mode");
|
||||
} else {
|
||||
// Execute the RScript command to plot the table of truth values
|
||||
RScriptExecutor executor = new RScriptExecutor();
|
||||
executor.addScript(new Resource(PLOT_TRANCHES_RSCRIPT, VariantRecalibrator.class));
|
||||
executor.addArgs(TRANCHES_FILE.getAbsoluteFile(), TARGET_TITV);
|
||||
// Print out the command line to make it clear to the user what is being executed and how one might modify it
|
||||
logger.info("Executing: " + executor.getApproximateCommandLine());
|
||||
executor.exec();
|
||||
}
|
||||
return;
|
||||
} catch (Exception e) {
|
||||
if (i == max_attempts) {
|
||||
throw e;
|
||||
} else {
|
||||
logger.info(String.format("Exception occurred on attempt %d of %d. Trying again. Message was: '%s'", i, max_attempts, e.getMessage()));
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue