Normalize input data by same annotation normalization scheme as input model
This commit is contained in:
parent
741c544fbf
commit
ee39f9141f
|
|
@ -99,23 +99,37 @@ public class VariantDataManager {
|
|||
this.data = data;
|
||||
}
|
||||
|
||||
public void setNormalization(final Map<String, Double> anMeans, final Map<String, Double> anStdDevs) {
|
||||
for (int i = 0; i < this.annotationKeys.size(); i++) {
|
||||
meanVector[i] = anMeans.get(annotationKeys.get(i));
|
||||
varianceVector[i] = anStdDevs.get(annotationKeys.get(i));
|
||||
}
|
||||
}
|
||||
|
||||
public List<VariantDatum> getData() {
|
||||
return data;
|
||||
}
|
||||
|
||||
public void normalizeData() {
|
||||
public void normalizeData(final boolean calculateMeans) {
|
||||
boolean foundZeroVarianceAnnotation = false;
|
||||
for( int iii = 0; iii < meanVector.length; iii++ ) {
|
||||
final double theMean = mean(iii, true);
|
||||
final double theSTD = standardDeviation(theMean, iii, true);
|
||||
logger.info( annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) );
|
||||
if( Double.isNaN(theMean) ) {
|
||||
throw new UserException.BadInput("Values for " + annotationKeys.get(iii) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations.");
|
||||
}
|
||||
final double theMean, theSTD;
|
||||
if (calculateMeans) {
|
||||
theMean = mean(iii, true);
|
||||
theSTD = standardDeviation(theMean, iii, true);
|
||||
if (Double.isNaN(theMean)) {
|
||||
throw new UserException.BadInput("Values for " + annotationKeys.get(iii) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations.");
|
||||
}
|
||||
|
||||
foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-5);
|
||||
meanVector[iii] = theMean;
|
||||
varianceVector[iii] = theSTD;
|
||||
foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-5);
|
||||
meanVector[iii] = theMean;
|
||||
varianceVector[iii] = theSTD;
|
||||
}
|
||||
else {
|
||||
theMean = meanVector[iii];
|
||||
theSTD = varianceVector[iii];
|
||||
}
|
||||
logger.info(annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD));
|
||||
for( final VariantDatum datum : data ) {
|
||||
// Transform each data point via: (x - mean) / standard deviation
|
||||
datum.annotations[iii] = ( datum.isNull[iii] ? 0.1 * Utils.getRandomGenerator().nextGaussian() : ( datum.annotations[iii] - theMean ) / theSTD );
|
||||
|
|
|
|||
|
|
@ -370,12 +370,18 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
|
||||
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
||||
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
||||
final GATKReportTable anMeansTable = reportIn.getTable("AnnotationMeans");
|
||||
final GATKReportTable anStDevsTable = reportIn.getTable("AnnotationStdevs");
|
||||
final int numAnnotations = dataManager.annotationKeys.size();
|
||||
|
||||
if( numAnnotations != pmmTable.getNumColumns()-1 || numAnnotations != nmmTable.getNumColumns()-1 ) { // -1 because the first column is the gaussian number.
|
||||
throw new UserException.CommandLineException( "Annotations specified on the command line do not match annotations in the model report." );
|
||||
}
|
||||
|
||||
final Map<String, Double> anMeans = getMapFromVectorTable(anMeansTable);
|
||||
final Map<String, Double> anStdDevs = getMapFromVectorTable(anStDevsTable);
|
||||
dataManager.setNormalization(anMeans, anStdDevs);
|
||||
|
||||
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
||||
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
||||
}
|
||||
|
|
@ -512,7 +518,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
for (int i = 1; i <= max_attempts; i++) {
|
||||
try {
|
||||
dataManager.setData(reduceSum);
|
||||
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
||||
dataManager.normalizeData(inputModel.isEmpty()); // Each data point is now (x - mean) / standard deviation
|
||||
|
||||
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
||||
final List<VariantDatum> negativeTrainingData;
|
||||
|
|
@ -638,38 +644,18 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
|
||||
}
|
||||
|
||||
private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){
|
||||
double[] stdVector = {};
|
||||
private Map<String, Double> getMapFromVectorTable(GATKReportTable vectorTable){
|
||||
Map<String, Double> dataMap = new HashMap<>();
|
||||
|
||||
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
|
||||
if (reportColumn.getColumnName().equals("Standarddeviation")) {
|
||||
stdVector = new double[astdTable.getNumRows()];
|
||||
for (int row = 0; row < astdTable.getNumRows(); row++) {
|
||||
stdVector[row] = (Double) astdTable.get(row, reportColumn.getColumnName());
|
||||
}
|
||||
}
|
||||
//do a row-major traversal
|
||||
for (int i = 0; i < vectorTable.getNumRows(); i++) {
|
||||
dataMap.put((String) vectorTable.get(i, 0), (Double) vectorTable.get(i, 1));
|
||||
}
|
||||
|
||||
return stdVector;
|
||||
}
|
||||
|
||||
private double[] getMeansFromTable(GATKReportTable amTable){
|
||||
double[] meanVector = {};
|
||||
|
||||
for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) {
|
||||
if (reportColumn.getColumnName().equals("Mean")) {
|
||||
meanVector = new double[amTable.getNumRows()];
|
||||
for (int row = 0; row < amTable.getNumRows(); row++) {
|
||||
meanVector[row] = (Double) amTable.get(row, reportColumn.getColumnName());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return meanVector;
|
||||
return dataMap;
|
||||
}
|
||||
|
||||
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
|
||||
final String formatString = "%.8E";
|
||||
final String formatString = "%.16E";
|
||||
final GATKReport report = new GATKReport();
|
||||
|
||||
if (dataManager != null) { //for unit test
|
||||
|
|
|
|||
Loading…
Reference in New Issue