From 856c1f87c10ed1f44132006dfd74b42380097627 Mon Sep 17 00:00:00 2001 From: Ryan Poplin Date: Wed, 11 Dec 2013 14:43:07 -0500 Subject: [PATCH] Allow for additional input data to be used in the VQSR for clustering but don't carry it forward into the output VCF file. -- New -a argument in the VQSR for specifying additional data to be used in the clustering -- New NA12878KB walker which creates ROC curves by partitioning the data along VQSLOD and calculating how many KB TP/FP's are called. --- .../VariantDataManager.java | 18 +++++- .../variantrecalibration/VariantDatum.java | 3 +- .../VariantRecalibrator.java | 41 +++++++++++-- .../VariantRecalibratorEngine.java | 3 + .../VariantDataManagerUnitTest.java | 35 +++++++++++ ...ntRecalibrationWalkersIntegrationTest.java | 60 ++++++++++++++++++- 6 files changed, 150 insertions(+), 10 deletions(-) diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java index ac4654f73..1f355359d 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java @@ -46,6 +46,7 @@ package org.broadinstitute.sting.gatk.walkers.variantrecalibration; +import it.unimi.dsi.fastutil.booleans.BooleanLists; import org.apache.commons.lang.ArrayUtils; import org.apache.log4j.Logger; import org.broadinstitute.sting.gatk.GenomeAnalysisEngine; @@ -71,7 +72,7 @@ import java.util.*; */ public class VariantDataManager { - private List data; + private List data = Collections.emptyList(); private double[] meanVector; private double[] varianceVector; // this is really the standard deviation public List annotationKeys; @@ -80,7 +81,7 @@ public class VariantDataManager { protected final List trainingSets; public VariantDataManager( final List annotationKeys, final VariantRecalibratorArgumentCollection VRAC ) { - this.data = null; + this.data = Collections.emptyList(); this.annotationKeys = new ArrayList<>( annotationKeys ); this.VRAC = VRAC; meanVector = new double[this.annotationKeys.size()]; @@ -279,6 +280,19 @@ public class VariantDataManager { return evaluationData; } + /** + * Remove all VariantDatum's from the data list which are marked as aggregate data + */ + public void dropAggregateData() { + final Iterator iter = data.iterator(); + while (iter.hasNext()) { + final VariantDatum datum = iter.next(); + if( datum.isAggregate ) { + iter.remove(); + } + } + } + public List getRandomDataForPlotting( final int numToAdd, final List trainingData, final List antiTrainingData, final List evaluationData ) { final List returnData = new ExpandingArrayList<>(); Collections.shuffle(trainingData); diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java index 905c97df4..41b27949d 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java @@ -74,7 +74,8 @@ public class VariantDatum { public int consensusCount; public GenomeLoc loc; public int worstAnnotation; - public MultivariateGaussian assignment; // used in K-means implementation + public MultivariateGaussian assignment; // used in K-means implementation + public boolean isAggregate; // this datum was provided to aid in modeling but isn't part of the input callset public static class VariantDatumLODComparator implements Comparator, Serializable { @Override diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java index 1c32b852b..5da7b4219 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java @@ -152,11 +152,17 @@ public class VariantRecalibrator extends RodWalker> input; + /** + * These additional calls should be unfiltered and annotated with the error covariates that are intended to be used for modeling. + */ + @Input(fullName="aggregate", shortName = "aggregate", doc="Additional raw input variants to be used in building the model", required=false) + public List> aggregate; + /** * Any set of VCF files to use as lists of training, truth, or known sites. * Training - Input variants which are found to overlap with these training sites are used to build the Gaussian mixture model. @@ -290,29 +296,53 @@ public class VariantRecalibrator extends RodWalker addOverlappingVariants( final List> rods, final boolean isInput, final RefMetaDataTracker tracker, final AlignmentContext context ) { + if( rods == null ) { throw new IllegalArgumentException("rods cannot be null."); } + if( tracker == null ) { throw new IllegalArgumentException("tracker cannot be null."); } + if( context == null ) { throw new IllegalArgumentException("context cannot be null."); } + + final ExpandingArrayList variants = new ExpandingArrayList<>(); + + for( final VariantContext vc : tracker.getValues(rods, context.getLocation()) ) { if( vc != null && ( vc.isNotFiltered() || ignoreInputFilterSet.containsAll(vc.getFilters()) ) ) { if( VariantDataManager.checkVariationClass( vc, VRAC.MODE ) ) { final VariantDatum datum = new VariantDatum(); // Populate the datum with lots of fields from the VariantContext, unfortunately the VC is too big so we just pull in only the things we absolutely need. dataManager.decodeAnnotations( datum, vc, true ); //BUGBUG: when run with HierarchicalMicroScheduler this is non-deterministic because order of calls depends on load of machine - datum.loc = getToolkit().getGenomeLocParser().createGenomeLoc(vc); + datum.loc = ( isInput ? getToolkit().getGenomeLocParser().createGenomeLoc(vc) : null ); datum.originalQual = vc.getPhredScaledQual(); datum.isSNP = vc.isSNP() && vc.isBiallelic(); datum.isTransition = datum.isSNP && GATKVariantContextUtils.isTransition(vc); + datum.isAggregate = !isInput; // Loop through the training data sets and if they overlap this loci then update the prior and training status appropriately dataManager.parseTrainingSets( tracker, context.getLocation(), vc, datum, TRUST_ALL_POLYMORPHIC ); final double priorFactor = QualityUtils.qualToProb( datum.prior ); datum.prior = Math.log10( priorFactor ) - Math.log10( 1.0 - priorFactor ); - mapList.add( datum ); + variants.add( datum ); } } } - return mapList; + return variants; } //--------------------------------------------------------------------------------------------------------------- @@ -357,6 +387,7 @@ public class VariantRecalibrator extends RodWalker negativeTrainingData = dataManager.selectWorstVariants(); final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); + dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory engine.evaluateData( dataManager.getData(), badModel, true ); if( badModel.failedToConverge || goodModel.failedToConverge ) { diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java index 3828e6e20..dae3bffa5 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java @@ -80,6 +80,9 @@ public class VariantRecalibratorEngine { } public GaussianMixtureModel generateModel( final List data, final int maxGaussians ) { + if( data == null || data.isEmpty() ) { throw new IllegalArgumentException("No data found."); } + if( maxGaussians <= 0 ) { throw new IllegalArgumentException("maxGaussians must be a positive integer but found: " + maxGaussians); } + final GaussianMixtureModel model = new GaussianMixtureModel( maxGaussians, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS ); variationalBayesExpectationMaximization( model, data ); return model; diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManagerUnitTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManagerUnitTest.java index 754fe30a2..9a1422608 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManagerUnitTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManagerUnitTest.java @@ -142,4 +142,39 @@ public class VariantDataManagerUnitTest extends BaseTest { Assert.assertTrue( trainingData.size() == MAX_NUM_TRAINING_DATA ); } + + @Test + public final void testDropAggregateData() { + final int MAX_NUM_TRAINING_DATA = 5000; + final double passingQual = 400.0; + final VariantRecalibratorArgumentCollection VRAC = new VariantRecalibratorArgumentCollection(); + VRAC.MAX_NUM_TRAINING_DATA = MAX_NUM_TRAINING_DATA; + + VariantDataManager vdm = new VariantDataManager(new ArrayList(), VRAC); + final List theData = new ArrayList<>(); + for( int iii = 0; iii < MAX_NUM_TRAINING_DATA * 10; iii++) { + final VariantDatum datum = new VariantDatum(); + datum.atTrainingSite = true; + datum.isAggregate = false; + datum.failingSTDThreshold = false; + datum.originalQual = passingQual; + theData.add(datum); + } + + for( int iii = 0; iii < MAX_NUM_TRAINING_DATA * 2; iii++) { + final VariantDatum datum = new VariantDatum(); + datum.atTrainingSite = false; + datum.isAggregate = true; + datum.failingSTDThreshold = false; + datum.originalQual = passingQual; + theData.add(datum); + } + + vdm.setData(theData); + vdm.dropAggregateData(); + + for( final VariantDatum datum : vdm.getData() ) { + Assert.assertFalse( datum.isAggregate ); + } + } } diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java index f3e57b48a..225000775 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java @@ -62,9 +62,11 @@ import java.util.List; public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { private static class VRTest { String inVCF; + String aggregateVCF; String tranchesMD5; String recalMD5; String cutVCFMD5; + public VRTest(String inVCF, String tranchesMD5, String recalMD5, String cutVCFMD5) { this.inVCF = inVCF; this.tranchesMD5 = tranchesMD5; @@ -72,6 +74,14 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { this.cutVCFMD5 = cutVCFMD5; } + public VRTest(String inVCF, String aggregateVCF, String tranchesMD5, String recalMD5, String cutVCFMD5) { + this.inVCF = inVCF; + this.aggregateVCF = aggregateVCF; + this.tranchesMD5 = tranchesMD5; + this.recalMD5 = recalMD5; + this.cutVCFMD5 = cutVCFMD5; + } + @Override public String toString() { return "VRTest{inVCF='" + inVCF +"'}"; @@ -83,10 +93,20 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { "73c7897441622c9b37376eb4f071c560", // recal file "11a28df79b92229bd317ac49a3ed0fa1"); // cut VCF + VRTest lowPassPlusExomes = new VRTest(validationDataLocation + "phase1.projectConsensus.chr20.raw.snps.vcf", + validationDataLocation + "1kg_exomes_unfiltered.AFR.unfiltered.vcf", + "ce4bfc6619147fe7ce1f8331bbeb86ce", // tranches + "1b33c10be7d8bf8e9accd11113835262", // recal file + "4700d52a06f2ef3a5882719b86911e51"); // cut VCF + @DataProvider(name = "VRTest") public Object[][] createData1() { return new Object[][]{ {lowPass} }; - //return new Object[][]{ {yriTrio}, {lowPass} }; // Add hg19 chr20 trio calls here + } + + @DataProvider(name = "VRAggregateTest") + public Object[][] createData2() { + return new Object[][]{ {lowPassPlusExomes} }; } @Test(dataProvider = "VRTest") @@ -125,6 +145,43 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { executeTest("testApplyRecalibration-"+params.inVCF, spec); } + @Test(dataProvider = "VRAggregateTest") + public void testVariantRecalibratorAggregate(VRTest params) { + //System.out.printf("PARAMS FOR %s is %s%n", vcf, clusterFile); + WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec( + "-R " + b37KGReference + + " -resource:known=true,prior=10.0 " + GATKDataLocation + "dbsnp_132_b37.leftAligned.vcf" + + " -resource:truth=true,training=true,prior=15.0 " + comparisonDataLocation + "Validated/HapMap/3.3/sites_r27_nr.b37_fwd.vcf" + + " -resource:training=true,truth=true,prior=12.0 " + comparisonDataLocation + "Validated/Omni2.5_chip/Omni25_sites_1525_samples.b37.vcf" + + " -T VariantRecalibrator" + + " -input " + params.inVCF + + " -aggregate " + params.aggregateVCF + + " -L 20:1,000,000-40,000,000" + + " --no_cmdline_in_header" + + " -an QD -an HaplotypeScore -an MQ" + + " --trustAllPolymorphic" + // for speed + " -recalFile %s" + + " -tranchesFile %s", + Arrays.asList(params.recalMD5, params.tranchesMD5)); + executeTest("testVariantRecalibratorAggregate-"+params.inVCF, spec).getFirst(); + } + + @Test(dataProvider = "VRAggregateTest",dependsOnMethods="testVariantRecalibratorAggregate") + public void testApplyRecalibrationAggregate(VRTest params) { + WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec( + "-R " + b37KGReference + + " -T ApplyRecalibration" + + " -L 20:12,000,000-30,000,000" + + " --no_cmdline_in_header" + + " -input " + params.inVCF + + " -U LENIENT_VCF_PROCESSING -o %s" + + " -tranchesFile " + getMd5DB().getMD5FilePath(params.tranchesMD5, null) + + " -recalFile " + getMd5DB().getMD5FilePath(params.recalMD5, null), + Arrays.asList(params.cutVCFMD5)); + spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles + executeTest("testApplyRecalibrationAggregate-"+params.inVCF, spec); + } + VRTest bcfTest = new VRTest(privateTestDir + "vqsr.bcf_test.snps.unfiltered.bcf", "3ad7f55fb3b072f373cbce0b32b66df4", // tranches "e747c08131d58d9a4800720f6ca80e0c", // recal file @@ -133,7 +190,6 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { @DataProvider(name = "VRBCFTest") public Object[][] createVRBCFTest() { return new Object[][]{ {bcfTest} }; - //return new Object[][]{ {yriTrio}, {lowPass} }; // Add hg19 chr20 trio calls here } @Test(dataProvider = "VRBCFTest")