From ed440f1684bdbfedb98e6e3219a06fb6fbd3963c Mon Sep 17 00:00:00 2001 From: Samuel Friedman Date: Fri, 5 May 2017 16:29:01 -0400 Subject: [PATCH] respond to review comments --- .../VariantRecalibrator.java | 2 +- ...ariantRecalibratorModelOutputUnitTest.java | 163 +++++++++--------- .../gatk/utils/report/GATKReportTable.java | 1 - 3 files changed, 80 insertions(+), 86 deletions(-) diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index d44e53dba..546c0f05b 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -278,7 +278,7 @@ public class VariantRecalibrator extends RodWalker goodGaussianList = new ArrayList<>(); - goodGaussianList.add(goodGaussian1); - goodGaussianList.add(goodGaussian2); - - List badGaussianList = new ArrayList<>(); - badGaussianList.add(badGaussian1); - - GaussianMixtureModel goodModel = new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts); - GaussianMixtureModel badModel = new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts); + GaussianMixtureModel goodModel = getGoodGMM(); + GaussianMixtureModel badModel = getBadGMM(); if (printTables) { System.out.println("Good model mean matrix:"); - System.out.println(vectorToString(goodGaussian1.mu)); - System.out.println(vectorToString(goodGaussian2.mu)); + System.out.println(vectorToString(goodModel.getModelGaussians().get(0).mu)); + System.out.println(vectorToString(goodModel.getModelGaussians().get(1).mu)); System.out.println("\n\n"); System.out.println("Good model covariance matrices:"); - goodGaussian1.sigma.print(10, 3); - goodGaussian2.sigma.print(10, 3); + goodModel.getModelGaussians().get(0).sigma.print(10, 3); + goodModel.getModelGaussians().get(1).sigma.print(10, 3); System.out.println("\n\n"); System.out.println("Bad model mean matrix:\n"); - System.out.println(vectorToString(badGaussian1.mu)); + System.out.println(vectorToString(badModel.getModelGaussians().get(0).mu)); System.out.println("\n\n"); System.out.println("Bad model covariance matrix:"); - badGaussian1.sigma.print(10, 3); + badModel.getModelGaussians().get(0).sigma.print(10, 3); } VariantRecalibrator vqsr = new VariantRecalibrator(); - List annotationList = new ArrayList<>(); - annotationList.add("QD"); - annotationList.add("MQ"); - annotationList.add("FS"); - annotationList.add("SOR"); - annotationList.add("ReadPosRankSum"); - annotationList.add("MQRankSum"); - + List annotationList = getAnnotationList(); GATKReport report = vqsr.writeModelReport(goodModel, badModel, annotationList); - if(printTables) - report.print(System.out); + if(printTables) { + try { + PrintStream modelReporter = new PrintStream(this.privateTestDir+this.modelReportName); + report.print(modelReporter); + } catch (FileNotFoundException e) { + e.printStackTrace(); + } + } //Check values for Gaussian means GATKReportTable goodMus = report.getTable("PositiveModelMeans"); for(int i = 0; i < annotationList.size(); i++) { - Assert.assertEquals(goodGaussian1.mu[i], (Double)goodMus.get(0,annotationList.get(i)), epsilon); + Assert.assertEquals(goodModel.getModelGaussians().get(0).mu[i], (Double)goodMus.get(0,annotationList.get(i)), epsilon); } for(int i = 0; i < annotationList.size(); i++) { - Assert.assertEquals(goodGaussian2.mu[i], (Double)goodMus.get(1,annotationList.get(i)), epsilon); + Assert.assertEquals(goodModel.getModelGaussians().get(1).mu[i], (Double)goodMus.get(1,annotationList.get(i)), epsilon); } GATKReportTable badMus = report.getTable("NegativeModelMeans"); for(int i = 0; i < annotationList.size(); i++) { - Assert.assertEquals(badGaussian1.mu[i], (Double)badMus.get(0,annotationList.get(i)), epsilon); + Assert.assertEquals(badModel.getModelGaussians().get(0).mu[i], (Double)badMus.get(0,annotationList.get(i)), epsilon); } //Check values for Gaussian covariances GATKReportTable goodSigma = report.getTable("PositiveModelCovariances"); for(int i = 0; i < annotationList.size(); i++) { for(int j = 0; j < annotationList.size(); j++) { - Assert.assertEquals(goodGaussian1.sigma.get(i,j), (Double)goodSigma.get(i,annotationList.get(j)), epsilon); + Assert.assertEquals(goodModel.getModelGaussians().get(0).sigma.get(i,j), (Double)goodSigma.get(i,annotationList.get(j)), epsilon); } } //add annotationList.size() to row indexes for second Gaussian because the matrices are concatenated by row in the report for(int i = 0; i < annotationList.size(); i++) { for(int j = 0; j < annotationList.size(); j++) { - Assert.assertEquals(goodGaussian2.sigma.get(i,j), (Double)goodSigma.get(annotationList.size()+i,annotationList.get(j)), epsilon); + Assert.assertEquals(goodModel.getModelGaussians().get(1).sigma.get(i,j), (Double)goodSigma.get(annotationList.size()+i,annotationList.get(j)), epsilon); } } GATKReportTable badSigma = report.getTable("NegativeModelCovariances"); for(int i = 0; i < annotationList.size(); i++) { for(int j = 0; j < annotationList.size(); j++) { - Assert.assertEquals(badGaussian1.sigma.get(i,j), (Double)badSigma.get(i,annotationList.get(j)), epsilon); + Assert.assertEquals(badModel.getModelGaussians().get(0).sigma.get(i,j), (Double)badSigma.get(i,annotationList.get(j)), epsilon); } } } @@ -172,39 +152,8 @@ public class VariantRecalibratorModelOutputUnitTest { @Test public void testVQSRModelInput(){ - Random rand = new Random(12878); - MultivariateGaussian goodGaussian1 = new MultivariateGaussian(numAnnotations); - goodGaussian1.initializeRandomMu(rand); - goodGaussian1.initializeRandomSigma(rand); - - MultivariateGaussian goodGaussian2 = new MultivariateGaussian(numAnnotations); - goodGaussian2.initializeRandomMu(rand); - goodGaussian2.initializeRandomSigma(rand); - - MultivariateGaussian badGaussian1 = new MultivariateGaussian(numAnnotations); - badGaussian1.initializeRandomMu(rand); - badGaussian1.initializeRandomSigma(rand); - - List goodGaussianList = new ArrayList<>(); - goodGaussianList.add(goodGaussian1); - goodGaussianList.add(goodGaussian2); - - List badGaussianList = new ArrayList<>(); - badGaussianList.add(badGaussian1); - - GaussianMixtureModel goodModel = new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts); - GaussianMixtureModel badModel = new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts); - - VariantRecalibrator vqsr = new VariantRecalibrator(); - List annotationList = new ArrayList<>(); - annotationList.add("QD"); - annotationList.add("MQ"); - annotationList.add("FS"); - annotationList.add("SOR"); - annotationList.add("ReadPosRankSum"); - annotationList.add("MQRankSum"); - - GATKReport report = vqsr.writeModelReport(goodModel, badModel, annotationList); + final File inputFile = new File(this.privateTestDir + this.modelReportName); + final GATKReport report = new GATKReport(inputFile); // Now test model report reading // Read all the tables @@ -216,11 +165,14 @@ public class VariantRecalibratorModelOutputUnitTest { final GATKReportTable goodSigma = report.getTable("PositiveModelCovariances"); final GATKReportTable pPMixTable = report.getTable("GoodGaussianPMix"); + List annotationList = getAnnotationList(); + VariantRecalibrator vqsr = new VariantRecalibrator(); + GaussianMixtureModel goodModelFromFile = vqsr.GMMFromTables(goodMus, goodSigma, pPMixTable, annotationList.size()); GaussianMixtureModel badModelFromFile = vqsr.GMMFromTables(badMus, badSigma, nPMixTable, annotationList.size()); - testGMMsForEquality(goodModel, goodModelFromFile, epsilon); - testGMMsForEquality(badModel, badModelFromFile, epsilon); + testGMMsForEquality(getGoodGMM(), goodModelFromFile, epsilon); + testGMMsForEquality(getBadGMM(), badModelFromFile, epsilon); } @Test @@ -269,6 +221,8 @@ public class VariantRecalibratorModelOutputUnitTest { final MultivariateGaussian g = gmm1.getModelGaussians().get(k); final MultivariateGaussian gFile = gmm2.getModelGaussians().get(k); + Assert.assertEquals(g.pMixtureLog10, gFile.pMixtureLog10); + for(int i = 0; i < g.mu.length; i++){ Assert.assertEquals(g.mu[i], gFile.mu[i], epsilon); } @@ -281,4 +235,45 @@ public class VariantRecalibratorModelOutputUnitTest { } } + private List getAnnotationList(){ + List annotationList = new ArrayList<>(); + annotationList.add("QD"); + annotationList.add("MQ"); + annotationList.add("FS"); + annotationList.add("SOR"); + annotationList.add("ReadPosRankSum"); + annotationList.add("MQRankSum"); + return annotationList; + } + + private GaussianMixtureModel getGoodGMM(){ + Random rand = new Random(12878); + MultivariateGaussian goodGaussian1 = new MultivariateGaussian(numAnnotations); + goodGaussian1.initializeRandomMu(rand); + goodGaussian1.initializeRandomSigma(rand); + + MultivariateGaussian goodGaussian2 = new MultivariateGaussian(numAnnotations); + goodGaussian2.initializeRandomMu(rand); + goodGaussian2.initializeRandomSigma(rand); + + List goodGaussianList = new ArrayList<>(); + goodGaussianList.add(goodGaussian1); + goodGaussianList.add(goodGaussian2); + + return new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts); + } + + private GaussianMixtureModel getBadGMM(){ + Random rand = new Random(12878); + MultivariateGaussian badGaussian1 = new MultivariateGaussian(numAnnotations); + + badGaussian1.initializeRandomMu(rand); + badGaussian1.initializeRandomSigma(rand); + + List badGaussianList = new ArrayList<>(); + badGaussianList.add(badGaussian1); + + return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts); + } + } \ No newline at end of file diff --git a/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java b/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java index 4d87bab84..ff7aed473 100644 --- a/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java +++ b/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java @@ -152,7 +152,6 @@ public class GATKReportTable { final List lineSplits = Arrays.asList(TextFormattingUtils.splitFixedWidth(dataLine, columnStarts)); underlyingData.add(new Object[nColumns]); for ( int columnIndex = 0; columnIndex < nColumns; columnIndex++ ) { - final GATKReportDataType type = columnInfo.get(columnIndex).getDataType(); final String columnName = columnNames[columnIndex]; set(i, columnName, type.Parse(lineSplits.get(columnIndex)));