From fdc2ebb321abe0da2ea20b13b5250d016740213e Mon Sep 17 00:00:00 2001 From: Ryan Poplin Date: Sat, 2 Jul 2011 17:15:13 -0400 Subject: [PATCH] Adding ability to specify in VQSR a list of bad sites to use when training the negative model. Just add bad=true to the list of rod tags for your bad sites track. --- .../variantrecalibration/TrainingSet.java | 12 ++++- .../VariantDataManager.java | 44 ++++++++++++------- .../variantrecalibration/VariantDatum.java | 2 +- .../VariantRecalibrator.java | 4 +- 4 files changed, 42 insertions(+), 20 deletions(-) diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/TrainingSet.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/TrainingSet.java index f3677421e..9bbcf395a 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/TrainingSet.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/TrainingSet.java @@ -14,6 +14,7 @@ public class TrainingSet { public String name; public boolean isKnown = false; public boolean isTraining = false; + public boolean isAntiTraining = false; public boolean isTruth = false; public boolean isConsensus = false; public double prior = 0.0; @@ -22,17 +23,24 @@ public class TrainingSet { public TrainingSet( final String name, final Tags tags ) { this.name = name; + + // Parse the tags to decide which tracks have which properties if( tags != null ) { isKnown = tags.containsKey("known") && tags.getValue("known").equals("true"); isTraining = tags.containsKey("training") && tags.getValue("training").equals("true"); + isAntiTraining = tags.containsKey("bad") && tags.getValue("bad").equals("true"); isTruth = tags.containsKey("truth") && tags.getValue("truth").equals("true"); isConsensus = tags.containsKey("consensus") && tags.getValue("consensus").equals("true"); prior = ( tags.containsKey("prior") ? Double.parseDouble(tags.getValue("prior")) : prior ); } - if( !isConsensus ) { + + // Report back to the user which tracks were found and the properties that were detected + if( !isConsensus && !isAntiTraining ) { logger.info( String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) ); - } else { + } else if( isConsensus ) { logger.info( String.format( "Found consensus track: %s", this.name) ); + } else { + logger.info( String.format( "Found bad sites training track: %s", this.name) ); } } } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java index efff5b780..cd739f9fc 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java @@ -84,7 +84,6 @@ public class VariantDataManager { remove = remove || (Math.abs(val) > VRAC.STD_THRESHOLD); } datum.failingSTDThreshold = remove; - datum.usedForTraining = 0; } } @@ -118,38 +117,47 @@ public class VariantDataManager { for( final VariantDatum datum : data ) { if( datum.atTrainingSite && !datum.failingSTDThreshold && datum.originalQual > VRAC.QUAL_THRESHOLD ) { trainingData.add( datum ); - datum.usedForTraining = 1; } } logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." ); if( trainingData.size() < VRAC.MIN_NUM_BAD_VARIANTS ) { - logger.warn("WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable."); + logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." ); } return trainingData; } public ExpandingArrayList selectWorstVariants( double bottomPercentage, final int minimumNumber ) { - Collections.sort( data ); + // The return value is the list of training variants final ExpandingArrayList trainingData = new ExpandingArrayList(); - final int numToAdd = Math.max( minimumNumber, Math.round((float)bottomPercentage * data.size()) ); - if( numToAdd > data.size() ) { - throw new UserException.BadInput("Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe."); + + // First add to the training list all sites overlapping any bad sites training tracks + for( final VariantDatum datum : data ) { + if( datum.atAntiTrainingSite && !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) { + trainingData.add( datum ); + } } - if( numToAdd == minimumNumber ) { - logger.warn("WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable."); + final int numBadSitesAdded = trainingData.size(); + logger.info( "Found " + numBadSitesAdded + " variants overlapping bad sites training tracks." ); + + // Next, sort the variants by the LOD coming from the positive model and add to the list the bottom X percent of variants + Collections.sort( data ); + final int numToAdd = Math.max( minimumNumber - trainingData.size(), Math.round((float)bottomPercentage * data.size()) ); + if( numToAdd > data.size() ) { + throw new UserException.BadInput( "Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe." ); + } else if( numToAdd == minimumNumber - trainingData.size() ) { + logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." ); bottomPercentage = ((float) numToAdd) / ((float) data.size()); } - int index = 0; - int numAdded = 0; + int index = 0, numAdded = 0; while( numAdded < numToAdd ) { final VariantDatum datum = data.get(index++); if( !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) { + datum.atAntiTrainingSite = true; trainingData.add( datum ); - datum.usedForTraining = -1; numAdded++; } } - logger.info("Training with worst " + (float) bottomPercentage * 100.0f + "% of passing data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "."); + logger.info( "Additionally training with worst " + (float) bottomPercentage * 100.0f + "% of passing data --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." ); return trainingData; } @@ -162,10 +170,11 @@ public class VariantDataManager { returnData.add(datum); } } - // add an extra 5% of points from bad training set, since that set is small but interesting + + // Add an extra 5% of points from bad training set, since that set is small but interesting for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) { final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size())); - if( datum.usedForTraining == -1 && !datum.failingSTDThreshold ) { returnData.add(datum); } + if( datum.atAntiTrainingSite && !datum.failingSTDThreshold ) { returnData.add(datum); } else { iii--; } } @@ -236,6 +245,7 @@ public class VariantDataManager { datum.atTrainingSite = false; datum.prior = 2.0; datum.consensusCount = 0; + for( final TrainingSet trainingSet : trainingSets ) { for( final VariantContext trainVC : tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, false ) ) { if( trainVC != null && trainVC.isNotFiltered() && trainVC.isVariant() && @@ -248,6 +258,10 @@ public class VariantDataManager { datum.prior = Math.max( datum.prior, trainingSet.prior ); datum.consensusCount += ( trainingSet.isConsensus ? 1 : 0 ); } + if( trainVC != null ) { + datum.atAntiTrainingSite = datum.atAntiTrainingSite || trainingSet.isAntiTraining; + } + } } } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java index 38e2affb6..04a5a9d3e 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java @@ -14,13 +14,13 @@ public class VariantDatum implements Comparable { public double lod; public boolean atTruthSite; public boolean atTrainingSite; + public boolean atAntiTrainingSite; public boolean isTransition; public boolean isSNP; public boolean failingSTDThreshold; public double originalQual; public double prior; public int consensusCount; - public int usedForTraining; public String contig; public int start; public int stop; diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java index b5fda4443..b903d20af 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java @@ -175,7 +175,6 @@ public class VariantRecalibrator extends RodWalker