Merge pull request #1608 from broadinstitute/ldg_renomalizeVQSRdata
Normalize input data by same annotation normalization scheme as input…
This commit is contained in:
commit
385f06b5cb
|
|
@ -99,23 +99,37 @@ public class VariantDataManager {
|
||||||
this.data = data;
|
this.data = data;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public void setNormalization(final Map<String, Double> anMeans, final Map<String, Double> anStdDevs) {
|
||||||
|
for (int i = 0; i < this.annotationKeys.size(); i++) {
|
||||||
|
meanVector[i] = anMeans.get(annotationKeys.get(i));
|
||||||
|
varianceVector[i] = anStdDevs.get(annotationKeys.get(i));
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public List<VariantDatum> getData() {
|
public List<VariantDatum> getData() {
|
||||||
return data;
|
return data;
|
||||||
}
|
}
|
||||||
|
|
||||||
public void normalizeData() {
|
public void normalizeData(final boolean calculateMeans) {
|
||||||
boolean foundZeroVarianceAnnotation = false;
|
boolean foundZeroVarianceAnnotation = false;
|
||||||
for( int iii = 0; iii < meanVector.length; iii++ ) {
|
for( int iii = 0; iii < meanVector.length; iii++ ) {
|
||||||
final double theMean = mean(iii, true);
|
final double theMean, theSTD;
|
||||||
final double theSTD = standardDeviation(theMean, iii, true);
|
if (calculateMeans) {
|
||||||
logger.info( annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) );
|
theMean = mean(iii, true);
|
||||||
if( Double.isNaN(theMean) ) {
|
theSTD = standardDeviation(theMean, iii, true);
|
||||||
throw new UserException.BadInput("Values for " + annotationKeys.get(iii) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations.");
|
if (Double.isNaN(theMean)) {
|
||||||
}
|
throw new UserException.BadInput("Values for " + annotationKeys.get(iii) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations.");
|
||||||
|
}
|
||||||
|
|
||||||
foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-5);
|
foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-5);
|
||||||
meanVector[iii] = theMean;
|
meanVector[iii] = theMean;
|
||||||
varianceVector[iii] = theSTD;
|
varianceVector[iii] = theSTD;
|
||||||
|
}
|
||||||
|
else {
|
||||||
|
theMean = meanVector[iii];
|
||||||
|
theSTD = varianceVector[iii];
|
||||||
|
}
|
||||||
|
logger.info(annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD));
|
||||||
for( final VariantDatum datum : data ) {
|
for( final VariantDatum datum : data ) {
|
||||||
// Transform each data point via: (x - mean) / standard deviation
|
// Transform each data point via: (x - mean) / standard deviation
|
||||||
datum.annotations[iii] = ( datum.isNull[iii] ? 0.1 * Utils.getRandomGenerator().nextGaussian() : ( datum.annotations[iii] - theMean ) / theSTD );
|
datum.annotations[iii] = ( datum.isNull[iii] ? 0.1 * Utils.getRandomGenerator().nextGaussian() : ( datum.annotations[iii] - theMean ) / theSTD );
|
||||||
|
|
|
||||||
|
|
@ -370,12 +370,18 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
|
final GATKReportTable pmcTable = reportIn.getTable("PositiveModelCovariances");
|
||||||
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
final GATKReportTable pmmTable = reportIn.getTable("PositiveModelMeans");
|
||||||
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
|
||||||
|
final GATKReportTable anMeansTable = reportIn.getTable("AnnotationMeans");
|
||||||
|
final GATKReportTable anStDevsTable = reportIn.getTable("AnnotationStdevs");
|
||||||
final int numAnnotations = dataManager.annotationKeys.size();
|
final int numAnnotations = dataManager.annotationKeys.size();
|
||||||
|
|
||||||
if( numAnnotations != pmmTable.getNumColumns()-1 || numAnnotations != nmmTable.getNumColumns()-1 ) { // -1 because the first column is the gaussian number.
|
if( numAnnotations != pmmTable.getNumColumns()-1 || numAnnotations != nmmTable.getNumColumns()-1 ) { // -1 because the first column is the gaussian number.
|
||||||
throw new UserException.CommandLineException( "Annotations specified on the command line do not match annotations in the model report." );
|
throw new UserException.CommandLineException( "Annotations specified on the command line do not match annotations in the model report." );
|
||||||
}
|
}
|
||||||
|
|
||||||
|
final Map<String, Double> anMeans = getMapFromVectorTable(anMeansTable);
|
||||||
|
final Map<String, Double> anStdDevs = getMapFromVectorTable(anStDevsTable);
|
||||||
|
dataManager.setNormalization(anMeans, anStdDevs);
|
||||||
|
|
||||||
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
|
||||||
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
|
||||||
}
|
}
|
||||||
|
|
@ -512,7 +518,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
for (int i = 1; i <= max_attempts; i++) {
|
for (int i = 1; i <= max_attempts; i++) {
|
||||||
try {
|
try {
|
||||||
dataManager.setData(reduceSum);
|
dataManager.setData(reduceSum);
|
||||||
dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
|
dataManager.normalizeData(inputModel.isEmpty()); // Each data point is now (x - mean) / standard deviation
|
||||||
|
|
||||||
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
|
||||||
final List<VariantDatum> negativeTrainingData;
|
final List<VariantDatum> negativeTrainingData;
|
||||||
|
|
@ -638,38 +644,18 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private double[] getStandardDeviationsFromTable(GATKReportTable astdTable){
|
private Map<String, Double> getMapFromVectorTable(GATKReportTable vectorTable){
|
||||||
double[] stdVector = {};
|
Map<String, Double> dataMap = new HashMap<>();
|
||||||
|
|
||||||
for (GATKReportColumn reportColumn : astdTable.getColumnInfo() ) {
|
//do a row-major traversal
|
||||||
if (reportColumn.getColumnName().equals("Standarddeviation")) {
|
for (int i = 0; i < vectorTable.getNumRows(); i++) {
|
||||||
stdVector = new double[astdTable.getNumRows()];
|
dataMap.put((String) vectorTable.get(i, 0), (Double) vectorTable.get(i, 1));
|
||||||
for (int row = 0; row < astdTable.getNumRows(); row++) {
|
|
||||||
stdVector[row] = (Double) astdTable.get(row, reportColumn.getColumnName());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
return dataMap;
|
||||||
return stdVector;
|
|
||||||
}
|
|
||||||
|
|
||||||
private double[] getMeansFromTable(GATKReportTable amTable){
|
|
||||||
double[] meanVector = {};
|
|
||||||
|
|
||||||
for (GATKReportColumn reportColumn : amTable.getColumnInfo() ) {
|
|
||||||
if (reportColumn.getColumnName().equals("Mean")) {
|
|
||||||
meanVector = new double[amTable.getNumRows()];
|
|
||||||
for (int row = 0; row < amTable.getNumRows(); row++) {
|
|
||||||
meanVector[row] = (Double) amTable.get(row, reportColumn.getColumnName());
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return meanVector;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
|
protected GATKReport writeModelReport(final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, List<String> annotationList) {
|
||||||
final String formatString = "%.8E";
|
final String formatString = "%.16E";
|
||||||
final GATKReport report = new GATKReport();
|
final GATKReport report = new GATKReport();
|
||||||
|
|
||||||
if (dataManager != null) { //for unit test
|
if (dataManager != null) { //for unit test
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue