Adding ability to specify in VQSR a list of bad sites to use when training the negative model. Just add bad=true to the list of rod tags for your bad sites track.

This commit is contained in:
Ryan Poplin 2011-07-02 17:15:13 -04:00
parent 4532a84314
commit fdc2ebb321
4 changed files with 42 additions and 20 deletions

View File

@ -14,6 +14,7 @@ public class TrainingSet {
public String name;
public boolean isKnown = false;
public boolean isTraining = false;
public boolean isAntiTraining = false;
public boolean isTruth = false;
public boolean isConsensus = false;
public double prior = 0.0;
@ -22,17 +23,24 @@ public class TrainingSet {
public TrainingSet( final String name, final Tags tags ) {
this.name = name;
// Parse the tags to decide which tracks have which properties
if( tags != null ) {
isKnown = tags.containsKey("known") && tags.getValue("known").equals("true");
isTraining = tags.containsKey("training") && tags.getValue("training").equals("true");
isAntiTraining = tags.containsKey("bad") && tags.getValue("bad").equals("true");
isTruth = tags.containsKey("truth") && tags.getValue("truth").equals("true");
isConsensus = tags.containsKey("consensus") && tags.getValue("consensus").equals("true");
prior = ( tags.containsKey("prior") ? Double.parseDouble(tags.getValue("prior")) : prior );
}
if( !isConsensus ) {
// Report back to the user which tracks were found and the properties that were detected
if( !isConsensus && !isAntiTraining ) {
logger.info( String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) );
} else {
} else if( isConsensus ) {
logger.info( String.format( "Found consensus track: %s", this.name) );
} else {
logger.info( String.format( "Found bad sites training track: %s", this.name) );
}
}
}

View File

@ -84,7 +84,6 @@ public class VariantDataManager {
remove = remove || (Math.abs(val) > VRAC.STD_THRESHOLD);
}
datum.failingSTDThreshold = remove;
datum.usedForTraining = 0;
}
}
@ -118,38 +117,47 @@ public class VariantDataManager {
for( final VariantDatum datum : data ) {
if( datum.atTrainingSite && !datum.failingSTDThreshold && datum.originalQual > VRAC.QUAL_THRESHOLD ) {
trainingData.add( datum );
datum.usedForTraining = 1;
}
}
logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." );
if( trainingData.size() < VRAC.MIN_NUM_BAD_VARIANTS ) {
logger.warn("WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable.");
logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." );
}
return trainingData;
}
public ExpandingArrayList<VariantDatum> selectWorstVariants( double bottomPercentage, final int minimumNumber ) {
Collections.sort( data );
// The return value is the list of training variants
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>();
final int numToAdd = Math.max( minimumNumber, Math.round((float)bottomPercentage * data.size()) );
if( numToAdd > data.size() ) {
throw new UserException.BadInput("Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe.");
// First add to the training list all sites overlapping any bad sites training tracks
for( final VariantDatum datum : data ) {
if( datum.atAntiTrainingSite && !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) {
trainingData.add( datum );
}
}
if( numToAdd == minimumNumber ) {
logger.warn("WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable.");
final int numBadSitesAdded = trainingData.size();
logger.info( "Found " + numBadSitesAdded + " variants overlapping bad sites training tracks." );
// Next, sort the variants by the LOD coming from the positive model and add to the list the bottom X percent of variants
Collections.sort( data );
final int numToAdd = Math.max( minimumNumber - trainingData.size(), Math.round((float)bottomPercentage * data.size()) );
if( numToAdd > data.size() ) {
throw new UserException.BadInput( "Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe." );
} else if( numToAdd == minimumNumber - trainingData.size() ) {
logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." );
bottomPercentage = ((float) numToAdd) / ((float) data.size());
}
int index = 0;
int numAdded = 0;
int index = 0, numAdded = 0;
while( numAdded < numToAdd ) {
final VariantDatum datum = data.get(index++);
if( !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) {
datum.atAntiTrainingSite = true;
trainingData.add( datum );
datum.usedForTraining = -1;
numAdded++;
}
}
logger.info("Training with worst " + (float) bottomPercentage * 100.0f + "% of passing data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + ".");
logger.info( "Additionally training with worst " + (float) bottomPercentage * 100.0f + "% of passing data --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." );
return trainingData;
}
@ -162,10 +170,11 @@ public class VariantDataManager {
returnData.add(datum);
}
}
// add an extra 5% of points from bad training set, since that set is small but interesting
// Add an extra 5% of points from bad training set, since that set is small but interesting
for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) {
final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
if( datum.usedForTraining == -1 && !datum.failingSTDThreshold ) { returnData.add(datum); }
if( datum.atAntiTrainingSite && !datum.failingSTDThreshold ) { returnData.add(datum); }
else { iii--; }
}
@ -236,6 +245,7 @@ public class VariantDataManager {
datum.atTrainingSite = false;
datum.prior = 2.0;
datum.consensusCount = 0;
for( final TrainingSet trainingSet : trainingSets ) {
for( final VariantContext trainVC : tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, false ) ) {
if( trainVC != null && trainVC.isNotFiltered() && trainVC.isVariant() &&
@ -248,6 +258,10 @@ public class VariantDataManager {
datum.prior = Math.max( datum.prior, trainingSet.prior );
datum.consensusCount += ( trainingSet.isConsensus ? 1 : 0 );
}
if( trainVC != null ) {
datum.atAntiTrainingSite = datum.atAntiTrainingSite || trainingSet.isAntiTraining;
}
}
}
}

View File

@ -14,13 +14,13 @@ public class VariantDatum implements Comparable<VariantDatum> {
public double lod;
public boolean atTruthSite;
public boolean atTrainingSite;
public boolean atAntiTrainingSite;
public boolean isTransition;
public boolean isSNP;
public boolean failingSTDThreshold;
public double originalQual;
public double prior;
public int consensusCount;
public int usedForTraining;
public String contig;
public int start;
public int stop;

View File

@ -175,7 +175,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
datum.originalQual = vc.getPhredScaledQual();
datum.isSNP = vc.isSNP() && vc.isBiallelic();
datum.isTransition = datum.isSNP && VariantContextUtils.isTransition(vc);
datum.usedForTraining = 0;
// Loop through the training data sets and if they overlap this loci then update the prior and training status appropriately
dataManager.parseTrainingSets( tracker, ref, context, vc, datum, TRUST_ALL_POLYMORPHIC );
@ -328,7 +327,8 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
stream.print("data <- c(");
for( final VariantDatum datum : randomData ) {
stream.print(String.format("%.3f, %.3f, %.3f, %d, %d,", datum.annotations[iii], datum.annotations[jjj], (datum.lod < lodCutoff ? -1.0 : 1.0), datum.usedForTraining, (datum.isKnown ? 1 : -1)));
stream.print(String.format("%.3f, %.3f, %.3f, %d, %d,", datum.annotations[iii], datum.annotations[jjj], (datum.lod < lodCutoff ? -1.0 : 1.0),
(datum.atAntiTrainingSite ? -1 : (datum.atTrainingSite ? 1 : 0)), (datum.isKnown ? 1 : -1)));
}
stream.println("NA,NA,NA,NA,1)");
stream.println("d <- matrix(data,ncol=5,byrow=T)");