New visualization output for VQSR. It creates the R script file on the fly and then runs Rscript on it. Adding 1000G Project consensus code. First pass of having VQSR work with missing data by marginalizing over the missing dimension for that data point (thanks Chris and Bob for ideas). Updated math functions to use apache math commons instead of approximations from wikipedia. New parameters available for the priors based on further reading in Bishop and looking at the new visualizations. Updated integration test to use more modern files. Updated MDCP to use new best practices w.r.t. annotations.

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@5723 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
rpoplin 2011-05-02 19:14:42 +00:00
parent fcf8cff64a
commit 3224bbe750
11 changed files with 359 additions and 177 deletions

View File

@ -37,10 +37,13 @@ import org.broadinstitute.sting.gatk.datasources.rmd.ReferenceOrderedDataSource;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.RodWalker;
import org.broadinstitute.sting.gatk.walkers.TreeReducible;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.QualityUtils;
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
import org.broadinstitute.sting.utils.exceptions.UserException;
import java.io.FileNotFoundException;
import java.io.PrintStream;
import java.util.*;
@ -76,6 +79,12 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
private double[] TS_TRANCHES = new double[] {100.0, 99.9, 99.0, 90.0};
@Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false)
private String[] IGNORE_INPUT_FILTERS = null;
@Argument(fullName="path_to_Rscript", shortName = "Rscript", doc = "The path to your implementation of Rscript. For Broad users this is maybe /broad/tools/apps/R-2.6.0/bin/Rscript", required=false)
private String PATH_TO_RSCRIPT = "Rscript";
@Argument(fullName="rscript_file", shortName="rscriptFile", doc="The output rscript file generated by the VQSR to aid in visualization of the input data and learned model", required=false)
private String RSCRIPT_FILE = null;
@Argument(fullName="ts_filter_level", shortName="ts_filter_level", doc="The truth sensitivity level at which to start filtering, used here to indicate filtered variants in plots", required=false)
private double TS_FILTER_LEVEL = 99.0;
/////////////////////////////
// Debug Arguments
@ -83,10 +92,6 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
@Hidden
@Argument(fullName = "trustAllPolymorphic", shortName = "allPoly", doc = "Trust that all the input training sets' unfiltered records contain only polymorphic sites to drastically speed up the computation.", required = false)
protected Boolean TRUST_ALL_POLYMORPHIC = false;
@Hidden
@Argument(fullName = "fixOmni", shortName = "fixOmni", doc = "Ignore the NOT_POLY_IN_1000G filter for the omni file because it is broken.", required = false)
protected Boolean FIX_OMNI = false; //BUGBUG: remove me very soon!
/////////////////////////////
// Private Member Variables
@ -150,13 +155,18 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
if( vc != null && ( vc.isNotFiltered() || ignoreInputFilterSet.containsAll(vc.getFilters()) ) ) {
if( checkRecalibrationMode( vc, VRAC.MODE ) ) {
final VariantDatum datum = new VariantDatum();
datum.annotations = dataManager.decodeAnnotations( ref.getGenomeLocParser(), vc, true ); //BUGBUG: when run with HierarchicalMicroScheduler this is non-deterministic because order of calls depends on load of machine
dataManager.decodeAnnotations( datum, vc, true ); //BUGBUG: when run with HierarchicalMicroScheduler this is non-deterministic because order of calls depends on load of machine
datum.pos = context.getLocation();
datum.originalQual = vc.getPhredScaledQual();
datum.isSNP = vc.isSNP() && vc.isBiallelic();
datum.isTransition = datum.isSNP && VariantContextUtils.isTransition(vc);
dataManager.parseTrainingSets( tracker, ref, context, vc, datum, TRUST_ALL_POLYMORPHIC, FIX_OMNI );
final double priorFactor = QualityUtils.qualToProb( datum.prior );
datum.usedForTraining = 0;
dataManager.parseTrainingSets( tracker, ref, context, vc, datum, TRUST_ALL_POLYMORPHIC );
double priorFactor = QualityUtils.qualToProb( datum.prior );
if( datum.consensusCount != 0 ) {
final double consensusPrior = QualityUtils.qualToProb( 1.0 + 5.0 * datum.consensusCount );
priorFactor = 1.0 - ((1.0 - priorFactor) * (1.0 - consensusPrior));
}
datum.prior = Math.log10( priorFactor ) - Math.log10( 1.0 - priorFactor );
mapList.add( datum );
}
@ -201,15 +211,149 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
public void onTraversalDone( final ExpandingArrayList<VariantDatum> reduceSum ) {
dataManager.setData( reduceSum );
dataManager.normalizeData();
engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.getTrainingData() ), false );
engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS ) ), true );
final GaussianMixtureModel goodModel = engine.generateModel( dataManager.getTrainingData() );
engine.evaluateData( dataManager.getData(), goodModel, false );
final GaussianMixtureModel badModel = engine.generateModel( dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS ) );
engine.evaluateData( dataManager.getData(), badModel, true );
final ExpandingArrayList<VariantDatum> randomData = dataManager.getRandomDataForPlotting( 6000 );
final int nCallsAtTruth = TrancheManager.countCallsAtTruth( dataManager.getData(), Double.NEGATIVE_INFINITY );
final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric( nCallsAtTruth );
final List<Tranche> tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric );
TRANCHES_FILE.print(Tranche.tranchesString( tranches ));
double lodCutoff = 0.0;
for(final Tranche tranche : tranches) {
if(MathUtils.compareDoubles(tranche.ts, TS_FILTER_LEVEL, 0.0001)==0) {
lodCutoff = tranche.minVQSLod;
}
}
logger.info( "Writing out recalibration table..." );
dataManager.writeOutRecalibrationTable( RECAL_FILE );
if( RSCRIPT_FILE != null ) {
logger.info( "Writing out visualization Rscript file...");
createVisualizationScript( randomData, goodModel, badModel, lodCutoff );
}
}
}
private void createVisualizationScript( final ExpandingArrayList<VariantDatum> randomData, final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, final double lodCutoff ) {
PrintStream stream;
try {
stream = new PrintStream(RSCRIPT_FILE);
} catch( FileNotFoundException e ) {
throw new UserException.CouldNotCreateOutputFile(RSCRIPT_FILE, "", e);
}
stream.println("library(ggplot2)");
createArrangeFunction( stream );
stream.println("pdf(\"" + RSCRIPT_FILE + ".pdf\")");
for(int iii = 0; iii < USE_ANNOTATIONS.length; iii++) {
for( int jjj = iii + 1; jjj < USE_ANNOTATIONS.length; jjj++) {
//stream.println("png(\"" + RSCRIPT_FILE + "." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj] + ".png\", type=\"cairo\", width = 960, height = 960)");
//stream.println("pdf(\"" + RSCRIPT_FILE + "." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj] + ".pdf\")");
logger.info( "Building " + USE_ANNOTATIONS[iii] + " x " + USE_ANNOTATIONS[jjj] + " plot...");
final ExpandingArrayList<VariantDatum> fakeData = new ExpandingArrayList<VariantDatum>();
double minAnn1 = 100.0, maxAnn1 = -100.0, minAnn2 = 100.0, maxAnn2 = -100.0;
for( final VariantDatum datum : randomData ) {
minAnn1 = Math.min(minAnn1, datum.annotations[iii]);
maxAnn1 = Math.max(maxAnn1, datum.annotations[iii]);
minAnn2 = Math.min(minAnn2, datum.annotations[jjj]);
maxAnn2 = Math.max(maxAnn2, datum.annotations[jjj]);
}
for(double ann1 = minAnn1; ann1 <= maxAnn1; ann1+=0.1) {
for(double ann2 = minAnn2; ann2 <= maxAnn2; ann2+=0.1) {
final VariantDatum datum = new VariantDatum();
datum.prior = 0.0;
datum.annotations = new double[randomData.get(0).annotations.length];
datum.isNull = new boolean[randomData.get(0).annotations.length];
for(int ann=0; ann< datum.annotations.length; ann++) {
datum.annotations[ann] = 0.0;
datum.isNull[ann] = true;
}
datum.annotations[iii] = ann1;
datum.annotations[jjj] = ann2;
datum.isNull[iii] = false;
datum.isNull[jjj] = false;
fakeData.add(datum);
}
}
engine.evaluateData( fakeData, goodModel, false );
engine.evaluateData( fakeData, badModel, true );
stream.print("surface <- c(");
for( final VariantDatum datum : fakeData ) {
stream.print(String.format("%.3f, %.3f, %.3f, ", datum.annotations[iii], datum.annotations[jjj], Math.min(4.0, Math.max(-4.0, datum.lod))));
}
stream.println("NA,NA,NA)");
stream.println("s <- matrix(surface,ncol=3,byrow=T)");
stream.print("data <- c(");
for( final VariantDatum datum : randomData ) {
stream.print(String.format("%.3f, %.3f, %.3f, %d, %d,", datum.annotations[iii], datum.annotations[jjj], (datum.lod < lodCutoff ? -1.0 : 1.0), datum.usedForTraining, (datum.isKnown ? 1 : -1)));
}
stream.println("NA,NA,NA,NA,NA)");
stream.println("d <- matrix(data,ncol=5,byrow=T)");
final String surfaceFrame = "sf." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj];
final String dataFrame = "df." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj];
stream.println(surfaceFrame + " <- data.frame(x=s[,1], y=s[,2], lod=s[,3])");
stream.println(dataFrame + " <- data.frame(x=d[,1], y=d[,2], retained=d[,3], training=d[,4], novelty=d[,5])");
stream.println("p <- ggplot(data=" + surfaceFrame + ", aes(x=x, y=y)) + opts(panel.background = theme_rect(colour = NA), panel.grid.minor = theme_line(colour = NA), panel.grid.major = theme_line(colour = NA))");
stream.println("p1 = p + opts(title=\"model PDF\") + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_tile(aes(fill = lod)) + scale_fill_gradient(high=\"green\", low=\"red\")");
stream.println("p <- ggplot(data=" + dataFrame + ", aes(x=x, y=y)) + opts(panel.background = theme_rect(colour = NA), panel.grid.minor = theme_line(colour = NA), panel.grid.major = theme_line(colour = NA))");
stream.println("p2 = p + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_point(data="+ dataFrame + ", aes(x=x, y=y, colour = retained, alpha=0.3, size=1.5)) + scale_colour_gradient(name=\"\", high=\"black\", low=\"red\",breaks=c(-1,1),labels=c(\"filtered\",\"retained\"))");
stream.println("p3 = p + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_point(data="+ dataFrame + "["+dataFrame+"$training==0,], aes(x=x, y=y, colour = training, alpha=0.3, size=1.5)) + geom_point(data="+ dataFrame + "["+dataFrame+"$training!=0,], aes(x=x, y=y, colour = training, alpha=0.3, size=1.5)) + scale_colour_gradient2(high=\"green\", mid=\"lightgrey\", low=\"purple\",breaks=c(-1,0,1), labels=c(\"bad\", \"\", \"good\"))");
stream.println("p4 = p + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_point(data="+ dataFrame + ", aes(x=x, y=y, colour = novelty, alpha=0.3, size=1.5)) + scale_colour_gradient(name=\"\", high=\"blue\", low=\"red\",breaks=c(-1,1), labels=c(\"novel\",\"known\"))");
stream.println("arrange(p1, p2, p3, p4, ncol=2)");
}
}
stream.println("dev.off()");
stream.close();
// Execute Rscript command to generate the clustering plots
final String rScriptTranchesCommandLine = PATH_TO_RSCRIPT + " " + RSCRIPT_FILE;
logger.info( "Executing: " + rScriptTranchesCommandLine );
try {
Process p;
p = Runtime.getRuntime().exec( rScriptTranchesCommandLine );
p.waitFor();
} catch ( Exception e ) {
Utils.warnUser("Unable to execute the RScript command. While not critical to the calculations themselves, the script outputs a report that is extremely useful for visualizing the recalibration results. We highly recommend trying to rerun the script manually if possible.");
}
}
// from http://gettinggeneticsdone.blogspot.com/2010/03/arrange-multiple-ggplot2-plots-in-same.html
private void createArrangeFunction( final PrintStream stream ) {
stream.println("vp.layout <- function(x, y) viewport(layout.pos.row=x, layout.pos.col=y)");
stream.println("arrange <- function(..., nrow=NULL, ncol=NULL, as.table=FALSE) {");
stream.println("dots <- list(...)");
stream.println("n <- length(dots)");
stream.println("if(is.null(nrow) & is.null(ncol)) { nrow = floor(n/2) ; ncol = ceiling(n/nrow)}");
stream.println("if(is.null(nrow)) { nrow = ceiling(n/ncol)}");
stream.println("if(is.null(ncol)) { ncol = ceiling(n/nrow)}");
stream.println("grid.newpage()");
stream.println("pushViewport(viewport(layout=grid.layout(nrow,ncol) ) )");
stream.println("ii.p <- 1");
stream.println("for(ii.row in seq(1, nrow)){");
stream.println("ii.table.row <- ii.row ");
stream.println("if(as.table) {ii.table.row <- nrow - ii.table.row + 1}");
stream.println("for(ii.col in seq(1, ncol)){");
stream.println("ii.table <- ii.p");
stream.println("if(ii.p > n) break");
stream.println("print(dots[[ii.table]], vp=vp.layout(ii.table.row, ii.col))");
stream.println("ii.p <- ii.p + 1");
stream.println("}");
stream.println("}");
stream.println("}");
}
}

