small code cleanup
This commit is contained in:
parent
a8f70c891f
commit
57c064eaa3
|
|
@ -280,7 +280,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
*/
|
||||
@Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model fit to the file specified by -modelFile or to stdout", required=false)
|
||||
private boolean outputModel = false;
|
||||
@Argument(fullName="input_model", shortName = "inputModel", doc="If specified, the variant recalibrator will read the VQSR model from the file specified by -modelFile", required=false)
|
||||
@Argument(fullName="input_model", shortName = "inputModel", doc="If specified, the variant recalibrator will read the VQSR model from this file path.", required=false)
|
||||
private String inputModel = "";
|
||||
@Output(fullName="model_file", shortName = "modelFile", doc="A GATKReport containing the positive and negative model fits", required=false)
|
||||
protected PrintStream modelReport = null;
|
||||
|
|
@ -497,10 +497,16 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
|
||||
File inputFile = new File(inputModel);
|
||||
if (inputFile.exists()){ // Load GMM from a file
|
||||
logger.info("Loading model from:"+inputModel);
|
||||
GATKReport reportIn = new GATKReport(inputFile);
|
||||
|
||||
// Read all the tables
|
||||
GATKReportTable amTable = reportIn.getTable("AnnotationMeans");
|
||||
GATKReportTable astdTable = reportIn.getTable("AnnotationStdevs");
|
||||
|
||||
// Should have same number of means and standard deviations.
|
||||
assert(amTable.getNumRows() == astdTable.getNumRows() );
|
||||
|
||||
GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
|
||||
GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
|
||||
GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
|
||||
|
|
@ -508,41 +514,13 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
||||
GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
||||
|
||||
double[] meanVector;
|
||||
double[] stdVector;
|
||||
int numAnnotations = 0;
|
||||
|
||||
for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) {
|
||||
if (reportColumn.getColumnName().equals("Mean")) {
|
||||
meanVector = new double[amTable.getNumRows()];
|
||||
numAnnotations = amTable.getNumRows();
|
||||
for (int row = 0; row < amTable.getNumRows(); row++) {
|
||||
meanVector[row] = (double) amTable.get(row, reportColumn.getColumnName());
|
||||
}
|
||||
logger.info("Got mean Vector:" + Arrays.toString(meanVector));
|
||||
}
|
||||
}
|
||||
|
||||
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
|
||||
logger.info("Report column name is:" + reportColumn.getColumnName());
|
||||
if (reportColumn.getColumnName().equals("Standarddeviation")) {
|
||||
stdVector = new double[astdTable.getNumRows()];
|
||||
|
||||
for (int row = 0; row < astdTable.getNumRows(); row++) {
|
||||
stdVector[row] = (double) astdTable.get(row, reportColumn.getColumnName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
int numAnnotations = amTable.getNumRows();
|
||||
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
||||
//Utils.getRandomGenerator().setSeed(12878);
|
||||
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||
negativeTrainingData = dataManager.selectWorstVariants();
|
||||
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
||||
logger.info("Loaded GMM from file:" + inputModel);
|
||||
|
||||
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
||||
//Utils.getRandomGenerator().setSeed(12878);
|
||||
engine.evaluateData(dataManager.getData(), badModel, true);
|
||||
|
||||
} else { // Generate the GMMs from scratch
|
||||
|
|
@ -568,9 +546,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
report.print(modelReport);
|
||||
}
|
||||
|
||||
boolean writeFeatures = true;
|
||||
if (writeFeatures) writeFeaturesFiles(positiveTrainingData, negativeTrainingData);
|
||||
|
||||
engine.calculateWorstPerformingAnnotation(dataManager.getData(), goodModel, badModel);
|
||||
|
||||
// Find the VQSLOD cutoff values which correspond to the various tranches of calls requested by the user
|
||||
|
|
@ -609,7 +584,15 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
}
|
||||
}
|
||||
|
||||
public GaussianMixtureModel GMMFromTables(GATKReportTable muTable, GATKReportTable sigmaTable, GATKReportTable pmixTable, int numAnnotations){
|
||||
/**
|
||||
* Rebuild a Gaussian Mixture Model from gaussian means and co-variates stored in a GATKReportTables
|
||||
* @param muTable Table of Gaussian means
|
||||
* @param sigmaTable Table of Gaussian co-variates
|
||||
* @param pmixTable Table of PMixLog10 values
|
||||
* @param numAnnotations Number of annotations, i.e. Dimension of the annotation space in which the Gaussians live
|
||||
* @return a GaussianMixtureModel whose state reflects the state recorded in the tables.
|
||||
*/
|
||||
private GaussianMixtureModel GMMFromTables(GATKReportTable muTable, GATKReportTable sigmaTable, GATKReportTable pmixTable, int numAnnotations){
|
||||
List<MultivariateGaussian> gaussianList = new ArrayList<>();
|
||||
|
||||
int curAnnotation = 0;
|
||||
|
|
@ -621,7 +604,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
MultivariateGaussian mg = new MultivariateGaussian(numAnnotations);
|
||||
gaussianList.add(mg);
|
||||
}
|
||||
gaussianList.get(row).mu[curAnnotation] = (double) muTable.get(row, reportColumn.getColumnName());
|
||||
gaussianList.get(row).mu[curAnnotation] = Double.parseDouble((String)muTable.get(row, reportColumn.getColumnName()));
|
||||
}
|
||||
curAnnotation++;
|
||||
}
|
||||
|
|
@ -630,7 +613,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) {
|
||||
if (reportColumn.getColumnName().equals("pMixLog10")) {
|
||||
for (int row = 0; row < pmixTable.getNumRows(); row++) {
|
||||
gaussianList.get(row).pMixtureLog10 = (double) pmixTable.get(row, reportColumn.getColumnName());
|
||||
gaussianList.get(row).pMixtureLog10 = Double.parseDouble((String)pmixTable.get(row, reportColumn.getColumnName()));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
@ -643,7 +626,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
for (int row = 0; row < sigmaTable.getNumRows(); row++) {
|
||||
int curGaussian = row / numAnnotations;
|
||||
int curI = row % numAnnotations;
|
||||
double curVal = (double) sigmaTable.get(row, reportColumn.getColumnName());
|
||||
double curVal = Double.parseDouble((String)sigmaTable.get(row, reportColumn.getColumnName()));
|
||||
gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal);
|
||||
|
||||
}
|
||||
|
|
@ -655,41 +638,36 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
|
||||
}
|
||||
|
||||
private void writeFeaturesFiles(List<VariantDatum> positiveTrainingData, List<VariantDatum> negativeTrainingData){
|
||||
//Begin Sam Hacking
|
||||
try {
|
||||
File file = new File("/Users/sam/data/haploid_features.txt");
|
||||
file.createNewFile();
|
||||
File badFile = new File("/Users/sam/data/haploid_bad_features.txt");
|
||||
badFile.createNewFile();
|
||||
private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){
|
||||
double[] stdVector = {};
|
||||
|
||||
FileWriter fw = new FileWriter(file.getAbsoluteFile());
|
||||
BufferedWriter bw = new BufferedWriter(fw);
|
||||
for(int jj = 0; jj < positiveTrainingData.size(); jj++){
|
||||
VariantDatum v = positiveTrainingData.get(jj);
|
||||
for(int kk = 0; kk < v.annotations.length; kk++){
|
||||
bw.write(Double.toString(v.annotations[kk]));
|
||||
bw.write(" ");
|
||||
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
|
||||
logger.info("Report column name is:" + reportColumn.getColumnName());
|
||||
if (reportColumn.getColumnName().equals("Standarddeviation")) {
|
||||
stdVector = new double[astdTable.getNumRows()];
|
||||
for (int row = 0; row < astdTable.getNumRows(); row++) {
|
||||
stdVector[row] = Double.parseDouble((String) astdTable.get(row, reportColumn.getColumnName()));
|
||||
}
|
||||
bw.write("\n");
|
||||
}
|
||||
bw.close();
|
||||
|
||||
fw = new FileWriter(badFile.getAbsoluteFile());
|
||||
bw = new BufferedWriter(fw);
|
||||
for(int jj = 0; jj < negativeTrainingData.size(); jj++){
|
||||
VariantDatum v = negativeTrainingData.get(jj);
|
||||
for(int kk = 0; kk < v.annotations.length; kk++){
|
||||
bw.write(Double.toString(v.annotations[kk]));
|
||||
bw.write(" ");
|
||||
}
|
||||
bw.write("\n");
|
||||
}
|
||||
bw.close();
|
||||
}catch(IOException e){
|
||||
e.printStackTrace();
|
||||
}
|
||||
// End Sam Hacking
|
||||
|
||||
return stdVector;
|
||||
}
|
||||
|
||||
private double[] getMeansFromTable(GATKReportTable amTable){
|
||||
double[] meanVector = {};
|
||||
|
||||
for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) {
|
||||
if (reportColumn.getColumnName().equals("Mean")) {
|
||||
meanVector = new double[amTable.getNumRows()];
|
||||
for (int row = 0; row < amTable.getNumRows(); row++) {
|
||||
meanVector[row] = Double.parseDouble((String) amTable.get(row, reportColumn.getColumnName()));
|
||||
}
|
||||
logger.info("Got mean Vector:" + Arrays.toString(meanVector));
|
||||
}
|
||||
}
|
||||
|
||||
return meanVector;
|
||||
}
|
||||
|
||||
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
|
||||
|
|
|
|||
Loading…
Reference in New Issue