diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index 9ed65b87d..f305c780b 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -319,6 +319,8 @@ public class VariantRecalibrator extends RodWalker ignoreInputFilterSet = new TreeSet<>(); private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC ); + private GaussianMixtureModel goodModel = null; + private GaussianMixtureModel badModel = null; //--------------------------------------------------------------------------------------------------------------- // @@ -356,7 +358,28 @@ public class VariantRecalibrator extends RodWalker hInfo = new HashSet<>(); + final File inputFile = new File(inputModel); + if (inputFile.exists()) { // Load GMM from a file + logger.info("Loading model from:" + inputModel); + final GATKReport reportIn = new GATKReport(inputFile); + + // Read all the tables + final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances"); + final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans"); + final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix"); + final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances"); + final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans"); + final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix"); + final int numAnnotations = dataManager.getMeanVector().length; + + goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations); + badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations); + } + + + + + final Set hInfo = new HashSet<>(); ApplyRecalibration.addVQSRStandardHeaderLines(hInfo); recalWriter.writeHeader( new VCFHeader(hInfo) ); @@ -490,29 +513,14 @@ public class VariantRecalibrator extends RodWalker positiveTrainingData = dataManager.getTrainingData(); final List negativeTrainingData; - final File inputFile = new File(inputModel); - if (inputFile.exists()){ // Load GMM from a file - logger.info("Loading model from:"+inputModel); - final GATKReport reportIn = new GATKReport(inputFile); - - // Read all the tables - final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances"); - final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans"); - final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix"); - final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances"); - final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans"); - final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix"); - final int numAnnotations = dataManager.getMeanVector().length; - - goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations); + if (goodModel != null && badModel != null){ // GMMs were loaded from a file + // Keeping this to maintain reproducibility between runs with and without serialized GMMs engine.evaluateData(dataManager.getData(), goodModel, false); negativeTrainingData = dataManager.selectWorstVariants(); - badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations); - } else { // Generate the GMMs from scratch // Generate the positive model using the training data and evaluate each variant goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS); @@ -586,14 +594,13 @@ public class VariantRecalibrator extends RodWalker annotationList) { - final String formatString = "%.25f"; + final String formatString = "%.8f"; final GATKReport report = new GATKReport(); if (dataManager != null) { //for unit test