respond to review comments
This commit is contained in:
parent
68bdb93c8c
commit
ed440f1684
|
|
@ -278,7 +278,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
* to help describe the normalization. The model fit report can be read in with our R gsalib package. Individual
|
||||
* model Gaussians can be subset by the value in the "Gaussian" column if desired.
|
||||
*/
|
||||
@Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model fit to the file specified by -modelFile or to stdout", required=false)
|
||||
@Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model to this file path.", required=false)
|
||||
private String outputModel = null;
|
||||
@Argument(fullName="input_model", shortName = "inputModel", doc="If specified, the variant recalibrator will read the VQSR model from this file path.", required=false)
|
||||
private String inputModel = "";
|
||||
|
|
|
|||
|
|
@ -51,21 +51,21 @@
|
|||
|
||||
package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
|
||||
|
||||
import static org.testng.Assert.*;
|
||||
|
||||
import Jama.Matrix;
|
||||
import org.apache.commons.lang.StringUtils;
|
||||
import org.apache.log4j.Logger;
|
||||
import org.broadinstitute.gatk.utils.BaseTest;
|
||||
import org.broadinstitute.gatk.utils.report.GATKReport;
|
||||
import org.broadinstitute.gatk.utils.report.GATKReportTable;
|
||||
import org.testng.Assert;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.PrintStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.Random;
|
||||
|
||||
public class VariantRecalibratorModelOutputUnitTest {
|
||||
public class VariantRecalibratorModelOutputUnitTest extends BaseTest {
|
||||
protected final static Logger logger = Logger.getLogger(VariantRecalibratorModelOutputUnitTest.class);
|
||||
private final boolean printTables = true;
|
||||
private final int numAnnotations = 6;
|
||||
|
|
@ -73,98 +73,78 @@ public class VariantRecalibratorModelOutputUnitTest {
|
|||
private final double dirichlet = 0.001;
|
||||
private final double priorCounts = 20.0;
|
||||
private final double epsilon = 1e-6;
|
||||
private final String modelReportName = "vqsr_model.report";
|
||||
|
||||
@Test
|
||||
public void testVQSRModelOutput() {
|
||||
Random rand = new Random(12878);
|
||||
MultivariateGaussian goodGaussian1 = new MultivariateGaussian(numAnnotations);
|
||||
goodGaussian1.initializeRandomMu(rand);
|
||||
goodGaussian1.initializeRandomSigma(rand);
|
||||
|
||||
MultivariateGaussian goodGaussian2 = new MultivariateGaussian(numAnnotations);
|
||||
goodGaussian2.initializeRandomMu(rand);
|
||||
goodGaussian2.initializeRandomSigma(rand);
|
||||
|
||||
MultivariateGaussian badGaussian1 = new MultivariateGaussian(numAnnotations);
|
||||
badGaussian1.initializeRandomMu(rand);
|
||||
badGaussian1.initializeRandomSigma(rand);
|
||||
|
||||
List<MultivariateGaussian> goodGaussianList = new ArrayList<>();
|
||||
goodGaussianList.add(goodGaussian1);
|
||||
goodGaussianList.add(goodGaussian2);
|
||||
|
||||
List<MultivariateGaussian> badGaussianList = new ArrayList<>();
|
||||
badGaussianList.add(badGaussian1);
|
||||
|
||||
GaussianMixtureModel goodModel = new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts);
|
||||
GaussianMixtureModel badModel = new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts);
|
||||
GaussianMixtureModel goodModel = getGoodGMM();
|
||||
GaussianMixtureModel badModel = getBadGMM();
|
||||
|
||||
if (printTables) {
|
||||
System.out.println("Good model mean matrix:");
|
||||
System.out.println(vectorToString(goodGaussian1.mu));
|
||||
System.out.println(vectorToString(goodGaussian2.mu));
|
||||
System.out.println(vectorToString(goodModel.getModelGaussians().get(0).mu));
|
||||
System.out.println(vectorToString(goodModel.getModelGaussians().get(1).mu));
|
||||
System.out.println("\n\n");
|
||||
|
||||
System.out.println("Good model covariance matrices:");
|
||||
goodGaussian1.sigma.print(10, 3);
|
||||
goodGaussian2.sigma.print(10, 3);
|
||||
goodModel.getModelGaussians().get(0).sigma.print(10, 3);
|
||||
goodModel.getModelGaussians().get(1).sigma.print(10, 3);
|
||||
System.out.println("\n\n");
|
||||
|
||||
System.out.println("Bad model mean matrix:\n");
|
||||
System.out.println(vectorToString(badGaussian1.mu));
|
||||
System.out.println(vectorToString(badModel.getModelGaussians().get(0).mu));
|
||||
System.out.println("\n\n");
|
||||
|
||||
System.out.println("Bad model covariance matrix:");
|
||||
badGaussian1.sigma.print(10, 3);
|
||||
badModel.getModelGaussians().get(0).sigma.print(10, 3);
|
||||
}
|
||||
|
||||
VariantRecalibrator vqsr = new VariantRecalibrator();
|
||||
List<String> annotationList = new ArrayList<>();
|
||||
annotationList.add("QD");
|
||||
annotationList.add("MQ");
|
||||
annotationList.add("FS");
|
||||
annotationList.add("SOR");
|
||||
annotationList.add("ReadPosRankSum");
|
||||
annotationList.add("MQRankSum");
|
||||
|
||||
List<String> annotationList = getAnnotationList();
|
||||
|
||||
GATKReport report = vqsr.writeModelReport(goodModel, badModel, annotationList);
|
||||
if(printTables)
|
||||
report.print(System.out);
|
||||
if(printTables) {
|
||||
try {
|
||||
PrintStream modelReporter = new PrintStream(this.privateTestDir+this.modelReportName);
|
||||
report.print(modelReporter);
|
||||
} catch (FileNotFoundException e) {
|
||||
e.printStackTrace();
|
||||
}
|
||||
}
|
||||
|
||||
//Check values for Gaussian means
|
||||
GATKReportTable goodMus = report.getTable("PositiveModelMeans");
|
||||
for(int i = 0; i < annotationList.size(); i++) {
|
||||
Assert.assertEquals(goodGaussian1.mu[i], (Double)goodMus.get(0,annotationList.get(i)), epsilon);
|
||||
Assert.assertEquals(goodModel.getModelGaussians().get(0).mu[i], (Double)goodMus.get(0,annotationList.get(i)), epsilon);
|
||||
}
|
||||
for(int i = 0; i < annotationList.size(); i++) {
|
||||
Assert.assertEquals(goodGaussian2.mu[i], (Double)goodMus.get(1,annotationList.get(i)), epsilon);
|
||||
Assert.assertEquals(goodModel.getModelGaussians().get(1).mu[i], (Double)goodMus.get(1,annotationList.get(i)), epsilon);
|
||||
}
|
||||
|
||||
GATKReportTable badMus = report.getTable("NegativeModelMeans");
|
||||
for(int i = 0; i < annotationList.size(); i++) {
|
||||
Assert.assertEquals(badGaussian1.mu[i], (Double)badMus.get(0,annotationList.get(i)), epsilon);
|
||||
Assert.assertEquals(badModel.getModelGaussians().get(0).mu[i], (Double)badMus.get(0,annotationList.get(i)), epsilon);
|
||||
}
|
||||
|
||||
//Check values for Gaussian covariances
|
||||
GATKReportTable goodSigma = report.getTable("PositiveModelCovariances");
|
||||
for(int i = 0; i < annotationList.size(); i++) {
|
||||
for(int j = 0; j < annotationList.size(); j++) {
|
||||
Assert.assertEquals(goodGaussian1.sigma.get(i,j), (Double)goodSigma.get(i,annotationList.get(j)), epsilon);
|
||||
Assert.assertEquals(goodModel.getModelGaussians().get(0).sigma.get(i,j), (Double)goodSigma.get(i,annotationList.get(j)), epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
//add annotationList.size() to row indexes for second Gaussian because the matrices are concatenated by row in the report
|
||||
for(int i = 0; i < annotationList.size(); i++) {
|
||||
for(int j = 0; j < annotationList.size(); j++) {
|
||||
Assert.assertEquals(goodGaussian2.sigma.get(i,j), (Double)goodSigma.get(annotationList.size()+i,annotationList.get(j)), epsilon);
|
||||
Assert.assertEquals(goodModel.getModelGaussians().get(1).sigma.get(i,j), (Double)goodSigma.get(annotationList.size()+i,annotationList.get(j)), epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
GATKReportTable badSigma = report.getTable("NegativeModelCovariances");
|
||||
for(int i = 0; i < annotationList.size(); i++) {
|
||||
for(int j = 0; j < annotationList.size(); j++) {
|
||||
Assert.assertEquals(badGaussian1.sigma.get(i,j), (Double)badSigma.get(i,annotationList.get(j)), epsilon);
|
||||
Assert.assertEquals(badModel.getModelGaussians().get(0).sigma.get(i,j), (Double)badSigma.get(i,annotationList.get(j)), epsilon);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -172,39 +152,8 @@ public class VariantRecalibratorModelOutputUnitTest {
|
|||
|
||||
@Test
|
||||
public void testVQSRModelInput(){
|
||||
Random rand = new Random(12878);
|
||||
MultivariateGaussian goodGaussian1 = new MultivariateGaussian(numAnnotations);
|
||||
goodGaussian1.initializeRandomMu(rand);
|
||||
goodGaussian1.initializeRandomSigma(rand);
|
||||
|
||||
MultivariateGaussian goodGaussian2 = new MultivariateGaussian(numAnnotations);
|
||||
goodGaussian2.initializeRandomMu(rand);
|
||||
goodGaussian2.initializeRandomSigma(rand);
|
||||
|
||||
MultivariateGaussian badGaussian1 = new MultivariateGaussian(numAnnotations);
|
||||
badGaussian1.initializeRandomMu(rand);
|
||||
badGaussian1.initializeRandomSigma(rand);
|
||||
|
||||
List<MultivariateGaussian> goodGaussianList = new ArrayList<>();
|
||||
goodGaussianList.add(goodGaussian1);
|
||||
goodGaussianList.add(goodGaussian2);
|
||||
|
||||
List<MultivariateGaussian> badGaussianList = new ArrayList<>();
|
||||
badGaussianList.add(badGaussian1);
|
||||
|
||||
GaussianMixtureModel goodModel = new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts);
|
||||
GaussianMixtureModel badModel = new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts);
|
||||
|
||||
VariantRecalibrator vqsr = new VariantRecalibrator();
|
||||
List<String> annotationList = new ArrayList<>();
|
||||
annotationList.add("QD");
|
||||
annotationList.add("MQ");
|
||||
annotationList.add("FS");
|
||||
annotationList.add("SOR");
|
||||
annotationList.add("ReadPosRankSum");
|
||||
annotationList.add("MQRankSum");
|
||||
|
||||
GATKReport report = vqsr.writeModelReport(goodModel, badModel, annotationList);
|
||||
final File inputFile = new File(this.privateTestDir + this.modelReportName);
|
||||
final GATKReport report = new GATKReport(inputFile);
|
||||
|
||||
// Now test model report reading
|
||||
// Read all the tables
|
||||
|
|
@ -216,11 +165,14 @@ public class VariantRecalibratorModelOutputUnitTest {
|
|||
final GATKReportTable goodSigma = report.getTable("PositiveModelCovariances");
|
||||
final GATKReportTable pPMixTable = report.getTable("GoodGaussianPMix");
|
||||
|
||||
List<String> annotationList = getAnnotationList();
|
||||
VariantRecalibrator vqsr = new VariantRecalibrator();
|
||||
|
||||
GaussianMixtureModel goodModelFromFile = vqsr.GMMFromTables(goodMus, goodSigma, pPMixTable, annotationList.size());
|
||||
GaussianMixtureModel badModelFromFile = vqsr.GMMFromTables(badMus, badSigma, nPMixTable, annotationList.size());
|
||||
|
||||
testGMMsForEquality(goodModel, goodModelFromFile, epsilon);
|
||||
testGMMsForEquality(badModel, badModelFromFile, epsilon);
|
||||
testGMMsForEquality(getGoodGMM(), goodModelFromFile, epsilon);
|
||||
testGMMsForEquality(getBadGMM(), badModelFromFile, epsilon);
|
||||
}
|
||||
|
||||
@Test
|
||||
|
|
@ -269,6 +221,8 @@ public class VariantRecalibratorModelOutputUnitTest {
|
|||
final MultivariateGaussian g = gmm1.getModelGaussians().get(k);
|
||||
final MultivariateGaussian gFile = gmm2.getModelGaussians().get(k);
|
||||
|
||||
Assert.assertEquals(g.pMixtureLog10, gFile.pMixtureLog10);
|
||||
|
||||
for(int i = 0; i < g.mu.length; i++){
|
||||
Assert.assertEquals(g.mu[i], gFile.mu[i], epsilon);
|
||||
}
|
||||
|
|
@ -281,4 +235,45 @@ public class VariantRecalibratorModelOutputUnitTest {
|
|||
}
|
||||
}
|
||||
|
||||
private List<String> getAnnotationList(){
|
||||
List<String> annotationList = new ArrayList<>();
|
||||
annotationList.add("QD");
|
||||
annotationList.add("MQ");
|
||||
annotationList.add("FS");
|
||||
annotationList.add("SOR");
|
||||
annotationList.add("ReadPosRankSum");
|
||||
annotationList.add("MQRankSum");
|
||||
return annotationList;
|
||||
}
|
||||
|
||||
private GaussianMixtureModel getGoodGMM(){
|
||||
Random rand = new Random(12878);
|
||||
MultivariateGaussian goodGaussian1 = new MultivariateGaussian(numAnnotations);
|
||||
goodGaussian1.initializeRandomMu(rand);
|
||||
goodGaussian1.initializeRandomSigma(rand);
|
||||
|
||||
MultivariateGaussian goodGaussian2 = new MultivariateGaussian(numAnnotations);
|
||||
goodGaussian2.initializeRandomMu(rand);
|
||||
goodGaussian2.initializeRandomSigma(rand);
|
||||
|
||||
List<MultivariateGaussian> goodGaussianList = new ArrayList<>();
|
||||
goodGaussianList.add(goodGaussian1);
|
||||
goodGaussianList.add(goodGaussian2);
|
||||
|
||||
return new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts);
|
||||
}
|
||||
|
||||
private GaussianMixtureModel getBadGMM(){
|
||||
Random rand = new Random(12878);
|
||||
MultivariateGaussian badGaussian1 = new MultivariateGaussian(numAnnotations);
|
||||
|
||||
badGaussian1.initializeRandomMu(rand);
|
||||
badGaussian1.initializeRandomSigma(rand);
|
||||
|
||||
List<MultivariateGaussian> badGaussianList = new ArrayList<>();
|
||||
badGaussianList.add(badGaussian1);
|
||||
|
||||
return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts);
|
||||
}
|
||||
|
||||
}
|
||||
|
|
@ -152,7 +152,6 @@ public class GATKReportTable {
|
|||
final List<String> lineSplits = Arrays.asList(TextFormattingUtils.splitFixedWidth(dataLine, columnStarts));
|
||||
underlyingData.add(new Object[nColumns]);
|
||||
for ( int columnIndex = 0; columnIndex < nColumns; columnIndex++ ) {
|
||||
|
||||
final GATKReportDataType type = columnInfo.get(columnIndex).getDataType();
|
||||
final String columnName = columnNames[columnIndex];
|
||||
set(i, columnName, type.Parse(lineSplits.get(columnIndex)));
|
||||
|
|
|
|||
Loading…
Reference in New Issue