Ability to retry building VQSR model (contributed by mdp)

This commit is contained in:
Geraldine Van der Auwera 2016-05-31 18:57:55 -04:00
parent e2634e56a9
commit a76cb052e2
1 changed files with 72 additions and 44 deletions

View File

@ -266,9 +266,11 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
private File RSCRIPT_FILE = null; private File RSCRIPT_FILE = null;
/** /**
* This GATKReport gives information to describe the VQSR model fit. Normalized means for the positive model are concatenated as one table and negative model normalized means as another table. * This GATKReport gives information to describe the VQSR model fit. Normalized means for the positive model are
* Covariances are also concatenated for postive and negative models, respectively. Tables of annotation means and standard deviations are provided to help describe the normalization. * concatenated as one table and negative model normalized means as another table. Covariances are also concatenated
* The model fit report can be read in with our R gsalib package. Individual model Gaussians can be subset by the value in the "Gaussian" column if desired. * for positive and negative models, respectively. Tables of annotation means and standard deviations are provided
* to help describe the normalization. The model fit report can be read in with our R gsalib package. Individual
* model Gaussians can be subset by the value in the "Gaussian" column if desired.
*/ */
@Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model fit to the file specified by -modelFile or to stdout", required=false) @Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model fit to the file specified by -modelFile or to stdout", required=false)
private boolean outputModel = false; private boolean outputModel = false;
@ -280,6 +282,21 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
protected int REPLICATE = 200; protected int REPLICATE = 200;
private ArrayList<Double> replicate = new ArrayList<>(); private ArrayList<Double> replicate = new ArrayList<>();
/**
* The statistical model being built by this tool may fail due to simple statistical sampling
* issues. Rather than dying immediately when the initial model fails, this argument allows the
* tool to restart with a different random seed and try to build the model again. The first
* successfully built model will be kept.
*
* Note that the most common underlying cause of model building failure is that there is insufficient data to
* build a really robust model. This argument provides a workaround for that issue but it is
* preferable to provide this tool with more data (typically by including more samples or more territory)
* in order to generate a more robust model.
*/
@Advanced
@Argument(fullName="max_attempts", shortName = "max_attempts", doc="Number of attempts to build a model before failing", required=false)
protected int max_attempts = 1;
///////////////////////////// /////////////////////////////
// Debug Arguments // Debug Arguments
///////////////////////////// /////////////////////////////
@ -457,22 +474,24 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
@Override @Override
public void onTraversalDone( final ExpandingArrayList<VariantDatum> reduceSum ) { public void onTraversalDone( final ExpandingArrayList<VariantDatum> reduceSum ) {
dataManager.setData( reduceSum ); for (int i = 1; i <= max_attempts; i++) {
try {
dataManager.setData(reduceSum);
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
// Generate the positive model using the training data and evaluate each variant // Generate the positive model using the training data and evaluate each variant
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData(); final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
final GaussianMixtureModel goodModel = engine.generateModel( positiveTrainingData, VRAC.MAX_GAUSSIANS ); final GaussianMixtureModel goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
engine.evaluateData( dataManager.getData(), goodModel, false ); engine.evaluateData(dataManager.getData(), goodModel, false);
// Generate the negative model using the worst performing data and evaluate each variant contrastively // Generate the negative model using the worst performing data and evaluate each variant contrastively
final List<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants(); final List<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants();
final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); final GaussianMixtureModel badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
engine.evaluateData( dataManager.getData(), badModel, true ); engine.evaluateData(dataManager.getData(), badModel, true);
if( badModel.failedToConverge || goodModel.failedToConverge ) { if (badModel.failedToConverge || goodModel.failedToConverge) {
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).") ); throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example)."));
} }
if (outputModel) { if (outputModel) {
@ -480,22 +499,22 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
report.print(modelReport); report.print(modelReport);
} }
engine.calculateWorstPerformingAnnotation( dataManager.getData(), goodModel, badModel ); engine.calculateWorstPerformingAnnotation(dataManager.getData(), goodModel, badModel);
// Find the VQSLOD cutoff values which correspond to the various tranches of calls requested by the user // Find the VQSLOD cutoff values which correspond to the various tranches of calls requested by the user
final int nCallsAtTruth = TrancheManager.countCallsAtTruth( dataManager.getData(), Double.NEGATIVE_INFINITY ); final int nCallsAtTruth = TrancheManager.countCallsAtTruth(dataManager.getData(), Double.NEGATIVE_INFINITY);
final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric( nCallsAtTruth ); final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric(nCallsAtTruth);
final List<Tranche> tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric, VRAC.MODE ); final List<Tranche> tranches = TrancheManager.findTranches(dataManager.getData(), TS_TRANCHES, metric, VRAC.MODE);
tranchesStream.print(Tranche.tranchesString( tranches )); tranchesStream.print(Tranche.tranchesString(tranches));
logger.info( "Writing out recalibration table..." ); logger.info("Writing out recalibration table...");
dataManager.writeOutRecalibrationTable( recalWriter ); dataManager.writeOutRecalibrationTable(recalWriter);
if( RSCRIPT_FILE != null ) { if (RSCRIPT_FILE != null) {
logger.info( "Writing out visualization Rscript file..."); logger.info("Writing out visualization Rscript file...");
createVisualizationScript( dataManager.getRandomDataForPlotting( 1000, positiveTrainingData, negativeTrainingData, dataManager.getEvaluationData() ), goodModel, badModel, 0.0, dataManager.getAnnotationKeys().toArray(new String[USE_ANNOTATIONS.size()]) ); createVisualizationScript(dataManager.getRandomDataForPlotting(1000, positiveTrainingData, negativeTrainingData, dataManager.getEvaluationData()), goodModel, badModel, 0.0, dataManager.getAnnotationKeys().toArray(new String[USE_ANNOTATIONS.size()]));
} }
if(VRAC.MODE == VariantRecalibratorArgumentCollection.Mode.INDEL) { if (VRAC.MODE == VariantRecalibratorArgumentCollection.Mode.INDEL) {
// Print out an info message to make it clear why the tranches plot is not generated // Print out an info message to make it clear why the tranches plot is not generated
logger.info("Tranches plot will not be generated since we are running in INDEL mode"); logger.info("Tranches plot will not be generated since we are running in INDEL mode");
} else { } else {
@ -507,6 +526,15 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
logger.info("Executing: " + executor.getApproximateCommandLine()); logger.info("Executing: " + executor.getApproximateCommandLine());
executor.exec(); executor.exec();
} }
return;
} catch (Exception e) {
if (i == max_attempts) {
throw e;
} else {
logger.info(String.format("Exception occurred on attempt %d of %d. Trying again. Message was: '%s'", i, max_attempts, e.getMessage()));
}
}
}
} }
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) { protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {