More incremental updates to the variant optimizer.

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@2939 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
rpoplin 2010-03-05 16:42:42 +00:00
parent 7a7e85188c
commit 95d560aa2f
9 changed files with 267 additions and 71 deletions

View File

@ -33,5 +33,5 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer;
public interface VariantClusteringModel extends VariantOptimizationInterface {
public void createClusters( final VariantDatum[] data );
public double[] applyClusters( final VariantDatum[] data );
public void applyClusters( final VariantDatum[] data, final String outputPrefix );
}

View File

@ -58,13 +58,20 @@ public class VariantDataManager {
}
public void normalizeData() {
for( int iii = 0; iii < numAnnotations; iii++ ) {
final double theMean = mean(data, iii);
final double theSTD = standardDeviation(data, theMean, iii);
System.out.println( (iii == numAnnotations-1 ? "QUAL" : annotationKeys.get(iii)) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) );
varianceVector[iii] = theSTD * theSTD;
for( int jjj=0; jjj<numVariants; jjj++ ) {
data[jjj].annotations[iii] = ( data[jjj].annotations[iii] - theMean ) / theSTD;
for( int iii = 0; iii < numVariants; iii++ ) {
data[iii].originalAnnotations = data[iii].annotations.clone();
}
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
final double theMean = mean(data, jjj);
final double theSTD = standardDeviation(data, theMean, jjj);
System.out.println( (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj)) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) );
if( theSTD < 1E-8 ) {
throw new StingException("Zero variance is a problem: standard deviation = " + theSTD);
}
varianceVector[jjj] = theSTD * theSTD;
for( int iii = 0; iii < numVariants; iii++ ) {
data[iii].annotations[jjj] = ( data[iii].annotations[jjj] - theMean ) / theSTD;
}
}
isNormalized = true; // Each data point is now [ (x - mean) / standard deviation ]

View File

