Merge pull request #452 from broadinstitute/rp_vqsr_aggregate_model
Allow for additional input data to be used in the VQSR for clustering bu...
This commit is contained in:
commit
5c32ad174a
|
|
@ -46,6 +46,7 @@
|
|||
|
||||
package org.broadinstitute.sting.gatk.walkers.variantrecalibration;
|
||||
|
||||
import it.unimi.dsi.fastutil.booleans.BooleanLists;
|
||||
import org.apache.commons.lang.ArrayUtils;
|
||||
import org.apache.log4j.Logger;
|
||||
import org.broadinstitute.sting.gatk.GenomeAnalysisEngine;
|
||||
|
|
@ -71,7 +72,7 @@ import java.util.*;
|
|||
*/
|
||||
|
||||
public class VariantDataManager {
|
||||
private List<VariantDatum> data;
|
||||
private List<VariantDatum> data = Collections.emptyList();
|
||||
private double[] meanVector;
|
||||
private double[] varianceVector; // this is really the standard deviation
|
||||
public List<String> annotationKeys;
|
||||
|
|
@ -80,7 +81,7 @@ public class VariantDataManager {
|
|||
protected final List<TrainingSet> trainingSets;
|
||||
|
||||
public VariantDataManager( final List<String> annotationKeys, final VariantRecalibratorArgumentCollection VRAC ) {
|
||||
this.data = null;
|
||||
this.data = Collections.emptyList();
|
||||
this.annotationKeys = new ArrayList<>( annotationKeys );
|
||||
this.VRAC = VRAC;
|
||||
meanVector = new double[this.annotationKeys.size()];
|
||||
|
|
@ -279,6 +280,19 @@ public class VariantDataManager {
|
|||
return evaluationData;
|
||||
}
|
||||
|
||||
/**
|
||||
* Remove all VariantDatum's from the data list which are marked as aggregate data
|
||||
*/
|
||||
public void dropAggregateData() {
|
||||
final Iterator<VariantDatum> iter = data.iterator();
|
||||
while (iter.hasNext()) {
|
||||
final VariantDatum datum = iter.next();
|
||||
if( datum.isAggregate ) {
|
||||
iter.remove();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public List<VariantDatum> getRandomDataForPlotting( final int numToAdd, final List<VariantDatum> trainingData, final List<VariantDatum> antiTrainingData, final List<VariantDatum> evaluationData ) {
|
||||
final List<VariantDatum> returnData = new ExpandingArrayList<>();
|
||||
Collections.shuffle(trainingData);
|
||||
|
|
|
|||
|
|
@ -74,7 +74,8 @@ public class VariantDatum {
|
|||
public int consensusCount;
|
||||
public GenomeLoc loc;
|
||||
public int worstAnnotation;
|
||||
public MultivariateGaussian assignment; // used in K-means implementation
|
||||
public MultivariateGaussian assignment; // used in K-means implementation
|
||||
public boolean isAggregate; // this datum was provided to aid in modeling but isn't part of the input callset
|
||||
|
||||
public static class VariantDatumLODComparator implements Comparator<VariantDatum>, Serializable {
|
||||
@Override
|
||||
|
|
|
|||
|
|
@ -152,11 +152,17 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
// Inputs
|
||||
/////////////////////////////
|
||||
/**
|
||||
* These calls should be unfiltered and annotated with the error covariates that are intended to use for modeling.
|
||||
* These calls should be unfiltered and annotated with the error covariates that are intended to be used for modeling.
|
||||
*/
|
||||
@Input(fullName="input", shortName = "input", doc="The raw input variants to be recalibrated", required=true)
|
||||
public List<RodBinding<VariantContext>> input;
|
||||
|
||||
/**
|
||||
* These additional calls should be unfiltered and annotated with the error covariates that are intended to be used for modeling.
|
||||
*/
|
||||
@Input(fullName="aggregate", shortName = "aggregate", doc="Additional raw input variants to be used in building the model", required=false)
|
||||
public List<RodBinding<VariantContext>> aggregate;
|
||||
|
||||
/**
|
||||
* Any set of VCF files to use as lists of training, truth, or known sites.
|
||||
* Training - Input variants which are found to overlap with these training sites are used to build the Gaussian mixture model.
|
||||
|
|
@ -290,29 +296,53 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
return mapList;
|
||||
}
|
||||
|
||||
for( final VariantContext vc : tracker.getValues(input, context.getLocation()) ) {
|
||||
mapList.addAll( addOverlappingVariants(input, true, tracker, context) );
|
||||
if( aggregate != null ) {
|
||||
mapList.addAll( addOverlappingVariants(aggregate, false, tracker, context) );
|
||||
}
|
||||
|
||||
return mapList;
|
||||
}
|
||||
|
||||
/**
|
||||
* Using the RefMetaDataTracker find overlapping variants and pull out the necessary information to create the VariantDatum
|
||||
* @param rods the rods to search within
|
||||
* @param isInput is this rod an -input rod?
|
||||
* @param tracker the RefMetaDataTracker from the RODWalker map call
|
||||
* @param context the AlignmentContext from the RODWalker map call
|
||||
* @return a list of VariantDatums, can be empty
|
||||
*/
|
||||
private List<VariantDatum> addOverlappingVariants( final List<RodBinding<VariantContext>> rods, final boolean isInput, final RefMetaDataTracker tracker, final AlignmentContext context ) {
|
||||
if( rods == null ) { throw new IllegalArgumentException("rods cannot be null."); }
|
||||
if( tracker == null ) { throw new IllegalArgumentException("tracker cannot be null."); }
|
||||
if( context == null ) { throw new IllegalArgumentException("context cannot be null."); }
|
||||
|
||||
final ExpandingArrayList<VariantDatum> variants = new ExpandingArrayList<>();
|
||||
|
||||
for( final VariantContext vc : tracker.getValues(rods, context.getLocation()) ) {
|
||||
if( vc != null && ( vc.isNotFiltered() || ignoreInputFilterSet.containsAll(vc.getFilters()) ) ) {
|
||||
if( VariantDataManager.checkVariationClass( vc, VRAC.MODE ) ) {
|
||||
final VariantDatum datum = new VariantDatum();
|
||||
|
||||
// Populate the datum with lots of fields from the VariantContext, unfortunately the VC is too big so we just pull in only the things we absolutely need.
|
||||
dataManager.decodeAnnotations( datum, vc, true ); //BUGBUG: when run with HierarchicalMicroScheduler this is non-deterministic because order of calls depends on load of machine
|
||||
datum.loc = getToolkit().getGenomeLocParser().createGenomeLoc(vc);
|
||||
datum.loc = ( isInput ? getToolkit().getGenomeLocParser().createGenomeLoc(vc) : null );
|
||||
datum.originalQual = vc.getPhredScaledQual();
|
||||
datum.isSNP = vc.isSNP() && vc.isBiallelic();
|
||||
datum.isTransition = datum.isSNP && GATKVariantContextUtils.isTransition(vc);
|
||||
datum.isAggregate = !isInput;
|
||||
|
||||
// Loop through the training data sets and if they overlap this loci then update the prior and training status appropriately
|
||||
dataManager.parseTrainingSets( tracker, context.getLocation(), vc, datum, TRUST_ALL_POLYMORPHIC );
|
||||
final double priorFactor = QualityUtils.qualToProb( datum.prior );
|
||||
datum.prior = Math.log10( priorFactor ) - Math.log10( 1.0 - priorFactor );
|
||||
|
||||
mapList.add( datum );
|
||||
variants.add( datum );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return mapList;
|
||||
return variants;
|
||||
}
|
||||
|
||||
//---------------------------------------------------------------------------------------------------------------
|
||||
|
|
@ -357,6 +387,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
|
|||
// Generate the negative model using the worst performing data and evaluate each variant contrastively
|
||||
final List<VariantDatum> negativeTrainingData = dataManager.selectWorstVariants();
|
||||
final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS));
|
||||
dataManager.dropAggregateData(); // Don't need the aggregate data anymore so let's free up the memory
|
||||
engine.evaluateData( dataManager.getData(), badModel, true );
|
||||
|
||||
if( badModel.failedToConverge || goodModel.failedToConverge ) {
|
||||
|
|
|
|||
|
|
@ -80,6 +80,9 @@ public class VariantRecalibratorEngine {
|
|||
}
|
||||
|
||||
public GaussianMixtureModel generateModel( final List<VariantDatum> data, final int maxGaussians ) {
|
||||
if( data == null || data.isEmpty() ) { throw new IllegalArgumentException("No data found."); }
|
||||
if( maxGaussians <= 0 ) { throw new IllegalArgumentException("maxGaussians must be a positive integer but found: " + maxGaussians); }
|
||||
|
||||
final GaussianMixtureModel model = new GaussianMixtureModel( maxGaussians, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS );
|
||||
variationalBayesExpectationMaximization( model, data );
|
||||
return model;
|
||||
|
|
|
|||
|
|
@ -142,4 +142,39 @@ public class VariantDataManagerUnitTest extends BaseTest {
|
|||
|
||||
Assert.assertTrue( trainingData.size() == MAX_NUM_TRAINING_DATA );
|
||||
}
|
||||
|
||||
@Test
|
||||
public final void testDropAggregateData() {
|
||||
final int MAX_NUM_TRAINING_DATA = 5000;
|
||||
final double passingQual = 400.0;
|
||||
final VariantRecalibratorArgumentCollection VRAC = new VariantRecalibratorArgumentCollection();
|
||||
VRAC.MAX_NUM_TRAINING_DATA = MAX_NUM_TRAINING_DATA;
|
||||
|
||||
VariantDataManager vdm = new VariantDataManager(new ArrayList<String>(), VRAC);
|
||||
final List<VariantDatum> theData = new ArrayList<>();
|
||||
for( int iii = 0; iii < MAX_NUM_TRAINING_DATA * 10; iii++) {
|
||||
final VariantDatum datum = new VariantDatum();
|
||||
datum.atTrainingSite = true;
|
||||
datum.isAggregate = false;
|
||||
datum.failingSTDThreshold = false;
|
||||
datum.originalQual = passingQual;
|
||||
theData.add(datum);
|
||||
}
|
||||
|
||||
for( int iii = 0; iii < MAX_NUM_TRAINING_DATA * 2; iii++) {
|
||||
final VariantDatum datum = new VariantDatum();
|
||||
datum.atTrainingSite = false;
|
||||
datum.isAggregate = true;
|
||||
datum.failingSTDThreshold = false;
|
||||
datum.originalQual = passingQual;
|
||||
theData.add(datum);
|
||||
}
|
||||
|
||||
vdm.setData(theData);
|
||||
vdm.dropAggregateData();
|
||||
|
||||
for( final VariantDatum datum : vdm.getData() ) {
|
||||
Assert.assertFalse( datum.isAggregate );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -62,9 +62,11 @@ import java.util.List;
|
|||
public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
||||
private static class VRTest {
|
||||
String inVCF;
|
||||
String aggregateVCF;
|
||||
String tranchesMD5;
|
||||
String recalMD5;
|
||||
String cutVCFMD5;
|
||||
|
||||
public VRTest(String inVCF, String tranchesMD5, String recalMD5, String cutVCFMD5) {
|
||||
this.inVCF = inVCF;
|
||||
this.tranchesMD5 = tranchesMD5;
|
||||
|
|
@ -72,6 +74,14 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
|||
this.cutVCFMD5 = cutVCFMD5;
|
||||
}
|
||||
|
||||
public VRTest(String inVCF, String aggregateVCF, String tranchesMD5, String recalMD5, String cutVCFMD5) {
|
||||
this.inVCF = inVCF;
|
||||
this.aggregateVCF = aggregateVCF;
|
||||
this.tranchesMD5 = tranchesMD5;
|
||||
this.recalMD5 = recalMD5;
|
||||
this.cutVCFMD5 = cutVCFMD5;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "VRTest{inVCF='" + inVCF +"'}";
|
||||
|
|
@ -83,10 +93,20 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
|||
"73c7897441622c9b37376eb4f071c560", // recal file
|
||||
"11a28df79b92229bd317ac49a3ed0fa1"); // cut VCF
|
||||
|
||||
VRTest lowPassPlusExomes = new VRTest(validationDataLocation + "phase1.projectConsensus.chr20.raw.snps.vcf",
|
||||
validationDataLocation + "1kg_exomes_unfiltered.AFR.unfiltered.vcf",
|
||||
"ce4bfc6619147fe7ce1f8331bbeb86ce", // tranches
|
||||
"1b33c10be7d8bf8e9accd11113835262", // recal file
|
||||
"4700d52a06f2ef3a5882719b86911e51"); // cut VCF
|
||||
|
||||
@DataProvider(name = "VRTest")
|
||||
public Object[][] createData1() {
|
||||
return new Object[][]{ {lowPass} };
|
||||
//return new Object[][]{ {yriTrio}, {lowPass} }; // Add hg19 chr20 trio calls here
|
||||
}
|
||||
|
||||
@DataProvider(name = "VRAggregateTest")
|
||||
public Object[][] createData2() {
|
||||
return new Object[][]{ {lowPassPlusExomes} };
|
||||
}
|
||||
|
||||
@Test(dataProvider = "VRTest")
|
||||
|
|
@ -125,6 +145,43 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
|||
executeTest("testApplyRecalibration-"+params.inVCF, spec);
|
||||
}
|
||||
|
||||
@Test(dataProvider = "VRAggregateTest")
|
||||
public void testVariantRecalibratorAggregate(VRTest params) {
|
||||
//System.out.printf("PARAMS FOR %s is %s%n", vcf, clusterFile);
|
||||
WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec(
|
||||
"-R " + b37KGReference +
|
||||
" -resource:known=true,prior=10.0 " + GATKDataLocation + "dbsnp_132_b37.leftAligned.vcf" +
|
||||
" -resource:truth=true,training=true,prior=15.0 " + comparisonDataLocation + "Validated/HapMap/3.3/sites_r27_nr.b37_fwd.vcf" +
|
||||
" -resource:training=true,truth=true,prior=12.0 " + comparisonDataLocation + "Validated/Omni2.5_chip/Omni25_sites_1525_samples.b37.vcf" +
|
||||
" -T VariantRecalibrator" +
|
||||
" -input " + params.inVCF +
|
||||
" -aggregate " + params.aggregateVCF +
|
||||
" -L 20:1,000,000-40,000,000" +
|
||||
" --no_cmdline_in_header" +
|
||||
" -an QD -an HaplotypeScore -an MQ" +
|
||||
" --trustAllPolymorphic" + // for speed
|
||||
" -recalFile %s" +
|
||||
" -tranchesFile %s",
|
||||
Arrays.asList(params.recalMD5, params.tranchesMD5));
|
||||
executeTest("testVariantRecalibratorAggregate-"+params.inVCF, spec).getFirst();
|
||||
}
|
||||
|
||||
@Test(dataProvider = "VRAggregateTest",dependsOnMethods="testVariantRecalibratorAggregate")
|
||||
public void testApplyRecalibrationAggregate(VRTest params) {
|
||||
WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec(
|
||||
"-R " + b37KGReference +
|
||||
" -T ApplyRecalibration" +
|
||||
" -L 20:12,000,000-30,000,000" +
|
||||
" --no_cmdline_in_header" +
|
||||
" -input " + params.inVCF +
|
||||
" -U LENIENT_VCF_PROCESSING -o %s" +
|
||||
" -tranchesFile " + getMd5DB().getMD5FilePath(params.tranchesMD5, null) +
|
||||
" -recalFile " + getMd5DB().getMD5FilePath(params.recalMD5, null),
|
||||
Arrays.asList(params.cutVCFMD5));
|
||||
spec.disableShadowBCF(); // TODO -- enable when we support symbolic alleles
|
||||
executeTest("testApplyRecalibrationAggregate-"+params.inVCF, spec);
|
||||
}
|
||||
|
||||
VRTest bcfTest = new VRTest(privateTestDir + "vqsr.bcf_test.snps.unfiltered.bcf",
|
||||
"3ad7f55fb3b072f373cbce0b32b66df4", // tranches
|
||||
"e747c08131d58d9a4800720f6ca80e0c", // recal file
|
||||
|
|
@ -133,7 +190,6 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
|
|||
@DataProvider(name = "VRBCFTest")
|
||||
public Object[][] createVRBCFTest() {
|
||||
return new Object[][]{ {bcfTest} };
|
||||
//return new Object[][]{ {yriTrio}, {lowPass} }; // Add hg19 chr20 trio calls here
|
||||
}
|
||||
|
||||
@Test(dataProvider = "VRBCFTest")
|
||||
|
|
|
|||
Loading…
Reference in New Issue