addressed review comments
This commit is contained in:
parent
85cb1c2810
commit
f55b932cfc
|
|
@ -320,5 +320,4 @@ public class GaussianMixtureModel {
|
||||||
protected List<MultivariateGaussian> getModelGaussians() {return Collections.unmodifiableList(gaussians);}
|
protected List<MultivariateGaussian> getModelGaussians() {return Collections.unmodifiableList(gaussians);}
|
||||||
|
|
||||||
protected int getNumAnnotations() {return empiricalMu.length;}
|
protected int getNumAnnotations() {return empiricalMu.length;}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
@ -490,57 +490,46 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
dataManager.setData(reduceSum);
|
dataManager.setData(reduceSum);
|
||||||
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
||||||
|
|
||||||
// Generate the positive model using the training data and evaluate each variant
|
|
||||||
final GaussianMixtureModel goodModel, badModel;
|
final GaussianMixtureModel goodModel, badModel;
|
||||||
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
||||||
final List<VariantDatum> negativeTrainingData;
|
final List<VariantDatum> negativeTrainingData;
|
||||||
|
|
||||||
File inputFile = new File(inputModel);
|
final File inputFile = new File(inputModel);
|
||||||
if (inputFile.exists()){ // Load GMM from a file
|
if (inputFile.exists()){ // Load GMM from a file
|
||||||
logger.info("Loading model from:"+inputModel);
|
logger.info("Loading model from:"+inputModel);
|
||||||
GATKReport reportIn = new GATKReport(inputFile);
|
final GATKReport reportIn = new GATKReport(inputFile);
|
||||||
|
|
||||||
// Read all the tables
|
// Read all the tables
|
||||||
GATKReportTable amTable = reportIn.getTable("AnnotationMeans");
|
final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
|
||||||
GATKReportTable astdTable = reportIn.getTable("AnnotationStdevs");
|
final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
|
||||||
|
final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
|
||||||
|
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
|
||||||
|
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
||||||
|
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
||||||
|
final int numAnnotations = dataManager.getMeanVector().length;
|
||||||
|
|
||||||
// Should have same number of means and standard deviations.
|
|
||||||
assert(amTable.getNumRows() == astdTable.getNumRows() );
|
|
||||||
|
|
||||||
GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
|
|
||||||
GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
|
|
||||||
GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
|
|
||||||
GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
|
|
||||||
GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
|
||||||
GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
|
||||||
|
|
||||||
int numAnnotations = amTable.getNumRows();
|
|
||||||
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
||||||
engine.evaluateData(dataManager.getData(), goodModel, false);
|
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||||
negativeTrainingData = dataManager.selectWorstVariants();
|
negativeTrainingData = dataManager.selectWorstVariants();
|
||||||
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
||||||
logger.info("Loaded GMM from file:" + inputModel);
|
|
||||||
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
|
||||||
engine.evaluateData(dataManager.getData(), badModel, true);
|
|
||||||
|
|
||||||
} else { // Generate the GMMs from scratch
|
} else { // Generate the GMMs from scratch
|
||||||
|
// Generate the positive model using the training data and evaluate each variant
|
||||||
goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
|
goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
|
||||||
//Utils.getRandomGenerator().setSeed(12878);
|
|
||||||
engine.evaluateData(dataManager.getData(), goodModel, false);
|
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||||
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
||||||
negativeTrainingData = dataManager.selectWorstVariants();
|
negativeTrainingData = dataManager.selectWorstVariants();
|
||||||
|
|
||||||
badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
|
badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
|
||||||
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
|
||||||
//Utils.getRandomGenerator().setSeed(12878);
|
|
||||||
engine.evaluateData(dataManager.getData(), badModel, true);
|
|
||||||
|
|
||||||
|
|
||||||
if (badModel.failedToConverge || goodModel.failedToConverge) {
|
if (badModel.failedToConverge || goodModel.failedToConverge) {
|
||||||
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example)."));
|
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example)."));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
||||||
|
engine.evaluateData(dataManager.getData(), badModel, true);
|
||||||
|
|
||||||
if (outputModel) {
|
if (outputModel) {
|
||||||
GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS);
|
GATKReport report = writeModelReport(goodModel, badModel, USE_ANNOTATIONS);
|
||||||
report.print(modelReport);
|
report.print(modelReport);
|
||||||
|
|
@ -592,7 +581,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
* @param numAnnotations Number of annotations, i.e. Dimension of the annotation space in which the Gaussians live
|
* @param numAnnotations Number of annotations, i.e. Dimension of the annotation space in which the Gaussians live
|
||||||
* @return a GaussianMixtureModel whose state reflects the state recorded in the tables.
|
* @return a GaussianMixtureModel whose state reflects the state recorded in the tables.
|
||||||
*/
|
*/
|
||||||
private GaussianMixtureModel GMMFromTables(GATKReportTable muTable, GATKReportTable sigmaTable, GATKReportTable pmixTable, int numAnnotations){
|
private GaussianMixtureModel GMMFromTables(final GATKReportTable muTable, final GATKReportTable sigmaTable, final GATKReportTable pmixTable, final int numAnnotations){
|
||||||
List<MultivariateGaussian> gaussianList = new ArrayList<>();
|
List<MultivariateGaussian> gaussianList = new ArrayList<>();
|
||||||
|
|
||||||
int curAnnotation = 0;
|
int curAnnotation = 0;
|
||||||
|
|
@ -642,7 +631,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
double[] stdVector = {};
|
double[] stdVector = {};
|
||||||
|
|
||||||
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
|
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
|
||||||
logger.info("Report column name is:" + reportColumn.getColumnName());
|
|
||||||
if (reportColumn.getColumnName().equals("Standarddeviation")) {
|
if (reportColumn.getColumnName().equals("Standarddeviation")) {
|
||||||
stdVector = new double[astdTable.getNumRows()];
|
stdVector = new double[astdTable.getNumRows()];
|
||||||
for (int row = 0; row < astdTable.getNumRows(); row++) {
|
for (int row = 0; row < astdTable.getNumRows(); row++) {
|
||||||
|
|
@ -663,7 +651,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
for (int row = 0; row < amTable.getNumRows(); row++) {
|
for (int row = 0; row < amTable.getNumRows(); row++) {
|
||||||
meanVector[row] = Double.parseDouble((String) amTable.get(row, reportColumn.getColumnName()));
|
meanVector[row] = Double.parseDouble((String) amTable.get(row, reportColumn.getColumnName()));
|
||||||
}
|
}
|
||||||
logger.info("Got mean Vector:" + Arrays.toString(meanVector));
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -690,7 +677,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
|
|
||||||
for( final MultivariateGaussian gaussian : goodModel.getModelGaussians() ) {
|
for( final MultivariateGaussian gaussian : goodModel.getModelGaussians() ) {
|
||||||
pMixtureLog10s[idx] = gaussian.pMixtureLog10;
|
pMixtureLog10s[idx] = gaussian.pMixtureLog10;
|
||||||
logger.info("Good normalize PMix log 10 is:" + Double.toString(gaussian.pMixtureLog10) );
|
|
||||||
gaussianStrings.add(Integer.toString(idx++) );
|
gaussianStrings.add(Integer.toString(idx++) );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -703,7 +689,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
|
|
||||||
for( final MultivariateGaussian gaussian : badModel.getModelGaussians() ) {
|
for( final MultivariateGaussian gaussian : badModel.getModelGaussians() ) {
|
||||||
pMixtureLog10sBad[idx] = gaussian.pMixtureLog10;
|
pMixtureLog10sBad[idx] = gaussian.pMixtureLog10;
|
||||||
logger.info("Bad normalize PMix log 10 is:" + Double.toString(gaussian.pMixtureLog10));
|
|
||||||
gaussianStrings.add(Integer.toString(idx++));
|
gaussianStrings.add(Integer.toString(idx++));
|
||||||
}
|
}
|
||||||
GATKReportTable badPMix = makeVectorTable("BadGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10sBad, "pMixLog10", formatString, "Gaussian");
|
GATKReportTable badPMix = makeVectorTable("BadGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10sBad, "pMixLog10", formatString, "Gaussian");
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue