New visualization output for VQSR. It creates the R script file on the fly and then runs Rscript on it. Adding 1000G Project consensus code. First pass of having VQSR work with missing data by marginalizing over the missing dimension for that data point (thanks Chris and Bob for ideas). Updated math functions to use apache math commons instead of approximations from wikipedia. New parameters available for the priors based on further reading in Bishop and looking at the new visualizations. Updated integration test to use more modern files. Updated MDCP to use new best practices w.r.t. annotations.
git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@5723 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
parent
fcf8cff64a
commit
3224bbe750
|
|
@ -37,10 +37,13 @@ import org.broadinstitute.sting.gatk.datasources.rmd.ReferenceOrderedDataSource;
|
|||
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
|
||||
import org.broadinstitute.sting.gatk.walkers.RodWalker;
|
||||
import org.broadinstitute.sting.gatk.walkers.TreeReducible;
|
||||
import org.broadinstitute.sting.utils.MathUtils;
|
||||
import org.broadinstitute.sting.utils.QualityUtils;
|
||||
import org.broadinstitute.sting.utils.Utils;
|
||||
import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
|
||||
import org.broadinstitute.sting.utils.exceptions.UserException;
|
||||
|
||||
import java.io.FileNotFoundException;
|
||||
import java.io.PrintStream;
|
||||
import java.util.*;
|
||||
|
||||
|
|
@ -76,6 +79,12 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
|
|||
private double[] TS_TRANCHES = new double[] {100.0, 99.9, 99.0, 90.0};
|
||||
@Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false)
|
||||
private String[] IGNORE_INPUT_FILTERS = null;
|
||||
@Argument(fullName="path_to_Rscript", shortName = "Rscript", doc = "The path to your implementation of Rscript. For Broad users this is maybe /broad/tools/apps/R-2.6.0/bin/Rscript", required=false)
|
||||
private String PATH_TO_RSCRIPT = "Rscript";
|
||||
@Argument(fullName="rscript_file", shortName="rscriptFile", doc="The output rscript file generated by the VQSR to aid in visualization of the input data and learned model", required=false)
|
||||
private String RSCRIPT_FILE = null;
|
||||
@Argument(fullName="ts_filter_level", shortName="ts_filter_level", doc="The truth sensitivity level at which to start filtering, used here to indicate filtered variants in plots", required=false)
|
||||
private double TS_FILTER_LEVEL = 99.0;
|
||||
|
||||
/////////////////////////////
|
||||
// Debug Arguments
|
||||
|
|
@ -83,10 +92,6 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
|
|||
@Hidden
|
||||
@Argument(fullName = "trustAllPolymorphic", shortName = "allPoly", doc = "Trust that all the input training sets' unfiltered records contain only polymorphic sites to drastically speed up the computation.", required = false)
|
||||
protected Boolean TRUST_ALL_POLYMORPHIC = false;
|
||||
@Hidden
|
||||
@Argument(fullName = "fixOmni", shortName = "fixOmni", doc = "Ignore the NOT_POLY_IN_1000G filter for the omni file because it is broken.", required = false)
|
||||
protected Boolean FIX_OMNI = false; //BUGBUG: remove me very soon!
|
||||
|
||||
|
||||
/////////////////////////////
|
||||
// Private Member Variables
|
||||
|
|
@ -150,13 +155,18 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
|
|||
if( vc != null && ( vc.isNotFiltered() || ignoreInputFilterSet.containsAll(vc.getFilters()) ) ) {
|
||||
if( checkRecalibrationMode( vc, VRAC.MODE ) ) {
|
||||
final VariantDatum datum = new VariantDatum();
|
||||
datum.annotations = dataManager.decodeAnnotations( ref.getGenomeLocParser(), vc, true ); //BUGBUG: when run with HierarchicalMicroScheduler this is non-deterministic because order of calls depends on load of machine
|
||||
dataManager.decodeAnnotations( datum, vc, true ); //BUGBUG: when run with HierarchicalMicroScheduler this is non-deterministic because order of calls depends on load of machine
|
||||
datum.pos = context.getLocation();
|
||||
datum.originalQual = vc.getPhredScaledQual();
|
||||
datum.isSNP = vc.isSNP() && vc.isBiallelic();
|
||||
datum.isTransition = datum.isSNP && VariantContextUtils.isTransition(vc);
|
||||
dataManager.parseTrainingSets( tracker, ref, context, vc, datum, TRUST_ALL_POLYMORPHIC, FIX_OMNI );
|
||||
final double priorFactor = QualityUtils.qualToProb( datum.prior );
|
||||
datum.usedForTraining = 0;
|
||||
dataManager.parseTrainingSets( tracker, ref, context, vc, datum, TRUST_ALL_POLYMORPHIC );
|
||||
double priorFactor = QualityUtils.qualToProb( datum.prior );
|
||||
if( datum.consensusCount != 0 ) {
|
||||
final double consensusPrior = QualityUtils.qualToProb( 1.0 + 5.0 * datum.consensusCount );
|
||||
priorFactor = 1.0 - ((1.0 - priorFactor) * (1.0 - consensusPrior));
|
||||
}
|
||||
datum.prior = Math.log10( priorFactor ) - Math.log10( 1.0 - priorFactor );
|
||||
mapList.add( datum );
|
||||
}
|
||||
|
|
@ -201,15 +211,149 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
|
|||
public void onTraversalDone( final ExpandingArrayList<VariantDatum> reduceSum ) {
|
||||
dataManager.setData( reduceSum );
|
||||
dataManager.normalizeData();
|
||||
engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.getTrainingData() ), false );
|
||||
engine.evaluateData( dataManager.getData(), engine.generateModel( dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS ) ), true );
|
||||
final GaussianMixtureModel goodModel = engine.generateModel( dataManager.getTrainingData() );
|
||||
engine.evaluateData( dataManager.getData(), goodModel, false );
|
||||
final GaussianMixtureModel badModel = engine.generateModel( dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS ) );
|
||||
engine.evaluateData( dataManager.getData(), badModel, true );
|
||||
|
||||
final ExpandingArrayList<VariantDatum> randomData = dataManager.getRandomDataForPlotting( 6000 );
|
||||
|
||||
final int nCallsAtTruth = TrancheManager.countCallsAtTruth( dataManager.getData(), Double.NEGATIVE_INFINITY );
|
||||
final TrancheManager.SelectionMetric metric = new TrancheManager.TruthSensitivityMetric( nCallsAtTruth );
|
||||
final List<Tranche> tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric );
|
||||
TRANCHES_FILE.print(Tranche.tranchesString( tranches ));
|
||||
|
||||
double lodCutoff = 0.0;
|
||||
for(final Tranche tranche : tranches) {
|
||||
if(MathUtils.compareDoubles(tranche.ts, TS_FILTER_LEVEL, 0.0001)==0) {
|
||||
lodCutoff = tranche.minVQSLod;
|
||||
}
|
||||
}
|
||||
|
||||
logger.info( "Writing out recalibration table..." );
|
||||
dataManager.writeOutRecalibrationTable( RECAL_FILE );
|
||||
if( RSCRIPT_FILE != null ) {
|
||||
logger.info( "Writing out visualization Rscript file...");
|
||||
createVisualizationScript( randomData, goodModel, badModel, lodCutoff );
|
||||
}
|
||||
|
||||
}
|
||||
}
|
||||
|
||||
private void createVisualizationScript( final ExpandingArrayList<VariantDatum> randomData, final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, final double lodCutoff ) {
|
||||
PrintStream stream;
|
||||
try {
|
||||
stream = new PrintStream(RSCRIPT_FILE);
|
||||
} catch( FileNotFoundException e ) {
|
||||
throw new UserException.CouldNotCreateOutputFile(RSCRIPT_FILE, "", e);
|
||||
}
|
||||
stream.println("library(ggplot2)");
|
||||
|
||||
createArrangeFunction( stream );
|
||||
|
||||
stream.println("pdf(\"" + RSCRIPT_FILE + ".pdf\")");
|
||||
|
||||
for(int iii = 0; iii < USE_ANNOTATIONS.length; iii++) {
|
||||
for( int jjj = iii + 1; jjj < USE_ANNOTATIONS.length; jjj++) {
|
||||
//stream.println("png(\"" + RSCRIPT_FILE + "." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj] + ".png\", type=\"cairo\", width = 960, height = 960)");
|
||||
//stream.println("pdf(\"" + RSCRIPT_FILE + "." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj] + ".pdf\")");
|
||||
logger.info( "Building " + USE_ANNOTATIONS[iii] + " x " + USE_ANNOTATIONS[jjj] + " plot...");
|
||||
|
||||
final ExpandingArrayList<VariantDatum> fakeData = new ExpandingArrayList<VariantDatum>();
|
||||
double minAnn1 = 100.0, maxAnn1 = -100.0, minAnn2 = 100.0, maxAnn2 = -100.0;
|
||||
for( final VariantDatum datum : randomData ) {
|
||||
minAnn1 = Math.min(minAnn1, datum.annotations[iii]);
|
||||
maxAnn1 = Math.max(maxAnn1, datum.annotations[iii]);
|
||||
minAnn2 = Math.min(minAnn2, datum.annotations[jjj]);
|
||||
maxAnn2 = Math.max(maxAnn2, datum.annotations[jjj]);
|
||||
}
|
||||
for(double ann1 = minAnn1; ann1 <= maxAnn1; ann1+=0.1) {
|
||||
for(double ann2 = minAnn2; ann2 <= maxAnn2; ann2+=0.1) {
|
||||
final VariantDatum datum = new VariantDatum();
|
||||
datum.prior = 0.0;
|
||||
datum.annotations = new double[randomData.get(0).annotations.length];
|
||||
datum.isNull = new boolean[randomData.get(0).annotations.length];
|
||||
for(int ann=0; ann< datum.annotations.length; ann++) {
|
||||
datum.annotations[ann] = 0.0;
|
||||
datum.isNull[ann] = true;
|
||||
}
|
||||
datum.annotations[iii] = ann1;
|
||||
datum.annotations[jjj] = ann2;
|
||||
datum.isNull[iii] = false;
|
||||
datum.isNull[jjj] = false;
|
||||
fakeData.add(datum);
|
||||
}
|
||||
}
|
||||
|
||||
engine.evaluateData( fakeData, goodModel, false );
|
||||
engine.evaluateData( fakeData, badModel, true );
|
||||
|
||||
stream.print("surface <- c(");
|
||||
for( final VariantDatum datum : fakeData ) {
|
||||
stream.print(String.format("%.3f, %.3f, %.3f, ", datum.annotations[iii], datum.annotations[jjj], Math.min(4.0, Math.max(-4.0, datum.lod))));
|
||||
}
|
||||
stream.println("NA,NA,NA)");
|
||||
stream.println("s <- matrix(surface,ncol=3,byrow=T)");
|
||||
|
||||
stream.print("data <- c(");
|
||||
for( final VariantDatum datum : randomData ) {
|
||||
stream.print(String.format("%.3f, %.3f, %.3f, %d, %d,", datum.annotations[iii], datum.annotations[jjj], (datum.lod < lodCutoff ? -1.0 : 1.0), datum.usedForTraining, (datum.isKnown ? 1 : -1)));
|
||||
}
|
||||
stream.println("NA,NA,NA,NA,NA)");
|
||||
stream.println("d <- matrix(data,ncol=5,byrow=T)");
|
||||
|
||||
final String surfaceFrame = "sf." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj];
|
||||
final String dataFrame = "df." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj];
|
||||
|
||||
stream.println(surfaceFrame + " <- data.frame(x=s[,1], y=s[,2], lod=s[,3])");
|
||||
stream.println(dataFrame + " <- data.frame(x=d[,1], y=d[,2], retained=d[,3], training=d[,4], novelty=d[,5])");
|
||||
stream.println("p <- ggplot(data=" + surfaceFrame + ", aes(x=x, y=y)) + opts(panel.background = theme_rect(colour = NA), panel.grid.minor = theme_line(colour = NA), panel.grid.major = theme_line(colour = NA))");
|
||||
stream.println("p1 = p + opts(title=\"model PDF\") + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_tile(aes(fill = lod)) + scale_fill_gradient(high=\"green\", low=\"red\")");
|
||||
stream.println("p <- ggplot(data=" + dataFrame + ", aes(x=x, y=y)) + opts(panel.background = theme_rect(colour = NA), panel.grid.minor = theme_line(colour = NA), panel.grid.major = theme_line(colour = NA))");
|
||||
stream.println("p2 = p + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_point(data="+ dataFrame + ", aes(x=x, y=y, colour = retained, alpha=0.3, size=1.5)) + scale_colour_gradient(name=\"\", high=\"black\", low=\"red\",breaks=c(-1,1),labels=c(\"filtered\",\"retained\"))");
|
||||
stream.println("p3 = p + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_point(data="+ dataFrame + "["+dataFrame+"$training==0,], aes(x=x, y=y, colour = training, alpha=0.3, size=1.5)) + geom_point(data="+ dataFrame + "["+dataFrame+"$training!=0,], aes(x=x, y=y, colour = training, alpha=0.3, size=1.5)) + scale_colour_gradient2(high=\"green\", mid=\"lightgrey\", low=\"purple\",breaks=c(-1,0,1), labels=c(\"bad\", \"\", \"good\"))");
|
||||
stream.println("p4 = p + labs(x=\""+ USE_ANNOTATIONS[iii] +"\", y=\""+ USE_ANNOTATIONS[jjj] +"\") + geom_point(data="+ dataFrame + ", aes(x=x, y=y, colour = novelty, alpha=0.3, size=1.5)) + scale_colour_gradient(name=\"\", high=\"blue\", low=\"red\",breaks=c(-1,1), labels=c(\"novel\",\"known\"))");
|
||||
stream.println("arrange(p1, p2, p3, p4, ncol=2)");
|
||||
}
|
||||
}
|
||||
stream.println("dev.off()");
|
||||
|
||||
stream.close();
|
||||
|
||||
// Execute Rscript command to generate the clustering plots
|
||||
final String rScriptTranchesCommandLine = PATH_TO_RSCRIPT + " " + RSCRIPT_FILE;
|
||||
logger.info( "Executing: " + rScriptTranchesCommandLine );
|
||||
try {
|
||||
Process p;
|
||||
p = Runtime.getRuntime().exec( rScriptTranchesCommandLine );
|
||||
p.waitFor();
|
||||
} catch ( Exception e ) {
|
||||
Utils.warnUser("Unable to execute the RScript command. While not critical to the calculations themselves, the script outputs a report that is extremely useful for visualizing the recalibration results. We highly recommend trying to rerun the script manually if possible.");
|
||||
}
|
||||
}
|
||||
|
||||
// from http://gettinggeneticsdone.blogspot.com/2010/03/arrange-multiple-ggplot2-plots-in-same.html
|
||||
private void createArrangeFunction( final PrintStream stream ) {
|
||||
stream.println("vp.layout <- function(x, y) viewport(layout.pos.row=x, layout.pos.col=y)");
|
||||
stream.println("arrange <- function(..., nrow=NULL, ncol=NULL, as.table=FALSE) {");
|
||||
stream.println("dots <- list(...)");
|
||||
stream.println("n <- length(dots)");
|
||||
stream.println("if(is.null(nrow) & is.null(ncol)) { nrow = floor(n/2) ; ncol = ceiling(n/nrow)}");
|
||||
stream.println("if(is.null(nrow)) { nrow = ceiling(n/ncol)}");
|
||||
stream.println("if(is.null(ncol)) { ncol = ceiling(n/nrow)}");
|
||||
stream.println("grid.newpage()");
|
||||
stream.println("pushViewport(viewport(layout=grid.layout(nrow,ncol) ) )");
|
||||
stream.println("ii.p <- 1");
|
||||
stream.println("for(ii.row in seq(1, nrow)){");
|
||||
stream.println("ii.table.row <- ii.row ");
|
||||
stream.println("if(as.table) {ii.table.row <- nrow - ii.table.row + 1}");
|
||||
stream.println("for(ii.col in seq(1, ncol)){");
|
||||
stream.println("ii.table <- ii.p");
|
||||
stream.println("if(ii.p > n) break");
|
||||
stream.println("print(dots[[ii.table]], vp=vp.layout(ii.table.row, ii.col))");
|
||||
stream.println("ii.p <- ii.p + 1");
|
||||
stream.println("}");
|
||||
stream.println("}");
|
||||
stream.println("}");
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration;
|
||||
|
||||
import Jama.Matrix;
|
||||
import cern.jet.random.Normal;
|
||||
import org.apache.log4j.Logger;
|
||||
import org.broadinstitute.sting.gatk.GenomeAnalysisEngine;
|
||||
import org.broadinstitute.sting.utils.MathUtils;
|
||||
|
||||
import java.util.ArrayList;
|
||||
import java.util.Arrays;
|
||||
import java.util.List;
|
||||
|
||||
/**
|
||||
|
|
@ -21,13 +23,13 @@ public class GaussianMixtureModel {
|
|||
private final ArrayList<MultivariateGaussian> gaussians;
|
||||
private final double shrinkage;
|
||||
private final double dirichletParameter;
|
||||
private final double degreesOfFreedom;
|
||||
private final double priorCounts;
|
||||
private final double[] empiricalMu;
|
||||
private final Matrix empiricalSigma;
|
||||
public boolean isModelReadyForEvaluation;
|
||||
|
||||
public GaussianMixtureModel( final int numGaussians, final int numAnnotations,
|
||||
final double shrinkage, final double dirichletParameter ) {
|
||||
final double shrinkage, final double dirichletParameter, final double priorCounts ) {
|
||||
|
||||
gaussians = new ArrayList<MultivariateGaussian>( numGaussians );
|
||||
for( int iii = 0; iii < numGaussians; iii++ ) {
|
||||
|
|
@ -36,39 +38,15 @@ public class GaussianMixtureModel {
|
|||
}
|
||||
this.shrinkage = shrinkage;
|
||||
this.dirichletParameter = dirichletParameter;
|
||||
degreesOfFreedom = numAnnotations + 2;
|
||||
this.priorCounts = priorCounts;
|
||||
empiricalMu = new double[numAnnotations];
|
||||
empiricalSigma = new Matrix(numAnnotations, numAnnotations);
|
||||
isModelReadyForEvaluation = false;
|
||||
}
|
||||
|
||||
public void cacheEmpiricalStats( final List<VariantDatum> data ) {
|
||||
final double[][] tmpSigmaVals = new double[empiricalMu.length][empiricalMu.length];
|
||||
for( int iii = 0; iii < empiricalMu.length; iii++ ) {
|
||||
empiricalMu[iii] = 0.0;
|
||||
for( int jjj = iii; jjj < empiricalMu.length; jjj++ ) {
|
||||
tmpSigmaVals[iii][jjj] = 0.0;
|
||||
}
|
||||
}
|
||||
|
||||
for( final VariantDatum datum : data ) {
|
||||
for( int iii = 0; iii < empiricalMu.length; iii++ ) {
|
||||
empiricalMu[iii] += datum.annotations[iii] / ((double) data.size());
|
||||
}
|
||||
}
|
||||
|
||||
//for( final VariantDatum datum : data ) {
|
||||
// for( int iii = 0; iii < empiricalMu.length; iii++ ) {
|
||||
// for( int jjj = 0; jjj < empiricalMu.length; jjj++ ) {
|
||||
// tmpSigmaVals[iii][jjj] += (datum.annotations[iii]-empiricalMu[iii]) * (datum.annotations[jjj]-empiricalMu[jjj]);
|
||||
// }
|
||||
// }
|
||||
//}
|
||||
|
||||
//empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, new Matrix(tmpSigmaVals));
|
||||
//empiricalSigma.timesEquals( 1.0 / ((double) data.size()) );
|
||||
//empiricalSigma.timesEquals( 1.0 / (Math.pow(gaussians.size(), 2.0 / ((double) empiricalMu.length))) );
|
||||
empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length));
|
||||
public void cacheEmpiricalStats() {
|
||||
Arrays.fill(empiricalMu, 0.0);
|
||||
empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length).times(200.0).inverse());
|
||||
}
|
||||
|
||||
public void initializeRandomModel( final List<VariantDatum> data, final int numKMeansIterations ) {
|
||||
|
|
@ -85,8 +63,9 @@ public class GaussianMixtureModel {
|
|||
// initialize uniform mixture coefficients, random covariance matrices, and initial hyperparameters
|
||||
for( final MultivariateGaussian gaussian : gaussians ) {
|
||||
gaussian.pMixtureLog10 = Math.log10( 1.0 / ((double) gaussians.size()) );
|
||||
gaussian.sumProb = 1.0 / ((double) gaussians.size());
|
||||
gaussian.initializeRandomSigma( GenomeAnalysisEngine.getRandomGenerator() );
|
||||
gaussian.hyperParameter_a = degreesOfFreedom;
|
||||
gaussian.hyperParameter_a = priorCounts;
|
||||
gaussian.hyperParameter_b = shrinkage;
|
||||
gaussian.hyperParameter_lambda = dirichletParameter;
|
||||
}
|
||||
|
|
@ -129,19 +108,17 @@ public class GaussianMixtureModel {
|
|||
}
|
||||
}
|
||||
|
||||
public double expectationStep( final List<VariantDatum> data ) {
|
||||
public void expectationStep( final List<VariantDatum> data ) {
|
||||
|
||||
for( final MultivariateGaussian gaussian : gaussians ) {
|
||||
gaussian.precomputeDenominatorForVariationalBayes( getSumHyperParameterLambda() );
|
||||
}
|
||||
|
||||
double likelihood = 0.0;
|
||||
for( final VariantDatum datum : data ) {
|
||||
final ArrayList<Double> pVarInGaussianLog10 = new ArrayList<Double>( gaussians.size() );
|
||||
for( final MultivariateGaussian gaussian : gaussians ) {
|
||||
final double pVarLog10 = gaussian.evaluateDatumLog10( datum );
|
||||
pVarInGaussianLog10.add( pVarLog10 );
|
||||
likelihood += pVarLog10;
|
||||
}
|
||||
final double[] pVarInGaussianNormalized = MathUtils.normalizeFromLog10( pVarInGaussianLog10 );
|
||||
int iii = 0;
|
||||
|
|
@ -149,15 +126,11 @@ public class GaussianMixtureModel {
|
|||
gaussian.assignPVarInGaussian( pVarInGaussianNormalized[iii++] ); //BUGBUG: to clean up
|
||||
}
|
||||
}
|
||||
|
||||
final double scaledTotalLikelihoodLog10 = likelihood / ((double) data.size());
|
||||
logger.info( "sum Log10 likelihood = " + String.format("%.5f", scaledTotalLikelihoodLog10) );
|
||||
return scaledTotalLikelihoodLog10;
|
||||
}
|
||||
|
||||
public void maximizationStep( final List<VariantDatum> data ) {
|
||||
for( final MultivariateGaussian gaussian : gaussians ) {
|
||||
gaussian.maximizeGaussian( data, empiricalMu, empiricalSigma, shrinkage, dirichletParameter, degreesOfFreedom );
|
||||
gaussian.maximizeGaussian( data, empiricalMu, empiricalSigma, shrinkage, dirichletParameter, priorCounts);
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -171,28 +144,31 @@ public class GaussianMixtureModel {
|
|||
|
||||
public void evaluateFinalModelParameters( final List<VariantDatum> data ) {
|
||||
for( final MultivariateGaussian gaussian : gaussians ) {
|
||||
gaussian.evaluateFinalModelParameters( data );
|
||||
gaussian.evaluateFinalModelParameters(data);
|
||||
}
|
||||
normalizePMixtureLog10();
|
||||
}
|
||||
|
||||
private void normalizePMixtureLog10() {
|
||||
public double normalizePMixtureLog10() {
|
||||
double sumDiff = 0.0;
|
||||
double sumPK = 0.0;
|
||||
for( final MultivariateGaussian gaussian : gaussians ) {
|
||||
sumPK += gaussian.pMixtureLog10;
|
||||
sumPK += gaussian.sumProb;
|
||||
}
|
||||
|
||||
int gaussianIndex = 0;
|
||||
double[] pGaussianLog10 = new double[gaussians.size()];
|
||||
for( final MultivariateGaussian gaussian : gaussians ) {
|
||||
pGaussianLog10[gaussianIndex++] = Math.log10( gaussian.pMixtureLog10 / sumPK ); //BUGBUG: to clean up
|
||||
pGaussianLog10[gaussianIndex++] = Math.log10( gaussian.sumProb / sumPK );
|
||||
}
|
||||
pGaussianLog10 = MathUtils.normalizeFromLog10( pGaussianLog10, true );
|
||||
|
||||
gaussianIndex = 0;
|
||||
for( final MultivariateGaussian gaussian : gaussians ) {
|
||||
sumDiff += Math.abs( pGaussianLog10[gaussianIndex] - gaussian.pMixtureLog10 );
|
||||
gaussian.pMixtureLog10 = pGaussianLog10[gaussianIndex++];
|
||||
}
|
||||
return sumDiff;
|
||||
}
|
||||
|
||||
public void precomputeDenominatorForEvaluation() {
|
||||
|
|
@ -204,6 +180,9 @@ public class GaussianMixtureModel {
|
|||
}
|
||||
|
||||
public double evaluateDatum( final VariantDatum datum ) {
|
||||
for( final boolean isNull : datum.isNull ) {
|
||||
if( isNull ) { return evaluateDatumMarginalized( datum ); }
|
||||
}
|
||||
final double[] pVarInGaussianLog10 = new double[gaussians.size()];
|
||||
int gaussianIndex = 0;
|
||||
for( final MultivariateGaussian gaussian : gaussians ) {
|
||||
|
|
@ -211,4 +190,27 @@ public class GaussianMixtureModel {
|
|||
}
|
||||
return MathUtils.log10sumLog10(pVarInGaussianLog10);
|
||||
}
|
||||
|
||||
public double evaluateDatumMarginalized( final VariantDatum datum ) {
|
||||
int numVals = 0;
|
||||
double sumPVarInGaussian = 0.0;
|
||||
int numIter = 10;
|
||||
final double[] pVarInGaussianLog10 = new double[gaussians.size()];
|
||||
for( int iii = 0; iii < datum.annotations.length; iii++ ) {
|
||||
if( datum.isNull[iii] ) {
|
||||
for( int ttt = 0; ttt < numIter; ttt++ ) {
|
||||
datum.annotations[iii] = Normal.staticNextDouble(0.0, 1.0);
|
||||
|
||||
int gaussianIndex = 0;
|
||||
for( final MultivariateGaussian gaussian : gaussians ) {
|
||||
pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10( datum );
|
||||
}
|
||||
|
||||
sumPVarInGaussian += Math.pow(10.0, MathUtils.log10sumLog10(pVarInGaussianLog10));
|
||||
numVals++;
|
||||
}
|
||||
}
|
||||
}
|
||||
return Math.log10( sumPVarInGaussian / ((double) numVals) );
|
||||
}
|
||||
}
|
||||
|
|
@ -1,6 +1,7 @@
|
|||
package org.broadinstitute.sting.playground.gatk.walkers.variantrecalibration;
|
||||
|
||||
import Jama.Matrix;
|
||||
import org.apache.commons.math.special.Gamma;
|
||||
import org.broadinstitute.sting.utils.MathUtils;
|
||||
import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
|
||||
|
||||
|
|
@ -16,6 +17,7 @@ import java.util.Random;
|
|||
|
||||
public class MultivariateGaussian {
|
||||
public double pMixtureLog10;
|
||||
public double sumProb;
|
||||
final public double[] mu;
|
||||
final public Matrix sigma;
|
||||
public double hyperParameter_a;
|
||||
|
|
@ -52,15 +54,15 @@ public class MultivariateGaussian {
|
|||
|
||||
public void initializeRandomSigma( final Random rand ) {
|
||||
final double[][] randSigma = new double[mu.length][mu.length];
|
||||
for( int iii = 0; iii < mu.length; iii++ ) {
|
||||
for( int jjj = iii; jjj < mu.length; jjj++ ) {
|
||||
randSigma[jjj][iii] = 0.55 + 1.25 * rand.nextDouble();
|
||||
if( rand.nextBoolean() ) {
|
||||
randSigma[jjj][iii] *= -1.0;
|
||||
}
|
||||
if( iii != jjj ) { randSigma[iii][jjj] = 0.0; } // Sigma is a symmetric, positive-definite matrix created by taking a lower diagonal matrix and multiplying it by its transpose
|
||||
for( int iii = 0; iii < mu.length; iii++ ) {
|
||||
for( int jjj = iii; jjj < mu.length; jjj++ ) {
|
||||
randSigma[jjj][iii] = 0.55 + 1.25 * rand.nextDouble();
|
||||
if( rand.nextBoolean() ) {
|
||||
randSigma[jjj][iii] *= -1.0;
|
||||
}
|
||||
if( iii != jjj ) { randSigma[iii][jjj] = 0.0; } // Sigma is a symmetric, positive-definite matrix created by taking a lower diagonal matrix and multiplying it by its transpose
|
||||
}
|
||||
}
|
||||
Matrix tmp = new Matrix( randSigma );
|
||||
tmp = tmp.times(tmp.transpose());
|
||||
sigma.setMatrix(0, mu.length - 1, 0, mu.length - 1, tmp);
|
||||
|
|
@ -95,15 +97,15 @@ public class MultivariateGaussian {
|
|||
cachedSigmaInverse = sigma.inverse();
|
||||
cachedSigmaInverse.timesEquals( hyperParameter_a );
|
||||
double sum = 0.0;
|
||||
for(int jjj = 1; jjj < mu.length; jjj++) {
|
||||
sum += MathUtils.diGamma( (hyperParameter_a + 1.0 - jjj) / 2.0 );
|
||||
for(int jjj = 1; jjj <= mu.length; jjj++) {
|
||||
sum += Gamma.digamma( (hyperParameter_a + 1.0 - jjj) / 2.0 );
|
||||
}
|
||||
sum -= Math.log( sigma.det() );
|
||||
sum += Math.log(2.0) * mu.length;
|
||||
final double gamma = 0.5 * sum;
|
||||
final double pi = MathUtils.diGamma( hyperParameter_lambda ) - MathUtils.diGamma( sumHyperParameterLambda );
|
||||
final double lambda = 0.5 * sum;
|
||||
final double pi = Gamma.digamma( hyperParameter_lambda ) - Gamma.digamma( sumHyperParameterLambda );
|
||||
final double beta = (-1.0 * mu.length) / (2.0 * hyperParameter_b);
|
||||
cachedDenomLog10 = (pi / Math.log(10.0)) + (gamma / Math.log(10.0)) + (beta / Math.log(10.0));
|
||||
cachedDenomLog10 = (pi / Math.log(10.0)) + (lambda / Math.log(10.0)) + (beta / Math.log(10.0));
|
||||
}
|
||||
|
||||
public double evaluateDatumLog10( final VariantDatum datum ) {
|
||||
|
|
@ -132,7 +134,7 @@ public class MultivariateGaussian {
|
|||
|
||||
public void maximizeGaussian( final List<VariantDatum> data, final double[] empiricalMu, final Matrix empiricalSigma,
|
||||
final double SHRINKAGE, final double DIRICHLET_PARAMETER, final double DEGREES_OF_FREEDOM ) {
|
||||
double sumProb = 0.0;
|
||||
sumProb = 1E-10;
|
||||
final Matrix wishart = new Matrix(mu.length, mu.length);
|
||||
zeroOutMu();
|
||||
zeroOutSigma();
|
||||
|
|
@ -171,8 +173,6 @@ public class MultivariateGaussian {
|
|||
mu[iii] = (sumProb * mu[iii] + SHRINKAGE * empiricalMu[iii]) / (sumProb + SHRINKAGE);
|
||||
}
|
||||
|
||||
pMixtureLog10 = sumProb; // will be normalized later by GaussianMixtureModel so no need to do it every iteration
|
||||
|
||||
hyperParameter_a = sumProb + DEGREES_OF_FREEDOM;
|
||||
hyperParameter_b = sumProb + SHRINKAGE;
|
||||
hyperParameter_lambda = sumProb + DIRICHLET_PARAMETER;
|
||||
|
|
@ -181,7 +181,7 @@ public class MultivariateGaussian {
|
|||
}
|
||||
|
||||
public void evaluateFinalModelParameters( final List<VariantDatum> data ) {
|
||||
double sumProb = 0.0;
|
||||
sumProb = 0.0;
|
||||
zeroOutMu();
|
||||
zeroOutSigma();
|
||||
|
||||
|
|
@ -206,7 +206,6 @@ public class MultivariateGaussian {
|
|||
}
|
||||
sigma.timesEquals( 1.0 / sumProb );
|
||||
|
||||
pMixtureLog10 = sumProb; // will be normalized later by GaussianMixtureModel so no need to do it here
|
||||
resetPVarInGaussian(); // clean up some memory
|
||||
}
|
||||
}
|
||||
|
|
@ -15,7 +15,8 @@ public class TrainingSet {
|
|||
public boolean isKnown = false;
|
||||
public boolean isTraining = false;
|
||||
public boolean isTruth = false;
|
||||
public double prior = 3.0;
|
||||
public boolean isConsensus = false;
|
||||
public double prior = 0.0;
|
||||
|
||||
protected final static Logger logger = Logger.getLogger(TrainingSet.class);
|
||||
|
||||
|
|
@ -25,8 +26,13 @@ public class TrainingSet {
|
|||
isKnown = tags.containsKey("known") && tags.getValue("known").equals("true");
|
||||
isTraining = tags.containsKey("training") && tags.getValue("training").equals("true");
|
||||
isTruth = tags.containsKey("truth") && tags.getValue("truth").equals("true");
|
||||
prior = ( tags.containsKey("known") ? Double.parseDouble(tags.getValue("prior")) : prior );
|
||||
isConsensus = tags.containsKey("consensus") && tags.getValue("consensus").equals("true");
|
||||
prior = ( tags.containsKey("prior") ? Double.parseDouble(tags.getValue("prior")) : prior );
|
||||
}
|
||||
if( !isConsensus ) {
|
||||
logger.info( String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) );
|
||||
} else {
|
||||
logger.info( String.format( "Found consensus track: %s", this.name) );
|
||||
}
|
||||
logger.info(String.format( "Found %s track: \tKnown = %s \tTraining = %s \tTruth = %s \tPrior = Q%.1f", this.name, isKnown, isTraining, isTruth, prior) );
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -6,9 +6,7 @@ import org.broad.tribble.util.variantcontext.VariantContext;
|
|||
import org.broadinstitute.sting.gatk.GenomeAnalysisEngine;
|
||||
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
|
||||
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
|
||||
import org.broadinstitute.sting.gatk.contexts.variantcontext.VariantContextUtils;
|
||||
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
|
||||
import org.broadinstitute.sting.utils.GenomeLocParser;
|
||||
import org.broadinstitute.sting.utils.MathUtils;
|
||||
import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
|
||||
import org.broadinstitute.sting.utils.exceptions.UserException;
|
||||
|
|
@ -51,20 +49,39 @@ public class VariantDataManager {
|
|||
|
||||
public void normalizeData() {
|
||||
boolean foundZeroVarianceAnnotation = false;
|
||||
for( int jjj = 0; jjj < meanVector.length; jjj++ ) { //BUGBUG: to clean up
|
||||
final double theMean = mean(jjj); //BUGBUG: to clean up
|
||||
final double theSTD = standardDeviation(theMean, jjj); //BUGBUG: to clean up
|
||||
logger.info( annotationKeys.get(jjj) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) );
|
||||
foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-8);
|
||||
meanVector[jjj] = theMean;
|
||||
varianceVector[jjj] = theSTD;
|
||||
for( int iii = 0; iii < meanVector.length; iii++ ) { //BUGBUG: to clean up
|
||||
final double theMean = mean(iii); //BUGBUG: to clean up
|
||||
final double theSTD = standardDeviation(theMean, iii); //BUGBUG: to clean up
|
||||
logger.info( annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) );
|
||||
foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-6);
|
||||
if( annotationKeys.get(iii).toLowerCase().contains("ranksum") ) { // BUGBUG: to clean up
|
||||
for( final VariantDatum datum : data ) {
|
||||
if( datum.annotations[iii] > 0.0 ) { datum.annotations[iii] /= 3.0; }
|
||||
}
|
||||
}
|
||||
meanVector[iii] = theMean;
|
||||
varianceVector[iii] = theSTD;
|
||||
for( final VariantDatum datum : data ) {
|
||||
datum.annotations[jjj] = ( datum.annotations[jjj] - theMean ) / theSTD; // Each data point is now [ (x - mean) / standard deviation ]
|
||||
datum.annotations[iii] = ( datum.isNull[iii] ? Normal.staticNextDouble(0.0, 1.0) : ( datum.annotations[iii] - theMean ) / theSTD );
|
||||
// Each data point is now [ (x - mean) / standard deviation ]
|
||||
if( annotationKeys.get(iii).toLowerCase().contains("ranksum") && datum.isNull[iii] && datum.annotations[iii] > 0.0 ) {
|
||||
datum.annotations[iii] /= 3.0;
|
||||
}
|
||||
}
|
||||
}
|
||||
if( foundZeroVarianceAnnotation ) {
|
||||
throw new UserException.BadInput( "Found annotations with zero variance. They must be excluded before proceeding." );
|
||||
}
|
||||
|
||||
// trim data by standard deviation threshold and place into two sets: data and failingData
|
||||
for( final VariantDatum datum : data ) {
|
||||
boolean remove = false;
|
||||
for( final double val : datum.annotations ) {
|
||||
remove = remove || (Math.abs(val) > VRAC.STD_THRESHOLD);
|
||||
}
|
||||
datum.failingSTDThreshold = remove;
|
||||
datum.usedForTraining = 0;
|
||||
}
|
||||
}
|
||||
|
||||
public void addTrainingSet( final TrainingSet trainingSet ) {
|
||||
|
|
@ -95,63 +112,85 @@ public class VariantDataManager {
|
|||
public ExpandingArrayList<VariantDatum> getTrainingData() {
|
||||
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>();
|
||||
for( final VariantDatum datum : data ) {
|
||||
if( datum.atTrainingSite && datum.originalQual > VRAC.QUAL_THRESHOLD ) {
|
||||
if( datum.atTrainingSite && !datum.failingSTDThreshold && datum.originalQual > VRAC.QUAL_THRESHOLD ) {
|
||||
trainingData.add( datum );
|
||||
datum.usedForTraining = 1;
|
||||
}
|
||||
}
|
||||
trimDataBySTD( trainingData, VRAC.STD_THRESHOLD );
|
||||
logger.info( "Training with " + trainingData.size() + " variants found in the training set(s)." );
|
||||
logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." );
|
||||
return trainingData;
|
||||
}
|
||||
|
||||
public ExpandingArrayList<VariantDatum> selectWorstVariants( final double bottomPercentage ) {
|
||||
Collections.sort( data );
|
||||
final ExpandingArrayList<VariantDatum> trainingData = new ExpandingArrayList<VariantDatum>();
|
||||
trainingData.addAll( data.subList(0, Math.round((float)bottomPercentage * data.size())) );
|
||||
logger.info( "Training with worst " + (float)bottomPercentage * 100.0f + "% of data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(Math.round((float)bottomPercentage * data.size())).lod) + "." );
|
||||
final int numToAdd = Math.round((float)bottomPercentage * data.size());
|
||||
int index = 0;
|
||||
int numAdded = 0;
|
||||
while( numAdded < numToAdd ) {
|
||||
final VariantDatum datum = data.get(index++);
|
||||
if( !datum.failingSTDThreshold ) {
|
||||
trainingData.add( datum );
|
||||
datum.usedForTraining = -1;
|
||||
numAdded++;
|
||||
}
|
||||
}
|
||||
logger.info("Training with worst " + (float) bottomPercentage * 100.0f + "% of passing data --> " + trainingData.size() + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + ".");
|
||||
return trainingData;
|
||||
}
|
||||
|
||||
public ExpandingArrayList<VariantDatum> getRandomDataForPlotting( int numToAdd ) {
|
||||
numToAdd = Math.min(numToAdd, data.size());
|
||||
final ExpandingArrayList<VariantDatum> returnData = new ExpandingArrayList<VariantDatum>();
|
||||
for( int iii = 0; iii < numToAdd; iii++) {
|
||||
final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
|
||||
if( !datum.failingSTDThreshold ) {
|
||||
returnData.add(datum);
|
||||
}
|
||||
}
|
||||
// add an extra 5% of points from bad training set, since that set is small but interesting
|
||||
for( int iii = 0; iii < Math.floor(0.05*numToAdd); iii++) {
|
||||
final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size()));
|
||||
if( datum.usedForTraining == -1 && !datum.failingSTDThreshold ) { returnData.add(datum); }
|
||||
else { iii--; }
|
||||
}
|
||||
|
||||
return returnData;
|
||||
}
|
||||
|
||||
private double mean( final int index ) {
|
||||
double sum = 0.0;
|
||||
final int numVars = data.size();
|
||||
int numNonNull = 0;
|
||||
for( final VariantDatum datum : data ) {
|
||||
sum += (datum.annotations[index] / ((double) numVars));
|
||||
if( datum.atTrainingSite && !datum.isNull[index] ) { sum += datum.annotations[index]; numNonNull++; }
|
||||
}
|
||||
return sum;
|
||||
return sum / ((double) numNonNull);
|
||||
}
|
||||
|
||||
private double standardDeviation( final double mean, final int index ) {
|
||||
double sum = 0.0;
|
||||
final int numVars = data.size();
|
||||
int numNonNull = 0;
|
||||
for( final VariantDatum datum : data ) {
|
||||
sum += ( ((datum.annotations[index] - mean)*(datum.annotations[index] - mean)) / ((double) numVars));
|
||||
if( datum.atTrainingSite && !datum.isNull[index] ) { sum += ((datum.annotations[index] - mean)*(datum.annotations[index] - mean)); numNonNull++; }
|
||||
}
|
||||
return Math.sqrt( sum );
|
||||
return Math.sqrt( sum / ((double) numNonNull) );
|
||||
}
|
||||
|
||||
public static void trimDataBySTD( final ExpandingArrayList<VariantDatum> listData, final double STD_THRESHOLD ) {
|
||||
final ExpandingArrayList<VariantDatum> dataToRemove = new ExpandingArrayList<VariantDatum>();
|
||||
for( final VariantDatum datum : listData ) {
|
||||
boolean remove = false;
|
||||
for( final double val : datum.annotations ) {
|
||||
remove = remove || (Math.abs(val) > STD_THRESHOLD);
|
||||
}
|
||||
if( remove ) { dataToRemove.add( datum ); }
|
||||
}
|
||||
listData.removeAll( dataToRemove );
|
||||
}
|
||||
|
||||
public double[] decodeAnnotations( final GenomeLocParser genomeLocParser, final VariantContext vc, final boolean jitter ) {
|
||||
public void decodeAnnotations( final VariantDatum datum, final VariantContext vc, final boolean jitter ) {
|
||||
final double[] annotations = new double[annotationKeys.size()];
|
||||
final boolean[] isNull = new boolean[annotationKeys.size()];
|
||||
int iii = 0;
|
||||
for( final String key : annotationKeys ) {
|
||||
annotations[iii++] = decodeAnnotation( genomeLocParser, key, vc, jitter );
|
||||
isNull[iii] = false;
|
||||
annotations[iii] = decodeAnnotation( key, vc, jitter );
|
||||
if( Double.isNaN(annotations[iii]) ) { isNull[iii] = true; }
|
||||
iii++;
|
||||
}
|
||||
return annotations;
|
||||
datum.annotations = annotations;
|
||||
datum.isNull = isNull;
|
||||
}
|
||||
|
||||
private static double decodeAnnotation( final GenomeLocParser genomeLocParser, final String annotationKey, final VariantContext vc, final boolean jitter ) {
|
||||
private static double decodeAnnotation( final String annotationKey, final VariantContext vc, final boolean jitter ) {
|
||||
double value;
|
||||
if( jitter && annotationKey.equalsIgnoreCase("HRUN") ) { // HRun values must be jittered a bit to work in this GMM
|
||||
value = Double.parseDouble( (String)vc.getAttribute( annotationKey ) );
|
||||
|
|
@ -161,30 +200,29 @@ public class VariantDataManager {
|
|||
} else {
|
||||
try {
|
||||
value = Double.parseDouble( (String)vc.getAttribute( annotationKey ) );
|
||||
if(Double.isNaN(value)) { throw new NumberFormatException(); }
|
||||
if( annotationKey.toLowerCase().contains("ranksum") ) { //BUGBUG: temporary hack
|
||||
if(MathUtils.compareDoubles(value, 0.0, 0.01) == 0) { value = Normal.staticNextDouble(2.0, 2.0); }
|
||||
else if(MathUtils.compareDoubles(value, 200.0, 0.01) == 0) { value = Normal.staticNextDouble(162.0, 20.0); }
|
||||
}
|
||||
if(Double.isInfinite(value)) { value = Double.NaN; }
|
||||
if(annotationKey.equals("HaplotypeScore") && MathUtils.compareDoubles(value, 0.0, 0.0001) == 0 ) { value = -0.2 + 0.4*GenomeAnalysisEngine.getRandomGenerator().nextDouble(); }
|
||||
} catch( final Exception e ) {
|
||||
throw new UserException.MalformedFile( vc.getSource(), "No double value detected for annotation = " + annotationKey + " in variant at " + VariantContextUtils.getLocation(genomeLocParser,vc) + ", reported annotation value = " + vc.getAttribute( annotationKey ), e );
|
||||
value = Double.NaN; // The VQSR works with missing data now by marginalizing over the missing dimension when evaluating clusters.
|
||||
}
|
||||
}
|
||||
return value;
|
||||
}
|
||||
|
||||
public void parseTrainingSets( final RefMetaDataTracker tracker, final ReferenceContext ref, final AlignmentContext context, final VariantContext evalVC, final VariantDatum datum, final boolean TRUST_ALL_POLYMORPHIC, final boolean FIX_OMNI ) {
|
||||
public void parseTrainingSets( final RefMetaDataTracker tracker, final ReferenceContext ref, final AlignmentContext context, final VariantContext evalVC, final VariantDatum datum, final boolean TRUST_ALL_POLYMORPHIC ) {
|
||||
datum.isKnown = false;
|
||||
datum.atTruthSite = false;
|
||||
datum.atTrainingSite = false;
|
||||
datum.prior = 2.0;
|
||||
datum.consensusCount = 0;
|
||||
for( final TrainingSet trainingSet : trainingSets ) {
|
||||
for( final VariantContext trainVC : tracker.getVariantContexts( ref, trainingSet.name, null, context.getLocation(), false, false ) ) {
|
||||
if( trainVC != null && trainVC.isVariant() && (trainVC.isNotFiltered() || (FIX_OMNI && trainVC.getFilters().size()==1 && trainVC.getFilters().contains("NOT_POLY_IN_1000G"))) && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) {
|
||||
if( trainVC != null && trainVC.isNotFiltered() && trainVC.isVariant() && ((evalVC.isSNP() && trainVC.isSNP()) || (evalVC.isIndel() && trainVC.isIndel())) && (TRUST_ALL_POLYMORPHIC || !trainVC.hasGenotypes() || trainVC.isPolymorphic()) ) {
|
||||
datum.isKnown = datum.isKnown || trainingSet.isKnown;
|
||||
datum.atTruthSite = datum.atTruthSite || trainingSet.isTruth;
|
||||
datum.atTrainingSite = datum.atTrainingSite || trainingSet.isTraining;
|
||||
datum.prior = Math.max( datum.prior, trainingSet.prior );
|
||||
datum.consensusCount += ( trainingSet.isConsensus ? 1 : 0 );
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
|||
|
|
@ -11,15 +11,19 @@ import org.broadinstitute.sting.utils.GenomeLoc;
|
|||
public class VariantDatum implements Comparable<VariantDatum> {
|
||||
|
||||
public double[] annotations;
|
||||
public boolean[] isNull;
|
||||
public boolean isKnown;
|
||||
public double lod;
|
||||
public boolean atTruthSite;
|
||||
public boolean atTrainingSite;
|
||||
public boolean isTransition;
|
||||
public boolean isSNP;
|
||||
public boolean failingSTDThreshold;
|
||||
public double originalQual;
|
||||
public double prior;
|
||||
public int consensusCount;
|
||||
public GenomeLoc pos;
|
||||
public int usedForTraining;
|
||||
public MultivariateGaussian assignment; // used in K-means implementation
|
||||
|
||||
public int compareTo( final VariantDatum other ) {
|
||||
|
|
|
|||
|
|
@ -19,19 +19,21 @@ public class VariantRecalibratorArgumentCollection {
|
|||
@Argument(fullName = "mode", shortName = "mode", doc = "Recalibration mode to employ: 1.) SNP for recalibrating only SNPs (emitting indels untouched in the output VCF); 2.) INDEL for indels; and 3.) BOTH for recalibrating both SNPs and indels simultaneously.", required = false)
|
||||
public VariantRecalibratorArgumentCollection.Mode MODE = VariantRecalibratorArgumentCollection.Mode.SNP;
|
||||
@Argument(fullName="maxGaussians", shortName="mG", doc="The maximum number of Gaussians to try during variational Bayes algorithm", required=false)
|
||||
public int MAX_GAUSSIANS = 32;
|
||||
public int MAX_GAUSSIANS = 10;
|
||||
@Argument(fullName="maxIterations", shortName="mI", doc="The maximum number of VBEM iterations to be performed in variational Bayes algorithm. Procedure will normally end when convergence is detected.", required=false)
|
||||
public int MAX_ITERATIONS = 100;
|
||||
@Argument(fullName="numKMeans", shortName="nKM", doc="The number of k-means iterations to perform in order to initialize the means of the Gaussians in the Gaussian mixture model.", required=false)
|
||||
public int NUM_KMEANS_ITERATIONS = 10;
|
||||
public int NUM_KMEANS_ITERATIONS = 30;
|
||||
@Argument(fullName="stdThreshold", shortName="std", doc="If a variant has annotations more than -std standard deviations away from mean then don't use it for building the Gaussian mixture model.", required=false)
|
||||
public double STD_THRESHOLD = 4.5;
|
||||
public double STD_THRESHOLD = 8.0;
|
||||
@Argument(fullName="qualThreshold", shortName="qual", doc="If a known variant has raw QUAL value less than -qual then don't use it for building the Gaussian mixture model.", required=false)
|
||||
public double QUAL_THRESHOLD = 80.0;
|
||||
@Argument(fullName="shrinkage", shortName="shrinkage", doc="The shrinkage parameter in variational Bayes algorithm.", required=false)
|
||||
public double SHRINKAGE = 0.0001;
|
||||
public double SHRINKAGE = 1.0;
|
||||
@Argument(fullName="dirichlet", shortName="dirichlet", doc="The dirichlet parameter in variational Bayes algorithm.", required=false)
|
||||
public double DIRICHLET_PARAMETER = 0.0001;
|
||||
public double DIRICHLET_PARAMETER = 0.001;
|
||||
@Argument(fullName="priorCounts", shortName="priorCounts", doc="The number of prior counts to use in variational Bayes algorithm.", required=false)
|
||||
public double PRIOR_COUNTS = 20.0;
|
||||
@Argument(fullName="percentBadVariants", shortName="percentBad", doc="What percentage of the worst scoring variants to use when building the Gaussian mixture model of bad variants. 0.07 means bottom 7 percent.", required=false)
|
||||
public double PERCENT_BAD_VARIANTS = 0.07;
|
||||
public double PERCENT_BAD_VARIANTS = 0.015;
|
||||
}
|
||||
|
|
|
|||
|
|
@ -21,7 +21,7 @@ public class VariantRecalibratorEngine {
|
|||
// the unified argument collection
|
||||
final private VariantRecalibratorArgumentCollection VRAC;
|
||||
|
||||
private final static double MIN_PROB_CONVERGENCE_LOG10 = 1.0;
|
||||
private final static double MIN_PROB_CONVERGENCE = 2E-2;
|
||||
|
||||
/////////////////////////////
|
||||
// Public Methods to interface with the Engine
|
||||
|
|
@ -33,7 +33,7 @@ public class VariantRecalibratorEngine {
|
|||
}
|
||||
|
||||
public GaussianMixtureModel generateModel( final List<VariantDatum> data ) {
|
||||
final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER );
|
||||
final GaussianMixtureModel model = new GaussianMixtureModel( VRAC.MAX_GAUSSIANS, data.get(0).annotations.length, VRAC.SHRINKAGE, VRAC.DIRICHLET_PARAMETER, VRAC.PRIOR_COUNTS );
|
||||
variationalBayesExpectationMaximization( model, data );
|
||||
return model;
|
||||
}
|
||||
|
|
@ -63,25 +63,26 @@ public class VariantRecalibratorEngine {
|
|||
|
||||
private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List<VariantDatum> data ) {
|
||||
|
||||
model.cacheEmpiricalStats( data );
|
||||
model.cacheEmpiricalStats();
|
||||
model.initializeRandomModel( data, VRAC.NUM_KMEANS_ITERATIONS );
|
||||
|
||||
// The VBEM loop
|
||||
double previousLikelihood = model.expectationStep( data );
|
||||
model.normalizePMixtureLog10();
|
||||
model.expectationStep( data );
|
||||
double currentLikelihood;
|
||||
int iteration = 0;
|
||||
logger.info("Finished iteration " + iteration );
|
||||
while( iteration < VRAC.MAX_ITERATIONS ) {
|
||||
iteration++;
|
||||
model.maximizationStep( data );
|
||||
currentLikelihood = model.expectationStep( data );
|
||||
|
||||
currentLikelihood = model.normalizePMixtureLog10();
|
||||
model.expectationStep( data );
|
||||
logger.info("Current change in mixture coefficients = " + String.format("%.5f", currentLikelihood));
|
||||
logger.info("Finished iteration " + iteration );
|
||||
if( Math.abs(currentLikelihood - previousLikelihood) < MIN_PROB_CONVERGENCE_LOG10 ) {
|
||||
if( iteration > 2 && currentLikelihood < MIN_PROB_CONVERGENCE ) {
|
||||
logger.info("Convergence!");
|
||||
break;
|
||||
}
|
||||
previousLikelihood = currentLikelihood;
|
||||
}
|
||||
|
||||
model.evaluateFinalModelParameters( data );
|
||||
|
|
|
|||
|
|
@ -420,6 +420,14 @@ public class MathUtils {
|
|||
return dist;
|
||||
}
|
||||
|
||||
public static double round(double num, int digits) {
|
||||
double result = num * Math.pow(10.0, (double)digits);
|
||||
result = Math.round(result);
|
||||
result = result / Math.pow(10.0, (double)digits);
|
||||
return result;
|
||||
}
|
||||
|
||||
|
||||
/**
|
||||
* normalizes the log10-based array. ASSUMES THAT ALL ARRAY ENTRIES ARE <= 0 (<= 1 IN REAL-SPACE).
|
||||
*
|
||||
|
|
@ -717,16 +725,6 @@ public class MathUtils {
|
|||
return ans;
|
||||
}
|
||||
|
||||
// lifted from the internet
|
||||
// http://www.cs.princeton.edu/introcs/91float/Gamma.java.html
|
||||
public static double logGamma(double x) {
|
||||
double tmp = (x - 0.5) * Math.log(x + 4.5) - (x + 4.5);
|
||||
double ser = 1.0 + 76.18009173 / (x + 0) - 86.50532033 / (x + 1)
|
||||
+ 24.01409822 / (x + 2) - 1.231739516 / (x + 3)
|
||||
+ 0.00120858003 / (x + 4) - 0.00000536382 / (x + 5);
|
||||
return tmp + Math.log(ser * Math.sqrt(2 * Math.PI));
|
||||
}
|
||||
|
||||
public static double percentage(double x, double base) {
|
||||
return (base > 0 ? (x / base) * 100.0 : 0);
|
||||
}
|
||||
|
|
@ -892,15 +890,6 @@ public class MathUtils {
|
|||
return getQScoreOrderStatistic(reads, offsets, (int)Math.floor(reads.size()/2.));
|
||||
}
|
||||
|
||||
// from http://en.wikipedia.org/wiki/Digamma_function
|
||||
// According to J.M. Bernardo AS 103 algorithm the digamma function for x, a real number, can be approximated by:
|
||||
public static double diGamma(final double x) {
|
||||
return Math.log(x) - ( 1.0 / (2.0 * x) )
|
||||
- ( 1.0 / (12.0 * Math.pow(x, 2.0)) )
|
||||
+ ( 1.0 / (120.0 * Math.pow(x, 4.0)) )
|
||||
- ( 1.0 / (252.0 * Math.pow(x, 6.0)) );
|
||||
}
|
||||
|
||||
/** A utility class that computes on the fly average and standard deviation for a stream of numbers.
|
||||
* The number of observations does not have to be known in advance, and can be also very big (so that
|
||||
* it could overflow any naive summation-based scheme or cause loss of precision).
|
||||
|
|
|
|||
|
|
@ -24,33 +24,30 @@ public class VariantRecalibrationWalkersV2IntegrationTest extends WalkerTest {
|
|||
}
|
||||
}
|
||||
|
||||
VRTest yriTrio = new VRTest("yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf",
|
||||
"74127253b3e551ae602b5957eef8e37e", // tranches
|
||||
"38fe940269ce56449b2e4e3493f89cd9", // recal file
|
||||
"067197ba4212a38c7fd736c1f66b294d"); // cut VCF
|
||||
|
||||
VRTest lowPass = new VRTest("lowpass.N3.chr1.raw.vcf",
|
||||
"40242ee5ce97ffde50cc7a12cfa0d841", // tranches
|
||||
"eb8df1fbf3d71095bdb3a52713dd2a64", // recal file
|
||||
"13aa221241a974883072bfad14ad8afc"); // cut VCF
|
||||
VRTest lowPass = new VRTest("phase1.projectConsensus.chr20.raw.snps.vcf",
|
||||
"920b12d7765eb4f6f4a1bab045679b31", // tranches
|
||||
"41bbc5f07c8a9573d5bb638f01808bba", // recal file
|
||||
"2b8c5e884bf5a739b782f5b3bf17f19c"); // cut VCF
|
||||
|
||||
@DataProvider(name = "VRTest")
|
||||
public Object[][] createData1() {
|
||||
return new Object[][]{ {yriTrio}, {lowPass} };
|
||||
return new Object[][]{ {lowPass} };
|
||||
//return new Object[][]{ {yriTrio}, {lowPass} }; // Add hg19 chr20 trio calls here
|
||||
}
|
||||
|
||||
@Test(dataProvider = "VRTest")
|
||||
public void testVariantRecalibrator(VRTest params) {
|
||||
//System.out.printf("PARAMS FOR %s is %s%n", vcf, clusterFile);
|
||||
WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec(
|
||||
"-R " + b36KGReference +
|
||||
" -B:dbsnp,DBSNP,known=true,training=false,truth=false,prior=10.0 " + GATKDataLocation + "dbsnp_129_b36.rod" +
|
||||
" -B:hapmap,VCF,known=false,training=true,truth=true,prior=15.0 " + comparisonDataLocation + "Validated/HapMap/3.2/sites_r27_nr.b36_fwd.vcf" +
|
||||
" -B:omni,VCF,known=false,training=true,truth=true,prior=12.0 " + comparisonDataLocation + "Validated/Omni2.5_chip/1212samples.b36.vcf" +
|
||||
"-R " + b37KGReference +
|
||||
" -B:dbsnp,VCF,known=true,training=false,truth=false,prior=10.0 " + GATKDataLocation + "dbsnp_132_b37.leftAligned.vcf" +
|
||||
" -B:hapmap,VCF,known=false,training=true,truth=true,prior=15.0 " + comparisonDataLocation + "Validated/HapMap/3.3/sites_r27_nr.b37_fwd.vcf" +
|
||||
" -B:omni,VCF,known=false,training=true,truth=true,prior=12.0 " + comparisonDataLocation + "Validated/Omni2.5_chip/Omni25_sites_1525_samples.b37.vcf" +
|
||||
" -T ContrastiveRecalibrator" +
|
||||
" -B:input,VCF " + params.inVCF +
|
||||
" -L 1:50,000,000-120,000,000" +
|
||||
" -an QD -an MQ -an SB" +
|
||||
" -L 20:1,000,000-40,000,000" +
|
||||
" -an QD -an HaplotypeScore -an HRun" +
|
||||
" -percentBad 0.07" +
|
||||
" --trustAllPolymorphic" + // for speed
|
||||
" -recalFile %s" +
|
||||
" -tranchesFile %s",
|
||||
|
|
@ -61,9 +58,9 @@ public class VariantRecalibrationWalkersV2IntegrationTest extends WalkerTest {
|
|||
@Test(dataProvider = "VRTest",dependsOnMethods="testVariantRecalibrator")
|
||||
public void testApplyRecalibration(VRTest params) {
|
||||
WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec(
|
||||
"-R " + b36KGReference +
|
||||
"-R " + b37KGReference +
|
||||
" -T ApplyRecalibration" +
|
||||
" -L 1:60,000,000-115,000,000" +
|
||||
" -L 20:12,000,000-30,000,000" +
|
||||
" -NO_HEADER" +
|
||||
" -B:input,VCF " + params.inVCF +
|
||||
" -o %s" +
|
||||
|
|
|
|||
|
|
@ -96,8 +96,8 @@ class MethodsDevelopmentCallingPipeline extends QScript {
|
|||
val hapmap_b36 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/HapMap/3.3/sites_r27_nr.b36_fwd.vcf"
|
||||
val hapmap_b37 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/HapMap/3.3/sites_r27_nr.b37_fwd.vcf"
|
||||
val training_hapmap_b37 = "/humgen/1kg/processing/pipeline_test_bams/hapmap3.3_training_chr20.vcf"
|
||||
val omni_b36 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/1212samples.b36.vcf"
|
||||
val omni_b37 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/1212samples.b37.vcf"
|
||||
val omni_b36 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/Omni25_sites_1525_samples.b36.vcf"
|
||||
val omni_b37 = "/humgen/gsa-hpprojects/GATK/data/Comparisons/Validated/Omni2.5_chip/Omni25_sites_1525_samples.b37.vcf"
|
||||
val indelMask_b36 = "/humgen/1kg/processing/pipeline_test_bams/pilot1.dindel.mask.b36.bed"
|
||||
val indelMask_b37 = "/humgen/1kg/processing/pipeline_test_bams/pilot1.dindel.mask.b37.bed"
|
||||
|
||||
|
|
@ -105,7 +105,7 @@ class MethodsDevelopmentCallingPipeline extends QScript {
|
|||
val indels: Boolean = true
|
||||
|
||||
val queueLogDir = ".qlog/"
|
||||
|
||||
|
||||
val targetDataSets: Map[String, Target] = Map(
|
||||
"HiSeq" -> new Target("NA12878.HiSeq", hg18, dbSNP_hg18_129, hapmap_hg18,
|
||||
"/humgen/gsa-hpprojects/dev/depristo/oneOffProjects/1000GenomesProcessingPaper/wgs.v13/HiSeq.WGS.cleaned.indels.10.mask",
|
||||
|
|
@ -115,7 +115,7 @@ class MethodsDevelopmentCallingPipeline extends QScript {
|
|||
"HiSeq19" -> new Target("NA12878.HiSeq19", hg19, dbSNP_b37_129, hapmap_b37, indelMask_b37,
|
||||
new File("/humgen/gsa-hpprojects/NA12878Collection/bams/NA12878.HiSeq.WGS.bwa.cleaned.recal.hg19.bam"),
|
||||
new File("/humgen/gsa-hpprojects/dev/carneiro/hiseq19/analysis/snps/NA12878.HiSeq19.filtered.vcf"),
|
||||
"/humgen/1kg/processing/pipeline_test_bams/whole_genome_chunked.hg19.intervals", 2.3, 98.0, !lowPass),
|
||||
"/humgen/1kg/processing/pipeline_test_bams/whole_genome_chunked.hg19.intervals", 2.3, 99.0, !lowPass),
|
||||
"GA2hg19" -> new Target("NA12878.GA2.hg19", hg19, dbSNP_b37_129, hapmap_b37, indelMask_b37,
|
||||
new File("/humgen/gsa-hpprojects/NA12878Collection/bams/NA12878.GA2.WGS.bwa.cleaned.hg19.bam"),
|
||||
new File("/humgen/gsa-hpprojects/dev/carneiro/hiseq19/analysis/snps/NA12878.GA2.hg19.filtered.vcf"),
|
||||
|
|
@ -251,17 +251,17 @@ class MethodsDevelopmentCallingPipeline extends QScript {
|
|||
this.rodBind :+= RodBind("input", "VCF", if ( goldStandard ) { t.goldStandard_VCF } else { t.rawVCF } )
|
||||
this.rodBind :+= RodBind("hapmap", "VCF", t.hapmapFile, "known=false,training=true,truth=true,prior=15.0")
|
||||
if( t.hapmapFile.contains("b37") )
|
||||
this.rodBind :+= RodBind("omni", "VCF", omni_b37, "known=false,training=true,truth=false,prior=12.0")
|
||||
this.rodBind :+= RodBind("omni", "VCF", omni_b37, "known=false,training=true,truth=true,prior=12.0")
|
||||
else if( t.hapmapFile.contains("b36") )
|
||||
this.rodBind :+= RodBind("omni", "VCF", omni_b36, "known=false,training=true,truth=false,prior=12.0")
|
||||
this.rodBind :+= RodBind("omni", "VCF", omni_b36, "known=false,training=true,truth=true,prior=12.0")
|
||||
if (t.dbsnpFile.endsWith(".rod"))
|
||||
this.rodBind :+= RodBind("dbsnp", "DBSNP", t.dbsnpFile, "known=true,training=false,truth=false,prior=10.0")
|
||||
else if (t.dbsnpFile.endsWith(".vcf"))
|
||||
this.rodBind :+= RodBind("dbsnp", "VCF", t.dbsnpFile, "known=true,training=false,truth=false,prior=10.0")
|
||||
this.use_annotation ++= List("QD", "SB", "HaplotypeScore", "HRun")
|
||||
this.use_annotation ++= List("QD", "HaplotypeScore", "MQRankSum", "ReadPosRankSum", "HRun")
|
||||
this.tranches_file = if ( goldStandard ) { t.goldStandardTranchesFile } else { t.tranchesFile }
|
||||
this.recal_file = if ( goldStandard ) { t.goldStandardRecalFile } else { t.recalFile }
|
||||
this.fixOmni = true // temporary argument until new Omni file is released
|
||||
this.allPoly = true
|
||||
this.tranche ++= List("100.0", "99.9", "99.5", "99.3", "99.0", "98.9", "98.8", "98.5", "98.4", "98.3", "98.2", "98.1", "98.0", "97.9", "97.8", "97.5", "97.0", "95.0", "90.0")
|
||||
this.analysisName = t.name + "_VQSR"
|
||||
this.jobName = queueLogDir + t.name + ".VQSR"
|
||||
|
|
|
|||
Loading…
Reference in New Issue