From 3224bbe7505827f7781ff780023f41e0605fa5be Mon Sep 17 00:00:00 2001 From: rpoplin Date: Mon, 2 May 2011 19:14:42 +0000 Subject: [PATCH] New visualization output for VQSR. It creates the R script file on the fly and then runs Rscript on it. Adding 1000G Project consensus code. First pass of having VQSR work with missing data by marginalizing over the missing dimension for that data point (thanks Chris and Bob for ideas). Updated math functions to use apache math commons instead of approximations from wikipedia. New parameters available for the priors based on further reading in Bishop and looking at the new visualizations. Updated integration test to use more modern files. Updated MDCP to use new best practices w.r.t. annotations. git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@5723 348d0f76-0448-11de-a6fe-93d51630548a --- .../ContrastiveRecalibrator.java | 164 ++++++++++++++++-- .../GaussianMixtureModel.java | 88 +++++----- .../MultivariateGaussian.java | 33 ++-- .../variantrecalibration/TrainingSet.java | 12 +- .../VariantDataManager.java | 128 +++++++++----- .../variantrecalibration/VariantDatum.java | 4 + ...VariantRecalibratorArgumentCollection.java | 14 +- .../VariantRecalibratorEngine.java | 17 +- .../broadinstitute/sting/utils/MathUtils.java | 27 +-- ...RecalibrationWalkersV2IntegrationTest.java | 33 ++-- .../MethodsDevelopmentCallingPipeline.scala | 16 +- 11 files changed, 359 insertions(+), 177 deletions(-) diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/ContrastiveRecalibrator.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/ContrastiveRecalibrator.java index f16834153..692aae3a3 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/ContrastiveRecalibrator.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/ContrastiveRecalibrator.java @@ -37,10 +37,13 @@ import org.broadinstitute.sting.gatk.datasources.rmd.ReferenceOrderedDataSource; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.walkers.RodWalker; import org.broadinstitute.sting.gatk.walkers.TreeReducible; +import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.QualityUtils; +import org.broadinstitute.sting.utils.Utils; import org.broadinstitute.sting.utils.collections.ExpandingArrayList; import org.broadinstitute.sting.utils.exceptions.UserException; +import java.io.FileNotFoundException; import java.io.PrintStream; import java.util.*; @@ -76,6 +79,12 @@ public class ContrastiveRecalibrator extends RodWalker reduceSum ) { dataManager.setData( reduceSum ); dataManager.normalizeData(); - engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.getTrainingData() ), false ); - engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS ) ), true ); + final GaussianMixtureModel goodModel = engine.generateModel( dataManager.getTrainingData() ); + engine.evaluateData( dataManager.getData(), goodModel, false ); + final GaussianMixtureModel badModel = engine.generateModel( dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS ) ); + engine.evaluateData( dataManager.getData(), badModel, true ); + + final ExpandingArrayList randomData = dataManager.getRandomDataForPlotting( 6000 ); final int nCallsAtTruth = TrancheManager.countCallsAtTruth( dataManager.getData(), Double.NEGATIVE_INFINITY ); final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric( nCallsAtTruth ); final List tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric ); TRANCHES_FILE.print(Tranche.tranchesString( tranches )); + double lodCutoff = 0.0; + for(final Tranche tranche : tranches) { + if(MathUtils.compareDoubles(tranche.ts, TS_FILTER_LEVEL, 0.0001)==0) { + lodCutoff = tranche.minVQSLod; + } + } + logger.info( "Writing out recalibration table..." ); dataManager.writeOutRecalibrationTable( RECAL_FILE ); + if( RSCRIPT_FILE != null ) { + logger.info( "Writing out visualization Rscript file..."); + createVisualizationScript( randomData, goodModel, badModel, lodCutoff ); + } + } -} + + private void createVisualizationScript( final ExpandingArrayList randomData, final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, final double lodCutoff ) { + PrintStream stream; + try { + stream = new PrintStream(RSCRIPT_FILE); + } catch( FileNotFoundException e ) { + throw new UserException.CouldNotCreateOutputFile(RSCRIPT_FILE, "", e); + } + stream.println("library(ggplot2)"); + + createArrangeFunction( stream ); + + stream.println("pdf(\"" + RSCRIPT_FILE + ".pdf\")"); + + for(int iii = 0; iii < USE_ANNOTATIONS.length; iii++) { + for( int jjj = iii + 1; jjj < USE_ANNOTATIONS.length; jjj++) { + //stream.println("png(\"" + RSCRIPT_FILE + "." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj] + ".png\", type=\"cairo\", width = 960, height = 960)"); + //stream.println("pdf(\"" + RSCRIPT_FILE + "." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj] + ".pdf\")"); + logger.info( "Building " + USE_ANNOTATIONS[iii] + " x " + USE_ANNOTATIONS[jjj] + " plot..."); + + final ExpandingArrayList fakeData = new ExpandingArrayList(); + double minAnn1 = 100.0, maxAnn1 = -100.0, minAnn2 = 100.0, maxAnn2 = -100.0; + for( final VariantDatum datum : randomData ) { + minAnn1 = Math.min(minAnn1, datum.annotations[iii]); + maxAnn1 = Math.max(maxAnn1, datum.annotations[iii]); + minAnn2 = Math.min(minAnn2, datum.annotations[jjj]); + maxAnn2 = Math.max(maxAnn2, datum.annotations[jjj]); + } + for(double ann1 = minAnn1; ann1 <= maxAnn1; ann1+=0.1) { + for(double ann2 = minAnn2; ann2 <= maxAnn2; ann2+=0.1) { + final VariantDatum datum = new VariantDatum(); + datum.prior = 0.0; + datum.annotations = new double[randomData.get(0).annotations.length]; + datum.isNull = new boolean[randomData.get(0).annotations.length]; + for(int ann=0; ann< datum.annotations.length; ann++) { + datum.annotations[ann] = 0.0; + datum.isNull[ann] = true; + } + datum.annotations[iii] = ann1; + datum.annotations[jjj] = ann2; + datum.isNull[iii] = false; + datum.isNull[jjj] = false; + fakeData.add(datum); + } + } + + engine.evaluateData( fakeData, goodModel, false ); + engine.evaluateData( fakeData, badModel, true ); + + stream.print("surface <- c("); + for( final VariantDatum datum : fakeData ) { + stream.print(String.format("%.3f, %.3f, %.3f, ", datum.annotations[iii], datum.annotations[jjj], Math.min(4.0, Math.max(-4.0, datum.lod)))); + } + stream.println("NA,NA,NA)"); + stream.println("s <- matrix(surface,ncol=3,byrow=T)"); + + stream.print("data <- c("); + for( final VariantDatum datum : randomData ) { + stream.print(String.format("%.3f, %.3f, %.3f, %d, %d,", datum.annotations[iii], datum.annotations[jjj], (datum.lod < lodCutoff ? -1.0 : 1.0), datum.usedForTraining, (datum.isKnown ? 1 : -1))); + } + stream.println("NA,NA,NA,NA,NA)"); + stream.println("d <- matrix(data,ncol=5,byrow=T)"); + + final String surfaceFrame = "sf." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj]; + final String dataFrame = "df." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj]; + + stream.println(surfaceFrame + " <- data.frame(x=s[,1], y=s[,2], lod=s[,3])"); + stream.println(dataFrame + " <- data.frame(x=d[,1], y=d[,2], retained=d[,3], training=d[,4], novelty=d[,5])"); + stream.println("p <- ggplot(data=" + surfaceFrame + ", aes(x=x, y=y)) + opts(panel.background = theme_rect(colour = NA), panel.grid.minor = theme_line(colour = NA), panel.grid.major = theme_line(colour = NA))"); + stream.println("p1 = p + opts(title=\"model PDF\") + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_tile(aes(fill = lod)) + scale_fill_gradient(high=\"green\", low=\"red\")"); + stream.println("p <- ggplot(data=" + dataFrame + ", aes(x=x, y=y)) + opts(panel.background = theme_rect(colour = NA), panel.grid.minor = theme_line(colour = NA), panel.grid.major = theme_line(colour = NA))"); + stream.println("p2 = p + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_point(data="+ dataFrame + ", aes(x=x, y=y, colour = retained, alpha=0.3, size=1.5)) + scale_colour_gradient(name=\"\", high=\"black\", low=\"red\",breaks=c(-1,1),labels=c(\"filtered\",\"retained\"))"); + stream.println("p3 = p + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_point(data="+ dataFrame + "["+dataFrame+"$training==0,], aes(x=x, y=y, colour = training, alpha=0.3, size=1.5)) + geom_point(data="+ dataFrame + "["+dataFrame+"$training!=0,], aes(x=x, y=y, colour = training, alpha=0.3, size=1.5)) + scale_colour_gradient2(high=\"green\", mid=\"lightgrey\", low=\"purple\",breaks=c(-1,0,1), labels=c(\"bad\", \"\", \"good\"))"); + stream.println("p4 = p + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_point(data="+ dataFrame + ", aes(x=x, y=y, colour = novelty, alpha=0.3, size=1.5)) + scale_colour_gradient(name=\"\", high=\"blue\", low=\"red\",breaks=c(-1,1), labels=c(\"novel\",\"known\"))"); + stream.println("arrange(p1, p2, p3, p4, ncol=2)"); + } + } + stream.println("dev.off()"); + + stream.close(); + + // Execute Rscript command to generate the clustering plots + final String rScriptTranchesCommandLine = PATH_TO_RSCRIPT + " " + RSCRIPT_FILE; + logger.info( "Executing: " + rScriptTranchesCommandLine ); + try { + Process p; + p = Runtime.getRuntime().exec( rScriptTranchesCommandLine ); + p.waitFor(); + } catch ( Exception e ) { + Utils.warnUser("Unable to execute the RScript command. While not critical to the calculations themselves, the script outputs a report that is extremely useful for visualizing the recalibration results. We highly recommend trying to rerun the script manually if possible."); + } + } + + // from http://gettinggeneticsdone.blogspot.com/2010/03/arrange-multiple-ggplot2-plots-in-same.html + private void createArrangeFunction( final PrintStream stream ) { + stream.println("vp.layout <- function(x, y) viewport(layout.pos.row=x, layout.pos.col=y)"); + stream.println("arrange <- function(..., nrow=NULL, ncol=NULL, as.table=FALSE) {"); + stream.println("dots <- list(...)"); + stream.println("n <- length(dots)"); + stream.println("if(is.null(nrow) & is.null(ncol)) { nrow = floor(n/2) ; ncol = ceiling(n/nrow)}"); + stream.println("if(is.null(nrow)) { nrow = ceiling(n/ncol)}"); + stream.println("if(is.null(ncol)) { ncol = ceiling(n/nrow)}"); + stream.println("grid.newpage()"); + stream.println("pushViewport(viewport(layout=grid.layout(nrow,ncol) ) )"); + stream.println("ii.p <- 1"); + stream.println("for(ii.row in seq(1, nrow)){"); + stream.println("ii.table.row <- ii.row "); + stream.println("if(as.table) {ii.table.row <- nrow - ii.table.row + 1}"); + stream.println("for(ii.col in seq(1, ncol)){"); + stream.println("ii.table <- ii.p"); + stream.println("if(ii.p > n) break"); + stream.println("print(dots[[ii.table]], vp=vp.layout(ii.table.row, ii.col))"); + stream.println("ii.p <- ii.p + 1"); + stream.println("}"); + stream.println("}"); + stream.println("}"); + } + + } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java index d1a452bec..64bb3755a 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java @@ -1,11 +1,13 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration; import Jama.Matrix; +import cern.jet.random.Normal; import org.apache.log4j.Logger; import org.broadinstitute.sting.gatk.GenomeAnalysisEngine; import org.broadinstitute.sting.utils.MathUtils; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; /** @@ -21,13 +23,13 @@ public class GaussianMixtureModel { private final ArrayList gaussians; private final double shrinkage; private final double dirichletParameter; - private final double degreesOfFreedom; + private final double priorCounts; private final double[] empiricalMu; private final Matrix empiricalSigma; public boolean isModelReadyForEvaluation; public GaussianMixtureModel( final int numGaussians, final int numAnnotations, - final double shrinkage, final double dirichletParameter ) { + final double shrinkage, final double dirichletParameter, final double priorCounts ) { gaussians = new ArrayList( numGaussians ); for( int iii = 0; iii < numGaussians; iii++ ) { @@ -36,39 +38,15 @@ public class GaussianMixtureModel { } this.shrinkage = shrinkage; this.dirichletParameter = dirichletParameter; - degreesOfFreedom = numAnnotations + 2; + this.priorCounts = priorCounts; empiricalMu = new double[numAnnotations]; empiricalSigma = new Matrix(numAnnotations, numAnnotations); isModelReadyForEvaluation = false; } - public void cacheEmpiricalStats( final List data ) { - final double[][] tmpSigmaVals = new double[empiricalMu.length][empiricalMu.length]; - for( int iii = 0; iii < empiricalMu.length; iii++ ) { - empiricalMu[iii] = 0.0; - for( int jjj = iii; jjj < empiricalMu.length; jjj++ ) { - tmpSigmaVals[iii][jjj] = 0.0; - } - } - - for( final VariantDatum datum : data ) { - for( int iii = 0; iii < empiricalMu.length; iii++ ) { - empiricalMu[iii] += datum.annotations[iii] / ((double) data.size()); - } - } - - //for( final VariantDatum datum : data ) { - // for( int iii = 0; iii < empiricalMu.length; iii++ ) { - // for( int jjj = 0; jjj < empiricalMu.length; jjj++ ) { - // tmpSigmaVals[iii][jjj] += (datum.annotations[iii]-empiricalMu[iii]) * (datum.annotations[jjj]-empiricalMu[jjj]); - // } - // } - //} - - //empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, new Matrix(tmpSigmaVals)); - //empiricalSigma.timesEquals( 1.0 / ((double) data.size()) ); - //empiricalSigma.timesEquals( 1.0 / (Math.pow(gaussians.size(), 2.0 / ((double) empiricalMu.length))) ); - empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length)); + public void cacheEmpiricalStats() { + Arrays.fill(empiricalMu, 0.0); + empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length).times(200.0).inverse()); } public void initializeRandomModel( final List data, final int numKMeansIterations ) { @@ -85,8 +63,9 @@ public class GaussianMixtureModel { // initialize uniform mixture coefficients, random covariance matrices, and initial hyperparameters for( final MultivariateGaussian gaussian : gaussians ) { gaussian.pMixtureLog10 = Math.log10( 1.0 / ((double) gaussians.size()) ); + gaussian.sumProb = 1.0 / ((double) gaussians.size()); gaussian.initializeRandomSigma( GenomeAnalysisEngine.getRandomGenerator() ); - gaussian.hyperParameter_a = degreesOfFreedom; + gaussian.hyperParameter_a = priorCounts; gaussian.hyperParameter_b = shrinkage; gaussian.hyperParameter_lambda = dirichletParameter; } @@ -129,19 +108,17 @@ public class GaussianMixtureModel { } } - public double expectationStep( final List data ) { + public void expectationStep( final List data ) { for( final MultivariateGaussian gaussian : gaussians ) { gaussian.precomputeDenominatorForVariationalBayes( getSumHyperParameterLambda() ); } - double likelihood = 0.0; for( final VariantDatum datum : data ) { final ArrayList pVarInGaussianLog10 = new ArrayList( gaussians.size() ); for( final MultivariateGaussian gaussian : gaussians ) { final double pVarLog10 = gaussian.evaluateDatumLog10( datum ); pVarInGaussianLog10.add( pVarLog10 ); - likelihood += pVarLog10; } final double[] pVarInGaussianNormalized = MathUtils.normalizeFromLog10( pVarInGaussianLog10 ); int iii = 0; @@ -149,15 +126,11 @@ public class GaussianMixtureModel { gaussian.assignPVarInGaussian( pVarInGaussianNormalized[iii++] ); //BUGBUG: to clean up } } - - final double scaledTotalLikelihoodLog10 = likelihood / ((double) data.size()); - logger.info( "sum Log10 likelihood = " + String.format("%.5f", scaledTotalLikelihoodLog10) ); - return scaledTotalLikelihoodLog10; } public void maximizationStep( final List data ) { for( final MultivariateGaussian gaussian : gaussians ) { - gaussian.maximizeGaussian( data, empiricalMu, empiricalSigma, shrinkage, dirichletParameter, degreesOfFreedom ); + gaussian.maximizeGaussian( data, empiricalMu, empiricalSigma, shrinkage, dirichletParameter, priorCounts); } } @@ -171,28 +144,31 @@ public class GaussianMixtureModel { public void evaluateFinalModelParameters( final List data ) { for( final MultivariateGaussian gaussian : gaussians ) { - gaussian.evaluateFinalModelParameters( data ); + gaussian.evaluateFinalModelParameters(data); } normalizePMixtureLog10(); } - private void normalizePMixtureLog10() { + public double normalizePMixtureLog10() { + double sumDiff = 0.0; double sumPK = 0.0; for( final MultivariateGaussian gaussian : gaussians ) { - sumPK += gaussian.pMixtureLog10; + sumPK += gaussian.sumProb; } int gaussianIndex = 0; double[] pGaussianLog10 = new double[gaussians.size()]; for( final MultivariateGaussian gaussian : gaussians ) { - pGaussianLog10[gaussianIndex++] = Math.log10( gaussian.pMixtureLog10 / sumPK ); //BUGBUG: to clean up + pGaussianLog10[gaussianIndex++] = Math.log10( gaussian.sumProb / sumPK ); } pGaussianLog10 = MathUtils.normalizeFromLog10( pGaussianLog10, true ); gaussianIndex = 0; for( final MultivariateGaussian gaussian : gaussians ) { + sumDiff += Math.abs( pGaussianLog10[gaussianIndex] - gaussian.pMixtureLog10 ); gaussian.pMixtureLog10 = pGaussianLog10[gaussianIndex++]; } + return sumDiff; } public void precomputeDenominatorForEvaluation() { @@ -204,6 +180,9 @@ public class GaussianMixtureModel { } public double evaluateDatum( final VariantDatum datum ) { + for( final boolean isNull : datum.isNull ) { + if( isNull ) { return evaluateDatumMarginalized( datum ); } + } final double[] pVarInGaussianLog10 = new double[gaussians.size()]; int gaussianIndex = 0; for( final MultivariateGaussian gaussian : gaussians ) { @@ -211,4 +190,27 @@ public class GaussianMixtureModel { } return MathUtils.log10sumLog10(pVarInGaussianLog10); } + + public double evaluateDatumMarginalized( final VariantDatum datum ) { + int numVals = 0; + double sumPVarInGaussian = 0.0; + int numIter = 10; + final double[] pVarInGaussianLog10 = new double[gaussians.size()]; + for( int iii = 0; iii < datum.annotations.length; iii++ ) { + if( datum.isNull[iii] ) { + for( int ttt = 0; ttt < numIter; ttt++ ) { + datum.annotations[iii] = Normal.staticNextDouble(0.0, 1.0); + + int gaussianIndex = 0; + for( final MultivariateGaussian gaussian : gaussians ) { + pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10( datum ); + } + + sumPVarInGaussian += Math.pow(10.0, MathUtils.log10sumLog10(pVarInGaussianLog10)); + numVals++; + } + } + } + return Math.log10( sumPVarInGaussian / ((double) numVals) ); + } } \ No newline at end of file diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/MultivariateGaussian.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/MultivariateGaussian.java index a3c909bf6..b8e8fd28a 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/MultivariateGaussian.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/MultivariateGaussian.java @@ -1,6 +1,7 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration; import Jama.Matrix; +import org.apache.commons.math.special.Gamma; import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.collections.ExpandingArrayList; @@ -16,6 +17,7 @@ import java.util.Random; public class MultivariateGaussian { public double pMixtureLog10; + public double sumProb; final public double[] mu; final public Matrix sigma; public double hyperParameter_a; @@ -52,15 +54,15 @@ public class MultivariateGaussian { public void initializeRandomSigma( final Random rand ) { final double[][] randSigma = new double[mu.length][mu.length]; - for( int iii = 0; iii < mu.length; iii++ ) { - for( int jjj = iii; jjj < mu.length; jjj++ ) { - randSigma[jjj][iii] = 0.55 + 1.25 * rand.nextDouble(); - if( rand.nextBoolean() ) { - randSigma[jjj][iii] *= -1.0; - } - if( iii != jjj ) { randSigma[iii][jjj] = 0.0; } // Sigma is a symmetric, positive-definite matrix created by taking a lower diagonal matrix and multiplying it by its transpose + for( int iii = 0; iii < mu.length; iii++ ) { + for( int jjj = iii; jjj < mu.length; jjj++ ) { + randSigma[jjj][iii] = 0.55 + 1.25 * rand.nextDouble(); + if( rand.nextBoolean() ) { + randSigma[jjj][iii] *= -1.0; } + if( iii != jjj ) { randSigma[iii][jjj] = 0.0; } // Sigma is a symmetric, positive-definite matrix created by taking a lower diagonal matrix and multiplying it by its transpose } + } Matrix tmp = new Matrix( randSigma ); tmp = tmp.times(tmp.transpose()); sigma.setMatrix(0, mu.length - 1, 0, mu.length - 1, tmp); @@ -95,15 +97,15 @@ public class MultivariateGaussian { cachedSigmaInverse = sigma.inverse(); cachedSigmaInverse.timesEquals( hyperParameter_a ); double sum = 0.0; - for(int jjj = 1; jjj < mu.length; jjj++) { - sum += MathUtils.diGamma( (hyperParameter_a + 1.0 - jjj) / 2.0 ); + for(int jjj = 1; jjj <= mu.length; jjj++) { + sum += Gamma.digamma( (hyperParameter_a + 1.0 - jjj) / 2.0 ); } sum -= Math.log( sigma.det() ); sum += Math.log(2.0) * mu.length; - final double gamma = 0.5 * sum; - final double pi = MathUtils.diGamma( hyperParameter_lambda ) - MathUtils.diGamma( sumHyperParameterLambda ); + final double lambda = 0.5 * sum; + final double pi = Gamma.digamma( hyperParameter_lambda ) - Gamma.digamma( sumHyperParameterLambda ); final double beta = (-1.0 * mu.length) / (2.0 * hyperParameter_b); - cachedDenomLog10 = (pi / Math.log(10.0)) + (gamma / Math.log(10.0)) + (beta / Math.log(10.0)); + cachedDenomLog10 = (pi / Math.log(10.0)) + (lambda / Math.log(10.0)) + (beta / Math.log(10.0)); } public double evaluateDatumLog10( final VariantDatum datum ) { @@ -132,7 +134,7 @@ public class MultivariateGaussian { public void maximizeGaussian( final List data, final double[] empiricalMu, final Matrix empiricalSigma, final double SHRINKAGE, final double DIRICHLET_PARAMETER, final double DEGREES_OF_FREEDOM ) { - double sumProb = 0.0; + sumProb = 1E-10; final Matrix wishart = new Matrix(mu.length, mu.length); zeroOutMu(); zeroOutSigma(); @@ -171,8 +173,6 @@ public class MultivariateGaussian { mu[iii] = (sumProb * mu[iii] + SHRINKAGE * empiricalMu[iii]) / (sumProb + SHRINKAGE); } - pMixtureLog10 = sumProb; // will be normalized later by GaussianMixtureModel so no need to do it every iteration - hyperParameter_a = sumProb + DEGREES_OF_FREEDOM; hyperParameter_b = sumProb + SHRINKAGE; hyperParameter_lambda = sumProb + DIRICHLET_PARAMETER; @@ -181,7 +181,7 @@ public class MultivariateGaussian { } public void evaluateFinalModelParameters( final List data ) { - double sumProb = 0.0; + sumProb = 0.0; zeroOutMu(); zeroOutSigma(); @@ -206,7 +206,6 @@ public class MultivariateGaussian { } sigma.timesEquals( 1.0 / sumProb ); - pMixtureLog10 = sumProb; // will be normalized later by GaussianMixtureModel so no need to do it here resetPVarInGaussian(); // clean up some memory } } \ No newline at end of file diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/TrainingSet.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/TrainingSet.java index 8e12cf544..e9c67a21e 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/TrainingSet.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/TrainingSet.java @@ -15,7 +15,8 @@ public class TrainingSet { public boolean isKnown = false; public boolean isTraining = false; public boolean isTruth = false; - public double prior = 3.0; + public boolean isConsensus = false; + public double prior = 0.0; protected final static Logger logger = Logger.getLogger(TrainingSet.class); @@ -25,8 +26,13 @@ public class TrainingSet { isKnown = tags.containsKey("known") && tags.getValue("known").equals("true"); isTraining = tags.containsKey("training") && tags.getValue("training").equals("true"); isTruth = tags.containsKey("truth") && tags.getValue("truth").equals("true"); - prior = ( tags.containsKey("known") ? Double.parseDouble(tags.getValue("prior")) : prior ); + isConsensus = tags.containsKey("consensus") && tags.getValue("consensus").equals("true"); + prior = ( tags.containsKey("prior") ? Double.parseDouble(tags.getValue("prior")) : prior ); + } + if( !isConsensus ) { + logger.info( String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) ); + } else { + logger.info( String.format( "Found consensus track: %s", this.name) ); } - logger.info(String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) ); } } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDataManager.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDataManager.java index 6787cc211..b46fe536f 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDataManager.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDataManager.java @@ -6,9 +6,7 @@ import org.broad.tribble.util.variantcontext.VariantContext; import org.broadinstitute.sting.gatk.GenomeAnalysisEngine; import org.broadinstitute.sting.gatk.contexts.AlignmentContext; import org.broadinstitute.sting.gatk.contexts.ReferenceContext; -import org.broadinstitute.sting.gatk.contexts.variantcontext.VariantContextUtils; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; -import org.broadinstitute.sting.utils.GenomeLocParser; import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.collections.ExpandingArrayList; import org.broadinstitute.sting.utils.exceptions.UserException; @@ -51,20 +49,39 @@ public class VariantDataManager { public void normalizeData() { boolean foundZeroVarianceAnnotation = false; - for( int jjj = 0; jjj < meanVector.length; jjj++ ) { //BUGBUG: to clean up - final double theMean = mean(jjj); //BUGBUG: to clean up - final double theSTD = standardDeviation(theMean, jjj); //BUGBUG: to clean up - logger.info( annotationKeys.get(jjj) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); - foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-8); - meanVector[jjj] = theMean; - varianceVector[jjj] = theSTD; + for( int iii = 0; iii < meanVector.length; iii++ ) { //BUGBUG: to clean up + final double theMean = mean(iii); //BUGBUG: to clean up + final double theSTD = standardDeviation(theMean, iii); //BUGBUG: to clean up + logger.info( annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); + foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-6); + if( annotationKeys.get(iii).toLowerCase().contains("ranksum") ) { // BUGBUG: to clean up + for( final VariantDatum datum : data ) { + if( datum.annotations[iii] > 0.0 ) { datum.annotations[iii] /= 3.0; } + } + } + meanVector[iii] = theMean; + varianceVector[iii] = theSTD; for( final VariantDatum datum : data ) { - datum.annotations[jjj] = ( datum.annotations[jjj] - theMean ) / theSTD; // Each data point is now [ (x - mean) / standard deviation ] + datum.annotations[iii] = ( datum.isNull[iii] ? Normal.staticNextDouble(0.0, 1.0) : ( datum.annotations[iii] - theMean ) / theSTD ); + // Each data point is now [ (x - mean) / standard deviation ] + if( annotationKeys.get(iii).toLowerCase().contains("ranksum") && datum.isNull[iii] && datum.annotations[iii] > 0.0 ) { + datum.annotations[iii] /= 3.0; + } } } if( foundZeroVarianceAnnotation ) { throw new UserException.BadInput( "Found annotations with zero variance. They must be excluded before proceeding." ); } + + // trim data by standard deviation threshold and place into two sets: data and failingData + for( final VariantDatum datum : data ) { + boolean remove = false; + for( final double val : datum.annotations ) { + remove = remove || (Math.abs(val) > VRAC.STD_THRESHOLD); + } + datum.failingSTDThreshold = remove; + datum.usedForTraining = 0; + } } public void addTrainingSet( final TrainingSet trainingSet ) { @@ -95,63 +112,85 @@ public class VariantDataManager { public ExpandingArrayList getTrainingData() { final ExpandingArrayList trainingData = new ExpandingArrayList(); for( final VariantDatum datum : data ) { - if( datum.atTrainingSite && datum.originalQual > VRAC.QUAL_THRESHOLD ) { + if( datum.atTrainingSite && !datum.failingSTDThreshold && datum.originalQual > VRAC.QUAL_THRESHOLD ) { trainingData.add( datum ); + datum.usedForTraining = 1; } } - trimDataBySTD( trainingData, VRAC.STD_THRESHOLD ); - logger.info( "Training with " + trainingData.size() + " variants found in the training set(s)." ); + logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." ); return trainingData; } public ExpandingArrayList selectWorstVariants( final double bottomPercentage ) { Collections.sort( data ); final ExpandingArrayList trainingData = new ExpandingArrayList(); - trainingData.addAll( data.subList(0, Math.round((float)bottomPercentage * data.size())) ); - logger.info( "Training with worst " + (float)bottomPercentage * 100.0f + "% of data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(Math.round((float)bottomPercentage * data.size())).lod) + "." ); + final int numToAdd = Math.round((float)bottomPercentage * data.size()); + int index = 0; + int numAdded = 0; + while( numAdded < numToAdd ) { + final VariantDatum datum = data.get(index++); + if( !datum.failingSTDThreshold ) { + trainingData.add( datum ); + datum.usedForTraining = -1; + numAdded++; + } + } + logger.info("Training with worst " + (float) bottomPercentage * 100.0f + "% of passing data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "."); return trainingData; } + public ExpandingArrayList getRandomDataForPlotting( int numToAdd ) { + numToAdd = Math.min(numToAdd, data.size()); + final ExpandingArrayList returnData = new ExpandingArrayList(); + for( int iii = 0; iii < numToAdd; iii++) { + final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size())); + if( !datum.failingSTDThreshold ) { + returnData.add(datum); + } + } + // add an extra 5% of points from bad training set, since that set is small but interesting + for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) { + final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size())); + if( datum.usedForTraining == -1 && !datum.failingSTDThreshold ) { returnData.add(datum); } + else { iii--; } + } + + return returnData; + } + private double mean( final int index ) { double sum = 0.0; - final int numVars = data.size(); + int numNonNull = 0; for( final VariantDatum datum : data ) { - sum += (datum.annotations[index] / ((double) numVars)); + if( datum.atTrainingSite && !datum.isNull[index] ) { sum += datum.annotations[index]; numNonNull++; } } - return sum; + return sum / ((double) numNonNull); } private double standardDeviation( final double mean, final int index ) { double sum = 0.0; - final int numVars = data.size(); + int numNonNull = 0; for( final VariantDatum datum : data ) { - sum += ( ((datum.annotations[index] - mean)*(datum.annotations[index] - mean)) / ((double) numVars)); + if( datum.atTrainingSite && !datum.isNull[index] ) { sum += ((datum.annotations[index] - mean)*(datum.annotations[index] - mean)); numNonNull++; } } - return Math.sqrt( sum ); + return Math.sqrt( sum / ((double) numNonNull) ); } - public static void trimDataBySTD( final ExpandingArrayList listData, final double STD_THRESHOLD ) { - final ExpandingArrayList dataToRemove = new ExpandingArrayList(); - for( final VariantDatum datum : listData ) { - boolean remove = false; - for( final double val : datum.annotations ) { - remove = remove || (Math.abs(val) > STD_THRESHOLD); - } - if( remove ) { dataToRemove.add( datum ); } - } - listData.removeAll( dataToRemove ); - } - - public double[] decodeAnnotations( final GenomeLocParser genomeLocParser, final VariantContext vc, final boolean jitter ) { + public void decodeAnnotations( final VariantDatum datum, final VariantContext vc, final boolean jitter ) { final double[] annotations = new double[annotationKeys.size()]; + final boolean[] isNull = new boolean[annotationKeys.size()]; int iii = 0; for( final String key : annotationKeys ) { - annotations[iii++] = decodeAnnotation( genomeLocParser, key, vc, jitter ); + isNull[iii] = false; + annotations[iii] = decodeAnnotation( key, vc, jitter ); + if( Double.isNaN(annotations[iii]) ) { isNull[iii] = true; } + iii++; } - return annotations; + datum.annotations = annotations; + datum.isNull = isNull; } - private static double decodeAnnotation( final GenomeLocParser genomeLocParser, final String annotationKey, final VariantContext vc, final boolean jitter ) { + private static double decodeAnnotation( final String annotationKey, final VariantContext vc, final boolean jitter ) { double value; if( jitter && annotationKey.equalsIgnoreCase("HRUN") ) { // HRun values must be jittered a bit to work in this GMM value = Double.parseDouble( (String)vc.getAttribute( annotationKey ) ); @@ -161,30 +200,29 @@ public class VariantDataManager { } else { try { value = Double.parseDouble( (String)vc.getAttribute( annotationKey ) ); - if(Double.isNaN(value)) { throw new NumberFormatException(); } - if( annotationKey.toLowerCase().contains("ranksum") ) { //BUGBUG: temporary hack - if(MathUtils.compareDoubles(value, 0.0, 0.01) == 0) { value = Normal.staticNextDouble(2.0, 2.0); } - else if(MathUtils.compareDoubles(value, 200.0, 0.01) == 0) { value = Normal.staticNextDouble(162.0, 20.0); } - } + if(Double.isInfinite(value)) { value = Double.NaN; } + if(annotationKey.equals("HaplotypeScore") && MathUtils.compareDoubles(value, 0.0, 0.0001) == 0 ) { value = -0.2 + 0.4*GenomeAnalysisEngine.getRandomGenerator().nextDouble(); } } catch( final Exception e ) { - throw new UserException.MalformedFile( vc.getSource(), "No double value detected for annotation = " + annotationKey + " in variant at " + VariantContextUtils.getLocation(genomeLocParser,vc) + ", reported annotation value = " + vc.getAttribute( annotationKey ), e ); + value = Double.NaN; // The VQSR works with missing data now by marginalizing over the missing dimension when evaluating clusters. } } return value; } - public void parseTrainingSets( final RefMetaDataTracker tracker, final ReferenceContext ref, final AlignmentContext context, final VariantContext evalVC, final VariantDatum datum, final boolean TRUST_ALL_POLYMORPHIC, final boolean FIX_OMNI ) { + public void parseTrainingSets( final RefMetaDataTracker tracker, final ReferenceContext ref, final AlignmentContext context, final VariantContext evalVC, final VariantDatum datum, final boolean TRUST_ALL_POLYMORPHIC ) { datum.isKnown = false; datum.atTruthSite = false; datum.atTrainingSite = false; datum.prior = 2.0; + datum.consensusCount = 0; for( final TrainingSet trainingSet : trainingSets ) { for( final VariantContext trainVC : tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, false ) ) { - if( trainVC != null && trainVC.isVariant() && (trainVC.isNotFiltered() || (FIX_OMNI && trainVC.getFilters().size()==1 && trainVC.getFilters().contains("NOT_POLY_IN_1000G"))) && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) { + if( trainVC != null && trainVC.isNotFiltered() && trainVC.isVariant() && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) { datum.isKnown = datum.isKnown || trainingSet.isKnown; datum.atTruthSite = datum.atTruthSite || trainingSet.isTruth; datum.atTrainingSite = datum.atTrainingSite || trainingSet.isTraining; datum.prior = Math.max( datum.prior, trainingSet.prior ); + datum.consensusCount += ( trainingSet.isConsensus ? 1 : 0 ); } } } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDatum.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDatum.java index 1526672ed..2509bba1e 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDatum.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDatum.java @@ -11,15 +11,19 @@ import org.broadinstitute.sting.utils.GenomeLoc; public class VariantDatum implements Comparable { public double[] annotations; + public boolean[] isNull; public boolean isKnown; public double lod; public boolean atTruthSite; public boolean atTrainingSite; public boolean isTransition; public boolean isSNP; + public boolean failingSTDThreshold; public double originalQual; public double prior; + public int consensusCount; public GenomeLoc pos; + public int usedForTraining; public MultivariateGaussian assignment; // used in K-means implementation public int compareTo( final VariantDatum other ) { diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java index 170e399b9..e12fc20b5 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java @@ -19,19 +19,21 @@ public class VariantRecalibratorArgumentCollection { @Argument(fullName = "mode", shortName = "mode", doc = "Recalibration mode to employ: 1.) SNP for recalibrating only SNPs (emitting indels untouched in the output VCF); 2.) INDEL for indels; and 3.) BOTH for recalibrating both SNPs and indels simultaneously.", required = false) public VariantRecalibratorArgumentCollection.Mode MODE = VariantRecalibratorArgumentCollection.Mode.SNP; @Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians to try during variational Bayes algorithm", required=false) - public int MAX_GAUSSIANS = 32; + public int MAX_GAUSSIANS = 10; @Argument(fullName="maxIterations", shortName="mI", doc="The maximum number of VBEM iterations to be performed in variational Bayes algorithm. Procedure will normally end when convergence is detected.", required=false) public int MAX_ITERATIONS = 100; @Argument(fullName="numKMeans", shortName="nKM", doc="The number of k-means iterations to perform in order to initialize the means of the Gaussians in the Gaussian mixture model.", required=false) - public int NUM_KMEANS_ITERATIONS = 10; + public int NUM_KMEANS_ITERATIONS = 30; @Argument(fullName="stdThreshold", shortName="std", doc="If a variant has annotations more than -std standard deviations away from mean then don't use it for building the Gaussian mixture model.", required=false) - public double STD_THRESHOLD = 4.5; + public double STD_THRESHOLD = 8.0; @Argument(fullName="qualThreshold", shortName="qual", doc="If a known variant has raw QUAL value less than -qual then don't use it for building the Gaussian mixture model.", required=false) public double QUAL_THRESHOLD = 80.0; @Argument(fullName="shrinkage", shortName="shrinkage", doc="The shrinkage parameter in variational Bayes algorithm.", required=false) - public double SHRINKAGE = 0.0001; + public double SHRINKAGE = 1.0; @Argument(fullName="dirichlet", shortName="dirichlet", doc="The dirichlet parameter in variational Bayes algorithm.", required=false) - public double DIRICHLET_PARAMETER = 0.0001; + public double DIRICHLET_PARAMETER = 0.001; + @Argument(fullName="priorCounts", shortName="priorCounts", doc="The number of prior counts to use in variational Bayes algorithm.", required=false) + public double PRIOR_COUNTS = 20.0; @Argument(fullName="percentBadVariants", shortName="percentBad", doc="What percentage of the worst scoring variants to use when building the Gaussian mixture model of bad variants. 0.07 means bottom 7 percent.", required=false) - public double PERCENT_BAD_VARIANTS = 0.07; + public double PERCENT_BAD_VARIANTS = 0.015; } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java index b4ce0040a..2a2565df8 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java @@ -21,7 +21,7 @@ public class VariantRecalibratorEngine { // the unified argument collection final private VariantRecalibratorArgumentCollection VRAC; - private final static double MIN_PROB_CONVERGENCE_LOG10 = 1.0; + private final static double MIN_PROB_CONVERGENCE = 2E-2; ///////////////////////////// // Public Methods to interface with the Engine @@ -33,7 +33,7 @@ public class VariantRecalibratorEngine { } public GaussianMixtureModel generateModel( final List data ) { - final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER ); + final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS ); variationalBayesExpectationMaximization( model, data ); return model; } @@ -63,25 +63,26 @@ public class VariantRecalibratorEngine { private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List data ) { - model.cacheEmpiricalStats( data ); + model.cacheEmpiricalStats(); model.initializeRandomModel( data, VRAC.NUM_KMEANS_ITERATIONS ); // The VBEM loop - double previousLikelihood = model.expectationStep( data ); + model.normalizePMixtureLog10(); + model.expectationStep( data ); double currentLikelihood; int iteration = 0; logger.info("Finished iteration " + iteration ); while( iteration < VRAC.MAX_ITERATIONS ) { iteration++; model.maximizationStep( data ); - currentLikelihood = model.expectationStep( data ); - + currentLikelihood = model.normalizePMixtureLog10(); + model.expectationStep( data ); + logger.info("Current change in mixture coefficients = " + String.format("%.5f", currentLikelihood)); logger.info("Finished iteration " + iteration ); - if( Math.abs(currentLikelihood - previousLikelihood) < MIN_PROB_CONVERGENCE_LOG10 ) { + if( iteration > 2 && currentLikelihood < MIN_PROB_CONVERGENCE ) { logger.info("Convergence!"); break; } - previousLikelihood = currentLikelihood; } model.evaluateFinalModelParameters( data ); diff --git a/java/src/org/broadinstitute/sting/utils/MathUtils.java b/java/src/org/broadinstitute/sting/utils/MathUtils.java index 01b697607..69d031190 100755 --- a/java/src/org/broadinstitute/sting/utils/MathUtils.java +++ b/java/src/org/broadinstitute/sting/utils/MathUtils.java @@ -420,6 +420,14 @@ public class MathUtils { return dist; } + public static double round(double num, int digits) { + double result = num * Math.pow(10.0, (double)digits); + result = Math.round(result); + result = result / Math.pow(10.0, (double)digits); + return result; + } + + /** * normalizes the log10-based array. ASSUMES THAT ALL ARRAY ENTRIES ARE <= 0 (<= 1 IN REAL-SPACE). * @@ -717,16 +725,6 @@ public class MathUtils { return ans; } - // lifted from the internet - // http://www.cs.princeton.edu/introcs/91float/Gamma.java.html - public static double logGamma(double x) { - double tmp = (x - 0.5) * Math.log(x + 4.5) - (x + 4.5); - double ser = 1.0 + 76.18009173 / (x + 0) - 86.50532033 / (x + 1) - + 24.01409822 / (x + 2) - 1.231739516 / (x + 3) - + 0.00120858003 / (x + 4) - 0.00000536382 / (x + 5); - return tmp + Math.log(ser * Math.sqrt(2 * Math.PI)); - } - public static double percentage(double x, double base) { return (base > 0 ? (x / base) * 100.0 : 0); } @@ -892,15 +890,6 @@ public class MathUtils { return getQScoreOrderStatistic(reads, offsets, (int)Math.floor(reads.size()/2.)); } - // from http://en.wikipedia.org/wiki/Digamma_function - // According to J.M. Bernardo AS 103 algorithm the digamma function for x, a real number, can be approximated by: - public static double diGamma(final double x) { - return Math.log(x) - ( 1.0 / (2.0 * x) ) - - ( 1.0 / (12.0 * Math.pow(x, 2.0)) ) - + ( 1.0 / (120.0 * Math.pow(x, 4.0)) ) - - ( 1.0 / (252.0 * Math.pow(x, 6.0)) ); - } - /** A utility class that computes on the fly average and standard deviation for a stream of numbers. * The number of observations does not have to be known in advance, and can be also very big (so that * it could overflow any naive summation-based scheme or cause loss of precision). diff --git a/java/test/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersV2IntegrationTest.java b/java/test/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersV2IntegrationTest.java index dea584949..7d923e21a 100755 --- a/java/test/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersV2IntegrationTest.java +++ b/java/test/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersV2IntegrationTest.java @@ -24,33 +24,30 @@ public class VariantRecalibrationWalkersV2IntegrationTest extends WalkerTest { } } - VRTest yriTrio = new VRTest("yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", - "74127253b3e551ae602b5957eef8e37e", // tranches - "38fe940269ce56449b2e4e3493f89cd9", // recal file - "067197ba4212a38c7fd736c1f66b294d"); // cut VCF - - VRTest lowPass = new VRTest("lowpass.N3.chr1.raw.vcf", - "40242ee5ce97ffde50cc7a12cfa0d841", // tranches - "eb8df1fbf3d71095bdb3a52713dd2a64", // recal file - "13aa221241a974883072bfad14ad8afc"); // cut VCF + VRTest lowPass = new VRTest("phase1.projectConsensus.chr20.raw.snps.vcf", + "920b12d7765eb4f6f4a1bab045679b31", // tranches + "41bbc5f07c8a9573d5bb638f01808bba", // recal file + "2b8c5e884bf5a739b782f5b3bf17f19c"); // cut VCF @DataProvider(name = "VRTest") public Object[][] createData1() { - return new Object[][]{ {yriTrio}, {lowPass} }; + return new Object[][]{ {lowPass} }; + //return new Object[][]{ {yriTrio}, {lowPass} }; // Add hg19 chr20 trio calls here } @Test(dataProvider = "VRTest") public void testVariantRecalibrator(VRTest params) { //System.out.printf("PARAMS FOR %s is %s%n", vcf, clusterFile); WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec( - "-R " + b36KGReference + - " -B:dbsnp,DBSNP,known=true,training=false,truth=false,prior=10.0 " + GATKDataLocation + "dbsnp_129_b36.rod" + - " -B:hapmap,VCF,known=false,training=true,truth=true,prior=15.0 " + comparisonDataLocation + "Validated/HapMap/3.2/sites_r27_nr.b36_fwd.vcf" + - " -B:omni,VCF,known=false,training=true,truth=true,prior=12.0 " + comparisonDataLocation + "Validated/Omni2.5_chip/1212samples.b36.vcf" + + "-R " + b37KGReference + + " -B:dbsnp,VCF,known=true,training=false,truth=false,prior=10.0 " + GATKDataLocation + "dbsnp_132_b37.leftAligned.vcf" + + " -B:hapmap,VCF,known=false,training=true,truth=true,prior=15.0 " + comparisonDataLocation + "Validated/HapMap/3.3/sites_r27_nr.b37_fwd.vcf" + + " -B:omni,VCF,known=false,training=true,truth=true,prior=12.0 " + comparisonDataLocation + "Validated/Omni2.5_chip/Omni25_sites_1525_samples.b37.vcf" + " -T ContrastiveRecalibrator" + " -B:input,VCF " + params.inVCF + - " -L 1:50,000,000-120,000,000" + - " -an QD -an MQ -an SB" + + " -L 20:1,000,000-40,000,000" + + " -an QD -an HaplotypeScore -an HRun" + + " -percentBad 0.07" + " --trustAllPolymorphic" + // for speed " -recalFile %s" + " -tranchesFile %s", @@ -61,9 +58,9 @@ public class VariantRecalibrationWalkersV2IntegrationTest extends WalkerTest { @Test(dataProvider = "VRTest",dependsOnMethods="testVariantRecalibrator") public void testApplyRecalibration(VRTest params) { WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec( - "-R " + b36KGReference + + "-R " + b37KGReference + " -T ApplyRecalibration" + - " -L 1:60,000,000-115,000,000" + + " -L 20:12,000,000-30,000,000" + " -NO_HEADER" + " -B:input,VCF " + params.inVCF + " -o %s" + diff --git a/scala/qscript/core/MethodsDevelopmentCallingPipeline.scala b/scala/qscript/core/MethodsDevelopmentCallingPipeline.scala index fe31b6bf2..15a7be90a 100755 --- a/scala/qscript/core/MethodsDevelopmentCallingPipeline.scala +++ b/scala/qscript/core/MethodsDevelopmentCallingPipeline.scala @@ -96,8 +96,8 @@ class MethodsDevelopmentCallingPipeline extends QScript { val hapmap_b36 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/HapMap/3.3/sites_r27_nr.b36_fwd.vcf" val hapmap_b37 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/HapMap/3.3/sites_r27_nr.b37_fwd.vcf" val training_hapmap_b37 = "/humgen/1kg/processing/pipeline_test_bams/hapmap3.3_training_chr20.vcf" - val omni_b36 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/1212samples.b36.vcf" - val omni_b37 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/1212samples.b37.vcf" + val omni_b36 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/Omni25_sites_1525_samples.b36.vcf" + val omni_b37 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/Omni25_sites_1525_samples.b37.vcf" val indelMask_b36 = "/humgen/1kg/processing/pipeline_test_bams/pilot1.dindel.mask.b36.bed" val indelMask_b37 = "/humgen/1kg/processing/pipeline_test_bams/pilot1.dindel.mask.b37.bed" @@ -105,7 +105,7 @@ class MethodsDevelopmentCallingPipeline extends QScript { val indels: Boolean = true val queueLogDir = ".qlog/" - + val targetDataSets: Map[String, Target] = Map( "HiSeq" -> new Target("NA12878.HiSeq", hg18, dbSNP_hg18_129, hapmap_hg18, "/humgen/gsa-hpprojects/dev/depristo/oneOffProjects/1000GenomesProcessingPaper/wgs.v13/HiSeq.WGS.cleaned.indels.10.mask", @@ -115,7 +115,7 @@ class MethodsDevelopmentCallingPipeline extends QScript { "HiSeq19" -> new Target("NA12878.HiSeq19", hg19, dbSNP_b37_129, hapmap_b37, indelMask_b37, new File("/humgen/gsa-hpprojects/NA12878Collection/bams/NA12878.HiSeq.WGS.bwa.cleaned.recal.hg19.bam"), new File("/humgen/gsa-hpprojects/dev/carneiro/hiseq19/analysis/snps/NA12878.HiSeq19.filtered.vcf"), - "/humgen/1kg/processing/pipeline_test_bams/whole_genome_chunked.hg19.intervals", 2.3, 98.0, !lowPass), + "/humgen/1kg/processing/pipeline_test_bams/whole_genome_chunked.hg19.intervals", 2.3, 99.0, !lowPass), "GA2hg19" -> new Target("NA12878.GA2.hg19", hg19, dbSNP_b37_129, hapmap_b37, indelMask_b37, new File("/humgen/gsa-hpprojects/NA12878Collection/bams/NA12878.GA2.WGS.bwa.cleaned.hg19.bam"), new File("/humgen/gsa-hpprojects/dev/carneiro/hiseq19/analysis/snps/NA12878.GA2.hg19.filtered.vcf"), @@ -251,17 +251,17 @@ class MethodsDevelopmentCallingPipeline extends QScript { this.rodBind :+= RodBind("input", "VCF", if ( goldStandard ) { t.goldStandard_VCF } else { t.rawVCF } ) this.rodBind :+= RodBind("hapmap", "VCF", t.hapmapFile, "known=false,training=true,truth=true,prior=15.0") if( t.hapmapFile.contains("b37") ) - this.rodBind :+= RodBind("omni", "VCF", omni_b37, "known=false,training=true,truth=false,prior=12.0") + this.rodBind :+= RodBind("omni", "VCF", omni_b37, "known=false,training=true,truth=true,prior=12.0") else if( t.hapmapFile.contains("b36") ) - this.rodBind :+= RodBind("omni", "VCF", omni_b36, "known=false,training=true,truth=false,prior=12.0") + this.rodBind :+= RodBind("omni", "VCF", omni_b36, "known=false,training=true,truth=true,prior=12.0") if (t.dbsnpFile.endsWith(".rod")) this.rodBind :+= RodBind("dbsnp", "DBSNP", t.dbsnpFile, "known=true,training=false,truth=false,prior=10.0") else if (t.dbsnpFile.endsWith(".vcf")) this.rodBind :+= RodBind("dbsnp", "VCF", t.dbsnpFile, "known=true,training=false,truth=false,prior=10.0") - this.use_annotation ++= List("QD", "SB", "HaplotypeScore", "HRun") + this.use_annotation ++= List("QD", "HaplotypeScore", "MQRankSum", "ReadPosRankSum", "HRun") this.tranches_file = if ( goldStandard ) { t.goldStandardTranchesFile } else { t.tranchesFile } this.recal_file = if ( goldStandard ) { t.goldStandardRecalFile } else { t.recalFile } - this.fixOmni = true // temporary argument until new Omni file is released + this.allPoly = true this.tranche ++= List("100.0", "99.9", "99.5", "99.3", "99.0", "98.9", "98.8", "98.5", "98.4", "98.3", "98.2", "98.1", "98.0", "97.9", "97.8", "97.5", "97.0", "95.0", "90.0") this.analysisName = t.name + "_VQSR" this.jobName = queueLogDir + t.name + ".VQSR"