@ -33,6 +33,7 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer;
public class VariantDatum {
public double[] annotations;
public double[] originalAnnotations;
public boolean isTransition;
public boolean isKnown;
public boolean isFiltered;

View File

@ -1,5 +1,6 @@
package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer;
import java.io.PrintStream;
import java.util.Random;
/*
@ -33,23 +34,38 @@ import java.util.Random;
* Date: Feb 26, 2010
*/
public class VariantGaussianMixtureModel extends VariantOptimizationModel implements VariantClusteringModel {
public final class VariantGaussianMixtureModel extends VariantOptimizationModel implements VariantClusteringModel {
private final int numGaussians = 128;
private final int numGaussians;
private final int numIterations;
private final long RANDOM_SEED = 91801305;
private final Random rand = new Random( RANDOM_SEED );
private final double MIN_PROB = 1E-30;
private final double MIN_SUM_PROB = 1E-20;
private final double[][] mu = new double[numGaussians][];
private final double[][] sigma = new double[numGaussians][];
private final double[] pCluster = new double[numGaussians];
final double[] clusterTruePositiveRate = new double[numGaussians];
private final double[][] mu;
private final double[][] sigma;
private final double[] pCluster;
private final int[] numMaxClusterKnown;
private final int[] numMaxClusterNovel;
private final double[] clusterTITV;
private final double[] clusterTruePositiveRate;
public VariantGaussianMixtureModel( VariantDataManager _dataManager, final double _targetTITV ) {
public VariantGaussianMixtureModel( VariantDataManager _dataManager, final double _targetTITV, final int _numGaussians, final int _numIterations ) {
super( _dataManager, _targetTITV );
numGaussians = _numGaussians;
numIterations = _numIterations;
mu = new double[numGaussians][];
sigma = new double[numGaussians][];
pCluster = new double[numGaussians];
numMaxClusterKnown = new int[numGaussians];
numMaxClusterNovel = new int[numGaussians];
clusterTITV = new double[numGaussians];
clusterTruePositiveRate = new double[numGaussians];
}
public double[] run() {
public final void run( final String outputPrefix ) {
// Create the subset of the data to cluster with
int numSubset = 0;
@ -58,31 +74,36 @@ public class VariantGaussianMixtureModel extends VariantOptimizationModel implem
numSubset++;
}
}
final VariantDatum[] data = new VariantDatum[numSubset];
final VariantDatum[] data = new VariantDatum[numSubset*2];
int iii = 0;
for( final VariantDatum datum : dataManager.data ) {
if( !datum.isKnown ) {
data[iii++] = datum;
}
}
while( iii < numSubset*2 ) { // grab an equal number of known variants at random
final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)];
if( datum.isKnown ) {
data[iii++] = datum;
}
}
System.out.println("Clustering with " + numSubset + " variants...");
System.out.println("Clustering with " + numSubset*2 + " variants...");
createClusters( data ); // Using a subset of the data
System.out.println("Applying clusters to all variants...");
return applyClusters( dataManager.data ); // Using all the data
System.out.println("Printing out cluster parameters...");
printClusters( outputPrefix );
System.out.println("Applying clusters to all the variants...");
applyClusters( dataManager.data, outputPrefix ); // Using all the data
}
public final void createClusters( final VariantDatum[] data ) {
final int numVariants = data.length;
final int numAnnotations = data[0].annotations.length;
final int numIterations = 3;
final double[][] pVarInCluster = new double[numGaussians][numVariants];
final double[] probTi = new double[numGaussians];
final double[] probTv = new double[numGaussians];
final Random rand = new Random( RANDOM_SEED );
// loop control variables:
// iii - loop over data points
@ -92,7 +113,9 @@ public class VariantGaussianMixtureModel extends VariantOptimizationModel implem
// Set up the initial random Gaussians
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
pCluster[kkk] = 1.0;
numMaxClusterKnown[kkk] = 0;
numMaxClusterNovel[kkk] = 0;
pCluster[kkk] = 1.0 / ((double) numGaussians);
//final double[] randMu = new double[numAnnotations];
//for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
// randMu[jjj] = data[rand.nextInt(numVariants)].annotations[jjj];
@ -145,35 +168,116 @@ public class VariantGaussianMixtureModel extends VariantOptimizationModel implem
probTv[kkk] += pVarInCluster[kkk][iii];
}
}
double maxProb = pVarInCluster[0][iii];
int maxCluster = 0;
for( int kkk = 1; kkk < numGaussians; kkk++ ) {
if( pVarInCluster[kkk][iii] > maxProb ) {
maxProb = pVarInCluster[kkk][iii];
maxCluster = kkk;
}
}
if( data[iii].isKnown ) {
numMaxClusterKnown[maxCluster]++;
} else {
numMaxClusterNovel[maxCluster]++;
}
}
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( probTi[kkk] / probTv[kkk] );
clusterTITV[kkk] = probTi[kkk] / probTv[kkk];
clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] );
}
}
public final double[] applyClusters( final VariantDatum[] data ) {
private void printClusters( final String outputPrefix ) {
try {
final PrintStream outputFile = new PrintStream( outputPrefix + ".clusters" );
int clusterNumber = 0;
final int numAnnotations = mu[0].length;
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
if( numMaxClusterKnown[kkk] + numMaxClusterNovel[kkk] >= 2000 ) {
outputFile.print(clusterNumber + ",");
outputFile.print(numMaxClusterKnown[kkk] + ",");
outputFile.print(numMaxClusterNovel[kkk] + ",");
outputFile.print(clusterTITV[kkk] + ",");
outputFile.print(clusterTruePositiveRate[kkk] + ",");
for(int jjj = 0; jjj < numAnnotations; jjj++ ) {
outputFile.print(mu[kkk][jjj] + ",");
}
for(int jjj = 0; jjj < numAnnotations; jjj++ ) {
outputFile.print(sigma[kkk][jjj] + ",");
}
outputFile.println(-1);
clusterNumber++;
}
}
} catch (Exception e) {
e.printStackTrace();
System.exit(-1);
}
}
public final void applyClusters( final VariantDatum[] data, final String outputPrefix ) {
final int numVariants = data.length;
final int numAnnotations = data[0].annotations.length;
final double[] pTrueVariant = new double[numVariants];
final double[][] pVarInCluster = new double[numGaussians][numVariants];
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Expectation Step (calculate the probability that each data point is in each cluster)
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
evaluateGaussians( data, pVarInCluster );
final double[] pVarInCluster = new double[numGaussians];
final int[] clusterAssignment = new int[numVariants];
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Evaluate each variant using the probability of being in each cluster and that cluster's true positive rate
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
for( int iii = 0; iii < numVariants; iii++ ) {
evaluateGaussiansForSingleVariant( data[iii], pVarInCluster );
pTrueVariant[iii] = 0.0;
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
pTrueVariant[iii] += pVarInCluster[kkk][iii] * clusterTruePositiveRate[kkk];
pTrueVariant[iii] += pVarInCluster[kkk] * clusterTruePositiveRate[kkk];
}
double maxProb = -1.0;//pVarInCluster[0][iii];
int maxCluster = -1;
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
if( numMaxClusterKnown[kkk] + numMaxClusterNovel[kkk] >= 2000 && pVarInCluster[kkk] > maxProb ) {
maxProb = pVarInCluster[kkk];
maxCluster = kkk;
}
}
clusterAssignment[iii] = maxCluster;
}
return pTrueVariant;
PrintStream outputFile = null;
try {
outputFile = new PrintStream( outputPrefix + ".data" );
} catch (Exception e) {
e.printStackTrace();
System.exit(-1);
}
for(int iii = 0; iii < numVariants; iii++) {
outputFile.print(pTrueVariant[iii] + ",");
outputFile.print(clusterAssignment[iii] + ",");
for(int jjj = 0; jjj < numAnnotations; jjj++ ) {
outputFile.print(data[iii].originalAnnotations[jjj] + ",");
}
outputFile.println( (data[iii].isTransition ? 1 : 0)
+ "," + (data[iii].isKnown? 1 : 0)
+ "," + (data[iii].isFiltered ? 1 : 0));
}
try {
outputFile = new PrintStream( outputPrefix + ".optimize" );
} catch (Exception e) {
e.printStackTrace();
System.exit(-1);
}
for(int iii = 0; iii < numVariants; iii++) {
outputFile.print(String.format("%.4f",pTrueVariant[iii]) + ",");
outputFile.println( (data[iii].isTransition ? 1 : 0)
+ "," + (data[iii].isKnown? 1 : 0)
+ "," + (data[iii].isFiltered ? 1 : 0));
}
}
@ -207,6 +311,34 @@ public class VariantGaussianMixtureModel extends VariantOptimizationModel implem
}
private void evaluateGaussiansForSingleVariant( final VariantDatum datum, final double[] pVarInCluster ) {
final int numAnnotations = datum.annotations.length;
double sumProb = 0.0;
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
double sum = 0.0;
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
sum += ( (datum.annotations[jjj] - mu[kkk][jjj]) * (datum.annotations[jjj] - mu[kkk][jjj]) )
/ sigma[kkk][jjj];
}
pVarInCluster[kkk] = pCluster[kkk] * Math.exp( -0.5 * sum );
if( pVarInCluster[kkk] < MIN_PROB) { // Very small numbers are a very big problem
pVarInCluster[kkk] = MIN_PROB;
}
sumProb += pVarInCluster[kkk];
}
if( sumProb > MIN_SUM_PROB ) { // Very small numbers are a very big problem
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
pVarInCluster[kkk] /= sumProb;
}
}
}
private void maximizeGaussians( final VariantDatum[] data, final double[][] pVarInCluster ) {
final int numVariants = data.length;
@ -253,9 +385,75 @@ public class VariantGaussianMixtureModel extends VariantOptimizationModel implem
}
}
pCluster[kkk] = sumProb / numVariants; // BUGBUG: Experiment with this, want to keep many clusters alive
// Perhaps replace the cluster with a new random draw once pCluster gets too small
// and break up a large cluster with examples drawn from that cluster
pCluster[kkk] = sumProb / numVariants;
}
// Clean up extra big or extra small clusters
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
if( pCluster[kkk] > 0.45 ) { // This is a very large cluster compared to all the others
System.out.println("Very large cluster!");
final int numToReplace = 4;
final double[] savedSigma = sigma[kkk];
for( int rrr = 0; rrr < numToReplace; rrr++ ) {
// Find an example variant in the large cluster, drawn randomly
int randVarIndex = -1;
boolean foundVar = false;
while( !foundVar ) {
randVarIndex = rand.nextInt( numVariants );
final double probK = pVarInCluster[kkk][randVarIndex];
boolean inClusterK = true;
for( int ccc = 0; ccc < numGaussians; ccc++ ) {
if( pVarInCluster[ccc][randVarIndex] > probK ) {
inClusterK = false;
break;
}
}
if( inClusterK ) { foundVar = true; }
}
// Find a place to put the example variant
if( rrr == 0 ) { // Replace the big cluster that kicked this process off
mu[kkk] = data[randVarIndex].annotations;
//sigma[kkk] = savedSigma;
pCluster[kkk] = 1.0 / ((double) numGaussians);
} else { // Replace the cluster with the minimum prob
double minProb = pCluster[0];
int minClusterIndex = 0;
for( int ccc = 1; ccc < numGaussians; ccc++ ) {
if( pCluster[ccc] < minProb ) {
minProb = pCluster[ccc];
minClusterIndex = ccc;
}
}
mu[minClusterIndex] = data[randVarIndex].annotations;
sigma[minClusterIndex] = savedSigma;
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
sigma[minClusterIndex][jjj] += -0.02 + 0.04 * rand.nextDouble();
}
pCluster[minClusterIndex] = 1.0 / ((double) numGaussians);
}
}
}
}
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
if( pCluster[kkk] < 0.05 * (1.0 / ((double) numGaussians)) ) { // This is a very small cluster compared to all the others
System.out.println("Very small cluster!");
pCluster[kkk] = 1.0 / ((double) numGaussians);
mu[kkk] = data[rand.nextInt(numVariants)].annotations;
final double[] randSigma = new double[numAnnotations];
if( dataManager.isNormalized ) {
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
randSigma[jjj] = 0.9 + 0.2 * rand.nextDouble();
}
} else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
randSigma[jjj] = dataManager.varianceVector[jjj] + ((1.0 + rand.nextDouble()) * 0.01 * dataManager.varianceVector[jjj]);
}
}
sigma[kkk] = randSigma;
}
}
}
}

