Adding ability to specify in VQSR a list of bad sites to use when training the negative model. Just add bad=true to the list of rod tags for your bad sites track.

This commit is contained in:
Ryan Poplin 2011-07-02 17:15:13 -04:00
parent 4532a84314
commit fdc2ebb321
4 changed files with 42 additions and 20 deletions

View File

@ -14,6 +14,7 @@ public class TrainingSet {
public String name; public String name;
public boolean isKnown = false; public boolean isKnown = false;
public boolean isTraining = false; public boolean isTraining = false;
public boolean isAntiTraining = false;
public boolean isTruth = false; public boolean isTruth = false;
public boolean isConsensus = false; public boolean isConsensus = false;
public double prior = 0.0; public double prior = 0.0;
@ -22,17 +23,24 @@ public class TrainingSet {
public TrainingSet( final String name, final Tags tags ) { public TrainingSet( final String name, final Tags tags ) {
this.name = name; this.name = name;
// Parse the tags to decide which tracks have which properties
if( tags != null ) { if( tags != null ) {
isKnown = tags.containsKey("known") && tags.getValue("known").equals("true"); isKnown = tags.containsKey("known") && tags.getValue("known").equals("true");
isTraining = tags.containsKey("training") && tags.getValue("training").equals("true"); isTraining = tags.containsKey("training") && tags.getValue("training").equals("true");
isAntiTraining = tags.containsKey("bad") && tags.getValue("bad").equals("true");
isTruth = tags.containsKey("truth") && tags.getValue("truth").equals("true"); isTruth = tags.containsKey("truth") && tags.getValue("truth").equals("true");
isConsensus = tags.containsKey("consensus") && tags.getValue("consensus").equals("true"); isConsensus = tags.containsKey("consensus") && tags.getValue("consensus").equals("true");
prior = ( tags.containsKey("prior") ? Double.parseDouble(tags.getValue("prior")) : prior ); prior = ( tags.containsKey("prior") ? Double.parseDouble(tags.getValue("prior")) : prior );
} }
if( !isConsensus ) {
// Report back to the user which tracks were found and the properties that were detected
if( !isConsensus && !isAntiTraining ) {
logger.info( String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) ); logger.info( String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) );
} else { } else if( isConsensus ) {
logger.info( String.format( "Found consensus track: %s", this.name) ); logger.info( String.format( "Found consensus track: %s", this.name) );
} else {
logger.info( String.format( "Found bad sites training track: %s", this.name) );
} }
} }
} }

View File

@ -84,7 +84,6 @@ public class VariantDataManager {
remove = remove || (Math.abs(val) > VRAC.STD_THRESHOLD); remove = remove || (Math.abs(val) > VRAC.STD_THRESHOLD);
} }
datum.failingSTDThreshold = remove; datum.failingSTDThreshold = remove;
datum.usedForTraining = 0;
} }
} }
@ -118,38 +117,47 @@ public class VariantDataManager {
for( final VariantDatum datum : data ) { for( final VariantDatum datum : data ) {
if( datum.atTrainingSite && !datum.failingSTDThreshold && datum.originalQual > VRAC.QUAL_THRESHOLD ) { if( datum.atTrainingSite && !datum.failingSTDThreshold && datum.originalQual > VRAC.QUAL_THRESHOLD ) {
trainingData.add( datum ); trainingData.add( datum );
datum.usedForTraining = 1;
} }
} }
logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." ); logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." );
if( trainingData.size() < VRAC.MIN_NUM_BAD_VARIANTS ) { if( trainingData.size() < VRAC.MIN_NUM_BAD_VARIANTS ) {
logger.warn("WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable."); logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." );
} }
return trainingData; return trainingData;
} }
public ExpandingArrayList<VariantDatum> selectWorstVariants( double bottomPercentage, final int minimumNumber ) { public ExpandingArrayList<VariantDatum> selectWorstVariants( double bottomPercentage, final int minimumNumber ) {
Collections.sort( data ); // The return value is the list of training variants
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>(); final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>();
final int numToAdd = Math.max( minimumNumber, Math.round((float)bottomPercentage * data.size()) );
if( numToAdd > data.size() ) { // First add to the training list all sites overlapping any bad sites training tracks
throw new UserException.BadInput("Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe."); for( final VariantDatum datum : data ) {
if( datum.atAntiTrainingSite && !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) {
trainingData.add( datum );
}
} }
if( numToAdd == minimumNumber ) { final int numBadSitesAdded = trainingData.size();
logger.warn("WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable."); logger.info( "Found " + numBadSitesAdded + " variants overlapping bad sites training tracks." );
// Next, sort the variants by the LOD coming from the positive model and add to the list the bottom X percent of variants
Collections.sort( data );
final int numToAdd = Math.max( minimumNumber - trainingData.size(), Math.round((float)bottomPercentage * data.size()) );
if( numToAdd > data.size() ) {
throw new UserException.BadInput( "Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. One can attempt to lower the --minNumBadVariants arugment but this is unsafe." );
} else if( numToAdd == minimumNumber - trainingData.size() ) {
logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." );
bottomPercentage = ((float) numToAdd) / ((float) data.size()); bottomPercentage = ((float) numToAdd) / ((float) data.size());
} }
int index = 0; int index = 0, numAdded = 0;
int numAdded = 0;
while( numAdded < numToAdd ) { while( numAdded < numToAdd ) {
final VariantDatum datum = data.get(index++); final VariantDatum datum = data.get(index++);
if( !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) { if( !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) {
datum.atAntiTrainingSite = true;
trainingData.add( datum ); trainingData.add( datum );
datum.usedForTraining = -1;
numAdded++; numAdded++;
} }
} }
logger.info("Training with worst " + (float) bottomPercentage * 100.0f + "% of passing data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "."); logger.info( "Additionally training with worst " + (float) bottomPercentage * 100.0f + "% of passing data --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." );
return trainingData; return trainingData;
} }
@ -162,10 +170,11 @@ public class VariantDataManager {
returnData.add(datum); returnData.add(datum);
} }
} }
// add an extra 5% of points from bad training set, since that set is small but interesting
// Add an extra 5% of points from bad training set, since that set is small but interesting
for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) { for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) {
final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size())); final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
if( datum.usedForTraining == -1 && !datum.failingSTDThreshold ) { returnData.add(datum); } if( datum.atAntiTrainingSite && !datum.failingSTDThreshold ) { returnData.add(datum); }
else { iii--; } else { iii--; }
} }
@ -236,6 +245,7 @@ public class VariantDataManager {
datum.atTrainingSite = false; datum.atTrainingSite = false;
datum.prior = 2.0; datum.prior = 2.0;
datum.consensusCount = 0; datum.consensusCount = 0;
for( final TrainingSet trainingSet : trainingSets ) { for( final TrainingSet trainingSet : trainingSets ) {
for( final VariantContext trainVC : tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, false ) ) { for( final VariantContext trainVC : tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, false ) ) {
if( trainVC != null && trainVC.isNotFiltered() && trainVC.isVariant() && if( trainVC != null && trainVC.isNotFiltered() && trainVC.isVariant() &&
@ -248,6 +258,10 @@ public class VariantDataManager {
datum.prior = Math.max( datum.prior, trainingSet.prior ); datum.prior = Math.max( datum.prior, trainingSet.prior );
datum.consensusCount += ( trainingSet.isConsensus ? 1 : 0 ); datum.consensusCount += ( trainingSet.isConsensus ? 1 : 0 );
} }
if( trainVC != null ) {
datum.atAntiTrainingSite = datum.atAntiTrainingSite || trainingSet.isAntiTraining;
}
} }
} }
} }

