Allow for additional input data to be used in the VQSR for clustering but don't carry it forward into the output VCF file.

-- New -a argument in the VQSR for specifying additional data to be used in the clustering
-- New NA12878KB walker which creates ROC curves by partitioning the data along VQSLOD and calculating how many KB TP/FP's are called.
This commit is contained in:
Ryan Poplin 2013-12-11 14:43:07 -05:00
parent c82501ac35
commit 856c1f87c1
6 changed files with 150 additions and 10 deletions

View File

@ -46,6 +46,7 @@
package org.broadinstitute.sting.gatk.walkers.variantrecalibration; package org.broadinstitute.sting.gatk.walkers.variantrecalibration;
import it.unimi.dsi.fastutil.booleans.BooleanLists;
import org.apache.commons.lang.ArrayUtils; import org.apache.commons.lang.ArrayUtils;
import org.apache.log4j.Logger; import org.apache.log4j.Logger;
import org.broadinstitute.sting.gatk.GenomeAnalysisEngine; import org.broadinstitute.sting.gatk.GenomeAnalysisEngine;
@ -71,7 +72,7 @@ import java.util.*;
*/ */
public class VariantDataManager { public class VariantDataManager {
private List<VariantDatum> data; private List<VariantDatum> data = Collections.emptyList();
private double[] meanVector; private double[] meanVector;
private double[] varianceVector; // this is really the standard deviation private double[] varianceVector; // this is really the standard deviation
public List<String> annotationKeys; public List<String> annotationKeys;
@ -80,7 +81,7 @@ public class VariantDataManager {
protected final List<TrainingSet> trainingSets; protected final List<TrainingSet> trainingSets;
public VariantDataManager( final List<String> annotationKeys, final VariantRecalibratorArgumentCollection VRAC ) { public VariantDataManager( final List<String> annotationKeys, final VariantRecalibratorArgumentCollection VRAC ) {
this.data = null; this.data = Collections.emptyList();
this.annotationKeys = new ArrayList<>( annotationKeys ); this.annotationKeys = new ArrayList<>( annotationKeys );
this.VRAC = VRAC; this.VRAC = VRAC;
meanVector = new double[this.annotationKeys.size()]; meanVector = new double[this.annotationKeys.size()];
@ -279,6 +280,19 @@ public class VariantDataManager {
return evaluationData; return evaluationData;
} }
/**
* Remove all VariantDatum's from the data list which are marked as aggregate data
*/
public void dropAggregateData() {
final Iterator<VariantDatum> iter = data.iterator();
while (iter.hasNext()) {
final VariantDatum datum = iter.next();
if( datum.isAggregate ) {
iter.remove();
}
}
}
public List<VariantDatum> getRandomDataForPlotting( final int numToAdd, final List<VariantDatum> trainingData, final List<VariantDatum> antiTrainingData, final List<VariantDatum> evaluationData ) { public List<VariantDatum> getRandomDataForPlotting( final int numToAdd, final List<VariantDatum> trainingData, final List<VariantDatum> antiTrainingData, final List<VariantDatum> evaluationData ) {
final List<VariantDatum> returnData = new ExpandingArrayList<>(); final List<VariantDatum> returnData = new ExpandingArrayList<>();
Collections.shuffle(trainingData); Collections.shuffle(trainingData);

View File

@ -74,7 +74,8 @@ public class VariantDatum {
public int consensusCount; public int consensusCount;
public GenomeLoc loc; public GenomeLoc loc;
public int worstAnnotation; public int worstAnnotation;
public MultivariateGaussian assignment; // used in K-means implementation public MultivariateGaussian assignment; // used in K-means implementation
public boolean isAggregate; // this datum was provided to aid in modeling but isn't part of the input callset
public static class VariantDatumLODComparator implements Comparator<VariantDatum>, Serializable { public static class VariantDatumLODComparator implements Comparator<VariantDatum>, Serializable {
@Override @Override

View File

@ -152,11 +152,17 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
// Inputs // Inputs
///////////////////////////// /////////////////////////////
/** /**
* These calls should be unfiltered and annotated with the error covariates that are intended to use for modeling. * These calls should be unfiltered and annotated with the error covariates that are intended to be used for modeling.
*/ */
@Input(fullName="input", shortName = "input", doc="The raw input variants to be recalibrated", required=true) @Input(fullName="input", shortName = "input", doc="The raw input variants to be recalibrated", required=true)
public List<RodBinding<VariantContext>> input; public List<RodBinding<VariantContext>> input;
/**
* These additional calls should be unfiltered and annotated with the error covariates that are intended to be used for modeling.
*/
@Input(fullName="aggregate", shortName = "aggregate", doc="Additional raw input variants to be used in building the model", required=false)
public List<RodBinding<VariantContext>> aggregate;
/** /**
* Any set of VCF files to use as lists of training, truth, or known sites. * Any set of VCF files to use as lists of training, truth, or known sites.
* Training - Input variants which are found to overlap with these training sites are used to build the Gaussian mixture model. * Training - Input variants which are found to overlap with these training sites are used to build the Gaussian mixture model.
@ -290,29 +296,53 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
return mapList; return mapList;
} }
for( final VariantContext vc : tracker.getValues(input, context.getLocation()) ) { mapList.addAll( addOverlappingVariants(input, true, tracker, context) );
if( aggregate != null ) {
mapList.addAll( addOverlappingVariants(aggregate, false, tracker, context) );
}
return mapList;
}
/**
* Using the RefMetaDataTracker find overlapping variants and pull out the necessary information to create the VariantDatum
* @param rods the rods to search within
* @param isInput is this rod an -input rod?
* @param tracker the RefMetaDataTracker from the RODWalker map call
* @param context the AlignmentContext from the RODWalker map call
* @return a list of VariantDatums, can be empty
*/
private List<VariantDatum> addOverlappingVariants( final List<RodBinding<VariantContext>> rods, final boolean isInput, final RefMetaDataTracker tracker, final AlignmentContext context ) {
if( rods == null ) { throw new IllegalArgumentException("rods cannot be null."); }
if( tracker == null ) { throw new IllegalArgumentException("tracker cannot be null."); }
if( context == null ) { throw new IllegalArgumentException("context cannot be null."); }
final ExpandingArrayList<VariantDatum> variants = new ExpandingArrayList<>();
for( final VariantContext vc : tracker.getValues(rods, context.getLocation()) ) {
if( vc != null && ( vc.isNotFiltered() || ignoreInputFilterSet.containsAll(vc.getFilters()) ) ) { if( vc != null && ( vc.isNotFiltered() || ignoreInputFilterSet.containsAll(vc.getFilters()) ) ) {
if( VariantDataManager.checkVariationClass( vc, VRAC.MODE ) ) { if( VariantDataManager.checkVariationClass( vc, VRAC.MODE ) ) {
final VariantDatum datum = new VariantDatum(); final VariantDatum datum = new VariantDatum();
// Populate the datum with lots of fields from the VariantContext, unfortunately the VC is too big so we just pull in only the things we absolutely need. // Populate the datum with lots of fields from the VariantContext, unfortunately the VC is too big so we just pull in only the things we absolutely need.
dataManager.decodeAnnotations( datum, vc, true ); //BUGBUG: when run with HierarchicalMicroScheduler this is non-deterministic because order of calls depends on load of machine dataManager.decodeAnnotations( datum, vc, true ); //BUGBUG: when run with HierarchicalMicroScheduler this is non-deterministic because order of calls depends on load of machine
datum.loc = getToolkit().getGenomeLocParser().createGenomeLoc(vc); datum.loc = ( isInput ? getToolkit().getGenomeLocParser().createGenomeLoc(vc) : null );
datum.originalQual = vc.getPhredScaledQual(); datum.originalQual = vc.getPhredScaledQual();
datum.isSNP = vc.isSNP() && vc.isBiallelic(); datum.isSNP = vc.isSNP() && vc.isBiallelic();
datum.isTransition = datum.isSNP && GATKVariantContextUtils.isTransition(vc); datum.isTransition = datum.isSNP && GATKVariantContextUtils.isTransition(vc);
datum.isAggregate = !isInput;
// Loop through the training data sets and if they overlap this loci then update the prior and training status appropriately // Loop through the training data sets and if they overlap this loci then update the prior and training status appropriately
dataManager.parseTrainingSets( tracker, context.getLocation(), vc, datum, TRUST_ALL_POLYMORPHIC ); dataManager.parseTrainingSets( tracker, context.getLocation(), vc, datum, TRUST_ALL_POLYMORPHIC );
final double priorFactor = QualityUtils.qualToProb( datum.prior ); final double priorFactor = QualityUtils.qualToProb( datum.prior );
datum.prior = Math.log10( priorFactor ) - Math.log10( 1.0 - priorFactor ); datum.prior = Math.log10( priorFactor ) - Math.log10( 1.0 - priorFactor );
mapList.add( datum ); variants.add( datum );
} }
} }
} }
return mapList; return variants;
} }
//--------------------------------------------------------------------------------------------------------------- //---------------------------------------------------------------------------------------------------------------
@ -357,6 +387,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
// Generate the negative model using the worst performing data and evaluate each variant contrastively // Generate the negative model using the worst performing data and evaluate each variant contrastively
final List<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants(); final List<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants();
final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
engine.evaluateData( dataManager.getData(), badModel, true ); engine.evaluateData( dataManager.getData(), badModel, true );
if( badModel.failedToConverge || goodModel.failedToConverge ) { if( badModel.failedToConverge || goodModel.failedToConverge ) {

View File

@ -80,6 +80,9 @@ public class VariantRecalibratorEngine {
} }
public GaussianMixtureModel generateModel( final List<VariantDatum> data, final int maxGaussians ) { public GaussianMixtureModel generateModel( final List<VariantDatum> data, final int maxGaussians ) {
if( data == null || data.isEmpty() ) { throw new IllegalArgumentException("No data found."); }
if( maxGaussians <= 0 ) { throw new IllegalArgumentException("maxGaussians must be a positive integer but found: " + maxGaussians); }
final GaussianMixtureModel model = new GaussianMixtureModel( maxGaussians, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS ); final GaussianMixtureModel model = new GaussianMixtureModel( maxGaussians, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS );
variationalBayesExpectationMaximization( model, data ); variationalBayesExpectationMaximization( model, data );
return model; return model;

View File

@ -142,4 +142,39 @@ public class VariantDataManagerUnitTest extends BaseTest {
Assert.assertTrue( trainingData.size() == MAX_NUM_TRAINING_DATA ); Assert.assertTrue( trainingData.size() == MAX_NUM_TRAINING_DATA );
} }
@Test
public final void testDropAggregateData() {
final int MAX_NUM_TRAINING_DATA = 5000;
final double passingQual = 400.0;
final VariantRecalibratorArgumentCollection VRAC = new VariantRecalibratorArgumentCollection();
VRAC.MAX_NUM_TRAINING_DATA = MAX_NUM_TRAINING_DATA;
VariantDataManager vdm = new VariantDataManager(new ArrayList<String>(), VRAC);
final List<VariantDatum> theData = new ArrayList<>();
for( int iii = 0; iii < MAX_NUM_TRAINING_DATA * 10; iii++) {
final VariantDatum datum = new VariantDatum();
datum.atTrainingSite = true;
datum.isAggregate = false;
datum.failingSTDThreshold = false;
datum.originalQual = passingQual;
theData.add(datum);
}
for( int iii = 0; iii < MAX_NUM_TRAINING_DATA * 2; iii++) {
final VariantDatum datum = new VariantDatum();
datum.atTrainingSite = false;
datum.isAggregate = true;
datum.failingSTDThreshold = false;
datum.originalQual = passingQual;
theData.add(datum);
}
vdm.setData(theData);
vdm.dropAggregateData();
for( final VariantDatum datum : vdm.getData() ) {
Assert.assertFalse( datum.isAggregate );
}
}
} }

View File

@ -62,9 +62,11 @@ import java.util.List;
public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
private static class VRTest { private static class VRTest {
String inVCF; String inVCF;
String aggregateVCF;
String tranchesMD5; String tranchesMD5;
String recalMD5; String recalMD5;
String cutVCFMD5; String cutVCFMD5;
public VRTest(String inVCF, String tranchesMD5, String recalMD5, String cutVCFMD5) { public VRTest(String inVCF, String tranchesMD5, String recalMD5, String cutVCFMD5) {
this.inVCF = inVCF; this.inVCF = inVCF;
this.tranchesMD5 = tranchesMD5; this.tranchesMD5 = tranchesMD5;
@ -72,6 +74,14 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
this.cutVCFMD5 = cutVCFMD5; this.cutVCFMD5 = cutVCFMD5;
} }
public VRTest(String inVCF, String aggregateVCF, String tranchesMD5, String recalMD5, String cutVCFMD5) {
this.inVCF = inVCF;
this.aggregateVCF = aggregateVCF;
this.tranchesMD5 = tranchesMD5;
this.recalMD5 = recalMD5;
this.cutVCFMD5 = cutVCFMD5;
}
@Override @Override
public String toString() { public String toString() {
return "VRTest{inVCF='" + inVCF +"'}"; return "VRTest{inVCF='" + inVCF +"'}";
@ -83,10 +93,20 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
"73c7897441622c9b37376eb4f071c560", // recal file "73c7897441622c9b37376eb4f071c560", // recal file
"11a28df79b92229bd317ac49a3ed0fa1"); // cut VCF "11a28df79b92229bd317ac49a3ed0fa1"); // cut VCF
VRTest lowPassPlusExomes = new VRTest(validationDataLocation + "phase1.projectConsensus.chr20.raw.snps.vcf",
validationDataLocation + "1kg_exomes_unfiltered.AFR.unfiltered.vcf",
"ce4bfc6619147fe7ce1f8331bbeb86ce", // tranches
"1b33c10be7d8bf8e9accd11113835262", // recal file
"4700d52a06f2ef3a5882719b86911e51"); // cut VCF
@DataProvider(name = "VRTest") @DataProvider(name = "VRTest")
public Object[][] createData1() { public Object[][] createData1() {
return new Object[][]{ {lowPass} }; return new Object[][]{ {lowPass} };
//return new Object[][]{ {yriTrio}, {lowPass} }; // Add hg19 chr20 trio calls here }
@DataProvider(name = "VRAggregateTest")
public Object[][] createData2() {
return new Object[][]{ {lowPassPlusExomes} };
} }
@Test(dataProvider = "VRTest") @Test(dataProvider = "VRTest")
@ -125,6 +145,43 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
executeTest("testApplyRecalibration-"+params.inVCF, spec); executeTest("testApplyRecalibration-"+params.inVCF, spec);
} }
@Test(dataProvider = "VRAggregateTest")
public void testVariantRecalibratorAggregate(VRTest params) {
//System.out.printf("PARAMS FOR %s is %s%n", vcf, clusterFile);
WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec(
"-R " + b37KGReference +
" -resource:known=true,prior=10.0 " + GATKDataLocation + "dbsnp_132_b37.leftAligned.vcf" +
" -resource:truth=true,training=true,prior=15.0 " + comparisonDataLocation + "Validated/HapMap/3.3/sites_r27_nr.b37_fwd.vcf" +
" -resource:training=true,truth=true,prior=12.0 " + comparisonDataLocation + "Validated/Omni2.5_chip/Omni25_sites_1525_samples.b37.vcf" +
" -T VariantRecalibrator" +
" -input " + params.inVCF +
" -aggregate " + params.aggregateVCF +
" -L 20:1,000,000-40,000,000" +
" --no_cmdline_in_header" +
" -an QD -an HaplotypeScore -an MQ" +
" --trustAllPolymorphic" + // for speed
" -recalFile %s" +
" -tranchesFile %s",
Arrays.asList(params.recalMD5, params.tranchesMD5));
executeTest("testVariantRecalibratorAggregate-"+params.inVCF, spec).getFirst();
}
@Test(dataProvider = "VRAggregateTest",dependsOnMethods="testVariantRecalibratorAggregate")
public void testApplyRecalibrationAggregate(VRTest params) {
WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec(
"-R " + b37KGReference +
" -T ApplyRecalibration" +
" -L 20:12,000,000-30,000,000" +
" --no_cmdline_in_header" +
" -input " + params.inVCF +
" -U LENIENT_VCF_PROCESSING -o %s" +
" -tranchesFile " + getMd5DB().getMD5FilePath(params.tranchesMD5, null) +
" -recalFile " + getMd5DB().getMD5FilePath(params.recalMD5, null),
Arrays.asList(params.cutVCFMD5));
spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
executeTest("testApplyRecalibrationAggregate-"+params.inVCF, spec);
}
VRTest bcfTest = new VRTest(privateTestDir + "vqsr.bcf_test.snps.unfiltered.bcf", VRTest bcfTest = new VRTest(privateTestDir + "vqsr.bcf_test.snps.unfiltered.bcf",
"3ad7f55fb3b072f373cbce0b32b66df4", // tranches "3ad7f55fb3b072f373cbce0b32b66df4", // tranches
"e747c08131d58d9a4800720f6ca80e0c", // recal file "e747c08131d58d9a4800720f6ca80e0c", // recal file
@ -133,7 +190,6 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
@DataProvider(name = "VRBCFTest") @DataProvider(name = "VRBCFTest")
public Object[][] createVRBCFTest() { public Object[][] createVRBCFTest() {
return new Object[][]{ {bcfTest} }; return new Object[][]{ {bcfTest} };
//return new Object[][]{ {yriTrio}, {lowPass} }; // Add hg19 chr20 trio calls here
} }
@Test(dataProvider = "VRBCFTest") @Test(dataProvider = "VRBCFTest")