diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java index 86dd1854f..c5afc5861 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java @@ -33,5 +33,5 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; public interface VariantClusteringModel extends VariantOptimizationInterface { public void createClusters( final VariantDatum[] data ); - public double[] applyClusters( final VariantDatum[] data ); + public void applyClusters( final VariantDatum[] data, final String outputPrefix ); } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java index 4673b8b41..76da26164 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java @@ -58,13 +58,20 @@ public class VariantDataManager { } public void normalizeData() { - for( int iii = 0; iii < numAnnotations; iii++ ) { - final double theMean = mean(data, iii); - final double theSTD = standardDeviation(data, theMean, iii); - System.out.println( (iii == numAnnotations-1 ? "QUAL" : annotationKeys.get(iii)) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); - varianceVector[iii] = theSTD * theSTD; - for( int jjj=0; jjj maxProb ) { + maxProb = pVarInCluster[kkk][iii]; + maxCluster = kkk; + } + } + if( data[iii].isKnown ) { + numMaxClusterKnown[maxCluster]++; + } else { + numMaxClusterNovel[maxCluster]++; + } } for( int kkk = 0; kkk < numGaussians; kkk++ ) { - clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( probTi[kkk] / probTv[kkk] ); + clusterTITV[kkk] = probTi[kkk] / probTv[kkk]; + clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] ); } } - public final double[] applyClusters( final VariantDatum[] data ) { + private void printClusters( final String outputPrefix ) { + try { + final PrintStream outputFile = new PrintStream( outputPrefix + ".clusters" ); + int clusterNumber = 0; + final int numAnnotations = mu[0].length; + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + if( numMaxClusterKnown[kkk] + numMaxClusterNovel[kkk] >= 2000 ) { + outputFile.print(clusterNumber + ","); + outputFile.print(numMaxClusterKnown[kkk] + ","); + outputFile.print(numMaxClusterNovel[kkk] + ","); + outputFile.print(clusterTITV[kkk] + ","); + outputFile.print(clusterTruePositiveRate[kkk] + ","); + for(int jjj = 0; jjj < numAnnotations; jjj++ ) { + outputFile.print(mu[kkk][jjj] + ","); + } + for(int jjj = 0; jjj < numAnnotations; jjj++ ) { + outputFile.print(sigma[kkk][jjj] + ","); + } + outputFile.println(-1); + clusterNumber++; + } + } + } catch (Exception e) { + e.printStackTrace(); + System.exit(-1); + } + } + + public final void applyClusters( final VariantDatum[] data, final String outputPrefix ) { final int numVariants = data.length; + final int numAnnotations = data[0].annotations.length; final double[] pTrueVariant = new double[numVariants]; - final double[][] pVarInCluster = new double[numGaussians][numVariants]; - - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - // Expectation Step (calculate the probability that each data point is in each cluster) - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - evaluateGaussians( data, pVarInCluster ); + final double[] pVarInCluster = new double[numGaussians]; + final int[] clusterAssignment = new int[numVariants]; ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Evaluate each variant using the probability of being in each cluster and that cluster's true positive rate ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// for( int iii = 0; iii < numVariants; iii++ ) { + evaluateGaussiansForSingleVariant( data[iii], pVarInCluster ); + pTrueVariant[iii] = 0.0; for( int kkk = 0; kkk < numGaussians; kkk++ ) { - pTrueVariant[iii] += pVarInCluster[kkk][iii] * clusterTruePositiveRate[kkk]; + pTrueVariant[iii] += pVarInCluster[kkk] * clusterTruePositiveRate[kkk]; } + + double maxProb = -1.0;//pVarInCluster[0][iii]; + int maxCluster = -1; + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + if( numMaxClusterKnown[kkk] + numMaxClusterNovel[kkk] >= 2000 && pVarInCluster[kkk] > maxProb ) { + maxProb = pVarInCluster[kkk]; + maxCluster = kkk; + } + } + clusterAssignment[iii] = maxCluster; } - return pTrueVariant; + PrintStream outputFile = null; + try { + outputFile = new PrintStream( outputPrefix + ".data" ); + } catch (Exception e) { + e.printStackTrace(); + System.exit(-1); + } + for(int iii = 0; iii < numVariants; iii++) { + outputFile.print(pTrueVariant[iii] + ","); + outputFile.print(clusterAssignment[iii] + ","); + for(int jjj = 0; jjj < numAnnotations; jjj++ ) { + outputFile.print(data[iii].originalAnnotations[jjj] + ","); + } + outputFile.println( (data[iii].isTransition ? 1 : 0) + + "," + (data[iii].isKnown? 1 : 0) + + "," + (data[iii].isFiltered ? 1 : 0)); + } + + try { + outputFile = new PrintStream( outputPrefix + ".optimize" ); + } catch (Exception e) { + e.printStackTrace(); + System.exit(-1); + } + for(int iii = 0; iii < numVariants; iii++) { + outputFile.print(String.format("%.4f",pTrueVariant[iii]) + ","); + outputFile.println( (data[iii].isTransition ? 1 : 0) + + "," + (data[iii].isKnown? 1 : 0) + + "," + (data[iii].isFiltered ? 1 : 0)); + } } @@ -207,6 +311,34 @@ public class VariantGaussianMixtureModel extends VariantOptimizationModel implem } + private void evaluateGaussiansForSingleVariant( final VariantDatum datum, final double[] pVarInCluster ) { + + final int numAnnotations = datum.annotations.length; + + double sumProb = 0.0; + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + double sum = 0.0; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + sum += ( (datum.annotations[jjj] - mu[kkk][jjj]) * (datum.annotations[jjj] - mu[kkk][jjj]) ) + / sigma[kkk][jjj]; + } + pVarInCluster[kkk] = pCluster[kkk] * Math.exp( -0.5 * sum ); + + if( pVarInCluster[kkk] < MIN_PROB) { // Very small numbers are a very big problem + pVarInCluster[kkk] = MIN_PROB; + } + + sumProb += pVarInCluster[kkk]; + } + + if( sumProb > MIN_SUM_PROB ) { // Very small numbers are a very big problem + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + pVarInCluster[kkk] /= sumProb; + } + } + } + + private void maximizeGaussians( final VariantDatum[] data, final double[][] pVarInCluster ) { final int numVariants = data.length; @@ -253,9 +385,75 @@ public class VariantGaussianMixtureModel extends VariantOptimizationModel implem } } - pCluster[kkk] = sumProb / numVariants; // BUGBUG: Experiment with this, want to keep many clusters alive - // Perhaps replace the cluster with a new random draw once pCluster gets too small - // and break up a large cluster with examples drawn from that cluster + pCluster[kkk] = sumProb / numVariants; } + + // Clean up extra big or extra small clusters + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + if( pCluster[kkk] > 0.45 ) { // This is a very large cluster compared to all the others + System.out.println("Very large cluster!"); + final int numToReplace = 4; + final double[] savedSigma = sigma[kkk]; + for( int rrr = 0; rrr < numToReplace; rrr++ ) { + // Find an example variant in the large cluster, drawn randomly + int randVarIndex = -1; + boolean foundVar = false; + while( !foundVar ) { + randVarIndex = rand.nextInt( numVariants ); + final double probK = pVarInCluster[kkk][randVarIndex]; + boolean inClusterK = true; + for( int ccc = 0; ccc < numGaussians; ccc++ ) { + if( pVarInCluster[ccc][randVarIndex] > probK ) { + inClusterK = false; + break; + } + } + if( inClusterK ) { foundVar = true; } + } + + // Find a place to put the example variant + if( rrr == 0 ) { // Replace the big cluster that kicked this process off + mu[kkk] = data[randVarIndex].annotations; + //sigma[kkk] = savedSigma; + pCluster[kkk] = 1.0 / ((double) numGaussians); + } else { // Replace the cluster with the minimum prob + double minProb = pCluster[0]; + int minClusterIndex = 0; + for( int ccc = 1; ccc < numGaussians; ccc++ ) { + if( pCluster[ccc] < minProb ) { + minProb = pCluster[ccc]; + minClusterIndex = ccc; + } + } + mu[minClusterIndex] = data[randVarIndex].annotations; + sigma[minClusterIndex] = savedSigma; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + sigma[minClusterIndex][jjj] += -0.02 + 0.04 * rand.nextDouble(); + } + pCluster[minClusterIndex] = 1.0 / ((double) numGaussians); + } + } + } + } + + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + if( pCluster[kkk] < 0.05 * (1.0 / ((double) numGaussians)) ) { // This is a very small cluster compared to all the others + System.out.println("Very small cluster!"); + pCluster[kkk] = 1.0 / ((double) numGaussians); + mu[kkk] = data[rand.nextInt(numVariants)].annotations; + final double[] randSigma = new double[numAnnotations]; + if( dataManager.isNormalized ) { + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + randSigma[jjj] = 0.9 + 0.2 * rand.nextDouble(); + } + } else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + randSigma[jjj] = dataManager.varianceVector[jjj] + ((1.0 + rand.nextDouble()) * 0.01 * dataManager.varianceVector[jjj]); + } + } + sigma[kkk] = randSigma; + } + } + } } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantNearestNeighborsModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantNearestNeighborsModel.java index 7386f7025..45214b47e 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantNearestNeighborsModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantNearestNeighborsModel.java @@ -1,5 +1,7 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; +import java.io.PrintStream; + /* * Copyright (c) 2010 The Broad Institute * @@ -31,13 +33,13 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; * Date: Mar 1, 2010 */ -public class VariantNearestNeighborsModel extends VariantOptimizationModel { +public final class VariantNearestNeighborsModel extends VariantOptimizationModel { public VariantNearestNeighborsModel( VariantDataManager _dataManager, final double _targetTITV ) { super( _dataManager, _targetTITV ); } - public double[] run() { + public void run( final String outputPrefix ) { final int numVariants = dataManager.numVariants; @@ -52,6 +54,7 @@ public class VariantNearestNeighborsModel extends VariantOptimizationModel { pTrueVariant[iii] = calcTruePositiveRateFromTITV( vTree.calcNeighborhoodTITV( dataManager.data[iii] ) ); } - return pTrueVariant; + //BUGBUG: need to output pTrueVariant and other metrics in this method + //return pTrueVariant; } } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationInterface.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationInterface.java index 33fcb52ca..9724d7907 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationInterface.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationInterface.java @@ -32,5 +32,5 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; */ public interface VariantOptimizationInterface { - public double[] run(); + public void run( String outputPrefix ); } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationModel.java index ad904d54d..aee28491e 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationModel.java @@ -40,7 +40,7 @@ public abstract class VariantOptimizationModel implements VariantOptimizationInt targetTITV = _targetTITV; } - public double calcTruePositiveRateFromTITV( double titv ) { + public final double calcTruePositiveRateFromTITV( double titv ) { if( titv > targetTITV ) { titv -= 2.0f*(titv-targetTITV); } if( titv < 0.5 ) { titv = 0.5; } return ( (titv - 0.5) / (targetTITV - 0.5) ); diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java index 5268059a1..ad8b57093 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java @@ -9,8 +9,6 @@ import org.broadinstitute.sting.utils.BaseUtils; import org.broadinstitute.sting.utils.ExpandingArrayList; import org.broadinstitute.sting.utils.cmdLine.Argument; -import java.io.PrintStream; - /* * Copyright (c) 2010 The Broad Institute * @@ -51,14 +49,21 @@ public class VariantOptimizer extends RodWalker ///////////////////////////// // Command Line Arguments ///////////////////////////// - @Argument(fullName = "target_titv", shortName="titv", doc="The target Ti/Tv ratio towards which to optimize. (~~2.2 for whole genome experiments)", required=true) + @Argument(fullName="target_titv", shortName="titv", doc="The target Ti/Tv ratio towards which to optimize. (~~2.2 for whole genome experiments)", required=true) private double TARGET_TITV = 2.12; - //@Argument(fullName = "filter_output", shortName="filter", doc="If specified the optimizer will not only update the QUAL field of the output VCF file but will also filter the variants", required=false) + //@Argument(fullName="filter_output", shortName="filter", doc="If specified the optimizer will not only update the QUAL field of the output VCF file but will also filter the variants", required=false) //private boolean FILTER_OUTPUT = false; - @Argument(fullName = "ignore_input_filters", shortName="ignoreFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false) + @Argument(fullName="ignore_input_filters", shortName="ignoreFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false) private boolean IGNORE_INPUT_FILTERS = false; - @Argument(fullName = "exclude_annotation", shortName = "exclude", doc = "The names of the annotations which should be excluded from the calculations", required = false) + @Argument(fullName="exclude_annotation", shortName="exclude", doc="The names of the annotations which should be excluded from the calculations", required=false) private String[] EXCLUDED_ANNOTATIONS = null; + @Argument(fullName="output", shortName="output", doc="The output file name", required=false) + private String OUTPUT_FILE = "optimizer.data"; + @Argument(fullName="numGaussians", shortName="nG", doc="The number of Gaussians to be used in the Gaussian mixture model", required=false) + private int NUM_GAUSSIANS = 32; + @Argument(fullName="numIterations", shortName="nI", doc="The number of iterations to be performed in the Gaussian mixture model", required=false) + private int NUM_ITERATIONS = 5; //BUGBUG: should automatically decided when to stop by looking at how entropy changes with each iteration + ///////////////////////////// // Private Member Variables @@ -118,7 +123,6 @@ public class VariantOptimizer extends RodWalker value = Double.parseDouble( (String)vc.getAttribute( key, "0.0" ) ); } catch( NumberFormatException e ) { // do nothing, default value is 0.0, - // BUGBUG: annotations with zero variance should be ignored } annotationValues[iii++] = value; } @@ -155,7 +159,7 @@ public class VariantOptimizer extends RodWalker public void onTraversalDone( ExpandingArrayList reduceSum ) { - final VariantDataManager dataManager = new VariantDataManager( reduceSum, annotationKeys ); + final VariantDataManager dataManager = new VariantDataManager( reduceSum, annotationKeys); reduceSum.clear(); // Don't need this ever again, clean up some memory logger.info( "There are " + dataManager.numVariants + " variants and " + dataManager.numAnnotations + " annotations."); @@ -163,26 +167,9 @@ public class VariantOptimizer extends RodWalker dataManager.normalizeData(); // Each data point is now [ (x - mean) / standard deviation ] - final VariantOptimizationModel gmm = new VariantGaussianMixtureModel( dataManager, TARGET_TITV ); - final double[] p = gmm.run(); - - // BUGBUG: Change to call a second ROD walker to output the new VCF file with new qual fields and filters - // Intermediate cluster ROD can be analyzed to assess clustering performance - try { - final PrintStream out = new PrintStream("gmm128clusterNovel.data"); // Parse in Matlab to create performance plots - for(int iii = 0; iii < dataManager.numVariants; iii++) { - out.print(p[iii] + "\t"); - out.println( (dataManager.data[iii].isTransition ? 1 : 0) - + "\t" + (dataManager.data[iii].isKnown? 1 : 0) - + "\t" + (dataManager.data[iii].isFiltered ? 1 : 0) ); - } - - - } catch (Exception e) { - e.printStackTrace(); - System.exit(-1); - } - + // Create either the Gaussian Mixture Model or the Nearest Neighbors model and run it + final VariantOptimizationModel gmm = new VariantGaussianMixtureModel( dataManager, TARGET_TITV, NUM_GAUSSIANS, NUM_ITERATIONS ); + gmm.run( OUTPUT_FILE ); } } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantTreeNode.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantTreeNode.java index 02b4c35c3..fa0e19c72 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantTreeNode.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantTreeNode.java @@ -49,7 +49,7 @@ public class VariantTreeNode { cutValue = -1; } - public void cutData( final VariantDatum[] data, final int depth, final int lastCutDepth, final int numAnnotations ) { + public final void cutData( final VariantDatum[] data, final int depth, final int lastCutDepth, final int numAnnotations ) { cutDim = depth % numAnnotations;