Serialized GMM no longer depends on command line annotation order (#1632)

Order annotations by the order in the model report.
This commit is contained in:
Samuel Friedman 2017-10-24 11:50:03 -04:00 committed by GitHub
parent 09b4cf70f6
commit a2f45944f3
4 changed files with 168 additions and 10 deletions

View File

@ -110,7 +110,15 @@ public class VariantDataManager {
return data;
}
public void normalizeData(final boolean calculateMeans) {
/**
* Normalize annotations to mean 0 and standard deviation 1.
* Order the variant annotations by the provided list {@code theOrder} or standard deviation.
*
* @param calculateMeans Boolean indicating whether or not to calculate the means
* @param theOrder a list of integers specifying the desired annotation order. If this is null
* annotations will get sorted in decreasing size of their standard deviations.
*/
public void normalizeData(final boolean calculateMeans, List<Integer> theOrder) {
boolean foundZeroVarianceAnnotation = false;
for( int iii = 0; iii < meanVector.length; iii++ ) {
final double theMean, theSTD;
@ -150,7 +158,10 @@ public class VariantDataManager {
// re-order the data by increasing standard deviation so that the results don't depend on the order things were specified on the command line
// standard deviation over the training points is used as a simple proxy for information content, perhaps there is a better thing to use here
final List<Integer> theOrder = calculateSortOrder(meanVector);
// or use the serialized report's annotation order via the argument theOrder
if (theOrder == null){
theOrder = calculateSortOrder(meanVector);
}
annotationKeys = reorderList(annotationKeys, theOrder);
varianceVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(varianceVector), theOrder));
meanVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(meanVector), theOrder));
@ -158,7 +169,8 @@ public class VariantDataManager {
datum.annotations = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.annotations), theOrder));
datum.isNull = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.isNull), theOrder));
}
logger.info("Annotations are now ordered by their information content: " + annotationKeys.toString());
logger.info("Annotation order is: " + annotationKeys.toString());
}
public double[] getMeanVector() {

View File

@ -51,6 +51,7 @@
package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
import com.google.common.annotations.VisibleForTesting;
import htsjdk.variant.variantcontext.Allele;
import org.broadinstitute.gatk.utils.commandline.*;
import org.broadinstitute.gatk.engine.CommandLineGATK;
@ -312,6 +313,9 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
@Argument(fullName = "trustAllPolymorphic", shortName = "allPoly", doc = "Trust that all the input training sets' unfiltered records contain only polymorphic sites to drastically speed up the computation.", required = false)
protected Boolean TRUST_ALL_POLYMORPHIC = false;
@VisibleForTesting
protected List<Integer> annotationOrder = null;
/////////////////////////////
// Private Member Variables
/////////////////////////////
@ -372,18 +376,15 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
final GATKReportTable pPMixTable = reportIn.getTable("GoodGaussianPMix");
final GATKReportTable anMeansTable = reportIn.getTable("AnnotationMeans");
final GATKReportTable anStDevsTable = reportIn.getTable("AnnotationStdevs");
final int numAnnotations = dataManager.annotationKeys.size();
if( numAnnotations != pmmTable.getNumColumns()-1 || numAnnotations != nmmTable.getNumColumns()-1 ) { // -1 because the first column is the gaussian number.
throw new UserException.CommandLineException( "Annotations specified on the command line do not match annotations in the model report." );
}
orderAndValidateAnnotations(anMeansTable, dataManager.annotationKeys);
final Map<String, Double> anMeans = getMapFromVectorTable(anMeansTable);
final Map<String, Double> anStdDevs = getMapFromVectorTable(anStDevsTable);
dataManager.setNormalization(anMeans, anStdDevs);
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations);
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations);
goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, annotationOrder.size());
badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, annotationOrder.size());
}
final Set<VCFHeaderLine> hInfo = new HashSet<>();
@ -401,6 +402,32 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
}
/**
* Order and validate annotations according to the annotations in the serialized model
* Annotations on the command line must be the same as those in the model report or this will throw an exception.
* Sets the {@code annotationOrder} list to map from command line order to the model report's order.
* n^2 because we typically use 7 or less annotations.
* @param annotationTable GATKReportTable of annotations read from the serialized model file
*/
protected void orderAndValidateAnnotations(final GATKReportTable annotationTable, final List<String> annotationKeys){
annotationOrder = new ArrayList<Integer>(annotationKeys.size());
for (int i = 0; i < annotationTable.getNumRows(); i++){
String serialAnno = (String)annotationTable.get(i, "Annotation");
for (int j = 0; j < annotationKeys.size(); j++) {
if (serialAnno.equals( annotationKeys.get(j) )){
annotationOrder.add(j);
}
}
}
if(annotationOrder.size() != annotationTable.getNumRows() || annotationOrder.size() != annotationKeys.size()) {
final String errorMsg = "Annotations specified on the command line:"+annotationKeys.toString() +" do not match annotations in the model report:"+inputModel;
throw new UserException.CommandLineException(errorMsg);
}
}
//---------------------------------------------------------------------------------------------------------------
//
@ -518,7 +545,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
for (int i = 1; i <= max_attempts; i++) {
try {
dataManager.setData(reduceSum);
dataManager.normalizeData(inputModel.isEmpty()); // Each data point is now (x - mean) / standard deviation
dataManager.normalizeData(inputModel.isEmpty(), annotationOrder); // Each data point is now (x - mean) / standard deviation
final List<VariantDatum> positiveTrainingData = dataManager.getTrainingData();
final List<VariantDatum> negativeTrainingData;

View File

@ -51,6 +51,8 @@
package org.broadinstitute.gatk.tools.walkers.variantrecalibration;
import org.broadinstitute.gatk.utils.exceptions.UserException;
import org.broadinstitute.gatk.utils.exceptions.UserException.CommandLineException;
import org.broadinstitute.gatk.utils.variant.VCIterable;
import org.broadinstitute.gatk.engine.walkers.WalkerTest;
import htsjdk.variant.variantcontext.VariantContext;
@ -60,6 +62,7 @@ import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.io.File;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
@ -390,5 +393,73 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
new File(outputFile.getAbsolutePath() + ".pdf").deleteOnExit();
}
}
@Test
public void testVQSRAnnotationOrder() throws IOException {
final String inputFile = privateTestDir + "oneSNP.vcf";
final String exacModelReportFilename = privateTestDir + "subsetExAC.snps_model.report";
final String annoOrderRecal = privateTestDir + "anno_order.recal";
final String annoOrderTranches = privateTestDir + "anno_order.tranches";
final String goodMd5 = "d41d8cd98f00b204e9800998ecf8427e";
final String base = "-R " + b37KGReference +
" -T VariantRecalibrator" +
" -input " + inputFile +
" -L 1:110201699" +
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
" -an FS -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR"+
" --recal_file " + annoOrderRecal +
" -tranchesFile " + annoOrderTranches +
" --input_model " + exacModelReportFilename +
" -ignoreAllFilters -mode SNP" +
" --no_cmdline_in_header" ;
final WalkerTestSpec spec = new WalkerTestSpec(base, 1, Arrays.asList(goodMd5));
spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
List<File> outputFiles = executeTest("testVQSRAnnotationOrder", spec).getFirst();
setPDFsForDeletion(outputFiles);
final String base2 = "-R " + b37KGReference +
" -T VariantRecalibrator" +
" -input " + inputFile +
" -L 1:110201699" +
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
" -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR -an FS "+
" --recal_file " + annoOrderRecal +
" -tranchesFile " + annoOrderTranches +
" --input_model " + exacModelReportFilename +
" -ignoreAllFilters -mode SNP" +
" --no_cmdline_in_header" ;
final WalkerTestSpec spec2 = new WalkerTestSpec(base2, 1, Arrays.asList(goodMd5));
spec2.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
outputFiles = executeTest("testVQSRAnnotationOrder2", spec2).getFirst();
setPDFsForDeletion(outputFiles);
}
@Test(expectedExceptions={RuntimeException.class, CommandLineException.class})
public void testVQSRAnnotationMismatch() throws IOException {
final String inputFile = privateTestDir + "oneSNP.vcf";
final String exacModelReportFilename = privateTestDir + "subsetExAC.snps_model.report";
final String annoOrderRecal = privateTestDir + "anno_order.recal";
final String annoOrderTranches = privateTestDir + "anno_order.tranches";
final String goodMd5 = "d41d8cd98f00b204e9800998ecf8427e";
final String base = "-R " + b37KGReference +
" -T VariantRecalibrator" +
" -input " + inputFile +
" -L 1:110201699" +
" -resource:truth=true,training=true,prior=15.0 " + inputFile +
" -an FS -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR -an BaseQRankSum"+
" --recal_file " + annoOrderRecal +
" -tranchesFile " + annoOrderTranches +
" --input_model " + exacModelReportFilename +
" -ignoreAllFilters -mode SNP" +
" --no_cmdline_in_header" ;
final WalkerTestSpec spec = new WalkerTestSpec(base, 1, Arrays.asList(goodMd5));
spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
executeTest("testVQSRAnnotationMismatch", spec).getFirst();
}
}

