From a76cb052e2f5d2513707c133880ec5f140cc4043 Mon Sep 17 00:00:00 2001 From: Geraldine Van der Auwera Date: Tue, 31 May 2016 18:57:55 -0400 Subject: [PATCH] Ability to retry building VQSR model (contributed by mdp) --- .../VariantRecalibrator.java | 116 +++++++++++------- 1 file changed, 72 insertions(+), 44 deletions(-) diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index 1c22f8fe1..416d72c67 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -266,9 +266,11 @@ public class VariantRecalibrator extends RodWalker replicate = new ArrayList<>(); + /** + * The statistical model being built by this tool may fail due to simple statistical sampling + * issues. Rather than dying immediately when the initial model fails, this argument allows the + * tool to restart with a different random seed and try to build the model again. The first + * successfully built model will be kept. + * + * Note that the most common underlying cause of model building failure is that there is insufficient data to + * build a really robust model. This argument provides a workaround for that issue but it is + * preferable to provide this tool with more data (typically by including more samples or more territory) + * in order to generate a more robust model. + */ + @Advanced + @Argument(fullName="max_attempts", shortName = "max_attempts", doc="Number of attempts to build a model before failing", required=false) + protected int max_attempts = 1; + ///////////////////////////// // Debug Arguments ///////////////////////////// @@ -457,55 +474,66 @@ public class VariantRecalibrator extends RodWalker reduceSum ) { - dataManager.setData( reduceSum ); - dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation + for (int i = 1; i <= max_attempts; i++) { + try { + dataManager.setData(reduceSum); + dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation - // Generate the positive model using the training data and evaluate each variant - final List positiveTrainingData = dataManager.getTrainingData(); - final GaussianMixtureModel goodModel = engine.generateModel( positiveTrainingData, VRAC.MAX_GAUSSIANS ); - engine.evaluateData( dataManager.getData(), goodModel, false ); + // Generate the positive model using the training data and evaluate each variant + final List positiveTrainingData = dataManager.getTrainingData(); + final GaussianMixtureModel goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS); + engine.evaluateData(dataManager.getData(), goodModel, false); - // Generate the negative model using the worst performing data and evaluate each variant contrastively - final List negativeTrainingData = dataManager.selectWorstVariants(); - final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); - dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory - engine.evaluateData( dataManager.getData(), badModel, true ); + // Generate the negative model using the worst performing data and evaluate each variant contrastively + final List negativeTrainingData = dataManager.selectWorstVariants(); + final GaussianMixtureModel badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); + dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory + engine.evaluateData(dataManager.getData(), badModel, true); - if( badModel.failedToConverge || goodModel.failedToConverge ) { - throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).") ); - } + if (badModel.failedToConverge || goodModel.failedToConverge) { + throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).")); + } - if (outputModel) { - GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS); - report.print(modelReport); - } + if (outputModel) { + GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS); + report.print(modelReport); + } - engine.calculateWorstPerformingAnnotation( dataManager.getData(), goodModel, badModel ); + engine.calculateWorstPerformingAnnotation(dataManager.getData(), goodModel, badModel); - // Find the VQSLOD cutoff values which correspond to the various tranches of calls requested by the user - final int nCallsAtTruth = TrancheManager.countCallsAtTruth( dataManager.getData(), Double.NEGATIVE_INFINITY ); - final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric( nCallsAtTruth ); - final List tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric, VRAC.MODE ); - tranchesStream.print(Tranche.tranchesString( tranches )); + // Find the VQSLOD cutoff values which correspond to the various tranches of calls requested by the user + final int nCallsAtTruth = TrancheManager.countCallsAtTruth(dataManager.getData(), Double.NEGATIVE_INFINITY); + final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric(nCallsAtTruth); + final List tranches = TrancheManager.findTranches(dataManager.getData(), TS_TRANCHES, metric, VRAC.MODE); + tranchesStream.print(Tranche.tranchesString(tranches)); - logger.info( "Writing out recalibration table..." ); - dataManager.writeOutRecalibrationTable( recalWriter ); - if( RSCRIPT_FILE != null ) { - logger.info( "Writing out visualization Rscript file..."); - createVisualizationScript( dataManager.getRandomDataForPlotting( 1000, positiveTrainingData, negativeTrainingData, dataManager.getEvaluationData() ), goodModel, badModel, 0.0, dataManager.getAnnotationKeys().toArray(new String[USE_ANNOTATIONS.size()]) ); - } + logger.info("Writing out recalibration table..."); + dataManager.writeOutRecalibrationTable(recalWriter); + if (RSCRIPT_FILE != null) { + logger.info("Writing out visualization Rscript file..."); + createVisualizationScript(dataManager.getRandomDataForPlotting(1000, positiveTrainingData, negativeTrainingData, dataManager.getEvaluationData()), goodModel, badModel, 0.0, dataManager.getAnnotationKeys().toArray(new String[USE_ANNOTATIONS.size()])); + } - if(VRAC.MODE == VariantRecalibratorArgumentCollection.Mode.INDEL) { - // Print out an info message to make it clear why the tranches plot is not generated - logger.info("Tranches plot will not be generated since we are running in INDEL mode"); - } else { - // Execute the RScript command to plot the table of truth values - RScriptExecutor executor = new RScriptExecutor(); - executor.addScript(new Resource(PLOT_TRANCHES_RSCRIPT, VariantRecalibrator.class)); - executor.addArgs(TRANCHES_FILE.getAbsoluteFile(), TARGET_TITV); - // Print out the command line to make it clear to the user what is being executed and how one might modify it - logger.info("Executing: " + executor.getApproximateCommandLine()); - executor.exec(); + if (VRAC.MODE == VariantRecalibratorArgumentCollection.Mode.INDEL) { + // Print out an info message to make it clear why the tranches plot is not generated + logger.info("Tranches plot will not be generated since we are running in INDEL mode"); + } else { + // Execute the RScript command to plot the table of truth values + RScriptExecutor executor = new RScriptExecutor(); + executor.addScript(new Resource(PLOT_TRANCHES_RSCRIPT, VariantRecalibrator.class)); + executor.addArgs(TRANCHES_FILE.getAbsoluteFile(), TARGET_TITV); + // Print out the command line to make it clear to the user what is being executed and how one might modify it + logger.info("Executing: " + executor.getApproximateCommandLine()); + executor.exec(); + } + return; + } catch (Exception e) { + if (i == max_attempts) { + throw e; + } else { + logger.info(String.format("Exception occurred on attempt %d of %d. Trying again. Message was: '%s'", i, max_attempts, e.getMessage())); + } + } } }