From b2a0331e2d08f3c07de8559e0f2583ef29885b1a Mon Sep 17 00:00:00 2001 From: rpoplin Date: Sun, 3 Apr 2011 19:55:09 +0000 Subject: [PATCH] Pushing hard coded arguments into VariantRecalibratorArgumentCollection git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@5566 348d0f76-0448-11de-a6fe-93d51630548a --- .../ContrastiveRecalibrator.java | 25 ++++++++----------- .../GaussianMixtureModel.java | 3 +-- .../variantrecalibration/TrancheManager.java | 3 +-- .../VariantDataManager.java | 16 ++++++------ ...VariantRecalibratorArgumentCollection.java | 23 +++++++++++++---- .../VariantRecalibratorEngine.java | 6 ++--- 6 files changed, 42 insertions(+), 34 deletions(-) diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/ContrastiveRecalibrator.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/ContrastiveRecalibrator.java index 98f7ad473..ac9d0602d 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/ContrastiveRecalibrator.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/ContrastiveRecalibrator.java @@ -27,6 +27,7 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration; import org.broad.tribble.util.variantcontext.VariantContext; import org.broadinstitute.sting.commandline.Argument; +import org.broadinstitute.sting.commandline.ArgumentCollection; import org.broadinstitute.sting.commandline.Hidden; import org.broadinstitute.sting.commandline.Output; import org.broadinstitute.sting.gatk.contexts.AlignmentContext; @@ -40,7 +41,6 @@ import org.broadinstitute.sting.utils.QualityUtils; import org.broadinstitute.sting.utils.collections.ExpandingArrayList; import org.broadinstitute.sting.utils.exceptions.UserException; -import java.io.File; import java.io.PrintStream; import java.util.*; @@ -57,6 +57,8 @@ public class ContrastiveRecalibrator extends RodWalker ignoreInputFilterSet = new TreeSet(); private final Set inputNames = new HashSet(); - private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine(new VariantRecalibratorArgumentCollection()); //BUGBUG: doesn't do anything with the args at the moment + private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC ); //--------------------------------------------------------------------------------------------------------------- // @@ -101,7 +99,7 @@ public class ContrastiveRecalibrator extends RodWalker(Arrays.asList(USE_ANNOTATIONS)) ); + dataManager = new VariantDataManager( new ArrayList(Arrays.asList(USE_ANNOTATIONS)), VRAC ); if( IGNORE_INPUT_FILTERS != null ) { ignoreInputFilterSet.addAll( Arrays.asList(IGNORE_INPUT_FILTERS) ); @@ -191,14 +189,11 @@ public class ContrastiveRecalibrator extends RodWalker tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric, DEBUG_FILE ); //BUGBUG: recreated here to match the integration tests + final int nCallsAtTruth = TrancheManager.countCallsAtTruth( dataManager.getData(), Double.NEGATIVE_INFINITY ); + final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric( nCallsAtTruth ); + final List tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric ); TRANCHES_FILE.print(Tranche.tranchesString( tranches )); logger.info( "Writing out recalibration table..." ); diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java index ddc12269f..96e081fe7 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java @@ -71,7 +71,7 @@ public class GaussianMixtureModel { empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length)); } - public void initializeRandomModel( final List data, final Random rand ) { + public void initializeRandomModel( final List data, final Random rand, final int numKMeansIterations ) { // initialize random Gaussian means // BUGBUG: this is broken up this way to match the order of calls to rand.nextDouble() in the old code for( final MultivariateGaussian gaussian : gaussians ) { @@ -79,7 +79,6 @@ public class GaussianMixtureModel { } // initialize means using K-means algorithm - final int numKMeansIterations = 10; // BUGBUG: VRAC argument logger.info( "Initializing model with " + numKMeansIterations + " k-means iterations..." ); initializeMeansUsingKMeans( data, numKMeansIterations, rand ); diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/TrancheManager.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/TrancheManager.java index 49b464eda..9b8fabd63 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/TrancheManager.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/TrancheManager.java @@ -7,7 +7,6 @@ import java.io.File; import java.io.FileNotFoundException; import java.io.PrintStream; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -211,7 +210,7 @@ public class TrancheManager { public static int countCallsAtTruth(final List data, double minLOD ) { int n = 0; - for ( VariantDatum d : data) { n += (d.atTruthSite && d.lod >= minLOD ? 1 : 0); } + for ( VariantDatum d : data ) { n += (d.atTruthSite && d.lod >= minLOD ? 1 : 0); } return n; } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDataManager.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDataManager.java index 507cda88a..12f1f5da8 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDataManager.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDataManager.java @@ -25,15 +25,17 @@ public class VariantDataManager { private final double[] varianceVector; // this is really the standard deviation public final ArrayList annotationKeys; private final ExpandingArrayList trainingSets; + private final VariantRecalibratorArgumentCollection VRAC; private final static long RANDOM_SEED = 83409701; private final static Random rand = new Random( RANDOM_SEED ); // this is going to cause problems if it is ever used in an integration test, planning to get rid of HRun anyway protected final static Logger logger = Logger.getLogger(VariantDataManager.class); - public VariantDataManager( final List annotationKeys ) { + public VariantDataManager( final List annotationKeys, final VariantRecalibratorArgumentCollection VRAC ) { this.data = null; this.annotationKeys = new ArrayList( annotationKeys ); + this.VRAC = VRAC; meanVector = new double[this.annotationKeys.size()]; varianceVector = new double[this.annotationKeys.size()]; trainingSets = new ExpandingArrayList(); @@ -93,20 +95,20 @@ public class VariantDataManager { public ExpandingArrayList getTrainingData() { final ExpandingArrayList trainingData = new ExpandingArrayList(); for( final VariantDatum datum : data ) { - if( datum.atTrainingSite && datum.originalQual > 80.0 ) { //BUGBUG: VRAC argument + if( datum.atTrainingSite && datum.originalQual > VRAC.QUAL_THRESHOLD ) { trainingData.add( datum ); } } - trimDataBySTD(trainingData, 4.5); //BUGBUG: VRAC argument + trimDataBySTD( trainingData, VRAC.STD_THRESHOLD ); logger.info( "Training with " + trainingData.size() + " variants found in the training set(s)." ); return trainingData; } - public ExpandingArrayList selectWorstVariants( final float bottomPercentage ) { + public ExpandingArrayList selectWorstVariants( final double bottomPercentage ) { Collections.sort( data ); final ExpandingArrayList trainingData = new ExpandingArrayList(); - trainingData.addAll( data.subList(0, Math.round(bottomPercentage * data.size())) ); - logger.info( "Training with worst " + bottomPercentage * 100.0f + "% of data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(Math.round(bottomPercentage * data.size())).lod) + "." ); + trainingData.addAll( data.subList(0, Math.round((float)bottomPercentage * data.size())) ); + logger.info( "Training with worst " + bottomPercentage * 100.0f + "% of data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(Math.round((float)bottomPercentage * data.size())).lod) + "." ); return trainingData; } @@ -174,7 +176,7 @@ public class VariantDataManager { for( final TrainingSet trainingSet : trainingSets ) { final Collection vcs = tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, true ); final VariantContext trainVC = ( vcs.size() != 0 ? vcs.iterator().next() : null ); - if( trainVC != null && trainVC.isVariant() && !trainVC.isFiltered() && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) { + if( trainVC != null && trainVC.isVariant() && trainVC.isNotFiltered() && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) { datum.isKnown = datum.isKnown || trainingSet.isKnown; datum.atTruthSite = datum.atTruthSite || trainingSet.isTruth; datum.atTrainingSite = datum.atTrainingSite || trainingSet.isTraining; diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java index bfe3a2d7c..65e610cde 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java @@ -1,5 +1,7 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration; +import org.broadinstitute.sting.commandline.Argument; + /** * Created by IntelliJ IDEA. * User: rpoplin @@ -8,9 +10,20 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration; public class VariantRecalibratorArgumentCollection { - public VariantRecalibratorArgumentCollection clone() { - final VariantRecalibratorArgumentCollection vrac = new VariantRecalibratorArgumentCollection(); - - return vrac; - } + @Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians to try during variational Bayes algorithm", required=false) + public int MAX_GAUSSIANS = 32; + @Argument(fullName="maxIterations", shortName="mI", doc="The maximum number of VBEM iterations to be performed in variational Bayes algorithm. Procedure will normally end when convergence is detected.", required=false) + public int MAX_ITERATIONS = 100; + @Argument(fullName="numKMeans", shortName="nKM", doc="The number of k-means iterations to perform in order to initialize the means of the Gaussians in the Gaussian mixture model.", required=false) + public int NUM_KMEANS_ITERATIONS = 10; + @Argument(fullName="stdThreshold", shortName="std", doc="If a variant has annotations more than -std standard deviations away from mean then don't use it for building the Gaussian mixture model.", required=false) + public double STD_THRESHOLD = 4.5; + @Argument(fullName="qualThreshold", shortName="qual", doc="If a known variant has raw QUAL value less than -qual then don't use it for building the Gaussian mixture model.", required=false) + public double QUAL_THRESHOLD = 80.0; + @Argument(fullName="shrinkage", shortName="shrinkage", doc="The shrinkage parameter in variational Bayes algorithm.", required=false) + public double SHRINKAGE = 0.0001; + @Argument(fullName="dirichlet", shortName="dirichlet", doc="The dirichlet parameter in variational Bayes algorithm.", required=false) + public double DIRICHLET_PARAMETER = 0.0001; + @Argument(fullName="percentBadVariants", shortName="percentBad", doc="What percentage of the worst scoring variants to use when building the Gaussian mixture model of bad variants. 0.07 means bottom 7 percent.", required=false) + public double PERCENT_BAD_VARIANTS = 0.07; } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java index 6678f850f..a1ffd23ea 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java @@ -36,7 +36,7 @@ public class VariantRecalibratorEngine { } public GaussianMixtureModel generateModel( final List data ) { - final GaussianMixtureModel model = new GaussianMixtureModel( 32, 4, 0.0001, 0.0001 ); //BUGBUG: VRAC arguments + final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER ); variationalBayesExpectationMaximization( model, data ); return model; } @@ -67,14 +67,14 @@ public class VariantRecalibratorEngine { private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List data ) { model.cacheEmpiricalStats( data ); - model.initializeRandomModel( data, rand ); + model.initializeRandomModel( data, rand, VRAC.NUM_KMEANS_ITERATIONS ); // The VBEM loop double previousLikelihood = model.expectationStep( data ); double currentLikelihood; int iteration = 0; logger.info("Finished iteration " + iteration ); - while( iteration < 100 ) { //BUGBUG: VRAC.maxIterations + while( iteration < VRAC.MAX_ITERATIONS ) { iteration++; model.maximizationStep( data ); currentLikelihood = model.expectationStep( data );