scientific notation in model report and basic test
This commit is contained in:
parent
4df71a3ca1
commit
f5d133df87
|
|
@ -110,6 +110,11 @@ public class GaussianMixtureModel {
|
|||
this.shrinkage = shrinkage;
|
||||
this.dirichletParameter = dirichletParameter;
|
||||
this.priorCounts = priorCounts;
|
||||
for( final MultivariateGaussian gaussian : gaussians ) {
|
||||
gaussian.hyperParameter_a = priorCounts;
|
||||
gaussian.hyperParameter_b = shrinkage;
|
||||
gaussian.hyperParameter_lambda = dirichletParameter;
|
||||
}
|
||||
empiricalMu = new double[numAnnotations];
|
||||
empiricalSigma = new Matrix(numAnnotations, numAnnotations);
|
||||
isModelReadyForEvaluation = false;
|
||||
|
|
|
|||
|
|
@ -277,7 +277,7 @@ public class VariantDataManager {
|
|||
}
|
||||
}
|
||||
|
||||
logger.info( "Training with worst " + trainingData.size() + " scoring variants --> variants with LOD <= " + String.format("%.4f", VRAC.BAD_LOD_CUTOFF) + "." );
|
||||
logger.info( "Selected worst " + trainingData.size() + " scoring variants --> variants with LOD <= " + String.format("%.4f", VRAC.BAD_LOD_CUTOFF) + "." );
|
||||
|
||||
return trainingData;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -376,10 +376,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
final Set<VCFHeaderLine> hInfo = new HashSet<>();
|
||||
final Set<VCFHeaderLine> hInfo = new HashSet<>();
|
||||
ApplyRecalibration.addVQSRStandardHeaderLines(hInfo);
|
||||
recalWriter.writeHeader( new VCFHeader(hInfo) );
|
||||
|
||||
|
|
@ -513,12 +510,11 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
dataManager.setData(reduceSum);
|
||||
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
||||
|
||||
//final GaussianMixtureModel goodModel, badModel;
|
||||
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
||||
final List<VariantDatum> negativeTrainingData;
|
||||
|
||||
if (goodModel != null && badModel != null){ // GMMs were loaded from a file
|
||||
// Keeping this to maintain reproducibility between runs with and without serialized GMMs
|
||||
logger.info("Using serialized GMMs from file...");
|
||||
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||
negativeTrainingData = dataManager.selectWorstVariants();
|
||||
} else { // Generate the GMMs from scratch
|
||||
|
|
@ -527,12 +523,12 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
||||
negativeTrainingData = dataManager.selectWorstVariants();
|
||||
|
||||
badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
|
||||
|
||||
if (badModel.failedToConverge || goodModel.failedToConverge) {
|
||||
throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example)."));
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
||||
|
|
@ -589,7 +585,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
* @param numAnnotations Number of annotations, i.e. Dimension of the annotation space in which the Gaussians live
|
||||
* @return a GaussianMixtureModel whose state reflects the state recorded in the tables.
|
||||
*/
|
||||
private GaussianMixtureModel GMMFromTables(final GATKReportTable muTable, final GATKReportTable sigmaTable, final GATKReportTable pmixTable, final int numAnnotations){
|
||||
protected GaussianMixtureModel GMMFromTables(final GATKReportTable muTable, final GATKReportTable sigmaTable, final GATKReportTable pmixTable, final int numAnnotations){
|
||||
List<MultivariateGaussian> gaussianList = new ArrayList<>();
|
||||
|
||||
int curAnnotation = 0;
|
||||
|
|
@ -665,7 +661,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
}
|
||||
|
||||
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
|
||||
final String formatString = "%.8f";
|
||||
final String formatString = "%.8E";
|
||||
final GATKReport report = new GATKReport();
|
||||
|
||||
if (dataManager != null) { //for unit test
|
||||
|
|
|
|||
|
|
@ -98,6 +98,7 @@ public class VariantRecalibratorEngine {
|
|||
try {
|
||||
model.precomputeDenominatorForEvaluation();
|
||||
} catch( Exception e ) {
|
||||
logger.warn("Model could not pre-compute denominators.");
|
||||
model.failedToConverge = true;
|
||||
return;
|
||||
}
|
||||
|
|
@ -107,6 +108,7 @@ public class VariantRecalibratorEngine {
|
|||
for( final VariantDatum datum : data ) {
|
||||
final double thisLod = evaluateDatum( datum, model );
|
||||
if( Double.isNaN(thisLod) ) {
|
||||
logger.warn("Evaluate datum returned a NaN.");
|
||||
model.failedToConverge = true;
|
||||
return;
|
||||
}
|
||||
|
|
@ -142,7 +144,7 @@ public class VariantRecalibratorEngine {
|
|||
// Private Methods used for generating a GaussianMixtureModel
|
||||
/////////////////////////////
|
||||
|
||||
private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List<VariantDatum> data ) {
|
||||
protected void variationalBayesExpectationMaximization(final GaussianMixtureModel model, final List<VariantDatum> data) {
|
||||
|
||||
model.initializeRandomModel( data, VRAC.NUM_KMEANS_ITERATIONS );
|
||||
|
||||
|
|
|
|||
|
|
@ -170,8 +170,20 @@ public class VariantRecalibratorModelOutputUnitTest {
|
|||
Assert.assertEquals(badGaussian1.sigma.get(i,j), (Double)badSigma.get(i,annotationList.get(j)), epsilon);
|
||||
}
|
||||
}
|
||||
|
||||
// Now test model report reading
|
||||
// Read the gaussian weighting tables
|
||||
final GATKReportTable nPMixTable = report.getTable("BadGaussianPMix");
|
||||
final GATKReportTable pPMixTable = report.getTable("GoodGaussianPMix");
|
||||
|
||||
GaussianMixtureModel goodModelFromFile = vqsr.GMMFromTables(goodMus, goodSigma, pPMixTable, annotationList.size());
|
||||
GaussianMixtureModel badModelFromFile = vqsr.GMMFromTables(badMus, badSigma, nPMixTable, annotationList.size());
|
||||
|
||||
testGMMsForEquality(goodModel, goodModelFromFile, epsilon);
|
||||
testGMMsForEquality(badModel, badModelFromFile, epsilon);
|
||||
}
|
||||
|
||||
|
||||
@Test
|
||||
//This is tested separately to avoid setting up a VariantDataManager and populating it with fake data
|
||||
public void testAnnotationNormalizationOutput() {
|
||||
|
|
@ -211,4 +223,23 @@ public class VariantRecalibratorModelOutputUnitTest {
|
|||
return returnString;
|
||||
}
|
||||
|
||||
private void testGMMsForEquality(GaussianMixtureModel gmm1, GaussianMixtureModel gmm2, double epsilon){
|
||||
Assert.assertEquals(gmm1.getModelGaussians().size(), gmm2.getModelGaussians().size(), 0);
|
||||
|
||||
for(int k = 0; k < gmm1.getModelGaussians().size(); k++) {
|
||||
final MultivariateGaussian g = gmm1.getModelGaussians().get(k);
|
||||
final MultivariateGaussian gFile = gmm2.getModelGaussians().get(k);
|
||||
|
||||
for(int i = 0; i < g.mu.length; i++){
|
||||
Assert.assertEquals(g.mu[i], gFile.mu[i], epsilon);
|
||||
}
|
||||
|
||||
for(int i = 0; i < g.sigma.getRowDimension(); i++) {
|
||||
for (int j = 0; j < g.sigma.getColumnDimension(); j++) {
|
||||
Assert.assertEquals(g.sigma.get(i, j), gFile.sigma.get(i, j), epsilon);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
}
|
||||
Loading…
Reference in New Issue