Allow for additional input data to be used in the VQSR for clustering but don't carry it forward into the output VCF file.

-- New -a argument in the VQSR for specifying additional data to be used in the clustering
-- New NA12878KB walker which creates ROC curves by partitioning the data along VQSLOD and calculating how many KB TP/FP's are called.
This commit is contained in:
Ryan Poplin 2013-12-11 14:43:07 -05:00
parent c82501ac35
commit 856c1f87c1
6 changed files with 150 additions and 10 deletions

View File

@ -46,6 +46,7 @@
package org.broadinstitute.sting.gatk.walkers.variantrecalibration;
import it.unimi.dsi.fastutil.booleans.BooleanLists;
import org.apache.commons.lang.ArrayUtils;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.gatk.GenomeAnalysisEngine;
@ -71,7 +72,7 @@ import java.util.*;
*/
public class VariantDataManager {
private List<VariantDatum> data;
private List<VariantDatum> data = Collections.emptyList();
private double[] meanVector;
private double[] varianceVector; // this is really the standard deviation
public List<String> annotationKeys;
@ -80,7 +81,7 @@ public class VariantDataManager {
protected final List<TrainingSet> trainingSets;
public VariantDataManager( final List<String> annotationKeys, final VariantRecalibratorArgumentCollection VRAC ) {
this.data = null;
this.data = Collections.emptyList();
this.annotationKeys = new ArrayList<>( annotationKeys );
this.VRAC = VRAC;
meanVector = new double[this.annotationKeys.size()];
@ -279,6 +280,19 @@ public class VariantDataManager {
return evaluationData;
}
/**
* Remove all VariantDatum's from the data list which are marked as aggregate data
*/
public void dropAggregateData() {
final Iterator<VariantDatum> iter = data.iterator();
while (iter.hasNext()) {
final VariantDatum datum = iter.next();
if( datum.isAggregate ) {
iter.remove();
}
}
}
public List<VariantDatum> getRandomDataForPlotting( final int numToAdd, final List<VariantDatum> trainingData, final List<VariantDatum> antiTrainingData, final List<VariantDatum> evaluationData ) {
final List<VariantDatum> returnData = new ExpandingArrayList<>();
Collections.shuffle(trainingData);

View File

@ -74,7 +74,8 @@ public class VariantDatum {
public int consensusCount;
public GenomeLoc loc;
public int worstAnnotation;
public MultivariateGaussian assignment; // used in K-means implementation
public MultivariateGaussian assignment; // used in K-means implementation
public boolean isAggregate; // this datum was provided to aid in modeling but isn't part of the input callset
public static class VariantDatumLODComparator implements Comparator<VariantDatum>, Serializable {
@Override

View File

@ -152,11 +152,17 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
// Inputs
/////////////////////////////
/**
* These calls should be unfiltered and annotated with the error covariates that are intended to use for modeling.
* These calls should be unfiltered and annotated with the error covariates that are intended to be used for modeling.
*/
@Input(fullName="input", shortName = "input", doc="The raw input variants to be recalibrated", required=true)
public List<RodBinding<VariantContext>> input;
/**
* These additional calls should be unfiltered and annotated with the error covariates that are intended to be used for modeling.
*/
@Input(fullName="aggregate", shortName = "aggregate", doc="Additional raw input variants to be used in building the model", required=false)
public List<RodBinding<VariantContext>> aggregate;
/**
* Any set of VCF files to use as lists of training, truth, or known sites.
* Training - Input variants which are found to overlap with these training sites are used to build the Gaussian mixture model.
@ -290,29 +296,53 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
return mapList;
}
for( final VariantContext vc : tracker.getValues(input, context.getLocation()) ) {
mapList.addAll( addOverlappingVariants(input, true, tracker, context) );
if( aggregate != null ) {
mapList.addAll( addOverlappingVariants(aggregate, false, tracker, context) );
}
return mapList;
}
/**
* Using the RefMetaDataTracker find overlapping variants and pull out the necessary information to create the VariantDatum
* @param rods the rods to search within
* @param isInput is this rod an -input rod?
* @param tracker the RefMetaDataTracker from the RODWalker map call
* @param context the AlignmentContext from the RODWalker map call
* @return a list of VariantDatums, can be empty
*/
private List<VariantDatum> addOverlappingVariants( final List<RodBinding<VariantContext>> rods, final boolean isInput, final RefMetaDataTracker tracker, final AlignmentContext context ) {
if( rods == null ) { throw new IllegalArgumentException("rods cannot be null."); }
if( tracker == null ) { throw new IllegalArgumentException("tracker cannot be null."); }
if( context == null ) { throw new IllegalArgumentException("context cannot be null."); }
final ExpandingArrayList<VariantDatum> variants = new ExpandingArrayList<>();
for( final VariantContext vc : tracker.getValues(rods, context.getLocation()) ) {
if( vc != null && ( vc.isNotFiltered() || ignoreInputFilterSet.containsAll(vc.getFilters()) ) ) {
if( VariantDataManager.checkVariationClass( vc, VRAC.MODE ) ) {
final VariantDatum datum = new VariantDatum();
// Populate the datum with lots of fields from the VariantContext, unfortunately the VC is too big so we just pull in only the things we absolutely need.
dataManager.decodeAnnotations( datum, vc, true ); //BUGBUG: when run with HierarchicalMicroScheduler this is non-deterministic because order of calls depends on load of machine
datum.loc = getToolkit().getGenomeLocParser().createGenomeLoc(vc);
datum.loc = ( isInput ? getToolkit().getGenomeLocParser().createGenomeLoc(vc) : null );
datum.originalQual = vc.getPhredScaledQual();
datum.isSNP = vc.isSNP() && vc.isBiallelic();
datum.isTransition = datum.isSNP && GATKVariantContextUtils.isTransition(vc);
datum.isAggregate = !isInput;
// Loop through the training data sets and if they overlap this loci then update the prior and training status appropriately
dataManager.parseTrainingSets( tracker, context.getLocation(), vc, datum, TRUST_ALL_POLYMORPHIC );
final double priorFactor = QualityUtils.qualToProb( datum.prior );
datum.prior = Math.log10( priorFactor ) - Math.log10( 1.0 - priorFactor );
mapList.add( datum );
variants.add( datum );
}
}
}
return mapList;
return variants;
}
//---------------------------------------------------------------------------------------------------------------
@ -357,6 +387,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
// Generate the negative model using the worst performing data and evaluate each variant contrastively
final List<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants();
final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
engine.evaluateData( dataManager.getData(), badModel, true );
if( badModel.failedToConverge || goodModel.failedToConverge ) {

View File

@ -80,6 +80,9 @@ public class VariantRecalibratorEngine {
}
public GaussianMixtureModel generateModel( final List<VariantDatum> data, final int maxGaussians ) {
if( data == null || data.isEmpty() ) { throw new IllegalArgumentException("No data found."); }
if( maxGaussians <= 0 ) { throw new IllegalArgumentException("maxGaussians must be a positive integer but found: " + maxGaussians); }
final GaussianMixtureModel model = new GaussianMixtureModel( maxGaussians, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS );
variationalBayesExpectationMaximization( model, data );
return model;

View File

@ -142,4 +142,39 @@ public class VariantDataManagerUnitTest extends BaseTest {
Assert.assertTrue( trainingData.size() == MAX_NUM_TRAINING_DATA );
}
@Test
public final void testDropAggregateData() {
final int MAX_NUM_TRAINING_DATA = 5000;
final double passingQual = 400.0;
final VariantRecalibratorArgumentCollection VRAC = new VariantRecalibratorArgumentCollection();
VRAC.MAX_NUM_TRAINING_DATA = MAX_NUM_TRAINING_DATA;
VariantDataManager vdm = new VariantDataManager(new ArrayList<String>(), VRAC);
final List<VariantDatum> theData = new ArrayList<>();
for( int iii = 0; iii < MAX_NUM_TRAINING_DATA * 10; iii++) {
final VariantDatum datum = new VariantDatum();
datum.atTrainingSite = true;
datum.isAggregate = false;
datum.failingSTDThreshold = false;
datum.originalQual = passingQual;
theData.add(datum);
}
for( int iii = 0; iii < MAX_NUM_TRAINING_DATA * 2; iii++) {
final VariantDatum datum = new VariantDatum();
datum.atTrainingSite = false;
datum.isAggregate = true;
datum.failingSTDThreshold = false;
datum.originalQual = passingQual;
theData.add(datum);
}
vdm.setData(theData);
vdm.dropAggregateData();
for( final VariantDatum datum : vdm.getData() ) {
Assert.assertFalse( datum.isAggregate );
}
}
}

View File

@ -62,9 +62,11 @@ import java.util.List;
public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
private static class VRTest {
String inVCF;
String aggregateVCF;
String tranchesMD5;
String recalMD5;
String cutVCFMD5;
public VRTest(String inVCF, String tranchesMD5, String recalMD5, String cutVCFMD5) {
this.inVCF = inVCF;
this.tranchesMD5 = tranchesMD5;
@ -72,6 +74,14 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
this.cutVCFMD5 = cutVCFMD5;
}
public VRTest(String inVCF, String aggregateVCF, String tranchesMD5, String recalMD5, String cutVCFMD5) {
this.inVCF = inVCF;
this.aggregateVCF = aggregateVCF;
this.tranchesMD5 = tranchesMD5;
this.recalMD5 = recalMD5;
this.cutVCFMD5 = cutVCFMD5;
}
@Override
public String toString() {
return "VRTest{inVCF='" + inVCF +"'}";
@ -83,10 +93,20 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
"73c7897441622c9b37376eb4f071c560", // recal file
"11a28df79b92229bd317ac49a3ed0fa1"); // cut VCF
VRTest lowPassPlusExomes = new VRTest(validationDataLocation + "phase1.projectConsensus.chr20.raw.snps.vcf",
validationDataLocation + "1kg_exomes_unfiltered.AFR.unfiltered.vcf",
"ce4bfc6619147fe7ce1f8331bbeb86ce", // tranches
"1b33c10be7d8bf8e9accd11113835262", // recal file
"4700d52a06f2ef3a5882719b86911e51"); // cut VCF
@DataProvider(name = "VRTest")
public Object[][] createData1() {
return new Object[][]{ {lowPass} };
//return new Object[][]{ {yriTrio}, {lowPass} }; // Add hg19 chr20 trio calls here
}
@DataProvider(name = "VRAggregateTest")
public Object[][] createData2() {
return new Object[][]{ {lowPassPlusExomes} };
}
@Test(dataProvider = "VRTest")
@ -125,6 +145,43 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
executeTest("testApplyRecalibration-"+params.inVCF, spec);
}
@Test(dataProvider = "VRAggregateTest")
public void testVariantRecalibratorAggregate(VRTest params) {
//System.out.printf("PARAMS FOR %s is %s%n", vcf, clusterFile);
WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec(
"-R " + b37KGReference +
" -resource:known=true,prior=10.0 " + GATKDataLocation + "dbsnp_132_b37.leftAligned.vcf" +
" -resource:truth=true,training=true,prior=15.0 " + comparisonDataLocation + "Validated/HapMap/3.3/sites_r27_nr.b37_fwd.vcf" +
" -resource:training=true,truth=true,prior=12.0 " + comparisonDataLocation + "Validated/Omni2.5_chip/Omni25_sites_1525_samples.b37.vcf" +
" -T VariantRecalibrator" +
" -input " + params.inVCF +
" -aggregate " + params.aggregateVCF +
" -L 20:1,000,000-40,000,000" +
" --no_cmdline_in_header" +
" -an QD -an HaplotypeScore -an MQ" +
" --trustAllPolymorphic" + // for speed
" -recalFile %s" +
" -tranchesFile %s",
Arrays.asList(params.recalMD5, params.tranchesMD5));
executeTest("testVariantRecalibratorAggregate-"+params.inVCF, spec).getFirst();
}
@Test(dataProvider = "VRAggregateTest",dependsOnMethods="testVariantRecalibratorAggregate")
public void testApplyRecalibrationAggregate(VRTest params) {
WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec(
"-R " + b37KGReference +
" -T ApplyRecalibration" +
" -L 20:12,000,000-30,000,000" +
" --no_cmdline_in_header" +
" -input " + params.inVCF +
" -U LENIENT_VCF_PROCESSING -o %s" +
" -tranchesFile " + getMd5DB().getMD5FilePath(params.tranchesMD5, null) +
" -recalFile " + getMd5DB().getMD5FilePath(params.recalMD5, null),
Arrays.asList(params.cutVCFMD5));
spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
executeTest("testApplyRecalibrationAggregate-"+params.inVCF, spec);
}
VRTest bcfTest = new VRTest(privateTestDir + "vqsr.bcf_test.snps.unfiltered.bcf",
"3ad7f55fb3b072f373cbce0b32b66df4", // tranches
"e747c08131d58d9a4800720f6ca80e0c", // recal file
@ -133,7 +190,6 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
@DataProvider(name = "VRBCFTest")
public Object[][] createVRBCFTest() {
return new Object[][]{ {bcfTest} };
//return new Object[][]{ {yriTrio}, {lowPass} }; // Add hg19 chr20 trio calls here
}
@Test(dataProvider = "VRBCFTest")