scientific notation in model report and basic test

This commit is contained in:
Samuel Friedman 2017-05-03 18:37:52 -04:00
parent 4df71a3ca1
commit f5d133df87
5 changed files with 45 additions and 11 deletions

View File

@ -110,6 +110,11 @@ public class GaussianMixtureModel {
this.shrinkage = shrinkage;
this.dirichletParameter = dirichletParameter;
this.priorCounts = priorCounts;
for( final MultivariateGaussian gaussian : gaussians ) {
gaussian.hyperParameter_a = priorCounts;
gaussian.hyperParameter_b = shrinkage;
gaussian.hyperParameter_lambda = dirichletParameter;
}
empiricalMu = new double[numAnnotations];
empiricalSigma = new Matrix(numAnnotations, numAnnotations);
isModelReadyForEvaluation = false;

View File

@ -277,7 +277,7 @@ public class VariantDataManager {
}
}
logger.info( "Training with worst " + trainingData.size() + " scoring variants --> variants with LOD <= " + String.format("%.4f", VRAC.BAD_LOD_CUTOFF) + "." );
logger.info( "Selected worst " + trainingData.size() + " scoring variants --> variants with LOD <= " + String.format("%.4f", VRAC.BAD_LOD_CUTOFF) + "." );
return trainingData;
}

View File

@ -376,10 +376,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
}
final Set<VCFHeaderLine> hInfo = new HashSet<>();
final Set<VCFHeaderLine> hInfo = new HashSet<>();
ApplyRecalibration.addVQSRStandardHeaderLines(hInfo);
recalWriter.writeHeader( new VCFHeader(hInfo) );
@ -513,12 +510,11 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
dataManager.setData(reduceSum);
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
//final GaussianMixtureModel goodModel, badModel;
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
final List<VariantDatum> negativeTrainingData;
if (goodModel != null && badModel != null){ // GMMs were loaded from a file
// Keeping this to maintain reproducibility between runs with and without serialized GMMs
logger.info("Using serialized GMMs from file...");
engine.evaluateData(dataManager.getData(), goodModel, false);
negativeTrainingData = dataManager.selectWorstVariants();
} else { // Generate the GMMs from scratch
@ -527,12 +523,12 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
engine.evaluateData(dataManager.getData(), goodModel, false);
// Generate the negative model using the worst performing data and evaluate each variant contrastively
negativeTrainingData = dataManager.selectWorstVariants();
badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
if (badModel.failedToConverge || goodModel.failedToConverge) {
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example)."));
}
}
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
@ -589,7 +585,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
* @param numAnnotations Number of annotations, i.e. Dimension of the annotation space in which the Gaussians live
* @return a GaussianMixtureModel whose state reflects the state recorded in the tables.
*/
private GaussianMixtureModel GMMFromTables(final GATKReportTable muTable, final GATKReportTable sigmaTable, final GATKReportTable pmixTable, final int numAnnotations){
protected GaussianMixtureModel GMMFromTables(final GATKReportTable muTable, final GATKReportTable sigmaTable, final GATKReportTable pmixTable, final int numAnnotations){
List<MultivariateGaussian> gaussianList = new ArrayList<>();
int curAnnotation = 0;
@ -665,7 +661,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
}
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
final String formatString = "%.8f";
final String formatString = "%.8E";
final GATKReport report = new GATKReport();
if (dataManager != null) { //for unit test

View File

@ -98,6 +98,7 @@ public class VariantRecalibratorEngine {
try {
model.precomputeDenominatorForEvaluation();
} catch( Exception e ) {
logger.warn("Model could not pre-compute denominators.");
model.failedToConverge = true;
return;
}
@ -107,6 +108,7 @@ public class VariantRecalibratorEngine {
for( final VariantDatum datum : data ) {
final double thisLod = evaluateDatum( datum, model );
if( Double.isNaN(thisLod) ) {
logger.warn("Evaluate datum returned a NaN.");
model.failedToConverge = true;
return;
}
@ -142,7 +144,7 @@ public class VariantRecalibratorEngine {
// Private Methods used for generating a GaussianMixtureModel
/////////////////////////////
private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List<VariantDatum> data ) {
protected void variationalBayesExpectationMaximization(final GaussianMixtureModel model, final List<VariantDatum> data) {
model.initializeRandomModel( data, VRAC.NUM_KMEANS_ITERATIONS );

View File

@ -170,8 +170,20 @@ public class VariantRecalibratorModelOutputUnitTest {
Assert.assertEquals(badGaussian1.sigma.get(i,j), (Double)badSigma.get(i,annotationList.get(j)), epsilon);
}
}
// Now test model report reading
// Read the gaussian weighting tables
final GATKReportTable nPMixTable = report.getTable("BadGaussianPMix");
final GATKReportTable pPMixTable = report.getTable("GoodGaussianPMix");
GaussianMixtureModel goodModelFromFile = vqsr.GMMFromTables(goodMus, goodSigma, pPMixTable, annotationList.size());
GaussianMixtureModel badModelFromFile = vqsr.GMMFromTables(badMus, badSigma, nPMixTable, annotationList.size());
testGMMsForEquality(goodModel, goodModelFromFile, epsilon);
testGMMsForEquality(badModel, badModelFromFile, epsilon);
}
@Test
//This is tested separately to avoid setting up a VariantDataManager and populating it with fake data
public void testAnnotationNormalizationOutput() {
@ -211,4 +223,23 @@ public class VariantRecalibratorModelOutputUnitTest {
return returnString;
}
private void testGMMsForEquality(GaussianMixtureModel gmm1, GaussianMixtureModel gmm2, double epsilon){
Assert.assertEquals(gmm1.getModelGaussians().size(), gmm2.getModelGaussians().size(), 0);
for(int k = 0; k < gmm1.getModelGaussians().size(); k++) {
final MultivariateGaussian g = gmm1.getModelGaussians().get(k);
final MultivariateGaussian gFile = gmm2.getModelGaussians().get(k);
for(int i = 0; i < g.mu.length; i++){
Assert.assertEquals(g.mu[i], gFile.mu[i], epsilon);
}
for(int i = 0; i < g.sigma.getRowDimension(); i++) {
for (int j = 0; j < g.sigma.getColumnDimension(); j++) {
Assert.assertEquals(g.sigma.get(i, j), gFile.sigma.get(i, j), epsilon);
}
}
}
}
}