addressed review comments
This commit is contained in:
parent
85cb1c2810
commit
f55b932cfc
|
|
@ -320,5 +320,4 @@ public class GaussianMixtureModel {
|
|||
protected List<MultivariateGaussian> getModelGaussians() {return Collections.unmodifiableList(gaussians);}
|
||||
|
||||
protected int getNumAnnotations() {return empiricalMu.length;}
|
||||
|
||||
}
|
||||
|
|
@ -490,57 +490,46 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
dataManager.setData(reduceSum);
|
||||
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
||||
|
||||
// Generate the positive model using the training data and evaluate each variant
|
||||
final GaussianMixtureModel goodModel, badModel;
|
||||
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
||||
final List<VariantDatum> negativeTrainingData;
|
||||
|
||||
File inputFile = new File(inputModel);
|
||||
final File inputFile = new File(inputModel);
|
||||
if (inputFile.exists()){ // Load GMM from a file
|
||||
logger.info("Loading model from:"+inputModel);
|
||||
GATKReport reportIn = new GATKReport(inputFile);
|
||||
final GATKReport reportIn = new GATKReport(inputFile);
|
||||
|
||||
// Read all the tables
|
||||
GATKReportTable amTable = reportIn.getTable("AnnotationMeans");
|
||||
GATKReportTable astdTable = reportIn.getTable("AnnotationStdevs");
|
||||
final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
|
||||
final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
|
||||
final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
|
||||
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
|
||||
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
||||
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
||||
final int numAnnotations = dataManager.getMeanVector().length;
|
||||
|
||||
// Should have same number of means and standard deviations.
|
||||
assert(amTable.getNumRows() == astdTable.getNumRows() );
|
||||
|
||||
GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
|
||||
GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
|
||||
GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
|
||||
GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
|
||||
GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
||||
GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
||||
|
||||
int numAnnotations = amTable.getNumRows();
|
||||
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
||||
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||
negativeTrainingData = dataManager.selectWorstVariants();
|
||||
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
||||
logger.info("Loaded GMM from file:" + inputModel);
|
||||
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
||||
engine.evaluateData(dataManager.getData(), badModel, true);
|
||||
|
||||
} else { // Generate the GMMs from scratch
|
||||
// Generate the positive model using the training data and evaluate each variant
|
||||
goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
|
||||
//Utils.getRandomGenerator().setSeed(12878);
|
||||
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
||||
negativeTrainingData = dataManager.selectWorstVariants();
|
||||
|
||||
badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
|
||||
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
||||
//Utils.getRandomGenerator().setSeed(12878);
|
||||
engine.evaluateData(dataManager.getData(), badModel, true);
|
||||
|
||||
|
||||
if (badModel.failedToConverge || goodModel.failedToConverge) {
|
||||
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example)."));
|
||||
}
|
||||
}
|
||||
|
||||
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
||||
engine.evaluateData(dataManager.getData(), badModel, true);
|
||||
|
||||
if (outputModel) {
|
||||
GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS);
|
||||
report.print(modelReport);
|
||||
|
|
@ -592,7 +581,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
* @param numAnnotations Number of annotations, i.e. Dimension of the annotation space in which the Gaussians live
|
||||
* @return a GaussianMixtureModel whose state reflects the state recorded in the tables.
|
||||
*/
|
||||
private GaussianMixtureModel GMMFromTables(GATKReportTable muTable, GATKReportTable sigmaTable, GATKReportTable pmixTable, int numAnnotations){
|
||||
private GaussianMixtureModel GMMFromTables(final GATKReportTable muTable, final GATKReportTable sigmaTable, final GATKReportTable pmixTable, final int numAnnotations){
|
||||
List<MultivariateGaussian> gaussianList = new ArrayList<>();
|
||||
|
||||
int curAnnotation = 0;
|
||||
|
|
@ -642,7 +631,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
double[] stdVector = {};
|
||||
|
||||
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
|
||||
logger.info("Report column name is:" + reportColumn.getColumnName());
|
||||
if (reportColumn.getColumnName().equals("Standarddeviation")) {
|
||||
stdVector = new double[astdTable.getNumRows()];
|
||||
for (int row = 0; row < astdTable.getNumRows(); row++) {
|
||||
|
|
@ -663,7 +651,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
for (int row = 0; row < amTable.getNumRows(); row++) {
|
||||
meanVector[row] = Double.parseDouble((String) amTable.get(row, reportColumn.getColumnName()));
|
||||
}
|
||||
logger.info("Got mean Vector:" + Arrays.toString(meanVector));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -690,7 +677,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
|
||||
for( final MultivariateGaussian gaussian : goodModel.getModelGaussians() ) {
|
||||
pMixtureLog10s[idx] = gaussian.pMixtureLog10;
|
||||
logger.info("Good normalize PMix log 10 is:" + Double.toString(gaussian.pMixtureLog10) );
|
||||
gaussianStrings.add(Integer.toString(idx++) );
|
||||
}
|
||||
|
||||
|
|
@ -703,7 +689,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
|
||||
for( final MultivariateGaussian gaussian : badModel.getModelGaussians() ) {
|
||||
pMixtureLog10sBad[idx] = gaussian.pMixtureLog10;
|
||||
logger.info("Bad normalize PMix log 10 is:" + Double.toString(gaussian.pMixtureLog10));
|
||||
gaussianStrings.add(Integer.toString(idx++));
|
||||
}
|
||||
GATKReportTable badPMix = makeVectorTable("BadGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10sBad, "pMixLog10", formatString, "Gaussian");
|
||||
|
|
|
|||
Loading…
Reference in New Issue