Ability to retry building VQSR model (contributed by mdp)
This commit is contained in:
parent
e2634e56a9
commit
a76cb052e2
|
|
@ -266,9 +266,11 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
private File RSCRIPT_FILE = null;
|
private File RSCRIPT_FILE = null;
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* This GATKReport gives information to describe the VQSR model fit. Normalized means for the positive model are concatenated as one table and negative model normalized means as another table.
|
* This GATKReport gives information to describe the VQSR model fit. Normalized means for the positive model are
|
||||||
* Covariances are also concatenated for postive and negative models, respectively. Tables of annotation means and standard deviations are provided to help describe the normalization.
|
* concatenated as one table and negative model normalized means as another table. Covariances are also concatenated
|
||||||
* The model fit report can be read in with our R gsalib package. Individual model Gaussians can be subset by the value in the "Gaussian" column if desired.
|
* for positive and negative models, respectively. Tables of annotation means and standard deviations are provided
|
||||||
|
* to help describe the normalization. The model fit report can be read in with our R gsalib package. Individual
|
||||||
|
* model Gaussians can be subset by the value in the "Gaussian" column if desired.
|
||||||
*/
|
*/
|
||||||
@Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model fit to the file specified by -modelFile or to stdout", required=false)
|
@Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model fit to the file specified by -modelFile or to stdout", required=false)
|
||||||
private boolean outputModel = false;
|
private boolean outputModel = false;
|
||||||
|
|
@ -280,6 +282,21 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
protected int REPLICATE = 200;
|
protected int REPLICATE = 200;
|
||||||
private ArrayList<Double> replicate = new ArrayList<>();
|
private ArrayList<Double> replicate = new ArrayList<>();
|
||||||
|
|
||||||
|
/**
|
||||||
|
* The statistical model being built by this tool may fail due to simple statistical sampling
|
||||||
|
* issues. Rather than dying immediately when the initial model fails, this argument allows the
|
||||||
|
* tool to restart with a different random seed and try to build the model again. The first
|
||||||
|
* successfully built model will be kept.
|
||||||
|
*
|
||||||
|
* Note that the most common underlying cause of model building failure is that there is insufficient data to
|
||||||
|
* build a really robust model. This argument provides a workaround for that issue but it is
|
||||||
|
* preferable to provide this tool with more data (typically by including more samples or more territory)
|
||||||
|
* in order to generate a more robust model.
|
||||||
|
*/
|
||||||
|
@Advanced
|
||||||
|
@Argument(fullName="max_attempts", shortName = "max_attempts", doc="Number of attempts to build a model before failing", required=false)
|
||||||
|
protected int max_attempts = 1;
|
||||||
|
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
// Debug Arguments
|
// Debug Arguments
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
|
|
@ -457,55 +474,66 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
|
|
||||||
@Override
|
@Override
|
||||||
public void onTraversalDone( final ExpandingArrayList<VariantDatum> reduceSum ) {
|
public void onTraversalDone( final ExpandingArrayList<VariantDatum> reduceSum ) {
|
||||||
dataManager.setData( reduceSum );
|
for (int i = 1; i <= max_attempts; i++) {
|
||||||
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
try {
|
||||||
|
dataManager.setData(reduceSum);
|
||||||
|
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
||||||
|
|
||||||
// Generate the positive model using the training data and evaluate each variant
|
// Generate the positive model using the training data and evaluate each variant
|
||||||
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
||||||
final GaussianMixtureModel goodModel = engine.generateModel( positiveTrainingData, VRAC.MAX_GAUSSIANS );
|
final GaussianMixtureModel goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
|
||||||
engine.evaluateData( dataManager.getData(), goodModel, false );
|
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||||
|
|
||||||
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
||||||
final List<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants();
|
final List<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants();
|
||||||
final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
|
final GaussianMixtureModel badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
|
||||||
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
||||||
engine.evaluateData( dataManager.getData(), badModel, true );
|
engine.evaluateData(dataManager.getData(), badModel, true);
|
||||||
|
|
||||||
if( badModel.failedToConverge || goodModel.failedToConverge ) {
|
if (badModel.failedToConverge || goodModel.failedToConverge) {
|
||||||
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).") );
|
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example)."));
|
||||||
}
|
}
|
||||||
|
|
||||||
if (outputModel) {
|
if (outputModel) {
|
||||||
GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS);
|
GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS);
|
||||||
report.print(modelReport);
|
report.print(modelReport);
|
||||||
}
|
}
|
||||||
|
|
||||||
engine.calculateWorstPerformingAnnotation( dataManager.getData(), goodModel, badModel );
|
engine.calculateWorstPerformingAnnotation(dataManager.getData(), goodModel, badModel);
|
||||||
|
|
||||||
// Find the VQSLOD cutoff values which correspond to the various tranches of calls requested by the user
|
// Find the VQSLOD cutoff values which correspond to the various tranches of calls requested by the user
|
||||||
final int nCallsAtTruth = TrancheManager.countCallsAtTruth( dataManager.getData(), Double.NEGATIVE_INFINITY );
|
final int nCallsAtTruth = TrancheManager.countCallsAtTruth(dataManager.getData(), Double.NEGATIVE_INFINITY);
|
||||||
final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric( nCallsAtTruth );
|
final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric(nCallsAtTruth);
|
||||||
final List<Tranche> tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric, VRAC.MODE );
|
final List<Tranche> tranches = TrancheManager.findTranches(dataManager.getData(), TS_TRANCHES, metric, VRAC.MODE);
|
||||||
tranchesStream.print(Tranche.tranchesString( tranches ));
|
tranchesStream.print(Tranche.tranchesString(tranches));
|
||||||
|
|
||||||
logger.info( "Writing out recalibration table..." );
|
logger.info("Writing out recalibration table...");
|
||||||
dataManager.writeOutRecalibrationTable( recalWriter );
|
dataManager.writeOutRecalibrationTable(recalWriter);
|
||||||
if( RSCRIPT_FILE != null ) {
|
if (RSCRIPT_FILE != null) {
|
||||||
logger.info( "Writing out visualization Rscript file...");
|
logger.info("Writing out visualization Rscript file...");
|
||||||
createVisualizationScript( dataManager.getRandomDataForPlotting( 1000, positiveTrainingData, negativeTrainingData, dataManager.getEvaluationData() ), goodModel, badModel, 0.0, dataManager.getAnnotationKeys().toArray(new String[USE_ANNOTATIONS.size()]) );
|
createVisualizationScript(dataManager.getRandomDataForPlotting(1000, positiveTrainingData, negativeTrainingData, dataManager.getEvaluationData()), goodModel, badModel, 0.0, dataManager.getAnnotationKeys().toArray(new String[USE_ANNOTATIONS.size()]));
|
||||||
}
|
}
|
||||||
|
|
||||||
if(VRAC.MODE == VariantRecalibratorArgumentCollection.Mode.INDEL) {
|
if (VRAC.MODE == VariantRecalibratorArgumentCollection.Mode.INDEL) {
|
||||||
// Print out an info message to make it clear why the tranches plot is not generated
|
// Print out an info message to make it clear why the tranches plot is not generated
|
||||||
logger.info("Tranches plot will not be generated since we are running in INDEL mode");
|
logger.info("Tranches plot will not be generated since we are running in INDEL mode");
|
||||||
} else {
|
} else {
|
||||||
// Execute the RScript command to plot the table of truth values
|
// Execute the RScript command to plot the table of truth values
|
||||||
RScriptExecutor executor = new RScriptExecutor();
|
RScriptExecutor executor = new RScriptExecutor();
|
||||||
executor.addScript(new Resource(PLOT_TRANCHES_RSCRIPT, VariantRecalibrator.class));
|
executor.addScript(new Resource(PLOT_TRANCHES_RSCRIPT, VariantRecalibrator.class));
|
||||||
executor.addArgs(TRANCHES_FILE.getAbsoluteFile(), TARGET_TITV);
|
executor.addArgs(TRANCHES_FILE.getAbsoluteFile(), TARGET_TITV);
|
||||||
// Print out the command line to make it clear to the user what is being executed and how one might modify it
|
// Print out the command line to make it clear to the user what is being executed and how one might modify it
|
||||||
logger.info("Executing: " + executor.getApproximateCommandLine());
|
logger.info("Executing: " + executor.getApproximateCommandLine());
|
||||||
executor.exec();
|
executor.exec();
|
||||||
|
}
|
||||||
|
return;
|
||||||
|
} catch (Exception e) {
|
||||||
|
if (i == max_attempts) {
|
||||||
|
throw e;
|
||||||
|
} else {
|
||||||
|
logger.info(String.format("Exception occurred on attempt %d of %d. Trying again. Message was: '%s'", i, max_attempts, e.getMessage()));
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue