Create GMM from model reports in VQSR

This commit is contained in:
Samuel Friedman 2016-11-04 15:05:23 -04:00
parent 4fe4ace232
commit a8f70c891f
4 changed files with 216 additions and 13 deletions

View File

@ -86,6 +86,11 @@ public class GaussianMixtureModel {
gaussians = new ArrayList<>( numGaussians );
for( int iii = 0; iii < numGaussians; iii++ ) {
final MultivariateGaussian gaussian = new MultivariateGaussian( numAnnotations );
gaussian.pMixtureLog10 = Math.log10( 1.0 / ((double)numGaussians) );
gaussian.sumProb = 1.0 / ((double) numGaussians);
gaussian.hyperParameter_a = priorCounts;
gaussian.hyperParameter_b = shrinkage;
gaussian.hyperParameter_lambda = dirichletParameter;
gaussians.add( gaussian );
}
this.shrinkage = shrinkage;
@ -190,6 +195,9 @@ public class GaussianMixtureModel {
final double[] pVarInGaussianNormalized = MathUtils.normalizeFromLog10( pVarInGaussianLog10, false );
gaussianIndex = 0;
for( final MultivariateGaussian gaussian : gaussians ) {
if (Double.isNaN(pVarInGaussianNormalized[gaussianIndex])){
logger.info(" Got a NaN at gaussian:" + Integer.toString(gaussianIndex) + " datum:" + datum.toString());
}
gaussian.assignPVarInGaussian( pVarInGaussianNormalized[gaussianIndex++] );
}
}
@ -315,4 +323,5 @@ public class GaussianMixtureModel {
protected List<MultivariateGaussian> getModelGaussians() {return Collections.unmodifiableList(gaussians);}
protected int getNumAnnotations() {return empiricalMu.length;}
}

View File

@ -271,4 +271,14 @@ public class MultivariateGaussian {
resetPVarInGaussian(); // clean up some memory
}
public void setSumProb( final List<VariantDatum> data ) {
sumProb = 0.0;
for( int datumIndex = 0; datumIndex < data.size(); datumIndex++ ) {
final double prob = pVarInGaussian.get(datumIndex);
if(!Double.isNaN(prob)) sumProb += prob;
}
}
}

View File