View File

@ -14,13 +14,13 @@ public class VariantDatum implements Comparable<VariantDatum> {
public double lod; public double lod;
public boolean atTruthSite; public boolean atTruthSite;
public boolean atTrainingSite; public boolean atTrainingSite;
public boolean atAntiTrainingSite;
public boolean isTransition; public boolean isTransition;
public boolean isSNP; public boolean isSNP;
public boolean failingSTDThreshold; public boolean failingSTDThreshold;
public double originalQual; public double originalQual;
public double prior; public double prior;
public int consensusCount; public int consensusCount;
public int usedForTraining;
public String contig; public String contig;
public int start; public int start;
public int stop; public int stop;

View File

@ -175,7 +175,6 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
datum.originalQual = vc.getPhredScaledQual(); datum.originalQual = vc.getPhredScaledQual();
datum.isSNP = vc.isSNP() && vc.isBiallelic(); datum.isSNP = vc.isSNP() && vc.isBiallelic();
datum.isTransition = datum.isSNP && VariantContextUtils.isTransition(vc); datum.isTransition = datum.isSNP && VariantContextUtils.isTransition(vc);
datum.usedForTraining = 0;
// Loop through the training data sets and if they overlap this loci then update the prior and training status appropriately // Loop through the training data sets and if they overlap this loci then update the prior and training status appropriately
dataManager.parseTrainingSets( tracker, ref, context, vc, datum, TRUST_ALL_POLYMORPHIC ); dataManager.parseTrainingSets( tracker, ref, context, vc, datum, TRUST_ALL_POLYMORPHIC );
@ -328,7 +327,8 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
stream.print("data <- c("); stream.print("data <- c(");
for( final VariantDatum datum : randomData ) { for( final VariantDatum datum : randomData ) {
stream.print(String.format("%.3f, %.3f, %.3f, %d, %d,", datum.annotations[iii], datum.annotations[jjj], (datum.lod < lodCutoff ? -1.0 : 1.0), datum.usedForTraining, (datum.isKnown ? 1 : -1))); stream.print(String.format("%.3f, %.3f, %.3f, %d, %d,", datum.annotations[iii], datum.annotations[jjj], (datum.lod < lodCutoff ? -1.0 : 1.0),
(datum.atAntiTrainingSite ? -1 : (datum.atTrainingSite ? 1 : 0)), (datum.isKnown ? 1 : -1)));
} }
stream.println("NA,NA,NA,NA,1)"); stream.println("NA,NA,NA,NA,1)");
stream.println("d <- matrix(data,ncol=5,byrow=T)"); stream.println("d <- matrix(data,ncol=5,byrow=T)");