small code cleanup
This commit is contained in:
parent
a8f70c891f
commit
57c064eaa3
|
|
@ -280,7 +280,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
*/
|
*/
|
||||||
@Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model fit to the file specified by -modelFile or to stdout", required=false)
|
@Argument(fullName="output_model", shortName = "outputModel", doc="If specified, the variant recalibrator will output the VQSR model fit to the file specified by -modelFile or to stdout", required=false)
|
||||||
private boolean outputModel = false;
|
private boolean outputModel = false;
|
||||||
@Argument(fullName="input_model", shortName = "inputModel", doc="If specified, the variant recalibrator will read the VQSR model from the file specified by -modelFile", required=false)
|
@Argument(fullName="input_model", shortName = "inputModel", doc="If specified, the variant recalibrator will read the VQSR model from this file path.", required=false)
|
||||||
private String inputModel = "";
|
private String inputModel = "";
|
||||||
@Output(fullName="model_file", shortName = "modelFile", doc="A GATKReport containing the positive and negative model fits", required=false)
|
@Output(fullName="model_file", shortName = "modelFile", doc="A GATKReport containing the positive and negative model fits", required=false)
|
||||||
protected PrintStream modelReport = null;
|
protected PrintStream modelReport = null;
|
||||||
|
|
@ -497,10 +497,16 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
|
|
||||||
File inputFile = new File(inputModel);
|
File inputFile = new File(inputModel);
|
||||||
if (inputFile.exists()){ // Load GMM from a file
|
if (inputFile.exists()){ // Load GMM from a file
|
||||||
|
logger.info("Loading model from:"+inputModel);
|
||||||
GATKReport reportIn = new GATKReport(inputFile);
|
GATKReport reportIn = new GATKReport(inputFile);
|
||||||
|
|
||||||
|
// Read all the tables
|
||||||
GATKReportTable amTable = reportIn.getTable("AnnotationMeans");
|
GATKReportTable amTable = reportIn.getTable("AnnotationMeans");
|
||||||
GATKReportTable astdTable = reportIn.getTable("AnnotationStdevs");
|
GATKReportTable astdTable = reportIn.getTable("AnnotationStdevs");
|
||||||
|
|
||||||
|
// Should have same number of means and standard deviations.
|
||||||
|
assert(amTable.getNumRows() == astdTable.getNumRows() );
|
||||||
|
|
||||||
GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
|
GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances");
|
||||||
GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
|
GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans");
|
||||||
GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
|
GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix");
|
||||||
|
|
@ -508,41 +514,13 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
||||||
GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
||||||
|
|
||||||
double[] meanVector;
|
int numAnnotations = amTable.getNumRows();
|
||||||
double[] stdVector;
|
|
||||||
int numAnnotations = 0;
|
|
||||||
|
|
||||||
for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) {
|
|
||||||
if (reportColumn.getColumnName().equals("Mean")) {
|
|
||||||
meanVector = new double[amTable.getNumRows()];
|
|
||||||
numAnnotations = amTable.getNumRows();
|
|
||||||
for (int row = 0; row < amTable.getNumRows(); row++) {
|
|
||||||
meanVector[row] = (double) amTable.get(row, reportColumn.getColumnName());
|
|
||||||
}
|
|
||||||
logger.info("Got mean Vector:" + Arrays.toString(meanVector));
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
|
|
||||||
logger.info("Report column name is:" + reportColumn.getColumnName());
|
|
||||||
if (reportColumn.getColumnName().equals("Standarddeviation")) {
|
|
||||||
stdVector = new double[astdTable.getNumRows()];
|
|
||||||
|
|
||||||
for (int row = 0; row < astdTable.getNumRows(); row++) {
|
|
||||||
stdVector[row] = (double) astdTable.get(row, reportColumn.getColumnName());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
||||||
//Utils.getRandomGenerator().setSeed(12878);
|
|
||||||
engine.evaluateData(dataManager.getData(), goodModel, false);
|
engine.evaluateData(dataManager.getData(), goodModel, false);
|
||||||
negativeTrainingData = dataManager.selectWorstVariants();
|
negativeTrainingData = dataManager.selectWorstVariants();
|
||||||
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
||||||
logger.info("Loaded GMM from file:" + inputModel);
|
logger.info("Loaded GMM from file:" + inputModel);
|
||||||
|
|
||||||
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
||||||
//Utils.getRandomGenerator().setSeed(12878);
|
|
||||||
engine.evaluateData(dataManager.getData(), badModel, true);
|
engine.evaluateData(dataManager.getData(), badModel, true);
|
||||||
|
|
||||||
} else { // Generate the GMMs from scratch
|
} else { // Generate the GMMs from scratch
|
||||||
|
|
@ -568,9 +546,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
report.print(modelReport);
|
report.print(modelReport);
|
||||||
}
|
}
|
||||||
|
|
||||||
boolean writeFeatures = true;
|
|
||||||
if (writeFeatures) writeFeaturesFiles(positiveTrainingData, negativeTrainingData);
|
|
||||||
|
|
||||||
engine.calculateWorstPerformingAnnotation(dataManager.getData(), goodModel, badModel);
|
engine.calculateWorstPerformingAnnotation(dataManager.getData(), goodModel, badModel);
|
||||||
|
|
||||||
// Find the VQSLOD cutoff values which correspond to the various tranches of calls requested by the user
|
// Find the VQSLOD cutoff values which correspond to the various tranches of calls requested by the user
|
||||||
|
|
@ -609,7 +584,15 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
public GaussianMixtureModel GMMFromTables(GATKReportTable muTable, GATKReportTable sigmaTable, GATKReportTable pmixTable, int numAnnotations){
|
/**
|
||||||
|
* Rebuild a Gaussian Mixture Model from gaussian means and co-variates stored in a GATKReportTables
|
||||||
|
* @param muTable Table of Gaussian means
|
||||||
|
* @param sigmaTable Table of Gaussian co-variates
|
||||||
|
* @param pmixTable Table of PMixLog10 values
|
||||||
|
* @param numAnnotations Number of annotations, i.e. Dimension of the annotation space in which the Gaussians live
|
||||||
|
* @return a GaussianMixtureModel whose state reflects the state recorded in the tables.
|
||||||
|
*/
|
||||||
|
private GaussianMixtureModel GMMFromTables(GATKReportTable muTable, GATKReportTable sigmaTable, GATKReportTable pmixTable, int numAnnotations){
|
||||||
List<MultivariateGaussian> gaussianList = new ArrayList<>();
|
List<MultivariateGaussian> gaussianList = new ArrayList<>();
|
||||||
|
|
||||||
int curAnnotation = 0;
|
int curAnnotation = 0;
|
||||||
|
|
@ -621,7 +604,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
MultivariateGaussian mg = new MultivariateGaussian(numAnnotations);
|
MultivariateGaussian mg = new MultivariateGaussian(numAnnotations);
|
||||||
gaussianList.add(mg);
|
gaussianList.add(mg);
|
||||||
}
|
}
|
||||||
gaussianList.get(row).mu[curAnnotation] = (double) muTable.get(row, reportColumn.getColumnName());
|
gaussianList.get(row).mu[curAnnotation] = Double.parseDouble((String)muTable.get(row, reportColumn.getColumnName()));
|
||||||
}
|
}
|
||||||
curAnnotation++;
|
curAnnotation++;
|
||||||
}
|
}
|
||||||
|
|
@ -630,7 +613,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) {
|
for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) {
|
||||||
if (reportColumn.getColumnName().equals("pMixLog10")) {
|
if (reportColumn.getColumnName().equals("pMixLog10")) {
|
||||||
for (int row = 0; row < pmixTable.getNumRows(); row++) {
|
for (int row = 0; row < pmixTable.getNumRows(); row++) {
|
||||||
gaussianList.get(row).pMixtureLog10 = (double) pmixTable.get(row, reportColumn.getColumnName());
|
gaussianList.get(row).pMixtureLog10 = Double.parseDouble((String)pmixTable.get(row, reportColumn.getColumnName()));
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -643,7 +626,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
for (int row = 0; row < sigmaTable.getNumRows(); row++) {
|
for (int row = 0; row < sigmaTable.getNumRows(); row++) {
|
||||||
int curGaussian = row / numAnnotations;
|
int curGaussian = row / numAnnotations;
|
||||||
int curI = row % numAnnotations;
|
int curI = row % numAnnotations;
|
||||||
double curVal = (double) sigmaTable.get(row, reportColumn.getColumnName());
|
double curVal = Double.parseDouble((String)sigmaTable.get(row, reportColumn.getColumnName()));
|
||||||
gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal);
|
gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal);
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
@ -655,41 +638,36 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private void writeFeaturesFiles(List<VariantDatum> positiveTrainingData, List<VariantDatum> negativeTrainingData){
|
private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){
|
||||||
//Begin Sam Hacking
|
double[] stdVector = {};
|
||||||
try {
|
|
||||||
File file = new File("/Users/sam/data/haploid_features.txt");
|
|
||||||
file.createNewFile();
|
|
||||||
File badFile = new File("/Users/sam/data/haploid_bad_features.txt");
|
|
||||||
badFile.createNewFile();
|
|
||||||
|
|
||||||
FileWriter fw = new FileWriter(file.getAbsoluteFile());
|
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
|
||||||
BufferedWriter bw = new BufferedWriter(fw);
|
logger.info("Report column name is:" + reportColumn.getColumnName());
|
||||||
for(int jj = 0; jj < positiveTrainingData.size(); jj++){
|
if (reportColumn.getColumnName().equals("Standarddeviation")) {
|
||||||
VariantDatum v = positiveTrainingData.get(jj);
|
stdVector = new double[astdTable.getNumRows()];
|
||||||
for(int kk = 0; kk < v.annotations.length; kk++){
|
for (int row = 0; row < astdTable.getNumRows(); row++) {
|
||||||
bw.write(Double.toString(v.annotations[kk]));
|
stdVector[row] = Double.parseDouble((String) astdTable.get(row, reportColumn.getColumnName()));
|
||||||
bw.write(" ");
|
|
||||||
}
|
}
|
||||||
bw.write("\n");
|
|
||||||
}
|
}
|
||||||
bw.close();
|
|
||||||
|
|
||||||
fw = new FileWriter(badFile.getAbsoluteFile());
|
|
||||||
bw = new BufferedWriter(fw);
|
|
||||||
for(int jj = 0; jj < negativeTrainingData.size(); jj++){
|
|
||||||
VariantDatum v = negativeTrainingData.get(jj);
|
|
||||||
for(int kk = 0; kk < v.annotations.length; kk++){
|
|
||||||
bw.write(Double.toString(v.annotations[kk]));
|
|
||||||
bw.write(" ");
|
|
||||||
}
|
|
||||||
bw.write("\n");
|
|
||||||
}
|
|
||||||
bw.close();
|
|
||||||
}catch(IOException e){
|
|
||||||
e.printStackTrace();
|
|
||||||
}
|
}
|
||||||
// End Sam Hacking
|
|
||||||
|
return stdVector;
|
||||||
|
}
|
||||||
|
|
||||||
|
private double[] getMeansFromTable(GATKReportTable amTable){
|
||||||
|
double[] meanVector = {};
|
||||||
|
|
||||||
|
for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) {
|
||||||
|
if (reportColumn.getColumnName().equals("Mean")) {
|
||||||
|
meanVector = new double[amTable.getNumRows()];
|
||||||
|
for (int row = 0; row < amTable.getNumRows(); row++) {
|
||||||
|
meanVector[row] = Double.parseDouble((String) amTable.get(row, reportColumn.getColumnName()));
|
||||||
|
}
|
||||||
|
logger.info("Got mean Vector:" + Arrays.toString(meanVector));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return meanVector;
|
||||||
}
|
}
|
||||||
|
|
||||||
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
|
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue