Merge pull request #1608 from broadinstitute/ldg_renomalizeVQSRdata

Normalize input data by same annotation normalization scheme as input…
This commit is contained in:
ldgauthier 2017-06-30 16:44:40 -04:00 committed by GitHub
commit 385f06b5cb
2 changed files with 38 additions and 38 deletions

View File

@ -99,23 +99,37 @@ public class VariantDataManager {
this.data = data;
}
public void setNormalization(final Map<String, Double> anMeans, final Map<String, Double> anStdDevs) {
for (int i = 0; i < this.annotationKeys.size(); i++) {
meanVector[i] = anMeans.get(annotationKeys.get(i));
varianceVector[i] = anStdDevs.get(annotationKeys.get(i));
}
}
public List<VariantDatum> getData() {
return data;
}
public void normalizeData() {
public void normalizeData(final boolean calculateMeans) {
boolean foundZeroVarianceAnnotation = false;
for( int iii = 0; iii < meanVector.length; iii++ ) {
final double theMean = mean(iii, true);
final double theSTD = standardDeviation(theMean, iii, true);
logger.info( annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) );
if( Double.isNaN(theMean) ) {
final double theMean, theSTD;
if (calculateMeans) {
theMean = mean(iii, true);
theSTD = standardDeviation(theMean, iii, true);
if (Double.isNaN(theMean)) {
throw new UserException.BadInput("Values for " + annotationKeys.get(iii) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations.");
}
foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-5);
meanVector[iii] = theMean;
varianceVector[iii] = theSTD;
}
else {
theMean = meanVector[iii];
theSTD = varianceVector[iii];
}
logger.info(annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD));
for( final VariantDatum datum : data ) {
// Transform each data point via: (x - mean) / standard deviation
datum.annotations[iii] = ( datum.isNull[iii] ? 0.1 * Utils.getRandomGenerator().nextGaussian() : ( datum.annotations[iii] - theMean ) / theSTD );

View File

@ -370,12 +370,18 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
final GATKReportTable anMeansTable = reportIn.getTable("AnnotationMeans");
final GATKReportTable anStDevsTable = reportIn.getTable("AnnotationStdevs");
final int numAnnotations = dataManager.annotationKeys.size();
if( numAnnotations != pmmTable.getNumColumns()-1 || numAnnotations != nmmTable.getNumColumns()-1 ) { // -1 because the first column is the gaussian number.
throw new UserException.CommandLineException( "Annotations specified on the command line do not match annotations in the model report." );
}
final Map<String, Double> anMeans = getMapFromVectorTable(anMeansTable);
final Map<String, Double> anStdDevs = getMapFromVectorTable(anStDevsTable);
dataManager.setNormalization(anMeans, anStdDevs);
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
}
@ -512,7 +518,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
for (int i = 1; i <= max_attempts; i++) {
try {
dataManager.setData(reduceSum);
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
dataManager.normalizeData(inputModel.isEmpty()); // Each data point is now (x - mean) / standard deviation
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
final List<VariantDatum> negativeTrainingData;
@ -638,38 +644,18 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
}
private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){
double[] stdVector = {};
private Map<String, Double> getMapFromVectorTable(GATKReportTable vectorTable){
Map<String, Double> dataMap = new HashMap<>();
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
if (reportColumn.getColumnName().equals("Standarddeviation")) {
stdVector = new double[astdTable.getNumRows()];
for (int row = 0; row < astdTable.getNumRows(); row++) {
stdVector[row] = (Double) astdTable.get(row, reportColumn.getColumnName());
//do a row-major traversal
for (int i = 0; i < vectorTable.getNumRows(); i++) {
dataMap.put((String) vectorTable.get(i, 0), (Double) vectorTable.get(i, 1));
}
}
}
return stdVector;
}
private double[] getMeansFromTable(GATKReportTable amTable){
double[] meanVector = {};
for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) {
if (reportColumn.getColumnName().equals("Mean")) {
meanVector = new double[amTable.getNumRows()];
for (int row = 0; row < amTable.getNumRows(); row++) {
meanVector[row] = (Double) amTable.get(row, reportColumn.getColumnName());
}
}
}
return meanVector;
return dataMap;
}
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
final String formatString = "%.8E";
final String formatString = "%.16E";
final GATKReport report = new GATKReport();
if (dataManager != null) { //for unit test