From a8f70c891f3ea1fe4f44b935913d32624128854b Mon Sep 17 00:00:00 2001 From: Samuel Friedman Date: Fri, 4 Nov 2016 15:05:23 -0400 Subject: [PATCH 01/10] Create GMM from model reports in VQSR --- .../GaussianMixtureModel.java | 9 + .../MultivariateGaussian.java | 10 + .../VariantRecalibrator.java | 209 +++++++++++++++++- .../gatk/utils/report/GATKReportTable.java | 1 - 4 files changed, 216 insertions(+), 13 deletions(-) diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java index 0995fc10c..59a5af92d 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java @@ -86,6 +86,11 @@ public class GaussianMixtureModel { gaussians = new ArrayList<>( numGaussians ); for( int iii = 0; iii < numGaussians; iii++ ) { final MultivariateGaussian gaussian = new MultivariateGaussian( numAnnotations ); + gaussian.pMixtureLog10 = Math.log10( 1.0 / ((double)numGaussians) ); + gaussian.sumProb = 1.0 / ((double) numGaussians); + gaussian.hyperParameter_a = priorCounts; + gaussian.hyperParameter_b = shrinkage; + gaussian.hyperParameter_lambda = dirichletParameter; gaussians.add( gaussian ); } this.shrinkage = shrinkage; @@ -190,6 +195,9 @@ public class GaussianMixtureModel { final double[] pVarInGaussianNormalized = MathUtils.normalizeFromLog10( pVarInGaussianLog10, false ); gaussianIndex = 0; for( final MultivariateGaussian gaussian : gaussians ) { + if (Double.isNaN(pVarInGaussianNormalized[gaussianIndex])){ + logger.info(" Got a NaN at gaussian:" + Integer.toString(gaussianIndex) + " datum:" + datum.toString()); + } gaussian.assignPVarInGaussian( pVarInGaussianNormalized[gaussianIndex++] ); } } @@ -315,4 +323,5 @@ public class GaussianMixtureModel { protected List getModelGaussians() {return Collections.unmodifiableList(gaussians);} protected int getNumAnnotations() {return empiricalMu.length;} + } \ No newline at end of file diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java index db46b0c33..5e9d18ffd 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java @@ -271,4 +271,14 @@ public class MultivariateGaussian { resetPVarInGaussian(); // clean up some memory } + + + public void setSumProb( final List data ) { + sumProb = 0.0; + + for( int datumIndex = 0; datumIndex < data.size(); datumIndex++ ) { + final double prob = pVarInGaussian.get(datumIndex); + if(!Double.isNaN(prob)) sumProb += prob; + } + } } \ No newline at end of file diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index 416d72c67..470ce3cdb 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -66,6 +66,7 @@ import org.broadinstitute.gatk.utils.R.RScriptExecutor; import org.broadinstitute.gatk.utils.Utils; import org.broadinstitute.gatk.utils.help.HelpConstants; import org.broadinstitute.gatk.utils.report.GATKReport; +import org.broadinstitute.gatk.utils.report.GATKReportColumn; import org.broadinstitute.gatk.utils.report.GATKReportTable; import org.broadinstitute.gatk.utils.variant.GATKVariantContextUtils; import htsjdk.variant.vcf.VCFHeader; @@ -80,10 +81,15 @@ import htsjdk.variant.variantcontext.writer.VariantContextWriter; import java.io.File; import java.io.FileNotFoundException; import java.io.PrintStream; +import java.nio.file.Files; import java.util.*; import Jama.Matrix; + +import java.io.FileWriter; +import java.io.BufferedWriter; +import java.io.IOException; /** * Build a recalibration model to score variant quality for filtering purposes * @@ -274,6 +280,8 @@ public class VariantRecalibrator extends RodWalker inputCollection : inputCollections ) input.addAll(inputCollection.getRodBindings()); + + } + //--------------------------------------------------------------------------------------------------------------- // // map @@ -480,18 +491,76 @@ public class VariantRecalibrator extends RodWalker positiveTrainingData = dataManager.getTrainingData(); - final GaussianMixtureModel goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS); - engine.evaluateData(dataManager.getData(), goodModel, false); + final List negativeTrainingData; - // Generate the negative model using the worst performing data and evaluate each variant contrastively - final List negativeTrainingData = dataManager.selectWorstVariants(); - final GaussianMixtureModel badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); - dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory - engine.evaluateData(dataManager.getData(), badModel, true); + File inputFile = new File(inputModel); + if (inputFile.exists()){ // Load GMM from a file + GATKReport reportIn = new GATKReport(inputFile); + GATKReportTable amTable = reportIn.getTable("AnnotationMeans"); + GATKReportTable astdTable = reportIn.getTable("AnnotationStdevs"); - if (badModel.failedToConverge || goodModel.failedToConverge) { - throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).")); + GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances"); + GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans"); + GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix"); + GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances"); + GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans"); + GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix"); + + double[] meanVector; + double[] stdVector; + int numAnnotations = 0; + + for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) { + if (reportColumn.getColumnName().equals("Mean")) { + meanVector = new double[amTable.getNumRows()]; + numAnnotations = amTable.getNumRows(); + for (int row = 0; row < amTable.getNumRows(); row++) { + meanVector[row] = (double) amTable.get(row, reportColumn.getColumnName()); + } + logger.info("Got mean Vector:" + Arrays.toString(meanVector)); + } + } + + for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) { + logger.info("Report column name is:" + reportColumn.getColumnName()); + if (reportColumn.getColumnName().equals("Standarddeviation")) { + stdVector = new double[astdTable.getNumRows()]; + + for (int row = 0; row < astdTable.getNumRows(); row++) { + stdVector[row] = (double) astdTable.get(row, reportColumn.getColumnName()); + } + } + } + + goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations); + //Utils.getRandomGenerator().setSeed(12878); + engine.evaluateData(dataManager.getData(), goodModel, false); + negativeTrainingData = dataManager.selectWorstVariants(); + badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations); + logger.info("Loaded GMM from file:" + inputModel); + + dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory + //Utils.getRandomGenerator().setSeed(12878); + engine.evaluateData(dataManager.getData(), badModel, true); + + } else { // Generate the GMMs from scratch + goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS); + //Utils.getRandomGenerator().setSeed(12878); + engine.evaluateData(dataManager.getData(), goodModel, false); + // Generate the negative model using the worst performing data and evaluate each variant contrastively + negativeTrainingData = dataManager.selectWorstVariants(); + + badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); + dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory + //Utils.getRandomGenerator().setSeed(12878); + engine.evaluateData(dataManager.getData(), badModel, true); + + + if (badModel.failedToConverge || goodModel.failedToConverge) { + throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).")); + } } if (outputModel) { @@ -499,6 +568,9 @@ public class VariantRecalibrator extends RodWalker gaussianList = new ArrayList<>(); + + int curAnnotation = 0; + for (GATKReportColumn reportColumn : muTable.getColumnInfo() ) { + logger.info("Report column name is:" + reportColumn.getColumnName()); + if (!reportColumn.getColumnName().equals("Gaussian")) { + for (int row = 0; row < muTable.getNumRows(); row++) { + if (gaussianList.size() <= row){ + MultivariateGaussian mg = new MultivariateGaussian(numAnnotations); + gaussianList.add(mg); + } + gaussianList.get(row).mu[curAnnotation] = (double) muTable.get(row, reportColumn.getColumnName()); + } + curAnnotation++; + } + } + + for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) { + if (reportColumn.getColumnName().equals("pMixLog10")) { + for (int row = 0; row < pmixTable.getNumRows(); row++) { + gaussianList.get(row).pMixtureLog10 = (double) pmixTable.get(row, reportColumn.getColumnName()); + } + } + } + + int curJ = 0; + for (GATKReportColumn reportColumn : sigmaTable.getColumnInfo() ) { + if (reportColumn.getColumnName().equals("Gaussian")) continue; + if (reportColumn.getColumnName().equals("Annotation")) continue; + + for (int row = 0; row < sigmaTable.getNumRows(); row++) { + int curGaussian = row / numAnnotations; + int curI = row % numAnnotations; + double curVal = (double) sigmaTable.get(row, reportColumn.getColumnName()); + gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal); + + } + curJ++; + + } + + return new GaussianMixtureModel(gaussianList, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS); + + } + + private void writeFeaturesFiles(List positiveTrainingData, List negativeTrainingData){ + //Begin Sam Hacking + try { + File file = new File("/Users/sam/data/haploid_features.txt"); + file.createNewFile(); + File badFile = new File("/Users/sam/data/haploid_bad_features.txt"); + badFile.createNewFile(); + + FileWriter fw = new FileWriter(file.getAbsoluteFile()); + BufferedWriter bw = new BufferedWriter(fw); + for(int jj = 0; jj < positiveTrainingData.size(); jj++){ + VariantDatum v = positiveTrainingData.get(jj); + for(int kk = 0; kk < v.annotations.length; kk++){ + bw.write(Double.toString(v.annotations[kk])); + bw.write(" "); + } + bw.write("\n"); + } + bw.close(); + + fw = new FileWriter(badFile.getAbsoluteFile()); + bw = new BufferedWriter(fw); + for(int jj = 0; jj < negativeTrainingData.size(); jj++){ + VariantDatum v = negativeTrainingData.get(jj); + for(int kk = 0; kk < v.annotations.length; kk++){ + bw.write(Double.toString(v.annotations[kk])); + bw.write(" "); + } + bw.write("\n"); + } + bw.close(); + }catch(IOException e){ + e.printStackTrace(); + } + // End Sam Hacking + } + protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List annotationList) { - final String formatString = "%.3f"; + final String formatString = "%.25f"; final GATKReport report = new GATKReport(); if (dataManager != null) { //for unit test @@ -547,10 +702,36 @@ public class VariantRecalibrator extends RodWalker gaussianStrings = new ArrayList<>(); + final double[] pMixtureLog10s = new double[goodModel.getModelGaussians().size()]; + int idx = 0; + + for( final MultivariateGaussian gaussian : goodModel.getModelGaussians() ) { + pMixtureLog10s[idx] = gaussian.pMixtureLog10; + logger.info("Good normalize PMix log 10 is:" + Double.toString(gaussian.pMixtureLog10) ); + gaussianStrings.add(Integer.toString(idx++) ); + } + + GATKReportTable goodPMix = makeVectorTable("GoodGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10s, "pMixLog10", formatString, "Gaussian"); + report.addTable(goodPMix); + + gaussianStrings.clear(); + final double[] pMixtureLog10sBad = new double[badModel.getModelGaussians().size()]; + idx = 0; + + for( final MultivariateGaussian gaussian : badModel.getModelGaussians() ) { + pMixtureLog10sBad[idx] = gaussian.pMixtureLog10; + logger.info("Bad normalize PMix log 10 is:" + Double.toString(gaussian.pMixtureLog10)); + gaussianStrings.add(Integer.toString(idx++)); + } + GATKReportTable badPMix = makeVectorTable("BadGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10sBad, "pMixLog10", formatString, "Gaussian"); + report.addTable(badPMix); + + //The model and Gaussians don't know what the annotations are, so get them from this class //VariantDataManager keeps the annotation in the same order as the argument list GATKReportTable positiveMeans = makeMeansTable("PositiveModelMeans", "Vector of annotation values to describe the (normalized) mean for each Gaussian in the positive model", annotationList, goodModel, formatString); @@ -570,8 +751,12 @@ public class VariantRecalibrator extends RodWalker annotationList, final double[] perAnnotationValues, final String columnName, final String formatString) { + return makeVectorTable(tableName, tableDescription, annotationList, perAnnotationValues, columnName, formatString, "Annotation"); + } + + protected GATKReportTable makeVectorTable(final String tableName, final String tableDescription, final List annotationList, final double[] perAnnotationValues, final String columnName, final String formatString, final String firstColumn) { GATKReportTable vectorTable = new GATKReportTable(tableName, tableDescription, annotationList.size(), GATKReportTable.TableSortingWay.DO_NOT_SORT); - vectorTable.addColumn("Annotation"); + vectorTable.addColumn(firstColumn); vectorTable.addColumn(columnName, formatString); for (int i = 0; i < perAnnotationValues.length; i++) { vectorTable.addRowIDMapping(annotationList.get(i), i, true); diff --git a/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java b/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java index 85d2386a0..4d87bab84 100644 --- a/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java +++ b/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java @@ -150,7 +150,6 @@ public class GATKReportTable { // read a data line final String dataLine = reader.readLine(); final List lineSplits = Arrays.asList(TextFormattingUtils.splitFixedWidth(dataLine, columnStarts)); - underlyingData.add(new Object[nColumns]); for ( int columnIndex = 0; columnIndex < nColumns; columnIndex++ ) { From 57c064eaa3d3a287f95b29ca701d6d9ae7a0cc7b Mon Sep 17 00:00:00 2001 From: Samuel Friedman Date: Wed, 26 Apr 2017 16:01:16 -0400 Subject: [PATCH 02/10] small code cleanup --- .../VariantRecalibrator.java | 114 +++++++----------- 1 file changed, 46 insertions(+), 68 deletions(-) diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index 470ce3cdb..bb6f80835 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -280,7 +280,7 @@ public class VariantRecalibrator extends RodWalker gaussianList = new ArrayList<>(); int curAnnotation = 0; @@ -621,7 +604,7 @@ public class VariantRecalibrator extends RodWalker positiveTrainingData, List negativeTrainingData){ - //Begin Sam Hacking - try { - File file = new File("/Users/sam/data/haploid_features.txt"); - file.createNewFile(); - File badFile = new File("/Users/sam/data/haploid_bad_features.txt"); - badFile.createNewFile(); + private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){ + double[] stdVector = {}; - FileWriter fw = new FileWriter(file.getAbsoluteFile()); - BufferedWriter bw = new BufferedWriter(fw); - for(int jj = 0; jj < positiveTrainingData.size(); jj++){ - VariantDatum v = positiveTrainingData.get(jj); - for(int kk = 0; kk < v.annotations.length; kk++){ - bw.write(Double.toString(v.annotations[kk])); - bw.write(" "); + for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) { + logger.info("Report column name is:" + reportColumn.getColumnName()); + if (reportColumn.getColumnName().equals("Standarddeviation")) { + stdVector = new double[astdTable.getNumRows()]; + for (int row = 0; row < astdTable.getNumRows(); row++) { + stdVector[row] = Double.parseDouble((String) astdTable.get(row, reportColumn.getColumnName())); } - bw.write("\n"); } - bw.close(); - - fw = new FileWriter(badFile.getAbsoluteFile()); - bw = new BufferedWriter(fw); - for(int jj = 0; jj < negativeTrainingData.size(); jj++){ - VariantDatum v = negativeTrainingData.get(jj); - for(int kk = 0; kk < v.annotations.length; kk++){ - bw.write(Double.toString(v.annotations[kk])); - bw.write(" "); - } - bw.write("\n"); - } - bw.close(); - }catch(IOException e){ - e.printStackTrace(); } - // End Sam Hacking + + return stdVector; + } + + private double[] getMeansFromTable(GATKReportTable amTable){ + double[] meanVector = {}; + + for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) { + if (reportColumn.getColumnName().equals("Mean")) { + meanVector = new double[amTable.getNumRows()]; + for (int row = 0; row < amTable.getNumRows(); row++) { + meanVector[row] = Double.parseDouble((String) amTable.get(row, reportColumn.getColumnName())); + } + logger.info("Got mean Vector:" + Arrays.toString(meanVector)); + } + } + + return meanVector; } protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List annotationList) { From 85cb1c281068e563c31e8bce00012586a40d5df9 Mon Sep 17 00:00:00 2001 From: Samuel Friedman Date: Wed, 26 Apr 2017 17:33:06 -0400 Subject: [PATCH 03/10] dont spam on NaNs --- .../variantrecalibration/GaussianMixtureModel.java | 3 --- .../variantrecalibration/MultivariateGaussian.java | 9 --------- 2 files changed, 12 deletions(-) diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java index 59a5af92d..17b3a63d7 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java @@ -195,9 +195,6 @@ public class GaussianMixtureModel { final double[] pVarInGaussianNormalized = MathUtils.normalizeFromLog10( pVarInGaussianLog10, false ); gaussianIndex = 0; for( final MultivariateGaussian gaussian : gaussians ) { - if (Double.isNaN(pVarInGaussianNormalized[gaussianIndex])){ - logger.info(" Got a NaN at gaussian:" + Integer.toString(gaussianIndex) + " datum:" + datum.toString()); - } gaussian.assignPVarInGaussian( pVarInGaussianNormalized[gaussianIndex++] ); } } diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java index 5e9d18ffd..51662dc92 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java @@ -272,13 +272,4 @@ public class MultivariateGaussian { resetPVarInGaussian(); // clean up some memory } - - public void setSumProb( final List data ) { - sumProb = 0.0; - - for( int datumIndex = 0; datumIndex < data.size(); datumIndex++ ) { - final double prob = pVarInGaussian.get(datumIndex); - if(!Double.isNaN(prob)) sumProb += prob; - } - } } \ No newline at end of file From f55b932cfc04359b32907f8f38be7a5bd3740002 Mon Sep 17 00:00:00 2001 From: Samuel Friedman Date: Thu, 27 Apr 2017 17:16:57 -0400 Subject: [PATCH 04/10] addressed review comments --- .../GaussianMixtureModel.java | 1 - .../VariantRecalibrator.java | 43 ++++++------------- 2 files changed, 14 insertions(+), 30 deletions(-) diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java index 17b3a63d7..2a9b2b5cc 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java @@ -320,5 +320,4 @@ public class GaussianMixtureModel { protected List getModelGaussians() {return Collections.unmodifiableList(gaussians);} protected int getNumAnnotations() {return empiricalMu.length;} - } \ No newline at end of file diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index bb6f80835..b5145f76d 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -490,57 +490,46 @@ public class VariantRecalibrator extends RodWalker positiveTrainingData = dataManager.getTrainingData(); final List negativeTrainingData; - File inputFile = new File(inputModel); + final File inputFile = new File(inputModel); if (inputFile.exists()){ // Load GMM from a file logger.info("Loading model from:"+inputModel); - GATKReport reportIn = new GATKReport(inputFile); + final GATKReport reportIn = new GATKReport(inputFile); // Read all the tables - GATKReportTable amTable = reportIn.getTable("AnnotationMeans"); - GATKReportTable astdTable = reportIn.getTable("AnnotationStdevs"); + final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances"); + final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans"); + final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix"); + final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances"); + final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans"); + final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix"); + final int numAnnotations = dataManager.getMeanVector().length; - // Should have same number of means and standard deviations. - assert(amTable.getNumRows() == astdTable.getNumRows() ); - - GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances"); - GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans"); - GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix"); - GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances"); - GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans"); - GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix"); - - int numAnnotations = amTable.getNumRows(); goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations); engine.evaluateData(dataManager.getData(), goodModel, false); negativeTrainingData = dataManager.selectWorstVariants(); badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations); - logger.info("Loaded GMM from file:" + inputModel); - dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory - engine.evaluateData(dataManager.getData(), badModel, true); } else { // Generate the GMMs from scratch + // Generate the positive model using the training data and evaluate each variant goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS); - //Utils.getRandomGenerator().setSeed(12878); engine.evaluateData(dataManager.getData(), goodModel, false); // Generate the negative model using the worst performing data and evaluate each variant contrastively negativeTrainingData = dataManager.selectWorstVariants(); badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); - dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory - //Utils.getRandomGenerator().setSeed(12878); - engine.evaluateData(dataManager.getData(), badModel, true); - if (badModel.failedToConverge || goodModel.failedToConverge) { throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).")); } } + dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory + engine.evaluateData(dataManager.getData(), badModel, true); + if (outputModel) { GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS); report.print(modelReport); @@ -592,7 +581,7 @@ public class VariantRecalibrator extends RodWalker gaussianList = new ArrayList<>(); int curAnnotation = 0; @@ -642,7 +631,6 @@ public class VariantRecalibrator extends RodWalker Date: Fri, 28 Apr 2017 17:32:36 -0400 Subject: [PATCH 05/10] string cast bug --- .../VariantRecalibrator.java | 18 +++++++++++++----- 1 file changed, 13 insertions(+), 5 deletions(-) diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index b5145f76d..9ed65b87d 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -593,7 +593,7 @@ public class VariantRecalibrator extends RodWalker Date: Mon, 1 May 2017 15:52:45 -0400 Subject: [PATCH 06/10] move model file parsing to initialize --- .../VariantRecalibrator.java | 67 +++++++++---------- 1 file changed, 33 insertions(+), 34 deletions(-) diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index 9ed65b87d..f305c780b 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -319,6 +319,8 @@ public class VariantRecalibrator extends RodWalker ignoreInputFilterSet = new TreeSet<>(); private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC ); + private GaussianMixtureModel goodModel = null; + private GaussianMixtureModel badModel = null; //--------------------------------------------------------------------------------------------------------------- // @@ -356,7 +358,28 @@ public class VariantRecalibrator extends RodWalker hInfo = new HashSet<>(); + final File inputFile = new File(inputModel); + if (inputFile.exists()) { // Load GMM from a file + logger.info("Loading model from:" + inputModel); + final GATKReport reportIn = new GATKReport(inputFile); + + // Read all the tables + final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances"); + final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans"); + final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix"); + final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances"); + final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans"); + final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix"); + final int numAnnotations = dataManager.getMeanVector().length; + + goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations); + badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations); + } + + + + + final Set hInfo = new HashSet<>(); ApplyRecalibration.addVQSRStandardHeaderLines(hInfo); recalWriter.writeHeader( new VCFHeader(hInfo) ); @@ -490,29 +513,14 @@ public class VariantRecalibrator extends RodWalker positiveTrainingData = dataManager.getTrainingData(); final List negativeTrainingData; - final File inputFile = new File(inputModel); - if (inputFile.exists()){ // Load GMM from a file - logger.info("Loading model from:"+inputModel); - final GATKReport reportIn = new GATKReport(inputFile); - - // Read all the tables - final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances"); - final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans"); - final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix"); - final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances"); - final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans"); - final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix"); - final int numAnnotations = dataManager.getMeanVector().length; - - goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations); + if (goodModel != null && badModel != null){ // GMMs were loaded from a file + // Keeping this to maintain reproducibility between runs with and without serialized GMMs engine.evaluateData(dataManager.getData(), goodModel, false); negativeTrainingData = dataManager.selectWorstVariants(); - badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations); - } else { // Generate the GMMs from scratch // Generate the positive model using the training data and evaluate each variant goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS); @@ -586,14 +594,13 @@ public class VariantRecalibrator extends RodWalker annotationList) { - final String formatString = "%.25f"; + final String formatString = "%.8f"; final GATKReport report = new GATKReport(); if (dataManager != null) { //for unit test From 4df71a3ca1ef38420fd91b0b10804d1449db2c7d Mon Sep 17 00:00:00 2001 From: Samuel Friedman Date: Tue, 2 May 2017 12:05:10 -0400 Subject: [PATCH 07/10] expect floats in the report --- .../variantrecalibration/VariantRecalibrator.java | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index f305c780b..8e473b5d5 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -600,7 +600,7 @@ public class VariantRecalibrator extends RodWalker Date: Wed, 3 May 2017 18:37:52 -0400 Subject: [PATCH 08/10] scientific notation in model report and basic test --- .../GaussianMixtureModel.java | 5 +++ .../VariantDataManager.java | 2 +- .../VariantRecalibrator.java | 14 +++------ .../VariantRecalibratorEngine.java | 4 ++- ...ariantRecalibratorModelOutputUnitTest.java | 31 +++++++++++++++++++ 5 files changed, 45 insertions(+), 11 deletions(-) diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java index 2a9b2b5cc..37dc1ae43 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java @@ -110,6 +110,11 @@ public class GaussianMixtureModel { this.shrinkage = shrinkage; this.dirichletParameter = dirichletParameter; this.priorCounts = priorCounts; + for( final MultivariateGaussian gaussian : gaussians ) { + gaussian.hyperParameter_a = priorCounts; + gaussian.hyperParameter_b = shrinkage; + gaussian.hyperParameter_lambda = dirichletParameter; + } empiricalMu = new double[numAnnotations]; empiricalSigma = new Matrix(numAnnotations, numAnnotations); isModelReadyForEvaluation = false; diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java index b1b19433c..d4304d147 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java @@ -277,7 +277,7 @@ public class VariantDataManager { } } - logger.info( "Training with worst " + trainingData.size() + " scoring variants --> variants with LOD <= " + String.format("%.4f", VRAC.BAD_LOD_CUTOFF) + "." ); + logger.info( "Selected worst " + trainingData.size() + " scoring variants --> variants with LOD <= " + String.format("%.4f", VRAC.BAD_LOD_CUTOFF) + "." ); return trainingData; } diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index 8e473b5d5..b4792665b 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -376,10 +376,7 @@ public class VariantRecalibrator extends RodWalker hInfo = new HashSet<>(); + final Set hInfo = new HashSet<>(); ApplyRecalibration.addVQSRStandardHeaderLines(hInfo); recalWriter.writeHeader( new VCFHeader(hInfo) ); @@ -513,12 +510,11 @@ public class VariantRecalibrator extends RodWalker positiveTrainingData = dataManager.getTrainingData(); final List negativeTrainingData; if (goodModel != null && badModel != null){ // GMMs were loaded from a file - // Keeping this to maintain reproducibility between runs with and without serialized GMMs + logger.info("Using serialized GMMs from file..."); engine.evaluateData(dataManager.getData(), goodModel, false); negativeTrainingData = dataManager.selectWorstVariants(); } else { // Generate the GMMs from scratch @@ -527,12 +523,12 @@ public class VariantRecalibrator extends RodWalker gaussianList = new ArrayList<>(); int curAnnotation = 0; @@ -665,7 +661,7 @@ public class VariantRecalibrator extends RodWalker annotationList) { - final String formatString = "%.8f"; + final String formatString = "%.8E"; final GATKReport report = new GATKReport(); if (dataManager != null) { //for unit test diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorEngine.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorEngine.java index f86099113..d8a8bb653 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorEngine.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorEngine.java @@ -98,6 +98,7 @@ public class VariantRecalibratorEngine { try { model.precomputeDenominatorForEvaluation(); } catch( Exception e ) { + logger.warn("Model could not pre-compute denominators."); model.failedToConverge = true; return; } @@ -107,6 +108,7 @@ public class VariantRecalibratorEngine { for( final VariantDatum datum : data ) { final double thisLod = evaluateDatum( datum, model ); if( Double.isNaN(thisLod) ) { + logger.warn("Evaluate datum returned a NaN."); model.failedToConverge = true; return; } @@ -142,7 +144,7 @@ public class VariantRecalibratorEngine { // Private Methods used for generating a GaussianMixtureModel ///////////////////////////// - private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List data ) { + protected void variationalBayesExpectationMaximization(final GaussianMixtureModel model, final List data) { model.initializeRandomModel( data, VRAC.NUM_KMEANS_ITERATIONS ); diff --git a/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorModelOutputUnitTest.java b/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorModelOutputUnitTest.java index 56e515029..7b6f708a9 100644 --- a/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorModelOutputUnitTest.java +++ b/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorModelOutputUnitTest.java @@ -170,8 +170,20 @@ public class VariantRecalibratorModelOutputUnitTest { Assert.assertEquals(badGaussian1.sigma.get(i,j), (Double)badSigma.get(i,annotationList.get(j)), epsilon); } } + + // Now test model report reading + // Read the gaussian weighting tables + final GATKReportTable nPMixTable = report.getTable("BadGaussianPMix"); + final GATKReportTable pPMixTable = report.getTable("GoodGaussianPMix"); + + GaussianMixtureModel goodModelFromFile = vqsr.GMMFromTables(goodMus, goodSigma, pPMixTable, annotationList.size()); + GaussianMixtureModel badModelFromFile = vqsr.GMMFromTables(badMus, badSigma, nPMixTable, annotationList.size()); + + testGMMsForEquality(goodModel, goodModelFromFile, epsilon); + testGMMsForEquality(badModel, badModelFromFile, epsilon); } + @Test //This is tested separately to avoid setting up a VariantDataManager and populating it with fake data public void testAnnotationNormalizationOutput() { @@ -211,4 +223,23 @@ public class VariantRecalibratorModelOutputUnitTest { return returnString; } + private void testGMMsForEquality(GaussianMixtureModel gmm1, GaussianMixtureModel gmm2, double epsilon){ + Assert.assertEquals(gmm1.getModelGaussians().size(), gmm2.getModelGaussians().size(), 0); + + for(int k = 0; k < gmm1.getModelGaussians().size(); k++) { + final MultivariateGaussian g = gmm1.getModelGaussians().get(k); + final MultivariateGaussian gFile = gmm2.getModelGaussians().get(k); + + for(int i = 0; i < g.mu.length; i++){ + Assert.assertEquals(g.mu[i], gFile.mu[i], epsilon); + } + + for(int i = 0; i < g.sigma.getRowDimension(); i++) { + for (int j = 0; j < g.sigma.getColumnDimension(); j++) { + Assert.assertEquals(g.sigma.get(i, j), gFile.sigma.get(i, j), epsilon); + } + } + } + } + } \ No newline at end of file From 68bdb93c8c795d5efeea46f04898540d3c46bb61 Mon Sep 17 00:00:00 2001 From: Samuel Friedman Date: Thu, 4 May 2017 16:48:24 -0400 Subject: [PATCH 09/10] add annotation mismatch warning and refactor tests --- .../VariantRecalibrator.java | 22 ++++--- ...ariantRecalibratorModelOutputUnitTest.java | 59 +++++++++++++++---- 2 files changed, 64 insertions(+), 17 deletions(-) diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index b4792665b..d44e53dba 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -279,11 +279,11 @@ public class VariantRecalibrator extends RodWalker goodGaussianList = new ArrayList<>(); + goodGaussianList.add(goodGaussian1); + goodGaussianList.add(goodGaussian2); + + List badGaussianList = new ArrayList<>(); + badGaussianList.add(badGaussian1); + + GaussianMixtureModel goodModel = new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts); + GaussianMixtureModel badModel = new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts); + + VariantRecalibrator vqsr = new VariantRecalibrator(); + List annotationList = new ArrayList<>(); + annotationList.add("QD"); + annotationList.add("MQ"); + annotationList.add("FS"); + annotationList.add("SOR"); + annotationList.add("ReadPosRankSum"); + annotationList.add("MQRankSum"); + + GATKReport report = vqsr.writeModelReport(goodModel, badModel, annotationList); // Now test model report reading - // Read the gaussian weighting tables + // Read all the tables + final GATKReportTable badMus = report.getTable("NegativeModelMeans"); + final GATKReportTable badSigma = report.getTable("NegativeModelCovariances"); final GATKReportTable nPMixTable = report.getTable("BadGaussianPMix"); + + final GATKReportTable goodMus = report.getTable("PositiveModelMeans"); + final GATKReportTable goodSigma = report.getTable("PositiveModelCovariances"); final GATKReportTable pPMixTable = report.getTable("GoodGaussianPMix"); GaussianMixtureModel goodModelFromFile = vqsr.GMMFromTables(goodMus, goodSigma, pPMixTable, annotationList.size()); @@ -183,7 +223,6 @@ public class VariantRecalibratorModelOutputUnitTest { testGMMsForEquality(badModel, badModelFromFile, epsilon); } - @Test //This is tested separately to avoid setting up a VariantDataManager and populating it with fake data public void testAnnotationNormalizationOutput() { From ed440f1684bdbfedb98e6e3219a06fb6fbd3963c Mon Sep 17 00:00:00 2001 From: Samuel Friedman Date: Fri, 5 May 2017 16:29:01 -0400 Subject: [PATCH 10/10] respond to review comments --- .../VariantRecalibrator.java | 2 +- ...ariantRecalibratorModelOutputUnitTest.java | 163 +++++++++--------- .../gatk/utils/report/GATKReportTable.java | 1 - 3 files changed, 80 insertions(+), 86 deletions(-) diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index d44e53dba..546c0f05b 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -278,7 +278,7 @@ public class VariantRecalibrator extends RodWalker goodGaussianList = new ArrayList<>(); - goodGaussianList.add(goodGaussian1); - goodGaussianList.add(goodGaussian2); - - List badGaussianList = new ArrayList<>(); - badGaussianList.add(badGaussian1); - - GaussianMixtureModel goodModel = new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts); - GaussianMixtureModel badModel = new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts); + GaussianMixtureModel goodModel = getGoodGMM(); + GaussianMixtureModel badModel = getBadGMM(); if (printTables) { System.out.println("Good model mean matrix:"); - System.out.println(vectorToString(goodGaussian1.mu)); - System.out.println(vectorToString(goodGaussian2.mu)); + System.out.println(vectorToString(goodModel.getModelGaussians().get(0).mu)); + System.out.println(vectorToString(goodModel.getModelGaussians().get(1).mu)); System.out.println("\n\n"); System.out.println("Good model covariance matrices:"); - goodGaussian1.sigma.print(10, 3); - goodGaussian2.sigma.print(10, 3); + goodModel.getModelGaussians().get(0).sigma.print(10, 3); + goodModel.getModelGaussians().get(1).sigma.print(10, 3); System.out.println("\n\n"); System.out.println("Bad model mean matrix:\n"); - System.out.println(vectorToString(badGaussian1.mu)); + System.out.println(vectorToString(badModel.getModelGaussians().get(0).mu)); System.out.println("\n\n"); System.out.println("Bad model covariance matrix:"); - badGaussian1.sigma.print(10, 3); + badModel.getModelGaussians().get(0).sigma.print(10, 3); } VariantRecalibrator vqsr = new VariantRecalibrator(); - List annotationList = new ArrayList<>(); - annotationList.add("QD"); - annotationList.add("MQ"); - annotationList.add("FS"); - annotationList.add("SOR"); - annotationList.add("ReadPosRankSum"); - annotationList.add("MQRankSum"); - + List annotationList = getAnnotationList(); GATKReport report = vqsr.writeModelReport(goodModel, badModel, annotationList); - if(printTables) - report.print(System.out); + if(printTables) { + try { + PrintStream modelReporter = new PrintStream(this.privateTestDir+this.modelReportName); + report.print(modelReporter); + } catch (FileNotFoundException e) { + e.printStackTrace(); + } + } //Check values for Gaussian means GATKReportTable goodMus = report.getTable("PositiveModelMeans"); for(int i = 0; i < annotationList.size(); i++) { - Assert.assertEquals(goodGaussian1.mu[i], (Double)goodMus.get(0,annotationList.get(i)), epsilon); + Assert.assertEquals(goodModel.getModelGaussians().get(0).mu[i], (Double)goodMus.get(0,annotationList.get(i)), epsilon); } for(int i = 0; i < annotationList.size(); i++) { - Assert.assertEquals(goodGaussian2.mu[i], (Double)goodMus.get(1,annotationList.get(i)), epsilon); + Assert.assertEquals(goodModel.getModelGaussians().get(1).mu[i], (Double)goodMus.get(1,annotationList.get(i)), epsilon); } GATKReportTable badMus = report.getTable("NegativeModelMeans"); for(int i = 0; i < annotationList.size(); i++) { - Assert.assertEquals(badGaussian1.mu[i], (Double)badMus.get(0,annotationList.get(i)), epsilon); + Assert.assertEquals(badModel.getModelGaussians().get(0).mu[i], (Double)badMus.get(0,annotationList.get(i)), epsilon); } //Check values for Gaussian covariances GATKReportTable goodSigma = report.getTable("PositiveModelCovariances"); for(int i = 0; i < annotationList.size(); i++) { for(int j = 0; j < annotationList.size(); j++) { - Assert.assertEquals(goodGaussian1.sigma.get(i,j), (Double)goodSigma.get(i,annotationList.get(j)), epsilon); + Assert.assertEquals(goodModel.getModelGaussians().get(0).sigma.get(i,j), (Double)goodSigma.get(i,annotationList.get(j)), epsilon); } } //add annotationList.size() to row indexes for second Gaussian because the matrices are concatenated by row in the report for(int i = 0; i < annotationList.size(); i++) { for(int j = 0; j < annotationList.size(); j++) { - Assert.assertEquals(goodGaussian2.sigma.get(i,j), (Double)goodSigma.get(annotationList.size()+i,annotationList.get(j)), epsilon); + Assert.assertEquals(goodModel.getModelGaussians().get(1).sigma.get(i,j), (Double)goodSigma.get(annotationList.size()+i,annotationList.get(j)), epsilon); } } GATKReportTable badSigma = report.getTable("NegativeModelCovariances"); for(int i = 0; i < annotationList.size(); i++) { for(int j = 0; j < annotationList.size(); j++) { - Assert.assertEquals(badGaussian1.sigma.get(i,j), (Double)badSigma.get(i,annotationList.get(j)), epsilon); + Assert.assertEquals(badModel.getModelGaussians().get(0).sigma.get(i,j), (Double)badSigma.get(i,annotationList.get(j)), epsilon); } } } @@ -172,39 +152,8 @@ public class VariantRecalibratorModelOutputUnitTest { @Test public void testVQSRModelInput(){ - Random rand = new Random(12878); - MultivariateGaussian goodGaussian1 = new MultivariateGaussian(numAnnotations); - goodGaussian1.initializeRandomMu(rand); - goodGaussian1.initializeRandomSigma(rand); - - MultivariateGaussian goodGaussian2 = new MultivariateGaussian(numAnnotations); - goodGaussian2.initializeRandomMu(rand); - goodGaussian2.initializeRandomSigma(rand); - - MultivariateGaussian badGaussian1 = new MultivariateGaussian(numAnnotations); - badGaussian1.initializeRandomMu(rand); - badGaussian1.initializeRandomSigma(rand); - - List goodGaussianList = new ArrayList<>(); - goodGaussianList.add(goodGaussian1); - goodGaussianList.add(goodGaussian2); - - List badGaussianList = new ArrayList<>(); - badGaussianList.add(badGaussian1); - - GaussianMixtureModel goodModel = new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts); - GaussianMixtureModel badModel = new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts); - - VariantRecalibrator vqsr = new VariantRecalibrator(); - List annotationList = new ArrayList<>(); - annotationList.add("QD"); - annotationList.add("MQ"); - annotationList.add("FS"); - annotationList.add("SOR"); - annotationList.add("ReadPosRankSum"); - annotationList.add("MQRankSum"); - - GATKReport report = vqsr.writeModelReport(goodModel, badModel, annotationList); + final File inputFile = new File(this.privateTestDir + this.modelReportName); + final GATKReport report = new GATKReport(inputFile); // Now test model report reading // Read all the tables @@ -216,11 +165,14 @@ public class VariantRecalibratorModelOutputUnitTest { final GATKReportTable goodSigma = report.getTable("PositiveModelCovariances"); final GATKReportTable pPMixTable = report.getTable("GoodGaussianPMix"); + List annotationList = getAnnotationList(); + VariantRecalibrator vqsr = new VariantRecalibrator(); + GaussianMixtureModel goodModelFromFile = vqsr.GMMFromTables(goodMus, goodSigma, pPMixTable, annotationList.size()); GaussianMixtureModel badModelFromFile = vqsr.GMMFromTables(badMus, badSigma, nPMixTable, annotationList.size()); - testGMMsForEquality(goodModel, goodModelFromFile, epsilon); - testGMMsForEquality(badModel, badModelFromFile, epsilon); + testGMMsForEquality(getGoodGMM(), goodModelFromFile, epsilon); + testGMMsForEquality(getBadGMM(), badModelFromFile, epsilon); } @Test @@ -269,6 +221,8 @@ public class VariantRecalibratorModelOutputUnitTest { final MultivariateGaussian g = gmm1.getModelGaussians().get(k); final MultivariateGaussian gFile = gmm2.getModelGaussians().get(k); + Assert.assertEquals(g.pMixtureLog10, gFile.pMixtureLog10); + for(int i = 0; i < g.mu.length; i++){ Assert.assertEquals(g.mu[i], gFile.mu[i], epsilon); } @@ -281,4 +235,45 @@ public class VariantRecalibratorModelOutputUnitTest { } } + private List getAnnotationList(){ + List annotationList = new ArrayList<>(); + annotationList.add("QD"); + annotationList.add("MQ"); + annotationList.add("FS"); + annotationList.add("SOR"); + annotationList.add("ReadPosRankSum"); + annotationList.add("MQRankSum"); + return annotationList; + } + + private GaussianMixtureModel getGoodGMM(){ + Random rand = new Random(12878); + MultivariateGaussian goodGaussian1 = new MultivariateGaussian(numAnnotations); + goodGaussian1.initializeRandomMu(rand); + goodGaussian1.initializeRandomSigma(rand); + + MultivariateGaussian goodGaussian2 = new MultivariateGaussian(numAnnotations); + goodGaussian2.initializeRandomMu(rand); + goodGaussian2.initializeRandomSigma(rand); + + List goodGaussianList = new ArrayList<>(); + goodGaussianList.add(goodGaussian1); + goodGaussianList.add(goodGaussian2); + + return new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts); + } + + private GaussianMixtureModel getBadGMM(){ + Random rand = new Random(12878); + MultivariateGaussian badGaussian1 = new MultivariateGaussian(numAnnotations); + + badGaussian1.initializeRandomMu(rand); + badGaussian1.initializeRandomSigma(rand); + + List badGaussianList = new ArrayList<>(); + badGaussianList.add(badGaussian1); + + return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts); + } + } \ No newline at end of file diff --git a/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java b/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java index 4d87bab84..ff7aed473 100644 --- a/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java +++ b/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java @@ -152,7 +152,6 @@ public class GATKReportTable { final List lineSplits = Arrays.asList(TextFormattingUtils.splitFixedWidth(dataLine, columnStarts)); underlyingData.add(new Object[nColumns]); for ( int columnIndex = 0; columnIndex < nColumns; columnIndex++ ) { - final GATKReportDataType type = columnInfo.get(columnIndex).getDataType(); final String columnName = columnNames[columnIndex]; set(i, columnName, type.Parse(lineSplits.get(columnIndex)));