move model file parsing to initialize

This commit is contained in:
Samuel Friedman 2017-05-01 15:52:45 -04:00
parent d06a6c7318
commit 1b4ac51048
1 changed files with 33 additions and 34 deletions

View File

@ -319,6 +319,8 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
private PrintStream tranchesStream; private PrintStream tranchesStream;
private final Set<String> ignoreInputFilterSet = new TreeSet<>(); private final Set<String> ignoreInputFilterSet = new TreeSet<>();
private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC ); private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC );
private GaussianMixtureModel goodModel = null;
private GaussianMixtureModel badModel = null;
//--------------------------------------------------------------------------------------------------------------- //---------------------------------------------------------------------------------------------------------------
// //
@ -356,6 +358,27 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
throw new UserException.CommandLineException( "No truth set found! Please provide sets of known polymorphic loci marked with the truth=true ROD binding tag. For example, -resource:hapmap,VCF,known=false,training=true,truth=true,prior=12.0 hapmapFile.vcf" ); throw new UserException.CommandLineException( "No truth set found! Please provide sets of known polymorphic loci marked with the truth=true ROD binding tag. For example, -resource:hapmap,VCF,known=false,training=true,truth=true,prior=12.0 hapmapFile.vcf" );
} }
final File inputFile = new File(inputModel);
if (inputFile.exists()) { // Load GMM from a file
logger.info("Loading model from:" + inputModel);
final GATKReport reportIn = new GATKReport(inputFile);
// Read all the tables
final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
final int numAnnotations = dataManager.getMeanVector().length;
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
}
final Set<VCFHeaderLine> hInfo = new HashSet<>(); final Set<VCFHeaderLine> hInfo = new HashSet<>();
ApplyRecalibration.addVQSRStandardHeaderLines(hInfo); ApplyRecalibration.addVQSRStandardHeaderLines(hInfo);
recalWriter.writeHeader( new VCFHeader(hInfo) ); recalWriter.writeHeader( new VCFHeader(hInfo) );
@ -490,29 +513,14 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
dataManager.setData(reduceSum); dataManager.setData(reduceSum);
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
final GaussianMixtureModel goodModel, badModel; //final GaussianMixtureModel goodModel, badModel;
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData(); final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
final List<VariantDatum> negativeTrainingData; final List<VariantDatum> negativeTrainingData;
final File inputFile = new File(inputModel); if (goodModel != null && badModel != null){ // GMMs were loaded from a file
if (inputFile.exists()){ // Load GMM from a file // Keeping this to maintain reproducibility between runs with and without serialized GMMs
logger.info("Loading model from:"+inputModel);
final GATKReport reportIn = new GATKReport(inputFile);
// Read all the tables
final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
final int numAnnotations = dataManager.getMeanVector().length;
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
engine.evaluateData(dataManager.getData(), goodModel, false); engine.evaluateData(dataManager.getData(), goodModel, false);
negativeTrainingData = dataManager.selectWorstVariants(); negativeTrainingData = dataManager.selectWorstVariants();
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
} else { // Generate the GMMs from scratch } else { // Generate the GMMs from scratch
// Generate the positive model using the training data and evaluate each variant // Generate the positive model using the training data and evaluate each variant
goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS); goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
@ -586,14 +594,13 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
int curAnnotation = 0; int curAnnotation = 0;
for (GATKReportColumn reportColumn : muTable.getColumnInfo() ) { for (GATKReportColumn reportColumn : muTable.getColumnInfo() ) {
logger.info("Report column name is:" + reportColumn.getColumnName());
if (!reportColumn.getColumnName().equals("Gaussian")) { if (!reportColumn.getColumnName().equals("Gaussian")) {
for (int row = 0; row < muTable.getNumRows(); row++) { for (int row = 0; row < muTable.getNumRows(); row++) {
if (gaussianList.size() <= row){ if (gaussianList.size() <= row){
MultivariateGaussian mg = new MultivariateGaussian(numAnnotations); MultivariateGaussian mg = new MultivariateGaussian(numAnnotations);
gaussianList.add(mg); gaussianList.add(mg);
} }
gaussianList.get(row).mu[curAnnotation] = reportObjectToDouble(muTable.get(row, reportColumn.getColumnName())); gaussianList.get(row).mu[curAnnotation] = Double.parseDouble( (String)muTable.get(row, reportColumn.getColumnName()));
} }
curAnnotation++; curAnnotation++;
} }
@ -602,7 +609,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) { for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) {
if (reportColumn.getColumnName().equals("pMixLog10")) { if (reportColumn.getColumnName().equals("pMixLog10")) {
for (int row = 0; row < pmixTable.getNumRows(); row++) { for (int row = 0; row < pmixTable.getNumRows(); row++) {
gaussianList.get(row).pMixtureLog10 = reportObjectToDouble(pmixTable.get(row, reportColumn.getColumnName())); gaussianList.get(row).pMixtureLog10 = Double.parseDouble( (String)pmixTable.get(row, reportColumn.getColumnName()));
} }
} }
} }
@ -615,7 +622,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
for (int row = 0; row < sigmaTable.getNumRows(); row++) { for (int row = 0; row < sigmaTable.getNumRows(); row++) {
int curGaussian = row / numAnnotations; int curGaussian = row / numAnnotations;
int curI = row % numAnnotations; int curI = row % numAnnotations;
double curVal = reportObjectToDouble(sigmaTable.get(row, reportColumn.getColumnName())); double curVal = Double.parseDouble( (String)sigmaTable.get(row, reportColumn.getColumnName()));
gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal); gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal);
} }
@ -627,14 +634,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
} }
private double reportObjectToDouble(Object obj){
if (obj instanceof String){
return Double.parseDouble((String)obj);
} else {
return (Double) obj;
}
}
private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){ private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){
double[] stdVector = {}; double[] stdVector = {};
@ -642,7 +641,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
if (reportColumn.getColumnName().equals("Standarddeviation")) { if (reportColumn.getColumnName().equals("Standarddeviation")) {
stdVector = new double[astdTable.getNumRows()]; stdVector = new double[astdTable.getNumRows()];
for (int row = 0; row < astdTable.getNumRows(); row++) { for (int row = 0; row < astdTable.getNumRows(); row++) {
stdVector[row] = reportObjectToDouble(astdTable.get(row, reportColumn.getColumnName())); stdVector[row] = Double.parseDouble( (String)astdTable.get(row, reportColumn.getColumnName()));
} }
} }
} }
@ -657,7 +656,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
if (reportColumn.getColumnName().equals("Mean")) { if (reportColumn.getColumnName().equals("Mean")) {
meanVector = new double[amTable.getNumRows()]; meanVector = new double[amTable.getNumRows()];
for (int row = 0; row < amTable.getNumRows(); row++) { for (int row = 0; row < amTable.getNumRows(); row++) {
meanVector[row] = reportObjectToDouble(amTable.get(row, reportColumn.getColumnName())); meanVector[row] = Double.parseDouble( (String)amTable.get(row, reportColumn.getColumnName()));
} }
} }
} }
@ -666,7 +665,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
} }
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) { protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
final String formatString = "%.25f"; final String formatString = "%.8f";
final GATKReport report = new GATKReport(); final GATKReport report = new GATKReport();
if (dataManager != null) { //for unit test if (dataManager != null) { //for unit test