Pushing hard coded arguments into VariantRecalibratorArgumentCollection

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@5566 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
rpoplin 2011-04-03 19:55:09 +00:00
parent 79c43845ad
commit b2a0331e2d
6 changed files with 42 additions and 34 deletions

View File

@ -27,6 +27,7 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration;
import org.broad.tribble.util.variantcontext.VariantContext;
import org.broadinstitute.sting.commandline.Argument;
import org.broadinstitute.sting.commandline.ArgumentCollection;
import org.broadinstitute.sting.commandline.Hidden;
import org.broadinstitute.sting.commandline.Output;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
@ -40,7 +41,6 @@ import org.broadinstitute.sting.utils.QualityUtils;
import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
import org.broadinstitute.sting.utils.exceptions.UserException;
import java.io.File;
import java.io.PrintStream;
import java.util.*;
@ -57,6 +57,8 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
public static final String VQS_LOD_KEY = "VQSLOD";
@ArgumentCollection private VariantRecalibratorArgumentCollection VRAC = new VariantRecalibratorArgumentCollection();
/////////////////////////////
// Outputs
/////////////////////////////
@ -66,9 +68,8 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
private PrintStream TRANCHES_FILE;
/////////////////////////////
// Command Line Arguments
// Additional Command Line Arguments
/////////////////////////////
//BUGBUG: use VariantRecalibrationArgumentCollection
@Argument(fullName="use_annotation", shortName="an", doc="The names of the annotations which should used for calculations", required=true)
private String[] USE_ANNOTATIONS = null;
@Argument(fullName="TStranche", shortName="tranche", doc="The levels of novel false discovery rate (FDR, implied by ti/tv) at which to slice the data. (in percent, that is 1.0 for 1 percent)", required=false)
@ -80,9 +81,6 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
// Debug Arguments
/////////////////////////////
@Hidden
@Argument(fullName = "debugFile", shortName = "debugFile", doc = "Print debugging information here", required=false)
private File DEBUG_FILE = null;
@Hidden
@Argument(fullName = "trustAllPolymorphic", shortName = "allPoly", doc = "Trust that all the input training sets' unfiltered records contain only polymorphic sites to drastically speed up the computation.", required = false)
protected Boolean TRUST_ALL_POLYMORPHIC = false;
@ -92,7 +90,7 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
private VariantDataManager dataManager;
private final Set<String> ignoreInputFilterSet = new TreeSet<String>();
private final Set<String> inputNames = new HashSet<String>();
private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine(new VariantRecalibratorArgumentCollection()); //BUGBUG: doesn't do anything with the args at the moment
private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC );
//---------------------------------------------------------------------------------------------------------------
//
@ -101,7 +99,7 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
//---------------------------------------------------------------------------------------------------------------
public void initialize() {
dataManager = new VariantDataManager( new ArrayList<String>(Arrays.asList(USE_ANNOTATIONS)) );
dataManager = new VariantDataManager( new ArrayList<String>(Arrays.asList(USE_ANNOTATIONS)), VRAC );
if( IGNORE_INPUT_FILTERS != null ) {
ignoreInputFilterSet.addAll( Arrays.asList(IGNORE_INPUT_FILTERS) );
@ -191,14 +189,11 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
dataManager.setData( reduceSum );
dataManager.normalizeData();
engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.getTrainingData() ), false );
engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.selectWorstVariants( 0.07f ) ), true );
engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS ) ), true );
// old tranches stuff
int nCallsAtTruth = TrancheManager.countCallsAtTruth( dataManager.getData(), Double.NEGATIVE_INFINITY );
//logger.info(String.format("Truth set size is %d, raw calls at these sites %d, maximum sensitivity of %.2f",
// nTruthSites, nCallsAtTruth, (100.0*nCallsAtTruth / Math.max(nTruthSites, nCallsAtTruth))));
TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric( nCallsAtTruth );
List<Tranche> tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric, DEBUG_FILE ); //BUGBUG: recreated here to match the integration tests
final int nCallsAtTruth = TrancheManager.countCallsAtTruth( dataManager.getData(), Double.NEGATIVE_INFINITY );
final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric( nCallsAtTruth );
final List<Tranche> tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric );
TRANCHES_FILE.print(Tranche.tranchesString( tranches ));
logger.info( "Writing out recalibration table..." );

View File

@ -71,7 +71,7 @@ public class GaussianMixtureModel {
empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length));
}
public void initializeRandomModel( final List<VariantDatum> data, final Random rand ) {
public void initializeRandomModel( final List<VariantDatum> data, final Random rand, final int numKMeansIterations ) {
// initialize random Gaussian means // BUGBUG: this is broken up this way to match the order of calls to rand.nextDouble() in the old code
for( final MultivariateGaussian gaussian : gaussians ) {
@ -79,7 +79,6 @@ public class GaussianMixtureModel {
}
// initialize means using K-means algorithm
final int numKMeansIterations = 10; // BUGBUG: VRAC argument
logger.info( "Initializing model with " + numKMeansIterations + " k-means iterations..." );
initializeMeansUsingKMeans( data, numKMeansIterations, rand );

View File

@ -7,7 +7,6 @@ import java.io.File;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
@ -211,7 +210,7 @@ public class TrancheManager {
public static int countCallsAtTruth(final List<VariantDatum> data, double minLOD ) {
int n = 0;
for ( VariantDatum d : data) { n += (d.atTruthSite && d.lod >= minLOD ? 1 : 0); }
for ( VariantDatum d : data ) { n += (d.atTruthSite && d.lod >= minLOD ? 1 : 0); }
return n;
}

View File

@ -25,15 +25,17 @@ public class VariantDataManager {
private final double[] varianceVector; // this is really the standard deviation
public final ArrayList<String> annotationKeys;
private final ExpandingArrayList<TrainingSet> trainingSets;
private final VariantRecalibratorArgumentCollection VRAC;
private final static long RANDOM_SEED = 83409701;
private final static Random rand = new Random( RANDOM_SEED ); // this is going to cause problems if it is ever used in an integration test, planning to get rid of HRun anyway
protected final static Logger logger = Logger.getLogger(VariantDataManager.class);
public VariantDataManager( final List<String> annotationKeys ) {
public VariantDataManager( final List<String> annotationKeys, final VariantRecalibratorArgumentCollection VRAC ) {
this.data = null;
this.annotationKeys = new ArrayList<String>( annotationKeys );
this.VRAC = VRAC;
meanVector = new double[this.annotationKeys.size()];
varianceVector = new double[this.annotationKeys.size()];
trainingSets = new ExpandingArrayList<TrainingSet>();
@ -93,20 +95,20 @@ public class VariantDataManager {
public ExpandingArrayList<VariantDatum> getTrainingData() {
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>();
for( final VariantDatum datum : data ) {
if( datum.atTrainingSite && datum.originalQual > 80.0 ) { //BUGBUG: VRAC argument
if( datum.atTrainingSite && datum.originalQual > VRAC.QUAL_THRESHOLD ) {
trainingData.add( datum );
}
}
trimDataBySTD(trainingData, 4.5); //BUGBUG: VRAC argument
trimDataBySTD( trainingData, VRAC.STD_THRESHOLD );
logger.info( "Training with " + trainingData.size() + " variants found in the training set(s)." );
return trainingData;
}
public ExpandingArrayList<VariantDatum> selectWorstVariants( final float bottomPercentage ) {
public ExpandingArrayList<VariantDatum> selectWorstVariants( final double bottomPercentage ) {
Collections.sort( data );
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>();
trainingData.addAll( data.subList(0, Math.round(bottomPercentage * data.size())) );
logger.info( "Training with worst " + bottomPercentage * 100.0f + "% of data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(Math.round(bottomPercentage * data.size())).lod) + "." );
trainingData.addAll( data.subList(0, Math.round((float)bottomPercentage * data.size())) );
logger.info( "Training with worst " + bottomPercentage * 100.0f + "% of data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(Math.round((float)bottomPercentage * data.size())).lod) + "." );
return trainingData;
}
@ -174,7 +176,7 @@ public class VariantDataManager {
for( final TrainingSet trainingSet : trainingSets ) {
final Collection<VariantContext> vcs = tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, true );
final VariantContext trainVC = ( vcs.size() != 0 ? vcs.iterator().next() : null );
if( trainVC != null && trainVC.isVariant() && !trainVC.isFiltered() && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) {
if( trainVC != null && trainVC.isVariant() && trainVC.isNotFiltered() && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) {
datum.isKnown = datum.isKnown || trainingSet.isKnown;
datum.atTruthSite = datum.atTruthSite || trainingSet.isTruth;
datum.atTrainingSite = datum.atTrainingSite || trainingSet.isTraining;

View File

@ -1,5 +1,7 @@
package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration;
import org.broadinstitute.sting.commandline.Argument;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
@ -8,9 +10,20 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration;
public class VariantRecalibratorArgumentCollection {
public VariantRecalibratorArgumentCollection clone() {
final VariantRecalibratorArgumentCollection vrac = new VariantRecalibratorArgumentCollection();
return vrac;
}
@Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians to try during variational Bayes algorithm", required=false)
public int MAX_GAUSSIANS = 32;
@Argument(fullName="maxIterations", shortName="mI", doc="The maximum number of VBEM iterations to be performed in variational Bayes algorithm. Procedure will normally end when convergence is detected.", required=false)
public int MAX_ITERATIONS = 100;
@Argument(fullName="numKMeans", shortName="nKM", doc="The number of k-means iterations to perform in order to initialize the means of the Gaussians in the Gaussian mixture model.", required=false)
public int NUM_KMEANS_ITERATIONS = 10;
@Argument(fullName="stdThreshold", shortName="std", doc="If a variant has annotations more than -std standard deviations away from mean then don't use it for building the Gaussian mixture model.", required=false)
public double STD_THRESHOLD = 4.5;
@Argument(fullName="qualThreshold", shortName="qual", doc="If a known variant has raw QUAL value less than -qual then don't use it for building the Gaussian mixture model.", required=false)
public double QUAL_THRESHOLD = 80.0;
@Argument(fullName="shrinkage", shortName="shrinkage", doc="The shrinkage parameter in variational Bayes algorithm.", required=false)
public double SHRINKAGE = 0.0001;
@Argument(fullName="dirichlet", shortName="dirichlet", doc="The dirichlet parameter in variational Bayes algorithm.", required=false)
public double DIRICHLET_PARAMETER = 0.0001;
@Argument(fullName="percentBadVariants", shortName="percentBad", doc="What percentage of the worst scoring variants to use when building the Gaussian mixture model of bad variants. 0.07 means bottom 7 percent.", required=false)
public double PERCENT_BAD_VARIANTS = 0.07;
}

View File

@ -36,7 +36,7 @@ public class VariantRecalibratorEngine {
}
public GaussianMixtureModel generateModel( final List<VariantDatum> data ) {
final GaussianMixtureModel model = new GaussianMixtureModel( 32, 4, 0.0001, 0.0001 ); //BUGBUG: VRAC arguments
final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER );
variationalBayesExpectationMaximization( model, data );
return model;
}
@ -67,14 +67,14 @@ public class VariantRecalibratorEngine {
private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List<VariantDatum> data ) {
model.cacheEmpiricalStats( data );
model.initializeRandomModel( data, rand );
model.initializeRandomModel( data, rand, VRAC.NUM_KMEANS_ITERATIONS );
// The VBEM loop
double previousLikelihood = model.expectationStep( data );
double currentLikelihood;
int iteration = 0;
logger.info("Finished iteration " + iteration );
while( iteration < 100 ) { //BUGBUG: VRAC.maxIterations
while( iteration < VRAC.MAX_ITERATIONS ) {
iteration++;
model.maximizationStep( data );
currentLikelihood = model.expectationStep( data );