move model file parsing to initialize
This commit is contained in:
parent
d06a6c7318
commit
1b4ac51048
|
|
@ -319,6 +319,8 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
private PrintStream tranchesStream;
|
||||
private final Set<String> ignoreInputFilterSet = new TreeSet<>();
|
||||
private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC );
|
||||
private GaussianMixtureModel goodModel = null;
|
||||
private GaussianMixtureModel badModel = null;
|
||||
|
||||
//---------------------------------------------------------------------------------------------------------------
|
||||
//
|
||||
|
|
@ -356,7 +358,28 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
throw new UserException.CommandLineException( "No truth set found! Please provide sets of known polymorphic loci marked with the truth=true ROD binding tag. For example, -resource:hapmap,VCF,known=false,training=true,truth=true,prior=12.0 hapmapFile.vcf" );
|
||||
}
|
||||
|
||||
final Set<VCFHeaderLine> hInfo = new HashSet<>();
|
||||
final File inputFile = new File(inputModel);
|
||||
if (inputFile.exists()) { // Load GMM from a file
|
||||
logger.info("Loading model from:" + inputModel);
|
||||
final GATKReport reportIn = new GATKReport(inputFile);
|
||||
|
||||
// Read all the tables
|
||||
final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
|
||||
final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
|
||||
final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
|
||||
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
|
||||
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
||||
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
||||
final int numAnnotations = dataManager.getMeanVector().length;
|
||||
|
||||
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
||||
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
||||
}
|
||||
|
||||
|
||||
|
||||
|
||||
final Set<VCFHeaderLine> hInfo = new HashSet<>();
|
||||
ApplyRecalibration.addVQSRStandardHeaderLines(hInfo);
|
||||
recalWriter.writeHeader( new VCFHeader(hInfo) );
|
||||
|
||||
|
|
@ -490,29 +513,14 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
dataManager.setData(reduceSum);
|
||||
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
||||
|
||||
final GaussianMixtureModel goodModel, badModel;
|
||||
//final GaussianMixtureModel goodModel, badModel;
|
||||
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
||||
final List<VariantDatum> negativeTrainingData;
|
||||
|
||||
final File inputFile = new File(inputModel);
|
||||
if (inputFile.exists()){ // Load GMM from a file
|
||||
logger.info("Loading model from:"+inputModel);
|
||||
final GATKReport reportIn = new GATKReport(inputFile);
|
||||
|
||||
// Read all the tables
|
||||
final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
|
||||
final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
|
||||
final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
|
||||
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
|
||||
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
||||
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
||||
final int numAnnotations = dataManager.getMeanVector().length;
|
||||
|
||||
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
||||
if (goodModel != null && badModel != null){ // GMMs were loaded from a file
|
||||
// Keeping this to maintain reproducibility between runs with and without serialized GMMs
|
||||
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||
negativeTrainingData = dataManager.selectWorstVariants();
|
||||
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
||||
|
||||
} else { // Generate the GMMs from scratch
|
||||
// Generate the positive model using the training data and evaluate each variant
|
||||
goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
|
||||
|
|
@ -586,14 +594,13 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
|
||||
int curAnnotation = 0;
|
||||
for (GATKReportColumn reportColumn : muTable.getColumnInfo() ) {
|
||||
logger.info("Report column name is:" + reportColumn.getColumnName());
|
||||
if (!reportColumn.getColumnName().equals("Gaussian")) {
|
||||
for (int row = 0; row < muTable.getNumRows(); row++) {
|
||||
if (gaussianList.size() <= row){
|
||||
MultivariateGaussian mg = new MultivariateGaussian(numAnnotations);
|
||||
gaussianList.add(mg);
|
||||
}
|
||||
gaussianList.get(row).mu[curAnnotation] = reportObjectToDouble(muTable.get(row, reportColumn.getColumnName()));
|
||||
gaussianList.get(row).mu[curAnnotation] = Double.parseDouble( (String)muTable.get(row, reportColumn.getColumnName()));
|
||||
}
|
||||
curAnnotation++;
|
||||
}
|
||||
|
|
@ -602,7 +609,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) {
|
||||
if (reportColumn.getColumnName().equals("pMixLog10")) {
|
||||
for (int row = 0; row < pmixTable.getNumRows(); row++) {
|
||||
gaussianList.get(row).pMixtureLog10 = reportObjectToDouble(pmixTable.get(row, reportColumn.getColumnName()));
|
||||
gaussianList.get(row).pMixtureLog10 = Double.parseDouble( (String)pmixTable.get(row, reportColumn.getColumnName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -615,7 +622,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
for (int row = 0; row < sigmaTable.getNumRows(); row++) {
|
||||
int curGaussian = row / numAnnotations;
|
||||
int curI = row % numAnnotations;
|
||||
double curVal = reportObjectToDouble(sigmaTable.get(row, reportColumn.getColumnName()));
|
||||
double curVal = Double.parseDouble( (String)sigmaTable.get(row, reportColumn.getColumnName()));
|
||||
gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal);
|
||||
|
||||
}
|
||||
|
|
@ -627,14 +634,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
|
||||
}
|
||||
|
||||
private double reportObjectToDouble(Object obj){
|
||||
if (obj instanceof String){
|
||||
return Double.parseDouble((String)obj);
|
||||
} else {
|
||||
return (Double) obj;
|
||||
}
|
||||
}
|
||||
|
||||
private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){
|
||||
double[] stdVector = {};
|
||||
|
||||
|
|
@ -642,7 +641,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
if (reportColumn.getColumnName().equals("Standarddeviation")) {
|
||||
stdVector = new double[astdTable.getNumRows()];
|
||||
for (int row = 0; row < astdTable.getNumRows(); row++) {
|
||||
stdVector[row] = reportObjectToDouble(astdTable.get(row, reportColumn.getColumnName()));
|
||||
stdVector[row] = Double.parseDouble( (String)astdTable.get(row, reportColumn.getColumnName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -657,7 +656,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
if (reportColumn.getColumnName().equals("Mean")) {
|
||||
meanVector = new double[amTable.getNumRows()];
|
||||
for (int row = 0; row < amTable.getNumRows(); row++) {
|
||||
meanVector[row] = reportObjectToDouble(amTable.get(row, reportColumn.getColumnName()));
|
||||
meanVector[row] = Double.parseDouble( (String)amTable.get(row, reportColumn.getColumnName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -666,7 +665,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
}
|
||||
|
||||
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
|
||||
final String formatString = "%.25f";
|
||||
final String formatString = "%.8f";
|
||||
final GATKReport report = new GATKReport();
|
||||
|
||||
if (dataManager != null) { //for unit test
|
||||
|
|
|
|||
Loading…
Reference in New Issue