View File

@ -276,4 +276,52 @@ public class VariantRecalibratorModelOutputUnitTest extends BaseTest {
return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts);
}
@Test
public void testAnnotationOrderAndValidate() {
final VariantRecalibrator vqsr = new VariantRecalibrator();
final List<String> annotationList = new ArrayList<>();
annotationList.add("QD");
annotationList.add("FS");
annotationList.add("ReadPosRankSum");
annotationList.add("MQ");
annotationList.add("MQRankSum");
annotationList.add("SOR");
double[] meanVector = {16.13, 2.45, 0.37, 59.08, 0.14, 0.91};
final String columnName = "Mean";
final String formatString = "%.3f";
GATKReportTable annotationTable = vqsr.makeVectorTable("AnnotationMeans", "Mean for each annotation, used to normalize data", annotationList, meanVector, columnName, formatString);
vqsr.orderAndValidateAnnotations(annotationTable, annotationList);
for (int i = 0; i < vqsr.annotationOrder.size(); i++){
Assert.assertEquals(i, (int)vqsr.annotationOrder.get(i));
}
annotationList.remove(0);
annotationList.add("QD");
vqsr.orderAndValidateAnnotations(annotationTable, annotationList);
for (int i = 0; i < vqsr.annotationOrder.size(); i++) {
if (i == 0) {
Assert.assertEquals(annotationList.size()-1, (int)vqsr.annotationOrder.get(i));
} else {
Assert.assertEquals(i - 1, (int)vqsr.annotationOrder.get(i));
}
}
final List<String> annotationList2 = new ArrayList<>();
annotationList2.add("ReadPosRankSum");
annotationList2.add("MQRankSum");
annotationList2.add("MQ");
annotationList2.add("SOR");
annotationList2.add("QD");
annotationList2.add("FS");
final VariantRecalibrator vqsr2 = new VariantRecalibrator();
vqsr2.orderAndValidateAnnotations(annotationTable, annotationList2);
for (int i = 0; i < vqsr2.annotationOrder.size(); i++){
Assert.assertEquals(annotationList.get(vqsr.annotationOrder.get(i)), annotationList2.get(vqsr2.annotationOrder.get(i)));
}
}
}