Adding ability to specify in VQSR a list of bad sites to use when training the negative model. Just add bad=true to the list of rod tags for your bad sites track.

This commit is contained in:
Ryan Poplin 2011-07-02 17:15:13 -04:00
parent 4532a84314
commit fdc2ebb321
4 changed files with 42 additions and 20 deletions

View File

@ -14,6 +14,7 @@ public class TrainingSet {
public String name; public String name;
public boolean isKnown = false; public boolean isKnown = false;
public boolean isTraining = false; public boolean isTraining = false;
public boolean isAntiTraining = false;
public boolean isTruth = false; public boolean isTruth = false;
public boolean isConsensus = false; public boolean isConsensus = false;
public double prior = 0.0; public double prior = 0.0;
@ -22,17 +23,24 @@ public class TrainingSet {
public TrainingSet( final String name, final Tags tags ) { public TrainingSet( final String name, final Tags tags ) {
this.name = name; this.name = name;
// Parse the tags to decide which tracks have which properties
if( tags != null ) { if( tags != null ) {
isKnown = tags.containsKey("known") && tags.getValue("known").equals("true"); isKnown = tags.containsKey("known") && tags.getValue("known").equals("true");
isTraining = tags.containsKey("training") && tags.getValue("training").equals("true"); isTraining = tags.containsKey("training") && tags.getValue("training").equals("true");
isAntiTraining = tags.containsKey("bad") && tags.getValue("bad").equals("true");
isTruth = tags.containsKey("truth") && tags.getValue("truth").equals("true"); isTruth = tags.containsKey("truth") && tags.getValue("truth").equals("true");
isConsensus = tags.containsKey("consensus") && tags.getValue("consensus").equals("true"); isConsensus = tags.containsKey("consensus") && tags.getValue("consensus").equals("true");
prior = ( tags.containsKey("prior") ? Double.parseDouble(tags.getValue("prior")) : prior ); prior = ( tags.containsKey("prior") ? Double.parseDouble(tags.getValue("prior")) : prior );
} }
if( !isConsensus ) {
// Report back to the user which tracks were found and the properties that were detected
if( !isConsensus && !isAntiTraining ) {
logger.info( String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) ); logger.info( String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) );
} else { } else if( isConsensus ) {
logger.info( String.format( "Found consensus track: %s", this.name) ); logger.info( String.format( "Found consensus track: %s", this.name) );
} else {
logger.info( String.format( "Found bad sites training track: %s", this.name) );
} }
} }
} }

View File

@ -84,7 +84,6 @@ public class VariantDataManager {
remove = remove || (Math.abs(val) > VRAC.STD_THRESHOLD); remove = remove || (Math.abs(val) > VRAC.STD_THRESHOLD);
} }
datum.failingSTDThreshold = remove; datum.failingSTDThreshold = remove;
datum.usedForTraining = 0;
} }
} }
@ -118,7 +117,6 @@ public class VariantDataManager {
for( final VariantDatum datum : data ) { for( final VariantDatum datum : data ) {
if( datum.atTrainingSite && !datum.failingSTDThreshold && datum.originalQual > VRAC.QUAL_THRESHOLD ) { if( datum.atTrainingSite && !datum.failingSTDThreshold && datum.originalQual > VRAC.QUAL_THRESHOLD ) {
trainingData.add( datum ); trainingData.add( datum );
datum.usedForTraining = 1;
} }
} }
logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." ); logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." );
@ -129,27 +127,37 @@ public class VariantDataManager {
} }
public ExpandingArrayList<VariantDatum> selectWorstVariants( double bottomPercentage, final int minimumNumber ) { public ExpandingArrayList<VariantDatum> selectWorstVariants( double bottomPercentage, final int minimumNumber ) {
Collections.sort( data ); // The return value is the list of training variants
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>(); final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>();
final int numToAdd = Math.max( minimumNumber, Math.round((float)bottomPercentage * data.size()) );
// First add to the training list all sites overlapping any bad sites training tracks
for( final VariantDatum datum : data ) {
if( datum.atAntiTrainingSite && !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) {
trainingData.add( datum );
}
}
final int numBadSitesAdded = trainingData.size();
logger.info( "Found " + numBadSitesAdded + " variants overlapping bad sites training tracks." );
// Next, sort the variants by the LOD coming from the positive model and add to the list the bottom X percent of variants
Collections.sort( data );
final int numToAdd = Math.max( minimumNumber - trainingData.size(), Math.round((float)bottomPercentage * data.size()) );
if( numToAdd > data.size() ) { if( numToAdd > data.size() ) {
throw new UserException.BadInput( "Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe." ); throw new UserException.BadInput( "Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe." );
} } else if( numToAdd == minimumNumber - trainingData.size() ) {
if( numToAdd == minimumNumber ) {
logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." ); logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." );
bottomPercentage = ((float) numToAdd) / ((float) data.size()); bottomPercentage = ((float) numToAdd) / ((float) data.size());
} }
int index = 0; int index = 0, numAdded = 0;
int numAdded = 0;
while( numAdded < numToAdd ) { while( numAdded < numToAdd ) {
final VariantDatum datum = data.get(index++); final VariantDatum datum = data.get(index++);
if( !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) { if( !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) {
datum.atAntiTrainingSite = true;
trainingData.add( datum ); trainingData.add( datum );
datum.usedForTraining = -1;
numAdded++; numAdded++;
} }
} }
logger.info("Training with worst " + (float) bottomPercentage * 100.0f + "% of passing data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "."); logger.info( "Additionally training with worst " + (float) bottomPercentage * 100.0f + "% of passing data --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." );
return trainingData; return trainingData;
} }
@ -162,10 +170,11 @@ public class VariantDataManager {
returnData.add(datum); returnData.add(datum);
} }
} }
// add an extra 5% of points from bad training set, since that set is small but interesting
// Add an extra 5% of points from bad training set, since that set is small but interesting
for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) { for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) {
final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size())); final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
if( datum.usedForTraining == -1 && !datum.failingSTDThreshold ) { returnData.add(datum); } if( datum.atAntiTrainingSite && !datum.failingSTDThreshold ) { returnData.add(datum); }
else { iii--; } else { iii--; }
} }
@ -236,6 +245,7 @@ public class VariantDataManager {
datum.atTrainingSite = false; datum.atTrainingSite = false;
datum.prior = 2.0; datum.prior = 2.0;
datum.consensusCount = 0; datum.consensusCount = 0;
for( final TrainingSet trainingSet : trainingSets ) { for( final TrainingSet trainingSet : trainingSets ) {
for( final VariantContext trainVC : tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, false ) ) { for( final VariantContext trainVC : tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, false ) ) {
if( trainVC != null && trainVC.isNotFiltered() && trainVC.isVariant() && if( trainVC != null && trainVC.isNotFiltered() && trainVC.isVariant() &&
@ -248,6 +258,10 @@ public class VariantDataManager {
datum.prior = Math.max( datum.prior, trainingSet.prior ); datum.prior = Math.max( datum.prior, trainingSet.prior );
datum.consensusCount += ( trainingSet.isConsensus ? 1 : 0 ); datum.consensusCount += ( trainingSet.isConsensus ? 1 : 0 );
} }
if( trainVC != null ) {
datum.atAntiTrainingSite = datum.atAntiTrainingSite || trainingSet.isAntiTraining;
}
} }
} }
} }

View File

@ -14,13 +14,13 @@ public class VariantDatum implements Comparable<VariantDatum> {
public double lod; public double lod;
public boolean atTruthSite; public boolean atTruthSite;
public boolean atTrainingSite; public boolean atTrainingSite;
public boolean atAntiTrainingSite;
public boolean isTransition; public boolean isTransition;
public boolean isSNP; public boolean isSNP;
public boolean failingSTDThreshold; public boolean failingSTDThreshold;
public double originalQual; public double originalQual;
public double prior; public double prior;
public int consensusCount; public int consensusCount;
public int usedForTraining;
public String contig; public String contig;
public int start; public int start;
public int stop; public int stop;

View File

@ -175,7 +175,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
datum.originalQual = vc.getPhredScaledQual(); datum.originalQual = vc.getPhredScaledQual();
datum.isSNP = vc.isSNP() && vc.isBiallelic(); datum.isSNP = vc.isSNP() && vc.isBiallelic();
datum.isTransition = datum.isSNP && VariantContextUtils.isTransition(vc); datum.isTransition = datum.isSNP && VariantContextUtils.isTransition(vc);
datum.usedForTraining = 0;
// Loop through the training data sets and if they overlap this loci then update the prior and training status appropriately // Loop through the training data sets and if they overlap this loci then update the prior and training status appropriately
dataManager.parseTrainingSets( tracker, ref, context, vc, datum, TRUST_ALL_POLYMORPHIC ); dataManager.parseTrainingSets( tracker, ref, context, vc, datum, TRUST_ALL_POLYMORPHIC );
@ -328,7 +327,8 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
stream.print("data <- c("); stream.print("data <- c(");
for( final VariantDatum datum : randomData ) { for( final VariantDatum datum : randomData ) {
stream.print(String.format("%.3f, %.3f, %.3f, %d, %d,", datum.annotations[iii], datum.annotations[jjj], (datum.lod < lodCutoff ? -1.0 : 1.0), datum.usedForTraining, (datum.isKnown ? 1 : -1))); stream.print(String.format("%.3f, %.3f, %.3f, %d, %d,", datum.annotations[iii], datum.annotations[jjj], (datum.lod < lodCutoff ? -1.0 : 1.0),
(datum.atAntiTrainingSite ? -1 : (datum.atTrainingSite ? 1 : 0)), (datum.isKnown ? 1 : -1)));
} }
stream.println("NA,NA,NA,NA,1)"); stream.println("NA,NA,NA,NA,1)");
stream.println("d <- matrix(data,ncol=5,byrow=T)"); stream.println("d <- matrix(data,ncol=5,byrow=T)");