move model file parsing to initialize
This commit is contained in:
parent
d06a6c7318
commit
1b4ac51048
|
|
@ -319,6 +319,8 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
private PrintStream tranchesStream;
|
private PrintStream tranchesStream;
|
||||||
private final Set<String> ignoreInputFilterSet = new TreeSet<>();
|
private final Set<String> ignoreInputFilterSet = new TreeSet<>();
|
||||||
private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC );
|
private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC );
|
||||||
|
private GaussianMixtureModel goodModel = null;
|
||||||
|
private GaussianMixtureModel badModel = null;
|
||||||
|
|
||||||
//---------------------------------------------------------------------------------------------------------------
|
//---------------------------------------------------------------------------------------------------------------
|
||||||
//
|
//
|
||||||
|
|
@ -356,6 +358,27 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
throw new UserException.CommandLineException( "No truth set found! Please provide sets of known polymorphic loci marked with the truth=true ROD binding tag. For example, -resource:hapmap,VCF,known=false,training=true,truth=true,prior=12.0 hapmapFile.vcf" );
|
throw new UserException.CommandLineException( "No truth set found! Please provide sets of known polymorphic loci marked with the truth=true ROD binding tag. For example, -resource:hapmap,VCF,known=false,training=true,truth=true,prior=12.0 hapmapFile.vcf" );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
final File inputFile = new File(inputModel);
|
||||||
|
if (inputFile.exists()) { // Load GMM from a file
|
||||||
|
logger.info("Loading model from:" + inputModel);
|
||||||
|
final GATKReport reportIn = new GATKReport(inputFile);
|
||||||
|
|
||||||
|
// Read all the tables
|
||||||
|
final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
|
||||||
|
final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
|
||||||
|
final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
|
||||||
|
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
|
||||||
|
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
||||||
|
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
||||||
|
final int numAnnotations = dataManager.getMeanVector().length;
|
||||||
|
|
||||||
|
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
||||||
|
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
final Set<VCFHeaderLine> hInfo = new HashSet<>();
|
final Set<VCFHeaderLine> hInfo = new HashSet<>();
|
||||||
ApplyRecalibration.addVQSRStandardHeaderLines(hInfo);
|
ApplyRecalibration.addVQSRStandardHeaderLines(hInfo);
|
||||||
recalWriter.writeHeader( new VCFHeader(hInfo) );
|
recalWriter.writeHeader( new VCFHeader(hInfo) );
|
||||||
|
|
@ -490,29 +513,14 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
dataManager.setData(reduceSum);
|
dataManager.setData(reduceSum);
|
||||||
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
||||||
|
|
||||||
final GaussianMixtureModel goodModel, badModel;
|
//final GaussianMixtureModel goodModel, badModel;
|
||||||
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
||||||
final List<VariantDatum> negativeTrainingData;
|
final List<VariantDatum> negativeTrainingData;
|
||||||
|
|
||||||
final File inputFile = new File(inputModel);
|
if (goodModel != null && badModel != null){ // GMMs were loaded from a file
|
||||||
if (inputFile.exists()){ // Load GMM from a file
|
// Keeping this to maintain reproducibility between runs with and without serialized GMMs
|
||||||
logger.info("Loading model from:"+inputModel);
|
|
||||||
final GATKReport reportIn = new GATKReport(inputFile);
|
|
||||||
|
|
||||||
// Read all the tables
|
|
||||||
final GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
|
|
||||||
final GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
|
|
||||||
final GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
|
|
||||||
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
|
|
||||||
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
|
||||||
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
|
||||||
final int numAnnotations = dataManager.getMeanVector().length;
|
|
||||||
|
|
||||||
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
|
||||||
engine.evaluateData(dataManager.getData(), goodModel, false);
|
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||||
negativeTrainingData = dataManager.selectWorstVariants();
|
negativeTrainingData = dataManager.selectWorstVariants();
|
||||||
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
|
||||||
|
|
||||||
} else { // Generate the GMMs from scratch
|
} else { // Generate the GMMs from scratch
|
||||||
// Generate the positive model using the training data and evaluate each variant
|
// Generate the positive model using the training data and evaluate each variant
|
||||||
goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
|
goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS);
|
||||||
|
|
@ -586,14 +594,13 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
|
|
||||||
int curAnnotation = 0;
|
int curAnnotation = 0;
|
||||||
for (GATKReportColumn reportColumn : muTable.getColumnInfo() ) {
|
for (GATKReportColumn reportColumn : muTable.getColumnInfo() ) {
|
||||||
logger.info("Report column name is:" + reportColumn.getColumnName());
|
|
||||||
if (!reportColumn.getColumnName().equals("Gaussian")) {
|
if (!reportColumn.getColumnName().equals("Gaussian")) {
|
||||||
for (int row = 0; row < muTable.getNumRows(); row++) {
|
for (int row = 0; row < muTable.getNumRows(); row++) {
|
||||||
if (gaussianList.size() <= row){
|
if (gaussianList.size() <= row){
|
||||||
MultivariateGaussian mg = new MultivariateGaussian(numAnnotations);
|
MultivariateGaussian mg = new MultivariateGaussian(numAnnotations);
|
||||||
gaussianList.add(mg);
|
gaussianList.add(mg);
|
||||||
}
|
}
|
||||||
gaussianList.get(row).mu[curAnnotation] = reportObjectToDouble(muTable.get(row, reportColumn.getColumnName()));
|
gaussianList.get(row).mu[curAnnotation] = Double.parseDouble( (String)muTable.get(row, reportColumn.getColumnName()));
|
||||||
}
|
}
|
||||||
curAnnotation++;
|
curAnnotation++;
|
||||||
}
|
}
|
||||||
|
|
@ -602,7 +609,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) {
|
for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) {
|
||||||
if (reportColumn.getColumnName().equals("pMixLog10")) {
|
if (reportColumn.getColumnName().equals("pMixLog10")) {
|
||||||
for (int row = 0; row < pmixTable.getNumRows(); row++) {
|
for (int row = 0; row < pmixTable.getNumRows(); row++) {
|
||||||
gaussianList.get(row).pMixtureLog10 = reportObjectToDouble(pmixTable.get(row, reportColumn.getColumnName()));
|
gaussianList.get(row).pMixtureLog10 = Double.parseDouble( (String)pmixTable.get(row, reportColumn.getColumnName()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -615,7 +622,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
for (int row = 0; row < sigmaTable.getNumRows(); row++) {
|
for (int row = 0; row < sigmaTable.getNumRows(); row++) {
|
||||||
int curGaussian = row / numAnnotations;
|
int curGaussian = row / numAnnotations;
|
||||||
int curI = row % numAnnotations;
|
int curI = row % numAnnotations;
|
||||||
double curVal = reportObjectToDouble(sigmaTable.get(row, reportColumn.getColumnName()));
|
double curVal = Double.parseDouble( (String)sigmaTable.get(row, reportColumn.getColumnName()));
|
||||||
gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal);
|
gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
@ -627,14 +634,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private double reportObjectToDouble(Object obj){
|
|
||||||
if (obj instanceof String){
|
|
||||||
return Double.parseDouble((String)obj);
|
|
||||||
} else {
|
|
||||||
return (Double) obj;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){
|
private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){
|
||||||
double[] stdVector = {};
|
double[] stdVector = {};
|
||||||
|
|
||||||
|
|
@ -642,7 +641,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
if (reportColumn.getColumnName().equals("Standarddeviation")) {
|
if (reportColumn.getColumnName().equals("Standarddeviation")) {
|
||||||
stdVector = new double[astdTable.getNumRows()];
|
stdVector = new double[astdTable.getNumRows()];
|
||||||
for (int row = 0; row < astdTable.getNumRows(); row++) {
|
for (int row = 0; row < astdTable.getNumRows(); row++) {
|
||||||
stdVector[row] = reportObjectToDouble(astdTable.get(row, reportColumn.getColumnName()));
|
stdVector[row] = Double.parseDouble( (String)astdTable.get(row, reportColumn.getColumnName()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -657,7 +656,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
if (reportColumn.getColumnName().equals("Mean")) {
|
if (reportColumn.getColumnName().equals("Mean")) {
|
||||||
meanVector = new double[amTable.getNumRows()];
|
meanVector = new double[amTable.getNumRows()];
|
||||||
for (int row = 0; row < amTable.getNumRows(); row++) {
|
for (int row = 0; row < amTable.getNumRows(); row++) {
|
||||||
meanVector[row] = reportObjectToDouble(amTable.get(row, reportColumn.getColumnName()));
|
meanVector[row] = Double.parseDouble( (String)amTable.get(row, reportColumn.getColumnName()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -666,7 +665,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
}
|
}
|
||||||
|
|
||||||
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
|
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
|
||||||
final String formatString = "%.25f";
|
final String formatString = "%.8f";
|
||||||
final GATKReport report = new GATKReport();
|
final GATKReport report = new GATKReport();
|
||||||
|
|
||||||
if (dataManager != null) { //for unit test
|
if (dataManager != null) { //for unit test
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue