From a2f45944f3eb4953e0ccf92c8661e34714465c24 Mon Sep 17 00:00:00 2001 From: Samuel Friedman Date: Tue, 24 Oct 2017 11:50:03 -0400 Subject: [PATCH] Serialized GMM no longer depends on command line annotation order (#1632) Order annotations by the order in the model report. --- .../VariantDataManager.java | 18 ++++- .../VariantRecalibrator.java | 41 +++++++++-- ...ntRecalibrationWalkersIntegrationTest.java | 71 +++++++++++++++++++ ...ariantRecalibratorModelOutputUnitTest.java | 48 +++++++++++++ 4 files changed, 168 insertions(+), 10 deletions(-) diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java index 1df3bc321..166d145f0 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantDataManager.java @@ -110,7 +110,15 @@ public class VariantDataManager { return data; } - public void normalizeData(final boolean calculateMeans) { + /** + * Normalize annotations to mean 0 and standard deviation 1. + * Order the variant annotations by the provided list {@code theOrder} or standard deviation. + * + * @param calculateMeans Boolean indicating whether or not to calculate the means + * @param theOrder a list of integers specifying the desired annotation order. If this is null + * annotations will get sorted in decreasing size of their standard deviations. + */ + public void normalizeData(final boolean calculateMeans, List theOrder) { boolean foundZeroVarianceAnnotation = false; for( int iii = 0; iii < meanVector.length; iii++ ) { final double theMean, theSTD; @@ -150,7 +158,10 @@ public class VariantDataManager { // re-order the data by increasing standard deviation so that the results don't depend on the order things were specified on the command line // standard deviation over the training points is used as a simple proxy for information content, perhaps there is a better thing to use here - final List theOrder = calculateSortOrder(meanVector); + // or use the serialized report's annotation order via the argument theOrder + if (theOrder == null){ + theOrder = calculateSortOrder(meanVector); + } annotationKeys = reorderList(annotationKeys, theOrder); varianceVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(varianceVector), theOrder)); meanVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(meanVector), theOrder)); @@ -158,7 +169,8 @@ public class VariantDataManager { datum.annotations = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.annotations), theOrder)); datum.isNull = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.isNull), theOrder)); } - logger.info("Annotations are now ordered by their information content: " + annotationKeys.toString()); + logger.info("Annotation order is: " + annotationKeys.toString()); + } public double[] getMeanVector() { diff --git a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java index 4a03b568d..5bd1fabf2 100644 --- a/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/gatk-tools-protected/src/main/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrator.java @@ -51,6 +51,7 @@ package org.broadinstitute.gatk.tools.walkers.variantrecalibration; +import com.google.common.annotations.VisibleForTesting; import htsjdk.variant.variantcontext.Allele; import org.broadinstitute.gatk.utils.commandline.*; import org.broadinstitute.gatk.engine.CommandLineGATK; @@ -312,6 +313,9 @@ public class VariantRecalibrator extends RodWalker annotationOrder = null; + ///////////////////////////// // Private Member Variables ///////////////////////////// @@ -372,18 +376,15 @@ public class VariantRecalibrator extends RodWalker anMeans = getMapFromVectorTable(anMeansTable); final Map anStdDevs = getMapFromVectorTable(anStDevsTable); dataManager.setNormalization(anMeans, anStdDevs); - goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, numAnnotations); - badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, numAnnotations); + goodModel = GMMFromTables(pmmTable, pmcTable, pPMixTable, annotationOrder.size()); + badModel = GMMFromTables(nmmTable, nmcTable, nPMixTable, annotationOrder.size()); } final Set hInfo = new HashSet<>(); @@ -401,6 +402,32 @@ public class VariantRecalibrator extends RodWalker annotationKeys){ + annotationOrder = new ArrayList(annotationKeys.size()); + + for (int i = 0; i < annotationTable.getNumRows(); i++){ + String serialAnno = (String)annotationTable.get(i, "Annotation"); + for (int j = 0; j < annotationKeys.size(); j++) { + if (serialAnno.equals( annotationKeys.get(j) )){ + annotationOrder.add(j); + } + } + } + + if(annotationOrder.size() != annotationTable.getNumRows() || annotationOrder.size() != annotationKeys.size()) { + final String errorMsg = "Annotations specified on the command line:"+annotationKeys.toString() +" do not match annotations in the model report:"+inputModel; + throw new UserException.CommandLineException(errorMsg); + } + + } + //--------------------------------------------------------------------------------------------------------------- // @@ -518,7 +545,7 @@ public class VariantRecalibrator extends RodWalker positiveTrainingData = dataManager.getTrainingData(); final List negativeTrainingData; diff --git a/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java b/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java index 5231b7e77..7870fcfdc 100644 --- a/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java +++ b/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java @@ -51,6 +51,8 @@ package org.broadinstitute.gatk.tools.walkers.variantrecalibration; +import org.broadinstitute.gatk.utils.exceptions.UserException; +import org.broadinstitute.gatk.utils.exceptions.UserException.CommandLineException; import org.broadinstitute.gatk.utils.variant.VCIterable; import org.broadinstitute.gatk.engine.walkers.WalkerTest; import htsjdk.variant.variantcontext.VariantContext; @@ -60,6 +62,7 @@ import org.testng.annotations.DataProvider; import org.testng.annotations.Test; import java.io.File; +import java.io.IOException; import java.util.Arrays; import java.util.List; @@ -390,5 +393,73 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { new File(outputFile.getAbsolutePath() + ".pdf").deleteOnExit(); } } + + @Test + public void testVQSRAnnotationOrder() throws IOException { + final String inputFile = privateTestDir + "oneSNP.vcf"; + final String exacModelReportFilename = privateTestDir + "subsetExAC.snps_model.report"; + final String annoOrderRecal = privateTestDir + "anno_order.recal"; + final String annoOrderTranches = privateTestDir + "anno_order.tranches"; + final String goodMd5 = "d41d8cd98f00b204e9800998ecf8427e"; + final String base = "-R " + b37KGReference + + " -T VariantRecalibrator" + + " -input " + inputFile + + " -L 1:110201699" + + " -resource:truth=true,training=true,prior=15.0 " + inputFile + + " -an FS -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR"+ + " --recal_file " + annoOrderRecal + + " -tranchesFile " + annoOrderTranches + + " --input_model " + exacModelReportFilename + + " -ignoreAllFilters -mode SNP" + + " --no_cmdline_in_header" ; + + final WalkerTestSpec spec = new WalkerTestSpec(base, 1, Arrays.asList(goodMd5)); + spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles + + List outputFiles = executeTest("testVQSRAnnotationOrder", spec).getFirst(); + setPDFsForDeletion(outputFiles); + + + final String base2 = "-R " + b37KGReference + + " -T VariantRecalibrator" + + " -input " + inputFile + + " -L 1:110201699" + + " -resource:truth=true,training=true,prior=15.0 " + inputFile + + " -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR -an FS "+ + " --recal_file " + annoOrderRecal + + " -tranchesFile " + annoOrderTranches + + " --input_model " + exacModelReportFilename + + " -ignoreAllFilters -mode SNP" + + " --no_cmdline_in_header" ; + + final WalkerTestSpec spec2 = new WalkerTestSpec(base2, 1, Arrays.asList(goodMd5)); + spec2.disableShadowBCF(); // TODO -- enable when we support symbolic alleles + outputFiles = executeTest("testVQSRAnnotationOrder2", spec2).getFirst(); + setPDFsForDeletion(outputFiles); + } + + @Test(expectedExceptions={RuntimeException.class, CommandLineException.class}) + public void testVQSRAnnotationMismatch() throws IOException { + final String inputFile = privateTestDir + "oneSNP.vcf"; + final String exacModelReportFilename = privateTestDir + "subsetExAC.snps_model.report"; + final String annoOrderRecal = privateTestDir + "anno_order.recal"; + final String annoOrderTranches = privateTestDir + "anno_order.tranches"; + final String goodMd5 = "d41d8cd98f00b204e9800998ecf8427e"; + final String base = "-R " + b37KGReference + + " -T VariantRecalibrator" + + " -input " + inputFile + + " -L 1:110201699" + + " -resource:truth=true,training=true,prior=15.0 " + inputFile + + " -an FS -an ReadPosRankSum -an MQ -an MQRankSum -an QD -an SOR -an BaseQRankSum"+ + " --recal_file " + annoOrderRecal + + " -tranchesFile " + annoOrderTranches + + " --input_model " + exacModelReportFilename + + " -ignoreAllFilters -mode SNP" + + " --no_cmdline_in_header" ; + + final WalkerTestSpec spec = new WalkerTestSpec(base, 1, Arrays.asList(goodMd5)); + spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles + executeTest("testVQSRAnnotationMismatch", spec).getFirst(); + } } diff --git a/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorModelOutputUnitTest.java b/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorModelOutputUnitTest.java index e24fdac2b..b88c16c2b 100644 --- a/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorModelOutputUnitTest.java +++ b/protected/gatk-tools-protected/src/test/java/org/broadinstitute/gatk/tools/walkers/variantrecalibration/VariantRecalibratorModelOutputUnitTest.java @@ -276,4 +276,52 @@ public class VariantRecalibratorModelOutputUnitTest extends BaseTest { return new GaussianMixtureModel(badGaussianList, shrinkage, dirichlet, priorCounts); } + @Test + public void testAnnotationOrderAndValidate() { + final VariantRecalibrator vqsr = new VariantRecalibrator(); + final List annotationList = new ArrayList<>(); + annotationList.add("QD"); + annotationList.add("FS"); + annotationList.add("ReadPosRankSum"); + annotationList.add("MQ"); + annotationList.add("MQRankSum"); + annotationList.add("SOR"); + + double[] meanVector = {16.13, 2.45, 0.37, 59.08, 0.14, 0.91}; + final String columnName = "Mean"; + final String formatString = "%.3f"; + GATKReportTable annotationTable = vqsr.makeVectorTable("AnnotationMeans", "Mean for each annotation, used to normalize data", annotationList, meanVector, columnName, formatString); + vqsr.orderAndValidateAnnotations(annotationTable, annotationList); + + for (int i = 0; i < vqsr.annotationOrder.size(); i++){ + Assert.assertEquals(i, (int)vqsr.annotationOrder.get(i)); + } + + annotationList.remove(0); + annotationList.add("QD"); + vqsr.orderAndValidateAnnotations(annotationTable, annotationList); + for (int i = 0; i < vqsr.annotationOrder.size(); i++) { + if (i == 0) { + Assert.assertEquals(annotationList.size()-1, (int)vqsr.annotationOrder.get(i)); + } else { + Assert.assertEquals(i - 1, (int)vqsr.annotationOrder.get(i)); + } + } + + final List annotationList2 = new ArrayList<>(); + annotationList2.add("ReadPosRankSum"); + annotationList2.add("MQRankSum"); + annotationList2.add("MQ"); + annotationList2.add("SOR"); + annotationList2.add("QD"); + annotationList2.add("FS"); + + final VariantRecalibrator vqsr2 = new VariantRecalibrator(); + vqsr2.orderAndValidateAnnotations(annotationTable, annotationList2); + for (int i = 0; i < vqsr2.annotationOrder.size(); i++){ + Assert.assertEquals(annotationList.get(vqsr.annotationOrder.get(i)), annotationList2.get(vqsr2.annotationOrder.get(i))); + } + + } + } \ No newline at end of file