From 9d01670f62bee7cfbac2efbb1b3445d65a9f161d Mon Sep 17 00:00:00 2001 From: rpoplin Date: Sun, 2 May 2010 19:21:23 +0000 Subject: [PATCH] Major update to the Variant Optimizer. It now performs clustering for both the titv and titv-less models simultaneously, outputting the cluster files at every iteration. It makes use of the Jama matrix library to do full inverse and determinant calculation for the covariance matrix where before it was using only approximations. git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@3286 348d0f76-0448-11de-a6fe-93d51630548a --- .../ApplyVariantClustersWalker.java | 13 +- .../VariantClusteringModel.java | 2 +- .../variantoptimizer/VariantDataManager.java | 15 +- .../VariantGaussianMixtureModel.java | 896 +++++++++--------- .../variantoptimizer/VariantOptimizer.java | 37 +- 5 files changed, 493 insertions(+), 470 deletions(-) diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java index f9b8780af..67e11bbe8 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java @@ -159,12 +159,17 @@ public class ApplyVariantClustersWalker extends RodWalker vc.getHomVarCount(); // BUGBUG: what to do here for multi sample calls? - final double pTrue = theModel.evaluateVariant( vc.getAttributes(), vc.getPhredScaledQual(), variantDatum.isHet ); - final double recalQual = QualityUtils.phredScaleErrorRate( Math.max(1.0 - pTrue, 0.000000001) ); + final double pTrue = theModel.evaluateVariant( vc.getAttributes(), vc.getPhredScaledQual() ); + double recalQual = QualityUtils.phredScaleErrorRate( Math.max(1.0 - pTrue, 0.000000001) ); + if( !theModel.isUsingTiTvModel ) { + recalQual *= 30.0; + } else { + recalQual *= 3.0; + } + // BUGBUG: decide how to scale the quality score - if( variantDatum.isKnown && KNOWN_VAR_QUAL_PRIOR > 0.1 ) { + if( variantDatum.isKnown && KNOWN_VAR_QUAL_PRIOR > 0.1 ) { // only use the known prior if the value is specified (meaning not equal to zero) variantDatum.qual = 0.5 * recalQual + 0.5 * KNOWN_VAR_QUAL_PRIOR; } else { variantDatum.qual = recalQual; diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java index 891f21443..103f29392 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java @@ -32,6 +32,6 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; */ public interface VariantClusteringModel extends VariantOptimizationInterface { - public void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster ); + public void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster, final String clusterFilename ); //public void applyClusters( final VariantDatum[] data, final String outputPrefix ); } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java index e67bc09fb..1eff46b9b 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java @@ -25,6 +25,7 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; +import org.apache.log4j.Logger; import org.broadinstitute.sting.utils.collections.ExpandingArrayList; import org.broadinstitute.sting.utils.StingException; @@ -37,6 +38,9 @@ import java.io.PrintStream; */ public class VariantDataManager { + + protected final static Logger logger = Logger.getLogger(VariantDataManager.class); + public final VariantDatum[] data; public final int numVariants; public final int numAnnotations; @@ -56,7 +60,7 @@ public class VariantDataManager { meanVector = null; varianceVector = null; } else { - numAnnotations = _annotationKeys.size() + 1; // +1 for QUAL + numAnnotations = _annotationKeys.size(); if( numAnnotations <= 0 ) { throw new StingException( "There are zero annotations! (or possibly a problem with integer overflow)" ); } @@ -91,11 +95,12 @@ public class VariantDataManager { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { final double theMean = mean(data, jjj); final double theSTD = standardDeviation(data, theMean, jjj); - System.out.println( (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj)) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); + logger.info( annotationKeys.get(jjj) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); if( theSTD < 1E-8 ) { foundZeroVarianceAnnotation = true; - System.out.println("Zero variance is a problem: standard deviation = " + theSTD); - System.out.println("User must -exclude annotations with zero variance. Annotation = " + (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj))); + logger.warn("Zero variance is a problem: standard deviation = " + theSTD + " User must -exclude annotations with zero variance. Annotation = " + (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj))); + } else if( theSTD < 1E-2 ) { + logger.warn("Warning! Tiny variance. It is strongly recommended that you -exclude " + annotationKeys.get(jjj)); } meanVector[jjj] = theMean; varianceVector[jjj] = theSTD; @@ -129,7 +134,7 @@ public class VariantDataManager { public void printClusterFileHeader( PrintStream outputFile ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - outputFile.println("@!ANNOTATION," + (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj)) + "," + meanVector[jjj] + "," + varianceVector[jjj]); + outputFile.println("@!ANNOTATION," + annotationKeys.get(jjj) + "," + meanVector[jjj] + "," + varianceVector[jjj]); } } } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java index 1ef1f1601..2b11b62d1 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java @@ -25,10 +25,13 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; +import org.apache.log4j.Logger; import org.broadinstitute.sting.utils.collections.ExpandingArrayList; import org.broadinstitute.sting.utils.StingException; import org.broadinstitute.sting.utils.text.XReadLines; +import Jama.*; + import java.io.File; import java.io.FileNotFoundException; import java.io.PrintStream; @@ -42,47 +45,50 @@ import java.util.regex.Pattern; * Date: Feb 26, 2010 */ -public final class VariantGaussianMixtureModel extends VariantOptimizationModel implements VariantClusteringModel { +public final class VariantGaussianMixtureModel extends VariantOptimizationModel { + protected final static Logger logger = Logger.getLogger(VariantGaussianMixtureModel.class); + public final VariantDataManager dataManager; private final int numGaussians; private final int numIterations; private final long RANDOM_SEED = 91801305; private final Random rand = new Random( RANDOM_SEED ); - private final double MIN_PROB = 1E-30; - private final double MIN_SUM_PROB = 1E-20; + private final double MIN_PROB = 1E-7; + private final double MIN_SIGMA = 1E-5; + private final double MIN_DETERMINANT = 1E-5; private final double[][] mu; // The means for the clusters - private final double[][] sigma; // The variances for the clusters, sigma is really sigma^2 + private final Matrix[] sigma; // The variances for the clusters, sigma is really sigma^2 + private final Matrix[] sigmaInverse; + private final boolean[] deadCluster; private final double[] pCluster; - //private final boolean[] isHetCluster; - private final int[] numMaxClusterKnown; - private final int[] numMaxClusterNovel; + private final double[] determinant; private final double[] clusterTITV; private final double[] clusterTruePositiveRate; // The true positive rate implied by the cluster's Ti/Tv ratio private final int minVarInCluster; - private final double knownAlphaFactor; + public final boolean isUsingTiTvModel; - private static final double INFINITE_ANNOTATION_VALUE = 10000.0; + private static final double INFINITE_ANNOTATION_VALUE = 6000.0; private static final Pattern ANNOTATION_PATTERN = Pattern.compile("^@!ANNOTATION.*"); private static final Pattern CLUSTER_PATTERN = Pattern.compile("^@!CLUSTER.*"); - public VariantGaussianMixtureModel( final VariantDataManager _dataManager, final double _targetTITV, final int _numGaussians, final int _numIterations, final int _minVarInCluster, final double _knownAlphaFactor ) { + public VariantGaussianMixtureModel( final VariantDataManager _dataManager, final double _targetTITV, final int _numGaussians, final int _numIterations, final int _minVarInCluster ) { super( _targetTITV ); dataManager = _dataManager; - numGaussians = ( _numGaussians % 2 == 0 ? _numGaussians : _numGaussians + 1 ); + numGaussians = _numGaussians; numIterations = _numIterations; mu = new double[numGaussians][]; - sigma = new double[numGaussians][]; + sigma = new Matrix[numGaussians]; + determinant = new double[numGaussians]; + deadCluster = new boolean[numGaussians]; pCluster = new double[numGaussians]; - //isHetCluster = null; - numMaxClusterKnown = new int[numGaussians]; - numMaxClusterNovel = new int[numGaussians]; clusterTITV = new double[numGaussians]; clusterTruePositiveRate = new double[numGaussians]; minVarInCluster = _minVarInCluster; - knownAlphaFactor = _knownAlphaFactor; + sigmaInverse = null; + isUsingTiTvModel = false; // this field isn't used during VariantOptimizerWalker } public VariantGaussianMixtureModel( final double _targetTITV, final String clusterFileName, final double backOffGaussianFactor ) { @@ -105,128 +111,114 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } dataManager = new VariantDataManager( annotationLines ); + // Several of the clustering parameters aren't used the second time around in ApplyVariantClusters numIterations = 0; - numMaxClusterKnown = null; - numMaxClusterNovel = null; clusterTITV = null; + deadCluster = null; minVarInCluster = 0; - knownAlphaFactor = 0.0; - //BUGBUG: move this parsing out of the constructor + // BUGBUG: move this parsing out of the constructor numGaussians = clusterLines.size(); mu = new double[numGaussians][dataManager.numAnnotations]; - sigma = new double[numGaussians][dataManager.numAnnotations]; + double sigmaVals[][][] = new double[numGaussians][dataManager.numAnnotations][dataManager.numAnnotations]; + sigma = new Matrix[numGaussians]; + sigmaInverse = new Matrix[numGaussians]; pCluster = new double[numGaussians]; - //isHetCluster = new boolean[numGaussians]; + determinant = new double[numGaussians]; clusterTruePositiveRate = new double[numGaussians]; + boolean _isUsingTiTvModel = false; int kkk = 0; for( String line : clusterLines ) { final String[] vals = line.split(","); - //isHetCluster[kkk] = Integer.parseInt(vals[1]) == 1; - pCluster[kkk] = Double.parseDouble(vals[2]); - clusterTruePositiveRate[kkk] = Double.parseDouble(vals[6]); //BUGBUG: #define these magic index numbers, very easy to make a mistake here + pCluster[kkk] = Double.parseDouble(vals[1]); + clusterTruePositiveRate[kkk] = Double.parseDouble(vals[3]); // BUGBUG: #define these magic index numbers, very easy to make a mistake here + if( clusterTruePositiveRate[kkk] != 1.0 ) { _isUsingTiTvModel = true; } for( int jjj = 0; jjj < dataManager.numAnnotations; jjj++ ) { - mu[kkk][jjj] = Double.parseDouble(vals[7+jjj]); - sigma[kkk][jjj] = Double.parseDouble(vals[7+dataManager.numAnnotations+jjj]) * backOffGaussianFactor; //BUGBUG: *3, suggestion by Nick to prevent GMM from over fitting and producing low likelihoods for most points + mu[kkk][jjj] = Double.parseDouble(vals[4+jjj]); + for( int ppp = 0; ppp < dataManager.numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] = Double.parseDouble(vals[4+dataManager.numAnnotations+(jjj*dataManager.numAnnotations)+ppp]) * backOffGaussianFactor; // BUGBUG: *3, suggestion by Nick to prevent GMM from over fitting and producing low likelihoods for most points + } } + + sigma[kkk] = new Matrix(sigmaVals[kkk]); + sigmaInverse[kkk] = sigma[kkk].inverse(); // Precompute all the inverses and determinants for use later + determinant[kkk] = sigma[kkk].det(); + //if( determinant[kkk] < MIN_DETERMINANT ) { determinant[kkk] = MIN_DETERMINANT; } kkk++; } + isUsingTiTvModel = _isUsingTiTvModel; - System.out.println("Found " + numGaussians + " clusters and using " + dataManager.numAnnotations + " annotations: " + dataManager.annotationKeys); + logger.info("Found " + numGaussians + " clusters and using " + dataManager.numAnnotations + " annotations: " + dataManager.annotationKeys); } public final void run( final String clusterFileName ) { - final int MAX_VARS = 1000000; //BUGBUG: make this a command line argument + final int MAX_KNOWN_VARS = 5000000; // BUGBUG: make this a command line argument + final int MAX_NOVEL_VARS = 5000000; // BUGBUG: make this a command line argument + final double knownNovelMixture = 1.5; // BUGBUG: make this a command line argument // Create the subset of the data to cluster with int numNovel = 0; int numKnown = 0; - int numHet = 0; - int numHom = 0; for( final VariantDatum datum : dataManager.data ) { if( datum.isKnown ) { numKnown++; } else { numNovel++; } - if( datum.isHet ) { - numHet++; - } else { - numHom++; - } } - // This block of code is used to cluster with novels + 1.5x knowns mixed together + final int numNovelCluster = Math.min( numNovel, MAX_NOVEL_VARS ); + final int numKnownCluster = Math.min( numKnown, MAX_KNOWN_VARS ); + final int numKnownTogether = Math.min( numKnownCluster, (int) Math.floor(knownNovelMixture * numNovelCluster) ); - VariantDatum[] data; + final VariantDatum[] dataTogether = new VariantDatum[numNovelCluster + numKnownTogether]; + final VariantDatum[] dataKnown = new VariantDatum[numKnownCluster]; - // Grab a set of data that is all of the novel variants plus 1.5x as many known variants drawn at random - // If there are almost as many novels as known, simply use all the variants - // BUGBUG: allow downsampling and arbitrary mixtures of knowns and novels - final int numSubset = (int)Math.floor(numNovel*2.5); - if( numSubset * 1.3 < dataManager.numVariants ) { - data = new VariantDatum[numSubset]; - int iii = 0; + // Create the dataTogether array, which is all the novels and 1.5x as many knowns, downsampled if there are too many + int iii = 0; + if( numNovelCluster == numNovel ) { for( final VariantDatum datum : dataManager.data ) { if( !datum.isKnown ) { - data[iii++] = datum; + dataTogether[iii++] = datum; } } - while( iii < numSubset ) { // grab an equal number of known variants at random + } else { + logger.info("Capped at " + MAX_NOVEL_VARS + " novel variants."); + while( iii < numNovelCluster ) { + final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; + if( !datum.isKnown ) { + dataTogether[iii++] = datum; + } + } + } + if( numKnownTogether == numKnown ) { + for( final VariantDatum datum : dataManager.data ) { + if( datum.isKnown ) { + dataTogether[iii++] = datum; + } + } + } else { + while( iii < numNovelCluster + numKnownTogether ) { final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; if( datum.isKnown ) { - data[iii++] = datum; - } - } - } else { - data = dataManager.data; - } - - System.out.println("Clustering with " + numNovel + " novel variants and " + (data.length - numNovel) + " known variants..."); - if( data.length == dataManager.numVariants ) { System.out.println(" (used all variants since 2.5*numNovel is so large compared to the full set) "); } - createClusters( data, 0, numGaussians ); // Using a subset of the data - System.out.println("Outputting cluster parameters..."); - printClusters( clusterFileName ); - - - - - - // This block of code is to cluster knowns and novels separately - /* - final VariantDatum[] dataNovel = new VariantDatum[Math.min(numNovel,MAX_VARS)]; - final VariantDatum[] dataKnown = new VariantDatum[Math.min(numKnown,MAX_VARS)]; - - //BUGBUG: This is ugly - int jjj = 0; - if(numNovel <= MAX_VARS) { - for( final VariantDatum datum : dataManager.data ) { - if( !datum.isKnown ) { - dataNovel[jjj++] = datum; - } - } - } else { - System.out.println("Capped at " + MAX_VARS + " novel variants."); - while( jjj < MAX_VARS ) { - final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; - if( !datum.isKnown ) { - dataNovel[jjj++] = datum; + dataTogether[iii++] = datum; } } } - int iii = 0; - if(numKnown <= MAX_VARS) { + // Create the dataKnown array, which is simply all the known vars or downsampled if there are too many + iii = 0; + if( numKnownCluster == numKnown ) { for( final VariantDatum datum : dataManager.data ) { if( datum.isKnown ) { dataKnown[iii++] = datum; } } } else { - System.out.println("Capped at " + MAX_VARS + " known variants."); - while( iii < MAX_VARS ) { + logger.info("Capped at " + MAX_KNOWN_VARS + " known variants."); + while( iii < numKnownCluster ) { final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; if( datum.isKnown ) { dataKnown[iii++] = datum; @@ -234,73 +226,95 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } } + final boolean useTITV = true; + logger.info("First, cluster with novels and knowns together to use ti/tv based models:"); + logger.info("Clustering with " + numNovelCluster + " novel variants and " + numKnownTogether + " known variants."); + createClusters( dataTogether, 0, numGaussians, clusterFileName, useTITV ); - System.out.println("Clustering with " + Math.min(numNovel,MAX_VARS) + " novel variants."); - createClusters( dataNovel, 0, numGaussians / 2 ); - System.out.println("Clustering with " + Math.min(numKnown,MAX_VARS) + " known variants."); - createClusters( dataKnown, numGaussians / 2, numGaussians ); - System.out.println("Outputting cluster parameters..."); - printClusters( clusterFileName ); - - */ - - /* - // This block of code is to cluster het and hom calls separately, but mixing together knowns and novels - final VariantDatum[] dataHet = new VariantDatum[Math.min(numHet,MAX_VARS)]; - final VariantDatum[] dataHom = new VariantDatum[Math.min(numHom,MAX_VARS)]; - - //BUGBUG: This is ugly - int jjj = 0; - if(numHet <= MAX_VARS) { - for( final VariantDatum datum : dataManager.data ) { - if( datum.isHet ) { - dataHet[jjj++] = datum; - } - } - } else { - System.out.println("Found " + numHet + " het variants but capped at clustering with " + MAX_VARS + "."); - while( jjj < MAX_VARS ) { - final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; - if( datum.isHet ) { - dataHet[jjj++] = datum; - } - } - } - - int iii = 0; - if(numHom <= MAX_VARS) { - for( final VariantDatum datum : dataManager.data ) { - if( !datum.isHet ) { - dataHom[iii++] = datum; - } - } - } else { - System.out.println("Found " + numHom + " hom variants but capped at clustering with " + MAX_VARS + "."); - while( iii < MAX_VARS ) { - final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; - if( !datum.isHet ) { - dataHom[iii++] = datum; - } - } - } - - System.out.println("Clustering with " + Math.min(numHet,MAX_VARS) + " het variants."); - createClusters( dataHet, 0, numGaussians / 2 ); - System.out.println("Clustering with " + Math.min(numHom,MAX_VARS) + " hom variants."); - createClusters( dataHom, numGaussians / 2, numGaussians ); - System.out.println("Outputting cluster parameters..."); - printClusters( clusterFileName ); -*/ + logger.info("Finally, cluster with only knowns to use ti/tv-less models:"); + logger.info("Clustering with " + numKnownCluster + " known variants."); + createClusters( dataKnown, 0, numGaussians, clusterFileName, !useTITV ); } - -/* - public final void createClusters( final VariantDatum[] data, int startCluster, int stopCluster ) { + public final void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster, final String clusterFileName, final boolean useTITV ) { final int numVariants = data.length; final int numAnnotations = data[0].annotations.length; - final double[][] pVarInCluster = new double[numGaussians][numVariants]; + final double[][] pVarInCluster = new double[numGaussians][numVariants]; // Probability that the variant is in that cluster = simply evaluate the multivariate Gaussian + + // loop control variables: + // iii - loop over data points + // jjj - loop over annotations (features) + // ppp - loop over annotations again (full rank covariance matrix) + // kkk - loop over clusters + // ttt - loop over EM iterations + + // Set up the initial random Gaussians + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + deadCluster[kkk] = false; + pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster)); + //final double[] randMu = new double[numAnnotations]; + //for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + // randMu[jjj] = -1.5 + 3.0 * rand.nextDouble(); + //} + mu[kkk] = data[rand.nextInt(numVariants)].annotations; + final double[][] randSigma = new double[numAnnotations][numAnnotations]; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + randSigma[ppp][jjj] = 0.5 + 0.5 * rand.nextDouble(); // data has been normalized so sigmas are centered at 1.0 + if(jjj != ppp) { randSigma[jjj][ppp] = 0.0; } // Sigma is a symmetric, positive-definite matrix + } + } + Matrix tmp = new Matrix(randSigma); + tmp = tmp.times(tmp.transpose()); + sigma[kkk] = tmp; + determinant[kkk] = sigma[kkk].det(); + //if( determinant[kkk] < MIN_DETERMINANT ) { deadCluster[kkk] = true; } + } + + // The EM loop + for( int ttt = 0; ttt < numIterations; ttt++ ) { + + //int numValidClusters = 0; + //for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + // if( !deadCluster[kkk] ) { numValidClusters++; } + //} + //logger.info("Starting iteration " + (ttt+1) + " with " + numValidClusters + " clusters."); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Expectation Step (calculate the probability that each data point is in each cluster) + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Maximization Step (move the clusters to maximize the sum probability of each data point) + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + maximizeGaussians( data, pVarInCluster, startCluster, stopCluster ); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Estimate each cluster's p(true) and output cluster parameters + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + outputGaussians( data, pVarInCluster, ttt+1, startCluster, stopCluster, clusterFileName, useTITV ); + + logger.info("Finished iteration " + (ttt+1) ); + } + } + + private void outputGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int iterationNumber, + final int startCluster, final int stopCluster, final String clusterFileName, final boolean useTITV ) { + + if( !useTITV ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + clusterTITV[kkk] = 0.0; + clusterTruePositiveRate[kkk] = 1.0; + } + printClusterParamters( clusterFileName + ".WithoutTiTv." + iterationNumber ); + return; + } + + final int numVariants = data.length; + final double[] probTi = new double[numGaussians]; final double[] probTv = new double[numGaussians]; final double[] probKnown = new double[numGaussians]; @@ -310,229 +324,102 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final double[] probNovelTi = new double[numGaussians]; final double[] probNovelTv = new double[numGaussians]; - // loop control variables: - // iii - loop over data points - // jjj - loop over annotations (features) - // kkk - loop over clusters - // ttt - loop over EM iterations - - // Set up the initial random Gaussians - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - numMaxClusterKnown[kkk] = 0; - numMaxClusterNovel[kkk] = 0; - pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster)); - mu[kkk] = data[rand.nextInt(numVariants)].annotations; - final double[] randSigma = new double[numAnnotations]; - if( dataManager.isNormalized ) { - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble(); - } - } else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = dataManager.varianceVector[jjj] + ((1.0 + rand.nextDouble()) * 0.01 * dataManager.varianceVector[jjj]); - } - } - sigma[kkk] = randSigma; - } - - for( int ttt = 0; ttt < numIterations; ttt++ ) { - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Expectation Step (calculate the probability that each data point is in each cluster) - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Maximization Step (move the clusters to maximize the sum probability of each data point) - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - maximizeGaussians( data, pVarInCluster, startCluster, stopCluster ); - - System.out.println("Finished iteration " + (ttt+1) ); - } - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Evaluate the clusters using titv as an estimate of the true positive rate - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); // One final evaluation because the Gaussians moved in the last maximization step - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { probTi[kkk] = 0.0; probTv[kkk] = 0.0; probKnown[kkk] = 0.0; probNovel[kkk] = 0.0; + probKnownTi[kkk] = 0.0; + probKnownTv[kkk] = 0.0; + probNovelTi[kkk] = 0.0; + probNovelTv[kkk] = 0.0; } + + // Use the cluster's probabilistic Ti/Tv ratio as the indication of the cluster's true positive rate for( int iii = 0; iii < numVariants; iii++ ) { final boolean isTransition = data[iii].isTransition; final boolean isKnown = data[iii].isKnown; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - final double prob = pVarInCluster[kkk][iii]; - if( isKnown ) { // known - probKnown[kkk] += prob; - if( isTransition ) { // transition - probKnownTi[kkk] += prob; - probTi[kkk] += prob; - } else { // transversion - probKnownTv[kkk] += prob; - probTv[kkk] += prob; - } - } else { //novel - probNovel[kkk] += prob; - if( isTransition ) { // transition - probNovelTi[kkk] += prob; - probTi[kkk] += prob; - } else { // transversion - probNovelTv[kkk] += prob; - probTv[kkk] += prob; + if( !deadCluster[kkk] ) { + final double prob = pVarInCluster[kkk][iii]; + if( isKnown ) { // known + probKnown[kkk] += prob; + if( isTransition ) { // transition + probKnownTi[kkk] += prob; + probTi[kkk] += prob; + } else { // transversion + probKnownTv[kkk] += prob; + probTv[kkk] += prob; + } + } else { //novel + probNovel[kkk] += prob; + if( isTransition ) { // transition + probNovelTi[kkk] += prob; + probTi[kkk] += prob; + } else { // transversion + probNovelTv[kkk] += prob; + probTv[kkk] += prob; + } } } } - - double maxProb = pVarInCluster[startCluster][iii]; - int maxCluster = startCluster; - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - if( pVarInCluster[kkk][iii] > maxProb ) { - maxProb = pVarInCluster[kkk][iii]; - maxCluster = kkk; - } - } - if( isKnown ) { - numMaxClusterKnown[maxCluster]++; - } else { - numMaxClusterNovel[maxCluster]++; - } } - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - clusterTITV[kkk] = probTi[kkk] / probTv[kkk]; - if( probKnown[kkk] > 600.0 ) { // BUGBUG: make this a command line argument, parameterize performance based on this important argument - clusterTruePositiveRate[kkk] = calcTruePositiveRateFromKnownTITV( probKnownTi[kkk] / probKnownTv[kkk], probNovelTi[kkk] / probNovelTv[kkk], clusterTITV[kkk], knownAlphaFactor ); - } else { - clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] ); + for( int ttt = 0; ttt < 3; ttt++ ) { + double knownAlphaFactor = 0.0; + if( ttt == 0 ) { + knownAlphaFactor = 0.0; + } else if( ttt == 1 ) { + knownAlphaFactor = 1.0; + } else if( ttt == 2 ) { + knownAlphaFactor = 0.5; + } + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + if( !deadCluster[kkk] ) { + clusterTITV[kkk] = probTi[kkk] / probTv[kkk]; + if( probKnown[kkk] > 500.0 && probNovel[kkk] > 500.0 ) { + clusterTruePositiveRate[kkk] = calcTruePositiveRateFromKnownTITV( probKnownTi[kkk] / probKnownTv[kkk], probNovelTi[kkk] / probNovelTv[kkk], clusterTITV[kkk], knownAlphaFactor ); + } else { + clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] ); + } + } + } + + if( ttt == 0 ) { + printClusterParamters( clusterFileName + ".TargetTiTv." + iterationNumber ); + } else if( ttt == 1 ) { + printClusterParamters( clusterFileName + ".KnownTiTv." + iterationNumber ); + } else if( ttt == 2 ) { + printClusterParamters( clusterFileName + ".BlendedTiTv." + iterationNumber ); } } } -*/ - - // This cluster method doesn't make use of the differences between known and novel Ti/Tv ratios - - - public final void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster ) { - - final int numVariants = data.length; - final int numAnnotations = data[0].annotations.length; - - final double[][] pVarInCluster = new double[numGaussians][numVariants]; // Probability that the variant is in that cluster = simply evaluate the multivariate Gaussian - final double[] probTi = new double[numGaussians]; - final double[] probTv = new double[numGaussians]; - - // loop control variables: - // iii - loop over data points - // jjj - loop over annotations (features) - // kkk - loop over clusters - // ttt - loop over EM iterations - - // Set up the initial random Gaussians - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - numMaxClusterKnown[kkk] = 0; - numMaxClusterNovel[kkk] = 0; - pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster)); - mu[kkk] = data[rand.nextInt(numVariants)].annotations; - final double[] randSigma = new double[numAnnotations]; - if( dataManager.isNormalized ) { - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble(); - } - } else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = dataManager.varianceVector[jjj] + ((1.0 + rand.nextDouble()) * 0.01 * dataManager.varianceVector[jjj]); - } - } - sigma[kkk] = randSigma; - } - - // The EM loop - for( int ttt = 0; ttt < numIterations; ttt++ ) { - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Expectation Step (calculate the probability that each data point is in each cluster) - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Maximization Step (move the clusters to maximize the sum probability of each data point) - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - maximizeGaussians( data, pVarInCluster, startCluster, stopCluster ); - - System.out.println("Finished iteration " + (ttt+1) ); - } - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Evaluate the clusters using titv as an estimate of the true positive rate - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); // One final evaluation because the Gaussians moved in the last maximization step - - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - probTi[kkk] = 0.0; - probTv[kkk] = 0.0; - } - // Use the cluster's probabilistic Ti/Tv ratio as the indication of the cluster's true positive rate - for( int iii = 0; iii < numVariants; iii++ ) { - if( data[iii].isTransition ) { // transition - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - probTi[kkk] += pVarInCluster[kkk][iii]; - } - } else { // transversion - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - probTv[kkk] += pVarInCluster[kkk][iii]; - } - } - - // Calculate which cluster has the maximum probability for this variant for use as a metric of how well clustered the data is - double maxProb = pVarInCluster[startCluster][iii]; - int maxCluster = startCluster; - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - if( pVarInCluster[kkk][iii] > maxProb ) { - maxProb = pVarInCluster[kkk][iii]; - maxCluster = kkk; - } - } - if( data[iii].isKnown ) { - numMaxClusterKnown[maxCluster]++; - } else { - numMaxClusterNovel[maxCluster]++; - } - } - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - clusterTITV[kkk] = probTi[kkk] / probTv[kkk]; - clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] ); - } - } - - - - private void printClusters( final String clusterFileName ) { + private void printClusterParamters( final String clusterFileName ) { try { final PrintStream outputFile = new PrintStream( clusterFileName ); dataManager.printClusterFileHeader( outputFile ); final int numAnnotations = mu[0].length; + final int numVariants = dataManager.numVariants; for( int kkk = 0; kkk < numGaussians; kkk++ ) { - if( numMaxClusterKnown[kkk] + numMaxClusterNovel[kkk] >= minVarInCluster ) { - outputFile.print("@!CLUSTER,"); - outputFile.print( (kkk < numGaussians / 2 ? 1 : 0) + "," ); // is het cluster? - outputFile.print(pCluster[kkk] + ","); - outputFile.print(numMaxClusterKnown[kkk] + ","); - outputFile.print(numMaxClusterNovel[kkk] + ","); - outputFile.print(clusterTITV[kkk] + ","); - outputFile.print(clusterTruePositiveRate[kkk] + ","); - for(int jjj = 0; jjj < numAnnotations; jjj++ ) { - outputFile.print(mu[kkk][jjj] + ","); + if( !deadCluster[kkk] ) { + if( pCluster[kkk] * numVariants > minVarInCluster ) { + final double sigmaVals[][] = sigma[kkk].getArray(); + outputFile.print("@!CLUSTER,"); + outputFile.print(pCluster[kkk] + ","); + outputFile.print(clusterTITV[kkk] + ","); + outputFile.print(clusterTruePositiveRate[kkk] + ","); + for(int jjj = 0; jjj < numAnnotations; jjj++ ) { + outputFile.print(mu[kkk][jjj] + ","); + } + for(int jjj = 0; jjj < numAnnotations; jjj++ ) { + for(int ppp = 0; ppp < numAnnotations; ppp++ ) { + outputFile.print(sigmaVals[jjj][ppp] + ","); + } + } + outputFile.println(-1); } - for(int jjj = 0; jjj < numAnnotations; jjj++ ) { - outputFile.print(sigma[kkk][jjj] + ","); - } - outputFile.println(-1); } } outputFile.close(); @@ -541,7 +428,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } } - public final double evaluateVariant( final Map annotationMap, final double qualityScore, final boolean isHet ) { + public final double evaluateVariant( final Map annotationMap, final double qualityScore ) { final double[] pVarInCluster = new double[numGaussians]; final double[] annotations = new double[dataManager.numAnnotations]; @@ -550,6 +437,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final String annotationKey = dataManager.annotationKeys.get(jjj); if( annotationKey.equals("QUAL") ) { value = qualityScore; + } else if( annotationKey.equals("AB") && !annotationMap.containsKey(annotationKey) ) { + value = (0.5 - 0.005) + (0.01 * Math.random()); // HomVar calls don't have an allele balance } else { try { final Object stringValue = annotationMap.get( annotationKey ); @@ -567,16 +456,54 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel annotations[jjj] = (value - dataManager.meanVector[jjj]) / dataManager.varianceVector[jjj]; } - evaluateGaussiansForSingleVariant( annotations, pVarInCluster, isHet ); + evaluateGaussiansForSingleVariant( annotations, pVarInCluster ); + //if( isUsingTiTvModel ) { + // Sum prob model + double sum = 0.0; + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk]; + } + return sum; + /* + } else { + // Max prob model + double maxProb = 0.0; + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + if( pVarInCluster[kkk] > maxProb ) { + maxProb = pVarInCluster[kkk]; + } + } + return maxProb; + } + */ + + // Max prob model + /* + double maxProb = 0.0; + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + if( pVarInCluster[kkk] > maxProb ) { + maxProb = pVarInCluster[kkk]; + } + } + return maxProb; + */ + + // Entropy model + /* double sum = 0.0; for( int kkk = 0; kkk < numGaussians; kkk++ ) { //if( isHetCluster[kkk] == isHet ) { - sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk]; + sum += pVarInCluster[kkk] * Math.log(pVarInCluster[kkk]); //} } - return sum; + double entropy = -1.0 * sum; + double maxEntropy = -1.0 * Math.log( 1.0 / ((double) numGaussians)); + + //System.out.println("H = " + entropy + ", pTrue = " + ( 1.0 - (entropy / maxEntropy) )); + return ( 1.0 - (entropy / maxEntropy) ); + */ } public final void outputOptimizationCurve( final VariantDatum[] data, final String outputPrefix, final int desiredNumVariants ) { @@ -614,7 +541,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel boolean foundDesiredNumVariants = false; int jjj = 0; outputFile.println("pCut,numKnown,numNovel,knownTITV,novelTITV"); - for( double qCut = MAX_QUAL; qCut >= 0.0; qCut -= QUAL_STEP ) { + for( double qCut = MAX_QUAL; qCut >= -0.001; qCut -= QUAL_STEP ) { for( int iii = 0; iii < numVariants; iii++ ) { if( !markedVariant[iii] ) { if( data[iii].qual >= qCut ) { @@ -638,12 +565,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } } if( desiredNumVariants != 0 && !foundDesiredNumVariants && (numKnown + numNovel) >= desiredNumVariants ) { - System.out.println( "Keeping variants with QUAL >= " + String.format("%.1f",qCut) + " results in a filtered set with: " ); - System.out.println("\t" + numKnown + " known variants"); - System.out.println("\t" + numNovel + " novel variants, (dbSNP rate = " + String.format("%.2f",((double) numKnown * 100.0) / ((double) numKnown + numNovel) ) + "%)"); - System.out.println("\t" + String.format("%.4f known Ti/Tv ratio", ((double)numKnownTi) / ((double)numKnownTv))); - System.out.println("\t" + String.format("%.4f novel Ti/Tv ratio", ((double)numNovelTi) / ((double)numNovelTv))); - System.out.println(); + logger.info( "Keeping variants with QUAL >= " + String.format("%.1f",qCut) + " results in a filtered set with: " ); + logger.info("\t" + numKnown + " known variants"); + logger.info("\t" + numNovel + " novel variants, (dbSNP rate = " + String.format("%.2f",((double) numKnown * 100.0) / ((double) numKnown + numNovel) ) + "%)"); + logger.info("\t" + String.format("%.4f known Ti/Tv ratio", ((double)numKnownTi) / ((double)numKnownTv))); + logger.info("\t" + String.format("%.4f novel Ti/Tv ratio", ((double)numNovelTi) / ((double)numNovelTv))); foundDesiredNumVariants = true; } outputFile.println( qCut + "," + numKnown + "," + numNovel + "," + @@ -687,13 +613,12 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } if( foundCut ) { - System.out.println( "Keeping variants with QUAL >= " + String.format("%.1f",theCut[jjj]) + " results in a filtered set with: " ); - System.out.println("\t" + numKnownAtCut[jjj] + " known variants"); - System.out.println("\t" + numNovelAtCut[jjj] + " novel variants, (dbSNP rate = " + + logger.info( "Keeping variants with QUAL >= " + String.format("%.1f",theCut[jjj]) + " results in a filtered set with: " ); + logger.info("\t" + numKnownAtCut[jjj] + " known variants"); + logger.info("\t" + numNovelAtCut[jjj] + " novel variants, (dbSNP rate = " + String.format("%.2f",((double) numKnownAtCut[jjj] * 100.0) / ((double) numKnownAtCut[jjj] + numNovelAtCut[jjj]) ) + "%)"); - System.out.println("\t" + String.format("%.4f known Ti/Tv ratio", knownTiTvAtCut[jjj])); - System.out.println("\t" + String.format("%.4f novel Ti/Tv ratio", novelTiTvAtCut[jjj])); - System.out.println(); + logger.info("\t" + String.format("%.4f known Ti/Tv ratio", knownTiTvAtCut[jjj])); + logger.info("\t" + String.format("%.4f novel Ti/Tv ratio", novelTiTvAtCut[jjj])); } } @@ -704,63 +629,115 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel private void evaluateGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int startCluster, final int stopCluster ) { final int numAnnotations = data[0].annotations.length; - + double likelihood = 0.0; + final double sigmaVals[][][] = new double[numGaussians][][]; + final double denom[] = new double[numGaussians]; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + if( !deadCluster[kkk] ) { + sigmaVals[kkk] = sigma[kkk].inverse().getArray(); + denom[kkk] = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(Math.abs(determinant[kkk]), 0.5); + } + } + final double mult[] = new double[numAnnotations]; for( int iii = 0; iii < data.length; iii++ ) { double sumProb = 0.0; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - double sum = 0.0; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - sum += ( (data[iii].annotations[jjj] - mu[kkk][jjj]) * (data[iii].annotations[jjj] - mu[kkk][jjj]) ) - / sigma[kkk][jjj]; - } - pVarInCluster[kkk][iii] = pCluster[kkk] * Math.exp( -0.5 * sum ); - - if( pVarInCluster[kkk][iii] < MIN_PROB) { // Very small numbers are a very big problem - pVarInCluster[kkk][iii] = MIN_PROB; - } + if( !deadCluster[kkk] ) { + double sum = 0.0; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mult[jjj] = 0.0; + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + mult[jjj] += (data[iii].annotations[ppp] - mu[kkk][ppp]) * sigmaVals[kkk][ppp][jjj]; + } + } + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + sum += mult[jjj] * (data[iii].annotations[jjj] - mu[kkk][jjj]); + } - sumProb += pVarInCluster[kkk][iii]; + pVarInCluster[kkk][iii] = pCluster[kkk] * (Math.exp( -0.5 * sum ) / denom[kkk]); + likelihood += pVarInCluster[kkk][iii]; + if(Double.isNaN(denom[kkk]) || determinant[kkk] < 0.5 * MIN_DETERMINANT) { + System.out.println("det = " + sigma[kkk].det()); + System.out.println("denom = " + denom[kkk]); + System.out.println("sumExp = " + sum); + System.out.println("pVar = " + pVarInCluster[kkk][iii]); + System.out.println("=-------="); + throw new StingException("Numerical Instability! determinant of covariance matrix <= 0. Try running with fewer clusters and then with better behaved annotation values."); + } + if(sum < 0.0) { + System.out.println("det = " + sigma[kkk].det()); + System.out.println("denom = " + denom[kkk]); + System.out.println("sumExp = " + sum); + System.out.println("pVar = " + pVarInCluster[kkk][iii]); + System.out.println("=-------="); + throw new StingException("Numerical Instability! covariance matrix no longer positive definite. Try running with fewer clusters and then with better behaved annotation values."); + } + if(pVarInCluster[kkk][iii] > 1.0) { + System.out.println("det = " + sigma[kkk].det()); + System.out.println("denom = " + denom[kkk]); + System.out.println("sumExp = " + sum); + System.out.println("pVar = " + pVarInCluster[kkk][iii]); + System.out.println("=-------="); + throw new StingException("Numerical Instability! probability distribution returns > 1.0. Try running with fewer clusters and then with better behaved annotation values."); + } + + if( pVarInCluster[kkk][iii] < MIN_PROB) { // Very small numbers are a very big problem + pVarInCluster[kkk][iii] = MIN_PROB;// + MIN_PROB * rand.nextDouble(); + } + + sumProb += pVarInCluster[kkk][iii]; + } } - if( sumProb > MIN_SUM_PROB ) { // Very small numbers are a very big problem - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - pVarInCluster[kkk][iii] /= sumProb; - } + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + pVarInCluster[kkk][iii] /= sumProb; } + } + + logger.info("Explained likelihood = " + String.format("%.5f",likelihood / ((double) data.length))); } - private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster, final boolean isHet ) { + private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster ) { final int numAnnotations = annotations.length; double sumProb = 0.0; + final double mult[] = new double[numAnnotations]; for( int kkk = 0; kkk < numGaussians; kkk++ ) { - //if( isHetCluster[kkk] == isHet ) { - double sum = 0.0; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - sum += ( (annotations[jjj] - mu[kkk][jjj]) * (annotations[jjj] - mu[kkk][jjj]) ) - / sigma[kkk][jjj]; + final double sigmaVals[][] = sigmaInverse[kkk].getArray(); + double sum = 0.0; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mult[jjj] = 0.0; + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + mult[jjj] += (annotations[ppp] - mu[kkk][ppp]) * sigmaVals[ppp][jjj]; } + } + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + sum += mult[jjj] * (annotations[jjj] - mu[kkk][jjj]); + } - // BUGBUG: reverting to old version that didn't have pCluster[kkk]* here, this meant that the overfitting parameters changed meanings - //pVarInCluster[kkk] = pCluster[kkk] * Math.exp( -0.5 * sum ); - pVarInCluster[kkk] = Math.exp( -0.5 * sum ); + final double denom = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(determinant[kkk], 0.5); + pVarInCluster[kkk] = (1.0 / ((double) numGaussians)) * (Math.exp( -0.5 * sum )) / denom; + if( isUsingTiTvModel ) { + //pVarInCluster[kkk] = Math.exp( -0.5 * sum ); if( pVarInCluster[kkk] < MIN_PROB) { // Very small numbers are a very big problem pVarInCluster[kkk] = MIN_PROB; } - sumProb += pVarInCluster[kkk]; - //} + } else { + //final double denom = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(determinant[kkk], 0.5); + //pVarInCluster[kkk] = pCluster[kkk] * (Math.exp( -0.5 * sum )) / denom; + //pVarInCluster[kkk] = Math.exp( -0.5 * sum ); + // BUGBUG: should pCluster be the distribution from the GMM or a uniform distribution here? + } } - if( sumProb > MIN_SUM_PROB ) { // Very small numbers are a very big problem + if( isUsingTiTvModel ) { for( int kkk = 0; kkk < numGaussians; kkk++ ) { - //if( isHetCluster[kkk] == isHet ) { - pVarInCluster[kkk] /= sumProb; - //} + pVarInCluster[kkk] /= sumProb; } } } @@ -770,58 +747,79 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final int numVariants = data.length; final int numAnnotations = data[0].annotations.length; + final double sigmaVals[][][] = new double[numGaussians][numAnnotations][numAnnotations]; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { mu[kkk][jjj] = 0.0; - sigma[kkk][jjj] = 0.0; - } - } - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - double sumProb = 0.0; - for( int iii = 0; iii < numVariants; iii++ ) { - final double prob = pVarInCluster[kkk][iii]; - sumProb += prob; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - mu[kkk][jjj] += prob * data[iii].annotations[jjj]; + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] = 0.0; } } + } + double sumPK = 0.0; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + if( !deadCluster[kkk] ) { + double sumProb = 0.0; + for( int iii = 0; iii < numVariants; iii++ ) { + final double prob = pVarInCluster[kkk][iii]; + sumProb += prob; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mu[kkk][jjj] += prob * data[iii].annotations[jjj]; + } + } - if( sumProb > MIN_SUM_PROB ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { mu[kkk][jjj] /= sumProb; } - } //BUGBUG: clean up dead clusters to speed up computation - for( int iii = 0; iii < numVariants; iii++ ) { - final double prob = pVarInCluster[kkk][iii]; + for( int iii = 0; iii < numVariants; iii++ ) { + final double prob = pVarInCluster[kkk][iii]; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] += prob * (data[iii].annotations[jjj]-mu[kkk][jjj]) * (data[iii].annotations[ppp]-mu[kkk][ppp]); + } + } + } + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - sigma[kkk][jjj] += prob * (data[iii].annotations[jjj]-mu[kkk][jjj]) * (data[iii].annotations[jjj]-mu[kkk][jjj]); + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + if( sigmaVals[kkk][jjj][ppp] < MIN_SIGMA ) { // Very small numbers are a very big problem + sigmaVals[kkk][jjj][ppp] = MIN_SIGMA;// + MIN_SIGMA * rand.nextDouble(); + } + sigmaVals[kkk][ppp][jjj] = sigmaVals[kkk][jjj][ppp]; // sigma must be a symmetric matrix + } } - } - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - if( sigma[kkk][jjj] < MIN_PROB) { // Very small numbers are a very big problem - sigma[kkk][jjj] = MIN_PROB; - } - } - - if( sumProb > MIN_SUM_PROB ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - sigma[kkk][jjj] /= sumProb; + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] /= sumProb; + } + } + + sigma[kkk] = new Matrix(sigmaVals[kkk]); + determinant[kkk] = sigma[kkk].det(); + //if( determinant[kkk] < MIN_DETERMINANT ) { deadCluster[kkk] = true; } + + if( !deadCluster[kkk] ) { + pCluster[kkk] = sumProb / numVariants; + sumPK += pCluster[kkk]; } } - - pCluster[kkk] = sumProb / numVariants; } + // ensure pCluster sums to one, it doesn't automatically due to very small numbers getting capped + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + pCluster[kkk] /= sumPK; + } + /* // Clean up extra big or extra small clusters - //BUGBUG: Is this a good idea? for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { if( pCluster[kkk] > 0.45 ) { // This is a very large cluster compared to all the others - final int numToReplace = 4; - final double[] savedSigma = sigma[kkk]; + System.out.println("!! Found very large cluster! Busting it up into smaller clusters."); + final int numToReplace = 3; + final Matrix savedSigma = sigma[kkk]; for( int rrr = 0; rrr < numToReplace; rrr++ ) { // Find an example variant in the large cluster, drawn randomly int randVarIndex = -1; @@ -854,38 +852,52 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } mu[minClusterIndex] = data[randVarIndex].annotations; sigma[minClusterIndex] = savedSigma; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - sigma[minClusterIndex][jjj] += -0.06 + 0.12 * rand.nextDouble(); - if( sigma[minClusterIndex][jjj] < MIN_SUM_PROB ) { - sigma[minClusterIndex][jjj] = MIN_SUM_PROB; - } - } - pCluster[minClusterIndex] = 1.0 / ((double) (stopCluster-startCluster)); + //for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + // for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + // sigma[minClusterIndex].set(jjj, ppp, sigma[minClusterIndex].get(jjj, ppp) - 0.06 + 0.12 * rand.nextDouble()); + // } + //} + pCluster[minClusterIndex] = 0.5 / ((double) (stopCluster-startCluster)); } } } } - + */ - // Replace small clusters with another random draw from the dataset + + // Replace extremely small clusters with another random draw from the dataset for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - if( pCluster[kkk] < 0.07 * (1.0 / ((double) (stopCluster-startCluster))) ) { // This is a very small cluster compared to all the others - pCluster[kkk] = 1.0 / ((double) (stopCluster-startCluster)); + //if(determinant[kkk] < MIN_DETERMINANT ) { + if( pCluster[kkk] < 0.0005 * (1.0 / ((double) (stopCluster-startCluster))) || + determinant[kkk] < MIN_DETERMINANT ) { // This is a very small cluster compared to all the others + logger.info("!! Found singular cluster! Initializing a new random cluster."); + pCluster[kkk] = 0.1 / ((double) (stopCluster-startCluster)); // 0.5 / + //final double[] randMu = new double[numAnnotations]; + //for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + // randMu[jjj] = -1.5 + 3.0 * rand.nextDouble(); + //} mu[kkk] = data[rand.nextInt(numVariants)].annotations; - final double[] randSigma = new double[numAnnotations]; - if( dataManager.isNormalized ) { - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble(); // BUGBUG: Explore a wider range of possible sigma values since we are tossing out clusters anyway? - } - } else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = dataManager.varianceVector[jjj] + ((1.0 + rand.nextDouble()) * 0.01 * dataManager.varianceVector[jjj]); + final double[][] randSigma = new double[numAnnotations][numAnnotations]; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + randSigma[ppp][jjj] = 0.50 + 0.5 * rand.nextDouble(); // data is normalized so this is centered at 1.0 + if(jjj != ppp) { randSigma[jjj][ppp] = 0.0; } // Sigma is a symmetric, positive-definite matrix } } - sigma[kkk] = randSigma; + Matrix tmp = new Matrix(randSigma); + tmp = tmp.times(tmp.transpose()); + sigma[kkk] = tmp; + determinant[kkk] = sigma[kkk].det(); } } - + // renormalize pCluster since things might have changed due to the previous step + sumPK = 0.0; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + sumPK += pCluster[kkk]; + } + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + pCluster[kkk] /= sumPK; + } } } \ No newline at end of file diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java index d2efbc1bc..637fe7daf 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java @@ -55,8 +55,6 @@ public class VariantOptimizer extends RodWalker ///////////////////////////// @Argument(fullName="target_titv", shortName="titv", doc="The target Ti/Tv ratio towards which to optimize. (~~2.1 for whole genome experiments)", required=true) private double TARGET_TITV = 2.1; - @Argument(fullName="known_alpha_factor", shortName="kFactor", doc="Percentage of true positive rate that is due to difference between known and novel titv in a cluster as opposed to difference from target titv", required=false) - private double KNOWN_ALPHA_FACTOR = 0.0; @Argument(fullName="ignore_all_input_filters", shortName="ignoreAllFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false) private boolean IGNORE_ALL_INPUT_FILTERS = false; @Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false) @@ -68,11 +66,11 @@ public class VariantOptimizer extends RodWalker @Argument(fullName="clusterFile", shortName="clusterFile", doc="The output cluster file", required=true) private String CLUSTER_FILENAME = "optimizer.cluster"; @Argument(fullName="numGaussians", shortName="nG", doc="The number of Gaussians to be used in the Gaussian Mixture model", required=false) - private int NUM_GAUSSIANS = 32; + private int NUM_GAUSSIANS = 26; @Argument(fullName="numIterations", shortName="nI", doc="The number of iterations to be performed in the Gaussian Mixture model", required=false) private int NUM_ITERATIONS = 10; @Argument(fullName="minVarInCluster", shortName="minVar", doc="The minimum number of variants in a cluster to be considered a valid cluster. It can be used to prevent overfitting.", required=false) - private int MIN_VAR_IN_CLUSTER = 2000; + private int MIN_VAR_IN_CLUSTER = 1000; //@Argument(fullName="knn", shortName="knn", doc="The number of nearest neighbors to be used in the k-Nearest Neighbors model", required=false) //private int NUM_KNN = 2000; @@ -122,6 +120,7 @@ public class VariantOptimizer extends RodWalker if( !vc.isFiltered() || IGNORE_ALL_INPUT_FILTERS || (ignoreInputFilterSet != null && ignoreInputFilterSet.containsAll(vc.getFilters())) ) { if( firstVariant ) { // This is the first variant encountered so set up the list of annotations annotationKeys.addAll( vc.getAttributes().keySet() ); + annotationKeys.add("QUAL"); if( annotationKeys.contains("ID") ) { annotationKeys.remove("ID"); } // ID field is added to the vc's INFO field? if( annotationKeys.contains("DB") ) { annotationKeys.remove("DB"); } if( EXCLUDED_ANNOTATIONS != null ) { @@ -134,7 +133,7 @@ public class VariantOptimizer extends RodWalker if( !annotationKeys.contains( forcedAnnotation ) ) { annotationKeys.add( forcedAnnotation ); } } } - numAnnotations = annotationKeys.size() + 1; // +1 for variant quality ("QUAL") + numAnnotations = annotationKeys.size(); annotationValues = new double[numAnnotations]; firstVariant = false; } @@ -143,26 +142,28 @@ public class VariantOptimizer extends RodWalker for( final String key : annotationKeys ) { double value = 0.0; - try { - value = Double.parseDouble( (String)vc.getAttribute( key, "0.0" ) ); - if( Double.isInfinite(value) ) { - value = ( value > 0 ? 1.0 : -1.0 ) * INFINITE_ANNOTATION_VALUE; + if( key.equals("AB") && !vc.getAttributes().containsKey(key) ) { + value = (0.5 - 0.005) + (0.01 * Math.random()); // HomVar calls don't have an allele balance + } else if( key.equals("QUAL") ) { + value = vc.getPhredScaledQual(); + } else { + try { + value = Double.parseDouble( (String)vc.getAttribute( key, "0.0" ) ); + if( Double.isInfinite(value) ) { + value = ( value > 0 ? 1.0 : -1.0 ) * INFINITE_ANNOTATION_VALUE; + } + } catch( NumberFormatException e ) { + // do nothing, default value is 0.0 } - } catch( NumberFormatException e ) { - // do nothing, default value is 0.0 } annotationValues[iii++] = value; } - // Variant quality ("QUAL") is not in the list of annotations, but is useful so add it here. - annotationValues[iii] = vc.getPhredScaledQual(); - final VariantDatum variantDatum = new VariantDatum(); variantDatum.annotations = annotationValues; variantDatum.isTransition = vc.getSNPSubstitutionType().compareTo(BaseUtils.BaseSubstitutionType.TRANSITION) == 0; variantDatum.isKnown = !vc.getAttribute("ID").equals("."); - variantDatum.isHet = vc.getHetCount() > vc.getHomVarCount(); // BUGBUG: what to do here for multi sample calls? - + mapList.add( variantDatum ); } } @@ -192,7 +193,7 @@ public class VariantOptimizer extends RodWalker reduceSum.clear(); // Don't need this ever again, clean up some memory logger.info( "There are " + dataManager.numVariants + " variants and " + dataManager.numAnnotations + " annotations." ); - logger.info( "The annotations are: " + annotationKeys + " and QUAL." ); + logger.info( "The annotations are: " + annotationKeys ); dataManager.normalizeData(); // Each data point is now [ (x - mean) / standard deviation ] @@ -200,7 +201,7 @@ public class VariantOptimizer extends RodWalker VariantOptimizationModel theModel; switch (OPTIMIZATION_MODEL) { case GAUSSIAN_MIXTURE_MODEL: - theModel = new VariantGaussianMixtureModel( dataManager, TARGET_TITV, NUM_GAUSSIANS, NUM_ITERATIONS, MIN_VAR_IN_CLUSTER, KNOWN_ALPHA_FACTOR ); + theModel = new VariantGaussianMixtureModel( dataManager, TARGET_TITV, NUM_GAUSSIANS, NUM_ITERATIONS, MIN_VAR_IN_CLUSTER ); break; //case K_NEAREST_NEIGHBORS: // theModel = new VariantNearestNeighborsModel( dataManager, TARGET_TITV, NUM_KNN );