Serialized GMM no longer depends on command line annotation order (#1632)

Order annotations by the order in the model report.
This commit is contained in:
Samuel Friedman 2017-10-24 11:50:03 -04:00 committed by GitHub
parent 09b4cf70f6
commit a2f45944f3
4 changed files with 168 additions and 10 deletions

View File

@ -110,7 +110,15 @@ public class VariantDataManager {
return data; return data;
} }
public void normalizeData(final boolean calculateMeans) { /**
* Normalize annotations to mean 0 and standard deviation 1.
* Order the variant annotations by the provided list {@code theOrder} or standard deviation.
*
* @param calculateMeans Boolean indicating whether or not to calculate the means
* @param theOrder a list of integers specifying the desired annotation order. If this is null
* annotations will get sorted in decreasing size of their standard deviations.
*/
public void normalizeData(final boolean calculateMeans, List<Integer> theOrder) {
boolean foundZeroVarianceAnnotation = false; boolean foundZeroVarianceAnnotation = false;
for( int iii = 0; iii < meanVector.length; iii++ ) { for( int iii = 0; iii < meanVector.length; iii++ ) {
final double theMean, theSTD; final double theMean, theSTD;
@ -150,7 +158,10 @@ public class VariantDataManager {
// re-order the data by increasing standard deviation so that the results don't depend on the order things were specified on the command line // re-order the data by increasing standard deviation so that the results don't depend on the order things were specified on the command line
// standard deviation over the training points is used as a simple proxy for information content, perhaps there is a better thing to use here // standard deviation over the training points is used as a simple proxy for information content, perhaps there is a better thing to use here
final List<Integer> theOrder = calculateSortOrder(meanVector); // or use the serialized report's annotation order via the argument theOrder
if (theOrder == null){
theOrder = calculateSortOrder(meanVector);
}
annotationKeys = reorderList(annotationKeys, theOrder); annotationKeys = reorderList(annotationKeys, theOrder);
varianceVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(varianceVector), theOrder)); varianceVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(varianceVector), theOrder));
meanVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(meanVector), theOrder)); meanVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(meanVector), theOrder));
@ -158,7 +169,8 @@ public class VariantDataManager {
datum.annotations = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.annotations), theOrder)); datum.annotations = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.annotations), theOrder));
datum.isNull = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.isNull), theOrder)); datum.isNull = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.isNull), theOrder));
} }
logger.info("Annotations are now ordered by their information content: " + annotationKeys.toString()); logger.info("Annotation order is: " + annotationKeys.toString());
} }
public double[] getMeanVector() { public double[] getMeanVector() {

View File

@ -51,6 +51,7 @@
package org.broadinstitute.gatk.tools.walkers.variantrecalibration; package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
import com.google.common.annotations.VisibleForTesting;
import htsjdk.variant.variantcontext.Allele; import htsjdk.variant.variantcontext.Allele;
import org.broadinstitute.gatk.utils.commandline.*; import org.broadinstitute.gatk.utils.commandline.*;
import org.broadinstitute.gatk.engine.CommandLineGATK; import org.broadinstitute.gatk.engine.CommandLineGATK;
@ -312,6 +313,9 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
@Argument(fullName = "trustAllPolymorphic", shortName = "allPoly", doc = "Trust that all the input training sets' unfiltered records contain only polymorphic sites to drastically speed up the computation.", required = false) @Argument(fullName = "trustAllPolymorphic", shortName = "allPoly", doc = "Trust that all the input training sets' unfiltered records contain only polymorphic sites to drastically speed up the computation.", required = false)
protected Boolean TRUST_ALL_POLYMORPHIC = false; protected Boolean TRUST_ALL_POLYMORPHIC = false;
@VisibleForTesting
protected List<Integer> annotationOrder = null;
///////////////////////////// /////////////////////////////
// Private Member Variables // Private Member Variables
///////////////////////////// /////////////////////////////
@ -372,18 +376,15 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix"); final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
final GATKReportTable anMeansTable = reportIn.getTable("AnnotationMeans"); final GATKReportTable anMeansTable = reportIn.getTable("AnnotationMeans");
final GATKReportTable anStDevsTable = reportIn.getTable("AnnotationStdevs"); final GATKReportTable anStDevsTable = reportIn.getTable("AnnotationStdevs");
final int numAnnotations = dataManager.annotationKeys.size();
if( numAnnotations != pmmTable.getNumColumns()-1 || numAnnotations != nmmTable.getNumColumns()-1 ) { // -1 because the first column is the gaussian number. orderAndValidateAnnotations(anMeansTable, dataManager.annotationKeys);
throw new UserException.CommandLineException( "Annotations specified on the command line do not match annotations in the model report." );
}
final Map<String, Double> anMeans = getMapFromVectorTable(anMeansTable); final Map<String, Double> anMeans = getMapFromVectorTable(anMeansTable);
final Map<String, Double> anStdDevs = getMapFromVectorTable(anStDevsTable); final Map<String, Double> anStdDevs = getMapFromVectorTable(anStDevsTable);
dataManager.setNormalization(anMeans, anStdDevs); dataManager.setNormalization(anMeans, anStdDevs);
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations); goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, annotationOrder.size());
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations); badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, annotationOrder.size());
} }
final Set<VCFHeaderLine> hInfo = new HashSet<>(); final Set<VCFHeaderLine> hInfo = new HashSet<>();
@ -401,6 +402,32 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
} }
/**
* Order and validate annotations according to the annotations in the serialized model
* Annotations on the command line must be the same as those in the model report or this will throw an exception.
* Sets the {@code annotationOrder} list to map from command line order to the model report's order.
* n^2 because we typically use 7 or less annotations.
* @param annotationTable GATKReportTable of annotations read from the serialized model file
*/
protected void orderAndValidateAnnotations(final GATKReportTable annotationTable, final List<String> annotationKeys){
annotationOrder = new ArrayList<Integer>(annotationKeys.size());
for (int i = 0; i < annotationTable.getNumRows(); i++){
String serialAnno = (String)annotationTable.get(i, "Annotation");
for (int j = 0; j < annotationKeys.size(); j++) {
if (serialAnno.equals( annotationKeys.get(j) )){
annotationOrder.add(j);
}
}
}
if(annotationOrder.size() != annotationTable.getNumRows() || annotationOrder.size() != annotationKeys.size()) {
final String errorMsg = "Annotations specified on the command line:"+annotationKeys.toString() +" do not match annotations in the model report:"+inputModel;
throw new UserException.CommandLineException(errorMsg);
}
}
//--------------------------------------------------------------------------------------------------------------- //---------------------------------------------------------------------------------------------------------------
// //
@ -518,7 +545,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
for (int i = 1; i <= max_attempts; i++) { for (int i = 1; i <= max_attempts; i++) {
try { try {
dataManager.setData(reduceSum); dataManager.setData(reduceSum);
dataManager.normalizeData(inputModel.isEmpty()); // Each data point is now (x - mean) / standard deviation dataManager.normalizeData(inputModel.isEmpty(), annotationOrder); // Each data point is now (x - mean) / standard deviation
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData(); final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
final List<VariantDatum> negativeTrainingData; final List<VariantDatum> negativeTrainingData;

View File

@ -51,6 +51,8 @@
package org.broadinstitute.gatk.tools.walkers.variantrecalibration; package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
import org.broadinstitute.gatk.utils.exceptions.UserException;
import org.broadinstitute.gatk.utils.exceptions.UserException.CommandLineException;
import org.broadinstitute.gatk.utils.variant.VCIterable; import org.broadinstitute.gatk.utils.variant.VCIterable;
import org.broadinstitute.gatk.engine.walkers.WalkerTest; import org.broadinstitute.gatk.engine.walkers.WalkerTest;
import htsjdk.variant.variantcontext.VariantContext; import htsjdk.variant.variantcontext.VariantContext;
@ -60,6 +62,7 @@ import org.testng.annotations.DataProvider;
import org.testng.annotations.Test; import org.testng.annotations.Test;
import java.io.File; import java.io.File;
import java.io.IOException;
import java.util.Arrays; import java.util.Arrays;
import java.util.List; import java.util.List;
@ -390,5 +393,73 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
new File(outputFile.getAbsolutePath() + ".pdf").deleteOnExit(); new File(outputFile.getAbsolutePath() + ".pdf").deleteOnExit();
} }
} }
@Test
public void testVQSRAnnotationOrder() throws IOException {
final String inputFile = privateTestDir + "oneSNP.vcf";
final String exacModelReportFilename = privateTestDir + "subsetExAC.snps_model.report";
final String annoOrderRecal = privateTestDir + "anno_order.recal";
final String annoOrderTranches = privateTestDir + "anno_order.tranches";
final String goodMd5 = "d41d8cd98f00b204e9800998ecf8427e";
final String base = "-R " + b37KGReference +
" -T VariantRecalibrator" +
" -input " + inputFile +
" -L 1:110201699" +
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
" -an FS -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR"+
" --recal_file " + annoOrderRecal +
" -tranchesFile " + annoOrderTranches +
" --input_model " + exacModelReportFilename +
" -ignoreAllFilters -mode SNP" +
" --no_cmdline_in_header" ;
final WalkerTestSpec spec = new WalkerTestSpec(base, 1, Arrays.asList(goodMd5));
spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
List<File> outputFiles = executeTest("testVQSRAnnotationOrder", spec).getFirst();
setPDFsForDeletion(outputFiles);
final String base2 = "-R " + b37KGReference +
" -T VariantRecalibrator" +
" -input " + inputFile +
" -L 1:110201699" +
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
" -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR -an FS "+
" --recal_file " + annoOrderRecal +
" -tranchesFile " + annoOrderTranches +
" --input_model " + exacModelReportFilename +
" -ignoreAllFilters -mode SNP" +
" --no_cmdline_in_header" ;
final WalkerTestSpec spec2 = new WalkerTestSpec(base2, 1, Arrays.asList(goodMd5));
spec2.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
outputFiles = executeTest("testVQSRAnnotationOrder2", spec2).getFirst();
setPDFsForDeletion(outputFiles);
}
@Test(expectedExceptions={RuntimeException.class, CommandLineException.class})
public void testVQSRAnnotationMismatch() throws IOException {
final String inputFile = privateTestDir + "oneSNP.vcf";
final String exacModelReportFilename = privateTestDir + "subsetExAC.snps_model.report";
final String annoOrderRecal = privateTestDir + "anno_order.recal";
final String annoOrderTranches = privateTestDir + "anno_order.tranches";
final String goodMd5 = "d41d8cd98f00b204e9800998ecf8427e";
final String base = "-R " + b37KGReference +
" -T VariantRecalibrator" +
" -input " + inputFile +
" -L 1:110201699" +
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
" -an FS -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR -an BaseQRankSum"+
" --recal_file " + annoOrderRecal +
" -tranchesFile " + annoOrderTranches +
" --input_model " + exacModelReportFilename +
" -ignoreAllFilters -mode SNP" +
" --no_cmdline_in_header" ;
final WalkerTestSpec spec = new WalkerTestSpec(base, 1, Arrays.asList(goodMd5));
spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
executeTest("testVQSRAnnotationMismatch", spec).getFirst();
}
} }

