addressed review comments

This commit is contained in:
Samuel Friedman 2017-04-27 17:16:57 -04:00
parent 85cb1c2810
commit f55b932cfc
2 changed files with 14 additions and 30 deletions

View File

@ -320,5 +320,4 @@ public class GaussianMixtureModel {
protected List<MultivariateGaussian> getModelGaussians() {return Collections.unmodifiableList(gaussians);}
protected int getNumAnnotations() {return empiricalMu.length;}
}

View File

@ -490,57 +490,46 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
dataManager.setData(reduceSum);
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
// Generate the positive model using the training data and evaluate each variant
final GaussianMixtureModel goodModel, badModel;
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
final List<VariantDatum> negativeTrainingData;
File inputFile = new File(inputModel);
final File inputFile = new File(inputModel);
if (inputFile.exists()){ // Load GMM from a file
logger.info("Loading model from:"+inputModel);
GATKReport reportIn = new GATKReport(inputFile);
final GATKReport reportIn = new GATKReport(inputFile);
// Read all the tables
GATKReportTable amTable = reportIn.getTable("AnnotationMeans");
GATKReportTable astdTable = reportIn.getTable("AnnotationStdevs");
final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
final int numAnnotations = dataManager.getMeanVector().length;
// Should have same number of means and standard deviations.
assert(amTable.getNumRows() == astdTable.getNumRows() );
GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
int numAnnotations = amTable.getNumRows();
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
engine.evaluateData(dataManager.getData(), goodModel, false);
negativeTrainingData = dataManager.selectWorstVariants();
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
logger.info("Loaded GMM from file:" + inputModel);
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
engine.evaluateData(dataManager.getData(), badModel, true);
} else { // Generate the GMMs from scratch
// Generate the positive model using the training data and evaluate each variant
goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
//Utils.getRandomGenerator().setSeed(12878);
engine.evaluateData(dataManager.getData(), goodModel, false);
// Generate the negative model using the worst performing data and evaluate each variant contrastively
negativeTrainingData = dataManager.selectWorstVariants();
badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
//Utils.getRandomGenerator().setSeed(12878);
engine.evaluateData(dataManager.getData(), badModel, true);
if (badModel.failedToConverge || goodModel.failedToConverge) {
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example)."));
}
}
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
engine.evaluateData(dataManager.getData(), badModel, true);
if (outputModel) {
GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS);
report.print(modelReport);
@ -592,7 +581,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
* @param numAnnotations Number of annotations, i.e. Dimension of the annotation space in which the Gaussians live
* @return a GaussianMixtureModel whose state reflects the state recorded in the tables.
*/
private GaussianMixtureModel GMMFromTables(GATKReportTable muTable, GATKReportTable sigmaTable, GATKReportTable pmixTable, int numAnnotations){
private GaussianMixtureModel GMMFromTables(final GATKReportTable muTable, final GATKReportTable sigmaTable, final GATKReportTable pmixTable, final int numAnnotations){
List<MultivariateGaussian> gaussianList = new ArrayList<>();
int curAnnotation = 0;
@ -642,7 +631,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
double[] stdVector = {};
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
logger.info("Report column name is:" + reportColumn.getColumnName());
if (reportColumn.getColumnName().equals("Standarddeviation")) {
stdVector = new double[astdTable.getNumRows()];
for (int row = 0; row < astdTable.getNumRows(); row++) {
@ -663,7 +651,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
for (int row = 0; row < amTable.getNumRows(); row++) {
meanVector[row] = Double.parseDouble((String) amTable.get(row, reportColumn.getColumnName()));
}
logger.info("Got mean Vector:" + Arrays.toString(meanVector));
}
}
@ -690,7 +677,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
for( final MultivariateGaussian gaussian : goodModel.getModelGaussians() ) {
pMixtureLog10s[idx] = gaussian.pMixtureLog10;
logger.info("Good normalize PMix log 10 is:" + Double.toString(gaussian.pMixtureLog10) );
gaussianStrings.add(Integer.toString(idx++) );
}
@ -703,7 +689,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
for( final MultivariateGaussian gaussian : badModel.getModelGaussians() ) {
pMixtureLog10sBad[idx] = gaussian.pMixtureLog10;
logger.info("Bad normalize PMix log 10 is:" + Double.toString(gaussian.pMixtureLog10));
gaussianStrings.add(Integer.toString(idx++));
}
GATKReportTable badPMix = makeVectorTable("BadGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10sBad, "pMixLog10", formatString, "Gaussian");