diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java index 0995fc10c..59a5af92d 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/GaussianMixtureModel.java @@ -86,6 +86,11 @@ public class GaussianMixtureModel { gaussians = new ArrayList<>( numGaussians ); for( int iii = 0; iii < numGaussians; iii++ ) { final MultivariateGaussian gaussian = new MultivariateGaussian( numAnnotations ); + gaussian.pMixtureLog10 = Math.log10( 1.0 / ((double)numGaussians) ); + gaussian.sumProb = 1.0 / ((double) numGaussians); + gaussian.hyperParameter_a = priorCounts; + gaussian.hyperParameter_b = shrinkage; + gaussian.hyperParameter_lambda = dirichletParameter; gaussians.add( gaussian ); } this.shrinkage = shrinkage; @@ -190,6 +195,9 @@ public class GaussianMixtureModel { final double[] pVarInGaussianNormalized = MathUtils.normalizeFromLog10( pVarInGaussianLog10, false ); gaussianIndex = 0; for( final MultivariateGaussian gaussian : gaussians ) { + if (Double.isNaN(pVarInGaussianNormalized[gaussianIndex])){ + logger.info(" Got a NaN at gaussian:" + Integer.toString(gaussianIndex) + " datum:" + datum.toString()); + } gaussian.assignPVarInGaussian( pVarInGaussianNormalized[gaussianIndex++] ); } } @@ -315,4 +323,5 @@ public class GaussianMixtureModel { protected List getModelGaussians() {return Collections.unmodifiableList(gaussians);} protected int getNumAnnotations() {return empiricalMu.length;} + } \ No newline at end of file diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java index db46b0c33..5e9d18ffd 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/MultivariateGaussian.java @@ -271,4 +271,14 @@ public class MultivariateGaussian { resetPVarInGaussian(); // clean up some memory } + + + public void setSumProb( final List data ) { + sumProb = 0.0; + + for( int datumIndex = 0; datumIndex < data.size(); datumIndex++ ) { + final double prob = pVarInGaussian.get(datumIndex); + if(!Double.isNaN(prob)) sumProb += prob; + } + } } \ No newline at end of file diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index 416d72c67..470ce3cdb 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -66,6 +66,7 @@ import org.broadinstitute.gatk.utils.R.RScriptExecutor; import org.broadinstitute.gatk.utils.Utils; import org.broadinstitute.gatk.utils.help.HelpConstants; import org.broadinstitute.gatk.utils.report.GATKReport; +import org.broadinstitute.gatk.utils.report.GATKReportColumn; import org.broadinstitute.gatk.utils.report.GATKReportTable; import org.broadinstitute.gatk.utils.variant.GATKVariantContextUtils; import htsjdk.variant.vcf.VCFHeader; @@ -80,10 +81,15 @@ import htsjdk.variant.variantcontext.writer.VariantContextWriter; import java.io.File; import java.io.FileNotFoundException; import java.io.PrintStream; +import java.nio.file.Files; import java.util.*; import Jama.Matrix; + +import java.io.FileWriter; +import java.io.BufferedWriter; +import java.io.IOException; /** * Build a recalibration model to score variant quality for filtering purposes * @@ -274,6 +280,8 @@ public class VariantRecalibrator extends RodWalker inputCollection : inputCollections ) input.addAll(inputCollection.getRodBindings()); + + } + //--------------------------------------------------------------------------------------------------------------- // // map @@ -480,18 +491,76 @@ public class VariantRecalibrator extends RodWalker positiveTrainingData = dataManager.getTrainingData(); - final GaussianMixtureModel goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS); - engine.evaluateData(dataManager.getData(), goodModel, false); + final List negativeTrainingData; - // Generate the negative model using the worst performing data and evaluate each variant contrastively - final List negativeTrainingData = dataManager.selectWorstVariants(); - final GaussianMixtureModel badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); - dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory - engine.evaluateData(dataManager.getData(), badModel, true); + File inputFile = new File(inputModel); + if (inputFile.exists()){ // Load GMM from a file + GATKReport reportIn = new GATKReport(inputFile); + GATKReportTable amTable = reportIn.getTable("AnnotationMeans"); + GATKReportTable astdTable = reportIn.getTable("AnnotationStdevs"); - if (badModel.failedToConverge || goodModel.failedToConverge) { - throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).")); + GATKReportTable nmcTable = reportIn.getTable("NegativeModelCovariances"); + GATKReportTable nmmTable = reportIn.getTable("NegativeModelMeans"); + GATKReportTable nPMixTable = reportIn.getTable("BadGaussianPMix"); + GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances"); + GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans"); + GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix"); + + double[] meanVector; + double[] stdVector; + int numAnnotations = 0; + + for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) { + if (reportColumn.getColumnName().equals("Mean")) { + meanVector = new double[amTable.getNumRows()]; + numAnnotations = amTable.getNumRows(); + for (int row = 0; row < amTable.getNumRows(); row++) { + meanVector[row] = (double) amTable.get(row, reportColumn.getColumnName()); + } + logger.info("Got mean Vector:" + Arrays.toString(meanVector)); + } + } + + for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) { + logger.info("Report column name is:" + reportColumn.getColumnName()); + if (reportColumn.getColumnName().equals("Standarddeviation")) { + stdVector = new double[astdTable.getNumRows()]; + + for (int row = 0; row < astdTable.getNumRows(); row++) { + stdVector[row] = (double) astdTable.get(row, reportColumn.getColumnName()); + } + } + } + + goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations); + //Utils.getRandomGenerator().setSeed(12878); + engine.evaluateData(dataManager.getData(), goodModel, false); + negativeTrainingData = dataManager.selectWorstVariants(); + badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations); + logger.info("Loaded GMM from file:" + inputModel); + + dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory + //Utils.getRandomGenerator().setSeed(12878); + engine.evaluateData(dataManager.getData(), badModel, true); + + } else { // Generate the GMMs from scratch + goodModel = engine.generateModel(positiveTrainingData, VRAC.MAX_GAUSSIANS); + //Utils.getRandomGenerator().setSeed(12878); + engine.evaluateData(dataManager.getData(), goodModel, false); + // Generate the negative model using the worst performing data and evaluate each variant contrastively + negativeTrainingData = dataManager.selectWorstVariants(); + + badModel = engine.generateModel(negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); + dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory + //Utils.getRandomGenerator().setSeed(12878); + engine.evaluateData(dataManager.getData(), badModel, true); + + + if (badModel.failedToConverge || goodModel.failedToConverge) { + throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).")); + } } if (outputModel) { @@ -499,6 +568,9 @@ public class VariantRecalibrator extends RodWalker gaussianList = new ArrayList<>(); + + int curAnnotation = 0; + for (GATKReportColumn reportColumn : muTable.getColumnInfo() ) { + logger.info("Report column name is:" + reportColumn.getColumnName()); + if (!reportColumn.getColumnName().equals("Gaussian")) { + for (int row = 0; row < muTable.getNumRows(); row++) { + if (gaussianList.size() <= row){ + MultivariateGaussian mg = new MultivariateGaussian(numAnnotations); + gaussianList.add(mg); + } + gaussianList.get(row).mu[curAnnotation] = (double) muTable.get(row, reportColumn.getColumnName()); + } + curAnnotation++; + } + } + + for (GATKReportColumn reportColumn : pmixTable.getColumnInfo() ) { + if (reportColumn.getColumnName().equals("pMixLog10")) { + for (int row = 0; row < pmixTable.getNumRows(); row++) { + gaussianList.get(row).pMixtureLog10 = (double) pmixTable.get(row, reportColumn.getColumnName()); + } + } + } + + int curJ = 0; + for (GATKReportColumn reportColumn : sigmaTable.getColumnInfo() ) { + if (reportColumn.getColumnName().equals("Gaussian")) continue; + if (reportColumn.getColumnName().equals("Annotation")) continue; + + for (int row = 0; row < sigmaTable.getNumRows(); row++) { + int curGaussian = row / numAnnotations; + int curI = row % numAnnotations; + double curVal = (double) sigmaTable.get(row, reportColumn.getColumnName()); + gaussianList.get(curGaussian).sigma.set(curI, curJ, curVal); + + } + curJ++; + + } + + return new GaussianMixtureModel(gaussianList, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS); + + } + + private void writeFeaturesFiles(List positiveTrainingData, List negativeTrainingData){ + //Begin Sam Hacking + try { + File file = new File("/Users/sam/data/haploid_features.txt"); + file.createNewFile(); + File badFile = new File("/Users/sam/data/haploid_bad_features.txt"); + badFile.createNewFile(); + + FileWriter fw = new FileWriter(file.getAbsoluteFile()); + BufferedWriter bw = new BufferedWriter(fw); + for(int jj = 0; jj < positiveTrainingData.size(); jj++){ + VariantDatum v = positiveTrainingData.get(jj); + for(int kk = 0; kk < v.annotations.length; kk++){ + bw.write(Double.toString(v.annotations[kk])); + bw.write(" "); + } + bw.write("\n"); + } + bw.close(); + + fw = new FileWriter(badFile.getAbsoluteFile()); + bw = new BufferedWriter(fw); + for(int jj = 0; jj < negativeTrainingData.size(); jj++){ + VariantDatum v = negativeTrainingData.get(jj); + for(int kk = 0; kk < v.annotations.length; kk++){ + bw.write(Double.toString(v.annotations[kk])); + bw.write(" "); + } + bw.write("\n"); + } + bw.close(); + }catch(IOException e){ + e.printStackTrace(); + } + // End Sam Hacking + } + protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List annotationList) { - final String formatString = "%.3f"; + final String formatString = "%.25f"; final GATKReport report = new GATKReport(); if (dataManager != null) { //for unit test @@ -547,10 +702,36 @@ public class VariantRecalibrator extends RodWalker gaussianStrings = new ArrayList<>(); + final double[] pMixtureLog10s = new double[goodModel.getModelGaussians().size()]; + int idx = 0; + + for( final MultivariateGaussian gaussian : goodModel.getModelGaussians() ) { + pMixtureLog10s[idx] = gaussian.pMixtureLog10; + logger.info("Good normalize PMix log 10 is:" + Double.toString(gaussian.pMixtureLog10) ); + gaussianStrings.add(Integer.toString(idx++) ); + } + + GATKReportTable goodPMix = makeVectorTable("GoodGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10s, "pMixLog10", formatString, "Gaussian"); + report.addTable(goodPMix); + + gaussianStrings.clear(); + final double[] pMixtureLog10sBad = new double[badModel.getModelGaussians().size()]; + idx = 0; + + for( final MultivariateGaussian gaussian : badModel.getModelGaussians() ) { + pMixtureLog10sBad[idx] = gaussian.pMixtureLog10; + logger.info("Bad normalize PMix log 10 is:" + Double.toString(gaussian.pMixtureLog10)); + gaussianStrings.add(Integer.toString(idx++)); + } + GATKReportTable badPMix = makeVectorTable("BadGaussianPMix", "Pmixture log 10 used to evaluate model", gaussianStrings, pMixtureLog10sBad, "pMixLog10", formatString, "Gaussian"); + report.addTable(badPMix); + + //The model and Gaussians don't know what the annotations are, so get them from this class //VariantDataManager keeps the annotation in the same order as the argument list GATKReportTable positiveMeans = makeMeansTable("PositiveModelMeans", "Vector of annotation values to describe the (normalized) mean for each Gaussian in the positive model", annotationList, goodModel, formatString); @@ -570,8 +751,12 @@ public class VariantRecalibrator extends RodWalker annotationList, final double[] perAnnotationValues, final String columnName, final String formatString) { + return makeVectorTable(tableName, tableDescription, annotationList, perAnnotationValues, columnName, formatString, "Annotation"); + } + + protected GATKReportTable makeVectorTable(final String tableName, final String tableDescription, final List annotationList, final double[] perAnnotationValues, final String columnName, final String formatString, final String firstColumn) { GATKReportTable vectorTable = new GATKReportTable(tableName, tableDescription, annotationList.size(), GATKReportTable.TableSortingWay.DO_NOT_SORT); - vectorTable.addColumn("Annotation"); + vectorTable.addColumn(firstColumn); vectorTable.addColumn(columnName, formatString); for (int i = 0; i < perAnnotationValues.length; i++) { vectorTable.addRowIDMapping(annotationList.get(i), i, true); diff --git a/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java b/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java index 85d2386a0..4d87bab84 100644 --- a/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java +++ b/public/gatk-utils/src/main/java/org/broadinstitute/gatk/utils/report/GATKReportTable.java @@ -150,7 +150,6 @@ public class GATKReportTable { // read a data line final String dataLine = reader.readLine(); final List lineSplits = Arrays.asList(TextFormattingUtils.splitFixedWidth(dataLine, columnStarts)); - underlyingData.add(new Object[nColumns]); for ( int columnIndex = 0; columnIndex < nColumns; columnIndex++ ) {