small code cleanup

This commit is contained in:
Samuel Friedman 2017-04-26 16:01:16 -04:00
parent a8f70c891f
commit 57c064eaa3
1 changed files with 46 additions and 68 deletions

View File

@ -280,7 +280,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
*/
@Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model fit to the file specified by -modelFile or to stdout", required=false)
private boolean outputModel = false;
@Argument(fullName="input_model", shortName = "inputModel", doc="If specified, the variant recalibrator will read the VQSR model from the file specified by -modelFile", required=false)
@Argument(fullName="input_model", shortName = "inputModel", doc="If specified, the variant recalibrator will read the VQSR model from this file path.", required=false)
private String inputModel = "";
@Output(fullName="model_file", shortName = "modelFile", doc="A GATKReport containing the positive and negative model fits", required=false)
protected PrintStream modelReport = null;
@ -497,10 +497,16 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
File inputFile = new File(inputModel);
if (inputFile.exists()){ // Load GMM from a file
logger.info("Loading model from:"+inputModel);
GATKReport reportIn = new GATKReport(inputFile);
// Read all the tables
GATKReportTable amTable = reportIn.getTable("AnnotationMeans");
GATKReportTable astdTable = reportIn.getTable("AnnotationStdevs");
// Should have same number of means and standard deviations.
assert(amTable.getNumRows() == astdTable.getNumRows() );
GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
@ -508,41 +514,13 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
double[] meanVector;
double[] stdVector;
int numAnnotations = 0;
for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) {
if (reportColumn.getColumnName().equals("Mean")) {
meanVector = new double[amTable.getNumRows()];
numAnnotations = amTable.getNumRows();
for (int row = 0; row < amTable.getNumRows(); row++) {
meanVector[row] = (double) amTable.get(row, reportColumn.getColumnName());
}
logger.info("Got mean Vector:" + Arrays.toString(meanVector));
}
}
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
logger.info("Report column name is:" + reportColumn.getColumnName());
if (reportColumn.getColumnName().equals("Standarddeviation")) {
stdVector = new double[astdTable.getNumRows()];
for (int row = 0; row < astdTable.getNumRows(); row++) {
stdVector[row] = (double) astdTable.get(row, reportColumn.getColumnName());
}
}
}
int numAnnotations = amTable.getNumRows();
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
//Utils.getRandomGenerator().setSeed(12878);
engine.evaluateData(dataManager.getData(), goodModel, false);
negativeTrainingData = dataManager.selectWorstVariants();
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
logger.info("Loaded GMM from file:" + inputModel);
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
//Utils.getRandomGenerator().setSeed(12878);
engine.evaluateData(dataManager.getData(), badModel, true);
} else { // Generate the GMMs from scratch
@ -568,9 +546,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
report.print(modelReport);
}
boolean writeFeatures = true;
if (writeFeatures) writeFeaturesFiles(positiveTrainingData, negativeTrainingData);
engine.calculateWorstPerformingAnnotation(dataManager.getData(), goodModel, badModel);
// Find the VQSLOD cutoff values which correspond to the various tranches of calls requested by the user
@ -609,7 +584,15 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
}
}
public GaussianMixtureModel GMMFromTables(GATKReportTable muTable, GATKReportTable sigmaTable, GATKReportTable pmixTable, int numAnnotations){
/**
* Rebuild a Gaussian Mixture Model from gaussian means and co-variates stored in a GATKReportTables
* @param muTable Table of Gaussian means
* @param sigmaTable Table of Gaussian co-variates
* @param pmixTable Table of PMixLog10 values
* @param numAnnotations Number of annotations, i.e. Dimension of the annotation space in which the Gaussians live
* @return a GaussianMixtureModel whose state reflects the state recorded in the tables.
*/
private GaussianMixtureModel GMMFromTables(GATKReportTable muTable, GATKReportTable sigmaTable, GATKReportTable pmixTable, int numAnnotations){
List<MultivariateGaussian> gaussianList = new ArrayList<>();
int curAnnotation = 0;
@ -621,7 +604,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
MultivariateGaussian mg = new MultivariateGaussian(numAnnotations);
gaussianList.add(mg);
}
gaussianList.get(row).mu[curAnnotation] = (double) muTable.get(row, reportColumn.getColumnName());
gaussianList.get(row).mu[curAnnotation] = Double.parseDouble((String)muTable.get(row, reportColumn.getColumnName()));
}
curAnnotation++;
}
@ -630,7 +613,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) {
if (reportColumn.getColumnName().equals("pMixLog10")) {
for (int row = 0; row < pmixTable.getNumRows(); row++) {
gaussianList.get(row).pMixtureLog10 = (double) pmixTable.get(row, reportColumn.getColumnName());
gaussianList.get(row).pMixtureLog10 = Double.parseDouble((String)pmixTable.get(row, reportColumn.getColumnName()));
}
}
}
@ -643,7 +626,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
for (int row = 0; row < sigmaTable.getNumRows(); row++) {
int curGaussian = row / numAnnotations;
int curI = row % numAnnotations;
double curVal = (double) sigmaTable.get(row, reportColumn.getColumnName());
double curVal = Double.parseDouble((String)sigmaTable.get(row, reportColumn.getColumnName()));
gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal);
}
@ -655,41 +638,36 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
}
private void writeFeaturesFiles(List<VariantDatum> positiveTrainingData, List<VariantDatum> negativeTrainingData){
//Begin Sam Hacking
try {
File file = new File("/Users/sam/data/haploid_features.txt");
file.createNewFile();
File badFile = new File("/Users/sam/data/haploid_bad_features.txt");
badFile.createNewFile();
private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){
double[] stdVector = {};
FileWriter fw = new FileWriter(file.getAbsoluteFile());
BufferedWriter bw = new BufferedWriter(fw);
for(int jj = 0; jj < positiveTrainingData.size(); jj++){
VariantDatum v = positiveTrainingData.get(jj);
for(int kk = 0; kk < v.annotations.length; kk++){
bw.write(Double.toString(v.annotations[kk]));
bw.write(" ");
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
logger.info("Report column name is:" + reportColumn.getColumnName());
if (reportColumn.getColumnName().equals("Standarddeviation")) {
stdVector = new double[astdTable.getNumRows()];
for (int row = 0; row < astdTable.getNumRows(); row++) {
stdVector[row] = Double.parseDouble((String) astdTable.get(row, reportColumn.getColumnName()));
}
bw.write("\n");
}
bw.close();
fw = new FileWriter(badFile.getAbsoluteFile());
bw = new BufferedWriter(fw);
for(int jj = 0; jj < negativeTrainingData.size(); jj++){
VariantDatum v = negativeTrainingData.get(jj);
for(int kk = 0; kk < v.annotations.length; kk++){
bw.write(Double.toString(v.annotations[kk]));
bw.write(" ");
}
bw.write("\n");
}
bw.close();
}catch(IOException e){
e.printStackTrace();
}
// End Sam Hacking
return stdVector;
}
private double[] getMeansFromTable(GATKReportTable amTable){
double[] meanVector = {};
for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) {
if (reportColumn.getColumnName().equals("Mean")) {
meanVector = new double[amTable.getNumRows()];
for (int row = 0; row < amTable.getNumRows(); row++) {
meanVector[row] = Double.parseDouble((String) amTable.get(row, reportColumn.getColumnName()));
}
logger.info("Got mean Vector:" + Arrays.toString(meanVector));
}
}
return meanVector;
}
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {