Pushing hard coded arguments into VariantRecalibratorArgumentCollection
git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@5566 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
parent
79c43845ad
commit
b2a0331e2d
|
|
@ -27,6 +27,7 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration;
|
|||
|
||||
import org.broad.tribble.util.variantcontext.VariantContext;
|
||||
import org.broadinstitute.sting.commandline.Argument;
|
||||
import org.broadinstitute.sting.commandline.ArgumentCollection;
|
||||
import org.broadinstitute.sting.commandline.Hidden;
|
||||
import org.broadinstitute.sting.commandline.Output;
|
||||
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
|
||||
|
|
@ -40,7 +41,6 @@ import org.broadinstitute.sting.utils.QualityUtils;
|
|||
import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
|
||||
import org.broadinstitute.sting.utils.exceptions.UserException;
|
||||
|
||||
import java.io.File;
|
||||
import java.io.PrintStream;
|
||||
import java.util.*;
|
||||
|
||||
|
|
@ -57,6 +57,8 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
|
|||
|
||||
public static final String VQS_LOD_KEY = "VQSLOD";
|
||||
|
||||
@ArgumentCollection private VariantRecalibratorArgumentCollection VRAC = new VariantRecalibratorArgumentCollection();
|
||||
|
||||
/////////////////////////////
|
||||
// Outputs
|
||||
/////////////////////////////
|
||||
|
|
@ -66,9 +68,8 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
|
|||
private PrintStream TRANCHES_FILE;
|
||||
|
||||
/////////////////////////////
|
||||
// Command Line Arguments
|
||||
// Additional Command Line Arguments
|
||||
/////////////////////////////
|
||||
//BUGBUG: use VariantRecalibrationArgumentCollection
|
||||
@Argument(fullName="use_annotation", shortName="an", doc="The names of the annotations which should used for calculations", required=true)
|
||||
private String[] USE_ANNOTATIONS = null;
|
||||
@Argument(fullName="TStranche", shortName="tranche", doc="The levels of novel false discovery rate (FDR, implied by ti/tv) at which to slice the data. (in percent, that is 1.0 for 1 percent)", required=false)
|
||||
|
|
@ -80,9 +81,6 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
|
|||
// Debug Arguments
|
||||
/////////////////////////////
|
||||
@Hidden
|
||||
@Argument(fullName = "debugFile", shortName = "debugFile", doc = "Print debugging information here", required=false)
|
||||
private File DEBUG_FILE = null;
|
||||
@Hidden
|
||||
@Argument(fullName = "trustAllPolymorphic", shortName = "allPoly", doc = "Trust that all the input training sets' unfiltered records contain only polymorphic sites to drastically speed up the computation.", required = false)
|
||||
protected Boolean TRUST_ALL_POLYMORPHIC = false;
|
||||
|
||||
|
|
@ -92,7 +90,7 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
|
|||
private VariantDataManager dataManager;
|
||||
private final Set<String> ignoreInputFilterSet = new TreeSet<String>();
|
||||
private final Set<String> inputNames = new HashSet<String>();
|
||||
private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine(new VariantRecalibratorArgumentCollection()); //BUGBUG: doesn't do anything with the args at the moment
|
||||
private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC );
|
||||
|
||||
//---------------------------------------------------------------------------------------------------------------
|
||||
//
|
||||
|
|
@ -101,7 +99,7 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
|
|||
//---------------------------------------------------------------------------------------------------------------
|
||||
|
||||
public void initialize() {
|
||||
dataManager = new VariantDataManager( new ArrayList<String>(Arrays.asList(USE_ANNOTATIONS)) );
|
||||
dataManager = new VariantDataManager( new ArrayList<String>(Arrays.asList(USE_ANNOTATIONS)), VRAC );
|
||||
|
||||
if( IGNORE_INPUT_FILTERS != null ) {
|
||||
ignoreInputFilterSet.addAll( Arrays.asList(IGNORE_INPUT_FILTERS) );
|
||||
|
|
@ -191,14 +189,11 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
|
|||
dataManager.setData( reduceSum );
|
||||
dataManager.normalizeData();
|
||||
engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.getTrainingData() ), false );
|
||||
engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.selectWorstVariants( 0.07f ) ), true );
|
||||
engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS ) ), true );
|
||||
|
||||
// old tranches stuff
|
||||
int nCallsAtTruth = TrancheManager.countCallsAtTruth( dataManager.getData(), Double.NEGATIVE_INFINITY );
|
||||
//logger.info(String.format("Truth set size is %d, raw calls at these sites %d, maximum sensitivity of %.2f",
|
||||
// nTruthSites, nCallsAtTruth, (100.0*nCallsAtTruth / Math.max(nTruthSites, nCallsAtTruth))));
|
||||
TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric( nCallsAtTruth );
|
||||
List<Tranche> tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric, DEBUG_FILE ); //BUGBUG: recreated here to match the integration tests
|
||||
final int nCallsAtTruth = TrancheManager.countCallsAtTruth( dataManager.getData(), Double.NEGATIVE_INFINITY );
|
||||
final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric( nCallsAtTruth );
|
||||
final List<Tranche> tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric );
|
||||
TRANCHES_FILE.print(Tranche.tranchesString( tranches ));
|
||||
|
||||
logger.info( "Writing out recalibration table..." );
|
||||
|
|
|
|||
|
|
@ -71,7 +71,7 @@ public class GaussianMixtureModel {
|
|||
empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length));
|
||||
}
|
||||
|
||||
public void initializeRandomModel( final List<VariantDatum> data, final Random rand ) {
|
||||
public void initializeRandomModel( final List<VariantDatum> data, final Random rand, final int numKMeansIterations ) {
|
||||
|
||||
// initialize random Gaussian means // BUGBUG: this is broken up this way to match the order of calls to rand.nextDouble() in the old code
|
||||
for( final MultivariateGaussian gaussian : gaussians ) {
|
||||
|
|
@ -79,7 +79,6 @@ public class GaussianMixtureModel {
|
|||
}
|
||||
|
||||
// initialize means using K-means algorithm
|
||||
final int numKMeansIterations = 10; // BUGBUG: VRAC argument
|
||||
logger.info( "Initializing model with " + numKMeansIterations + " k-means iterations..." );
|
||||
initializeMeansUsingKMeans( data, numKMeansIterations, rand );
|
||||
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ import java.io.File;
|
|||
import java.io.FileNotFoundException;
|
||||
import java.io.PrintStream;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
|
|
@ -211,7 +210,7 @@ public class TrancheManager {
|
|||
|
||||
public static int countCallsAtTruth(final List<VariantDatum> data, double minLOD ) {
|
||||
int n = 0;
|
||||
for ( VariantDatum d : data) { n += (d.atTruthSite && d.lod >= minLOD ? 1 : 0); }
|
||||
for ( VariantDatum d : data ) { n += (d.atTruthSite && d.lod >= minLOD ? 1 : 0); }
|
||||
return n;
|
||||
}
|
||||
|
||||
|
|
|
|||
|
|
@ -25,15 +25,17 @@ public class VariantDataManager {
|
|||
private final double[] varianceVector; // this is really the standard deviation
|
||||
public final ArrayList<String> annotationKeys;
|
||||
private final ExpandingArrayList<TrainingSet> trainingSets;
|
||||
private final VariantRecalibratorArgumentCollection VRAC;
|
||||
|
||||
private final static long RANDOM_SEED = 83409701;
|
||||
private final static Random rand = new Random( RANDOM_SEED ); // this is going to cause problems if it is ever used in an integration test, planning to get rid of HRun anyway
|
||||
|
||||
protected final static Logger logger = Logger.getLogger(VariantDataManager.class);
|
||||
|
||||
public VariantDataManager( final List<String> annotationKeys ) {
|
||||
public VariantDataManager( final List<String> annotationKeys, final VariantRecalibratorArgumentCollection VRAC ) {
|
||||
this.data = null;
|
||||
this.annotationKeys = new ArrayList<String>( annotationKeys );
|
||||
this.VRAC = VRAC;
|
||||
meanVector = new double[this.annotationKeys.size()];
|
||||
varianceVector = new double[this.annotationKeys.size()];
|
||||
trainingSets = new ExpandingArrayList<TrainingSet>();
|
||||
|
|
@ -93,20 +95,20 @@ public class VariantDataManager {
|
|||
public ExpandingArrayList<VariantDatum> getTrainingData() {
|
||||
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>();
|
||||
for( final VariantDatum datum : data ) {
|
||||
if( datum.atTrainingSite && datum.originalQual > 80.0 ) { //BUGBUG: VRAC argument
|
||||
if( datum.atTrainingSite && datum.originalQual > VRAC.QUAL_THRESHOLD ) {
|
||||
trainingData.add( datum );
|
||||
}
|
||||
}
|
||||
trimDataBySTD(trainingData, 4.5); //BUGBUG: VRAC argument
|
||||
trimDataBySTD( trainingData, VRAC.STD_THRESHOLD );
|
||||
logger.info( "Training with " + trainingData.size() + " variants found in the training set(s)." );
|
||||
return trainingData;
|
||||
}
|
||||
|
||||
public ExpandingArrayList<VariantDatum> selectWorstVariants( final float bottomPercentage ) {
|
||||
public ExpandingArrayList<VariantDatum> selectWorstVariants( final double bottomPercentage ) {
|
||||
Collections.sort( data );
|
||||
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>();
|
||||
trainingData.addAll( data.subList(0, Math.round(bottomPercentage * data.size())) );
|
||||
logger.info( "Training with worst " + bottomPercentage * 100.0f + "% of data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(Math.round(bottomPercentage * data.size())).lod) + "." );
|
||||
trainingData.addAll( data.subList(0, Math.round((float)bottomPercentage * data.size())) );
|
||||
logger.info( "Training with worst " + bottomPercentage * 100.0f + "% of data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(Math.round((float)bottomPercentage * data.size())).lod) + "." );
|
||||
return trainingData;
|
||||
}
|
||||
|
||||
|
|
@ -174,7 +176,7 @@ public class VariantDataManager {
|
|||
for( final TrainingSet trainingSet : trainingSets ) {
|
||||
final Collection<VariantContext> vcs = tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, true );
|
||||
final VariantContext trainVC = ( vcs.size() != 0 ? vcs.iterator().next() : null );
|
||||
if( trainVC != null && trainVC.isVariant() && !trainVC.isFiltered() && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) {
|
||||
if( trainVC != null && trainVC.isVariant() && trainVC.isNotFiltered() && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) {
|
||||
datum.isKnown = datum.isKnown || trainingSet.isKnown;
|
||||
datum.atTruthSite = datum.atTruthSite || trainingSet.isTruth;
|
||||
datum.atTrainingSite = datum.atTrainingSite || trainingSet.isTraining;
|
||||
|
|
|
|||
|
|
@ -1,5 +1,7 @@
|
|||
package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration;
|
||||
|
||||
import org.broadinstitute.sting.commandline.Argument;
|
||||
|
||||
/**
|
||||
* Created by IntelliJ IDEA.
|
||||
* User: rpoplin
|
||||
|
|
@ -8,9 +10,20 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration;
|
|||
|
||||
public class VariantRecalibratorArgumentCollection {
|
||||
|
||||
public VariantRecalibratorArgumentCollection clone() {
|
||||
final VariantRecalibratorArgumentCollection vrac = new VariantRecalibratorArgumentCollection();
|
||||
|
||||
return vrac;
|
||||
}
|
||||
@Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians to try during variational Bayes algorithm", required=false)
|
||||
public int MAX_GAUSSIANS = 32;
|
||||
@Argument(fullName="maxIterations", shortName="mI", doc="The maximum number of VBEM iterations to be performed in variational Bayes algorithm. Procedure will normally end when convergence is detected.", required=false)
|
||||
public int MAX_ITERATIONS = 100;
|
||||
@Argument(fullName="numKMeans", shortName="nKM", doc="The number of k-means iterations to perform in order to initialize the means of the Gaussians in the Gaussian mixture model.", required=false)
|
||||
public int NUM_KMEANS_ITERATIONS = 10;
|
||||
@Argument(fullName="stdThreshold", shortName="std", doc="If a variant has annotations more than -std standard deviations away from mean then don't use it for building the Gaussian mixture model.", required=false)
|
||||
public double STD_THRESHOLD = 4.5;
|
||||
@Argument(fullName="qualThreshold", shortName="qual", doc="If a known variant has raw QUAL value less than -qual then don't use it for building the Gaussian mixture model.", required=false)
|
||||
public double QUAL_THRESHOLD = 80.0;
|
||||
@Argument(fullName="shrinkage", shortName="shrinkage", doc="The shrinkage parameter in variational Bayes algorithm.", required=false)
|
||||
public double SHRINKAGE = 0.0001;
|
||||
@Argument(fullName="dirichlet", shortName="dirichlet", doc="The dirichlet parameter in variational Bayes algorithm.", required=false)
|
||||
public double DIRICHLET_PARAMETER = 0.0001;
|
||||
@Argument(fullName="percentBadVariants", shortName="percentBad", doc="What percentage of the worst scoring variants to use when building the Gaussian mixture model of bad variants. 0.07 means bottom 7 percent.", required=false)
|
||||
public double PERCENT_BAD_VARIANTS = 0.07;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -36,7 +36,7 @@ public class VariantRecalibratorEngine {
|
|||
}
|
||||
|
||||
public GaussianMixtureModel generateModel( final List<VariantDatum> data ) {
|
||||
final GaussianMixtureModel model = new GaussianMixtureModel( 32, 4, 0.0001, 0.0001 ); //BUGBUG: VRAC arguments
|
||||
final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER );
|
||||
variationalBayesExpectationMaximization( model, data );
|
||||
return model;
|
||||
}
|
||||
|
|
@ -67,14 +67,14 @@ public class VariantRecalibratorEngine {
|
|||
private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List<VariantDatum> data ) {
|
||||
|
||||
model.cacheEmpiricalStats( data );
|
||||
model.initializeRandomModel( data, rand );
|
||||
model.initializeRandomModel( data, rand, VRAC.NUM_KMEANS_ITERATIONS );
|
||||
|
||||
// The VBEM loop
|
||||
double previousLikelihood = model.expectationStep( data );
|
||||
double currentLikelihood;
|
||||
int iteration = 0;
|
||||
logger.info("Finished iteration " + iteration );
|
||||
while( iteration < 100 ) { //BUGBUG: VRAC.maxIterations
|
||||
while( iteration < VRAC.MAX_ITERATIONS ) {
|
||||
iteration++;
|
||||
model.maximizationStep( data );
|
||||
currentLikelihood = model.expectationStep( data );
|
||||
|
|
|
|||
Loading…
Reference in New Issue