diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java index 0995fc10c..37dc1ae43 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java @@ -86,6 +86,11 @@ public class GaussianMixtureModel { gaussians = new ArrayList<>( numGaussians ); for( int iii = 0; iii < numGaussians; iii++ ) { final MultivariateGaussian gaussian = new MultivariateGaussian( numAnnotations ); + gaussian.pMixtureLog10 = Math.log10( 1.0 / ((double)numGaussians) ); + gaussian.sumProb = 1.0 / ((double) numGaussians); + gaussian.hyperParameter_a = priorCounts; + gaussian.hyperParameter_b = shrinkage; + gaussian.hyperParameter_lambda = dirichletParameter; gaussians.add( gaussian ); } this.shrinkage = shrinkage; @@ -105,6 +110,11 @@ public class GaussianMixtureModel { this.shrinkage = shrinkage; this.dirichletParameter = dirichletParameter; this.priorCounts = priorCounts; + for( final MultivariateGaussian gaussian : gaussians ) { + gaussian.hyperParameter_a = priorCounts; + gaussian.hyperParameter_b = shrinkage; + gaussian.hyperParameter_lambda = dirichletParameter; + } empiricalMu = new double[numAnnotations]; empiricalSigma = new Matrix(numAnnotations, numAnnotations); isModelReadyForEvaluation = false; diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java index db46b0c33..51662dc92 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java @@ -271,4 +271,5 @@ public class MultivariateGaussian { resetPVarInGaussian(); // clean up some memory } + } \ No newline at end of file diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java index b1b19433c..d4304d147 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java @@ -277,7 +277,7 @@ public class VariantDataManager { } } - logger.info( "Training with worst " + trainingData.size() + " scoring variants --> variants with LOD <= " + String.format("%.4f", VRAC.BAD_LOD_CUTOFF) + "." ); + logger.info( "Selected worst " + trainingData.size() + " scoring variants --> variants with LOD <= " + String.format("%.4f", VRAC.BAD_LOD_CUTOFF) + "." ); return trainingData; } diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index 416d72c67..546c0f05b 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -66,6 +66,7 @@ import org.broadinstitute.gatk.utils.R.RScriptExecutor; import org.broadinstitute.gatk.utils.Utils; import org.broadinstitute.gatk.utils.help.HelpConstants; import org.broadinstitute.gatk.utils.report.GATKReport; +import org.broadinstitute.gatk.utils.report.GATKReportColumn; import org.broadinstitute.gatk.utils.report.GATKReportTable; import org.broadinstitute.gatk.utils.variant.GATKVariantContextUtils; import htsjdk.variant.vcf.VCFHeader; @@ -80,10 +81,15 @@ import htsjdk.variant.variantcontext.writer.VariantContextWriter; import java.io.File; import java.io.FileNotFoundException; import java.io.PrintStream; +import java.nio.file.Files; import java.util.*; import Jama.Matrix; + +import java.io.FileWriter; +import java.io.BufferedWriter; +import java.io.IOException; /** * Build a recalibration model to score variant quality for filtering purposes * @@ -272,10 +278,12 @@ public class VariantRecalibrator extends RodWalker ignoreInputFilterSet = new TreeSet<>(); private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC ); + private GaussianMixtureModel goodModel = null; + private GaussianMixtureModel badModel = null; //--------------------------------------------------------------------------------------------------------------- // @@ -348,6 +358,28 @@ public class VariantRecalibrator extends RodWalker hInfo = new HashSet<>(); ApplyRecalibration.addVQSRStandardHeaderLines(hInfo); recalWriter.writeHeader( new VCFHeader(hInfo) ); @@ -359,8 +391,11 @@ public class VariantRecalibrator extends RodWalker inputCollection : inputCollections ) input.addAll(inputCollection.getRodBindings()); + + } + //--------------------------------------------------------------------------------------------------------------- // // map @@ -479,24 +514,37 @@ public class VariantRecalibrator extends RodWalker positiveTrainingData = dataManager.getTrainingData(); - final GaussianMixtureModel goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS); - engine.evaluateData(dataManager.getData(), goodModel, false); + final List negativeTrainingData; + + if (goodModel != null && badModel != null){ // GMMs were loaded from a file + logger.info("Using serialized GMMs from file..."); + engine.evaluateData(dataManager.getData(), goodModel, false); + negativeTrainingData = dataManager.selectWorstVariants(); + } else { // Generate the GMMs from scratch + // Generate the positive model using the training data and evaluate each variant + goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS); + engine.evaluateData(dataManager.getData(), goodModel, false); + // Generate the negative model using the worst performing data and evaluate each variant contrastively + negativeTrainingData = dataManager.selectWorstVariants(); + badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); + + if (badModel.failedToConverge || goodModel.failedToConverge) { + throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).")); + } + + } - // Generate the negative model using the worst performing data and evaluate each variant contrastively - final List negativeTrainingData = dataManager.selectWorstVariants(); - final GaussianMixtureModel badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory engine.evaluateData(dataManager.getData(), badModel, true); - if (badModel.failedToConverge || goodModel.failedToConverge) { - throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).")); - } - - if (outputModel) { - GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS); - report.print(modelReport); + if (outputModel != null) { + try (PrintStream modelReporter = new PrintStream(outputModel)) { + GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS); + report.print(modelReporter); + } catch (FileNotFoundException e){ + throw new UserException("Could not open output model file:" + outputModel); + } } engine.calculateWorstPerformingAnnotation(dataManager.getData(), goodModel, badModel); @@ -537,8 +585,91 @@ public class VariantRecalibrator extends RodWalker gaussianList = new ArrayList<>(); + + int curAnnotation = 0; + for (GATKReportColumn reportColumn : muTable.getColumnInfo() ) { + if (!reportColumn.getColumnName().equals("Gaussian")) { + for (int row = 0; row < muTable.getNumRows(); row++) { + if (gaussianList.size() <= row){ + MultivariateGaussian mg = new MultivariateGaussian(numAnnotations); + gaussianList.add(mg); + } + gaussianList.get(row).mu[curAnnotation] = (Double) muTable.get(row, reportColumn.getColumnName()); + } + curAnnotation++; + } + } + + for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) { + if (reportColumn.getColumnName().equals("pMixLog10")) { + for (int row = 0; row < pmixTable.getNumRows(); row++) { + gaussianList.get(row).pMixtureLog10 = (Double) pmixTable.get(row, reportColumn.getColumnName()); + } + } + } + + int curJ = 0; + for (GATKReportColumn reportColumn : sigmaTable.getColumnInfo() ) { + if (reportColumn.getColumnName().equals("Gaussian")) continue; + if (reportColumn.getColumnName().equals("Annotation")) continue; + + for (int row = 0; row < sigmaTable.getNumRows(); row++) { + int curGaussian = row / numAnnotations; + int curI = row % numAnnotations; + double curVal = (Double) sigmaTable.get(row, reportColumn.getColumnName()); + gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal); + + } + curJ++; + + } + + return new GaussianMixtureModel(gaussianList, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS); + + } + + private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){ + double[] stdVector = {}; + + for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) { + if (reportColumn.getColumnName().equals("Standarddeviation")) { + stdVector = new double[astdTable.getNumRows()]; + for (int row = 0; row < astdTable.getNumRows(); row++) { + stdVector[row] = (Double) astdTable.get(row, reportColumn.getColumnName()); + } + } + } + + return stdVector; + } + + private double[] getMeansFromTable(GATKReportTable amTable){ + double[] meanVector = {}; + + for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) { + if (reportColumn.getColumnName().equals("Mean")) { + meanVector = new double[amTable.getNumRows()]; + for (int row = 0; row < amTable.getNumRows(); row++) { + meanVector[row] = (Double) amTable.get(row, reportColumn.getColumnName()); + } + } + } + + return meanVector; + } + protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List annotationList) { - final String formatString = "%.3f"; + final String formatString = "%.8E"; final GATKReport report = new GATKReport(); if (dataManager != null) { //for unit test @@ -547,10 +678,34 @@ public class VariantRecalibrator extends RodWalker gaussianStrings = new ArrayList<>(); + final double[] pMixtureLog10s = new double[goodModel.getModelGaussians().size()]; + int idx = 0; + + for( final MultivariateGaussian gaussian : goodModel.getModelGaussians() ) { + pMixtureLog10s[idx] = gaussian.pMixtureLog10; + gaussianStrings.add(Integer.toString(idx++) ); + } + + GATKReportTable goodPMix = makeVectorTable("GoodGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10s, "pMixLog10", formatString, "Gaussian"); + report.addTable(goodPMix); + + gaussianStrings.clear(); + final double[] pMixtureLog10sBad = new double[badModel.getModelGaussians().size()]; + idx = 0; + + for( final MultivariateGaussian gaussian : badModel.getModelGaussians() ) { + pMixtureLog10sBad[idx] = gaussian.pMixtureLog10; + gaussianStrings.add(Integer.toString(idx++)); + } + GATKReportTable badPMix = makeVectorTable("BadGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10sBad, "pMixLog10", formatString, "Gaussian"); + report.addTable(badPMix); + + //The model and Gaussians don't know what the annotations are, so get them from this class //VariantDataManager keeps the annotation in the same order as the argument list GATKReportTable positiveMeans = makeMeansTable("PositiveModelMeans", "Vector of annotation values to describe the (normalized) mean for each Gaussian in the positive model", annotationList, goodModel, formatString); @@ -570,8 +725,12 @@ public class VariantRecalibrator extends RodWalker annotationList, final double[] perAnnotationValues, final String columnName, final String formatString) { + return makeVectorTable(tableName, tableDescription, annotationList, perAnnotationValues, columnName, formatString, "Annotation"); + } + + protected GATKReportTable makeVectorTable(final String tableName, final String tableDescription, final List annotationList, final double[] perAnnotationValues, final String columnName, final String formatString, final String firstColumn) { GATKReportTable vectorTable = new GATKReportTable(tableName, tableDescription, annotationList.size(), GATKReportTable.TableSortingWay.DO_NOT_SORT); - vectorTable.addColumn("Annotation"); + vectorTable.addColumn(firstColumn); vectorTable.addColumn(columnName, formatString); for (int i = 0; i < perAnnotationValues.length; i++) { vectorTable.addRowIDMapping(annotationList.get(i), i, true); diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorEngine.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorEngine.java index f86099113..d8a8bb653 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorEngine.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorEngine.java @@ -98,6 +98,7 @@ public class VariantRecalibratorEngine { try { model.precomputeDenominatorForEvaluation(); } catch( Exception e ) { + logger.warn("Model could not pre-compute denominators."); model.failedToConverge = true; return; } @@ -107,6 +108,7 @@ public class VariantRecalibratorEngine { for( final VariantDatum datum : data ) { final double thisLod = evaluateDatum( datum, model ); if( Double.isNaN(thisLod) ) { + logger.warn("Evaluate datum returned a NaN."); model.failedToConverge = true; return; } @@ -142,7 +144,7 @@ public class VariantRecalibratorEngine { // Private Methods used for generating a GaussianMixtureModel ///////////////////////////// - private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List data ) { + protected void variationalBayesExpectationMaximization(final GaussianMixtureModel model, final List data) { model.initializeRandomModel( data, VRAC.NUM_KMEANS_ITERATIONS ); diff --git a/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorModelOutputUnitTest.java b/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorModelOutputUnitTest.java index 56e515029..e24fdac2b 100644 --- a/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorModelOutputUnitTest.java +++ b/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorModelOutputUnitTest.java @@ -51,127 +51,130 @@ package org.broadinstitute.gatk.tools.walkers.variantrecalibration; -import static org.testng.Assert.*; - -import Jama.Matrix; -import org.apache.commons.lang.StringUtils; import org.apache.log4j.Logger; +import org.broadinstitute.gatk.utils.BaseTest; import org.broadinstitute.gatk.utils.report.GATKReport; import org.broadinstitute.gatk.utils.report.GATKReportTable; import org.testng.Assert; import org.testng.annotations.Test; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.PrintStream; import java.util.ArrayList; import java.util.List; import java.util.Random; -public class VariantRecalibratorModelOutputUnitTest { +public class VariantRecalibratorModelOutputUnitTest extends BaseTest { protected final static Logger logger = Logger.getLogger(VariantRecalibratorModelOutputUnitTest.class); private final boolean printTables = true; + private final int numAnnotations = 6; + private final double shrinkage = 1.0; + private final double dirichlet = 0.001; + private final double priorCounts = 20.0; + private final double epsilon = 1e-6; + private final String modelReportName = "vqsr_model.report"; @Test public void testVQSRModelOutput() { - final int numAnnotations = 6; - final double shrinkage = 1.0; - final double dirichlet = 0.001; - final double priorCounts = 20.0; - final int numGoodGaussians = 2; - final int numBadGaussians = 1; - final double epsilon = 1e-6; - - Random rand = new Random(12878); - MultivariateGaussian goodGaussian1 = new MultivariateGaussian(numAnnotations); - goodGaussian1.initializeRandomMu(rand); - goodGaussian1.initializeRandomSigma(rand); - - MultivariateGaussian goodGaussian2 = new MultivariateGaussian(numAnnotations); - goodGaussian2.initializeRandomMu(rand); - goodGaussian2.initializeRandomSigma(rand); - - MultivariateGaussian badGaussian1 = new MultivariateGaussian(numAnnotations); - badGaussian1.initializeRandomMu(rand); - badGaussian1.initializeRandomSigma(rand); - - List goodGaussianList = new ArrayList<>(); - goodGaussianList.add(goodGaussian1); - goodGaussianList.add(goodGaussian2); - - List badGaussianList = new ArrayList<>(); - badGaussianList.add(badGaussian1); - - GaussianMixtureModel goodModel = new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts); - GaussianMixtureModel badModel = new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts); + GaussianMixtureModel goodModel = getGoodGMM(); + GaussianMixtureModel badModel = getBadGMM(); if (printTables) { System.out.println("Good model mean matrix:"); - System.out.println(vectorToString(goodGaussian1.mu)); - System.out.println(vectorToString(goodGaussian2.mu)); + System.out.println(vectorToString(goodModel.getModelGaussians().get(0).mu)); + System.out.println(vectorToString(goodModel.getModelGaussians().get(1).mu)); System.out.println("\n\n"); System.out.println("Good model covariance matrices:"); - goodGaussian1.sigma.print(10, 3); - goodGaussian2.sigma.print(10, 3); + goodModel.getModelGaussians().get(0).sigma.print(10, 3); + goodModel.getModelGaussians().get(1).sigma.print(10, 3); System.out.println("\n\n"); System.out.println("Bad model mean matrix:\n"); - System.out.println(vectorToString(badGaussian1.mu)); + System.out.println(vectorToString(badModel.getModelGaussians().get(0).mu)); System.out.println("\n\n"); System.out.println("Bad model covariance matrix:"); - badGaussian1.sigma.print(10, 3); + badModel.getModelGaussians().get(0).sigma.print(10, 3); } VariantRecalibrator vqsr = new VariantRecalibrator(); - List annotationList = new ArrayList<>(); - annotationList.add("QD"); - annotationList.add("MQ"); - annotationList.add("FS"); - annotationList.add("SOR"); - annotationList.add("ReadPosRankSum"); - annotationList.add("MQRankSum"); - + List annotationList = getAnnotationList(); GATKReport report = vqsr.writeModelReport(goodModel, badModel, annotationList); - if(printTables) - report.print(System.out); + if(printTables) { + try { + PrintStream modelReporter = new PrintStream(this.privateTestDir+this.modelReportName); + report.print(modelReporter); + } catch (FileNotFoundException e) { + e.printStackTrace(); + } + } //Check values for Gaussian means GATKReportTable goodMus = report.getTable("PositiveModelMeans"); for(int i = 0; i < annotationList.size(); i++) { - Assert.assertEquals(goodGaussian1.mu[i], (Double)goodMus.get(0,annotationList.get(i)), epsilon); + Assert.assertEquals(goodModel.getModelGaussians().get(0).mu[i], (Double)goodMus.get(0,annotationList.get(i)), epsilon); } for(int i = 0; i < annotationList.size(); i++) { - Assert.assertEquals(goodGaussian2.mu[i], (Double)goodMus.get(1,annotationList.get(i)), epsilon); + Assert.assertEquals(goodModel.getModelGaussians().get(1).mu[i], (Double)goodMus.get(1,annotationList.get(i)), epsilon); } GATKReportTable badMus = report.getTable("NegativeModelMeans"); for(int i = 0; i < annotationList.size(); i++) { - Assert.assertEquals(badGaussian1.mu[i], (Double)badMus.get(0,annotationList.get(i)), epsilon); + Assert.assertEquals(badModel.getModelGaussians().get(0).mu[i], (Double)badMus.get(0,annotationList.get(i)), epsilon); } //Check values for Gaussian covariances GATKReportTable goodSigma = report.getTable("PositiveModelCovariances"); for(int i = 0; i < annotationList.size(); i++) { for(int j = 0; j < annotationList.size(); j++) { - Assert.assertEquals(goodGaussian1.sigma.get(i,j), (Double)goodSigma.get(i,annotationList.get(j)), epsilon); + Assert.assertEquals(goodModel.getModelGaussians().get(0).sigma.get(i,j), (Double)goodSigma.get(i,annotationList.get(j)), epsilon); } } //add annotationList.size() to row indexes for second Gaussian because the matrices are concatenated by row in the report for(int i = 0; i < annotationList.size(); i++) { for(int j = 0; j < annotationList.size(); j++) { - Assert.assertEquals(goodGaussian2.sigma.get(i,j), (Double)goodSigma.get(annotationList.size()+i,annotationList.get(j)), epsilon); + Assert.assertEquals(goodModel.getModelGaussians().get(1).sigma.get(i,j), (Double)goodSigma.get(annotationList.size()+i,annotationList.get(j)), epsilon); } } GATKReportTable badSigma = report.getTable("NegativeModelCovariances"); for(int i = 0; i < annotationList.size(); i++) { for(int j = 0; j < annotationList.size(); j++) { - Assert.assertEquals(badGaussian1.sigma.get(i,j), (Double)badSigma.get(i,annotationList.get(j)), epsilon); + Assert.assertEquals(badModel.getModelGaussians().get(0).sigma.get(i,j), (Double)badSigma.get(i,annotationList.get(j)), epsilon); } } } + + @Test + public void testVQSRModelInput(){ + final File inputFile = new File(this.privateTestDir + this.modelReportName); + final GATKReport report = new GATKReport(inputFile); + + // Now test model report reading + // Read all the tables + final GATKReportTable badMus = report.getTable("NegativeModelMeans"); + final GATKReportTable badSigma = report.getTable("NegativeModelCovariances"); + final GATKReportTable nPMixTable = report.getTable("BadGaussianPMix"); + + final GATKReportTable goodMus = report.getTable("PositiveModelMeans"); + final GATKReportTable goodSigma = report.getTable("PositiveModelCovariances"); + final GATKReportTable pPMixTable = report.getTable("GoodGaussianPMix"); + + List annotationList = getAnnotationList(); + VariantRecalibrator vqsr = new VariantRecalibrator(); + + GaussianMixtureModel goodModelFromFile = vqsr.GMMFromTables(goodMus, goodSigma, pPMixTable, annotationList.size()); + GaussianMixtureModel badModelFromFile = vqsr.GMMFromTables(badMus, badSigma, nPMixTable, annotationList.size()); + + testGMMsForEquality(getGoodGMM(), goodModelFromFile, epsilon); + testGMMsForEquality(getBadGMM(), badModelFromFile, epsilon); + } + @Test //This is tested separately to avoid setting up a VariantDataManager and populating it with fake data public void testAnnotationNormalizationOutput() { @@ -211,4 +214,66 @@ public class VariantRecalibratorModelOutputUnitTest { return returnString; } + private void testGMMsForEquality(GaussianMixtureModel gmm1, GaussianMixtureModel gmm2, double epsilon){ + Assert.assertEquals(gmm1.getModelGaussians().size(), gmm2.getModelGaussians().size(), 0); + + for(int k = 0; k < gmm1.getModelGaussians().size(); k++) { + final MultivariateGaussian g = gmm1.getModelGaussians().get(k); + final MultivariateGaussian gFile = gmm2.getModelGaussians().get(k); + + Assert.assertEquals(g.pMixtureLog10, gFile.pMixtureLog10); + + for(int i = 0; i < g.mu.length; i++){ + Assert.assertEquals(g.mu[i], gFile.mu[i], epsilon); + } + + for(int i = 0; i < g.sigma.getRowDimension(); i++) { + for (int j = 0; j < g.sigma.getColumnDimension(); j++) { + Assert.assertEquals(g.sigma.get(i, j), gFile.sigma.get(i, j), epsilon); + } + } + } + } + + private List getAnnotationList(){ + List annotationList = new ArrayList<>(); + annotationList.add("QD"); + annotationList.add("MQ"); + annotationList.add("FS"); + annotationList.add("SOR"); + annotationList.add("ReadPosRankSum"); + annotationList.add("MQRankSum"); + return annotationList; + } + + private GaussianMixtureModel getGoodGMM(){ + Random rand = new Random(12878); + MultivariateGaussian goodGaussian1 = new MultivariateGaussian(numAnnotations); + goodGaussian1.initializeRandomMu(rand); + goodGaussian1.initializeRandomSigma(rand); + + MultivariateGaussian goodGaussian2 = new MultivariateGaussian(numAnnotations); + goodGaussian2.initializeRandomMu(rand); + goodGaussian2.initializeRandomSigma(rand); + + List goodGaussianList = new ArrayList<>(); + goodGaussianList.add(goodGaussian1); + goodGaussianList.add(goodGaussian2); + + return new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts); + } + + private GaussianMixtureModel getBadGMM(){ + Random rand = new Random(12878); + MultivariateGaussian badGaussian1 = new MultivariateGaussian(numAnnotations); + + badGaussian1.initializeRandomMu(rand); + badGaussian1.initializeRandomSigma(rand); + + List badGaussianList = new ArrayList<>(); + badGaussianList.add(badGaussian1); + + return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts); + } + } \ No newline at end of file diff --git a/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java b/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java index 85d2386a0..ff7aed473 100644 --- a/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java +++ b/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java @@ -150,10 +150,8 @@ public class GATKReportTable { // read a data line final String dataLine = reader.readLine(); final List lineSplits = Arrays.asList(TextFormattingUtils.splitFixedWidth(dataLine, columnStarts)); - underlyingData.add(new Object[nColumns]); for ( int columnIndex = 0; columnIndex < nColumns; columnIndex++ ) { - final GATKReportDataType type = columnInfo.get(columnIndex).getDataType(); final String columnName = columnNames[columnIndex]; set(i, columnName, type.Parse(lineSplits.get(columnIndex)));