diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java index f9b8780af..67e11bbe8 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java @@ -159,12 +159,17 @@ public class ApplyVariantClustersWalker extends RodWalker vc.getHomVarCount(); // BUGBUG: what to do here for multi sample calls? - final double pTrue = theModel.evaluateVariant( vc.getAttributes(), vc.getPhredScaledQual(), variantDatum.isHet ); - final double recalQual = QualityUtils.phredScaleErrorRate( Math.max(1.0 - pTrue, 0.000000001) ); + final double pTrue = theModel.evaluateVariant( vc.getAttributes(), vc.getPhredScaledQual() ); + double recalQual = QualityUtils.phredScaleErrorRate( Math.max(1.0 - pTrue, 0.000000001) ); + if( !theModel.isUsingTiTvModel ) { + recalQual *= 30.0; + } else { + recalQual *= 3.0; + } + // BUGBUG: decide how to scale the quality score - if( variantDatum.isKnown && KNOWN_VAR_QUAL_PRIOR > 0.1 ) { + if( variantDatum.isKnown && KNOWN_VAR_QUAL_PRIOR > 0.1 ) { // only use the known prior if the value is specified (meaning not equal to zero) variantDatum.qual = 0.5 * recalQual + 0.5 * KNOWN_VAR_QUAL_PRIOR; } else { variantDatum.qual = recalQual; diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java index 891f21443..103f29392 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java @@ -32,6 +32,6 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; */ public interface VariantClusteringModel extends VariantOptimizationInterface { - public void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster ); + public void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster, final String clusterFilename ); //public void applyClusters( final VariantDatum[] data, final String outputPrefix ); } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java index e67bc09fb..1eff46b9b 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java @@ -25,6 +25,7 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; +import org.apache.log4j.Logger; import org.broadinstitute.sting.utils.collections.ExpandingArrayList; import org.broadinstitute.sting.utils.StingException; @@ -37,6 +38,9 @@ import java.io.PrintStream; */ public class VariantDataManager { + + protected final static Logger logger = Logger.getLogger(VariantDataManager.class); + public final VariantDatum[] data; public final int numVariants; public final int numAnnotations; @@ -56,7 +60,7 @@ public class VariantDataManager { meanVector = null; varianceVector = null; } else { - numAnnotations = _annotationKeys.size() + 1; // +1 for QUAL + numAnnotations = _annotationKeys.size(); if( numAnnotations <= 0 ) { throw new StingException( "There are zero annotations! (or possibly a problem with integer overflow)" ); } @@ -91,11 +95,12 @@ public class VariantDataManager { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { final double theMean = mean(data, jjj); final double theSTD = standardDeviation(data, theMean, jjj); - System.out.println( (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj)) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); + logger.info( annotationKeys.get(jjj) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); if( theSTD < 1E-8 ) { foundZeroVarianceAnnotation = true; - System.out.println("Zero variance is a problem: standard deviation = " + theSTD); - System.out.println("User must -exclude annotations with zero variance. Annotation = " + (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj))); + logger.warn("Zero variance is a problem: standard deviation = " + theSTD + " User must -exclude annotations with zero variance. Annotation = " + (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj))); + } else if( theSTD < 1E-2 ) { + logger.warn("Warning! Tiny variance. It is strongly recommended that you -exclude " + annotationKeys.get(jjj)); } meanVector[jjj] = theMean; varianceVector[jjj] = theSTD; @@ -129,7 +134,7 @@ public class VariantDataManager { public void printClusterFileHeader( PrintStream outputFile ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - outputFile.println("@!ANNOTATION," + (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj)) + "," + meanVector[jjj] + "," + varianceVector[jjj]); + outputFile.println("@!ANNOTATION," + annotationKeys.get(jjj) + "," + meanVector[jjj] + "," + varianceVector[jjj]); } } } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java index 1ef1f1601..2b11b62d1 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java @@ -25,10 +25,13 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; +import org.apache.log4j.Logger; import org.broadinstitute.sting.utils.collections.ExpandingArrayList; import org.broadinstitute.sting.utils.StingException; import org.broadinstitute.sting.utils.text.XReadLines; +import Jama.*; + import java.io.File; import java.io.FileNotFoundException; import java.io.PrintStream; @@ -42,47 +45,50 @@ import java.util.regex.Pattern; * Date: Feb 26, 2010 */ -public final class VariantGaussianMixtureModel extends VariantOptimizationModel implements VariantClusteringModel { +public final class VariantGaussianMixtureModel extends VariantOptimizationModel { + protected final static Logger logger = Logger.getLogger(VariantGaussianMixtureModel.class); + public final VariantDataManager dataManager; private final int numGaussians; private final int numIterations; private final long RANDOM_SEED = 91801305; private final Random rand = new Random( RANDOM_SEED ); - private final double MIN_PROB = 1E-30; - private final double MIN_SUM_PROB = 1E-20; + private final double MIN_PROB = 1E-7; + private final double MIN_SIGMA = 1E-5; + private final double MIN_DETERMINANT = 1E-5; private final double[][] mu; // The means for the clusters - private final double[][] sigma; // The variances for the clusters, sigma is really sigma^2 + private final Matrix[] sigma; // The variances for the clusters, sigma is really sigma^2 + private final Matrix[] sigmaInverse; + private final boolean[] deadCluster; private final double[] pCluster; - //private final boolean[] isHetCluster; - private final int[] numMaxClusterKnown; - private final int[] numMaxClusterNovel; + private final double[] determinant; private final double[] clusterTITV; private final double[] clusterTruePositiveRate; // The true positive rate implied by the cluster's Ti/Tv ratio private final int minVarInCluster; - private final double knownAlphaFactor; + public final boolean isUsingTiTvModel; - private static final double INFINITE_ANNOTATION_VALUE = 10000.0; + private static final double INFINITE_ANNOTATION_VALUE = 6000.0; private static final Pattern ANNOTATION_PATTERN = Pattern.compile("^@!ANNOTATION.*"); private static final Pattern CLUSTER_PATTERN = Pattern.compile("^@!CLUSTER.*"); - public VariantGaussianMixtureModel( final VariantDataManager _dataManager, final double _targetTITV, final int _numGaussians, final int _numIterations, final int _minVarInCluster, final double _knownAlphaFactor ) { + public VariantGaussianMixtureModel( final VariantDataManager _dataManager, final double _targetTITV, final int _numGaussians, final int _numIterations, final int _minVarInCluster ) { super( _targetTITV ); dataManager = _dataManager; - numGaussians = ( _numGaussians % 2 == 0 ? _numGaussians : _numGaussians + 1 ); + numGaussians = _numGaussians; numIterations = _numIterations; mu = new double[numGaussians][]; - sigma = new double[numGaussians][]; + sigma = new Matrix[numGaussians]; + determinant = new double[numGaussians]; + deadCluster = new boolean[numGaussians]; pCluster = new double[numGaussians]; - //isHetCluster = null; - numMaxClusterKnown = new int[numGaussians]; - numMaxClusterNovel = new int[numGaussians]; clusterTITV = new double[numGaussians]; clusterTruePositiveRate = new double[numGaussians]; minVarInCluster = _minVarInCluster; - knownAlphaFactor = _knownAlphaFactor; + sigmaInverse = null; + isUsingTiTvModel = false; // this field isn't used during VariantOptimizerWalker } public VariantGaussianMixtureModel( final double _targetTITV, final String clusterFileName, final double backOffGaussianFactor ) { @@ -105,128 +111,114 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } dataManager = new VariantDataManager( annotationLines ); + // Several of the clustering parameters aren't used the second time around in ApplyVariantClusters numIterations = 0; - numMaxClusterKnown = null; - numMaxClusterNovel = null; clusterTITV = null; + deadCluster = null; minVarInCluster = 0; - knownAlphaFactor = 0.0; - //BUGBUG: move this parsing out of the constructor + // BUGBUG: move this parsing out of the constructor numGaussians = clusterLines.size(); mu = new double[numGaussians][dataManager.numAnnotations]; - sigma = new double[numGaussians][dataManager.numAnnotations]; + double sigmaVals[][][] = new double[numGaussians][dataManager.numAnnotations][dataManager.numAnnotations]; + sigma = new Matrix[numGaussians]; + sigmaInverse = new Matrix[numGaussians]; pCluster = new double[numGaussians]; - //isHetCluster = new boolean[numGaussians]; + determinant = new double[numGaussians]; clusterTruePositiveRate = new double[numGaussians]; + boolean _isUsingTiTvModel = false; int kkk = 0; for( String line : clusterLines ) { final String[] vals = line.split(","); - //isHetCluster[kkk] = Integer.parseInt(vals[1]) == 1; - pCluster[kkk] = Double.parseDouble(vals[2]); - clusterTruePositiveRate[kkk] = Double.parseDouble(vals[6]); //BUGBUG: #define these magic index numbers, very easy to make a mistake here + pCluster[kkk] = Double.parseDouble(vals[1]); + clusterTruePositiveRate[kkk] = Double.parseDouble(vals[3]); // BUGBUG: #define these magic index numbers, very easy to make a mistake here + if( clusterTruePositiveRate[kkk] != 1.0 ) { _isUsingTiTvModel = true; } for( int jjj = 0; jjj < dataManager.numAnnotations; jjj++ ) { - mu[kkk][jjj] = Double.parseDouble(vals[7+jjj]); - sigma[kkk][jjj] = Double.parseDouble(vals[7+dataManager.numAnnotations+jjj]) * backOffGaussianFactor; //BUGBUG: *3, suggestion by Nick to prevent GMM from over fitting and producing low likelihoods for most points + mu[kkk][jjj] = Double.parseDouble(vals[4+jjj]); + for( int ppp = 0; ppp < dataManager.numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] = Double.parseDouble(vals[4+dataManager.numAnnotations+(jjj*dataManager.numAnnotations)+ppp]) * backOffGaussianFactor; // BUGBUG: *3, suggestion by Nick to prevent GMM from over fitting and producing low likelihoods for most points + } } + + sigma[kkk] = new Matrix(sigmaVals[kkk]); + sigmaInverse[kkk] = sigma[kkk].inverse(); // Precompute all the inverses and determinants for use later + determinant[kkk] = sigma[kkk].det(); + //if( determinant[kkk] < MIN_DETERMINANT ) { determinant[kkk] = MIN_DETERMINANT; } kkk++; } + isUsingTiTvModel = _isUsingTiTvModel; - System.out.println("Found " + numGaussians + " clusters and using " + dataManager.numAnnotations + " annotations: " + dataManager.annotationKeys); + logger.info("Found " + numGaussians + " clusters and using " + dataManager.numAnnotations + " annotations: " + dataManager.annotationKeys); } public final void run( final String clusterFileName ) { - final int MAX_VARS = 1000000; //BUGBUG: make this a command line argument + final int MAX_KNOWN_VARS = 5000000; // BUGBUG: make this a command line argument + final int MAX_NOVEL_VARS = 5000000; // BUGBUG: make this a command line argument + final double knownNovelMixture = 1.5; // BUGBUG: make this a command line argument // Create the subset of the data to cluster with int numNovel = 0; int numKnown = 0; - int numHet = 0; - int numHom = 0; for( final VariantDatum datum : dataManager.data ) { if( datum.isKnown ) { numKnown++; } else { numNovel++; } - if( datum.isHet ) { - numHet++; - } else { - numHom++; - } } - // This block of code is used to cluster with novels + 1.5x knowns mixed together + final int numNovelCluster = Math.min( numNovel, MAX_NOVEL_VARS ); + final int numKnownCluster = Math.min( numKnown, MAX_KNOWN_VARS ); + final int numKnownTogether = Math.min( numKnownCluster, (int) Math.floor(knownNovelMixture * numNovelCluster) ); - VariantDatum[] data; + final VariantDatum[] dataTogether = new VariantDatum[numNovelCluster + numKnownTogether]; + final VariantDatum[] dataKnown = new VariantDatum[numKnownCluster]; - // Grab a set of data that is all of the novel variants plus 1.5x as many known variants drawn at random - // If there are almost as many novels as known, simply use all the variants - // BUGBUG: allow downsampling and arbitrary mixtures of knowns and novels - final int numSubset = (int)Math.floor(numNovel*2.5); - if( numSubset * 1.3 < dataManager.numVariants ) { - data = new VariantDatum[numSubset]; - int iii = 0; + // Create the dataTogether array, which is all the novels and 1.5x as many knowns, downsampled if there are too many + int iii = 0; + if( numNovelCluster == numNovel ) { for( final VariantDatum datum : dataManager.data ) { if( !datum.isKnown ) { - data[iii++] = datum; + dataTogether[iii++] = datum; } } - while( iii < numSubset ) { // grab an equal number of known variants at random + } else { + logger.info("Capped at " + MAX_NOVEL_VARS + " novel variants."); + while( iii < numNovelCluster ) { + final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; + if( !datum.isKnown ) { + dataTogether[iii++] = datum; + } + } + } + if( numKnownTogether == numKnown ) { + for( final VariantDatum datum : dataManager.data ) { + if( datum.isKnown ) { + dataTogether[iii++] = datum; + } + } + } else { + while( iii < numNovelCluster + numKnownTogether ) { final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; if( datum.isKnown ) { - data[iii++] = datum; - } - } - } else { - data = dataManager.data; - } - - System.out.println("Clustering with " + numNovel + " novel variants and " + (data.length - numNovel) + " known variants..."); - if( data.length == dataManager.numVariants ) { System.out.println(" (used all variants since 2.5*numNovel is so large compared to the full set) "); } - createClusters( data, 0, numGaussians ); // Using a subset of the data - System.out.println("Outputting cluster parameters..."); - printClusters( clusterFileName ); - - - - - - // This block of code is to cluster knowns and novels separately - /* - final VariantDatum[] dataNovel = new VariantDatum[Math.min(numNovel,MAX_VARS)]; - final VariantDatum[] dataKnown = new VariantDatum[Math.min(numKnown,MAX_VARS)]; - - //BUGBUG: This is ugly - int jjj = 0; - if(numNovel <= MAX_VARS) { - for( final VariantDatum datum : dataManager.data ) { - if( !datum.isKnown ) { - dataNovel[jjj++] = datum; - } - } - } else { - System.out.println("Capped at " + MAX_VARS + " novel variants."); - while( jjj < MAX_VARS ) { - final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; - if( !datum.isKnown ) { - dataNovel[jjj++] = datum; + dataTogether[iii++] = datum; } } } - int iii = 0; - if(numKnown <= MAX_VARS) { + // Create the dataKnown array, which is simply all the known vars or downsampled if there are too many + iii = 0; + if( numKnownCluster == numKnown ) { for( final VariantDatum datum : dataManager.data ) { if( datum.isKnown ) { dataKnown[iii++] = datum; } } } else { - System.out.println("Capped at " + MAX_VARS + " known variants."); - while( iii < MAX_VARS ) { + logger.info("Capped at " + MAX_KNOWN_VARS + " known variants."); + while( iii < numKnownCluster ) { final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; if( datum.isKnown ) { dataKnown[iii++] = datum; @@ -234,73 +226,95 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } } + final boolean useTITV = true; + logger.info("First, cluster with novels and knowns together to use ti/tv based models:"); + logger.info("Clustering with " + numNovelCluster + " novel variants and " + numKnownTogether + " known variants."); + createClusters( dataTogether, 0, numGaussians, clusterFileName, useTITV ); - System.out.println("Clustering with " + Math.min(numNovel,MAX_VARS) + " novel variants."); - createClusters( dataNovel, 0, numGaussians / 2 ); - System.out.println("Clustering with " + Math.min(numKnown,MAX_VARS) + " known variants."); - createClusters( dataKnown, numGaussians / 2, numGaussians ); - System.out.println("Outputting cluster parameters..."); - printClusters( clusterFileName ); - - */ - - /* - // This block of code is to cluster het and hom calls separately, but mixing together knowns and novels - final VariantDatum[] dataHet = new VariantDatum[Math.min(numHet,MAX_VARS)]; - final VariantDatum[] dataHom = new VariantDatum[Math.min(numHom,MAX_VARS)]; - - //BUGBUG: This is ugly - int jjj = 0; - if(numHet <= MAX_VARS) { - for( final VariantDatum datum : dataManager.data ) { - if( datum.isHet ) { - dataHet[jjj++] = datum; - } - } - } else { - System.out.println("Found " + numHet + " het variants but capped at clustering with " + MAX_VARS + "."); - while( jjj < MAX_VARS ) { - final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; - if( datum.isHet ) { - dataHet[jjj++] = datum; - } - } - } - - int iii = 0; - if(numHom <= MAX_VARS) { - for( final VariantDatum datum : dataManager.data ) { - if( !datum.isHet ) { - dataHom[iii++] = datum; - } - } - } else { - System.out.println("Found " + numHom + " hom variants but capped at clustering with " + MAX_VARS + "."); - while( iii < MAX_VARS ) { - final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; - if( !datum.isHet ) { - dataHom[iii++] = datum; - } - } - } - - System.out.println("Clustering with " + Math.min(numHet,MAX_VARS) + " het variants."); - createClusters( dataHet, 0, numGaussians / 2 ); - System.out.println("Clustering with " + Math.min(numHom,MAX_VARS) + " hom variants."); - createClusters( dataHom, numGaussians / 2, numGaussians ); - System.out.println("Outputting cluster parameters..."); - printClusters( clusterFileName ); -*/ + logger.info("Finally, cluster with only knowns to use ti/tv-less models:"); + logger.info("Clustering with " + numKnownCluster + " known variants."); + createClusters( dataKnown, 0, numGaussians, clusterFileName, !useTITV ); } - -/* - public final void createClusters( final VariantDatum[] data, int startCluster, int stopCluster ) { + public final void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster, final String clusterFileName, final boolean useTITV ) { final int numVariants = data.length; final int numAnnotations = data[0].annotations.length; - final double[][] pVarInCluster = new double[numGaussians][numVariants]; + final double[][] pVarInCluster = new double[numGaussians][numVariants]; // Probability that the variant is in that cluster = simply evaluate the multivariate Gaussian + + // loop control variables: + // iii - loop over data points + // jjj - loop over annotations (features) + // ppp - loop over annotations again (full rank covariance matrix) + // kkk - loop over clusters + // ttt - loop over EM iterations + + // Set up the initial random Gaussians + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + deadCluster[kkk] = false; + pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster)); + //final double[] randMu = new double[numAnnotations]; + //for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + // randMu[jjj] = -1.5 + 3.0 * rand.nextDouble(); + //} + mu[kkk] = data[rand.nextInt(numVariants)].annotations; + final double[][] randSigma = new double[numAnnotations][numAnnotations]; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + randSigma[ppp][jjj] = 0.5 + 0.5 * rand.nextDouble(); // data has been normalized so sigmas are centered at 1.0 + if(jjj != ppp) { randSigma[jjj][ppp] = 0.0; } // Sigma is a symmetric, positive-definite matrix + } + } + Matrix tmp = new Matrix(randSigma); + tmp = tmp.times(tmp.transpose()); + sigma[kkk] = tmp; + determinant[kkk] = sigma[kkk].det(); + //if( determinant[kkk] < MIN_DETERMINANT ) { deadCluster[kkk] = true; } + } + + // The EM loop + for( int ttt = 0; ttt < numIterations; ttt++ ) { + + //int numValidClusters = 0; + //for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + // if( !deadCluster[kkk] ) { numValidClusters++; } + //} + //logger.info("Starting iteration " + (ttt+1) + " with " + numValidClusters + " clusters."); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Expectation Step (calculate the probability that each data point is in each cluster) + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Maximization Step (move the clusters to maximize the sum probability of each data point) + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + maximizeGaussians( data, pVarInCluster, startCluster, stopCluster ); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Estimate each cluster's p(true) and output cluster parameters + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + outputGaussians( data, pVarInCluster, ttt+1, startCluster, stopCluster, clusterFileName, useTITV ); + + logger.info("Finished iteration " + (ttt+1) ); + } + } + + private void outputGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int iterationNumber, + final int startCluster, final int stopCluster, final String clusterFileName, final boolean useTITV ) { + + if( !useTITV ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + clusterTITV[kkk] = 0.0; + clusterTruePositiveRate[kkk] = 1.0; + } + printClusterParamters( clusterFileName + ".WithoutTiTv." + iterationNumber ); + return; + } + + final int numVariants = data.length; + final double[] probTi = new double[numGaussians]; final double[] probTv = new double[numGaussians]; final double[] probKnown = new double[numGaussians]; @@ -310,229 +324,102 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final double[] probNovelTi = new double[numGaussians]; final double[] probNovelTv = new double[numGaussians]; - // loop control variables: - // iii - loop over data points - // jjj - loop over annotations (features) - // kkk - loop over clusters - // ttt - loop over EM iterations - - // Set up the initial random Gaussians - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - numMaxClusterKnown[kkk] = 0; - numMaxClusterNovel[kkk] = 0; - pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster)); - mu[kkk] = data[rand.nextInt(numVariants)].annotations; - final double[] randSigma = new double[numAnnotations]; - if( dataManager.isNormalized ) { - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble(); - } - } else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = dataManager.varianceVector[jjj] + ((1.0 + rand.nextDouble()) * 0.01 * dataManager.varianceVector[jjj]); - } - } - sigma[kkk] = randSigma; - } - - for( int ttt = 0; ttt < numIterations; ttt++ ) { - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Expectation Step (calculate the probability that each data point is in each cluster) - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Maximization Step (move the clusters to maximize the sum probability of each data point) - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - maximizeGaussians( data, pVarInCluster, startCluster, stopCluster ); - - System.out.println("Finished iteration " + (ttt+1) ); - } - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Evaluate the clusters using titv as an estimate of the true positive rate - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); // One final evaluation because the Gaussians moved in the last maximization step - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { probTi[kkk] = 0.0; probTv[kkk] = 0.0; probKnown[kkk] = 0.0; probNovel[kkk] = 0.0; + probKnownTi[kkk] = 0.0; + probKnownTv[kkk] = 0.0; + probNovelTi[kkk] = 0.0; + probNovelTv[kkk] = 0.0; } + + // Use the cluster's probabilistic Ti/Tv ratio as the indication of the cluster's true positive rate for( int iii = 0; iii < numVariants; iii++ ) { final boolean isTransition = data[iii].isTransition; final boolean isKnown = data[iii].isKnown; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - final double prob = pVarInCluster[kkk][iii]; - if( isKnown ) { // known - probKnown[kkk] += prob; - if( isTransition ) { // transition - probKnownTi[kkk] += prob; - probTi[kkk] += prob; - } else { // transversion - probKnownTv[kkk] += prob; - probTv[kkk] += prob; - } - } else { //novel - probNovel[kkk] += prob; - if( isTransition ) { // transition - probNovelTi[kkk] += prob; - probTi[kkk] += prob; - } else { // transversion - probNovelTv[kkk] += prob; - probTv[kkk] += prob; + if( !deadCluster[kkk] ) { + final double prob = pVarInCluster[kkk][iii]; + if( isKnown ) { // known + probKnown[kkk] += prob; + if( isTransition ) { // transition + probKnownTi[kkk] += prob; + probTi[kkk] += prob; + } else { // transversion + probKnownTv[kkk] += prob; + probTv[kkk] += prob; + } + } else { //novel + probNovel[kkk] += prob; + if( isTransition ) { // transition + probNovelTi[kkk] += prob; + probTi[kkk] += prob; + } else { // transversion + probNovelTv[kkk] += prob; + probTv[kkk] += prob; + } } } } - - double maxProb = pVarInCluster[startCluster][iii]; - int maxCluster = startCluster; - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - if( pVarInCluster[kkk][iii] > maxProb ) { - maxProb = pVarInCluster[kkk][iii]; - maxCluster = kkk; - } - } - if( isKnown ) { - numMaxClusterKnown[maxCluster]++; - } else { - numMaxClusterNovel[maxCluster]++; - } } - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - clusterTITV[kkk] = probTi[kkk] / probTv[kkk]; - if( probKnown[kkk] > 600.0 ) { // BUGBUG: make this a command line argument, parameterize performance based on this important argument - clusterTruePositiveRate[kkk] = calcTruePositiveRateFromKnownTITV( probKnownTi[kkk] / probKnownTv[kkk], probNovelTi[kkk] / probNovelTv[kkk], clusterTITV[kkk], knownAlphaFactor ); - } else { - clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] ); + for( int ttt = 0; ttt < 3; ttt++ ) { + double knownAlphaFactor = 0.0; + if( ttt == 0 ) { + knownAlphaFactor = 0.0; + } else if( ttt == 1 ) { + knownAlphaFactor = 1.0; + } else if( ttt == 2 ) { + knownAlphaFactor = 0.5; + } + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + if( !deadCluster[kkk] ) { + clusterTITV[kkk] = probTi[kkk] / probTv[kkk]; + if( probKnown[kkk] > 500.0 && probNovel[kkk] > 500.0 ) { + clusterTruePositiveRate[kkk] = calcTruePositiveRateFromKnownTITV( probKnownTi[kkk] / probKnownTv[kkk], probNovelTi[kkk] / probNovelTv[kkk], clusterTITV[kkk], knownAlphaFactor ); + } else { + clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] ); + } + } + } + + if( ttt == 0 ) { + printClusterParamters( clusterFileName + ".TargetTiTv." + iterationNumber ); + } else if( ttt == 1 ) { + printClusterParamters( clusterFileName + ".KnownTiTv." + iterationNumber ); + } else if( ttt == 2 ) { + printClusterParamters( clusterFileName + ".BlendedTiTv." + iterationNumber ); } } } -*/ - - // This cluster method doesn't make use of the differences between known and novel Ti/Tv ratios - - - public final void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster ) { - - final int numVariants = data.length; - final int numAnnotations = data[0].annotations.length; - - final double[][] pVarInCluster = new double[numGaussians][numVariants]; // Probability that the variant is in that cluster = simply evaluate the multivariate Gaussian - final double[] probTi = new double[numGaussians]; - final double[] probTv = new double[numGaussians]; - - // loop control variables: - // iii - loop over data points - // jjj - loop over annotations (features) - // kkk - loop over clusters - // ttt - loop over EM iterations - - // Set up the initial random Gaussians - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - numMaxClusterKnown[kkk] = 0; - numMaxClusterNovel[kkk] = 0; - pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster)); - mu[kkk] = data[rand.nextInt(numVariants)].annotations; - final double[] randSigma = new double[numAnnotations]; - if( dataManager.isNormalized ) { - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble(); - } - } else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = dataManager.varianceVector[jjj] + ((1.0 + rand.nextDouble()) * 0.01 * dataManager.varianceVector[jjj]); - } - } - sigma[kkk] = randSigma; - } - - // The EM loop - for( int ttt = 0; ttt < numIterations; ttt++ ) { - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Expectation Step (calculate the probability that each data point is in each cluster) - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Maximization Step (move the clusters to maximize the sum probability of each data point) - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - maximizeGaussians( data, pVarInCluster, startCluster, stopCluster ); - - System.out.println("Finished iteration " + (ttt+1) ); - } - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Evaluate the clusters using titv as an estimate of the true positive rate - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); // One final evaluation because the Gaussians moved in the last maximization step - - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - probTi[kkk] = 0.0; - probTv[kkk] = 0.0; - } - // Use the cluster's probabilistic Ti/Tv ratio as the indication of the cluster's true positive rate - for( int iii = 0; iii < numVariants; iii++ ) { - if( data[iii].isTransition ) { // transition - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - probTi[kkk] += pVarInCluster[kkk][iii]; - } - } else { // transversion - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - probTv[kkk] += pVarInCluster[kkk][iii]; - } - } - - // Calculate which cluster has the maximum probability for this variant for use as a metric of how well clustered the data is - double maxProb = pVarInCluster[startCluster][iii]; - int maxCluster = startCluster; - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - if( pVarInCluster[kkk][iii] > maxProb ) { - maxProb = pVarInCluster[kkk][iii]; - maxCluster = kkk; - } - } - if( data[iii].isKnown ) { - numMaxClusterKnown[maxCluster]++; - } else { - numMaxClusterNovel[maxCluster]++; - } - } - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - clusterTITV[kkk] = probTi[kkk] / probTv[kkk]; - clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] ); - } - } - - - - private void printClusters( final String clusterFileName ) { + private void printClusterParamters( final String clusterFileName ) { try { final PrintStream outputFile = new PrintStream( clusterFileName ); dataManager.printClusterFileHeader( outputFile ); final int numAnnotations = mu[0].length; + final int numVariants = dataManager.numVariants; for( int kkk = 0; kkk < numGaussians; kkk++ ) { - if( numMaxClusterKnown[kkk] + numMaxClusterNovel[kkk] >= minVarInCluster ) { - outputFile.print("@!CLUSTER,"); - outputFile.print( (kkk < numGaussians / 2 ? 1 : 0) + "," ); // is het cluster? - outputFile.print(pCluster[kkk] + ","); - outputFile.print(numMaxClusterKnown[kkk] + ","); - outputFile.print(numMaxClusterNovel[kkk] + ","); - outputFile.print(clusterTITV[kkk] + ","); - outputFile.print(clusterTruePositiveRate[kkk] + ","); - for(int jjj = 0; jjj < numAnnotations; jjj++ ) { - outputFile.print(mu[kkk][jjj] + ","); + if( !deadCluster[kkk] ) { + if( pCluster[kkk] * numVariants > minVarInCluster ) { + final double sigmaVals[][] = sigma[kkk].getArray(); + outputFile.print("@!CLUSTER,"); + outputFile.print(pCluster[kkk] + ","); + outputFile.print(clusterTITV[kkk] + ","); + outputFile.print(clusterTruePositiveRate[kkk] + ","); + for(int jjj = 0; jjj < numAnnotations; jjj++ ) { + outputFile.print(mu[kkk][jjj] + ","); + } + for(int jjj = 0; jjj < numAnnotations; jjj++ ) { + for(int ppp = 0; ppp < numAnnotations; ppp++ ) { + outputFile.print(sigmaVals[jjj][ppp] + ","); + } + } + outputFile.println(-1); } - for(int jjj = 0; jjj < numAnnotations; jjj++ ) { - outputFile.print(sigma[kkk][jjj] + ","); - } - outputFile.println(-1); } } outputFile.close(); @@ -541,7 +428,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } } - public final double evaluateVariant( final Map annotationMap, final double qualityScore, final boolean isHet ) { + public final double evaluateVariant( final Map annotationMap, final double qualityScore ) { final double[] pVarInCluster = new double[numGaussians]; final double[] annotations = new double[dataManager.numAnnotations]; @@ -550,6 +437,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final String annotationKey = dataManager.annotationKeys.get(jjj); if( annotationKey.equals("QUAL") ) { value = qualityScore; + } else if( annotationKey.equals("AB") && !annotationMap.containsKey(annotationKey) ) { + value = (0.5 - 0.005) + (0.01 * Math.random()); // HomVar calls don't have an allele balance } else { try { final Object stringValue = annotationMap.get( annotationKey ); @@ -567,16 +456,54 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel annotations[jjj] = (value - dataManager.meanVector[jjj]) / dataManager.varianceVector[jjj]; } - evaluateGaussiansForSingleVariant( annotations, pVarInCluster, isHet ); + evaluateGaussiansForSingleVariant( annotations, pVarInCluster ); + //if( isUsingTiTvModel ) { + // Sum prob model + double sum = 0.0; + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk]; + } + return sum; + /* + } else { + // Max prob model + double maxProb = 0.0; + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + if( pVarInCluster[kkk] > maxProb ) { + maxProb = pVarInCluster[kkk]; + } + } + return maxProb; + } + */ + + // Max prob model + /* + double maxProb = 0.0; + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + if( pVarInCluster[kkk] > maxProb ) { + maxProb = pVarInCluster[kkk]; + } + } + return maxProb; + */ + + // Entropy model + /* double sum = 0.0; for( int kkk = 0; kkk < numGaussians; kkk++ ) { //if( isHetCluster[kkk] == isHet ) { - sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk]; + sum += pVarInCluster[kkk] * Math.log(pVarInCluster[kkk]); //} } - return sum; + double entropy = -1.0 * sum; + double maxEntropy = -1.0 * Math.log( 1.0 / ((double) numGaussians)); + + //System.out.println("H = " + entropy + ", pTrue = " + ( 1.0 - (entropy / maxEntropy) )); + return ( 1.0 - (entropy / maxEntropy) ); + */ } public final void outputOptimizationCurve( final VariantDatum[] data, final String outputPrefix, final int desiredNumVariants ) { @@ -614,7 +541,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel boolean foundDesiredNumVariants = false; int jjj = 0; outputFile.println("pCut,numKnown,numNovel,knownTITV,novelTITV"); - for( double qCut = MAX_QUAL; qCut >= 0.0; qCut -= QUAL_STEP ) { + for( double qCut = MAX_QUAL; qCut >= -0.001; qCut -= QUAL_STEP ) { for( int iii = 0; iii < numVariants; iii++ ) { if( !markedVariant[iii] ) { if( data[iii].qual >= qCut ) { @@ -638,12 +565,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } } if( desiredNumVariants != 0 && !foundDesiredNumVariants && (numKnown + numNovel) >= desiredNumVariants ) { - System.out.println( "Keeping variants with QUAL >= " + String.format("%.1f",qCut) + " results in a filtered set with: " ); - System.out.println("\t" + numKnown + " known variants"); - System.out.println("\t" + numNovel + " novel variants, (dbSNP rate = " + String.format("%.2f",((double) numKnown * 100.0) / ((double) numKnown + numNovel) ) + "%)"); - System.out.println("\t" + String.format("%.4f known Ti/Tv ratio", ((double)numKnownTi) / ((double)numKnownTv))); - System.out.println("\t" + String.format("%.4f novel Ti/Tv ratio", ((double)numNovelTi) / ((double)numNovelTv))); - System.out.println(); + logger.info( "Keeping variants with QUAL >= " + String.format("%.1f",qCut) + " results in a filtered set with: " ); + logger.info("\t" + numKnown + " known variants"); + logger.info("\t" + numNovel + " novel variants, (dbSNP rate = " + String.format("%.2f",((double) numKnown * 100.0) / ((double) numKnown + numNovel) ) + "%)"); + logger.info("\t" + String.format("%.4f known Ti/Tv ratio", ((double)numKnownTi) / ((double)numKnownTv))); + logger.info("\t" + String.format("%.4f novel Ti/Tv ratio", ((double)numNovelTi) / ((double)numNovelTv))); foundDesiredNumVariants = true; } outputFile.println( qCut + "," + numKnown + "," + numNovel + "," + @@ -687,13 +613,12 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } if( foundCut ) { - System.out.println( "Keeping variants with QUAL >= " + String.format("%.1f",theCut[jjj]) + " results in a filtered set with: " ); - System.out.println("\t" + numKnownAtCut[jjj] + " known variants"); - System.out.println("\t" + numNovelAtCut[jjj] + " novel variants, (dbSNP rate = " + + logger.info( "Keeping variants with QUAL >= " + String.format("%.1f",theCut[jjj]) + " results in a filtered set with: " ); + logger.info("\t" + numKnownAtCut[jjj] + " known variants"); + logger.info("\t" + numNovelAtCut[jjj] + " novel variants, (dbSNP rate = " + String.format("%.2f",((double) numKnownAtCut[jjj] * 100.0) / ((double) numKnownAtCut[jjj] + numNovelAtCut[jjj]) ) + "%)"); - System.out.println("\t" + String.format("%.4f known Ti/Tv ratio", knownTiTvAtCut[jjj])); - System.out.println("\t" + String.format("%.4f novel Ti/Tv ratio", novelTiTvAtCut[jjj])); - System.out.println(); + logger.info("\t" + String.format("%.4f known Ti/Tv ratio", knownTiTvAtCut[jjj])); + logger.info("\t" + String.format("%.4f novel Ti/Tv ratio", novelTiTvAtCut[jjj])); } } @@ -704,63 +629,115 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel private void evaluateGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int startCluster, final int stopCluster ) { final int numAnnotations = data[0].annotations.length; - + double likelihood = 0.0; + final double sigmaVals[][][] = new double[numGaussians][][]; + final double denom[] = new double[numGaussians]; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + if( !deadCluster[kkk] ) { + sigmaVals[kkk] = sigma[kkk].inverse().getArray(); + denom[kkk] = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(Math.abs(determinant[kkk]), 0.5); + } + } + final double mult[] = new double[numAnnotations]; for( int iii = 0; iii < data.length; iii++ ) { double sumProb = 0.0; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - double sum = 0.0; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - sum += ( (data[iii].annotations[jjj] - mu[kkk][jjj]) * (data[iii].annotations[jjj] - mu[kkk][jjj]) ) - / sigma[kkk][jjj]; - } - pVarInCluster[kkk][iii] = pCluster[kkk] * Math.exp( -0.5 * sum ); - - if( pVarInCluster[kkk][iii] < MIN_PROB) { // Very small numbers are a very big problem - pVarInCluster[kkk][iii] = MIN_PROB; - } + if( !deadCluster[kkk] ) { + double sum = 0.0; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mult[jjj] = 0.0; + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + mult[jjj] += (data[iii].annotations[ppp] - mu[kkk][ppp]) * sigmaVals[kkk][ppp][jjj]; + } + } + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + sum += mult[jjj] * (data[iii].annotations[jjj] - mu[kkk][jjj]); + } - sumProb += pVarInCluster[kkk][iii]; + pVarInCluster[kkk][iii] = pCluster[kkk] * (Math.exp( -0.5 * sum ) / denom[kkk]); + likelihood += pVarInCluster[kkk][iii]; + if(Double.isNaN(denom[kkk]) || determinant[kkk] < 0.5 * MIN_DETERMINANT) { + System.out.println("det = " + sigma[kkk].det()); + System.out.println("denom = " + denom[kkk]); + System.out.println("sumExp = " + sum); + System.out.println("pVar = " + pVarInCluster[kkk][iii]); + System.out.println("=-------="); + throw new StingException("Numerical Instability! determinant of covariance matrix <= 0. Try running with fewer clusters and then with better behaved annotation values."); + } + if(sum < 0.0) { + System.out.println("det = " + sigma[kkk].det()); + System.out.println("denom = " + denom[kkk]); + System.out.println("sumExp = " + sum); + System.out.println("pVar = " + pVarInCluster[kkk][iii]); + System.out.println("=-------="); + throw new StingException("Numerical Instability! covariance matrix no longer positive definite. Try running with fewer clusters and then with better behaved annotation values."); + } + if(pVarInCluster[kkk][iii] > 1.0) { + System.out.println("det = " + sigma[kkk].det()); + System.out.println("denom = " + denom[kkk]); + System.out.println("sumExp = " + sum); + System.out.println("pVar = " + pVarInCluster[kkk][iii]); + System.out.println("=-------="); + throw new StingException("Numerical Instability! probability distribution returns > 1.0. Try running with fewer clusters and then with better behaved annotation values."); + } + + if( pVarInCluster[kkk][iii] < MIN_PROB) { // Very small numbers are a very big problem + pVarInCluster[kkk][iii] = MIN_PROB;// + MIN_PROB * rand.nextDouble(); + } + + sumProb += pVarInCluster[kkk][iii]; + } } - if( sumProb > MIN_SUM_PROB ) { // Very small numbers are a very big problem - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - pVarInCluster[kkk][iii] /= sumProb; - } + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + pVarInCluster[kkk][iii] /= sumProb; } + } + + logger.info("Explained likelihood = " + String.format("%.5f",likelihood / ((double) data.length))); } - private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster, final boolean isHet ) { + private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster ) { final int numAnnotations = annotations.length; double sumProb = 0.0; + final double mult[] = new double[numAnnotations]; for( int kkk = 0; kkk < numGaussians; kkk++ ) { - //if( isHetCluster[kkk] == isHet ) { - double sum = 0.0; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - sum += ( (annotations[jjj] - mu[kkk][jjj]) * (annotations[jjj] - mu[kkk][jjj]) ) - / sigma[kkk][jjj]; + final double sigmaVals[][] = sigmaInverse[kkk].getArray(); + double sum = 0.0; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mult[jjj] = 0.0; + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + mult[jjj] += (annotations[ppp] - mu[kkk][ppp]) * sigmaVals[ppp][jjj]; } + } + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + sum += mult[jjj] * (annotations[jjj] - mu[kkk][jjj]); + } - // BUGBUG: reverting to old version that didn't have pCluster[kkk]* here, this meant that the overfitting parameters changed meanings - //pVarInCluster[kkk] = pCluster[kkk] * Math.exp( -0.5 * sum ); - pVarInCluster[kkk] = Math.exp( -0.5 * sum ); + final double denom = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(determinant[kkk], 0.5); + pVarInCluster[kkk] = (1.0 / ((double) numGaussians)) * (Math.exp( -0.5 * sum )) / denom; + if( isUsingTiTvModel ) { + //pVarInCluster[kkk] = Math.exp( -0.5 * sum ); if( pVarInCluster[kkk] < MIN_PROB) { // Very small numbers are a very big problem pVarInCluster[kkk] = MIN_PROB; } - sumProb += pVarInCluster[kkk]; - //} + } else { + //final double denom = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(determinant[kkk], 0.5); + //pVarInCluster[kkk] = pCluster[kkk] * (Math.exp( -0.5 * sum )) / denom; + //pVarInCluster[kkk] = Math.exp( -0.5 * sum ); + // BUGBUG: should pCluster be the distribution from the GMM or a uniform distribution here? + } } - if( sumProb > MIN_SUM_PROB ) { // Very small numbers are a very big problem + if( isUsingTiTvModel ) { for( int kkk = 0; kkk < numGaussians; kkk++ ) { - //if( isHetCluster[kkk] == isHet ) { - pVarInCluster[kkk] /= sumProb; - //} + pVarInCluster[kkk] /= sumProb; } } } @@ -770,58 +747,79 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final int numVariants = data.length; final int numAnnotations = data[0].annotations.length; + final double sigmaVals[][][] = new double[numGaussians][numAnnotations][numAnnotations]; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { mu[kkk][jjj] = 0.0; - sigma[kkk][jjj] = 0.0; - } - } - for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - double sumProb = 0.0; - for( int iii = 0; iii < numVariants; iii++ ) { - final double prob = pVarInCluster[kkk][iii]; - sumProb += prob; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - mu[kkk][jjj] += prob * data[iii].annotations[jjj]; + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] = 0.0; } } + } + double sumPK = 0.0; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + if( !deadCluster[kkk] ) { + double sumProb = 0.0; + for( int iii = 0; iii < numVariants; iii++ ) { + final double prob = pVarInCluster[kkk][iii]; + sumProb += prob; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mu[kkk][jjj] += prob * data[iii].annotations[jjj]; + } + } - if( sumProb > MIN_SUM_PROB ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { mu[kkk][jjj] /= sumProb; } - } //BUGBUG: clean up dead clusters to speed up computation - for( int iii = 0; iii < numVariants; iii++ ) { - final double prob = pVarInCluster[kkk][iii]; + for( int iii = 0; iii < numVariants; iii++ ) { + final double prob = pVarInCluster[kkk][iii]; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] += prob * (data[iii].annotations[jjj]-mu[kkk][jjj]) * (data[iii].annotations[ppp]-mu[kkk][ppp]); + } + } + } + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - sigma[kkk][jjj] += prob * (data[iii].annotations[jjj]-mu[kkk][jjj]) * (data[iii].annotations[jjj]-mu[kkk][jjj]); + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + if( sigmaVals[kkk][jjj][ppp] < MIN_SIGMA ) { // Very small numbers are a very big problem + sigmaVals[kkk][jjj][ppp] = MIN_SIGMA;// + MIN_SIGMA * rand.nextDouble(); + } + sigmaVals[kkk][ppp][jjj] = sigmaVals[kkk][jjj][ppp]; // sigma must be a symmetric matrix + } } - } - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - if( sigma[kkk][jjj] < MIN_PROB) { // Very small numbers are a very big problem - sigma[kkk][jjj] = MIN_PROB; - } - } - - if( sumProb > MIN_SUM_PROB ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - sigma[kkk][jjj] /= sumProb; + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] /= sumProb; + } + } + + sigma[kkk] = new Matrix(sigmaVals[kkk]); + determinant[kkk] = sigma[kkk].det(); + //if( determinant[kkk] < MIN_DETERMINANT ) { deadCluster[kkk] = true; } + + if( !deadCluster[kkk] ) { + pCluster[kkk] = sumProb / numVariants; + sumPK += pCluster[kkk]; } } - - pCluster[kkk] = sumProb / numVariants; } + // ensure pCluster sums to one, it doesn't automatically due to very small numbers getting capped + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + pCluster[kkk] /= sumPK; + } + /* // Clean up extra big or extra small clusters - //BUGBUG: Is this a good idea? for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { if( pCluster[kkk] > 0.45 ) { // This is a very large cluster compared to all the others - final int numToReplace = 4; - final double[] savedSigma = sigma[kkk]; + System.out.println("!! Found very large cluster! Busting it up into smaller clusters."); + final int numToReplace = 3; + final Matrix savedSigma = sigma[kkk]; for( int rrr = 0; rrr < numToReplace; rrr++ ) { // Find an example variant in the large cluster, drawn randomly int randVarIndex = -1; @@ -854,38 +852,52 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } mu[minClusterIndex] = data[randVarIndex].annotations; sigma[minClusterIndex] = savedSigma; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - sigma[minClusterIndex][jjj] += -0.06 + 0.12 * rand.nextDouble(); - if( sigma[minClusterIndex][jjj] < MIN_SUM_PROB ) { - sigma[minClusterIndex][jjj] = MIN_SUM_PROB; - } - } - pCluster[minClusterIndex] = 1.0 / ((double) (stopCluster-startCluster)); + //for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + // for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + // sigma[minClusterIndex].set(jjj, ppp, sigma[minClusterIndex].get(jjj, ppp) - 0.06 + 0.12 * rand.nextDouble()); + // } + //} + pCluster[minClusterIndex] = 0.5 / ((double) (stopCluster-startCluster)); } } } } - + */ - // Replace small clusters with another random draw from the dataset + + // Replace extremely small clusters with another random draw from the dataset for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - if( pCluster[kkk] < 0.07 * (1.0 / ((double) (stopCluster-startCluster))) ) { // This is a very small cluster compared to all the others - pCluster[kkk] = 1.0 / ((double) (stopCluster-startCluster)); + //if(determinant[kkk] < MIN_DETERMINANT ) { + if( pCluster[kkk] < 0.0005 * (1.0 / ((double) (stopCluster-startCluster))) || + determinant[kkk] < MIN_DETERMINANT ) { // This is a very small cluster compared to all the others + logger.info("!! Found singular cluster! Initializing a new random cluster."); + pCluster[kkk] = 0.1 / ((double) (stopCluster-startCluster)); // 0.5 / + //final double[] randMu = new double[numAnnotations]; + //for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + // randMu[jjj] = -1.5 + 3.0 * rand.nextDouble(); + //} mu[kkk] = data[rand.nextInt(numVariants)].annotations; - final double[] randSigma = new double[numAnnotations]; - if( dataManager.isNormalized ) { - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble(); // BUGBUG: Explore a wider range of possible sigma values since we are tossing out clusters anyway? - } - } else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = dataManager.varianceVector[jjj] + ((1.0 + rand.nextDouble()) * 0.01 * dataManager.varianceVector[jjj]); + final double[][] randSigma = new double[numAnnotations][numAnnotations]; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + randSigma[ppp][jjj] = 0.50 + 0.5 * rand.nextDouble(); // data is normalized so this is centered at 1.0 + if(jjj != ppp) { randSigma[jjj][ppp] = 0.0; } // Sigma is a symmetric, positive-definite matrix } } - sigma[kkk] = randSigma; + Matrix tmp = new Matrix(randSigma); + tmp = tmp.times(tmp.transpose()); + sigma[kkk] = tmp; + determinant[kkk] = sigma[kkk].det(); } } - + // renormalize pCluster since things might have changed due to the previous step + sumPK = 0.0; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + sumPK += pCluster[kkk]; + } + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + pCluster[kkk] /= sumPK; + } } } \ No newline at end of file diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java index d2efbc1bc..637fe7daf 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java @@ -55,8 +55,6 @@ public class VariantOptimizer extends RodWalker ///////////////////////////// @Argument(fullName="target_titv", shortName="titv", doc="The target Ti/Tv ratio towards which to optimize. (~~2.1 for whole genome experiments)", required=true) private double TARGET_TITV = 2.1; - @Argument(fullName="known_alpha_factor", shortName="kFactor", doc="Percentage of true positive rate that is due to difference between known and novel titv in a cluster as opposed to difference from target titv", required=false) - private double KNOWN_ALPHA_FACTOR = 0.0; @Argument(fullName="ignore_all_input_filters", shortName="ignoreAllFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false) private boolean IGNORE_ALL_INPUT_FILTERS = false; @Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false) @@ -68,11 +66,11 @@ public class VariantOptimizer extends RodWalker @Argument(fullName="clusterFile", shortName="clusterFile", doc="The output cluster file", required=true) private String CLUSTER_FILENAME = "optimizer.cluster"; @Argument(fullName="numGaussians", shortName="nG", doc="The number of Gaussians to be used in the Gaussian Mixture model", required=false) - private int NUM_GAUSSIANS = 32; + private int NUM_GAUSSIANS = 26; @Argument(fullName="numIterations", shortName="nI", doc="The number of iterations to be performed in the Gaussian Mixture model", required=false) private int NUM_ITERATIONS = 10; @Argument(fullName="minVarInCluster", shortName="minVar", doc="The minimum number of variants in a cluster to be considered a valid cluster. It can be used to prevent overfitting.", required=false) - private int MIN_VAR_IN_CLUSTER = 2000; + private int MIN_VAR_IN_CLUSTER = 1000; //@Argument(fullName="knn", shortName="knn", doc="The number of nearest neighbors to be used in the k-Nearest Neighbors model", required=false) //private int NUM_KNN = 2000; @@ -122,6 +120,7 @@ public class VariantOptimizer extends RodWalker if( !vc.isFiltered() || IGNORE_ALL_INPUT_FILTERS || (ignoreInputFilterSet != null && ignoreInputFilterSet.containsAll(vc.getFilters())) ) { if( firstVariant ) { // This is the first variant encountered so set up the list of annotations annotationKeys.addAll( vc.getAttributes().keySet() ); + annotationKeys.add("QUAL"); if( annotationKeys.contains("ID") ) { annotationKeys.remove("ID"); } // ID field is added to the vc's INFO field? if( annotationKeys.contains("DB") ) { annotationKeys.remove("DB"); } if( EXCLUDED_ANNOTATIONS != null ) { @@ -134,7 +133,7 @@ public class VariantOptimizer extends RodWalker if( !annotationKeys.contains( forcedAnnotation ) ) { annotationKeys.add( forcedAnnotation ); } } } - numAnnotations = annotationKeys.size() + 1; // +1 for variant quality ("QUAL") + numAnnotations = annotationKeys.size(); annotationValues = new double[numAnnotations]; firstVariant = false; } @@ -143,26 +142,28 @@ public class VariantOptimizer extends RodWalker for( final String key : annotationKeys ) { double value = 0.0; - try { - value = Double.parseDouble( (String)vc.getAttribute( key, "0.0" ) ); - if( Double.isInfinite(value) ) { - value = ( value > 0 ? 1.0 : -1.0 ) * INFINITE_ANNOTATION_VALUE; + if( key.equals("AB") && !vc.getAttributes().containsKey(key) ) { + value = (0.5 - 0.005) + (0.01 * Math.random()); // HomVar calls don't have an allele balance + } else if( key.equals("QUAL") ) { + value = vc.getPhredScaledQual(); + } else { + try { + value = Double.parseDouble( (String)vc.getAttribute( key, "0.0" ) ); + if( Double.isInfinite(value) ) { + value = ( value > 0 ? 1.0 : -1.0 ) * INFINITE_ANNOTATION_VALUE; + } + } catch( NumberFormatException e ) { + // do nothing, default value is 0.0 } - } catch( NumberFormatException e ) { - // do nothing, default value is 0.0 } annotationValues[iii++] = value; } - // Variant quality ("QUAL") is not in the list of annotations, but is useful so add it here. - annotationValues[iii] = vc.getPhredScaledQual(); - final VariantDatum variantDatum = new VariantDatum(); variantDatum.annotations = annotationValues; variantDatum.isTransition = vc.getSNPSubstitutionType().compareTo(BaseUtils.BaseSubstitutionType.TRANSITION) == 0; variantDatum.isKnown = !vc.getAttribute("ID").equals("."); - variantDatum.isHet = vc.getHetCount() > vc.getHomVarCount(); // BUGBUG: what to do here for multi sample calls? - + mapList.add( variantDatum ); } } @@ -192,7 +193,7 @@ public class VariantOptimizer extends RodWalker reduceSum.clear(); // Don't need this ever again, clean up some memory logger.info( "There are " + dataManager.numVariants + " variants and " + dataManager.numAnnotations + " annotations." ); - logger.info( "The annotations are: " + annotationKeys + " and QUAL." ); + logger.info( "The annotations are: " + annotationKeys ); dataManager.normalizeData(); // Each data point is now [ (x - mean) / standard deviation ] @@ -200,7 +201,7 @@ public class VariantOptimizer extends RodWalker VariantOptimizationModel theModel; switch (OPTIMIZATION_MODEL) { case GAUSSIAN_MIXTURE_MODEL: - theModel = new VariantGaussianMixtureModel( dataManager, TARGET_TITV, NUM_GAUSSIANS, NUM_ITERATIONS, MIN_VAR_IN_CLUSTER, KNOWN_ALPHA_FACTOR ); + theModel = new VariantGaussianMixtureModel( dataManager, TARGET_TITV, NUM_GAUSSIANS, NUM_ITERATIONS, MIN_VAR_IN_CLUSTER ); break; //case K_NEAREST_NEIGHBORS: // theModel = new VariantNearestNeighborsModel( dataManager, TARGET_TITV, NUM_KNN );