@ -66,6 +66,7 @@ import org.broadinstitute.gatk.utils.R.RScriptExecutor;
import org.broadinstitute.gatk.utils.Utils;
import org.broadinstitute.gatk.utils.help.HelpConstants;
import org.broadinstitute.gatk.utils.report.GATKReport;
import org.broadinstitute.gatk.utils.report.GATKReportColumn;
import org.broadinstitute.gatk.utils.report.GATKReportTable;
import org.broadinstitute.gatk.utils.variant.GATKVariantContextUtils;
import htsjdk.variant.vcf.VCFHeader;
@ -80,10 +81,15 @@ import htsjdk.variant.variantcontext.writer.VariantContextWriter;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.nio.file.Files;
import java.util.*;
import Jama.Matrix;
import java.io.FileWriter;
import java.io.BufferedWriter;
import java.io.IOException;
/**
* Build a recalibration model to score variant quality for filtering purposes
*
@ -274,6 +280,8 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
*/
@Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model fit to the file specified by -modelFile or to stdout", required=false)
private boolean outputModel = false;
@Argument(fullName="input_model", shortName = "inputModel", doc="If specified, the variant recalibrator will read the VQSR model from the file specified by -modelFile", required=false)
private String inputModel = "";
@Output(fullName="model_file", shortName = "modelFile", doc="A GATKReport containing the positive and negative model fits", required=false)
protected PrintStream modelReport = null;
@ -359,8 +367,11 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
// collect the actual rod bindings into a list for use later
for ( final RodBindingCollection<VariantContext> inputCollection : inputCollections )
input.addAll(inputCollection.getRodBindings());
}
//---------------------------------------------------------------------------------------------------------------
//
// map
@ -480,25 +491,86 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
// Generate the positive model using the training data and evaluate each variant
final GaussianMixtureModel goodModel, badModel;
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
final GaussianMixtureModel goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
engine.evaluateData(dataManager.getData(), goodModel, false);
final List<VariantDatum> negativeTrainingData;
File inputFile = new File(inputModel);
if (inputFile.exists()){ // Load GMM from a file
GATKReport reportIn = new GATKReport(inputFile);
GATKReportTable amTable = reportIn.getTable("AnnotationMeans");
GATKReportTable astdTable = reportIn.getTable("AnnotationStdevs");
GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
double[] meanVector;
double[] stdVector;
int numAnnotations = 0;
for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) {
if (reportColumn.getColumnName().equals("Mean")) {
meanVector = new double[amTable.getNumRows()];
numAnnotations = amTable.getNumRows();
for (int row = 0; row < amTable.getNumRows(); row++) {
meanVector[row] = (double) amTable.get(row, reportColumn.getColumnName());
}
logger.info("Got mean Vector:" + Arrays.toString(meanVector));
}
}
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
logger.info("Report column name is:" + reportColumn.getColumnName());
if (reportColumn.getColumnName().equals("Standarddeviation")) {
stdVector = new double[astdTable.getNumRows()];
for (int row = 0; row < astdTable.getNumRows(); row++) {
stdVector[row] = (double) astdTable.get(row, reportColumn.getColumnName());
}
}
}
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
//Utils.getRandomGenerator().setSeed(12878);
engine.evaluateData(dataManager.getData(), goodModel, false);
negativeTrainingData = dataManager.selectWorstVariants();
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
logger.info("Loaded GMM from file:" + inputModel);
// Generate the negative model using the worst performing data and evaluate each variant contrastively
final List<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants();
final GaussianMixtureModel badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
//Utils.getRandomGenerator().setSeed(12878);
engine.evaluateData(dataManager.getData(), badModel, true);
} else { // Generate the GMMs from scratch
goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
//Utils.getRandomGenerator().setSeed(12878);
engine.evaluateData(dataManager.getData(), goodModel, false);
// Generate the negative model using the worst performing data and evaluate each variant contrastively
negativeTrainingData = dataManager.selectWorstVariants();
badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
//Utils.getRandomGenerator().setSeed(12878);
engine.evaluateData(dataManager.getData(), badModel, true);
if (badModel.failedToConverge || goodModel.failedToConverge) {
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example)."));
}
}
if (outputModel) {
GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS);
report.print(modelReport);
}
boolean writeFeatures = true;
if (writeFeatures) writeFeaturesFiles(positiveTrainingData, negativeTrainingData);
engine.calculateWorstPerformingAnnotation(dataManager.getData(), goodModel, badModel);
// Find the VQSLOD cutoff values which correspond to the various tranches of calls requested by the user
@ -537,8 +609,91 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
}
}
public GaussianMixtureModel GMMFromTables(GATKReportTable muTable, GATKReportTable sigmaTable, GATKReportTable pmixTable, int numAnnotations){
List<MultivariateGaussian> gaussianList = new ArrayList<>();
int curAnnotation = 0;
for (GATKReportColumn reportColumn : muTable.getColumnInfo() ) {
logger.info("Report column name is:" + reportColumn.getColumnName());
if (!reportColumn.getColumnName().equals("Gaussian")) {
for (int row = 0; row < muTable.getNumRows(); row++) {
if (gaussianList.size() <= row){
MultivariateGaussian mg = new MultivariateGaussian(numAnnotations);
gaussianList.add(mg);
}
gaussianList.get(row).mu[curAnnotation] = (double) muTable.get(row, reportColumn.getColumnName());
}
curAnnotation++;
}
}
for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) {
if (reportColumn.getColumnName().equals("pMixLog10")) {
for (int row = 0; row < pmixTable.getNumRows(); row++) {
gaussianList.get(row).pMixtureLog10 = (double) pmixTable.get(row, reportColumn.getColumnName());
}
}
}
int curJ = 0;
for (GATKReportColumn reportColumn : sigmaTable.getColumnInfo() ) {
if (reportColumn.getColumnName().equals("Gaussian")) continue;
if (reportColumn.getColumnName().equals("Annotation")) continue;
for (int row = 0; row < sigmaTable.getNumRows(); row++) {
int curGaussian = row / numAnnotations;
int curI = row % numAnnotations;
double curVal = (double) sigmaTable.get(row, reportColumn.getColumnName());
gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal);
}
curJ++;
}
return new GaussianMixtureModel(gaussianList, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS);
}
private void writeFeaturesFiles(List<VariantDatum> positiveTrainingData, List<VariantDatum> negativeTrainingData){
//Begin Sam Hacking
try {
File file = new File("/Users/sam/data/haploid_features.txt");
file.createNewFile();
File badFile = new File("/Users/sam/data/haploid_bad_features.txt");
badFile.createNewFile();
FileWriter fw = new FileWriter(file.getAbsoluteFile());
BufferedWriter bw = new BufferedWriter(fw);
for(int jj = 0; jj < positiveTrainingData.size(); jj++){
VariantDatum v = positiveTrainingData.get(jj);
for(int kk = 0; kk < v.annotations.length; kk++){
bw.write(Double.toString(v.annotations[kk]));
bw.write(" ");
}
bw.write("\n");
}
bw.close();
fw = new FileWriter(badFile.getAbsoluteFile());
bw = new BufferedWriter(fw);
for(int jj = 0; jj < negativeTrainingData.size(); jj++){
VariantDatum v = negativeTrainingData.get(jj);
for(int kk = 0; kk < v.annotations.length; kk++){
bw.write(Double.toString(v.annotations[kk]));
bw.write(" ");
}
bw.write("\n");
}
bw.close();
}catch(IOException e){
e.printStackTrace();
}
// End Sam Hacking
}
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
final String formatString = "%.3f";
final String formatString = "%.25f";
final GATKReport report = new GATKReport();
if (dataManager != null) { //for unit test
@ -551,6 +706,32 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
report.addTable(annotationVariances);
}
List<String> gaussianStrings = new ArrayList<>();
final double[] pMixtureLog10s = new double[goodModel.getModelGaussians().size()];
int idx = 0;
for( final MultivariateGaussian gaussian : goodModel.getModelGaussians() ) {
pMixtureLog10s[idx] = gaussian.pMixtureLog10;
logger.info("Good normalize PMix log 10 is:" + Double.toString(gaussian.pMixtureLog10) );
gaussianStrings.add(Integer.toString(idx++) );
}
GATKReportTable goodPMix = makeVectorTable("GoodGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10s, "pMixLog10", formatString, "Gaussian");
report.addTable(goodPMix);
gaussianStrings.clear();
final double[] pMixtureLog10sBad = new double[badModel.getModelGaussians().size()];
idx = 0;
for( final MultivariateGaussian gaussian : badModel.getModelGaussians() ) {
pMixtureLog10sBad[idx] = gaussian.pMixtureLog10;
logger.info("Bad normalize PMix log 10 is:" + Double.toString(gaussian.pMixtureLog10));
gaussianStrings.add(Integer.toString(idx++));
}
GATKReportTable badPMix = makeVectorTable("BadGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10sBad, "pMixLog10", formatString, "Gaussian");
report.addTable(badPMix);
//The model and Gaussians don't know what the annotations are, so get them from this class
//VariantDataManager keeps the annotation in the same order as the argument list
GATKReportTable positiveMeans = makeMeansTable("PositiveModelMeans", "Vector of annotation values to describe the (normalized) mean for each Gaussian in the positive model", annotationList, goodModel, formatString);
@ -570,8 +751,12 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
}
protected GATKReportTable makeVectorTable(final String tableName, final String tableDescription, final List<String> annotationList, final double[] perAnnotationValues, final String columnName, final String formatString) {
return makeVectorTable(tableName, tableDescription, annotationList, perAnnotationValues, columnName, formatString, "Annotation");
}
protected GATKReportTable makeVectorTable(final String tableName, final String tableDescription, final List<String> annotationList, final double[] perAnnotationValues, final String columnName, final String formatString, final String firstColumn) {
GATKReportTable vectorTable = new GATKReportTable(tableName, tableDescription, annotationList.size(), GATKReportTable.TableSortingWay.DO_NOT_SORT);
vectorTable.addColumn("Annotation");
vectorTable.addColumn(firstColumn);
vectorTable.addColumn(columnName, formatString);
for (int i = 0; i < perAnnotationValues.length; i++) {
vectorTable.addRowIDMapping(annotationList.get(i), i, true);

View File

@ -150,7 +150,6 @@ public class GATKReportTable {
// read a data line
final String dataLine = reader.readLine();
final List<String> lineSplits = Arrays.asList(TextFormattingUtils.splitFixedWidth(dataLine, columnStarts));
underlyingData.add(new Object[nColumns]);
for ( int columnIndex = 0; columnIndex < nColumns; columnIndex++ ) {