View File

@ -1,5 +1,7 @@
package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer;
import java.io.PrintStream;
/*
* Copyright (c) 2010 The Broad Institute
*
@ -31,13 +33,13 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer;
* Date: Mar 1, 2010
*/
public class VariantNearestNeighborsModel extends VariantOptimizationModel {
public final class VariantNearestNeighborsModel extends VariantOptimizationModel {
public VariantNearestNeighborsModel( VariantDataManager _dataManager, final double _targetTITV ) {
super( _dataManager, _targetTITV );
}
public double[] run() {
public void run( final String outputPrefix ) {
final int numVariants = dataManager.numVariants;
@ -52,6 +54,7 @@ public class VariantNearestNeighborsModel extends VariantOptimizationModel {
pTrueVariant[iii] = calcTruePositiveRateFromTITV( vTree.calcNeighborhoodTITV( dataManager.data[iii] ) );
}
return pTrueVariant;
//BUGBUG: need to output pTrueVariant and other metrics in this method
//return pTrueVariant;
}
}

View File

@ -32,5 +32,5 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer;
*/
public interface VariantOptimizationInterface {
public double[] run();
public void run( String outputPrefix );
}

View File

@ -40,7 +40,7 @@ public abstract class VariantOptimizationModel implements VariantOptimizationInt
targetTITV = _targetTITV;
}
public double calcTruePositiveRateFromTITV( double titv ) {
public final double calcTruePositiveRateFromTITV( double titv ) {
if( titv > targetTITV ) { titv -= 2.0f*(titv-targetTITV); }
if( titv < 0.5 ) { titv = 0.5; }
return ( (titv - 0.5) / (targetTITV - 0.5) );

View File

@ -9,8 +9,6 @@ import org.broadinstitute.sting.utils.BaseUtils;
import org.broadinstitute.sting.utils.ExpandingArrayList;
import org.broadinstitute.sting.utils.cmdLine.Argument;
import java.io.PrintStream;
/*
* Copyright (c) 2010 The Broad Institute
*
@ -51,14 +49,21 @@ public class VariantOptimizer extends RodWalker<ExpandingArrayList<VariantDatum>
/////////////////////////////
// Command Line Arguments
/////////////////////////////
@Argument(fullName = "target_titv", shortName="titv", doc="The target Ti/Tv ratio towards which to optimize. (~~2.2 for whole genome experiments)", required=true)
@Argument(fullName="target_titv", shortName="titv", doc="The target Ti/Tv ratio towards which to optimize. (~~2.2 for whole genome experiments)", required=true)
private double TARGET_TITV = 2.12;
//@Argument(fullName = "filter_output", shortName="filter", doc="If specified the optimizer will not only update the QUAL field of the output VCF file but will also filter the variants", required=false)
//@Argument(fullName="filter_output", shortName="filter", doc="If specified the optimizer will not only update the QUAL field of the output VCF file but will also filter the variants", required=false)
//private boolean FILTER_OUTPUT = false;
@Argument(fullName = "ignore_input_filters", shortName="ignoreFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false)
@Argument(fullName="ignore_input_filters", shortName="ignoreFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false)
private boolean IGNORE_INPUT_FILTERS = false;
@Argument(fullName = "exclude_annotation", shortName = "exclude", doc = "The names of the annotations which should be excluded from the calculations", required = false)
@Argument(fullName="exclude_annotation", shortName="exclude", doc="The names of the annotations which should be excluded from the calculations", required=false)
private String[] EXCLUDED_ANNOTATIONS = null;
@Argument(fullName="output", shortName="output", doc="The output file name", required=false)
private String OUTPUT_FILE = "optimizer.data";
@Argument(fullName="numGaussians", shortName="nG", doc="The number of Gaussians to be used in the Gaussian mixture model", required=false)
private int NUM_GAUSSIANS = 32;
@Argument(fullName="numIterations", shortName="nI", doc="The number of iterations to be performed in the Gaussian mixture model", required=false)
private int NUM_ITERATIONS = 5; //BUGBUG: should automatically decided when to stop by looking at how entropy changes with each iteration
/////////////////////////////
// Private Member Variables
@ -118,7 +123,6 @@ public class VariantOptimizer extends RodWalker<ExpandingArrayList<VariantDatum>
value = Double.parseDouble( (String)vc.getAttribute( key, "0.0" ) );
} catch( NumberFormatException e ) {
// do nothing, default value is 0.0,
// BUGBUG: annotations with zero variance should be ignored
}
annotationValues[iii++] = value;
}
@ -155,7 +159,7 @@ public class VariantOptimizer extends RodWalker<ExpandingArrayList<VariantDatum>
public void onTraversalDone( ExpandingArrayList<VariantDatum> reduceSum ) {
final VariantDataManager dataManager = new VariantDataManager( reduceSum, annotationKeys );
final VariantDataManager dataManager = new VariantDataManager( reduceSum, annotationKeys);
reduceSum.clear(); // Don't need this ever again, clean up some memory
logger.info( "There are " + dataManager.numVariants + " variants and " + dataManager.numAnnotations + " annotations.");
@ -163,26 +167,9 @@ public class VariantOptimizer extends RodWalker<ExpandingArrayList<VariantDatum>
dataManager.normalizeData(); // Each data point is now [ (x - mean) / standard deviation ]
final VariantOptimizationModel gmm = new VariantGaussianMixtureModel( dataManager, TARGET_TITV );
final double[] p = gmm.run();
// BUGBUG: Change to call a second ROD walker to output the new VCF file with new qual fields and filters
// Intermediate cluster ROD can be analyzed to assess clustering performance
try {
final PrintStream out = new PrintStream("gmm128clusterNovel.data"); // Parse in Matlab to create performance plots
for(int iii = 0; iii < dataManager.numVariants; iii++) {
out.print(p[iii] + "\t");
out.println( (dataManager.data[iii].isTransition ? 1 : 0)
+ "\t" + (dataManager.data[iii].isKnown? 1 : 0)
+ "\t" + (dataManager.data[iii].isFiltered ? 1 : 0) );
}
} catch (Exception e) {
e.printStackTrace();
System.exit(-1);
}
// Create either the Gaussian Mixture Model or the Nearest Neighbors model and run it
final VariantOptimizationModel gmm = new VariantGaussianMixtureModel( dataManager, TARGET_TITV, NUM_GAUSSIANS, NUM_ITERATIONS );
gmm.run( OUTPUT_FILE );
}
}

View File

@ -49,7 +49,7 @@ public class VariantTreeNode {
cutValue = -1;
}
public void cutData( final VariantDatum[] data, final int depth, final int lastCutDepth, final int numAnnotations ) {
public final void cutData( final VariantDatum[] data, final int depth, final int lastCutDepth, final int numAnnotations ) {
cutDim = depth % numAnnotations;