From 522dd7a5b2c26ab16234086027ab3dc9bd66e609 Mon Sep 17 00:00:00 2001 From: rpoplin Date: Fri, 28 May 2010 18:21:27 +0000 Subject: [PATCH] Adding the variantrecalibration classes. git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@3459 348d0f76-0448-11de-a6fe-93d51630548a --- .../GenerateVariantClustersWalker.java | 237 ++++++ .../VariantClusteringModel.java | 37 + .../VariantDataManager.java | 140 ++++ .../variantrecalibration/VariantDatum.java | 41 + .../VariantGaussianMixtureModel.java | 730 ++++++++++++++++++ .../VariantNearestNeighborsModel.java | 78 ++ .../VariantOptimizationInterface.java | 36 + .../VariantOptimizationModel.java | 70 ++ .../VariantRecalibrator.java | 244 ++++++ .../variantrecalibration/VariantTree.java | 144 ++++ .../variantrecalibration/VariantTreeNode.java | 104 +++ ...ntRecalibrationWalkersPerformanceTest.java | 65 ++ 12 files changed, 1926 insertions(+) create mode 100755 java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GenerateVariantClustersWalker.java create mode 100755 java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantClusteringModel.java create mode 100755 java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java create mode 100755 java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java create mode 100755 java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java create mode 100755 java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantNearestNeighborsModel.java create mode 100755 java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantOptimizationInterface.java create mode 100755 java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantOptimizationModel.java create mode 100755 java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java create mode 100755 java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantTree.java create mode 100755 java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantTreeNode.java create mode 100755 java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersPerformanceTest.java diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GenerateVariantClustersWalker.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GenerateVariantClustersWalker.java new file mode 100755 index 000000000..dcb2a7388 --- /dev/null +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GenerateVariantClustersWalker.java @@ -0,0 +1,237 @@ +/* + * Copyright (c) 2010 The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR + * THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package org.broadinstitute.sting.gatk.walkers.variantrecalibration; + +import org.broad.tribble.dbsnp.DbSNPFeature; +import org.broadinstitute.sting.gatk.contexts.AlignmentContext; +import org.broadinstitute.sting.gatk.contexts.ReferenceContext; +import org.broadinstitute.sting.gatk.contexts.variantcontext.VariantContext; +import org.broadinstitute.sting.gatk.datasources.simpleDataSources.ReferenceOrderedDataSource; +import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; +import org.broadinstitute.sting.gatk.refdata.tracks.RMDTrack; +import org.broadinstitute.sting.gatk.refdata.utils.helpers.DbSNPHelper; +import org.broadinstitute.sting.gatk.walkers.RodWalker; +import org.broadinstitute.sting.utils.BaseUtils; +import org.broadinstitute.sting.utils.collections.ExpandingArrayList; +import org.broadinstitute.sting.utils.StingException; +import org.broadinstitute.sting.commandline.Argument; + +import java.io.IOException; +import java.util.*; + +/** + * Takes variant calls as .vcf files, learns a Gaussian mixture model over the variant annotations producing calibrated variant cluster parameters which can be applied to other datasets + * + * @author rpoplin + * @since Feb 11, 2010 + * + * @help.summary Takes variant calls as .vcf files, learns a Gaussian mixture model over the variant annotations producing calibrated variant cluster parameters which can be applied to other datasets + */ + +public class GenerateVariantClustersWalker extends RodWalker, ExpandingArrayList> { + + ///////////////////////////// + // Command Line Arguments + ///////////////////////////// + @Argument(fullName="ignore_all_input_filters", shortName="ignoreAllFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false) + private boolean IGNORE_ALL_INPUT_FILTERS = false; + @Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false) + private String[] IGNORE_INPUT_FILTERS = null; + @Argument(fullName="use_annotation", shortName="an", doc="The names of the annotations which should used for calculations", required=true) + private String[] USE_ANNOTATIONS = null; + @Argument(fullName="clusterFile", shortName="clusterFile", doc="The output cluster file", required=true) + private String CLUSTER_FILENAME = "optimizer.cluster"; + @Argument(fullName="numGaussians", shortName="nG", doc="The number of Gaussians to be used when clustering", required=false) + private int NUM_GAUSSIANS = 6; + @Argument(fullName="numIterations", shortName="nI", doc="The number of iterations to be performed when clustering", required=false) + private int NUM_ITERATIONS = 10; + @Argument(fullName="minVarInCluster", shortName="minVar", doc="The minimum number of variants in a cluster to be considered a valid cluster. It can be used to prevent overfitting.", required=false) + private int MIN_VAR_IN_CLUSTER = 0; + @Argument(fullName = "path_to_Rscript", shortName = "Rscript", doc = "The path to your implementation of Rscript. For Broad users this is probably /broad/tools/apps/R-2.6.0/bin/Rscript", required = false) + private String PATH_TO_RSCRIPT = "/broad/tools/apps/R-2.6.0/bin/Rscript"; + @Argument(fullName = "path_to_resources", shortName = "resources", doc = "Path to resources folder holding the Sting R scripts.", required = false) + private String PATH_TO_RESOURCES = "R/"; + @Argument(fullName="weightKnowns", shortName="weightKnowns", doc="The weight for known variants during clustering", required=false) + private double WEIGHT_KNOWNS = 8.0; + @Argument(fullName="weightHapMap", shortName="weightHapMap", doc="The weight for known HapMap variants during clustering", required=false) + private double WEIGHT_HAPMAP = 120.0; + @Argument(fullName="weight1000Genomes", shortName="weight1000Genomes", doc="The weight for known 1000 Genomes Project variants during clustering", required=false) + private double WEIGHT_1000GENOMES = 12.0; + @Argument(fullName="weightMQ1", shortName="weightMQ1", doc="The weight for MQ1 dbSNP variants during clustering", required=false) + private double WEIGHT_MQ1 = 10.0; + + //@Argument(fullName="knn", shortName="knn", doc="The number of nearest neighbors to be used in the k-Nearest Neighbors model", required=false) + //private int NUM_KNN = 2000; + //@Argument(fullName = "optimization_model", shortName = "om", doc = "Optimization calculation model to employ -- GAUSSIAN_MIXTURE_MODEL is currently the default, while K_NEAREST_NEIGHBORS is also available for small callsets.", required = false) + private VariantOptimizationModel.Model OPTIMIZATION_MODEL = VariantOptimizationModel.Model.GAUSSIAN_MIXTURE_MODEL; + + ///////////////////////////// + // Private Member Variables + ///////////////////////////// + private ExpandingArrayList annotationKeys; + private Set ignoreInputFilterSet = null; + private int maxAC = 0; + + //--------------------------------------------------------------------------------------------------------------- + // + // initialize + // + //--------------------------------------------------------------------------------------------------------------- + + public void initialize() { + annotationKeys = new ExpandingArrayList(Arrays.asList(USE_ANNOTATIONS)); + + if( IGNORE_INPUT_FILTERS != null ) { + ignoreInputFilterSet = new TreeSet(Arrays.asList(IGNORE_INPUT_FILTERS)); + } + + boolean foundDBSNP = false; + final List dataSources = this.getToolkit().getRodDataSources(); + for( final ReferenceOrderedDataSource source : dataSources ) { + final RMDTrack rod = source.getReferenceOrderedData(); + if ( rod.getName().equals(DbSNPHelper.STANDARD_DBSNP_TRACK_NAME) ) { + foundDBSNP = true; + } + } + + if(!foundDBSNP) { + throw new StingException("dbSNP track is required. This calculation is critically dependent on being able to distinguish known and novel sites."); + } + } + + //--------------------------------------------------------------------------------------------------------------- + // + // map + // + //--------------------------------------------------------------------------------------------------------------- + + public ExpandingArrayList map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) { + + final ExpandingArrayList mapList = new ExpandingArrayList(); + + if( tracker == null ) { // For some reason RodWalkers get map calls with null trackers + return mapList; + } + + final double annotationValues[] = new double[annotationKeys.size()]; + + // todo -- do we really need to support multiple tracks -- logic is cleaner without this case -- what's the use case? + for( final VariantContext vc : tracker.getAllVariantContexts(ref, null, context.getLocation(), false, false) ) { + if( vc != null && !vc.getName().equals(DbSNPHelper.STANDARD_DBSNP_TRACK_NAME) && vc.isSNP() ) { + if( !vc.isFiltered() || IGNORE_ALL_INPUT_FILTERS || (ignoreInputFilterSet != null && ignoreInputFilterSet.containsAll(vc.getFilters())) ) { + int iii = 0; + for( final String key : annotationKeys ) { + annotationValues[iii++] = VariantGaussianMixtureModel.decodeAnnotation( key, vc ); + } + + final VariantDatum variantDatum = new VariantDatum(); + variantDatum.annotations = annotationValues; + variantDatum.isTransition = vc.getSNPSubstitutionType().compareTo(BaseUtils.BaseSubstitutionType.TRANSITION) == 0; + variantDatum.alleleCount = vc.getChromosomeCount(vc.getAlternateAllele(0)); // BUGBUG: assumes file has genotypes + if( variantDatum.alleleCount > maxAC ) { + maxAC = variantDatum.alleleCount; + } + + variantDatum.isKnown = false; + variantDatum.weight = 1.0; + + final DbSNPFeature dbsnp = DbSNPHelper.getFirstRealSNP(tracker.getReferenceMetaData(DbSNPHelper.STANDARD_DBSNP_TRACK_NAME)); + if( dbsnp != null ) { + variantDatum.isKnown = true; + variantDatum.weight = WEIGHT_KNOWNS; + if( DbSNPHelper.isHapmap( dbsnp ) ) { variantDatum.weight = WEIGHT_HAPMAP; } + else if( DbSNPHelper.is1000genomes( dbsnp ) ) { variantDatum.weight = WEIGHT_1000GENOMES; } + else if( DbSNPHelper.isMQ1( dbsnp ) ) { variantDatum.weight = WEIGHT_MQ1; } + } + + mapList.add( variantDatum ); + } + } + } + + return mapList; + } + + //--------------------------------------------------------------------------------------------------------------- + // + // reduce + // + //--------------------------------------------------------------------------------------------------------------- + + public ExpandingArrayList reduceInit() { + return new ExpandingArrayList(); + } + + public ExpandingArrayList reduce( final ExpandingArrayList mapValue, final ExpandingArrayList reduceSum ) { + reduceSum.addAll( mapValue ); + return reduceSum; + } + + public void onTraversalDone( ExpandingArrayList reduceSum ) { + + final VariantDataManager dataManager = new VariantDataManager( reduceSum, annotationKeys ); + reduceSum.clear(); // Don't need this ever again, clean up some memory + + logger.info( "There are " + dataManager.numVariants + " variants and " + dataManager.numAnnotations + " annotations." ); + logger.info( "The annotations are: " + annotationKeys ); + + dataManager.normalizeData(); // Each data point is now [ (x - mean) / standard deviation ] + + // Create either the Gaussian Mixture Model or the Nearest Neighbors model and run it + VariantGaussianMixtureModel theModel; + switch (OPTIMIZATION_MODEL) { + case GAUSSIAN_MIXTURE_MODEL: + theModel = new VariantGaussianMixtureModel( dataManager, NUM_GAUSSIANS, NUM_ITERATIONS, MIN_VAR_IN_CLUSTER, maxAC ); + break; + //case K_NEAREST_NEIGHBORS: + // theModel = new VariantNearestNeighborsModel( dataManager, TARGET_TITV, NUM_KNN ); + // break; + default: + throw new StingException( "Variant Optimization Model is unrecognized. Implemented options are GAUSSIAN_MIXTURE_MODEL and K_NEAREST_NEIGHBORS" ); + } + + theModel.run( CLUSTER_FILENAME ); + theModel.outputClusterReports( CLUSTER_FILENAME ); + + for( final String annotation : annotationKeys ) { + // Execute Rscript command to plot the optimization curve + // Print out the command line to make it clear to the user what is being executed and how one might modify it + final String rScriptCommandLine = PATH_TO_RSCRIPT + " " + PATH_TO_RESOURCES + "plot_ClusterReport.R" + " " + CLUSTER_FILENAME + "." + annotation + ".dat " + annotation; + System.out.println( rScriptCommandLine ); + + // Execute the RScript command to plot the table of truth values + try { + final Process p = Runtime.getRuntime().exec( rScriptCommandLine ); + p.waitFor(); + } catch (InterruptedException e) { + throw new StingException(e.getMessage()); + } catch ( IOException e ) { + throw new StingException( "Unable to execute RScript command: " + rScriptCommandLine ); + } + } + } + +} diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantClusteringModel.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantClusteringModel.java new file mode 100755 index 000000000..1191b5b30 --- /dev/null +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantClusteringModel.java @@ -0,0 +1,37 @@ +package org.broadinstitute.sting.gatk.walkers.variantrecalibration; + +/* + * Copyright (c) 2010 The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +/** + * Created by IntelliJ IDEA. + * User: rpoplin + * Date: Mar 2, 2010 + */ + +public interface VariantClusteringModel extends VariantOptimizationInterface { + public void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster, final String clusterFilename ); + //public void applyClusters( final VariantDatum[] data, final String outputPrefix ); +} diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java new file mode 100755 index 000000000..b00279dad --- /dev/null +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java @@ -0,0 +1,140 @@ +/* + * Copyright (c) 2010 The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR + * THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package org.broadinstitute.sting.gatk.walkers.variantrecalibration; + +import org.apache.log4j.Logger; +import org.broadinstitute.sting.utils.collections.ExpandingArrayList; +import org.broadinstitute.sting.utils.StingException; + +import java.io.PrintStream; + +/** + * Created by IntelliJ IDEA. + * User: rpoplin + * Date: Feb 26, 2010 + */ + +public class VariantDataManager { + + protected final static Logger logger = Logger.getLogger(VariantDataManager.class); + + public final VariantDatum[] data; + public final int numVariants; + public final int numAnnotations; + public final double[] meanVector; + public final double[] varianceVector; // This is really the standard deviation + public boolean isNormalized; + public final ExpandingArrayList annotationKeys; + + public VariantDataManager( final ExpandingArrayList dataList, final ExpandingArrayList _annotationKeys ) { + numVariants = dataList.size(); + data = dataList.toArray( new VariantDatum[numVariants] ); + if( numVariants <= 0 ) { + throw new StingException( "There are zero variants! (or possibly a problem with integer overflow)" ); + } + if( _annotationKeys == null ) { + numAnnotations = 0; + meanVector = null; + varianceVector = null; + } else { + numAnnotations = _annotationKeys.size(); + if( numAnnotations <= 0 ) { + throw new StingException( "There are zero annotations! (or possibly a problem with integer overflow)" ); + } + meanVector = new double[numAnnotations]; + varianceVector = new double[numAnnotations]; + isNormalized = false; + } + annotationKeys = _annotationKeys; + } + + public VariantDataManager( final ExpandingArrayList annotationLines ) { + data = null; + numVariants = 0; + numAnnotations = annotationLines.size(); + meanVector = new double[numAnnotations]; + varianceVector = new double[numAnnotations]; + isNormalized = true; + annotationKeys = new ExpandingArrayList(); + + int jjj = 0; + for( final String line : annotationLines ) { + final String[] vals = line.split(","); + annotationKeys.add(vals[1]); + meanVector[jjj] = Double.parseDouble(vals[2]); + varianceVector[jjj] = Double.parseDouble(vals[3]); + jjj++; + } + } + + public void normalizeData() { + boolean foundZeroVarianceAnnotation = false; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + final double theMean = mean(data, jjj); + final double theSTD = standardDeviation(data, theMean, jjj); + logger.info( annotationKeys.get(jjj) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); + if( theSTD < 1E-8 ) { + foundZeroVarianceAnnotation = true; + logger.warn("Zero variance is a problem: standard deviation = " + theSTD + " User must -exclude annotations with zero variance. Annotation = " + (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj))); + } else if( theSTD < 1E-2 ) { + logger.warn("Warning! Tiny variance. It is strongly recommended that you -exclude " + annotationKeys.get(jjj)); + } + meanVector[jjj] = theMean; + varianceVector[jjj] = theSTD; + for( int iii = 0; iii < numVariants; iii++ ) { + data[iii].annotations[jjj] = ( data[iii].annotations[jjj] - theMean ) / theSTD; + } + } + isNormalized = true; // Each data point is now [ (x - mean) / standard deviation ] + if( foundZeroVarianceAnnotation ) { + throw new StingException("Found annotations with zero variance. They must be excluded before proceeding."); + } + } + + private static double mean( final VariantDatum[] data, final int index ) { + double sum = 0.0; + final int numVars = data.length; + for( int iii = 0; iii < numVars; iii++ ) { + sum += (data[iii].annotations[index] / ((double) numVars)); + } + return sum; + } + + private static double standardDeviation( final VariantDatum[] data, final double mean, final int index ) { + double sum = 0.0; + final int numVars = data.length; + for( int iii = 0; iii < numVars; iii++ ) { + sum += ( ((data[iii].annotations[index] - mean)*(data[iii].annotations[index] - mean)) / ((double) numVars)); + } + return Math.sqrt( sum ); + } + + public void printClusterFileHeader( PrintStream outputFile ) { + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + outputFile.println("@!ANNOTATION," + annotationKeys.get(jjj) + "," + meanVector[jjj] + "," + varianceVector[jjj]); + } + } +} diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java new file mode 100755 index 000000000..117a50deb --- /dev/null +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java @@ -0,0 +1,41 @@ +package org.broadinstitute.sting.gatk.walkers.variantrecalibration; + +/* + * Copyright (c) 2010 The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +/** + * Created by IntelliJ IDEA. + * User: rpoplin + * Date: Feb 24, 2010 + */ + +public class VariantDatum { + public double[] annotations; + public boolean isTransition; + public boolean isKnown; + public double qual; + public double weight; + public int alleleCount; +} diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java new file mode 100755 index 000000000..c5a415445 --- /dev/null +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java @@ -0,0 +1,730 @@ +/* + * Copyright (c) 2010 The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR + * THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package org.broadinstitute.sting.gatk.walkers.variantrecalibration; + +import org.apache.log4j.Logger; +import org.broadinstitute.sting.gatk.contexts.variantcontext.VariantContext; +import org.broadinstitute.sting.utils.collections.ExpandingArrayList; +import org.broadinstitute.sting.utils.StingException; +import org.broadinstitute.sting.utils.text.XReadLines; + +import Jama.*; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.PrintStream; +import java.util.Random; +import java.util.regex.Pattern; + +/** + * Created by IntelliJ IDEA. + * User: rpoplin + * Date: Feb 26, 2010 + */ + +public final class VariantGaussianMixtureModel extends VariantOptimizationModel { + + protected final static Logger logger = Logger.getLogger(VariantGaussianMixtureModel.class); + + public final VariantDataManager dataManager; + private final int numGaussians; + private final int numIterations; + private final static long RANDOM_SEED = 91801305; + private final static Random rand = new Random( RANDOM_SEED ); + private final double MIN_PROB = 1E-7; + private final double MIN_SIGMA = 1E-5; + private final double MIN_DETERMINANT = 1E-5; + + private final double[][] mu; // The means for each cluster + private final Matrix[] sigma; // The covariance matrix for each cluster + private final Matrix[] sigmaInverse; + private final double[] pCluster; + private final double[] determinant; + private final double[] alleleCountFactorArray; + private final int minVarInCluster; + + private static final Pattern ANNOTATION_PATTERN = Pattern.compile("^@!ANNOTATION.*"); + private static final Pattern ALLELECOUNT_PATTERN = Pattern.compile("^@!ALLELECOUNT.*"); + private static final Pattern CLUSTER_PATTERN = Pattern.compile("^@!CLUSTER.*"); + + public VariantGaussianMixtureModel( final VariantDataManager _dataManager, final int _numGaussians, final int _numIterations, final int _minVarInCluster, final int maxAC ) { + dataManager = _dataManager; + numGaussians = _numGaussians; + numIterations = _numIterations; + + mu = new double[numGaussians][]; + sigma = new Matrix[numGaussians]; + determinant = new double[numGaussians]; + pCluster = new double[numGaussians]; + alleleCountFactorArray = new double[maxAC + 1]; + minVarInCluster = _minVarInCluster; + sigmaInverse = null; // This field isn't used during VariantOptimizer pass + } + + public VariantGaussianMixtureModel( final double _targetTITV, final String clusterFileName, final double backOffGaussianFactor ) { + super( _targetTITV ); + final ExpandingArrayList annotationLines = new ExpandingArrayList(); + final ExpandingArrayList alleleCountLines = new ExpandingArrayList(); + final ExpandingArrayList clusterLines = new ExpandingArrayList(); + + try { + for ( final String line : new XReadLines(new File( clusterFileName )) ) { + if( ANNOTATION_PATTERN.matcher(line).matches() ) { + annotationLines.add(line); + } else if( ALLELECOUNT_PATTERN.matcher(line).matches() ) { + alleleCountLines.add(line); + } else if( CLUSTER_PATTERN.matcher(line).matches() ) { + clusterLines.add(line); + } else { + throw new StingException("Malformed input file: " + clusterFileName); + } + } + } catch ( FileNotFoundException e ) { + throw new StingException("Can not find input file: " + clusterFileName); + } + + dataManager = new VariantDataManager( annotationLines ); + // Several of the clustering parameters aren't used the second time around in ApplyVariantClusters + numIterations = 0; + minVarInCluster = 0; + + // BUGBUG: move this parsing out of the constructor + numGaussians = clusterLines.size(); + mu = new double[numGaussians][dataManager.numAnnotations]; + double sigmaVals[][][] = new double[numGaussians][dataManager.numAnnotations][dataManager.numAnnotations]; + sigma = new Matrix[numGaussians]; + sigmaInverse = new Matrix[numGaussians]; + pCluster = new double[numGaussians]; + determinant = new double[numGaussians]; + + alleleCountFactorArray = new double[alleleCountLines.size() + 1]; + for( final String line : alleleCountLines ) { + final String[] vals = line.split(","); + alleleCountFactorArray[Integer.parseInt(vals[1])] = Double.parseDouble(vals[2]); + } + + int kkk = 0; + for( final String line : clusterLines ) { + final String[] vals = line.split(","); + pCluster[kkk] = Double.parseDouble(vals[1]); // BUGBUG: #define these magic index numbers, very easy to make a mistake here + for( int jjj = 0; jjj < dataManager.numAnnotations; jjj++ ) { + mu[kkk][jjj] = Double.parseDouble(vals[2+jjj]); + for( int ppp = 0; ppp < dataManager.numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] = Double.parseDouble(vals[2+dataManager.numAnnotations+(jjj*dataManager.numAnnotations)+ppp]) * backOffGaussianFactor; + } + } + + sigma[kkk] = new Matrix(sigmaVals[kkk]); + sigmaInverse[kkk] = sigma[kkk].inverse(); // Precompute all the inverses and determinants for use later + determinant[kkk] = sigma[kkk].det(); + kkk++; + } + + logger.info("Found " + numGaussians + " clusters and using " + dataManager.numAnnotations + " annotations: " + dataManager.annotationKeys); + } + + public final void run( final String clusterFileName ) { + + // Initialize the Allele Count prior + generateAlleleCountPrior(); + + // Simply cluster with all the variants. The knowns have been given more weight than the novels + logger.info("Clustering with " + dataManager.data.length + " variants."); + createClusters( dataManager.data, 0, numGaussians, clusterFileName, false ); + } + + private void generateAlleleCountPrior() { + + final double[] acExpectation = new double[alleleCountFactorArray.length]; + final double[] acActual = new double[alleleCountFactorArray.length]; + final int[] alleleCount = new int[alleleCountFactorArray.length]; + + double sumExpectation = 0.0; + for( int iii = 1; iii < alleleCountFactorArray.length; iii++ ) { + acExpectation[iii] = 1.0 / ((double) iii); + sumExpectation += acExpectation[iii]; + } + for( int iii = 1; iii < alleleCountFactorArray.length; iii++ ) { + acExpectation[iii] /= sumExpectation; // Turn acExpectation into a probability distribution + alleleCount[iii] = 0; + } + for( final VariantDatum datum : dataManager.data ) { + alleleCount[datum.alleleCount]++; + } + for( int iii = 1; iii < alleleCountFactorArray.length; iii++ ) { + acActual[iii] = ((double)alleleCount[iii]) / ((double)dataManager.data.length); // Turn acActual into a probability distribution + } + for( int iii = 1; iii < alleleCountFactorArray.length; iii++ ) { + alleleCountFactorArray[iii] = acExpectation[iii] / acActual[iii]; // Prior is (expected / observed) + } + } + + public final double getAlleleCountPrior( final int alleleCount ) { + return alleleCountFactorArray[alleleCount]; + } + + public final void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster, final String clusterFileName, final boolean useTITV ) { + + final int numVariants = data.length; + final int numAnnotations = data[0].annotations.length; + + final double[][] pVarInCluster = new double[numGaussians][numVariants]; // Probability that the variant is in that cluster = simply evaluate the multivariate Gaussian + + // loop control variables: + // iii - loop over data points + // jjj - loop over annotations (features) + // ppp - loop over annotations again (full rank covariance matrix) + // kkk - loop over clusters + // ttt - loop over EM iterations + + // Set up the initial random Gaussians + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster)); + mu[kkk] = data[rand.nextInt(numVariants)].annotations; + final double[][] randSigma = new double[numAnnotations][numAnnotations]; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + randSigma[ppp][jjj] = 0.5 + 0.5 * rand.nextDouble(); // data has been normalized so sigmas are centered at 1.0 + if(jjj != ppp) { randSigma[jjj][ppp] = 0.0; } // Sigma is a symmetric, positive-definite matrix + } + } + Matrix tmp = new Matrix(randSigma); + tmp = tmp.times(tmp.transpose()); + sigma[kkk] = tmp; + determinant[kkk] = sigma[kkk].det(); + } + + // The EM loop + for( int ttt = 0; ttt < numIterations; ttt++ ) { + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Expectation Step (calculate the probability that each data point is in each cluster) + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Maximization Step (move the clusters to maximize the sum probability of each data point) + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + maximizeGaussians( data, pVarInCluster, startCluster, stopCluster ); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Output cluster parameters at each iteration + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + printClusterParameters( clusterFileName + "." + (ttt+1) ); + + logger.info("Finished iteration " + (ttt+1) ); + } + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Output the final cluster parameters + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + printClusterParameters( clusterFileName ); + } + + private void printClusterParameters( final String clusterFileName ) { + try { + final PrintStream outputFile = new PrintStream( clusterFileName ); + dataManager.printClusterFileHeader( outputFile ); + for( int iii = 1; iii < alleleCountFactorArray.length; iii++ ) { + outputFile.print("@!ALLELECOUNT,"); + outputFile.println(iii + "," + alleleCountFactorArray[iii]); + } + + final int numAnnotations = mu[0].length; + final int numVariants = dataManager.numVariants; + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + if( pCluster[kkk] * numVariants > minVarInCluster ) { + final double sigmaVals[][] = sigma[kkk].getArray(); + outputFile.print("@!CLUSTER,"); + outputFile.print(pCluster[kkk] + ","); + for(int jjj = 0; jjj < numAnnotations; jjj++ ) { + outputFile.print(mu[kkk][jjj] + ","); + } + for(int jjj = 0; jjj < numAnnotations; jjj++ ) { + for(int ppp = 0; ppp < numAnnotations; ppp++ ) { + outputFile.print(sigmaVals[jjj][ppp] + ","); + } + } + outputFile.println(-1); + } + } + outputFile.close(); + } catch (Exception e) { + throw new StingException( "Unable to create output file: " + clusterFileName ); + } + } + + public static double decodeAnnotation( final String annotationKey, final VariantContext vc ) { + double value; + //if( annotationKey.equals("AB") && !vc.getAttributes().containsKey(annotationKey) ) { + // value = (0.5 - 0.005) + (0.01 * rand.nextDouble()); // HomVar calls don't have an allele balance + //} + if( annotationKey.equals("QUAL") ) { + value = vc.getPhredScaledQual(); + } else { + try { + value = Double.parseDouble( (String)vc.getAttribute( annotationKey ) ); + } catch( Exception e ) { + throw new StingException("No double value detected for annotation = " + annotationKey + + " in variant at " + vc.getLocation() + ", reported annotation value = " + vc.getAttribute( annotationKey ) ); + } + } + return value; + } + + public final double evaluateVariant( final VariantContext vc ) { + final double[] pVarInCluster = new double[numGaussians]; + final double[] annotations = new double[dataManager.numAnnotations]; + + for( int jjj = 0; jjj < dataManager.numAnnotations; jjj++ ) { + final double value = decodeAnnotation( dataManager.annotationKeys.get(jjj), vc ); + annotations[jjj] = (value - dataManager.meanVector[jjj]) / dataManager.varianceVector[jjj]; + } + + evaluateGaussiansForSingleVariant( annotations, pVarInCluster ); + + double sum = 0.0; + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + sum += pVarInCluster[kkk]; // * clusterTruePositiveRate[kkk]; + } + return sum; + } + + public final void outputClusterReports( final String outputPrefix ) { + final double STD_STEP = 0.2; + final double MAX_STD = 4.0; + final double MIN_STD = -4.0; + final int NUM_BINS = (int)Math.floor((Math.abs(MIN_STD) + Math.abs(MAX_STD)) / STD_STEP); + final int numAnnotations = dataManager.numAnnotations; + int totalCountsKnown = 0; + int totalCountsNovel = 0; + + final int counts[][][] = new int[numAnnotations][NUM_BINS][2]; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int iii = 0; iii < NUM_BINS; iii++ ) { + counts[jjj][iii][0] = 0; + counts[jjj][iii][1] = 0; + } + } + + for( VariantDatum datum : dataManager.data ) { + final int isKnown = ( datum.isKnown ? 1 : 0 ); + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + int histBin = (int)Math.round((datum.annotations[jjj]-MIN_STD) * (1.0 / STD_STEP)); + if(histBin < 0) { histBin = 0; } + if(histBin > NUM_BINS-1) { histBin = NUM_BINS-1; } + if(histBin >= 0 && histBin <= NUM_BINS-1) { + counts[jjj][histBin][isKnown]++; + } + } + if( isKnown == 1 ) { totalCountsKnown++; } + else { totalCountsNovel++; } + } + + int annIndex = 0; + for( final String annotation : dataManager.annotationKeys ) { + PrintStream outputFile; + try { + outputFile = new PrintStream( outputPrefix + "." + annotation + ".dat" ); + } catch (Exception e) { + throw new StingException( "Unable to create output file: " + outputPrefix + ".dat" ); + } + + outputFile.println("annotationValue,knownDist,novelDist"); + + for( int iii = 0; iii < NUM_BINS; iii++ ) { + final double annotationValue = (((double)iii * STD_STEP)+MIN_STD) * dataManager.varianceVector[annIndex] + dataManager.meanVector[annIndex]; + outputFile.println( annotationValue + "," + ( ((double)counts[annIndex][iii][1])/((double)totalCountsKnown) ) + + "," + ( ((double)counts[annIndex][iii][0])/((double)totalCountsNovel) )); + } + + annIndex++; + } + + // BUGBUG: next output the actual cluster on top by integrating out every other annotation + } + + public final void outputOptimizationCurve( final VariantDatum[] data, final String outputPrefix, final int desiredNumVariants ) { + + final int numVariants = data.length; + final boolean[] markedVariant = new boolean[numVariants]; + + final double MAX_QUAL = 100.0; + final double QUAL_STEP = 0.1; + final int NUM_BINS = (int) ((MAX_QUAL / QUAL_STEP) + 1); + + final int numKnownAtCut[] = new int[NUM_BINS]; + final int numNovelAtCut[] = new int[NUM_BINS]; + final double knownTiTvAtCut[] = new double[NUM_BINS]; + final double novelTiTvAtCut[] = new double[NUM_BINS]; + final double theCut[] = new double[NUM_BINS]; + + for( int iii = 0; iii < numVariants; iii++ ) { + markedVariant[iii] = false; + } + + PrintStream outputFile; + try { + outputFile = new PrintStream( outputPrefix + ".dat" ); + } catch (Exception e) { + throw new StingException( "Unable to create output file: " + outputPrefix + ".dat" ); + } + + int numKnown = 0; + int numNovel = 0; + int numKnownTi = 0; + int numKnownTv = 0; + int numNovelTi = 0; + int numNovelTv = 0; + boolean foundDesiredNumVariants = false; + int jjj = 0; + outputFile.println("pCut,numKnown,numNovel,knownTITV,novelTITV"); + for( double qCut = MAX_QUAL; qCut >= -0.001; qCut -= QUAL_STEP ) { + for( int iii = 0; iii < numVariants; iii++ ) { + if( !markedVariant[iii] ) { + if( data[iii].qual >= qCut ) { + markedVariant[iii] = true; + if( data[iii].isKnown ) { // known + numKnown++; + if( data[iii].isTransition ) { // transition + numKnownTi++; + } else { // transversion + numKnownTv++; + } + } else { // novel + numNovel++; + if( data[iii].isTransition ) { // transition + numNovelTi++; + } else { // transversion + numNovelTv++; + } + } + } + } + } + if( desiredNumVariants != 0 && !foundDesiredNumVariants && (numKnown + numNovel) >= desiredNumVariants ) { + logger.info( "Keeping variants with QUAL >= " + String.format("%.1f",qCut) + " results in a filtered set with: " ); + logger.info("\t" + numKnown + " known variants"); + logger.info("\t" + numNovel + " novel variants, (dbSNP rate = " + String.format("%.2f",((double) numKnown * 100.0) / ((double) numKnown + numNovel) ) + "%)"); + logger.info("\t" + String.format("%.4f known Ti/Tv ratio", ((double)numKnownTi) / ((double)numKnownTv))); + logger.info("\t" + String.format("%.4f novel Ti/Tv ratio", ((double)numNovelTi) / ((double)numNovelTv))); + foundDesiredNumVariants = true; + } + outputFile.println( qCut + "," + numKnown + "," + numNovel + "," + + ( numKnownTi == 0 || numKnownTv == 0 ? "NaN" : ( ((double)numKnownTi) / ((double)numKnownTv) ) ) + "," + + ( numNovelTi == 0 || numNovelTv == 0 ? "NaN" : ( ((double)numNovelTi) / ((double)numNovelTv) ) )); + + numKnownAtCut[jjj] = numKnown; + numNovelAtCut[jjj] = numNovel; + knownTiTvAtCut[jjj] = ( numKnownTi == 0 || numKnownTv == 0 ? 0.0 : ( ((double)numKnownTi) / ((double)numKnownTv) ) ); + novelTiTvAtCut[jjj] = ( numNovelTi == 0 || numNovelTv == 0 ? 0.0 : ( ((double)numNovelTi) / ((double)numNovelTv) ) ); + theCut[jjj] = qCut; + jjj++; + } + + // loop back through the data points looking for appropriate places to cut the data to get the target novel titv ratio + int checkQuantile = 0; + for( jjj = NUM_BINS-1; jjj >= 0; jjj-- ) { + boolean foundCut = false; + if( checkQuantile == 0 ) { + if( novelTiTvAtCut[jjj] >= 0.9 * targetTITV ) { + foundCut = true; + checkQuantile++; + } + } else if( checkQuantile == 1 ) { + if( novelTiTvAtCut[jjj] >= 0.95 * targetTITV ) { + foundCut = true; + checkQuantile++; + } + } else if( checkQuantile == 2 ) { + if( novelTiTvAtCut[jjj] >= 0.98 * targetTITV ) { + foundCut = true; + checkQuantile++; + } + } else if( checkQuantile == 3 ) { + if( novelTiTvAtCut[jjj] >= targetTITV ) { + foundCut = true; + checkQuantile++; + } + } else if( checkQuantile == 4 ) { + break; // break out + } + + if( foundCut ) { + logger.info( "Keeping variants with QUAL >= " + String.format("%.1f",theCut[jjj]) + " results in a filtered set with: " ); + logger.info("\t" + numKnownAtCut[jjj] + " known variants"); + logger.info("\t" + numNovelAtCut[jjj] + " novel variants, (dbSNP rate = " + + String.format("%.2f",((double) numKnownAtCut[jjj] * 100.0) / ((double) numKnownAtCut[jjj] + numNovelAtCut[jjj]) ) + "%)"); + logger.info("\t" + String.format("%.4f known Ti/Tv ratio", knownTiTvAtCut[jjj])); + logger.info("\t" + String.format("%.4f novel Ti/Tv ratio", novelTiTvAtCut[jjj])); + } + } + + outputFile.close(); + } + + + private void evaluateGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int startCluster, final int stopCluster ) { + + final int numAnnotations = data[0].annotations.length; + double likelihood = 0.0; + final double sigmaVals[][][] = new double[numGaussians][][]; + final double denom[] = new double[numGaussians]; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + sigmaVals[kkk] = sigma[kkk].inverse().getArray(); + denom[kkk] = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(Math.abs(determinant[kkk]), 0.5); + } + final double mult[] = new double[numAnnotations]; + for( int iii = 0; iii < data.length; iii++ ) { + double sumProb = 0.0; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + double sum = 0.0; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mult[jjj] = 0.0; + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + mult[jjj] += (data[iii].annotations[ppp] - mu[kkk][ppp]) * sigmaVals[kkk][ppp][jjj]; + } + } + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + sum += mult[jjj] * (data[iii].annotations[jjj] - mu[kkk][jjj]); + } + + pVarInCluster[kkk][iii] = pCluster[kkk] * (Math.exp( -0.5 * sum ) / denom[kkk]); + likelihood += pVarInCluster[kkk][iii]; + if(Double.isNaN(denom[kkk]) || determinant[kkk] < 0.5 * MIN_DETERMINANT) { + System.out.println("det = " + sigma[kkk].det()); + System.out.println("denom = " + denom[kkk]); + System.out.println("sumExp = " + sum); + System.out.println("pVar = " + pVarInCluster[kkk][iii]); + System.out.println("=-------="); + throw new StingException("Numerical Instability! determinant of covariance matrix <= 0. Try running with fewer clusters and then with better behaved annotation values."); + } + if(sum < 0.0) { + System.out.println("det = " + sigma[kkk].det()); + System.out.println("denom = " + denom[kkk]); + System.out.println("sumExp = " + sum); + System.out.println("pVar = " + pVarInCluster[kkk][iii]); + System.out.println("=-------="); + throw new StingException("Numerical Instability! covariance matrix no longer positive definite. Try running with fewer clusters and then with better behaved annotation values."); + } + if(pVarInCluster[kkk][iii] > 1.0) { + System.out.println("det = " + sigma[kkk].det()); + System.out.println("denom = " + denom[kkk]); + System.out.println("sumExp = " + sum); + System.out.println("pVar = " + pVarInCluster[kkk][iii]); + System.out.println("=-------="); + throw new StingException("Numerical Instability! probability distribution returns > 1.0. Try running with fewer clusters and then with better behaved annotation values."); + } + + if( pVarInCluster[kkk][iii] < MIN_PROB ) { // Very small numbers are a very big problem + pVarInCluster[kkk][iii] = MIN_PROB; // + MIN_PROB * rand.nextDouble(); + } + + sumProb += pVarInCluster[kkk][iii]; + } + + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + pVarInCluster[kkk][iii] /= sumProb; + pVarInCluster[kkk][iii] *= data[iii].weight; + } + } + + logger.info("Explained likelihood = " + String.format("%.5f",likelihood / data.length)); + } + + + private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster ) { + + final int numAnnotations = annotations.length; + final double mult[] = new double[numAnnotations]; + for( int kkk = 0; kkk < numGaussians; kkk++ ) { + final double sigmaVals[][] = sigmaInverse[kkk].getArray(); + double sum = 0.0; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mult[jjj] = 0.0; + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + mult[jjj] += (annotations[ppp] - mu[kkk][ppp]) * sigmaVals[ppp][jjj]; + } + } + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + sum += mult[jjj] * (annotations[jjj] - mu[kkk][jjj]); + } + + final double denom = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(determinant[kkk], 0.5); + pVarInCluster[kkk] = pCluster[kkk] * (Math.exp( -0.5 * sum )) / denom; + } + } + + + private void maximizeGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int startCluster, final int stopCluster ) { + + final int numVariants = data.length; + final int numAnnotations = data[0].annotations.length; + final double sigmaVals[][][] = new double[numGaussians][numAnnotations][numAnnotations]; + + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mu[kkk][jjj] = 0.0; + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] = 0.0; + } + } + } + double sumPK = 0.0; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + double sumProb = 0.0; + for( int iii = 0; iii < numVariants; iii++ ) { + final double prob = pVarInCluster[kkk][iii]; + sumProb += prob; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mu[kkk][jjj] += prob * data[iii].annotations[jjj]; + } + } + + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mu[kkk][jjj] /= sumProb; + } + + for( int iii = 0; iii < numVariants; iii++ ) { + final double prob = pVarInCluster[kkk][iii]; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] += prob * (data[iii].annotations[jjj]-mu[kkk][jjj]) * (data[iii].annotations[ppp]-mu[kkk][ppp]); + } + } + } + + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + if( sigmaVals[kkk][jjj][ppp] < MIN_SIGMA ) { // Very small numbers are a very big problem + sigmaVals[kkk][jjj][ppp] = MIN_SIGMA;// + MIN_SIGMA * rand.nextDouble(); + } + sigmaVals[kkk][ppp][jjj] = sigmaVals[kkk][jjj][ppp]; // sigma must be a symmetric matrix + } + } + + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] /= sumProb; + } + } + + sigma[kkk] = new Matrix(sigmaVals[kkk]); + determinant[kkk] = sigma[kkk].det(); + + pCluster[kkk] = sumProb / numVariants; + sumPK += pCluster[kkk]; + } + + // ensure pCluster sums to one, it doesn't automatically due to very small numbers getting capped + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + pCluster[kkk] /= sumPK; + } + + /* + // Clean up extra big or extra small clusters + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + if( pCluster[kkk] > 0.45 ) { // This is a very large cluster compared to all the others + System.out.println("!! Found very large cluster! Busting it up into smaller clusters."); + final int numToReplace = 3; + final Matrix savedSigma = sigma[kkk]; + for( int rrr = 0; rrr < numToReplace; rrr++ ) { + // Find an example variant in the large cluster, drawn randomly + int randVarIndex = -1; + boolean foundVar = false; + while( !foundVar ) { + randVarIndex = rand.nextInt( numVariants ); + final double probK = pVarInCluster[kkk][randVarIndex]; + boolean inClusterK = true; + for( int ccc = startCluster; ccc < stopCluster; ccc++ ) { + if( pVarInCluster[ccc][randVarIndex] > probK ) { + inClusterK = false; + break; + } + } + if( inClusterK ) { foundVar = true; } + } + + // Find a place to put the example variant + if( rrr == 0 ) { // Replace the big cluster that kicked this process off + mu[kkk] = data[randVarIndex].annotations; + pCluster[kkk] = 1.0 / ((double) (stopCluster-startCluster)); + } else { // Replace the cluster with the minimum prob + double minProb = pCluster[startCluster]; + int minClusterIndex = startCluster; + for( int ccc = startCluster; ccc < stopCluster; ccc++ ) { + if( pCluster[ccc] < minProb ) { + minProb = pCluster[ccc]; + minClusterIndex = ccc; + } + } + mu[minClusterIndex] = data[randVarIndex].annotations; + sigma[minClusterIndex] = savedSigma; + //for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + // for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + // sigma[minClusterIndex].set(jjj, ppp, sigma[minClusterIndex].get(jjj, ppp) - 0.06 + 0.12 * rand.nextDouble()); + // } + //} + pCluster[minClusterIndex] = 0.5 / ((double) (stopCluster-startCluster)); + } + } + } + } + */ + + /* + // Replace extremely small clusters with another random draw from the dataset + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + if( pCluster[kkk] < 0.0005 * (1.0 / ((double) (stopCluster-startCluster))) || + determinant[kkk] < MIN_DETERMINANT ) { // This is a very small cluster compared to all the others + logger.info("!! Found singular cluster! Initializing a new random cluster."); + pCluster[kkk] = 0.1 / ((double) (stopCluster-startCluster)); + mu[kkk] = data[rand.nextInt(numVariants)].annotations; + final double[][] randSigma = new double[numAnnotations][numAnnotations]; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + randSigma[ppp][jjj] = 0.50 + 0.5 * rand.nextDouble(); // data is normalized so this is centered at 1.0 + if(jjj != ppp) { randSigma[jjj][ppp] = 0.0; } // Sigma is a symmetric, positive-definite matrix + } + } + Matrix tmp = new Matrix(randSigma); + tmp = tmp.times(tmp.transpose()); + sigma[kkk] = tmp; + determinant[kkk] = sigma[kkk].det(); + } + } + + // renormalize pCluster since things might have changed due to the previous step + sumPK = 0.0; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + sumPK += pCluster[kkk]; + } + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + pCluster[kkk] /= sumPK; + } + */ + } +} \ No newline at end of file diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantNearestNeighborsModel.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantNearestNeighborsModel.java new file mode 100755 index 000000000..f7867502c --- /dev/null +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantNearestNeighborsModel.java @@ -0,0 +1,78 @@ +package org.broadinstitute.sting.gatk.walkers.variantrecalibration; + +import org.broadinstitute.sting.utils.StingException; + +import java.io.PrintStream; + +/* + * Copyright (c) 2010 The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +/** + * Created by IntelliJ IDEA. + * User: rpoplin + * Date: Mar 1, 2010 + */ + +public final class VariantNearestNeighborsModel extends VariantOptimizationModel { + + private final int numKNN; + + public VariantNearestNeighborsModel( VariantDataManager _dataManager, final double _targetTITV, final int _numKNN ) { + super( _targetTITV ); + //dataManager = _dataManager; + numKNN = _numKNN; + } + + public void run( final String outputPrefix ) { + + throw new StingException( "Nearest Neighbors model hasn't been updated yet." ); + /* + final int numVariants = dataManager.numVariants; + + final double[] pTrueVariant = new double[numVariants]; + + final VariantTree vTree = new VariantTree( numKNN ); + vTree.createTreeFromData( dataManager.data ); + + System.out.println("Finished creating the kd-tree."); + + for(int iii = 0; iii < numVariants; iii++) { + pTrueVariant[iii] = calcTruePositiveRateFromTITV( vTree.calcNeighborhoodTITV( dataManager.data[iii] ) ); + } + + PrintStream outputFile; + try { + outputFile = new PrintStream( outputPrefix + ".knn.optimize" ); + } catch (Exception e) { + throw new StingException( "Unable to create output file: " + outputPrefix + ".knn.optimize" ); + } + for(int iii = 0; iii < numVariants; iii++) { + outputFile.print(String.format("%.4f",pTrueVariant[iii]) + ","); + outputFile.println( (dataManager.data[iii].isTransition ? 1 : 0) + + "," + (dataManager.data[iii].isKnown? 1 : 0)); + } + */ + } +} diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantOptimizationInterface.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantOptimizationInterface.java new file mode 100755 index 000000000..15ad69810 --- /dev/null +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantOptimizationInterface.java @@ -0,0 +1,36 @@ +package org.broadinstitute.sting.gatk.walkers.variantrecalibration; + +/* + * Copyright (c) 2010 The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +/** + * Created by IntelliJ IDEA. + * User: rpoplin + * Date: Feb 26, 2010 + */ + +public interface VariantOptimizationInterface { + public void run( String outputPrefix ); +} diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantOptimizationModel.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantOptimizationModel.java new file mode 100755 index 000000000..a908156be --- /dev/null +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantOptimizationModel.java @@ -0,0 +1,70 @@ +package org.broadinstitute.sting.gatk.walkers.variantrecalibration; + +/* + * Copyright (c) 2010 The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +/** + * Created by IntelliJ IDEA. + * User: rpoplin + * Date: Feb 26, 2010 + */ + +public abstract class VariantOptimizationModel implements VariantOptimizationInterface { + + public enum Model { + GAUSSIAN_MIXTURE_MODEL, + K_NEAREST_NEIGHBORS + } + + protected final double targetTITV; + + public VariantOptimizationModel() { + targetTITV = 0.0; + } + + public VariantOptimizationModel( final double _targetTITV ) { + targetTITV = _targetTITV; + } + + public final double calcTruePositiveRateFromTITV( final double _titv ) { + double titv = _titv; + if( titv > targetTITV ) { titv -= 2.0f*(titv-targetTITV); } + if( titv < 0.5 ) { titv = 0.5; } + return ( (titv - 0.5) / (targetTITV - 0.5) ); + //if( titv < 0.0 ) { titv = 0.0; } + //return ( titv / targetTITV ); + } + + public final double calcTruePositiveRateFromKnownTITV( final double knownTITV, final double _novelTITV, final double overallTITV, final double knownAlphaFactor ) { + + final double tprTarget = calcTruePositiveRateFromTITV( overallTITV ); + double novelTITV = _novelTITV; + if( novelTITV > knownTITV ) { novelTITV -= 2.0f*(novelTITV-knownTITV); } + if( novelTITV < 0.5 ) { novelTITV = 0.5; } + final double tprKnown = ( (novelTITV - 0.5) / (knownTITV - 0.5) ); + + return ( knownAlphaFactor * tprKnown ) + ( (1.0 - knownAlphaFactor) * tprTarget ); + } +} diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java new file mode 100755 index 000000000..0617edf16 --- /dev/null +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java @@ -0,0 +1,244 @@ +/* + * Copyright (c) 2010 The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR + * THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ + +package org.broadinstitute.sting.gatk.walkers.variantrecalibration; + +import org.broad.tribble.dbsnp.DbSNPFeature; +import org.broad.tribble.vcf.*; +import org.broadinstitute.sting.gatk.contexts.AlignmentContext; +import org.broadinstitute.sting.gatk.contexts.ReferenceContext; +import org.broadinstitute.sting.gatk.contexts.variantcontext.VariantContext; +import org.broadinstitute.sting.gatk.datasources.simpleDataSources.ReferenceOrderedDataSource; +import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; +import org.broadinstitute.sting.gatk.refdata.VariantContextAdaptors; +import org.broadinstitute.sting.gatk.refdata.tracks.RMDTrack; +import org.broadinstitute.sting.gatk.refdata.utils.helpers.DbSNPHelper; +import org.broadinstitute.sting.gatk.walkers.RodWalker; +import org.broadinstitute.sting.utils.*; +import org.broadinstitute.sting.utils.collections.ExpandingArrayList; +import org.broadinstitute.sting.commandline.Argument; +import org.broadinstitute.sting.utils.genotype.vcf.VCFReader; +import org.broadinstitute.sting.utils.genotype.vcf.VCFUtils; +import org.broadinstitute.sting.utils.genotype.vcf.VCFWriter; + +import java.io.File; +import java.io.IOException; +import java.util.*; + +/** + * Applies calibrated variant cluster parameters to variant calls to produce an accurate and informative variant quality score + * + * @author rpoplin + * @since Mar 17, 2010 + * + * @help.summary Applies calibrated variant cluster parameters to variant calls to produce an accurate and informative variant quality score + */ + +public class VariantRecalibrator extends RodWalker, ExpandingArrayList> { + + ///////////////////////////// + // Command Line Arguments + ///////////////////////////// + @Argument(fullName="target_titv", shortName="titv", doc="The expected Ti/Tv ratio to display on optimization curve output figures. (~~2.1 for whole genome experiments)", required=false) + private double TARGET_TITV = 2.1; + @Argument(fullName="backOff", shortName="backOff", doc="The Gaussian back off factor, used to prevent overfitting by spreading out the Gaussians.", required=false) + private double BACKOFF_FACTOR = 1.0; + @Argument(fullName="desired_num_variants", shortName="dV", doc="The desired number of variants to keep in a theoretically filtered set", required=false) + private int DESIRED_NUM_VARIANTS = 0; + @Argument(fullName="ignore_all_input_filters", shortName="ignoreAllFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false) + private boolean IGNORE_ALL_INPUT_FILTERS = false; + @Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false) + private String[] IGNORE_INPUT_FILTERS = null; + @Argument(fullName="known_prior", shortName="knownPrior", doc="A prior on the quality of known variants, a phred scaled probability of being true.", required=false) + private int KNOWN_QUAL_PRIOR = 9; + @Argument(fullName="novel_prior", shortName="novelPrior", doc="A prior on the quality of novel variants, a phred scaled probability of being true.", required=false) + private int NOVEL_QUAL_PRIOR = 2; + @Argument(fullName="quality_scale_factor", shortName="qScale", doc="Multiply all final quality scores by this value. Needed to normalize the quality scores.", required=false) + private double QUALITY_SCALE_FACTOR = 50.0; + @Argument(fullName="output_prefix", shortName="output", doc="The prefix added to output VCF file name and optimization curve pdf file name", required=false) + private String OUTPUT_PREFIX = "optimizer"; + @Argument(fullName="clusterFile", shortName="clusterFile", doc="The output cluster file", required=true) + private String CLUSTER_FILENAME = "optimizer.cluster"; + //@Argument(fullName = "optimization_model", shortName = "om", doc = "Optimization calculation model to employ -- GAUSSIAN_MIXTURE_MODEL is currently the default, while K_NEAREST_NEIGHBORS is also available for small callsets.", required = false) + private VariantOptimizationModel.Model OPTIMIZATION_MODEL = VariantOptimizationModel.Model.GAUSSIAN_MIXTURE_MODEL; + @Argument(fullName = "path_to_Rscript", shortName = "Rscript", doc = "The path to your implementation of Rscript. For Broad users this is probably /broad/tools/apps/R-2.6.0/bin/Rscript", required = false) + private String PATH_TO_RSCRIPT = "/broad/tools/apps/R-2.6.0/bin/Rscript"; + @Argument(fullName = "path_to_resources", shortName = "resources", doc = "Path to resources folder holding the Sting R scripts.", required = false) + private String PATH_TO_RESOURCES = "R/"; + + ///////////////////////////// + // Private Member Variables + ///////////////////////////// + private VariantGaussianMixtureModel theModel = null; + private VCFWriter vcfWriter; + private Set ignoreInputFilterSet = null; + private final ArrayList ALLOWED_FORMAT_FIELDS = new ArrayList(); + + //--------------------------------------------------------------------------------------------------------------- + // + // initialize + // + //--------------------------------------------------------------------------------------------------------------- + + public void initialize() { + if( !PATH_TO_RESOURCES.endsWith("/") ) { PATH_TO_RESOURCES = PATH_TO_RESOURCES + "/"; } + + if( IGNORE_INPUT_FILTERS != null ) { + ignoreInputFilterSet = new TreeSet(Arrays.asList(IGNORE_INPUT_FILTERS)); + } + + switch (OPTIMIZATION_MODEL) { + case GAUSSIAN_MIXTURE_MODEL: + theModel = new VariantGaussianMixtureModel( TARGET_TITV, CLUSTER_FILENAME, BACKOFF_FACTOR ); + break; + //case K_NEAREST_NEIGHBORS: + // theModel = new VariantNearestNeighborsModel( dataManager, TARGET_TITV, NUM_KNN ); + // break; + default: + throw new StingException( "Variant Optimization Model is unrecognized. Implemented options are GAUSSIAN_MIXTURE_MODEL and K_NEAREST_NEIGHBORS" ); + } + + ALLOWED_FORMAT_FIELDS.add(VCFGenotypeRecord.GENOTYPE_KEY); // copied from VariantsToVCF + ALLOWED_FORMAT_FIELDS.add(VCFGenotypeRecord.GENOTYPE_QUALITY_KEY); + ALLOWED_FORMAT_FIELDS.add(VCFGenotypeRecord.DEPTH_KEY); + ALLOWED_FORMAT_FIELDS.add(VCFGenotypeRecord.GENOTYPE_POSTERIORS_TRIPLET_KEY); + + // setup the header fields + final Set hInfo = new HashSet(); + hInfo.addAll(VCFUtils.getHeaderFields(getToolkit())); + hInfo.add(new VCFInfoHeaderLine("OQ", 1, VCFInfoHeaderLine.INFO_TYPE.Float, "The original variant quality score")); + hInfo.add(new VCFHeaderLine("source", "VariantOptimizer")); + vcfWriter = new VCFWriter( new File(OUTPUT_PREFIX + ".vcf") ); + final TreeSet samples = new TreeSet(); + final List dataSources = this.getToolkit().getRodDataSources(); + for( final ReferenceOrderedDataSource source : dataSources ) { + final RMDTrack rod = source.getReferenceOrderedData(); + if( rod.getRecordType().equals(VCFRecord.class) ) { + final VCFReader reader = new VCFReader(rod.getFile()); + final Set vcfSamples = reader.getHeader().getGenotypeSamples(); + samples.addAll(vcfSamples); + reader.close(); + } + } + final VCFHeader vcfHeader = new VCFHeader(hInfo, samples); + vcfWriter.writeHeader(vcfHeader); + + boolean foundDBSNP = false; + for( final ReferenceOrderedDataSource source : dataSources ) { + final RMDTrack rod = source.getReferenceOrderedData(); + if( rod.getName().equals(DbSNPHelper.STANDARD_DBSNP_TRACK_NAME) ) { + foundDBSNP = true; + } + } + + if(!foundDBSNP) { + throw new StingException("dbSNP track is required. This calculation is critically dependent on being able to distinguish known and novel sites."); + } + } + + //--------------------------------------------------------------------------------------------------------------- + // + // map + // + //--------------------------------------------------------------------------------------------------------------- + + public ExpandingArrayList map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) { + final ExpandingArrayList mapList = new ExpandingArrayList(); + + if( tracker == null ) { // For some reason RodWalkers get map calls with null trackers + return mapList; + } + + for( final VariantContext vc : tracker.getAllVariantContexts(ref, null, context.getLocation(), false, false) ) { + if( vc != null && !vc.getName().equals(DbSNPHelper.STANDARD_DBSNP_TRACK_NAME) && vc.isSNP() ) { + final VCFRecord vcf = VariantContextAdaptors.toVCF(vc, ref.getBase(), ALLOWED_FORMAT_FIELDS, false, false); + if( !vc.isFiltered() || IGNORE_ALL_INPUT_FILTERS || (ignoreInputFilterSet != null && ignoreInputFilterSet.containsAll(vc.getFilters())) ) { + final VariantDatum variantDatum = new VariantDatum(); + variantDatum.isTransition = vc.getSNPSubstitutionType().compareTo(BaseUtils.BaseSubstitutionType.TRANSITION) == 0; + + final DbSNPFeature dbsnp = DbSNPHelper.getFirstRealSNP(tracker.getReferenceMetaData(DbSNPHelper.STANDARD_DBSNP_TRACK_NAME)); + variantDatum.isKnown = dbsnp != null; + variantDatum.alleleCount = vc.getChromosomeCount(vc.getAlternateAllele(0)); // BUGBUG: assumes file has genotypes + + final double acPrior = theModel.getAlleleCountPrior( variantDatum.alleleCount ); + final double knownPrior = ( variantDatum.isKnown ? QualityUtils.qualToProb(KNOWN_QUAL_PRIOR) : QualityUtils.qualToProb(NOVEL_QUAL_PRIOR) ); + final double pTrue = theModel.evaluateVariant( vc ) * acPrior * knownPrior; + + variantDatum.qual = QUALITY_SCALE_FACTOR * QualityUtils.phredScaleErrorRate( Math.max(1.0 - pTrue, 0.000000001) ); // BUGBUG: don't have a normalizing constant, so need to scale up qual scores arbitrarily + mapList.add( variantDatum ); + + vcf.addInfoField("OQ", ((Double)vc.getPhredScaledQual()).toString() ); + vcf.setQual( variantDatum.qual ); + vcf.setFilterString(VCFRecord.PASSES_FILTERS); + vcfWriter.addRecord( vcf ); + + } else { // not a SNP or is filtered so just dump it out to the VCF file + vcfWriter.addRecord( vcf ); + } + } + + } + + return mapList; + } + + //--------------------------------------------------------------------------------------------------------------- + // + // reduce + // + //--------------------------------------------------------------------------------------------------------------- + + public ExpandingArrayList reduceInit() { + return new ExpandingArrayList(); + } + + public ExpandingArrayList reduce( final ExpandingArrayList mapValue, final ExpandingArrayList reduceSum ) { + reduceSum.addAll( mapValue ); + return reduceSum; + } + + public void onTraversalDone( ExpandingArrayList reduceSum ) { + + vcfWriter.close(); + + final VariantDataManager dataManager = new VariantDataManager( reduceSum, theModel.dataManager.annotationKeys ); + reduceSum.clear(); // Don't need this ever again, clean up some memory + + theModel.outputOptimizationCurve( dataManager.data, OUTPUT_PREFIX, DESIRED_NUM_VARIANTS ); + + // Execute Rscript command to plot the optimization curve + // Print out the command line to make it clear to the user what is being executed and how one might modify it + final String rScriptCommandLine = PATH_TO_RSCRIPT + " " + PATH_TO_RESOURCES + "plot_OptimizationCurve.R" + " " + OUTPUT_PREFIX + ".dat" + " " + TARGET_TITV; + System.out.println( rScriptCommandLine ); + + // Execute the RScript command to plot the table of truth values + try { + Runtime.getRuntime().exec( rScriptCommandLine ); + } catch ( IOException e ) { + throw new StingException( "Unable to execute RScript command: " + rScriptCommandLine ); + } + } +} + diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantTree.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantTree.java new file mode 100755 index 000000000..54a4606a4 --- /dev/null +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantTree.java @@ -0,0 +1,144 @@ +package org.broadinstitute.sting.gatk.walkers.variantrecalibration; + +import org.broadinstitute.sting.utils.StingException; + +/* + * Copyright (c) 2010 The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +/** + * Created by IntelliJ IDEA. + * User: rpoplin + * Date: Feb 24, 2010 + */ + +public class VariantTree { + private final VariantTreeNode root; + private final int numKNN; + + public VariantTree( final int _numKNN ) { + root = new VariantTreeNode(); + numKNN = _numKNN; + } + + public void createTreeFromData( VariantDatum[] data ) { + root.cutData( data, 0, 0, data[0].annotations.length ); + } + + public double calcNeighborhoodTITV( final VariantDatum variant ) { + + double[] distances; + + // Grab the subset of points that are approximately near this point + final VariantDatum[] data = getBin( variant.annotations, root ); + if( data.length < numKNN ) { + throw new StingException( "Bin is too small. Should be > " + numKNN ); + } + + // Find the X nearest points in the subset + final double[] originalDistances = calcDistances( variant.annotations, data ); + distances = originalDistances.clone(); + quickSort( distances, 0, distances.length-1 ); // BUGBUG: distances.length or distances.length-1 + + final double minDistance = distances[numKNN - 1]; + + // Calculate probability of being true based on this set of SNPs + int numTi = 0; + int numTv = 0; + for( int iii = 0; iii < distances.length; iii++ ) { + if( originalDistances[iii] <= minDistance ) { + if( data[iii].isTransition ) { numTi++; } + else { numTv++; } + } + } + + return ((double) numTi) / ((double) numTv); + } + + private VariantDatum[] getBin( final double[] variant, final VariantTreeNode node ) { + if( node.variants != null ) { + return node.variants; + } else { + if( variant[node.cutDim] < node.cutValue ) { + return getBin( variant, node.left ); + } else { + return getBin( variant, node.right ); + } + } + } + + private double[] calcDistances( final double[] variant, final VariantDatum[] data ) { + final double[] distSquared = new double[data.length]; + int iii = 0; + for( final VariantDatum variantDatum : data ) { + distSquared[iii] = 0.0; + int jjj = 0; + for( final double value : variantDatum.annotations) { + final double diff = variant[jjj] - value; + distSquared[iii] += ( diff * diff ); + jjj++; + } + iii++; + } + + return distSquared; + } + + public static int partition(final double arr[], final int left, final int right) + { + int i = left, j = right; + double tmp; + final double pivot = arr[(left + right) / 2]; + + while (i <= j) { + while (arr[i] < pivot) { + i++; + } + while (arr[j] > pivot) { + j--; + } + if (i <= j) { + tmp = arr[i]; + arr[i] = arr[j]; + arr[j] = tmp; + i++; + j--; + } + } + + return i; + } + + public static void quickSort(final double arr[], final int left, final int right) { + final int index = partition(arr, left, right); + if (left < index - 1) { + quickSort(arr, left, index - 1); + } + if (index < right) { + quickSort(arr, index, right); + } + } + + +} diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantTreeNode.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantTreeNode.java new file mode 100755 index 000000000..e3257bc25 --- /dev/null +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantTreeNode.java @@ -0,0 +1,104 @@ +package org.broadinstitute.sting.gatk.walkers.variantrecalibration; + +/* + * Copyright (c) 2010 The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +/** + * Created by IntelliJ IDEA. + * User: rpoplin + * Date: Feb 24, 2010 + */ + +public class VariantTreeNode { + + public VariantTreeNode left; + public VariantTreeNode right; + public int cutDim; + public double cutValue; + public VariantDatum[] variants; + + private final int minBinSize = 8000; // BUGBUG: must be larger than number of kNN + + public VariantTreeNode() { + left = null; + right = null; + variants = null; + cutDim = -1; + cutValue = -1; + } + + public final void cutData( final VariantDatum[] data, final int depth, final int lastCutDepth, final int numAnnotations ) { + + cutDim = depth % numAnnotations; + + if( depth != lastCutDepth && (cutDim == (lastCutDepth % numAnnotations)) ) { // Base case: we've tried to cut on all the annotations + variants = data; + return; + } + + final double[] values = new double[data.length]; + for( int iii = 0; iii < data.length; iii++ ) { + values[iii] = data[iii].annotations[cutDim]; + } + + final double[] sortedValues = values.clone(); + VariantTree.quickSort( sortedValues, 0, values.length-1 ); // BUGBUG: values.length or values.length-1 + + final int lowPivotIndex = Math.round(0.40f * sortedValues.length); + final int highPivotIndex = Math.round(0.60f * sortedValues.length); + final double lowPivot = sortedValues[lowPivotIndex]; + final double highPivot = sortedValues[highPivotIndex]; + cutValue = highPivot; + + int numLow = 0; + int numHigh = 0; + for( int iii = 0; iii < data.length; iii++ ) { + if( values[iii] < highPivot ) { numLow++; } + if( values[iii] >= lowPivot ) { numHigh++; } + } + + // If cutting here makes the bin too small then don't cut + if( numLow < minBinSize || numHigh < minBinSize || (numLow == numHigh && numLow == data.length) ) { + cutValue = sortedValues[0]; + right = new VariantTreeNode(); + right.cutData(data, depth+1, lastCutDepth, numAnnotations); + } else { + final VariantDatum[] leftData = new VariantDatum[numLow]; + final VariantDatum[] rightData = new VariantDatum[numHigh]; + int leftIndex = 0; + int rightIndex = 0; + for( int iii = 0; iii < data.length; iii++ ) { + if( values[iii] < highPivot ) { leftData[leftIndex++] = data[iii]; } + if( values[iii] >= lowPivot ) { rightData[rightIndex++] = data[iii]; } + } + + left = new VariantTreeNode(); + right = new VariantTreeNode(); + left.cutData(leftData, depth+1, depth, numAnnotations); + right.cutData(rightData, depth+1, depth, numAnnotations); + } + } + +} diff --git a/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersPerformanceTest.java b/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersPerformanceTest.java new file mode 100755 index 000000000..13baa958c --- /dev/null +++ b/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersPerformanceTest.java @@ -0,0 +1,65 @@ +package org.broadinstitute.sting.gatk.walkers.variantrecalibration; + +import org.broadinstitute.sting.WalkerTest; +import org.junit.Test; + +import java.util.*; +import java.io.File; + +public class VariantRecalibrationWalkersPerformanceTest extends WalkerTest { + static HashMap paramsFiles = new HashMap(); + + @Test + public void testGenerateVariantClusters() { + HashMap e = new HashMap(); + e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "b39bde091254e90490b0d7bf9f0ef24a" ); + + for ( Map.Entry entry : e.entrySet() ) { + String vcf = entry.getKey(); + String md5 = entry.getValue(); + + WalkerTest.WalkerTestSpec spec = new WalkerTest.WalkerTestSpec( + "-R " + oneKGLocation + "reference/human_b36_both.fasta" + + " --DBSNP /humgen/gsa-scr1/GATK_Data/dbsnp_129_b36.rod" + + " -T GenerateVariantClusters" + + " -B input,VCF," + vcf + + " -nG 6" + + " -nI 5" + + " --ignore_filter GATK_STANDARD" + + " -an QD -an HRun -an SB" + + " -clusterFile /dev/null", + 0, // just one output file + new ArrayList(0)); + List result = executeTest("testGenerateVariantClusters", spec).getFirst(); + paramsFiles.put(vcf, result.get(0).getAbsolutePath()); + } + } + + @Test + public void testVariantRecalibrator() { + HashMap e = new HashMap(); + e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "412bdb2eb4ca8f7ee9dfb39cda676c95" ); + + for ( Map.Entry entry : e.entrySet() ) { + String vcf = entry.getKey(); + String md5 = entry.getValue(); + String paramsFile = paramsFiles.get(vcf); + System.out.printf("PARAMS FOR %s is %s%n", vcf, paramsFile); + if ( paramsFile != null ) { + File file = createTempFile("cluster",".vcf"); + WalkerTestSpec spec = new WalkerTestSpec( + "-R " + oneKGLocation + "reference/human_b36_both.fasta" + + " --DBSNP /humgen/gsa-scr1/GATK_Data/dbsnp_129_b36.rod" + + " -T VariantRecalibrator" + + " -B input,VCF," + vcf + + " --ignore_filter GATK_STANDARD" + + " -output /dev/null" + + " -clusterFile " + validationDataLocation + "clusterFile", + 0, + new ArrayList(0)); + + executeTest("testVariantRecalibrator", spec); + } + } + } +}