diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/ContrastiveRecalibrator.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/ContrastiveRecalibrator.java index f16834153..692aae3a3 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/ContrastiveRecalibrator.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/ContrastiveRecalibrator.java @@ -37,10 +37,13 @@ import org.broadinstitute.sting.gatk.datasources.rmd.ReferenceOrderedDataSource; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.walkers.RodWalker; import org.broadinstitute.sting.gatk.walkers.TreeReducible; +import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.QualityUtils; +import org.broadinstitute.sting.utils.Utils; import org.broadinstitute.sting.utils.collections.ExpandingArrayList; import org.broadinstitute.sting.utils.exceptions.UserException; +import java.io.FileNotFoundException; import java.io.PrintStream; import java.util.*; @@ -76,6 +79,12 @@ public class ContrastiveRecalibrator extends RodWalker reduceSum ) { dataManager.setData( reduceSum ); dataManager.normalizeData(); - engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.getTrainingData() ), false ); - engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS ) ), true ); + final GaussianMixtureModel goodModel = engine.generateModel( dataManager.getTrainingData() ); + engine.evaluateData( dataManager.getData(), goodModel, false ); + final GaussianMixtureModel badModel = engine.generateModel( dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS ) ); + engine.evaluateData( dataManager.getData(), badModel, true ); + + final ExpandingArrayList randomData = dataManager.getRandomDataForPlotting( 6000 ); final int nCallsAtTruth = TrancheManager.countCallsAtTruth( dataManager.getData(), Double.NEGATIVE_INFINITY ); final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric( nCallsAtTruth ); final List tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric ); TRANCHES_FILE.print(Tranche.tranchesString( tranches )); + double lodCutoff = 0.0; + for(final Tranche tranche : tranches) { + if(MathUtils.compareDoubles(tranche.ts, TS_FILTER_LEVEL, 0.0001)==0) { + lodCutoff = tranche.minVQSLod; + } + } + logger.info( "Writing out recalibration table..." ); dataManager.writeOutRecalibrationTable( RECAL_FILE ); + if( RSCRIPT_FILE != null ) { + logger.info( "Writing out visualization Rscript file..."); + createVisualizationScript( randomData, goodModel, badModel, lodCutoff ); + } + } -} + + private void createVisualizationScript( final ExpandingArrayList randomData, final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, final double lodCutoff ) { + PrintStream stream; + try { + stream = new PrintStream(RSCRIPT_FILE); + } catch( FileNotFoundException e ) { + throw new UserException.CouldNotCreateOutputFile(RSCRIPT_FILE, "", e); + } + stream.println("library(ggplot2)"); + + createArrangeFunction( stream ); + + stream.println("pdf(\"" + RSCRIPT_FILE + ".pdf\")"); + + for(int iii = 0; iii < USE_ANNOTATIONS.length; iii++) { + for( int jjj = iii + 1; jjj < USE_ANNOTATIONS.length; jjj++) { + //stream.println("png(\"" + RSCRIPT_FILE + "." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj] + ".png\", type=\"cairo\", width = 960, height = 960)"); + //stream.println("pdf(\"" + RSCRIPT_FILE + "." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj] + ".pdf\")"); + logger.info( "Building " + USE_ANNOTATIONS[iii] + " x " + USE_ANNOTATIONS[jjj] + " plot..."); + + final ExpandingArrayList fakeData = new ExpandingArrayList(); + double minAnn1 = 100.0, maxAnn1 = -100.0, minAnn2 = 100.0, maxAnn2 = -100.0; + for( final VariantDatum datum : randomData ) { + minAnn1 = Math.min(minAnn1, datum.annotations[iii]); + maxAnn1 = Math.max(maxAnn1, datum.annotations[iii]); + minAnn2 = Math.min(minAnn2, datum.annotations[jjj]); + maxAnn2 = Math.max(maxAnn2, datum.annotations[jjj]); + } + for(double ann1 = minAnn1; ann1 <= maxAnn1; ann1+=0.1) { + for(double ann2 = minAnn2; ann2 <= maxAnn2; ann2+=0.1) { + final VariantDatum datum = new VariantDatum(); + datum.prior = 0.0; + datum.annotations = new double[randomData.get(0).annotations.length]; + datum.isNull = new boolean[randomData.get(0).annotations.length]; + for(int ann=0; ann< datum.annotations.length; ann++) { + datum.annotations[ann] = 0.0; + datum.isNull[ann] = true; + } + datum.annotations[iii] = ann1; + datum.annotations[jjj] = ann2; + datum.isNull[iii] = false; + datum.isNull[jjj] = false; + fakeData.add(datum); + } + } + + engine.evaluateData( fakeData, goodModel, false ); + engine.evaluateData( fakeData, badModel, true ); + + stream.print("surface <- c("); + for( final VariantDatum datum : fakeData ) { + stream.print(String.format("%.3f, %.3f, %.3f, ", datum.annotations[iii], datum.annotations[jjj], Math.min(4.0, Math.max(-4.0, datum.lod)))); + } + stream.println("NA,NA,NA)"); + stream.println("s <- matrix(surface,ncol=3,byrow=T)"); + + stream.print("data <- c("); + for( final VariantDatum datum : randomData ) { + stream.print(String.format("%.3f, %.3f, %.3f, %d, %d,", datum.annotations[iii], datum.annotations[jjj], (datum.lod < lodCutoff ? -1.0 : 1.0), datum.usedForTraining, (datum.isKnown ? 1 : -1))); + } + stream.println("NA,NA,NA,NA,NA)"); + stream.println("d <- matrix(data,ncol=5,byrow=T)"); + + final String surfaceFrame = "sf." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj]; + final String dataFrame = "df." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj]; + + stream.println(surfaceFrame + " <- data.frame(x=s[,1], y=s[,2], lod=s[,3])"); + stream.println(dataFrame + " <- data.frame(x=d[,1], y=d[,2], retained=d[,3], training=d[,4], novelty=d[,5])"); + stream.println("p <- ggplot(data=" + surfaceFrame + ", aes(x=x, y=y)) + opts(panel.background = theme_rect(colour = NA), panel.grid.minor = theme_line(colour = NA), panel.grid.major = theme_line(colour = NA))"); + stream.println("p1 = p + opts(title=\"model PDF\") + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_tile(aes(fill = lod)) + scale_fill_gradient(high=\"green\", low=\"red\")"); + stream.println("p <- ggplot(data=" + dataFrame + ", aes(x=x, y=y)) + opts(panel.background = theme_rect(colour = NA), panel.grid.minor = theme_line(colour = NA), panel.grid.major = theme_line(colour = NA))"); + stream.println("p2 = p + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_point(data="+ dataFrame + ", aes(x=x, y=y, colour = retained, alpha=0.3, size=1.5)) + scale_colour_gradient(name=\"\", high=\"black\", low=\"red\",breaks=c(-1,1),labels=c(\"filtered\",\"retained\"))"); + stream.println("p3 = p + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_point(data="+ dataFrame + "["+dataFrame+"$training==0,], aes(x=x, y=y, colour = training, alpha=0.3, size=1.5)) + geom_point(data="+ dataFrame + "["+dataFrame+"$training!=0,], aes(x=x, y=y, colour = training, alpha=0.3, size=1.5)) + scale_colour_gradient2(high=\"green\", mid=\"lightgrey\", low=\"purple\",breaks=c(-1,0,1), labels=c(\"bad\", \"\", \"good\"))"); + stream.println("p4 = p + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_point(data="+ dataFrame + ", aes(x=x, y=y, colour = novelty, alpha=0.3, size=1.5)) + scale_colour_gradient(name=\"\", high=\"blue\", low=\"red\",breaks=c(-1,1), labels=c(\"novel\",\"known\"))"); + stream.println("arrange(p1, p2, p3, p4, ncol=2)"); + } + } + stream.println("dev.off()"); + + stream.close(); + + // Execute Rscript command to generate the clustering plots + final String rScriptTranchesCommandLine = PATH_TO_RSCRIPT + " " + RSCRIPT_FILE; + logger.info( "Executing: " + rScriptTranchesCommandLine ); + try { + Process p; + p = Runtime.getRuntime().exec( rScriptTranchesCommandLine ); + p.waitFor(); + } catch ( Exception e ) { + Utils.warnUser("Unable to execute the RScript command. While not critical to the calculations themselves, the script outputs a report that is extremely useful for visualizing the recalibration results. We highly recommend trying to rerun the script manually if possible."); + } + } + + // from http://gettinggeneticsdone.blogspot.com/2010/03/arrange-multiple-ggplot2-plots-in-same.html + private void createArrangeFunction( final PrintStream stream ) { + stream.println("vp.layout <- function(x, y) viewport(layout.pos.row=x, layout.pos.col=y)"); + stream.println("arrange <- function(..., nrow=NULL, ncol=NULL, as.table=FALSE) {"); + stream.println("dots <- list(...)"); + stream.println("n <- length(dots)"); + stream.println("if(is.null(nrow) & is.null(ncol)) { nrow = floor(n/2) ; ncol = ceiling(n/nrow)}"); + stream.println("if(is.null(nrow)) { nrow = ceiling(n/ncol)}"); + stream.println("if(is.null(ncol)) { ncol = ceiling(n/nrow)}"); + stream.println("grid.newpage()"); + stream.println("pushViewport(viewport(layout=grid.layout(nrow,ncol) ) )"); + stream.println("ii.p <- 1"); + stream.println("for(ii.row in seq(1, nrow)){"); + stream.println("ii.table.row <- ii.row "); + stream.println("if(as.table) {ii.table.row <- nrow - ii.table.row + 1}"); + stream.println("for(ii.col in seq(1, ncol)){"); + stream.println("ii.table <- ii.p"); + stream.println("if(ii.p > n) break"); + stream.println("print(dots[[ii.table]], vp=vp.layout(ii.table.row, ii.col))"); + stream.println("ii.p <- ii.p + 1"); + stream.println("}"); + stream.println("}"); + stream.println("}"); + } + + } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java index d1a452bec..64bb3755a 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java @@ -1,11 +1,13 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration; import Jama.Matrix; +import cern.jet.random.Normal; import org.apache.log4j.Logger; import org.broadinstitute.sting.gatk.GenomeAnalysisEngine; import org.broadinstitute.sting.utils.MathUtils; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; /** @@ -21,13 +23,13 @@ public class GaussianMixtureModel { private final ArrayList gaussians; private final double shrinkage; private final double dirichletParameter; - private final double degreesOfFreedom; + private final double priorCounts; private final double[] empiricalMu; private final Matrix empiricalSigma; public boolean isModelReadyForEvaluation; public GaussianMixtureModel( final int numGaussians, final int numAnnotations, - final double shrinkage, final double dirichletParameter ) { + final double shrinkage, final double dirichletParameter, final double priorCounts ) { gaussians = new ArrayList( numGaussians ); for( int iii = 0; iii < numGaussians; iii++ ) { @@ -36,39 +38,15 @@ public class GaussianMixtureModel { } this.shrinkage = shrinkage; this.dirichletParameter = dirichletParameter; - degreesOfFreedom = numAnnotations + 2; + this.priorCounts = priorCounts; empiricalMu = new double[numAnnotations]; empiricalSigma = new Matrix(numAnnotations, numAnnotations); isModelReadyForEvaluation = false; } - public void cacheEmpiricalStats( final List data ) { - final double[][] tmpSigmaVals = new double[empiricalMu.length][empiricalMu.length]; - for( int iii = 0; iii < empiricalMu.length; iii++ ) { - empiricalMu[iii] = 0.0; - for( int jjj = iii; jjj < empiricalMu.length; jjj++ ) { - tmpSigmaVals[iii][jjj] = 0.0; - } - } - - for( final VariantDatum datum : data ) { - for( int iii = 0; iii < empiricalMu.length; iii++ ) { - empiricalMu[iii] += datum.annotations[iii] / ((double) data.size()); - } - } - - //for( final VariantDatum datum : data ) { - // for( int iii = 0; iii < empiricalMu.length; iii++ ) { - // for( int jjj = 0; jjj < empiricalMu.length; jjj++ ) { - // tmpSigmaVals[iii][jjj] += (datum.annotations[iii]-empiricalMu[iii]) * (datum.annotations[jjj]-empiricalMu[jjj]); - // } - // } - //} - - //empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, new Matrix(tmpSigmaVals)); - //empiricalSigma.timesEquals( 1.0 / ((double) data.size()) ); - //empiricalSigma.timesEquals( 1.0 / (Math.pow(gaussians.size(), 2.0 / ((double) empiricalMu.length))) ); - empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length)); + public void cacheEmpiricalStats() { + Arrays.fill(empiricalMu, 0.0); + empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length).times(200.0).inverse()); } public void initializeRandomModel( final List data, final int numKMeansIterations ) { @@ -85,8 +63,9 @@ public class GaussianMixtureModel { // initialize uniform mixture coefficients, random covariance matrices, and initial hyperparameters for( final MultivariateGaussian gaussian : gaussians ) { gaussian.pMixtureLog10 = Math.log10( 1.0 / ((double) gaussians.size()) ); + gaussian.sumProb = 1.0 / ((double) gaussians.size()); gaussian.initializeRandomSigma( GenomeAnalysisEngine.getRandomGenerator() ); - gaussian.hyperParameter_a = degreesOfFreedom; + gaussian.hyperParameter_a = priorCounts; gaussian.hyperParameter_b = shrinkage; gaussian.hyperParameter_lambda = dirichletParameter; } @@ -129,19 +108,17 @@ public class GaussianMixtureModel { } } - public double expectationStep( final List data ) { + public void expectationStep( final List data ) { for( final MultivariateGaussian gaussian : gaussians ) { gaussian.precomputeDenominatorForVariationalBayes( getSumHyperParameterLambda() ); } - double likelihood = 0.0; for( final VariantDatum datum : data ) { final ArrayList pVarInGaussianLog10 = new ArrayList( gaussians.size() ); for( final MultivariateGaussian gaussian : gaussians ) { final double pVarLog10 = gaussian.evaluateDatumLog10( datum ); pVarInGaussianLog10.add( pVarLog10 ); - likelihood += pVarLog10; } final double[] pVarInGaussianNormalized = MathUtils.normalizeFromLog10( pVarInGaussianLog10 ); int iii = 0; @@ -149,15 +126,11 @@ public class GaussianMixtureModel { gaussian.assignPVarInGaussian( pVarInGaussianNormalized[iii++] ); //BUGBUG: to clean up } } - - final double scaledTotalLikelihoodLog10 = likelihood / ((double) data.size()); - logger.info( "sum Log10 likelihood = " + String.format("%.5f", scaledTotalLikelihoodLog10) ); - return scaledTotalLikelihoodLog10; } public void maximizationStep( final List data ) { for( final MultivariateGaussian gaussian : gaussians ) { - gaussian.maximizeGaussian( data, empiricalMu, empiricalSigma, shrinkage, dirichletParameter, degreesOfFreedom ); + gaussian.maximizeGaussian( data, empiricalMu, empiricalSigma, shrinkage, dirichletParameter, priorCounts); } } @@ -171,28 +144,31 @@ public class GaussianMixtureModel { public void evaluateFinalModelParameters( final List data ) { for( final MultivariateGaussian gaussian : gaussians ) { - gaussian.evaluateFinalModelParameters( data ); + gaussian.evaluateFinalModelParameters(data); } normalizePMixtureLog10(); } - private void normalizePMixtureLog10() { + public double normalizePMixtureLog10() { + double sumDiff = 0.0; double sumPK = 0.0; for( final MultivariateGaussian gaussian : gaussians ) { - sumPK += gaussian.pMixtureLog10; + sumPK += gaussian.sumProb; } int gaussianIndex = 0; double[] pGaussianLog10 = new double[gaussians.size()]; for( final MultivariateGaussian gaussian : gaussians ) { - pGaussianLog10[gaussianIndex++] = Math.log10( gaussian.pMixtureLog10 / sumPK ); //BUGBUG: to clean up + pGaussianLog10[gaussianIndex++] = Math.log10( gaussian.sumProb / sumPK ); } pGaussianLog10 = MathUtils.normalizeFromLog10( pGaussianLog10, true ); gaussianIndex = 0; for( final MultivariateGaussian gaussian : gaussians ) { + sumDiff += Math.abs( pGaussianLog10[gaussianIndex] - gaussian.pMixtureLog10 ); gaussian.pMixtureLog10 = pGaussianLog10[gaussianIndex++]; } + return sumDiff; } public void precomputeDenominatorForEvaluation() { @@ -204,6 +180,9 @@ public class GaussianMixtureModel { } public double evaluateDatum( final VariantDatum datum ) { + for( final boolean isNull : datum.isNull ) { + if( isNull ) { return evaluateDatumMarginalized( datum ); } + } final double[] pVarInGaussianLog10 = new double[gaussians.size()]; int gaussianIndex = 0; for( final MultivariateGaussian gaussian : gaussians ) { @@ -211,4 +190,27 @@ public class GaussianMixtureModel { } return MathUtils.log10sumLog10(pVarInGaussianLog10); } + + public double evaluateDatumMarginalized( final VariantDatum datum ) { + int numVals = 0; + double sumPVarInGaussian = 0.0; + int numIter = 10; + final double[] pVarInGaussianLog10 = new double[gaussians.size()]; + for( int iii = 0; iii < datum.annotations.length; iii++ ) { + if( datum.isNull[iii] ) { + for( int ttt = 0; ttt < numIter; ttt++ ) { + datum.annotations[iii] = Normal.staticNextDouble(0.0, 1.0); + + int gaussianIndex = 0; + for( final MultivariateGaussian gaussian : gaussians ) { + pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10( datum ); + } + + sumPVarInGaussian += Math.pow(10.0, MathUtils.log10sumLog10(pVarInGaussianLog10)); + numVals++; + } + } + } + return Math.log10( sumPVarInGaussian / ((double) numVals) ); + } } \ No newline at end of file diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/MultivariateGaussian.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/MultivariateGaussian.java index a3c909bf6..b8e8fd28a 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/MultivariateGaussian.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/MultivariateGaussian.java @@ -1,6 +1,7 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration; import Jama.Matrix; +import org.apache.commons.math.special.Gamma; import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.collections.ExpandingArrayList; @@ -16,6 +17,7 @@ import java.util.Random; public class MultivariateGaussian { public double pMixtureLog10; + public double sumProb; final public double[] mu; final public Matrix sigma; public double hyperParameter_a; @@ -52,15 +54,15 @@ public class MultivariateGaussian { public void initializeRandomSigma( final Random rand ) { final double[][] randSigma = new double[mu.length][mu.length]; - for( int iii = 0; iii < mu.length; iii++ ) { - for( int jjj = iii; jjj < mu.length; jjj++ ) { - randSigma[jjj][iii] = 0.55 + 1.25 * rand.nextDouble(); - if( rand.nextBoolean() ) { - randSigma[jjj][iii] *= -1.0; - } - if( iii != jjj ) { randSigma[iii][jjj] = 0.0; } // Sigma is a symmetric, positive-definite matrix created by taking a lower diagonal matrix and multiplying it by its transpose + for( int iii = 0; iii < mu.length; iii++ ) { + for( int jjj = iii; jjj < mu.length; jjj++ ) { + randSigma[jjj][iii] = 0.55 + 1.25 * rand.nextDouble(); + if( rand.nextBoolean() ) { + randSigma[jjj][iii] *= -1.0; } + if( iii != jjj ) { randSigma[iii][jjj] = 0.0; } // Sigma is a symmetric, positive-definite matrix created by taking a lower diagonal matrix and multiplying it by its transpose } + } Matrix tmp = new Matrix( randSigma ); tmp = tmp.times(tmp.transpose()); sigma.setMatrix(0, mu.length - 1, 0, mu.length - 1, tmp); @@ -95,15 +97,15 @@ public class MultivariateGaussian { cachedSigmaInverse = sigma.inverse(); cachedSigmaInverse.timesEquals( hyperParameter_a ); double sum = 0.0; - for(int jjj = 1; jjj < mu.length; jjj++) { - sum += MathUtils.diGamma( (hyperParameter_a + 1.0 - jjj) / 2.0 ); + for(int jjj = 1; jjj <= mu.length; jjj++) { + sum += Gamma.digamma( (hyperParameter_a + 1.0 - jjj) / 2.0 ); } sum -= Math.log( sigma.det() ); sum += Math.log(2.0) * mu.length; - final double gamma = 0.5 * sum; - final double pi = MathUtils.diGamma( hyperParameter_lambda ) - MathUtils.diGamma( sumHyperParameterLambda ); + final double lambda = 0.5 * sum; + final double pi = Gamma.digamma( hyperParameter_lambda ) - Gamma.digamma( sumHyperParameterLambda ); final double beta = (-1.0 * mu.length) / (2.0 * hyperParameter_b); - cachedDenomLog10 = (pi / Math.log(10.0)) + (gamma / Math.log(10.0)) + (beta / Math.log(10.0)); + cachedDenomLog10 = (pi / Math.log(10.0)) + (lambda / Math.log(10.0)) + (beta / Math.log(10.0)); } public double evaluateDatumLog10( final VariantDatum datum ) { @@ -132,7 +134,7 @@ public class MultivariateGaussian { public void maximizeGaussian( final List data, final double[] empiricalMu, final Matrix empiricalSigma, final double SHRINKAGE, final double DIRICHLET_PARAMETER, final double DEGREES_OF_FREEDOM ) { - double sumProb = 0.0; + sumProb = 1E-10; final Matrix wishart = new Matrix(mu.length, mu.length); zeroOutMu(); zeroOutSigma(); @@ -171,8 +173,6 @@ public class MultivariateGaussian { mu[iii] = (sumProb * mu[iii] + SHRINKAGE * empiricalMu[iii]) / (sumProb + SHRINKAGE); } - pMixtureLog10 = sumProb; // will be normalized later by GaussianMixtureModel so no need to do it every iteration - hyperParameter_a = sumProb + DEGREES_OF_FREEDOM; hyperParameter_b = sumProb + SHRINKAGE; hyperParameter_lambda = sumProb + DIRICHLET_PARAMETER; @@ -181,7 +181,7 @@ public class MultivariateGaussian { } public void evaluateFinalModelParameters( final List data ) { - double sumProb = 0.0; + sumProb = 0.0; zeroOutMu(); zeroOutSigma(); @@ -206,7 +206,6 @@ public class MultivariateGaussian { } sigma.timesEquals( 1.0 / sumProb ); - pMixtureLog10 = sumProb; // will be normalized later by GaussianMixtureModel so no need to do it here resetPVarInGaussian(); // clean up some memory } } \ No newline at end of file diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/TrainingSet.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/TrainingSet.java index 8e12cf544..e9c67a21e 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/TrainingSet.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/TrainingSet.java @@ -15,7 +15,8 @@ public class TrainingSet { public boolean isKnown = false; public boolean isTraining = false; public boolean isTruth = false; - public double prior = 3.0; + public boolean isConsensus = false; + public double prior = 0.0; protected final static Logger logger = Logger.getLogger(TrainingSet.class); @@ -25,8 +26,13 @@ public class TrainingSet { isKnown = tags.containsKey("known") && tags.getValue("known").equals("true"); isTraining = tags.containsKey("training") && tags.getValue("training").equals("true"); isTruth = tags.containsKey("truth") && tags.getValue("truth").equals("true"); - prior = ( tags.containsKey("known") ? Double.parseDouble(tags.getValue("prior")) : prior ); + isConsensus = tags.containsKey("consensus") && tags.getValue("consensus").equals("true"); + prior = ( tags.containsKey("prior") ? Double.parseDouble(tags.getValue("prior")) : prior ); + } + if( !isConsensus ) { + logger.info( String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) ); + } else { + logger.info( String.format( "Found consensus track: %s", this.name) ); } - logger.info(String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) ); } } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDataManager.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDataManager.java index 6787cc211..b46fe536f 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDataManager.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDataManager.java @@ -6,9 +6,7 @@ import org.broad.tribble.util.variantcontext.VariantContext; import org.broadinstitute.sting.gatk.GenomeAnalysisEngine; import org.broadinstitute.sting.gatk.contexts.AlignmentContext; import org.broadinstitute.sting.gatk.contexts.ReferenceContext; -import org.broadinstitute.sting.gatk.contexts.variantcontext.VariantContextUtils; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; -import org.broadinstitute.sting.utils.GenomeLocParser; import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.collections.ExpandingArrayList; import org.broadinstitute.sting.utils.exceptions.UserException; @@ -51,20 +49,39 @@ public class VariantDataManager { public void normalizeData() { boolean foundZeroVarianceAnnotation = false; - for( int jjj = 0; jjj < meanVector.length; jjj++ ) { //BUGBUG: to clean up - final double theMean = mean(jjj); //BUGBUG: to clean up - final double theSTD = standardDeviation(theMean, jjj); //BUGBUG: to clean up - logger.info( annotationKeys.get(jjj) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); - foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-8); - meanVector[jjj] = theMean; - varianceVector[jjj] = theSTD; + for( int iii = 0; iii < meanVector.length; iii++ ) { //BUGBUG: to clean up + final double theMean = mean(iii); //BUGBUG: to clean up + final double theSTD = standardDeviation(theMean, iii); //BUGBUG: to clean up + logger.info( annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); + foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-6); + if( annotationKeys.get(iii).toLowerCase().contains("ranksum") ) { // BUGBUG: to clean up + for( final VariantDatum datum : data ) { + if( datum.annotations[iii] > 0.0 ) { datum.annotations[iii] /= 3.0; } + } + } + meanVector[iii] = theMean; + varianceVector[iii] = theSTD; for( final VariantDatum datum : data ) { - datum.annotations[jjj] = ( datum.annotations[jjj] - theMean ) / theSTD; // Each data point is now [ (x - mean) / standard deviation ] + datum.annotations[iii] = ( datum.isNull[iii] ? Normal.staticNextDouble(0.0, 1.0) : ( datum.annotations[iii] - theMean ) / theSTD ); + // Each data point is now [ (x - mean) / standard deviation ] + if( annotationKeys.get(iii).toLowerCase().contains("ranksum") && datum.isNull[iii] && datum.annotations[iii] > 0.0 ) { + datum.annotations[iii] /= 3.0; + } } } if( foundZeroVarianceAnnotation ) { throw new UserException.BadInput( "Found annotations with zero variance. They must be excluded before proceeding." ); } + + // trim data by standard deviation threshold and place into two sets: data and failingData + for( final VariantDatum datum : data ) { + boolean remove = false; + for( final double val : datum.annotations ) { + remove = remove || (Math.abs(val) > VRAC.STD_THRESHOLD); + } + datum.failingSTDThreshold = remove; + datum.usedForTraining = 0; + } } public void addTrainingSet( final TrainingSet trainingSet ) { @@ -95,63 +112,85 @@ public class VariantDataManager { public ExpandingArrayList getTrainingData() { final ExpandingArrayList trainingData = new ExpandingArrayList(); for( final VariantDatum datum : data ) { - if( datum.atTrainingSite && datum.originalQual > VRAC.QUAL_THRESHOLD ) { + if( datum.atTrainingSite && !datum.failingSTDThreshold && datum.originalQual > VRAC.QUAL_THRESHOLD ) { trainingData.add( datum ); + datum.usedForTraining = 1; } } - trimDataBySTD( trainingData, VRAC.STD_THRESHOLD ); - logger.info( "Training with " + trainingData.size() + " variants found in the training set(s)." ); + logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." ); return trainingData; } public ExpandingArrayList selectWorstVariants( final double bottomPercentage ) { Collections.sort( data ); final ExpandingArrayList trainingData = new ExpandingArrayList(); - trainingData.addAll( data.subList(0, Math.round((float)bottomPercentage * data.size())) ); - logger.info( "Training with worst " + (float)bottomPercentage * 100.0f + "% of data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(Math.round((float)bottomPercentage * data.size())).lod) + "." ); + final int numToAdd = Math.round((float)bottomPercentage * data.size()); + int index = 0; + int numAdded = 0; + while( numAdded < numToAdd ) { + final VariantDatum datum = data.get(index++); + if( !datum.failingSTDThreshold ) { + trainingData.add( datum ); + datum.usedForTraining = -1; + numAdded++; + } + } + logger.info("Training with worst " + (float) bottomPercentage * 100.0f + "% of passing data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "."); return trainingData; } + public ExpandingArrayList getRandomDataForPlotting( int numToAdd ) { + numToAdd = Math.min(numToAdd, data.size()); + final ExpandingArrayList returnData = new ExpandingArrayList(); + for( int iii = 0; iii < numToAdd; iii++) { + final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size())); + if( !datum.failingSTDThreshold ) { + returnData.add(datum); + } + } + // add an extra 5% of points from bad training set, since that set is small but interesting + for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) { + final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size())); + if( datum.usedForTraining == -1 && !datum.failingSTDThreshold ) { returnData.add(datum); } + else { iii--; } + } + + return returnData; + } + private double mean( final int index ) { double sum = 0.0; - final int numVars = data.size(); + int numNonNull = 0; for( final VariantDatum datum : data ) { - sum += (datum.annotations[index] / ((double) numVars)); + if( datum.atTrainingSite && !datum.isNull[index] ) { sum += datum.annotations[index]; numNonNull++; } } - return sum; + return sum / ((double) numNonNull); } private double standardDeviation( final double mean, final int index ) { double sum = 0.0; - final int numVars = data.size(); + int numNonNull = 0; for( final VariantDatum datum : data ) { - sum += ( ((datum.annotations[index] - mean)*(datum.annotations[index] - mean)) / ((double) numVars)); + if( datum.atTrainingSite && !datum.isNull[index] ) { sum += ((datum.annotations[index] - mean)*(datum.annotations[index] - mean)); numNonNull++; } } - return Math.sqrt( sum ); + return Math.sqrt( sum / ((double) numNonNull) ); } - public static void trimDataBySTD( final ExpandingArrayList listData, final double STD_THRESHOLD ) { - final ExpandingArrayList dataToRemove = new ExpandingArrayList(); - for( final VariantDatum datum : listData ) { - boolean remove = false; - for( final double val : datum.annotations ) { - remove = remove || (Math.abs(val) > STD_THRESHOLD); - } - if( remove ) { dataToRemove.add( datum ); } - } - listData.removeAll( dataToRemove ); - } - - public double[] decodeAnnotations( final GenomeLocParser genomeLocParser, final VariantContext vc, final boolean jitter ) { + public void decodeAnnotations( final VariantDatum datum, final VariantContext vc, final boolean jitter ) { final double[] annotations = new double[annotationKeys.size()]; + final boolean[] isNull = new boolean[annotationKeys.size()]; int iii = 0; for( final String key : annotationKeys ) { - annotations[iii++] = decodeAnnotation( genomeLocParser, key, vc, jitter ); + isNull[iii] = false; + annotations[iii] = decodeAnnotation( key, vc, jitter ); + if( Double.isNaN(annotations[iii]) ) { isNull[iii] = true; } + iii++; } - return annotations; + datum.annotations = annotations; + datum.isNull = isNull; } - private static double decodeAnnotation( final GenomeLocParser genomeLocParser, final String annotationKey, final VariantContext vc, final boolean jitter ) { + private static double decodeAnnotation( final String annotationKey, final VariantContext vc, final boolean jitter ) { double value; if( jitter && annotationKey.equalsIgnoreCase("HRUN") ) { // HRun values must be jittered a bit to work in this GMM value = Double.parseDouble( (String)vc.getAttribute( annotationKey ) ); @@ -161,30 +200,29 @@ public class VariantDataManager { } else { try { value = Double.parseDouble( (String)vc.getAttribute( annotationKey ) ); - if(Double.isNaN(value)) { throw new NumberFormatException(); } - if( annotationKey.toLowerCase().contains("ranksum") ) { //BUGBUG: temporary hack - if(MathUtils.compareDoubles(value, 0.0, 0.01) == 0) { value = Normal.staticNextDouble(2.0, 2.0); } - else if(MathUtils.compareDoubles(value, 200.0, 0.01) == 0) { value = Normal.staticNextDouble(162.0, 20.0); } - } + if(Double.isInfinite(value)) { value = Double.NaN; } + if(annotationKey.equals("HaplotypeScore") && MathUtils.compareDoubles(value, 0.0, 0.0001) == 0 ) { value = -0.2 + 0.4*GenomeAnalysisEngine.getRandomGenerator().nextDouble(); } } catch( final Exception e ) { - throw new UserException.MalformedFile( vc.getSource(), "No double value detected for annotation = " + annotationKey + " in variant at " + VariantContextUtils.getLocation(genomeLocParser,vc) + ", reported annotation value = " + vc.getAttribute( annotationKey ), e ); + value = Double.NaN; // The VQSR works with missing data now by marginalizing over the missing dimension when evaluating clusters. } } return value; } - public void parseTrainingSets( final RefMetaDataTracker tracker, final ReferenceContext ref, final AlignmentContext context, final VariantContext evalVC, final VariantDatum datum, final boolean TRUST_ALL_POLYMORPHIC, final boolean FIX_OMNI ) { + public void parseTrainingSets( final RefMetaDataTracker tracker, final ReferenceContext ref, final AlignmentContext context, final VariantContext evalVC, final VariantDatum datum, final boolean TRUST_ALL_POLYMORPHIC ) { datum.isKnown = false; datum.atTruthSite = false; datum.atTrainingSite = false; datum.prior = 2.0; + datum.consensusCount = 0; for( final TrainingSet trainingSet : trainingSets ) { for( final VariantContext trainVC : tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, false ) ) { - if( trainVC != null && trainVC.isVariant() && (trainVC.isNotFiltered() || (FIX_OMNI && trainVC.getFilters().size()==1 && trainVC.getFilters().contains("NOT_POLY_IN_1000G"))) && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) { + if( trainVC != null && trainVC.isNotFiltered() && trainVC.isVariant() && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) { datum.isKnown = datum.isKnown || trainingSet.isKnown; datum.atTruthSite = datum.atTruthSite || trainingSet.isTruth; datum.atTrainingSite = datum.atTrainingSite || trainingSet.isTraining; datum.prior = Math.max( datum.prior, trainingSet.prior ); + datum.consensusCount += ( trainingSet.isConsensus ? 1 : 0 ); } } } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDatum.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDatum.java index 1526672ed..2509bba1e 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDatum.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantDatum.java @@ -11,15 +11,19 @@ import org.broadinstitute.sting.utils.GenomeLoc; public class VariantDatum implements Comparable { public double[] annotations; + public boolean[] isNull; public boolean isKnown; public double lod; public boolean atTruthSite; public boolean atTrainingSite; public boolean isTransition; public boolean isSNP; + public boolean failingSTDThreshold; public double originalQual; public double prior; + public int consensusCount; public GenomeLoc pos; + public int usedForTraining; public MultivariateGaussian assignment; // used in K-means implementation public int compareTo( final VariantDatum other ) { diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java index 170e399b9..e12fc20b5 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorArgumentCollection.java @@ -19,19 +19,21 @@ public class VariantRecalibratorArgumentCollection { @Argument(fullName = "mode", shortName = "mode", doc = "Recalibration mode to employ: 1.) SNP for recalibrating only SNPs (emitting indels untouched in the output VCF); 2.) INDEL for indels; and 3.) BOTH for recalibrating both SNPs and indels simultaneously.", required = false) public VariantRecalibratorArgumentCollection.Mode MODE = VariantRecalibratorArgumentCollection.Mode.SNP; @Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians to try during variational Bayes algorithm", required=false) - public int MAX_GAUSSIANS = 32; + public int MAX_GAUSSIANS = 10; @Argument(fullName="maxIterations", shortName="mI", doc="The maximum number of VBEM iterations to be performed in variational Bayes algorithm. Procedure will normally end when convergence is detected.", required=false) public int MAX_ITERATIONS = 100; @Argument(fullName="numKMeans", shortName="nKM", doc="The number of k-means iterations to perform in order to initialize the means of the Gaussians in the Gaussian mixture model.", required=false) - public int NUM_KMEANS_ITERATIONS = 10; + public int NUM_KMEANS_ITERATIONS = 30; @Argument(fullName="stdThreshold", shortName="std", doc="If a variant has annotations more than -std standard deviations away from mean then don't use it for building the Gaussian mixture model.", required=false) - public double STD_THRESHOLD = 4.5; + public double STD_THRESHOLD = 8.0; @Argument(fullName="qualThreshold", shortName="qual", doc="If a known variant has raw QUAL value less than -qual then don't use it for building the Gaussian mixture model.", required=false) public double QUAL_THRESHOLD = 80.0; @Argument(fullName="shrinkage", shortName="shrinkage", doc="The shrinkage parameter in variational Bayes algorithm.", required=false) - public double SHRINKAGE = 0.0001; + public double SHRINKAGE = 1.0; @Argument(fullName="dirichlet", shortName="dirichlet", doc="The dirichlet parameter in variational Bayes algorithm.", required=false) - public double DIRICHLET_PARAMETER = 0.0001; + public double DIRICHLET_PARAMETER = 0.001; + @Argument(fullName="priorCounts", shortName="priorCounts", doc="The number of prior counts to use in variational Bayes algorithm.", required=false) + public double PRIOR_COUNTS = 20.0; @Argument(fullName="percentBadVariants", shortName="percentBad", doc="What percentage of the worst scoring variants to use when building the Gaussian mixture model of bad variants. 0.07 means bottom 7 percent.", required=false) - public double PERCENT_BAD_VARIANTS = 0.07; + public double PERCENT_BAD_VARIANTS = 0.015; } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java index b4ce0040a..2a2565df8 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java @@ -21,7 +21,7 @@ public class VariantRecalibratorEngine { // the unified argument collection final private VariantRecalibratorArgumentCollection VRAC; - private final static double MIN_PROB_CONVERGENCE_LOG10 = 1.0; + private final static double MIN_PROB_CONVERGENCE = 2E-2; ///////////////////////////// // Public Methods to interface with the Engine @@ -33,7 +33,7 @@ public class VariantRecalibratorEngine { } public GaussianMixtureModel generateModel( final List data ) { - final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER ); + final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS ); variationalBayesExpectationMaximization( model, data ); return model; } @@ -63,25 +63,26 @@ public class VariantRecalibratorEngine { private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List data ) { - model.cacheEmpiricalStats( data ); + model.cacheEmpiricalStats(); model.initializeRandomModel( data, VRAC.NUM_KMEANS_ITERATIONS ); // The VBEM loop - double previousLikelihood = model.expectationStep( data ); + model.normalizePMixtureLog10(); + model.expectationStep( data ); double currentLikelihood; int iteration = 0; logger.info("Finished iteration " + iteration ); while( iteration < VRAC.MAX_ITERATIONS ) { iteration++; model.maximizationStep( data ); - currentLikelihood = model.expectationStep( data ); - + currentLikelihood = model.normalizePMixtureLog10(); + model.expectationStep( data ); + logger.info("Current change in mixture coefficients = " + String.format("%.5f", currentLikelihood)); logger.info("Finished iteration " + iteration ); - if( Math.abs(currentLikelihood - previousLikelihood) < MIN_PROB_CONVERGENCE_LOG10 ) { + if( iteration > 2 && currentLikelihood < MIN_PROB_CONVERGENCE ) { logger.info("Convergence!"); break; } - previousLikelihood = currentLikelihood; } model.evaluateFinalModelParameters( data ); diff --git a/java/src/org/broadinstitute/sting/utils/MathUtils.java b/java/src/org/broadinstitute/sting/utils/MathUtils.java index 01b697607..69d031190 100755 --- a/java/src/org/broadinstitute/sting/utils/MathUtils.java +++ b/java/src/org/broadinstitute/sting/utils/MathUtils.java @@ -420,6 +420,14 @@ public class MathUtils { return dist; } + public static double round(double num, int digits) { + double result = num * Math.pow(10.0, (double)digits); + result = Math.round(result); + result = result / Math.pow(10.0, (double)digits); + return result; + } + + /** * normalizes the log10-based array. ASSUMES THAT ALL ARRAY ENTRIES ARE <= 0 (<= 1 IN REAL-SPACE). * @@ -717,16 +725,6 @@ public class MathUtils { return ans; } - // lifted from the internet - // http://www.cs.princeton.edu/introcs/91float/Gamma.java.html - public static double logGamma(double x) { - double tmp = (x - 0.5) * Math.log(x + 4.5) - (x + 4.5); - double ser = 1.0 + 76.18009173 / (x + 0) - 86.50532033 / (x + 1) - + 24.01409822 / (x + 2) - 1.231739516 / (x + 3) - + 0.00120858003 / (x + 4) - 0.00000536382 / (x + 5); - return tmp + Math.log(ser * Math.sqrt(2 * Math.PI)); - } - public static double percentage(double x, double base) { return (base > 0 ? (x / base) * 100.0 : 0); } @@ -892,15 +890,6 @@ public class MathUtils { return getQScoreOrderStatistic(reads, offsets, (int)Math.floor(reads.size()/2.)); } - // from http://en.wikipedia.org/wiki/Digamma_function - // According to J.M. Bernardo AS 103 algorithm the digamma function for x, a real number, can be approximated by: - public static double diGamma(final double x) { - return Math.log(x) - ( 1.0 / (2.0 * x) ) - - ( 1.0 / (12.0 * Math.pow(x, 2.0)) ) - + ( 1.0 / (120.0 * Math.pow(x, 4.0)) ) - - ( 1.0 / (252.0 * Math.pow(x, 6.0)) ); - } - /** A utility class that computes on the fly average and standard deviation for a stream of numbers. * The number of observations does not have to be known in advance, and can be also very big (so that * it could overflow any naive summation-based scheme or cause loss of precision). diff --git a/java/test/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersV2IntegrationTest.java b/java/test/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersV2IntegrationTest.java index dea584949..7d923e21a 100755 --- a/java/test/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersV2IntegrationTest.java +++ b/java/test/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersV2IntegrationTest.java @@ -24,33 +24,30 @@ public class VariantRecalibrationWalkersV2IntegrationTest extends WalkerTest { } } - VRTest yriTrio = new VRTest("yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", - "74127253b3e551ae602b5957eef8e37e", // tranches - "38fe940269ce56449b2e4e3493f89cd9", // recal file - "067197ba4212a38c7fd736c1f66b294d"); // cut VCF - - VRTest lowPass = new VRTest("lowpass.N3.chr1.raw.vcf", - "40242ee5ce97ffde50cc7a12cfa0d841", // tranches - "eb8df1fbf3d71095bdb3a52713dd2a64", // recal file - "13aa221241a974883072bfad14ad8afc"); // cut VCF + VRTest lowPass = new VRTest("phase1.projectConsensus.chr20.raw.snps.vcf", + "920b12d7765eb4f6f4a1bab045679b31", // tranches + "41bbc5f07c8a9573d5bb638f01808bba", // recal file + "2b8c5e884bf5a739b782f5b3bf17f19c"); // cut VCF @DataProvider(name = "VRTest") public Object[][] createData1() { - return new Object[][]{ {yriTrio}, {lowPass} }; + return new Object[][]{ {lowPass} }; + //return new Object[][]{ {yriTrio}, {lowPass} }; // Add hg19 chr20 trio calls here } @Test(dataProvider = "VRTest") public void testVariantRecalibrator(VRTest params) { //System.out.printf("PARAMS FOR %s is %s%n", vcf, clusterFile); WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec( - "-R " + b36KGReference + - " -B:dbsnp,DBSNP,known=true,training=false,truth=false,prior=10.0 " + GATKDataLocation + "dbsnp_129_b36.rod" + - " -B:hapmap,VCF,known=false,training=true,truth=true,prior=15.0 " + comparisonDataLocation + "Validated/HapMap/3.2/sites_r27_nr.b36_fwd.vcf" + - " -B:omni,VCF,known=false,training=true,truth=true,prior=12.0 " + comparisonDataLocation + "Validated/Omni2.5_chip/1212samples.b36.vcf" + + "-R " + b37KGReference + + " -B:dbsnp,VCF,known=true,training=false,truth=false,prior=10.0 " + GATKDataLocation + "dbsnp_132_b37.leftAligned.vcf" + + " -B:hapmap,VCF,known=false,training=true,truth=true,prior=15.0 " + comparisonDataLocation + "Validated/HapMap/3.3/sites_r27_nr.b37_fwd.vcf" + + " -B:omni,VCF,known=false,training=true,truth=true,prior=12.0 " + comparisonDataLocation + "Validated/Omni2.5_chip/Omni25_sites_1525_samples.b37.vcf" + " -T ContrastiveRecalibrator" + " -B:input,VCF " + params.inVCF + - " -L 1:50,000,000-120,000,000" + - " -an QD -an MQ -an SB" + + " -L 20:1,000,000-40,000,000" + + " -an QD -an HaplotypeScore -an HRun" + + " -percentBad 0.07" + " --trustAllPolymorphic" + // for speed " -recalFile %s" + " -tranchesFile %s", @@ -61,9 +58,9 @@ public class VariantRecalibrationWalkersV2IntegrationTest extends WalkerTest { @Test(dataProvider = "VRTest",dependsOnMethods="testVariantRecalibrator") public void testApplyRecalibration(VRTest params) { WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec( - "-R " + b36KGReference + + "-R " + b37KGReference + " -T ApplyRecalibration" + - " -L 1:60,000,000-115,000,000" + + " -L 20:12,000,000-30,000,000" + " -NO_HEADER" + " -B:input,VCF " + params.inVCF + " -o %s" + diff --git a/scala/qscript/core/MethodsDevelopmentCallingPipeline.scala b/scala/qscript/core/MethodsDevelopmentCallingPipeline.scala index fe31b6bf2..15a7be90a 100755 --- a/scala/qscript/core/MethodsDevelopmentCallingPipeline.scala +++ b/scala/qscript/core/MethodsDevelopmentCallingPipeline.scala @@ -96,8 +96,8 @@ class MethodsDevelopmentCallingPipeline extends QScript { val hapmap_b36 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/HapMap/3.3/sites_r27_nr.b36_fwd.vcf" val hapmap_b37 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/HapMap/3.3/sites_r27_nr.b37_fwd.vcf" val training_hapmap_b37 = "/humgen/1kg/processing/pipeline_test_bams/hapmap3.3_training_chr20.vcf" - val omni_b36 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/1212samples.b36.vcf" - val omni_b37 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/1212samples.b37.vcf" + val omni_b36 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/Omni25_sites_1525_samples.b36.vcf" + val omni_b37 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/Omni25_sites_1525_samples.b37.vcf" val indelMask_b36 = "/humgen/1kg/processing/pipeline_test_bams/pilot1.dindel.mask.b36.bed" val indelMask_b37 = "/humgen/1kg/processing/pipeline_test_bams/pilot1.dindel.mask.b37.bed" @@ -105,7 +105,7 @@ class MethodsDevelopmentCallingPipeline extends QScript { val indels: Boolean = true val queueLogDir = ".qlog/" - + val targetDataSets: Map[String, Target] = Map( "HiSeq" -> new Target("NA12878.HiSeq", hg18, dbSNP_hg18_129, hapmap_hg18, "/humgen/gsa-hpprojects/dev/depristo/oneOffProjects/1000GenomesProcessingPaper/wgs.v13/HiSeq.WGS.cleaned.indels.10.mask", @@ -115,7 +115,7 @@ class MethodsDevelopmentCallingPipeline extends QScript { "HiSeq19" -> new Target("NA12878.HiSeq19", hg19, dbSNP_b37_129, hapmap_b37, indelMask_b37, new File("/humgen/gsa-hpprojects/NA12878Collection/bams/NA12878.HiSeq.WGS.bwa.cleaned.recal.hg19.bam"), new File("/humgen/gsa-hpprojects/dev/carneiro/hiseq19/analysis/snps/NA12878.HiSeq19.filtered.vcf"), - "/humgen/1kg/processing/pipeline_test_bams/whole_genome_chunked.hg19.intervals", 2.3, 98.0, !lowPass), + "/humgen/1kg/processing/pipeline_test_bams/whole_genome_chunked.hg19.intervals", 2.3, 99.0, !lowPass), "GA2hg19" -> new Target("NA12878.GA2.hg19", hg19, dbSNP_b37_129, hapmap_b37, indelMask_b37, new File("/humgen/gsa-hpprojects/NA12878Collection/bams/NA12878.GA2.WGS.bwa.cleaned.hg19.bam"), new File("/humgen/gsa-hpprojects/dev/carneiro/hiseq19/analysis/snps/NA12878.GA2.hg19.filtered.vcf"), @@ -251,17 +251,17 @@ class MethodsDevelopmentCallingPipeline extends QScript { this.rodBind :+= RodBind("input", "VCF", if ( goldStandard ) { t.goldStandard_VCF } else { t.rawVCF } ) this.rodBind :+= RodBind("hapmap", "VCF", t.hapmapFile, "known=false,training=true,truth=true,prior=15.0") if( t.hapmapFile.contains("b37") ) - this.rodBind :+= RodBind("omni", "VCF", omni_b37, "known=false,training=true,truth=false,prior=12.0") + this.rodBind :+= RodBind("omni", "VCF", omni_b37, "known=false,training=true,truth=true,prior=12.0") else if( t.hapmapFile.contains("b36") ) - this.rodBind :+= RodBind("omni", "VCF", omni_b36, "known=false,training=true,truth=false,prior=12.0") + this.rodBind :+= RodBind("omni", "VCF", omni_b36, "known=false,training=true,truth=true,prior=12.0") if (t.dbsnpFile.endsWith(".rod")) this.rodBind :+= RodBind("dbsnp", "DBSNP", t.dbsnpFile, "known=true,training=false,truth=false,prior=10.0") else if (t.dbsnpFile.endsWith(".vcf")) this.rodBind :+= RodBind("dbsnp", "VCF", t.dbsnpFile, "known=true,training=false,truth=false,prior=10.0") - this.use_annotation ++= List("QD", "SB", "HaplotypeScore", "HRun") + this.use_annotation ++= List("QD", "HaplotypeScore", "MQRankSum", "ReadPosRankSum", "HRun") this.tranches_file = if ( goldStandard ) { t.goldStandardTranchesFile } else { t.tranchesFile } this.recal_file = if ( goldStandard ) { t.goldStandardRecalFile } else { t.recalFile } - this.fixOmni = true // temporary argument until new Omni file is released + this.allPoly = true this.tranche ++= List("100.0", "99.9", "99.5", "99.3", "99.0", "98.9", "98.8", "98.5", "98.4", "98.3", "98.2", "98.1", "98.0", "97.9", "97.8", "97.5", "97.0", "95.0", "90.0") this.analysisName = t.name + "_VQSR" this.jobName = queueLogDir + t.name + ".VQSR"