View File

@ -276,4 +276,52 @@ public class VariantRecalibratorModelOutputUnitTest extends BaseTest {
return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts); return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts);
} }
@Test
public void testAnnotationOrderAndValidate() {
final VariantRecalibrator vqsr = new VariantRecalibrator();
final List<String> annotationList = new ArrayList<>();
annotationList.add("QD");
annotationList.add("FS");
annotationList.add("ReadPosRankSum");
annotationList.add("MQ");
annotationList.add("MQRankSum");
annotationList.add("SOR");
double[] meanVector = {16.13, 2.45, 0.37, 59.08, 0.14, 0.91};
final String columnName = "Mean";
final String formatString = "%.3f";
GATKReportTable annotationTable = vqsr.makeVectorTable("AnnotationMeans", "Mean for each annotation, used to normalize data", annotationList, meanVector, columnName, formatString);
vqsr.orderAndValidateAnnotations(annotationTable, annotationList);
for (int i = 0; i < vqsr.annotationOrder.size(); i++){
Assert.assertEquals(i, (int)vqsr.annotationOrder.get(i));
}
annotationList.remove(0);
annotationList.add("QD");
vqsr.orderAndValidateAnnotations(annotationTable, annotationList);
for (int i = 0; i < vqsr.annotationOrder.size(); i++) {
if (i == 0) {
Assert.assertEquals(annotationList.size()-1, (int)vqsr.annotationOrder.get(i));
} else {
Assert.assertEquals(i - 1, (int)vqsr.annotationOrder.get(i));
}
}
final List<String> annotationList2 = new ArrayList<>();
annotationList2.add("ReadPosRankSum");
annotationList2.add("MQRankSum");
annotationList2.add("MQ");
annotationList2.add("SOR");
annotationList2.add("QD");
annotationList2.add("FS");
final VariantRecalibrator vqsr2 = new VariantRecalibrator();
vqsr2.orderAndValidateAnnotations(annotationTable, annotationList2);
for (int i = 0; i < vqsr2.annotationOrder.size(); i++){
Assert.assertEquals(annotationList.get(vqsr.annotationOrder.get(i)), annotationList2.get(vqsr2.annotationOrder.get(i)));
}
}
} }