Normalize input data by same annotation normalization scheme as input model

This commit is contained in:
Laura Gauthier 2017-05-05 13:09:49 -04:00
parent 741c544fbf
commit ee39f9141f
2 changed files with 38 additions and 38 deletions

View File

@ -99,23 +99,37 @@ public class VariantDataManager {
this.data = data; this.data = data;
} }
public void setNormalization(final Map<String, Double> anMeans, final Map<String, Double> anStdDevs) {
for (int i = 0; i < this.annotationKeys.size(); i++) {
meanVector[i] = anMeans.get(annotationKeys.get(i));
varianceVector[i] = anStdDevs.get(annotationKeys.get(i));
}
}
public List<VariantDatum> getData() { public List<VariantDatum> getData() {
return data; return data;
} }
public void normalizeData() { public void normalizeData(final boolean calculateMeans) {
boolean foundZeroVarianceAnnotation = false; boolean foundZeroVarianceAnnotation = false;
for( int iii = 0; iii < meanVector.length; iii++ ) { for( int iii = 0; iii < meanVector.length; iii++ ) {
final double theMean = mean(iii, true); final double theMean, theSTD;
final double theSTD = standardDeviation(theMean, iii, true); if (calculateMeans) {
logger.info( annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); theMean = mean(iii, true);
if( Double.isNaN(theMean) ) { theSTD = standardDeviation(theMean, iii, true);
if (Double.isNaN(theMean)) {
throw new UserException.BadInput("Values for " + annotationKeys.get(iii) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations."); throw new UserException.BadInput("Values for " + annotationKeys.get(iii) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations.");
} }
foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-5); foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-5);
meanVector[iii] = theMean; meanVector[iii] = theMean;
varianceVector[iii] = theSTD; varianceVector[iii] = theSTD;
}
else {
theMean = meanVector[iii];
theSTD = varianceVector[iii];
}
logger.info(annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD));
for( final VariantDatum datum : data ) { for( final VariantDatum datum : data ) {
// Transform each data point via: (x - mean) / standard deviation // Transform each data point via: (x - mean) / standard deviation
datum.annotations[iii] = ( datum.isNull[iii] ? 0.1 * Utils.getRandomGenerator().nextGaussian() : ( datum.annotations[iii] - theMean ) / theSTD ); datum.annotations[iii] = ( datum.isNull[iii] ? 0.1 * Utils.getRandomGenerator().nextGaussian() : ( datum.annotations[iii] - theMean ) / theSTD );

View File

@ -370,12 +370,18 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances"); final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans"); final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix"); final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
final GATKReportTable anMeansTable = reportIn.getTable("AnnotationMeans");
final GATKReportTable anStDevsTable = reportIn.getTable("AnnotationStdevs");
final int numAnnotations = dataManager.annotationKeys.size(); final int numAnnotations = dataManager.annotationKeys.size();
if( numAnnotations != pmmTable.getNumColumns()-1 || numAnnotations != nmmTable.getNumColumns()-1 ) { // -1 because the first column is the gaussian number. if( numAnnotations != pmmTable.getNumColumns()-1 || numAnnotations != nmmTable.getNumColumns()-1 ) { // -1 because the first column is the gaussian number.
throw new UserException.CommandLineException( "Annotations specified on the command line do not match annotations in the model report." ); throw new UserException.CommandLineException( "Annotations specified on the command line do not match annotations in the model report." );
} }
final Map<String, Double> anMeans = getMapFromVectorTable(anMeansTable);
final Map<String, Double> anStdDevs = getMapFromVectorTable(anStDevsTable);
dataManager.setNormalization(anMeans, anStdDevs);
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations); goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations); badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
} }
@ -512,7 +518,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
for (int i = 1; i <= max_attempts; i++) { for (int i = 1; i <= max_attempts; i++) {
try { try {
dataManager.setData(reduceSum); dataManager.setData(reduceSum);
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation dataManager.normalizeData(inputModel.isEmpty()); // Each data point is now (x - mean) / standard deviation
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData(); final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
final List<VariantDatum> negativeTrainingData; final List<VariantDatum> negativeTrainingData;
@ -638,38 +644,18 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
} }
private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){ private Map<String, Double> getMapFromVectorTable(GATKReportTable vectorTable){
double[] stdVector = {}; Map<String, Double> dataMap = new HashMap<>();
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) { //do a row-major traversal
if (reportColumn.getColumnName().equals("Standarddeviation")) { for (int i = 0; i < vectorTable.getNumRows(); i++) {
stdVector = new double[astdTable.getNumRows()]; dataMap.put((String) vectorTable.get(i, 0), (Double) vectorTable.get(i, 1));
for (int row = 0; row < astdTable.getNumRows(); row++) {
stdVector[row] = (Double) astdTable.get(row, reportColumn.getColumnName());
} }
} return dataMap;
}
return stdVector;
}
private double[] getMeansFromTable(GATKReportTable amTable){
double[] meanVector = {};
for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) {
if (reportColumn.getColumnName().equals("Mean")) {
meanVector = new double[amTable.getNumRows()];
for (int row = 0; row < amTable.getNumRows(); row++) {
meanVector[row] = (Double) amTable.get(row, reportColumn.getColumnName());
}
}
}
return meanVector;
} }
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) { protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
final String formatString = "%.8E"; final String formatString = "%.16E";
final GATKReport report = new GATKReport(); final GATKReport report = new GATKReport();
if (dataManager != null) { //for unit test if (dataManager != null) { //for unit test