View File

@ -1,11 +1,13 @@
package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration;
import Jama.Matrix;
import cern.jet.random.Normal;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.gatk.GenomeAnalysisEngine;
import org.broadinstitute.sting.utils.MathUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
@ -21,13 +23,13 @@ public class GaussianMixtureModel {
private final ArrayList<MultivariateGaussian> gaussians;
private final double shrinkage;
private final double dirichletParameter;
private final double degreesOfFreedom;
private final double priorCounts;
private final double[] empiricalMu;
private final Matrix empiricalSigma;
public boolean isModelReadyForEvaluation;
public GaussianMixtureModel( final int numGaussians, final int numAnnotations,
final double shrinkage, final double dirichletParameter ) {
final double shrinkage, final double dirichletParameter, final double priorCounts ) {
gaussians = new ArrayList<MultivariateGaussian>( numGaussians );
for( int iii = 0; iii < numGaussians; iii++ ) {
@ -36,39 +38,15 @@ public class GaussianMixtureModel {
}
this.shrinkage = shrinkage;
this.dirichletParameter = dirichletParameter;
degreesOfFreedom = numAnnotations + 2;
this.priorCounts = priorCounts;
empiricalMu = new double[numAnnotations];
empiricalSigma = new Matrix(numAnnotations, numAnnotations);
isModelReadyForEvaluation = false;
}
public void cacheEmpiricalStats( final List<VariantDatum> data ) {
final double[][] tmpSigmaVals = new double[empiricalMu.length][empiricalMu.length];
for( int iii = 0; iii < empiricalMu.length; iii++ ) {
empiricalMu[iii] = 0.0;
for( int jjj = iii; jjj < empiricalMu.length; jjj++ ) {
tmpSigmaVals[iii][jjj] = 0.0;
}
}
for( final VariantDatum datum : data ) {
for( int iii = 0; iii < empiricalMu.length; iii++ ) {
empiricalMu[iii] += datum.annotations[iii] / ((double) data.size());
}
}
//for( final VariantDatum datum : data ) {
// for( int iii = 0; iii < empiricalMu.length; iii++ ) {
// for( int jjj = 0; jjj < empiricalMu.length; jjj++ ) {
// tmpSigmaVals[iii][jjj] += (datum.annotations[iii]-empiricalMu[iii]) * (datum.annotations[jjj]-empiricalMu[jjj]);
// }
// }
//}
//empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, new Matrix(tmpSigmaVals));
//empiricalSigma.timesEquals( 1.0 / ((double) data.size()) );
//empiricalSigma.timesEquals( 1.0 / (Math.pow(gaussians.size(), 2.0 / ((double) empiricalMu.length))) );
empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length));
public void cacheEmpiricalStats() {
Arrays.fill(empiricalMu, 0.0);
empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length).times(200.0).inverse());
}
public void initializeRandomModel( final List<VariantDatum> data, final int numKMeansIterations ) {
@ -85,8 +63,9 @@ public class GaussianMixtureModel {
// initialize uniform mixture coefficients, random covariance matrices, and initial hyperparameters
for( final MultivariateGaussian gaussian : gaussians ) {
gaussian.pMixtureLog10 = Math.log10( 1.0 / ((double) gaussians.size()) );
gaussian.sumProb = 1.0 / ((double) gaussians.size());
gaussian.initializeRandomSigma( GenomeAnalysisEngine.getRandomGenerator() );
gaussian.hyperParameter_a = degreesOfFreedom;
gaussian.hyperParameter_a = priorCounts;
gaussian.hyperParameter_b = shrinkage;
gaussian.hyperParameter_lambda = dirichletParameter;
}
@ -129,19 +108,17 @@ public class GaussianMixtureModel {
}
}
public double expectationStep( final List<VariantDatum> data ) {
public void expectationStep( final List<VariantDatum> data ) {
for( final MultivariateGaussian gaussian : gaussians ) {
gaussian.precomputeDenominatorForVariationalBayes( getSumHyperParameterLambda() );
}
double likelihood = 0.0;
for( final VariantDatum datum : data ) {
final ArrayList<Double> pVarInGaussianLog10 = new ArrayList<Double>( gaussians.size() );
for( final MultivariateGaussian gaussian : gaussians ) {
final double pVarLog10 = gaussian.evaluateDatumLog10( datum );
pVarInGaussianLog10.add( pVarLog10 );
likelihood += pVarLog10;
}
final double[] pVarInGaussianNormalized = MathUtils.normalizeFromLog10( pVarInGaussianLog10 );
int iii = 0;
@ -149,15 +126,11 @@ public class GaussianMixtureModel {
gaussian.assignPVarInGaussian( pVarInGaussianNormalized[iii++] ); //BUGBUG: to clean up
}
}
final double scaledTotalLikelihoodLog10 = likelihood / ((double) data.size());
logger.info( "sum Log10 likelihood = " + String.format("%.5f", scaledTotalLikelihoodLog10) );
return scaledTotalLikelihoodLog10;
}
public void maximizationStep( final List<VariantDatum> data ) {
for( final MultivariateGaussian gaussian : gaussians ) {
gaussian.maximizeGaussian( data, empiricalMu, empiricalSigma, shrinkage, dirichletParameter, degreesOfFreedom );
gaussian.maximizeGaussian( data, empiricalMu, empiricalSigma, shrinkage, dirichletParameter, priorCounts);
}
}
@ -171,28 +144,31 @@ public class GaussianMixtureModel {
public void evaluateFinalModelParameters( final List<VariantDatum> data ) {
for( final MultivariateGaussian gaussian : gaussians ) {
gaussian.evaluateFinalModelParameters( data );
gaussian.evaluateFinalModelParameters(data);
}
normalizePMixtureLog10();
}
private void normalizePMixtureLog10() {
public double normalizePMixtureLog10() {
double sumDiff = 0.0;
double sumPK = 0.0;
for( final MultivariateGaussian gaussian : gaussians ) {
sumPK += gaussian.pMixtureLog10;
sumPK += gaussian.sumProb;
}
int gaussianIndex = 0;
double[] pGaussianLog10 = new double[gaussians.size()];
for( final MultivariateGaussian gaussian : gaussians ) {
pGaussianLog10[gaussianIndex++] = Math.log10( gaussian.pMixtureLog10 / sumPK ); //BUGBUG: to clean up
pGaussianLog10[gaussianIndex++] = Math.log10( gaussian.sumProb / sumPK );
}
pGaussianLog10 = MathUtils.normalizeFromLog10( pGaussianLog10, true );
gaussianIndex = 0;
for( final MultivariateGaussian gaussian : gaussians ) {
sumDiff += Math.abs( pGaussianLog10[gaussianIndex] - gaussian.pMixtureLog10 );
gaussian.pMixtureLog10 = pGaussianLog10[gaussianIndex++];
}
return sumDiff;
}
public void precomputeDenominatorForEvaluation() {
@ -204,6 +180,9 @@ public class GaussianMixtureModel {
}
public double evaluateDatum( final VariantDatum datum ) {
for( final boolean isNull : datum.isNull ) {
if( isNull ) { return evaluateDatumMarginalized( datum ); }
}
final double[] pVarInGaussianLog10 = new double[gaussians.size()];
int gaussianIndex = 0;
for( final MultivariateGaussian gaussian : gaussians ) {
@ -211,4 +190,27 @@ public class GaussianMixtureModel {
}
return MathUtils.log10sumLog10(pVarInGaussianLog10);
}
public double evaluateDatumMarginalized( final VariantDatum datum ) {
int numVals = 0;
double sumPVarInGaussian = 0.0;
int numIter = 10;
final double[] pVarInGaussianLog10 = new double[gaussians.size()];
for( int iii = 0; iii < datum.annotations.length; iii++ ) {
if( datum.isNull[iii] ) {
for( int ttt = 0; ttt < numIter; ttt++ ) {
datum.annotations[iii] = Normal.staticNextDouble(0.0, 1.0);
int gaussianIndex = 0;
for( final MultivariateGaussian gaussian : gaussians ) {
pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10( datum );
}
sumPVarInGaussian += Math.pow(10.0, MathUtils.log10sumLog10(pVarInGaussianLog10));
numVals++;
}
}
}
return Math.log10( sumPVarInGaussian / ((double) numVals) );
}
}

View File

@ -1,6 +1,7 @@
package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration;
import Jama.Matrix;
import org.apache.commons.math.special.Gamma;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
@ -16,6 +17,7 @@ import java.util.Random;
public class MultivariateGaussian {
public double pMixtureLog10;
public double sumProb;
final public double[] mu;
final public Matrix sigma;
public double hyperParameter_a;
@ -52,15 +54,15 @@ public class MultivariateGaussian {
public void initializeRandomSigma( final Random rand ) {
final double[][] randSigma = new double[mu.length][mu.length];
for( int iii = 0; iii < mu.length; iii++ ) {
for( int jjj = iii; jjj < mu.length; jjj++ ) {
randSigma[jjj][iii] = 0.55 + 1.25 * rand.nextDouble();
if( rand.nextBoolean() ) {
randSigma[jjj][iii] *= -1.0;
}
if( iii != jjj ) { randSigma[iii][jjj] = 0.0; } // Sigma is a symmetric, positive-definite matrix created by taking a lower diagonal matrix and multiplying it by its transpose
for( int iii = 0; iii < mu.length; iii++ ) {
for( int jjj = iii; jjj < mu.length; jjj++ ) {
randSigma[jjj][iii] = 0.55 + 1.25 * rand.nextDouble();
if( rand.nextBoolean() ) {
randSigma[jjj][iii] *= -1.0;
}
if( iii != jjj ) { randSigma[iii][jjj] = 0.0; } // Sigma is a symmetric, positive-definite matrix created by taking a lower diagonal matrix and multiplying it by its transpose
}
}
Matrix tmp = new Matrix( randSigma );
tmp = tmp.times(tmp.transpose());
sigma.setMatrix(0, mu.length - 1, 0, mu.length - 1, tmp);
@ -95,15 +97,15 @@ public class MultivariateGaussian {
cachedSigmaInverse = sigma.inverse();
cachedSigmaInverse.timesEquals( hyperParameter_a );
double sum = 0.0;
for(int jjj = 1; jjj < mu.length; jjj++) {
sum += MathUtils.diGamma( (hyperParameter_a + 1.0 - jjj) / 2.0 );
for(int jjj = 1; jjj <= mu.length; jjj++) {
sum += Gamma.digamma( (hyperParameter_a + 1.0 - jjj) / 2.0 );
}
sum -= Math.log( sigma.det() );
sum += Math.log(2.0) * mu.length;
final double gamma = 0.5 * sum;
final double pi = MathUtils.diGamma( hyperParameter_lambda ) - MathUtils.diGamma( sumHyperParameterLambda );
final double lambda = 0.5 * sum;
final double pi = Gamma.digamma( hyperParameter_lambda ) - Gamma.digamma( sumHyperParameterLambda );
final double beta = (-1.0 * mu.length) / (2.0 * hyperParameter_b);
cachedDenomLog10 = (pi / Math.log(10.0)) + (gamma / Math.log(10.0)) + (beta / Math.log(10.0));
cachedDenomLog10 = (pi / Math.log(10.0)) + (lambda / Math.log(10.0)) + (beta / Math.log(10.0));
}
public double evaluateDatumLog10( final VariantDatum datum ) {
@ -132,7 +134,7 @@ public class MultivariateGaussian {
public void maximizeGaussian( final List<VariantDatum> data, final double[] empiricalMu, final Matrix empiricalSigma,
final double SHRINKAGE, final double DIRICHLET_PARAMETER, final double DEGREES_OF_FREEDOM ) {
double sumProb = 0.0;
sumProb = 1E-10;
final Matrix wishart = new Matrix(mu.length, mu.length);
zeroOutMu();
zeroOutSigma();
@ -171,8 +173,6 @@ public class MultivariateGaussian {
mu[iii] = (sumProb * mu[iii] + SHRINKAGE * empiricalMu[iii]) / (sumProb + SHRINKAGE);
}
pMixtureLog10 = sumProb; // will be normalized later by GaussianMixtureModel so no need to do it every iteration
hyperParameter_a = sumProb + DEGREES_OF_FREEDOM;
hyperParameter_b = sumProb + SHRINKAGE;
hyperParameter_lambda = sumProb + DIRICHLET_PARAMETER;
@ -181,7 +181,7 @@ public class MultivariateGaussian {
}
public void evaluateFinalModelParameters( final List<VariantDatum> data ) {
double sumProb = 0.0;
sumProb = 0.0;
zeroOutMu();
zeroOutSigma();
@ -206,7 +206,6 @@ public class MultivariateGaussian {
}
sigma.timesEquals( 1.0 / sumProb );
pMixtureLog10 = sumProb; // will be normalized later by GaussianMixtureModel so no need to do it here
resetPVarInGaussian(); // clean up some memory
}
}

View File

@ -15,7 +15,8 @@ public class TrainingSet {
public boolean isKnown = false;
public boolean isTraining = false;
public boolean isTruth = false;
public double prior = 3.0;
public boolean isConsensus = false;
public double prior = 0.0;
protected final static Logger logger = Logger.getLogger(TrainingSet.class);
@ -25,8 +26,13 @@ public class TrainingSet {
isKnown = tags.containsKey("known") && tags.getValue("known").equals("true");
isTraining = tags.containsKey("training") && tags.getValue("training").equals("true");
isTruth = tags.containsKey("truth") && tags.getValue("truth").equals("true");
prior = ( tags.containsKey("known") ? Double.parseDouble(tags.getValue("prior")) : prior );
isConsensus = tags.containsKey("consensus") && tags.getValue("consensus").equals("true");
prior = ( tags.containsKey("prior") ? Double.parseDouble(tags.getValue("prior")) : prior );
}
if( !isConsensus ) {
logger.info( String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) );
} else {
logger.info( String.format( "Found consensus track: %s", this.name) );
}
logger.info(String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) );
}
}

View File

@ -6,9 +6,7 @@ import org.broad.tribble.util.variantcontext.VariantContext;
import org.broadinstitute.sting.gatk.GenomeAnalysisEngine;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.contexts.variantcontext.VariantContextUtils;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.utils.GenomeLocParser;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
import org.broadinstitute.sting.utils.exceptions.UserException;
@ -51,20 +49,39 @@ public class VariantDataManager {
public void normalizeData() {
boolean foundZeroVarianceAnnotation = false;
for( int jjj = 0; jjj < meanVector.length; jjj++ ) { //BUGBUG: to clean up
final double theMean = mean(jjj); //BUGBUG: to clean up
final double theSTD = standardDeviation(theMean, jjj); //BUGBUG: to clean up
logger.info( annotationKeys.get(jjj) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) );
foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-8);
meanVector[jjj] = theMean;
varianceVector[jjj] = theSTD;
for( int iii = 0; iii < meanVector.length; iii++ ) { //BUGBUG: to clean up
final double theMean = mean(iii); //BUGBUG: to clean up
final double theSTD = standardDeviation(theMean, iii); //BUGBUG: to clean up
logger.info( annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) );
foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-6);
if( annotationKeys.get(iii).toLowerCase().contains("ranksum") ) { // BUGBUG: to clean up
for( final VariantDatum datum : data ) {
if( datum.annotations[iii] > 0.0 ) { datum.annotations[iii] /= 3.0; }
}
}
meanVector[iii] = theMean;
varianceVector[iii] = theSTD;
for( final VariantDatum datum : data ) {
datum.annotations[jjj] = ( datum.annotations[jjj] - theMean ) / theSTD; // Each data point is now [ (x - mean) / standard deviation ]
datum.annotations[iii] = ( datum.isNull[iii] ? Normal.staticNextDouble(0.0, 1.0) : ( datum.annotations[iii] - theMean ) / theSTD );
// Each data point is now [ (x - mean) / standard deviation ]
if( annotationKeys.get(iii).toLowerCase().contains("ranksum") && datum.isNull[iii] && datum.annotations[iii] > 0.0 ) {
datum.annotations[iii] /= 3.0;
}
}
}
if( foundZeroVarianceAnnotation ) {
throw new UserException.BadInput( "Found annotations with zero variance. They must be excluded before proceeding." );
}
// trim data by standard deviation threshold and place into two sets: data and failingData
for( final VariantDatum datum : data ) {
boolean remove = false;
for( final double val : datum.annotations ) {
remove = remove || (Math.abs(val) > VRAC.STD_THRESHOLD);
}
datum.failingSTDThreshold = remove;
datum.usedForTraining = 0;
}
}
public void addTrainingSet( final TrainingSet trainingSet ) {
@ -95,63 +112,85 @@ public class VariantDataManager {
public ExpandingArrayList<VariantDatum> getTrainingData() {
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>();
for( final VariantDatum datum : data ) {
if( datum.atTrainingSite && datum.originalQual > VRAC.QUAL_THRESHOLD ) {
if( datum.atTrainingSite && !datum.failingSTDThreshold && datum.originalQual > VRAC.QUAL_THRESHOLD ) {
trainingData.add( datum );
datum.usedForTraining = 1;
}
}
trimDataBySTD( trainingData, VRAC.STD_THRESHOLD );
logger.info( "Training with " + trainingData.size() + " variants found in the training set(s)." );
logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." );
return trainingData;
}
public ExpandingArrayList<VariantDatum> selectWorstVariants( final double bottomPercentage ) {
Collections.sort( data );
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>();
trainingData.addAll( data.subList(0, Math.round((float)bottomPercentage * data.size())) );
logger.info( "Training with worst " + (float)bottomPercentage * 100.0f + "% of data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(Math.round((float)bottomPercentage * data.size())).lod) + "." );
final int numToAdd = Math.round((float)bottomPercentage * data.size());
int index = 0;
int numAdded = 0;
while( numAdded < numToAdd ) {
final VariantDatum datum = data.get(index++);
if( !datum.failingSTDThreshold ) {
trainingData.add( datum );
datum.usedForTraining = -1;
numAdded++;
}
}
logger.info("Training with worst " + (float) bottomPercentage * 100.0f + "% of passing data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + ".");
return trainingData;
}
public ExpandingArrayList<VariantDatum> getRandomDataForPlotting( int numToAdd ) {
numToAdd = Math.min(numToAdd, data.size());
final ExpandingArrayList<VariantDatum> returnData = new ExpandingArrayList<VariantDatum>();
for( int iii = 0; iii < numToAdd; iii++) {
final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
if( !datum.failingSTDThreshold ) {
returnData.add(datum);
}
}
// add an extra 5% of points from bad training set, since that set is small but interesting
for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) {
final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
if( datum.usedForTraining == -1 && !datum.failingSTDThreshold ) { returnData.add(datum); }
else { iii--; }
}
return returnData;
}
private double mean( final int index ) {
double sum = 0.0;
final int numVars = data.size();
int numNonNull = 0;
for( final VariantDatum datum : data ) {
sum += (datum.annotations[index] / ((double) numVars));
if( datum.atTrainingSite && !datum.isNull[index] ) { sum += datum.annotations[index]; numNonNull++; }
}
return sum;
return sum / ((double) numNonNull);
}
private double standardDeviation( final double mean, final int index ) {
double sum = 0.0;
final int numVars = data.size();
int numNonNull = 0;
for( final VariantDatum datum : data ) {
sum += ( ((datum.annotations[index] - mean)*(datum.annotations[index] - mean)) / ((double) numVars));
if( datum.atTrainingSite && !datum.isNull[index] ) { sum += ((datum.annotations[index] - mean)*(datum.annotations[index] - mean)); numNonNull++; }
}
return Math.sqrt( sum );
return Math.sqrt( sum / ((double) numNonNull) );
}
public static void trimDataBySTD( final ExpandingArrayList<VariantDatum> listData, final double STD_THRESHOLD ) {
final ExpandingArrayList<VariantDatum> dataToRemove = new ExpandingArrayList<VariantDatum>();
for( final VariantDatum datum : listData ) {
boolean remove = false;
for( final double val : datum.annotations ) {
remove = remove || (Math.abs(val) > STD_THRESHOLD);
}
if( remove ) { dataToRemove.add( datum ); }
}
listData.removeAll( dataToRemove );
}
public double[] decodeAnnotations( final GenomeLocParser genomeLocParser, final VariantContext vc, final boolean jitter ) {
public void decodeAnnotations( final VariantDatum datum, final VariantContext vc, final boolean jitter ) {
final double[] annotations = new double[annotationKeys.size()];
final boolean[] isNull = new boolean[annotationKeys.size()];
int iii = 0;
for( final String key : annotationKeys ) {
annotations[iii++] = decodeAnnotation( genomeLocParser, key, vc, jitter );
isNull[iii] = false;
annotations[iii] = decodeAnnotation( key, vc, jitter );
if( Double.isNaN(annotations[iii]) ) { isNull[iii] = true; }
iii++;
}
return annotations;
datum.annotations = annotations;
datum.isNull = isNull;
}
private static double decodeAnnotation( final GenomeLocParser genomeLocParser, final String annotationKey, final VariantContext vc, final boolean jitter ) {
private static double decodeAnnotation( final String annotationKey, final VariantContext vc, final boolean jitter ) {
double value;
if( jitter && annotationKey.equalsIgnoreCase("HRUN") ) { // HRun values must be jittered a bit to work in this GMM
value = Double.parseDouble( (String)vc.getAttribute( annotationKey ) );
@ -161,30 +200,29 @@ public class VariantDataManager {
} else {
try {
value = Double.parseDouble( (String)vc.getAttribute( annotationKey ) );
if(Double.isNaN(value)) { throw new NumberFormatException(); }
if( annotationKey.toLowerCase().contains("ranksum") ) { //BUGBUG: temporary hack
if(MathUtils.compareDoubles(value, 0.0, 0.01) == 0) { value = Normal.staticNextDouble(2.0, 2.0); }
else if(MathUtils.compareDoubles(value, 200.0, 0.01) == 0) { value = Normal.staticNextDouble(162.0, 20.0); }
}
if(Double.isInfinite(value)) { value = Double.NaN; }
if(annotationKey.equals("HaplotypeScore") && MathUtils.compareDoubles(value, 0.0, 0.0001) == 0 ) { value = -0.2 + 0.4*GenomeAnalysisEngine.getRandomGenerator().nextDouble(); }
} catch( final Exception e ) {
throw new UserException.MalformedFile( vc.getSource(), "No double value detected for annotation = " + annotationKey + " in variant at " + VariantContextUtils.getLocation(genomeLocParser,vc) + ", reported annotation value = " + vc.getAttribute( annotationKey ), e );
value = Double.NaN; // The VQSR works with missing data now by marginalizing over the missing dimension when evaluating clusters.
}
}
return value;
}
public void parseTrainingSets( final RefMetaDataTracker tracker, final ReferenceContext ref, final AlignmentContext context, final VariantContext evalVC, final VariantDatum datum, final boolean TRUST_ALL_POLYMORPHIC, final boolean FIX_OMNI ) {
public void parseTrainingSets( final RefMetaDataTracker tracker, final ReferenceContext ref, final AlignmentContext context, final VariantContext evalVC, final VariantDatum datum, final boolean TRUST_ALL_POLYMORPHIC ) {
datum.isKnown = false;
datum.atTruthSite = false;
datum.atTrainingSite = false;
datum.prior = 2.0;
datum.consensusCount = 0;
for( final TrainingSet trainingSet : trainingSets ) {
for( final VariantContext trainVC : tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, false ) ) {
if( trainVC != null && trainVC.isVariant() && (trainVC.isNotFiltered() || (FIX_OMNI && trainVC.getFilters().size()==1 && trainVC.getFilters().contains("NOT_POLY_IN_1000G"))) && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) {
if( trainVC != null && trainVC.isNotFiltered() && trainVC.isVariant() && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) {
datum.isKnown = datum.isKnown || trainingSet.isKnown;
datum.atTruthSite = datum.atTruthSite || trainingSet.isTruth;
datum.atTrainingSite = datum.atTrainingSite || trainingSet.isTraining;
datum.prior = Math.max( datum.prior, trainingSet.prior );
datum.consensusCount += ( trainingSet.isConsensus ? 1 : 0 );
}
}
}

View File

@ -11,15 +11,19 @@ import org.broadinstitute.sting.utils.GenomeLoc;
public class VariantDatum implements Comparable<VariantDatum> {
public double[] annotations;
public boolean[] isNull;
public boolean isKnown;
public double lod;
public boolean atTruthSite;
public boolean atTrainingSite;
public boolean isTransition;
public boolean isSNP;
public boolean failingSTDThreshold;
public double originalQual;
public double prior;
public int consensusCount;
public GenomeLoc pos;
public int usedForTraining;
public MultivariateGaussian assignment; // used in K-means implementation
public int compareTo( final VariantDatum other ) {

View File

@ -19,19 +19,21 @@ public class VariantRecalibratorArgumentCollection {
@Argument(fullName = "mode", shortName = "mode", doc = "Recalibration mode to employ: 1.) SNP for recalibrating only SNPs (emitting indels untouched in the output VCF); 2.) INDEL for indels; and 3.) BOTH for recalibrating both SNPs and indels simultaneously.", required = false)
public VariantRecalibratorArgumentCollection.Mode MODE = VariantRecalibratorArgumentCollection.Mode.SNP;
@Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians to try during variational Bayes algorithm", required=false)
public int MAX_GAUSSIANS = 32;
public int MAX_GAUSSIANS = 10;
@Argument(fullName="maxIterations", shortName="mI", doc="The maximum number of VBEM iterations to be performed in variational Bayes algorithm. Procedure will normally end when convergence is detected.", required=false)
public int MAX_ITERATIONS = 100;
@Argument(fullName="numKMeans", shortName="nKM", doc="The number of k-means iterations to perform in order to initialize the means of the Gaussians in the Gaussian mixture model.", required=false)
public int NUM_KMEANS_ITERATIONS = 10;
public int NUM_KMEANS_ITERATIONS = 30;
@Argument(fullName="stdThreshold", shortName="std", doc="If a variant has annotations more than -std standard deviations away from mean then don't use it for building the Gaussian mixture model.", required=false)
public double STD_THRESHOLD = 4.5;
public double STD_THRESHOLD = 8.0;
@Argument(fullName="qualThreshold", shortName="qual", doc="If a known variant has raw QUAL value less than -qual then don't use it for building the Gaussian mixture model.", required=false)
public double QUAL_THRESHOLD = 80.0;
@Argument(fullName="shrinkage", shortName="shrinkage", doc="The shrinkage parameter in variational Bayes algorithm.", required=false)
public double SHRINKAGE = 0.0001;
public double SHRINKAGE = 1.0;
@Argument(fullName="dirichlet", shortName="dirichlet", doc="The dirichlet parameter in variational Bayes algorithm.", required=false)
public double DIRICHLET_PARAMETER = 0.0001;
public double DIRICHLET_PARAMETER = 0.001;
@Argument(fullName="priorCounts", shortName="priorCounts", doc="The number of prior counts to use in variational Bayes algorithm.", required=false)
public double PRIOR_COUNTS = 20.0;
@Argument(fullName="percentBadVariants", shortName="percentBad", doc="What percentage of the worst scoring variants to use when building the Gaussian mixture model of bad variants. 0.07 means bottom 7 percent.", required=false)
public double PERCENT_BAD_VARIANTS = 0.07;
public double PERCENT_BAD_VARIANTS = 0.015;
}

View File

@ -21,7 +21,7 @@ public class VariantRecalibratorEngine {
// the unified argument collection
final private VariantRecalibratorArgumentCollection VRAC;
private final static double MIN_PROB_CONVERGENCE_LOG10 = 1.0;
private final static double MIN_PROB_CONVERGENCE = 2E-2;
/////////////////////////////
// Public Methods to interface with the Engine
@ -33,7 +33,7 @@ public class VariantRecalibratorEngine {
}
public GaussianMixtureModel generateModel( final List<VariantDatum> data ) {
final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER );
final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS );
variationalBayesExpectationMaximization( model, data );
return model;
}
@ -63,25 +63,26 @@ public class VariantRecalibratorEngine {
private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List<VariantDatum> data ) {
model.cacheEmpiricalStats( data );
model.cacheEmpiricalStats();
model.initializeRandomModel( data, VRAC.NUM_KMEANS_ITERATIONS );
// The VBEM loop
double previousLikelihood = model.expectationStep( data );
model.normalizePMixtureLog10();
model.expectationStep( data );
double currentLikelihood;
int iteration = 0;
logger.info("Finished iteration " + iteration );
while( iteration < VRAC.MAX_ITERATIONS ) {
iteration++;
model.maximizationStep( data );
currentLikelihood = model.expectationStep( data );
currentLikelihood = model.normalizePMixtureLog10();
model.expectationStep( data );
logger.info("Current change in mixture coefficients = " + String.format("%.5f", currentLikelihood));
logger.info("Finished iteration " + iteration );
if( Math.abs(currentLikelihood - previousLikelihood) < MIN_PROB_CONVERGENCE_LOG10 ) {
if( iteration > 2 && currentLikelihood < MIN_PROB_CONVERGENCE ) {
logger.info("Convergence!");
break;
}
previousLikelihood = currentLikelihood;
}
model.evaluateFinalModelParameters( data );

View File

@ -420,6 +420,14 @@ public class MathUtils {
return dist;
}
public static double round(double num, int digits) {
double result = num * Math.pow(10.0, (double)digits);
result = Math.round(result);
result = result / Math.pow(10.0, (double)digits);
return result;
}
/**
* normalizes the log10-based array. ASSUMES THAT ALL ARRAY ENTRIES ARE <= 0 (<= 1 IN REAL-SPACE).
*
@ -717,16 +725,6 @@ public class MathUtils {
return ans;
}
// lifted from the internet
// http://www.cs.princeton.edu/introcs/91float/Gamma.java.html
public static double logGamma(double x) {
double tmp = (x - 0.5) * Math.log(x + 4.5) - (x + 4.5);
double ser = 1.0 + 76.18009173 / (x + 0) - 86.50532033 / (x + 1)
+ 24.01409822 / (x + 2) - 1.231739516 / (x + 3)
+ 0.00120858003 / (x + 4) - 0.00000536382 / (x + 5);
return tmp + Math.log(ser * Math.sqrt(2 * Math.PI));
}
public static double percentage(double x, double base) {
return (base > 0 ? (x / base) * 100.0 : 0);
}
@ -892,15 +890,6 @@ public class MathUtils {
return getQScoreOrderStatistic(reads, offsets, (int)Math.floor(reads.size()/2.));
}
// from http://en.wikipedia.org/wiki/Digamma_function
// According to J.M. Bernardo AS 103 algorithm the digamma function for x, a real number, can be approximated by:
public static double diGamma(final double x) {
return Math.log(x) - ( 1.0 / (2.0 * x) )
- ( 1.0 / (12.0 * Math.pow(x, 2.0)) )
+ ( 1.0 / (120.0 * Math.pow(x, 4.0)) )
- ( 1.0 / (252.0 * Math.pow(x, 6.0)) );
}
/** A utility class that computes on the fly average and standard deviation for a stream of numbers.
* The number of observations does not have to be known in advance, and can be also very big (so that
* it could overflow any naive summation-based scheme or cause loss of precision).

View File

@ -24,33 +24,30 @@ public class VariantRecalibrationWalkersV2IntegrationTest extends WalkerTest {
}
}
VRTest yriTrio = new VRTest("yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf",
"74127253b3e551ae602b5957eef8e37e", // tranches
"38fe940269ce56449b2e4e3493f89cd9", // recal file
"067197ba4212a38c7fd736c1f66b294d"); // cut VCF
VRTest lowPass = new VRTest("lowpass.N3.chr1.raw.vcf",
"40242ee5ce97ffde50cc7a12cfa0d841", // tranches
"eb8df1fbf3d71095bdb3a52713dd2a64", // recal file
"13aa221241a974883072bfad14ad8afc"); // cut VCF
VRTest lowPass = new VRTest("phase1.projectConsensus.chr20.raw.snps.vcf",
"920b12d7765eb4f6f4a1bab045679b31", // tranches
"41bbc5f07c8a9573d5bb638f01808bba", // recal file
"2b8c5e884bf5a739b782f5b3bf17f19c"); // cut VCF
@DataProvider(name = "VRTest")
public Object[][] createData1() {
return new Object[][]{ {yriTrio}, {lowPass} };
return new Object[][]{ {lowPass} };
//return new Object[][]{ {yriTrio}, {lowPass} }; // Add hg19 chr20 trio calls here
}
@Test(dataProvider = "VRTest")
public void testVariantRecalibrator(VRTest params) {
//System.out.printf("PARAMS FOR %s is %s%n", vcf, clusterFile);
WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec(
"-R " + b36KGReference +
" -B:dbsnp,DBSNP,known=true,training=false,truth=false,prior=10.0 " + GATKDataLocation + "dbsnp_129_b36.rod" +
" -B:hapmap,VCF,known=false,training=true,truth=true,prior=15.0 " + comparisonDataLocation + "Validated/HapMap/3.2/sites_r27_nr.b36_fwd.vcf" +
" -B:omni,VCF,known=false,training=true,truth=true,prior=12.0 " + comparisonDataLocation + "Validated/Omni2.5_chip/1212samples.b36.vcf" +
"-R " + b37KGReference +
" -B:dbsnp,VCF,known=true,training=false,truth=false,prior=10.0 " + GATKDataLocation + "dbsnp_132_b37.leftAligned.vcf" +
" -B:hapmap,VCF,known=false,training=true,truth=true,prior=15.0 " + comparisonDataLocation + "Validated/HapMap/3.3/sites_r27_nr.b37_fwd.vcf" +
" -B:omni,VCF,known=false,training=true,truth=true,prior=12.0 " + comparisonDataLocation + "Validated/Omni2.5_chip/Omni25_sites_1525_samples.b37.vcf" +
" -T ContrastiveRecalibrator" +
" -B:input,VCF " + params.inVCF +
" -L 1:50,000,000-120,000,000" +
" -an QD -an MQ -an SB" +
" -L 20:1,000,000-40,000,000" +
" -an QD -an HaplotypeScore -an HRun" +
" -percentBad 0.07" +
" --trustAllPolymorphic" + // for speed
" -recalFile %s" +
" -tranchesFile %s",
@ -61,9 +58,9 @@ public class VariantRecalibrationWalkersV2IntegrationTest extends WalkerTest {
@Test(dataProvider = "VRTest",dependsOnMethods="testVariantRecalibrator")
public void testApplyRecalibration(VRTest params) {
WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec(
"-R " + b36KGReference +
"-R " + b37KGReference +
" -T ApplyRecalibration" +
" -L 1:60,000,000-115,000,000" +
" -L 20:12,000,000-30,000,000" +
" -NO_HEADER" +
" -B:input,VCF " + params.inVCF +
" -o %s" +

View File

@ -96,8 +96,8 @@ class MethodsDevelopmentCallingPipeline extends QScript {
val hapmap_b36 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/HapMap/3.3/sites_r27_nr.b36_fwd.vcf"
val hapmap_b37 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/HapMap/3.3/sites_r27_nr.b37_fwd.vcf"
val training_hapmap_b37 = "/humgen/1kg/processing/pipeline_test_bams/hapmap3.3_training_chr20.vcf"
val omni_b36 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/1212samples.b36.vcf"
val omni_b37 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/1212samples.b37.vcf"
val omni_b36 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/Omni25_sites_1525_samples.b36.vcf"
val omni_b37 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/Omni25_sites_1525_samples.b37.vcf"
val indelMask_b36 = "/humgen/1kg/processing/pipeline_test_bams/pilot1.dindel.mask.b36.bed"
val indelMask_b37 = "/humgen/1kg/processing/pipeline_test_bams/pilot1.dindel.mask.b37.bed"
@ -105,7 +105,7 @@ class MethodsDevelopmentCallingPipeline extends QScript {
val indels: Boolean = true
val queueLogDir = ".qlog/"
val targetDataSets: Map[String, Target] = Map(
"HiSeq" -> new Target("NA12878.HiSeq", hg18, dbSNP_hg18_129, hapmap_hg18,
"/humgen/gsa-hpprojects/dev/depristo/oneOffProjects/1000GenomesProcessingPaper/wgs.v13/HiSeq.WGS.cleaned.indels.10.mask",
@ -115,7 +115,7 @@ class MethodsDevelopmentCallingPipeline extends QScript {
"HiSeq19" -> new Target("NA12878.HiSeq19", hg19, dbSNP_b37_129, hapmap_b37, indelMask_b37,
new File("/humgen/gsa-hpprojects/NA12878Collection/bams/NA12878.HiSeq.WGS.bwa.cleaned.recal.hg19.bam"),
new File("/humgen/gsa-hpprojects/dev/carneiro/hiseq19/analysis/snps/NA12878.HiSeq19.filtered.vcf"),
"/humgen/1kg/processing/pipeline_test_bams/whole_genome_chunked.hg19.intervals", 2.3, 98.0, !lowPass),
"/humgen/1kg/processing/pipeline_test_bams/whole_genome_chunked.hg19.intervals", 2.3, 99.0, !lowPass),
"GA2hg19" -> new Target("NA12878.GA2.hg19", hg19, dbSNP_b37_129, hapmap_b37, indelMask_b37,
new File("/humgen/gsa-hpprojects/NA12878Collection/bams/NA12878.GA2.WGS.bwa.cleaned.hg19.bam"),
new File("/humgen/gsa-hpprojects/dev/carneiro/hiseq19/analysis/snps/NA12878.GA2.hg19.filtered.vcf"),
@ -251,17 +251,17 @@ class MethodsDevelopmentCallingPipeline extends QScript {
this.rodBind :+= RodBind("input", "VCF", if ( goldStandard ) { t.goldStandard_VCF } else { t.rawVCF } )
this.rodBind :+= RodBind("hapmap", "VCF", t.hapmapFile, "known=false,training=true,truth=true,prior=15.0")
if( t.hapmapFile.contains("b37") )
this.rodBind :+= RodBind("omni", "VCF", omni_b37, "known=false,training=true,truth=false,prior=12.0")
this.rodBind :+= RodBind("omni", "VCF", omni_b37, "known=false,training=true,truth=true,prior=12.0")
else if( t.hapmapFile.contains("b36") )
this.rodBind :+= RodBind("omni", "VCF", omni_b36, "known=false,training=true,truth=false,prior=12.0")
this.rodBind :+= RodBind("omni", "VCF", omni_b36, "known=false,training=true,truth=true,prior=12.0")
if (t.dbsnpFile.endsWith(".rod"))
this.rodBind :+= RodBind("dbsnp", "DBSNP", t.dbsnpFile, "known=true,training=false,truth=false,prior=10.0")
else if (t.dbsnpFile.endsWith(".vcf"))
this.rodBind :+= RodBind("dbsnp", "VCF", t.dbsnpFile, "known=true,training=false,truth=false,prior=10.0")
this.use_annotation ++= List("QD", "SB", "HaplotypeScore", "HRun")
this.use_annotation ++= List("QD", "HaplotypeScore", "MQRankSum", "ReadPosRankSum", "HRun")
this.tranches_file = if ( goldStandard ) { t.goldStandardTranchesFile } else { t.tranchesFile }
this.recal_file = if ( goldStandard ) { t.goldStandardRecalFile } else { t.recalFile }
this.fixOmni = true // temporary argument until new Omni file is released
this.allPoly = true
this.tranche ++= List("100.0", "99.9", "99.5", "99.3", "99.0", "98.9", "98.8", "98.5", "98.4", "98.3", "98.2", "98.1", "98.0", "97.9", "97.8", "97.5", "97.0", "95.0", "90.0")
this.analysisName = t.name + "_VQSR"
this.jobName = queueLogDir + t.name + ".VQSR"