diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java index 17b3a63d7..2a9b2b5cc 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java @@ -320,5 +320,4 @@ public class GaussianMixtureModel { protected List getModelGaussians() {return Collections.unmodifiableList(gaussians);} protected int getNumAnnotations() {return empiricalMu.length;} - } \ No newline at end of file diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index bb6f80835..b5145f76d 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -490,57 +490,46 @@ public class VariantRecalibrator extends RodWalker positiveTrainingData = dataManager.getTrainingData(); final List negativeTrainingData; - File inputFile = new File(inputModel); + final File inputFile = new File(inputModel); if (inputFile.exists()){ // Load GMM from a file logger.info("Loading model from:"+inputModel); - GATKReport reportIn = new GATKReport(inputFile); + final GATKReport reportIn = new GATKReport(inputFile); // Read all the tables - GATKReportTable amTable = reportIn.getTable("AnnotationMeans"); - GATKReportTable astdTable = reportIn.getTable("AnnotationStdevs"); + final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances"); + final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans"); + final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix"); + final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances"); + final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans"); + final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix"); + final int numAnnotations = dataManager.getMeanVector().length; - // Should have same number of means and standard deviations. - assert(amTable.getNumRows() == astdTable.getNumRows() ); - - GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances"); - GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans"); - GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix"); - GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances"); - GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans"); - GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix"); - - int numAnnotations = amTable.getNumRows(); goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations); engine.evaluateData(dataManager.getData(), goodModel, false); negativeTrainingData = dataManager.selectWorstVariants(); badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations); - logger.info("Loaded GMM from file:" + inputModel); - dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory - engine.evaluateData(dataManager.getData(), badModel, true); } else { // Generate the GMMs from scratch + // Generate the positive model using the training data and evaluate each variant goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS); - //Utils.getRandomGenerator().setSeed(12878); engine.evaluateData(dataManager.getData(), goodModel, false); // Generate the negative model using the worst performing data and evaluate each variant contrastively negativeTrainingData = dataManager.selectWorstVariants(); badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); - dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory - //Utils.getRandomGenerator().setSeed(12878); - engine.evaluateData(dataManager.getData(), badModel, true); - if (badModel.failedToConverge || goodModel.failedToConverge) { throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).")); } } + dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory + engine.evaluateData(dataManager.getData(), badModel, true); + if (outputModel) { GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS); report.print(modelReport); @@ -592,7 +581,7 @@ public class VariantRecalibrator extends RodWalker gaussianList = new ArrayList<>(); int curAnnotation = 0; @@ -642,7 +631,6 @@ public class VariantRecalibrator extends RodWalker