diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java index d4304d147..1df3bc321 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java @@ -99,23 +99,37 @@ public class VariantDataManager { this.data = data; } + public void setNormalization(final Map anMeans, final Map anStdDevs) { + for (int i = 0; i < this.annotationKeys.size(); i++) { + meanVector[i] = anMeans.get(annotationKeys.get(i)); + varianceVector[i] = anStdDevs.get(annotationKeys.get(i)); + } + } + public List getData() { return data; } - public void normalizeData() { + public void normalizeData(final boolean calculateMeans) { boolean foundZeroVarianceAnnotation = false; for( int iii = 0; iii < meanVector.length; iii++ ) { - final double theMean = mean(iii, true); - final double theSTD = standardDeviation(theMean, iii, true); - logger.info( annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); - if( Double.isNaN(theMean) ) { - throw new UserException.BadInput("Values for " + annotationKeys.get(iii) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations."); - } + final double theMean, theSTD; + if (calculateMeans) { + theMean = mean(iii, true); + theSTD = standardDeviation(theMean, iii, true); + if (Double.isNaN(theMean)) { + throw new UserException.BadInput("Values for " + annotationKeys.get(iii) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations."); + } - foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-5); - meanVector[iii] = theMean; - varianceVector[iii] = theSTD; + foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-5); + meanVector[iii] = theMean; + varianceVector[iii] = theSTD; + } + else { + theMean = meanVector[iii]; + theSTD = varianceVector[iii]; + } + logger.info(annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD)); for( final VariantDatum datum : data ) { // Transform each data point via: (x - mean) / standard deviation datum.annotations[iii] = ( datum.isNull[iii] ? 0.1 * Utils.getRandomGenerator().nextGaussian() : ( datum.annotations[iii] - theMean ) / theSTD ); diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index 546c0f05b..4a03b568d 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -370,12 +370,18 @@ public class VariantRecalibrator extends RodWalker anMeans = getMapFromVectorTable(anMeansTable); + final Map anStdDevs = getMapFromVectorTable(anStDevsTable); + dataManager.setNormalization(anMeans, anStdDevs); + goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations); badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations); } @@ -512,7 +518,7 @@ public class VariantRecalibrator extends RodWalker positiveTrainingData = dataManager.getTrainingData(); final List negativeTrainingData; @@ -638,38 +644,18 @@ public class VariantRecalibrator extends RodWalker getMapFromVectorTable(GATKReportTable vectorTable){ + Map dataMap = new HashMap<>(); - for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) { - if (reportColumn.getColumnName().equals("Standarddeviation")) { - stdVector = new double[astdTable.getNumRows()]; - for (int row = 0; row < astdTable.getNumRows(); row++) { - stdVector[row] = (Double) astdTable.get(row, reportColumn.getColumnName()); - } - } + //do a row-major traversal + for (int i = 0; i < vectorTable.getNumRows(); i++) { + dataMap.put((String) vectorTable.get(i, 0), (Double) vectorTable.get(i, 1)); } - - return stdVector; - } - - private double[] getMeansFromTable(GATKReportTable amTable){ - double[] meanVector = {}; - - for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) { - if (reportColumn.getColumnName().equals("Mean")) { - meanVector = new double[amTable.getNumRows()]; - for (int row = 0; row < amTable.getNumRows(); row++) { - meanVector[row] = (Double) amTable.get(row, reportColumn.getColumnName()); - } - } - } - - return meanVector; + return dataMap; } protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List annotationList) { - final String formatString = "%.8E"; + final String formatString = "%.16E"; final GATKReport report = new GATKReport(); if (dataManager != null) { //for unit test