Merge pull request #1575 from broadinstitute/snf_ReadGMMReport
Read model report and reconstruct GMM from it in VQSR
This commit is contained in:
commit
00b7135afe
|
|
@ -86,6 +86,11 @@ public class GaussianMixtureModel {
|
||||||
gaussians = new ArrayList<>( numGaussians );
|
gaussians = new ArrayList<>( numGaussians );
|
||||||
for( int iii = 0; iii < numGaussians; iii++ ) {
|
for( int iii = 0; iii < numGaussians; iii++ ) {
|
||||||
final MultivariateGaussian gaussian = new MultivariateGaussian( numAnnotations );
|
final MultivariateGaussian gaussian = new MultivariateGaussian( numAnnotations );
|
||||||
|
gaussian.pMixtureLog10 = Math.log10( 1.0 / ((double)numGaussians) );
|
||||||
|
gaussian.sumProb = 1.0 / ((double) numGaussians);
|
||||||
|
gaussian.hyperParameter_a = priorCounts;
|
||||||
|
gaussian.hyperParameter_b = shrinkage;
|
||||||
|
gaussian.hyperParameter_lambda = dirichletParameter;
|
||||||
gaussians.add( gaussian );
|
gaussians.add( gaussian );
|
||||||
}
|
}
|
||||||
this.shrinkage = shrinkage;
|
this.shrinkage = shrinkage;
|
||||||
|
|
@ -105,6 +110,11 @@ public class GaussianMixtureModel {
|
||||||
this.shrinkage = shrinkage;
|
this.shrinkage = shrinkage;
|
||||||
this.dirichletParameter = dirichletParameter;
|
this.dirichletParameter = dirichletParameter;
|
||||||
this.priorCounts = priorCounts;
|
this.priorCounts = priorCounts;
|
||||||
|
for( final MultivariateGaussian gaussian : gaussians ) {
|
||||||
|
gaussian.hyperParameter_a = priorCounts;
|
||||||
|
gaussian.hyperParameter_b = shrinkage;
|
||||||
|
gaussian.hyperParameter_lambda = dirichletParameter;
|
||||||
|
}
|
||||||
empiricalMu = new double[numAnnotations];
|
empiricalMu = new double[numAnnotations];
|
||||||
empiricalSigma = new Matrix(numAnnotations, numAnnotations);
|
empiricalSigma = new Matrix(numAnnotations, numAnnotations);
|
||||||
isModelReadyForEvaluation = false;
|
isModelReadyForEvaluation = false;
|
||||||
|
|
|
||||||
|
|
@ -271,4 +271,5 @@ public class MultivariateGaussian {
|
||||||
|
|
||||||
resetPVarInGaussian(); // clean up some memory
|
resetPVarInGaussian(); // clean up some memory
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
@ -277,7 +277,7 @@ public class VariantDataManager {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
logger.info( "Training with worst " + trainingData.size() + " scoring variants --> variants with LOD <= " + String.format("%.4f", VRAC.BAD_LOD_CUTOFF) + "." );
|
logger.info( "Selected worst " + trainingData.size() + " scoring variants --> variants with LOD <= " + String.format("%.4f", VRAC.BAD_LOD_CUTOFF) + "." );
|
||||||
|
|
||||||
return trainingData;
|
return trainingData;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -66,6 +66,7 @@ import org.broadinstitute.gatk.utils.R.RScriptExecutor;
|
||||||
import org.broadinstitute.gatk.utils.Utils;
|
import org.broadinstitute.gatk.utils.Utils;
|
||||||
import org.broadinstitute.gatk.utils.help.HelpConstants;
|
import org.broadinstitute.gatk.utils.help.HelpConstants;
|
||||||
import org.broadinstitute.gatk.utils.report.GATKReport;
|
import org.broadinstitute.gatk.utils.report.GATKReport;
|
||||||
|
import org.broadinstitute.gatk.utils.report.GATKReportColumn;
|
||||||
import org.broadinstitute.gatk.utils.report.GATKReportTable;
|
import org.broadinstitute.gatk.utils.report.GATKReportTable;
|
||||||
import org.broadinstitute.gatk.utils.variant.GATKVariantContextUtils;
|
import org.broadinstitute.gatk.utils.variant.GATKVariantContextUtils;
|
||||||
import htsjdk.variant.vcf.VCFHeader;
|
import htsjdk.variant.vcf.VCFHeader;
|
||||||
|
|
@ -80,10 +81,15 @@ import htsjdk.variant.variantcontext.writer.VariantContextWriter;
|
||||||
import java.io.File;
|
import java.io.File;
|
||||||
import java.io.FileNotFoundException;
|
import java.io.FileNotFoundException;
|
||||||
import java.io.PrintStream;
|
import java.io.PrintStream;
|
||||||
|
import java.nio.file.Files;
|
||||||
import java.util.*;
|
import java.util.*;
|
||||||
|
|
||||||
import Jama.Matrix;
|
import Jama.Matrix;
|
||||||
|
|
||||||
|
|
||||||
|
import java.io.FileWriter;
|
||||||
|
import java.io.BufferedWriter;
|
||||||
|
import java.io.IOException;
|
||||||
/**
|
/**
|
||||||
* Build a recalibration model to score variant quality for filtering purposes
|
* Build a recalibration model to score variant quality for filtering purposes
|
||||||
*
|
*
|
||||||
|
|
@ -272,10 +278,12 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
* to help describe the normalization. The model fit report can be read in with our R gsalib package. Individual
|
* to help describe the normalization. The model fit report can be read in with our R gsalib package. Individual
|
||||||
* model Gaussians can be subset by the value in the "Gaussian" column if desired.
|
* model Gaussians can be subset by the value in the "Gaussian" column if desired.
|
||||||
*/
|
*/
|
||||||
@Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model fit to the file specified by -modelFile or to stdout", required=false)
|
@Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model to this file path.", required=false)
|
||||||
private boolean outputModel = false;
|
private String outputModel = null;
|
||||||
@Output(fullName="model_file", shortName = "modelFile", doc="A GATKReport containing the positive and negative model fits", required=false)
|
@Argument(fullName="input_model", shortName = "inputModel", doc="If specified, the variant recalibrator will read the VQSR model from this file path.", required=false)
|
||||||
protected PrintStream modelReport = null;
|
private String inputModel = "";
|
||||||
|
//@Output(fullName="model_file", shortName = "modelFile", doc="A GATKReport containing the positive and negative model fits", required=false)
|
||||||
|
//protected PrintStream modelReport = null;
|
||||||
|
|
||||||
@Hidden
|
@Hidden
|
||||||
@Argument(fullName="replicate", shortName="replicate", doc="Used to debug the random number generation inside the VQSR. Do not use.", required=false)
|
@Argument(fullName="replicate", shortName="replicate", doc="Used to debug the random number generation inside the VQSR. Do not use.", required=false)
|
||||||
|
|
@ -311,6 +319,8 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
private PrintStream tranchesStream;
|
private PrintStream tranchesStream;
|
||||||
private final Set<String> ignoreInputFilterSet = new TreeSet<>();
|
private final Set<String> ignoreInputFilterSet = new TreeSet<>();
|
||||||
private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC );
|
private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC );
|
||||||
|
private GaussianMixtureModel goodModel = null;
|
||||||
|
private GaussianMixtureModel badModel = null;
|
||||||
|
|
||||||
//---------------------------------------------------------------------------------------------------------------
|
//---------------------------------------------------------------------------------------------------------------
|
||||||
//
|
//
|
||||||
|
|
@ -348,6 +358,28 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
throw new UserException.CommandLineException( "No truth set found! Please provide sets of known polymorphic loci marked with the truth=true ROD binding tag. For example, -resource:hapmap,VCF,known=false,training=true,truth=true,prior=12.0 hapmapFile.vcf" );
|
throw new UserException.CommandLineException( "No truth set found! Please provide sets of known polymorphic loci marked with the truth=true ROD binding tag. For example, -resource:hapmap,VCF,known=false,training=true,truth=true,prior=12.0 hapmapFile.vcf" );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
final File inputFile = new File(inputModel);
|
||||||
|
if (inputFile.exists()) { // Load GMM from a file
|
||||||
|
logger.info("Loading model from:" + inputModel);
|
||||||
|
final GATKReport reportIn = new GATKReport(inputFile);
|
||||||
|
|
||||||
|
// Read all the tables
|
||||||
|
final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
|
||||||
|
final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
|
||||||
|
final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
|
||||||
|
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
|
||||||
|
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
||||||
|
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
||||||
|
final int numAnnotations = dataManager.annotationKeys.size();
|
||||||
|
|
||||||
|
if( numAnnotations != pmmTable.getNumColumns()-1 || numAnnotations != nmmTable.getNumColumns()-1 ) { // -1 because the first column is the gaussian number.
|
||||||
|
throw new UserException.CommandLineException( "Annotations specified on the command line do not match annotations in the model report." );
|
||||||
|
}
|
||||||
|
|
||||||
|
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
||||||
|
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
||||||
|
}
|
||||||
|
|
||||||
final Set<VCFHeaderLine> hInfo = new HashSet<>();
|
final Set<VCFHeaderLine> hInfo = new HashSet<>();
|
||||||
ApplyRecalibration.addVQSRStandardHeaderLines(hInfo);
|
ApplyRecalibration.addVQSRStandardHeaderLines(hInfo);
|
||||||
recalWriter.writeHeader( new VCFHeader(hInfo) );
|
recalWriter.writeHeader( new VCFHeader(hInfo) );
|
||||||
|
|
@ -359,8 +391,11 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
// collect the actual rod bindings into a list for use later
|
// collect the actual rod bindings into a list for use later
|
||||||
for ( final RodBindingCollection<VariantContext> inputCollection : inputCollections )
|
for ( final RodBindingCollection<VariantContext> inputCollection : inputCollections )
|
||||||
input.addAll(inputCollection.getRodBindings());
|
input.addAll(inputCollection.getRodBindings());
|
||||||
|
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
//---------------------------------------------------------------------------------------------------------------
|
//---------------------------------------------------------------------------------------------------------------
|
||||||
//
|
//
|
||||||
// map
|
// map
|
||||||
|
|
@ -479,24 +514,37 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
dataManager.setData(reduceSum);
|
dataManager.setData(reduceSum);
|
||||||
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
||||||
|
|
||||||
// Generate the positive model using the training data and evaluate each variant
|
|
||||||
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
||||||
final GaussianMixtureModel goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
|
final List<VariantDatum> negativeTrainingData;
|
||||||
engine.evaluateData(dataManager.getData(), goodModel, false);
|
|
||||||
|
if (goodModel != null && badModel != null){ // GMMs were loaded from a file
|
||||||
|
logger.info("Using serialized GMMs from file...");
|
||||||
|
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||||
|
negativeTrainingData = dataManager.selectWorstVariants();
|
||||||
|
} else { // Generate the GMMs from scratch
|
||||||
|
// Generate the positive model using the training data and evaluate each variant
|
||||||
|
goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
|
||||||
|
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||||
|
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
||||||
|
negativeTrainingData = dataManager.selectWorstVariants();
|
||||||
|
badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
|
||||||
|
|
||||||
|
if (badModel.failedToConverge || goodModel.failedToConverge) {
|
||||||
|
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example)."));
|
||||||
|
}
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
|
||||||
final List<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants();
|
|
||||||
final GaussianMixtureModel badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
|
|
||||||
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
||||||
engine.evaluateData(dataManager.getData(), badModel, true);
|
engine.evaluateData(dataManager.getData(), badModel, true);
|
||||||
|
|
||||||
if (badModel.failedToConverge || goodModel.failedToConverge) {
|
if (outputModel != null) {
|
||||||
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example)."));
|
try (PrintStream modelReporter = new PrintStream(outputModel)) {
|
||||||
}
|
GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS);
|
||||||
|
report.print(modelReporter);
|
||||||
if (outputModel) {
|
} catch (FileNotFoundException e){
|
||||||
GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS);
|
throw new UserException("Could not open output model file:" + outputModel);
|
||||||
report.print(modelReport);
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
engine.calculateWorstPerformingAnnotation(dataManager.getData(), goodModel, badModel);
|
engine.calculateWorstPerformingAnnotation(dataManager.getData(), goodModel, badModel);
|
||||||
|
|
@ -537,8 +585,91 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Rebuild a Gaussian Mixture Model from gaussian means and co-variates stored in a GATKReportTables
|
||||||
|
* @param muTable Table of Gaussian means
|
||||||
|
* @param sigmaTable Table of Gaussian co-variates
|
||||||
|
* @param pmixTable Table of PMixLog10 values
|
||||||
|
* @param numAnnotations Number of annotations, i.e. Dimension of the annotation space in which the Gaussians live
|
||||||
|
* @return a GaussianMixtureModel whose state reflects the state recorded in the tables.
|
||||||
|
*/
|
||||||
|
protected GaussianMixtureModel GMMFromTables(final GATKReportTable muTable, final GATKReportTable sigmaTable, final GATKReportTable pmixTable, final int numAnnotations){
|
||||||
|
List<MultivariateGaussian> gaussianList = new ArrayList<>();
|
||||||
|
|
||||||
|
int curAnnotation = 0;
|
||||||
|
for (GATKReportColumn reportColumn : muTable.getColumnInfo() ) {
|
||||||
|
if (!reportColumn.getColumnName().equals("Gaussian")) {
|
||||||
|
for (int row = 0; row < muTable.getNumRows(); row++) {
|
||||||
|
if (gaussianList.size() <= row){
|
||||||
|
MultivariateGaussian mg = new MultivariateGaussian(numAnnotations);
|
||||||
|
gaussianList.add(mg);
|
||||||
|
}
|
||||||
|
gaussianList.get(row).mu[curAnnotation] = (Double) muTable.get(row, reportColumn.getColumnName());
|
||||||
|
}
|
||||||
|
curAnnotation++;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) {
|
||||||
|
if (reportColumn.getColumnName().equals("pMixLog10")) {
|
||||||
|
for (int row = 0; row < pmixTable.getNumRows(); row++) {
|
||||||
|
gaussianList.get(row).pMixtureLog10 = (Double) pmixTable.get(row, reportColumn.getColumnName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
int curJ = 0;
|
||||||
|
for (GATKReportColumn reportColumn : sigmaTable.getColumnInfo() ) {
|
||||||
|
if (reportColumn.getColumnName().equals("Gaussian")) continue;
|
||||||
|
if (reportColumn.getColumnName().equals("Annotation")) continue;
|
||||||
|
|
||||||
|
for (int row = 0; row < sigmaTable.getNumRows(); row++) {
|
||||||
|
int curGaussian = row / numAnnotations;
|
||||||
|
int curI = row % numAnnotations;
|
||||||
|
double curVal = (Double) sigmaTable.get(row, reportColumn.getColumnName());
|
||||||
|
gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal);
|
||||||
|
|
||||||
|
}
|
||||||
|
curJ++;
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
return new GaussianMixtureModel(gaussianList, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS);
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){
|
||||||
|
double[] stdVector = {};
|
||||||
|
|
||||||
|
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
|
||||||
|
if (reportColumn.getColumnName().equals("Standarddeviation")) {
|
||||||
|
stdVector = new double[astdTable.getNumRows()];
|
||||||
|
for (int row = 0; row < astdTable.getNumRows(); row++) {
|
||||||
|
stdVector[row] = (Double) astdTable.get(row, reportColumn.getColumnName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return stdVector;
|
||||||
|
}
|
||||||
|
|
||||||
|
private double[] getMeansFromTable(GATKReportTable amTable){
|
||||||
|
double[] meanVector = {};
|
||||||
|
|
||||||
|
for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) {
|
||||||
|
if (reportColumn.getColumnName().equals("Mean")) {
|
||||||
|
meanVector = new double[amTable.getNumRows()];
|
||||||
|
for (int row = 0; row < amTable.getNumRows(); row++) {
|
||||||
|
meanVector[row] = (Double) amTable.get(row, reportColumn.getColumnName());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return meanVector;
|
||||||
|
}
|
||||||
|
|
||||||
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
|
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
|
||||||
final String formatString = "%.3f";
|
final String formatString = "%.8E";
|
||||||
final GATKReport report = new GATKReport();
|
final GATKReport report = new GATKReport();
|
||||||
|
|
||||||
if (dataManager != null) { //for unit test
|
if (dataManager != null) { //for unit test
|
||||||
|
|
@ -547,10 +678,34 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
report.addTable(annotationMeans);
|
report.addTable(annotationMeans);
|
||||||
|
|
||||||
final double[] varianceVector = dataManager.getVarianceVector(); //"varianceVector" is actually stdev
|
final double[] varianceVector = dataManager.getVarianceVector(); //"varianceVector" is actually stdev
|
||||||
GATKReportTable annotationVariances = makeVectorTable("AnnotationStdevs", "Standard deviation for each annotation, used to normalize data", dataManager.annotationKeys, varianceVector, "Standard deviation", formatString);
|
GATKReportTable annotationVariances = makeVectorTable("AnnotationStdevs", "Standard deviation for each annotation, used to normalize data", dataManager.annotationKeys, varianceVector, "Standarddeviation", formatString);
|
||||||
report.addTable(annotationVariances);
|
report.addTable(annotationVariances);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
List<String> gaussianStrings = new ArrayList<>();
|
||||||
|
final double[] pMixtureLog10s = new double[goodModel.getModelGaussians().size()];
|
||||||
|
int idx = 0;
|
||||||
|
|
||||||
|
for( final MultivariateGaussian gaussian : goodModel.getModelGaussians() ) {
|
||||||
|
pMixtureLog10s[idx] = gaussian.pMixtureLog10;
|
||||||
|
gaussianStrings.add(Integer.toString(idx++) );
|
||||||
|
}
|
||||||
|
|
||||||
|
GATKReportTable goodPMix = makeVectorTable("GoodGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10s, "pMixLog10", formatString, "Gaussian");
|
||||||
|
report.addTable(goodPMix);
|
||||||
|
|
||||||
|
gaussianStrings.clear();
|
||||||
|
final double[] pMixtureLog10sBad = new double[badModel.getModelGaussians().size()];
|
||||||
|
idx = 0;
|
||||||
|
|
||||||
|
for( final MultivariateGaussian gaussian : badModel.getModelGaussians() ) {
|
||||||
|
pMixtureLog10sBad[idx] = gaussian.pMixtureLog10;
|
||||||
|
gaussianStrings.add(Integer.toString(idx++));
|
||||||
|
}
|
||||||
|
GATKReportTable badPMix = makeVectorTable("BadGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10sBad, "pMixLog10", formatString, "Gaussian");
|
||||||
|
report.addTable(badPMix);
|
||||||
|
|
||||||
|
|
||||||
//The model and Gaussians don't know what the annotations are, so get them from this class
|
//The model and Gaussians don't know what the annotations are, so get them from this class
|
||||||
//VariantDataManager keeps the annotation in the same order as the argument list
|
//VariantDataManager keeps the annotation in the same order as the argument list
|
||||||
GATKReportTable positiveMeans = makeMeansTable("PositiveModelMeans", "Vector of annotation values to describe the (normalized) mean for each Gaussian in the positive model", annotationList, goodModel, formatString);
|
GATKReportTable positiveMeans = makeMeansTable("PositiveModelMeans", "Vector of annotation values to describe the (normalized) mean for each Gaussian in the positive model", annotationList, goodModel, formatString);
|
||||||
|
|
@ -570,8 +725,12 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
}
|
}
|
||||||
|
|
||||||
protected GATKReportTable makeVectorTable(final String tableName, final String tableDescription, final List<String> annotationList, final double[] perAnnotationValues, final String columnName, final String formatString) {
|
protected GATKReportTable makeVectorTable(final String tableName, final String tableDescription, final List<String> annotationList, final double[] perAnnotationValues, final String columnName, final String formatString) {
|
||||||
|
return makeVectorTable(tableName, tableDescription, annotationList, perAnnotationValues, columnName, formatString, "Annotation");
|
||||||
|
}
|
||||||
|
|
||||||
|
protected GATKReportTable makeVectorTable(final String tableName, final String tableDescription, final List<String> annotationList, final double[] perAnnotationValues, final String columnName, final String formatString, final String firstColumn) {
|
||||||
GATKReportTable vectorTable = new GATKReportTable(tableName, tableDescription, annotationList.size(), GATKReportTable.TableSortingWay.DO_NOT_SORT);
|
GATKReportTable vectorTable = new GATKReportTable(tableName, tableDescription, annotationList.size(), GATKReportTable.TableSortingWay.DO_NOT_SORT);
|
||||||
vectorTable.addColumn("Annotation");
|
vectorTable.addColumn(firstColumn);
|
||||||
vectorTable.addColumn(columnName, formatString);
|
vectorTable.addColumn(columnName, formatString);
|
||||||
for (int i = 0; i < perAnnotationValues.length; i++) {
|
for (int i = 0; i < perAnnotationValues.length; i++) {
|
||||||
vectorTable.addRowIDMapping(annotationList.get(i), i, true);
|
vectorTable.addRowIDMapping(annotationList.get(i), i, true);
|
||||||
|
|
|
||||||
|
|
@ -98,6 +98,7 @@ public class VariantRecalibratorEngine {
|
||||||
try {
|
try {
|
||||||
model.precomputeDenominatorForEvaluation();
|
model.precomputeDenominatorForEvaluation();
|
||||||
} catch( Exception e ) {
|
} catch( Exception e ) {
|
||||||
|
logger.warn("Model could not pre-compute denominators.");
|
||||||
model.failedToConverge = true;
|
model.failedToConverge = true;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -107,6 +108,7 @@ public class VariantRecalibratorEngine {
|
||||||
for( final VariantDatum datum : data ) {
|
for( final VariantDatum datum : data ) {
|
||||||
final double thisLod = evaluateDatum( datum, model );
|
final double thisLod = evaluateDatum( datum, model );
|
||||||
if( Double.isNaN(thisLod) ) {
|
if( Double.isNaN(thisLod) ) {
|
||||||
|
logger.warn("Evaluate datum returned a NaN.");
|
||||||
model.failedToConverge = true;
|
model.failedToConverge = true;
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
@ -142,7 +144,7 @@ public class VariantRecalibratorEngine {
|
||||||
// Private Methods used for generating a GaussianMixtureModel
|
// Private Methods used for generating a GaussianMixtureModel
|
||||||
/////////////////////////////
|
/////////////////////////////
|
||||||
|
|
||||||
private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List<VariantDatum> data ) {
|
protected void variationalBayesExpectationMaximization(final GaussianMixtureModel model, final List<VariantDatum> data) {
|
||||||
|
|
||||||
model.initializeRandomModel( data, VRAC.NUM_KMEANS_ITERATIONS );
|
model.initializeRandomModel( data, VRAC.NUM_KMEANS_ITERATIONS );
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -51,127 +51,130 @@
|
||||||
|
|
||||||
package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
|
package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
|
||||||
|
|
||||||
import static org.testng.Assert.*;
|
|
||||||
|
|
||||||
import Jama.Matrix;
|
|
||||||
import org.apache.commons.lang.StringUtils;
|
|
||||||
import org.apache.log4j.Logger;
|
import org.apache.log4j.Logger;
|
||||||
|
import org.broadinstitute.gatk.utils.BaseTest;
|
||||||
import org.broadinstitute.gatk.utils.report.GATKReport;
|
import org.broadinstitute.gatk.utils.report.GATKReport;
|
||||||
import org.broadinstitute.gatk.utils.report.GATKReportTable;
|
import org.broadinstitute.gatk.utils.report.GATKReportTable;
|
||||||
import org.testng.Assert;
|
import org.testng.Assert;
|
||||||
import org.testng.annotations.Test;
|
import org.testng.annotations.Test;
|
||||||
|
|
||||||
|
import java.io.File;
|
||||||
|
import java.io.FileNotFoundException;
|
||||||
|
import java.io.PrintStream;
|
||||||
import java.util.ArrayList;
|
import java.util.ArrayList;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Random;
|
import java.util.Random;
|
||||||
|
|
||||||
public class VariantRecalibratorModelOutputUnitTest {
|
public class VariantRecalibratorModelOutputUnitTest extends BaseTest {
|
||||||
protected final static Logger logger = Logger.getLogger(VariantRecalibratorModelOutputUnitTest.class);
|
protected final static Logger logger = Logger.getLogger(VariantRecalibratorModelOutputUnitTest.class);
|
||||||
private final boolean printTables = true;
|
private final boolean printTables = true;
|
||||||
|
private final int numAnnotations = 6;
|
||||||
|
private final double shrinkage = 1.0;
|
||||||
|
private final double dirichlet = 0.001;
|
||||||
|
private final double priorCounts = 20.0;
|
||||||
|
private final double epsilon = 1e-6;
|
||||||
|
private final String modelReportName = "vqsr_model.report";
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
public void testVQSRModelOutput() {
|
public void testVQSRModelOutput() {
|
||||||
final int numAnnotations = 6;
|
GaussianMixtureModel goodModel = getGoodGMM();
|
||||||
final double shrinkage = 1.0;
|
GaussianMixtureModel badModel = getBadGMM();
|
||||||
final double dirichlet = 0.001;
|
|
||||||
final double priorCounts = 20.0;
|
|
||||||
final int numGoodGaussians = 2;
|
|
||||||
final int numBadGaussians = 1;
|
|
||||||
final double epsilon = 1e-6;
|
|
||||||
|
|
||||||
Random rand = new Random(12878);
|
|
||||||
MultivariateGaussian goodGaussian1 = new MultivariateGaussian(numAnnotations);
|
|
||||||
goodGaussian1.initializeRandomMu(rand);
|
|
||||||
goodGaussian1.initializeRandomSigma(rand);
|
|
||||||
|
|
||||||
MultivariateGaussian goodGaussian2 = new MultivariateGaussian(numAnnotations);
|
|
||||||
goodGaussian2.initializeRandomMu(rand);
|
|
||||||
goodGaussian2.initializeRandomSigma(rand);
|
|
||||||
|
|
||||||
MultivariateGaussian badGaussian1 = new MultivariateGaussian(numAnnotations);
|
|
||||||
badGaussian1.initializeRandomMu(rand);
|
|
||||||
badGaussian1.initializeRandomSigma(rand);
|
|
||||||
|
|
||||||
List<MultivariateGaussian> goodGaussianList = new ArrayList<>();
|
|
||||||
goodGaussianList.add(goodGaussian1);
|
|
||||||
goodGaussianList.add(goodGaussian2);
|
|
||||||
|
|
||||||
List<MultivariateGaussian> badGaussianList = new ArrayList<>();
|
|
||||||
badGaussianList.add(badGaussian1);
|
|
||||||
|
|
||||||
GaussianMixtureModel goodModel = new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts);
|
|
||||||
GaussianMixtureModel badModel = new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts);
|
|
||||||
|
|
||||||
if (printTables) {
|
if (printTables) {
|
||||||
System.out.println("Good model mean matrix:");
|
System.out.println("Good model mean matrix:");
|
||||||
System.out.println(vectorToString(goodGaussian1.mu));
|
System.out.println(vectorToString(goodModel.getModelGaussians().get(0).mu));
|
||||||
System.out.println(vectorToString(goodGaussian2.mu));
|
System.out.println(vectorToString(goodModel.getModelGaussians().get(1).mu));
|
||||||
System.out.println("\n\n");
|
System.out.println("\n\n");
|
||||||
|
|
||||||
System.out.println("Good model covariance matrices:");
|
System.out.println("Good model covariance matrices:");
|
||||||
goodGaussian1.sigma.print(10, 3);
|
goodModel.getModelGaussians().get(0).sigma.print(10, 3);
|
||||||
goodGaussian2.sigma.print(10, 3);
|
goodModel.getModelGaussians().get(1).sigma.print(10, 3);
|
||||||
System.out.println("\n\n");
|
System.out.println("\n\n");
|
||||||
|
|
||||||
System.out.println("Bad model mean matrix:\n");
|
System.out.println("Bad model mean matrix:\n");
|
||||||
System.out.println(vectorToString(badGaussian1.mu));
|
System.out.println(vectorToString(badModel.getModelGaussians().get(0).mu));
|
||||||
System.out.println("\n\n");
|
System.out.println("\n\n");
|
||||||
|
|
||||||
System.out.println("Bad model covariance matrix:");
|
System.out.println("Bad model covariance matrix:");
|
||||||
badGaussian1.sigma.print(10, 3);
|
badModel.getModelGaussians().get(0).sigma.print(10, 3);
|
||||||
}
|
}
|
||||||
|
|
||||||
VariantRecalibrator vqsr = new VariantRecalibrator();
|
VariantRecalibrator vqsr = new VariantRecalibrator();
|
||||||
List<String> annotationList = new ArrayList<>();
|
List<String> annotationList = getAnnotationList();
|
||||||
annotationList.add("QD");
|
|
||||||
annotationList.add("MQ");
|
|
||||||
annotationList.add("FS");
|
|
||||||
annotationList.add("SOR");
|
|
||||||
annotationList.add("ReadPosRankSum");
|
|
||||||
annotationList.add("MQRankSum");
|
|
||||||
|
|
||||||
|
|
||||||
GATKReport report = vqsr.writeModelReport(goodModel, badModel, annotationList);
|
GATKReport report = vqsr.writeModelReport(goodModel, badModel, annotationList);
|
||||||
if(printTables)
|
if(printTables) {
|
||||||
report.print(System.out);
|
try {
|
||||||
|
PrintStream modelReporter = new PrintStream(this.privateTestDir+this.modelReportName);
|
||||||
|
report.print(modelReporter);
|
||||||
|
} catch (FileNotFoundException e) {
|
||||||
|
e.printStackTrace();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
//Check values for Gaussian means
|
//Check values for Gaussian means
|
||||||
GATKReportTable goodMus = report.getTable("PositiveModelMeans");
|
GATKReportTable goodMus = report.getTable("PositiveModelMeans");
|
||||||
for(int i = 0; i < annotationList.size(); i++) {
|
for(int i = 0; i < annotationList.size(); i++) {
|
||||||
Assert.assertEquals(goodGaussian1.mu[i], (Double)goodMus.get(0,annotationList.get(i)), epsilon);
|
Assert.assertEquals(goodModel.getModelGaussians().get(0).mu[i], (Double)goodMus.get(0,annotationList.get(i)), epsilon);
|
||||||
}
|
}
|
||||||
for(int i = 0; i < annotationList.size(); i++) {
|
for(int i = 0; i < annotationList.size(); i++) {
|
||||||
Assert.assertEquals(goodGaussian2.mu[i], (Double)goodMus.get(1,annotationList.get(i)), epsilon);
|
Assert.assertEquals(goodModel.getModelGaussians().get(1).mu[i], (Double)goodMus.get(1,annotationList.get(i)), epsilon);
|
||||||
}
|
}
|
||||||
|
|
||||||
GATKReportTable badMus = report.getTable("NegativeModelMeans");
|
GATKReportTable badMus = report.getTable("NegativeModelMeans");
|
||||||
for(int i = 0; i < annotationList.size(); i++) {
|
for(int i = 0; i < annotationList.size(); i++) {
|
||||||
Assert.assertEquals(badGaussian1.mu[i], (Double)badMus.get(0,annotationList.get(i)), epsilon);
|
Assert.assertEquals(badModel.getModelGaussians().get(0).mu[i], (Double)badMus.get(0,annotationList.get(i)), epsilon);
|
||||||
}
|
}
|
||||||
|
|
||||||
//Check values for Gaussian covariances
|
//Check values for Gaussian covariances
|
||||||
GATKReportTable goodSigma = report.getTable("PositiveModelCovariances");
|
GATKReportTable goodSigma = report.getTable("PositiveModelCovariances");
|
||||||
for(int i = 0; i < annotationList.size(); i++) {
|
for(int i = 0; i < annotationList.size(); i++) {
|
||||||
for(int j = 0; j < annotationList.size(); j++) {
|
for(int j = 0; j < annotationList.size(); j++) {
|
||||||
Assert.assertEquals(goodGaussian1.sigma.get(i,j), (Double)goodSigma.get(i,annotationList.get(j)), epsilon);
|
Assert.assertEquals(goodModel.getModelGaussians().get(0).sigma.get(i,j), (Double)goodSigma.get(i,annotationList.get(j)), epsilon);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
//add annotationList.size() to row indexes for second Gaussian because the matrices are concatenated by row in the report
|
//add annotationList.size() to row indexes for second Gaussian because the matrices are concatenated by row in the report
|
||||||
for(int i = 0; i < annotationList.size(); i++) {
|
for(int i = 0; i < annotationList.size(); i++) {
|
||||||
for(int j = 0; j < annotationList.size(); j++) {
|
for(int j = 0; j < annotationList.size(); j++) {
|
||||||
Assert.assertEquals(goodGaussian2.sigma.get(i,j), (Double)goodSigma.get(annotationList.size()+i,annotationList.get(j)), epsilon);
|
Assert.assertEquals(goodModel.getModelGaussians().get(1).sigma.get(i,j), (Double)goodSigma.get(annotationList.size()+i,annotationList.get(j)), epsilon);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
GATKReportTable badSigma = report.getTable("NegativeModelCovariances");
|
GATKReportTable badSigma = report.getTable("NegativeModelCovariances");
|
||||||
for(int i = 0; i < annotationList.size(); i++) {
|
for(int i = 0; i < annotationList.size(); i++) {
|
||||||
for(int j = 0; j < annotationList.size(); j++) {
|
for(int j = 0; j < annotationList.size(); j++) {
|
||||||
Assert.assertEquals(badGaussian1.sigma.get(i,j), (Double)badSigma.get(i,annotationList.get(j)), epsilon);
|
Assert.assertEquals(badModel.getModelGaussians().get(0).sigma.get(i,j), (Double)badSigma.get(i,annotationList.get(j)), epsilon);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
@Test
|
||||||
|
public void testVQSRModelInput(){
|
||||||
|
final File inputFile = new File(this.privateTestDir + this.modelReportName);
|
||||||
|
final GATKReport report = new GATKReport(inputFile);
|
||||||
|
|
||||||
|
// Now test model report reading
|
||||||
|
// Read all the tables
|
||||||
|
final GATKReportTable badMus = report.getTable("NegativeModelMeans");
|
||||||
|
final GATKReportTable badSigma = report.getTable("NegativeModelCovariances");
|
||||||
|
final GATKReportTable nPMixTable = report.getTable("BadGaussianPMix");
|
||||||
|
|
||||||
|
final GATKReportTable goodMus = report.getTable("PositiveModelMeans");
|
||||||
|
final GATKReportTable goodSigma = report.getTable("PositiveModelCovariances");
|
||||||
|
final GATKReportTable pPMixTable = report.getTable("GoodGaussianPMix");
|
||||||
|
|
||||||
|
List<String> annotationList = getAnnotationList();
|
||||||
|
VariantRecalibrator vqsr = new VariantRecalibrator();
|
||||||
|
|
||||||
|
GaussianMixtureModel goodModelFromFile = vqsr.GMMFromTables(goodMus, goodSigma, pPMixTable, annotationList.size());
|
||||||
|
GaussianMixtureModel badModelFromFile = vqsr.GMMFromTables(badMus, badSigma, nPMixTable, annotationList.size());
|
||||||
|
|
||||||
|
testGMMsForEquality(getGoodGMM(), goodModelFromFile, epsilon);
|
||||||
|
testGMMsForEquality(getBadGMM(), badModelFromFile, epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
@Test
|
@Test
|
||||||
//This is tested separately to avoid setting up a VariantDataManager and populating it with fake data
|
//This is tested separately to avoid setting up a VariantDataManager and populating it with fake data
|
||||||
public void testAnnotationNormalizationOutput() {
|
public void testAnnotationNormalizationOutput() {
|
||||||
|
|
@ -211,4 +214,66 @@ public class VariantRecalibratorModelOutputUnitTest {
|
||||||
return returnString;
|
return returnString;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
private void testGMMsForEquality(GaussianMixtureModel gmm1, GaussianMixtureModel gmm2, double epsilon){
|
||||||
|
Assert.assertEquals(gmm1.getModelGaussians().size(), gmm2.getModelGaussians().size(), 0);
|
||||||
|
|
||||||
|
for(int k = 0; k < gmm1.getModelGaussians().size(); k++) {
|
||||||
|
final MultivariateGaussian g = gmm1.getModelGaussians().get(k);
|
||||||
|
final MultivariateGaussian gFile = gmm2.getModelGaussians().get(k);
|
||||||
|
|
||||||
|
Assert.assertEquals(g.pMixtureLog10, gFile.pMixtureLog10);
|
||||||
|
|
||||||
|
for(int i = 0; i < g.mu.length; i++){
|
||||||
|
Assert.assertEquals(g.mu[i], gFile.mu[i], epsilon);
|
||||||
|
}
|
||||||
|
|
||||||
|
for(int i = 0; i < g.sigma.getRowDimension(); i++) {
|
||||||
|
for (int j = 0; j < g.sigma.getColumnDimension(); j++) {
|
||||||
|
Assert.assertEquals(g.sigma.get(i, j), gFile.sigma.get(i, j), epsilon);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private List<String> getAnnotationList(){
|
||||||
|
List<String> annotationList = new ArrayList<>();
|
||||||
|
annotationList.add("QD");
|
||||||
|
annotationList.add("MQ");
|
||||||
|
annotationList.add("FS");
|
||||||
|
annotationList.add("SOR");
|
||||||
|
annotationList.add("ReadPosRankSum");
|
||||||
|
annotationList.add("MQRankSum");
|
||||||
|
return annotationList;
|
||||||
|
}
|
||||||
|
|
||||||
|
private GaussianMixtureModel getGoodGMM(){
|
||||||
|
Random rand = new Random(12878);
|
||||||
|
MultivariateGaussian goodGaussian1 = new MultivariateGaussian(numAnnotations);
|
||||||
|
goodGaussian1.initializeRandomMu(rand);
|
||||||
|
goodGaussian1.initializeRandomSigma(rand);
|
||||||
|
|
||||||
|
MultivariateGaussian goodGaussian2 = new MultivariateGaussian(numAnnotations);
|
||||||
|
goodGaussian2.initializeRandomMu(rand);
|
||||||
|
goodGaussian2.initializeRandomSigma(rand);
|
||||||
|
|
||||||
|
List<MultivariateGaussian> goodGaussianList = new ArrayList<>();
|
||||||
|
goodGaussianList.add(goodGaussian1);
|
||||||
|
goodGaussianList.add(goodGaussian2);
|
||||||
|
|
||||||
|
return new GaussianMixtureModel(goodGaussianList, shrinkage, dirichlet, priorCounts);
|
||||||
|
}
|
||||||
|
|
||||||
|
private GaussianMixtureModel getBadGMM(){
|
||||||
|
Random rand = new Random(12878);
|
||||||
|
MultivariateGaussian badGaussian1 = new MultivariateGaussian(numAnnotations);
|
||||||
|
|
||||||
|
badGaussian1.initializeRandomMu(rand);
|
||||||
|
badGaussian1.initializeRandomSigma(rand);
|
||||||
|
|
||||||
|
List<MultivariateGaussian> badGaussianList = new ArrayList<>();
|
||||||
|
badGaussianList.add(badGaussian1);
|
||||||
|
|
||||||
|
return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts);
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
@ -150,10 +150,8 @@ public class GATKReportTable {
|
||||||
// read a data line
|
// read a data line
|
||||||
final String dataLine = reader.readLine();
|
final String dataLine = reader.readLine();
|
||||||
final List<String> lineSplits = Arrays.asList(TextFormattingUtils.splitFixedWidth(dataLine, columnStarts));
|
final List<String> lineSplits = Arrays.asList(TextFormattingUtils.splitFixedWidth(dataLine, columnStarts));
|
||||||
|
|
||||||
underlyingData.add(new Object[nColumns]);
|
underlyingData.add(new Object[nColumns]);
|
||||||
for ( int columnIndex = 0; columnIndex < nColumns; columnIndex++ ) {
|
for ( int columnIndex = 0; columnIndex < nColumns; columnIndex++ ) {
|
||||||
|
|
||||||
final GATKReportDataType type = columnInfo.get(columnIndex).getDataType();
|
final GATKReportDataType type = columnInfo.get(columnIndex).getDataType();
|
||||||
final String columnName = columnNames[columnIndex];
|
final String columnName = columnNames[columnIndex];
|
||||||
set(i, columnName, type.Parse(lineSplits.get(columnIndex)));
|
set(i, columnName, type.Parse(lineSplits.get(columnIndex)));
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue