From 06a212e612d055838175c537e233a3e33e9c2cda Mon Sep 17 00:00:00 2001 From: rpoplin Date: Wed, 24 Mar 2010 19:43:10 +0000 Subject: [PATCH] Adding VariantConcordanceROCCurveWalker to create ROC curves comparing concordance between optimized call sets and validation truth sets in VCF format in order to evaluate performance of variant optimizer independently of achieving a particular novel ti/tv ratio. Added option to ignore only the specified filters in the input call sets via --ignore_filter . Added option to provide a prior estimate of error for known snps via --known_prior . The het and hom calls are clustered independently. Infrastructure in place to use titv of known snps to inform p(true) of novel snps. Tweaked protection against overfitting based on suggestions from several people. Minor edits to AnalyzeAnnotations. git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@3071 348d0f76-0448-11de-a6fe-93d51630548a --- R/plot_variantROCCurve.R | 21 + .../AnalyzeAnnotationsWalker.java | 4 +- .../AnnotationDataManager.java | 6 +- .../variantoptimizer/AnnotationDatum.java | 2 +- .../ApplyVariantClustersWalker.java | 66 ++- .../VariantClusteringModel.java | 2 +- .../VariantConcordanceROCCurveWalker.java | 281 +++++++++++++ .../variantoptimizer/VariantDataManager.java | 12 +- .../variantoptimizer/VariantDatum.java | 1 + .../VariantGaussianMixtureModel.java | 391 +++++++++++++++--- .../VariantNearestNeighborsModel.java | 5 +- .../VariantOptimizationModel.java | 6 + .../variantoptimizer/VariantOptimizer.java | 104 ++--- 13 files changed, 756 insertions(+), 145 deletions(-) create mode 100755 R/plot_variantROCCurve.R create mode 100755 java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantConcordanceROCCurveWalker.java diff --git a/R/plot_variantROCCurve.R b/R/plot_variantROCCurve.R new file mode 100755 index 000000000..1b7ae9292 --- /dev/null +++ b/R/plot_variantROCCurve.R @@ -0,0 +1,21 @@ +#!/broad/tools/apps/R-2.6.0/bin/Rscript + +args <- commandArgs(TRUE) +verbose = TRUE + +input = args[1] + +data = read.table(input,sep=",",head=T) +numCurves = (length(data) - 1)/3 +maxSpec = max(data[,(1:numCurves)*3]) + +outfile = paste(input, ".variantROCCurve.pdf", sep="") +pdf(outfile, height=7, width=7) + +par(cex=1.3) +plot(data$specificity1,data$sensitivity1, type="n", xlim=c(0,maxSpec),ylim=c(0,1),xlab="1 - Specificity",ylab="Sensitivity") +for(iii in 1:numCurves) { + points(data[,iii*3],data[,(iii-1)*3+2],lwd=3,type="l",col=iii) +} +legend("bottomright", names(data)[(0:(numCurves-1))*3+1], col=1:numCurves,lwd=3) +dev.off() \ No newline at end of file diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnalyzeAnnotationsWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnalyzeAnnotationsWalker.java index 4b3c5f3de..721b93195 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnalyzeAnnotationsWalker.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnalyzeAnnotationsWalker.java @@ -108,7 +108,7 @@ public class AnalyzeAnnotationsWalker extends RodWalker { // First find out if this variant is in the truth sets boolean isInTruthSet = false; boolean isTrueVariant = false; - for( ReferenceOrderedDatum rod : tracker.getAllRods() ) { + for( final ReferenceOrderedDatum rod : tracker.getAllRods() ) { if( rod != null && rod.getName().toUpperCase().startsWith("TRUTH") ) { isInTruthSet = true; @@ -132,7 +132,7 @@ public class AnalyzeAnnotationsWalker extends RodWalker { } // Add each annotation in this VCF Record to the dataManager - for( ReferenceOrderedDatum rod : tracker.getAllRods() ) { + for( final ReferenceOrderedDatum rod : tracker.getAllRods() ) { if( rod != null && rod instanceof RodVCF && !rod.getName().toUpperCase().startsWith("TRUTH") ) { final RodVCF variant = (RodVCF) rod; if( variant.isSNP() ) { diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDataManager.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDataManager.java index 64d2420d1..8d76b24a8 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDataManager.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDataManager.java @@ -63,7 +63,7 @@ public class AnnotationDataManager { // Loop over each annotation in the vcf record final Map infoField = variant.getInfoValues(); infoField.put("QUAL", ((Double)variant.getQual()).toString() ); // add QUAL field to annotations - for( String annotationKey : infoField.keySet() ) { + for( final String annotationKey : infoField.keySet() ) { float value; try { @@ -102,7 +102,7 @@ public class AnnotationDataManager { System.out.println( "\nFinished reading variants into memory. Executing RScript commands:" ); // For each annotation we've seen - for( String annotationKey : data.keySet() ) { + for( final String annotationKey : data.keySet() ) { PrintStream output; try { @@ -117,7 +117,7 @@ public class AnnotationDataManager { // Bin SNPs and calculate truth metrics for each bin thisAnnotationBin.clearBin(); - for( AnnotationDatum datum : data.get( annotationKey ) ) { + for( final AnnotationDatum datum : data.get( annotationKey ) ) { thisAnnotationBin.combine( datum ); if( thisAnnotationBin.numVariants( AnnotationDatum.FULL_SET ) >= MAX_VARIANTS_PER_BIN ) { // This annotation bin is full output.println( thisAnnotationBin.value + "\t" + thisAnnotationBin.calcTiTv( AnnotationDatum.FULL_SET ) + "\t" + thisAnnotationBin.calcDBsnpRate() + "\t" + thisAnnotationBin.calcTPrate() + diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDatum.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDatum.java index 5fb668ade..fb321b899 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDatum.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDatum.java @@ -59,7 +59,7 @@ public class AnnotationDatum implements Comparator { } } - public AnnotationDatum( float _value ) { + public AnnotationDatum( final float _value ) { value = _value; ti = new int[NUM_SETS]; diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java index 41fad687d..e1668a67a 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java @@ -12,10 +12,7 @@ import org.broadinstitute.sting.utils.genotype.vcf.*; import java.io.File; import java.io.IOException; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Set; -import java.util.TreeSet; +import java.util.*; /* * Copyright (c) 2010 The Broad Institute @@ -60,8 +57,12 @@ public class ApplyVariantClustersWalker extends RodWalker ignoreInputFilterSet = null; + //--------------------------------------------------------------------------------------------------------------- // @@ -88,6 +91,10 @@ public class ApplyVariantClustersWalker extends RodWalker(Arrays.asList(IGNORE_INPUT_FILTERS)); + } + switch (OPTIMIZATION_MODEL) { case GAUSSIAN_MIXTURE_MODEL: theModel = new VariantGaussianMixtureModel( CLUSTER_FILENAME ); @@ -101,14 +108,14 @@ public class ApplyVariantClustersWalker extends RodWalker hInfo = new HashSet(); + final Set hInfo = new HashSet(); hInfo.addAll(VCFUtils.getHeaderFields(getToolkit())); hInfo.add(new VCFInfoHeaderLine("OQ", 1, VCFInfoHeaderLine.INFO_TYPE.Float, "The original variant quality score")); hInfo.add(new VCFHeaderLine("source", "VariantOptimizer")); vcfWriter = new VCFWriter( new File(OUTPUT_PREFIX + ".vcf") ); - TreeSet samples = new TreeSet(); + final TreeSet samples = new TreeSet(); SampleUtils.getUniquifiedSamplesFromRods(getToolkit(), samples, new HashMap, String>()); - VCFHeader vcfHeader = new VCFHeader(hInfo, samples); + final VCFHeader vcfHeader = new VCFHeader(hInfo, samples); vcfWriter.writeHeader(vcfHeader); } @@ -126,22 +133,41 @@ public class ApplyVariantClustersWalker extends RodWalker numHom; //vc.getHetCount() > vc.getHomVarCount(); // BUGBUG: what to do here for multi sample calls? + + final double pTrue = theModel.evaluateVariant( rodVCF.getInfoValues(), rodVCF.getQual(), variantDatum.isHet ); + final double recalQual = QualityUtils.phredScaleErrorRate( Math.max( 1.0 - pTrue, 0.000000001) ); + + if( variantDatum.isKnown ) { + variantDatum.qual = 0.5 * recalQual + 0.5 * KNOWN_VAR_QUAL_PRIOR; + } else { + variantDatum.qual = recalQual; + } mapList.add( variantDatum ); + rodVCF.mCurrentRecord.addInfoField("OQ", ((Double)rodVCF.getQual()).toString() ); + rodVCF.mCurrentRecord.setQual( variantDatum.qual ); + rodVCF.mCurrentRecord.setFilterString(VCFRecord.UNFILTERED); + vcfWriter.addRecord( rodVCF.mCurrentRecord ); + + } else { // not a SNP or is filtered so just dump it out to the VCF file vcfWriter.addRecord( rodVCF.mCurrentRecord ); } @@ -169,6 +195,8 @@ public class ApplyVariantClustersWalker extends RodWalker reduceSum ) { + vcfWriter.close(); + final VariantDataManager dataManager = new VariantDataManager( reduceSum, theModel.dataManager.annotationKeys ); reduceSum.clear(); // Don't need this ever again, clean up some memory diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java index 1f42a61de..891f21443 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java @@ -32,6 +32,6 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; */ public interface VariantClusteringModel extends VariantOptimizationInterface { - public void createClusters( final VariantDatum[] data ); + public void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster ); //public void applyClusters( final VariantDatum[] data, final String outputPrefix ); } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantConcordanceROCCurveWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantConcordanceROCCurveWalker.java new file mode 100755 index 000000000..652af1936 --- /dev/null +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantConcordanceROCCurveWalker.java @@ -0,0 +1,281 @@ +package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; + +import org.broadinstitute.sting.gatk.contexts.AlignmentContext; +import org.broadinstitute.sting.gatk.contexts.ReferenceContext; +import org.broadinstitute.sting.gatk.contexts.variantcontext.VariantContext; +import org.broadinstitute.sting.gatk.datasources.simpleDataSources.ReferenceOrderedDataSource; +import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; +import org.broadinstitute.sting.gatk.refdata.RodVCF; +import org.broadinstitute.sting.gatk.walkers.RodWalker; +import org.broadinstitute.sting.utils.ExpandingArrayList; +import org.broadinstitute.sting.utils.Pair; +import org.broadinstitute.sting.utils.StingException; +import org.broadinstitute.sting.utils.cmdLine.Argument; + +import java.io.IOException; +import java.io.PrintStream; +import java.util.HashMap; + +/* + * Copyright (c) 2010 The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +/** + * Calculates variant concordance with the given truth sets and plots ROC curves for each input call set. + * + * @author rpoplin + * @since Mar 18, 2010 + * + * @help.summary Calculates variant concordance with the given truth sets and plots ROC curves for each input call set. + */ + +public class VariantConcordanceROCCurveWalker extends RodWalker>, HashMap>> { + + ///////////////////////////// + // Command Line Arguments + ///////////////////////////// + @Argument(fullName="output_prefix", shortName="output", doc="The prefix added to output VCF file name and optimization curve pdf file name", required=false) + private String OUTPUT_PREFIX = "optimizer"; + @Argument(fullName = "path_to_Rscript", shortName = "Rscript", doc = "The path to your implementation of Rscript. For Broad users this is probably /broad/tools/apps/R-2.6.0/bin/Rscript", required = false) + private String PATH_TO_RSCRIPT = "/broad/tools/apps/R-2.6.0/bin/Rscript"; + @Argument(fullName = "path_to_resources", shortName = "resources", doc = "Path to resources folder holding the Sting R scripts.", required = false) + private String PATH_TO_RESOURCES = "R/"; + + ///////////////////////////// + // Private Member Variables + ///////////////////////////// + private final ExpandingArrayList inputRodNames = new ExpandingArrayList(); + private int numCurves; + private int[] trueNegGlobal; + private int[] falseNegGlobal; + private String sampleName = null; + + //--------------------------------------------------------------------------------------------------------------- + // + // initialize + // + //--------------------------------------------------------------------------------------------------------------- + + public void initialize() { + if( !PATH_TO_RESOURCES.endsWith("/") ) { PATH_TO_RESOURCES = PATH_TO_RESOURCES + "/"; } + + for( ReferenceOrderedDataSource rod : this.getToolkit().getRodDataSources() ) { + if( rod != null && !rod.getName().toUpperCase().startsWith("TRUTH") ) { + if( rod.getReferenceOrderedData().iterator().next().get(0) instanceof RodVCF ) { + inputRodNames.add(rod.getName()); + if( sampleName == null ) { + sampleName = ((RodVCF)rod.getReferenceOrderedData().iterator().next().get(0)).getSampleNames()[0]; // BUGBUG: single sample calls only for now + } + } + } + } + numCurves = inputRodNames.size(); + trueNegGlobal = new int[numCurves]; + falseNegGlobal = new int[numCurves]; + for( int kkk = 0; kkk < numCurves; kkk++ ) { + trueNegGlobal[kkk] = 0; + falseNegGlobal[kkk] = 0; + } + } + + //--------------------------------------------------------------------------------------------------------------- + // + // map + // + //--------------------------------------------------------------------------------------------------------------- + + public ExpandingArrayList> map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) { + + final ExpandingArrayList> mapList = new ExpandingArrayList>(); + + if( tracker == null ) { // For some reason RodWalkers get map calls with null trackers + return mapList; + } + + boolean isInTruthSet = false; + boolean isTrueVariant = false; + + for( final VariantContext vc : tracker.getAllVariantContexts(null, context.getLocation(), false, false) ) { + if( vc != null && vc.getName().toUpperCase().startsWith("TRUTH") ) { + if( !vc.getGenotype(sampleName).isNoCall() ) { + isInTruthSet = true; + + if( !vc.getGenotype(sampleName).isHomRef() ) { + isTrueVariant = true; + } + } + //if( vc.isPolymorphic() ) { //BUGBUG: I don't think this is the right thing to do here, there are many polymorphic sites in the truth data because there are many samples + // isTrueVariant = true; + //} + } + } + + if( !isInTruthSet ) { + return mapList; // jump out if this location isn't in the truth set + } + + int curveIndex = 0; + for( final String inputName : inputRodNames ) { + final VariantContext vc = tracker.getVariantContext( inputName, null, context.getLocation(), false ); // assuming single variant per track per location + + if( vc != null && vc.isPolymorphic() && !vc.isFiltered() ) { + final VariantDatum variantDatum = new VariantDatum(); + variantDatum.qual = vc.getPhredScaledQual(); + variantDatum.isTrueVariant = isTrueVariant; + mapList.add( new Pair(inputName, variantDatum) ); + } else { // Either not in the call set at all, is a monomorphic call, or was filtered out + if( isTrueVariant ) { // ... but it is a true variant so this is a false negative call + falseNegGlobal[curveIndex]++; + } else { // ... and it is not a variant site so this is a true negative call + trueNegGlobal[curveIndex]++; + } + } + curveIndex++; + } + + return mapList; + } + + //--------------------------------------------------------------------------------------------------------------- + // + // reduce + // + //--------------------------------------------------------------------------------------------------------------- + + public HashMap> reduceInit() { + final HashMap> init = new HashMap>(); + for( final String inputName : inputRodNames ) { + init.put( inputName, new ExpandingArrayList() ); + } + return init; + } + + public HashMap> reduce( final ExpandingArrayList> mapValue, final HashMap> reduceSum ) { + for( Pair value : mapValue ) { + final ExpandingArrayList list = reduceSum.get(value.getFirst()); + list.add(value.getSecond()); + reduceSum.put(value.getFirst(),list); + } + return reduceSum; + } + + public void onTraversalDone( HashMap> reduceSum ) { + + final int NUM_CURVES = numCurves; + final HashMap dataManagerMap = new HashMap(); + for( final String inputName : inputRodNames ) { + System.out.println("Creating data manager for: " + inputName); + dataManagerMap.put(inputName, new VariantDataManager( reduceSum.get(inputName), null )); + } + reduceSum.clear(); // Don't need this ever again, clean up some memory + + final double[] minQual = new double[NUM_CURVES]; + final double[] maxQual = new double[NUM_CURVES]; + final double[] incrementQual = new double[NUM_CURVES]; + final double[] qualCut = new double[NUM_CURVES]; + + final int NUM_STEPS = 200; + + int curveIndex = 0; + for( final String inputName : inputRodNames ) { + final VariantDataManager dataManager = dataManagerMap.get(inputName); + minQual[curveIndex] = dataManager.data[0].qual; + maxQual[curveIndex] = dataManager.data[0].qual; + for( int iii = 1; iii < dataManager.data.length; iii++ ) { + final double qual = dataManager.data[iii].qual; + if( qual < minQual[curveIndex] ) { minQual[curveIndex] = qual; } + else if( qual > maxQual[curveIndex] ) { maxQual[curveIndex] = qual; } + } + incrementQual[curveIndex] = (maxQual[curveIndex] - minQual[curveIndex]) / ((double)NUM_STEPS); + qualCut[curveIndex] = minQual[curveIndex]; + curveIndex++; + } + + final int[] truePos = new int[NUM_CURVES]; + final int[] falsePos = new int[NUM_CURVES]; + final int[] trueNeg = new int[NUM_CURVES]; + final int[] falseNeg = new int[NUM_CURVES]; + + PrintStream outputFile; + try { + outputFile = new PrintStream( OUTPUT_PREFIX + ".dat" ); + } catch (Exception e) { + throw new StingException( "Unable to create output file: " + OUTPUT_PREFIX + ".dat" ); + } + int jjj = 1; + for( final String inputName : inputRodNames ) { + outputFile.print(inputName + ",sensitivity" + jjj + ",specificity" + jjj + ","); + jjj++; + } + outputFile.println("sentinel"); + + for( int step = 0; step < NUM_STEPS; step++ ) { + curveIndex = 0; + for( final String inputName : inputRodNames ) { + final VariantDataManager dataManager = dataManagerMap.get(inputName); + truePos[curveIndex] = 0; + falsePos[curveIndex] = 0; + trueNeg[curveIndex] = 0; + falseNeg[curveIndex] = 0; + final int NUM_VARIANTS = dataManager.data.length; + for( int iii = 0; iii < NUM_VARIANTS; iii++ ) { + if( dataManager.data[iii].qual >= qualCut[curveIndex] ) { // this var is in this hypothetical call set + if( dataManager.data[iii].isTrueVariant ) { + truePos[curveIndex]++; + } else { + falsePos[curveIndex]++; + } + } else { // this var is out of this hypothetical call set + if( dataManager.data[iii].isTrueVariant ) { + falseNeg[curveIndex]++; + } else { + trueNeg[curveIndex]++; + } + } + } + final double sensitivity = ((double) truePos[curveIndex]) / ((double) truePos[curveIndex] + falseNegGlobal[curveIndex] + falseNeg[curveIndex]); + final double specificity = ((double) trueNegGlobal[curveIndex] + trueNeg[curveIndex]) / + ((double) falsePos[curveIndex] + trueNegGlobal[curveIndex] + trueNeg[curveIndex]); + outputFile.print( String.format("%.4f,%.4f,%.4f,", qualCut[curveIndex], sensitivity, 1.0 - specificity) ); + qualCut[curveIndex] += incrementQual[curveIndex]; + curveIndex++; + } + outputFile.println("-1"); + } + + outputFile.close(); + + // Execute Rscript command to plot the optimization curve + // Print out the command line to make it clear to the user what is being executed and how one might modify it + final String rScriptCommandLine = PATH_TO_RSCRIPT + " " + PATH_TO_RESOURCES + "plot_variantROCCurve.R" + " " + OUTPUT_PREFIX + ".dat"; + System.out.println( rScriptCommandLine ); + + // Execute the RScript command to plot the table of truth values + try { + Runtime.getRuntime().exec( rScriptCommandLine ); + } catch ( IOException e ) { + throw new StingException( "Unable to execute RScript command: " + rScriptCommandLine ); + } + } +} diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java index 634987a0d..775d9a521 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java @@ -77,8 +77,8 @@ public class VariantDataManager { annotationKeys = new ExpandingArrayList(); int jjj = 0; - for( String line : annotationLines ) { - String[] vals = line.split(","); + for( final String line : annotationLines ) { + final String[] vals = line.split(","); annotationKeys.add(vals[1]); meanVector[jjj] = Double.parseDouble(vals[2]); varianceVector[jjj] = Double.parseDouble(vals[3]); @@ -87,12 +87,15 @@ public class VariantDataManager { } public void normalizeData() { + boolean foundZeroVarianceAnnotation = false; for( int jjj = 0; jjj < numAnnotations; jjj++ ) { final double theMean = mean(data, jjj); final double theSTD = standardDeviation(data, theMean, jjj); System.out.println( (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj)) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); if( theSTD < 1E-8 ) { - throw new StingException("Zero variance is a problem: standard deviation = " + theSTD); + foundZeroVarianceAnnotation = true; + System.out.println("Zero variance is a problem: standard deviation = " + theSTD); + System.out.println("User must -exclude annotations with zero variance. Annotation = " + (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj))); } meanVector[jjj] = theMean; varianceVector[jjj] = theSTD; @@ -101,6 +104,9 @@ public class VariantDataManager { } } isNormalized = true; // Each data point is now [ (x - mean) / standard deviation ] + if( foundZeroVarianceAnnotation ) { + throw new StingException("Found annotations with zero variance. They must be excluded before proceeding."); + } } private static double mean( final VariantDatum[] data, final int index ) { diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDatum.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDatum.java index 469fa90e5..663304d31 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDatum.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDatum.java @@ -36,5 +36,6 @@ public class VariantDatum { public boolean isTransition; public boolean isKnown; public boolean isTrueVariant; + public boolean isHet; public double qual; } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java index 79d487db4..3ef776c45 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java @@ -55,6 +55,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel private final double[][] mu; // The means for the clusters private final double[][] sigma; // The variances for the clusters, sigma is really sigma^2 private final double[] pCluster; + private final boolean[] isHetCluster; private final int[] numMaxClusterKnown; private final int[] numMaxClusterNovel; private final double[] clusterTITV; @@ -68,12 +69,13 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel public VariantGaussianMixtureModel( final VariantDataManager _dataManager, final double _targetTITV, final int _numGaussians, final int _numIterations, final int _minVarInCluster ) { super( _targetTITV ); dataManager = _dataManager; - numGaussians = _numGaussians; + numGaussians = ( _numGaussians % 2 == 0 ? _numGaussians : _numGaussians + 1 ); numIterations = _numIterations; mu = new double[numGaussians][]; sigma = new double[numGaussians][]; pCluster = new double[numGaussians]; + isHetCluster = null; numMaxClusterKnown = new int[numGaussians]; numMaxClusterNovel = new int[numGaussians]; clusterTITV = new double[numGaussians]; @@ -83,8 +85,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel public VariantGaussianMixtureModel( final String clusterFileName ) { super( 0.0 ); - ExpandingArrayList annotationLines = new ExpandingArrayList(); - ExpandingArrayList clusterLines = new ExpandingArrayList(); + final ExpandingArrayList annotationLines = new ExpandingArrayList(); + final ExpandingArrayList clusterLines = new ExpandingArrayList(); try { for ( String line : new xReadLines(new File( clusterFileName )) ) { @@ -102,24 +104,28 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel dataManager = new VariantDataManager( annotationLines ); numIterations = 0; - pCluster = null; numMaxClusterKnown = null; numMaxClusterNovel = null; clusterTITV = null; minVarInCluster = 0; + //BUGBUG: move this parsing out of the constructor numGaussians = clusterLines.size(); mu = new double[numGaussians][dataManager.numAnnotations]; sigma = new double[numGaussians][dataManager.numAnnotations]; + pCluster = new double[numGaussians]; + isHetCluster = new boolean[numGaussians]; clusterTruePositiveRate = new double[numGaussians]; int kkk = 0; for( String line : clusterLines ) { - String[] vals = line.split(","); - clusterTruePositiveRate[kkk] = Double.parseDouble(vals[4]); + final String[] vals = line.split(","); + isHetCluster[kkk] = Integer.parseInt(vals[1]) == 1; + pCluster[kkk] = Double.parseDouble(vals[2]); + clusterTruePositiveRate[kkk] = Double.parseDouble(vals[6]); //BUGBUG: #define these magic index numbers, very easy to make a mistake here for( int jjj = 0; jjj < dataManager.numAnnotations; jjj++ ) { - mu[kkk][jjj] = Double.parseDouble(vals[5+jjj]); - sigma[kkk][jjj] = Double.parseDouble(vals[5+dataManager.numAnnotations+jjj]); + mu[kkk][jjj] = Double.parseDouble(vals[7+jjj]); + sigma[kkk][jjj] = Double.parseDouble(vals[7+dataManager.numAnnotations+jjj]) * 3; //BUGBUG: *1.3, suggestion by Nick to prevent GMM from over fitting and producing low likelihoods for most points } kkk++; } @@ -129,13 +135,28 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel public final void run( final String clusterFileName ) { + final int MAX_VARS = 1000000; //BUGBUG: make this a command line argument + // Create the subset of the data to cluster with int numNovel = 0; + int numKnown = 0; + int numHet = 0; + int numHom = 0; for( final VariantDatum datum : dataManager.data ) { - if( !datum.isKnown ) { + if( datum.isKnown ) { + numKnown++; + } else { numNovel++; } + if( datum.isHet ) { + numHet++; + } else { + numHom++; + } } + + // This block of code is used to cluster with novels + 1.5x knowns mixed together + /* VariantDatum[] data; // Grab a set of data that is all of the novel variants plus 1.5x as many known variants drawn at random @@ -167,9 +188,235 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel printClusters( clusterFileName ); //System.out.println("Applying clusters to all variants..."); //applyClusters( dataManager.data, outputPrefix ); // Using all the data + */ + + + + + // This block of code is to cluster knowns and novels separately + /* + final VariantDatum[] dataNovel = new VariantDatum[Math.min(numNovel,MAX_VARS)]; + final VariantDatum[] dataKnown = new VariantDatum[Math.min(numKnown,MAX_VARS)]; + + //BUGBUG: This is ugly + int jjj = 0; + if(numNovel <= MAX_VARS) { + for( final VariantDatum datum : dataManager.data ) { + if( !datum.isKnown ) { + dataNovel[jjj++] = datum; + } + } + } else { + System.out.println("Capped at " + MAX_VARS + " novel variants."); + while( jjj < MAX_VARS ) { + final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; + if( !datum.isKnown ) { + dataNovel[jjj++] = datum; + } + } + } + + int iii = 0; + if(numKnown <= MAX_VARS) { + for( final VariantDatum datum : dataManager.data ) { + if( datum.isKnown ) { + dataKnown[iii++] = datum; + } + } + } else { + System.out.println("Capped at " + MAX_VARS + " known variants."); + while( iii < MAX_VARS ) { + final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; + if( datum.isKnown ) { + dataKnown[iii++] = datum; + } + } + } + + + System.out.println("Clustering with " + Math.min(numNovel,MAX_VARS) + " novel variants."); + createClusters( dataNovel, 0, numGaussians / 2 ); + System.out.println("Clustering with " + Math.min(numKnown,MAX_VARS) + " known variants."); + createClusters( dataKnown, numGaussians / 2, numGaussians ); + System.out.println("Outputting cluster parameters..."); + printClusters( clusterFileName ); + + */ + + // This block of code is to cluster het and hom calls separately, but mixing together knowns and novels + final VariantDatum[] dataHet = new VariantDatum[Math.min(numHet,MAX_VARS)]; + final VariantDatum[] dataHom = new VariantDatum[Math.min(numHom,MAX_VARS)]; + + //BUGBUG: This is ugly + int jjj = 0; + if(numHet <= MAX_VARS) { + for( final VariantDatum datum : dataManager.data ) { + if( datum.isHet ) { + dataHet[jjj++] = datum; + } + } + } else { + System.out.println("Found " + numHet + " het variants but capped at clustering with " + MAX_VARS + "."); + while( jjj < MAX_VARS ) { + final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; + if( datum.isHet ) { + dataHet[jjj++] = datum; + } + } + } + + int iii = 0; + if(numHom <= MAX_VARS) { + for( final VariantDatum datum : dataManager.data ) { + if( !datum.isHet ) { + dataHom[iii++] = datum; + } + } + } else { + System.out.println("Found " + numHom + " hom variants but capped at clustering with " + MAX_VARS + "."); + while( iii < MAX_VARS ) { + final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; + if( !datum.isHet ) { + dataHom[iii++] = datum; + } + } + } + + System.out.println("Clustering with " + Math.min(numHet,MAX_VARS) + " het variants."); + createClusters( dataHet, 0, numGaussians / 2 ); + System.out.println("Clustering with " + Math.min(numHom,MAX_VARS) + " hom variants."); + createClusters( dataHom, numGaussians / 2, numGaussians ); + System.out.println("Outputting cluster parameters..."); + printClusters( clusterFileName ); + } - public final void createClusters( final VariantDatum[] data ) { + + + public final void createClusters( final VariantDatum[] data, int startCluster, int stopCluster ) { + + final int numVariants = data.length; + final int numAnnotations = data[0].annotations.length; + + final double[][] pVarInCluster = new double[numGaussians][numVariants]; + final double[] probTi = new double[numGaussians]; + final double[] probTv = new double[numGaussians]; + final double[] probKnown = new double[numGaussians]; + final double[] probNovel = new double[numGaussians]; + final double[] probKnownTi = new double[numGaussians]; + final double[] probKnownTv = new double[numGaussians]; + final double[] probNovelTi = new double[numGaussians]; + final double[] probNovelTv = new double[numGaussians]; + + // loop control variables: + // iii - loop over data points + // jjj - loop over annotations (features) + // kkk - loop over clusters + // ttt - loop over EM iterations + + // Set up the initial random Gaussians + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + numMaxClusterKnown[kkk] = 0; + numMaxClusterNovel[kkk] = 0; + pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster)); + mu[kkk] = data[rand.nextInt(numVariants)].annotations; + final double[] randSigma = new double[numAnnotations]; + if( dataManager.isNormalized ) { + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble(); + } + } else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + randSigma[jjj] = dataManager.varianceVector[jjj] + ((1.0 + rand.nextDouble()) * 0.01 * dataManager.varianceVector[jjj]); + } + } + sigma[kkk] = randSigma; + } + + for( int ttt = 0; ttt < numIterations; ttt++ ) { + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Expectation Step (calculate the probability that each data point is in each cluster) + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Maximization Step (move the clusters to maximize the sum probability of each data point) + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + maximizeGaussians( data, pVarInCluster, startCluster, stopCluster ); + + System.out.println("Finished iteration " + (ttt+1) ); + } + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Evaluate the clusters using titv as an estimate of the true positive rate + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); // One final evaluation because the Gaussians moved in the last maximization step + + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + probTi[kkk] = 0.0; + probTv[kkk] = 0.0; + probKnown[kkk] = 0.0; + probNovel[kkk] = 0.0; + } + for( int iii = 0; iii < numVariants; iii++ ) { + final boolean isTransition = data[iii].isTransition; + final boolean isKnown = data[iii].isKnown; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + final double prob = pVarInCluster[kkk][iii]; + if( isKnown ) { // known + probKnown[kkk] += prob; + if( isTransition ) { // transition + probKnownTi[kkk] += prob; + probTi[kkk] += prob; + } else { // transversion + probKnownTv[kkk] += prob; + probTv[kkk] += prob; + } + } else { //novel + probNovel[kkk] += prob; + if( isTransition ) { // transition + probNovelTi[kkk] += prob; + probTi[kkk] += prob; + } else { // transversion + probNovelTv[kkk] += prob; + probTv[kkk] += prob; + } + } + + } + + double maxProb = pVarInCluster[startCluster][iii]; + int maxCluster = startCluster; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + if( pVarInCluster[kkk][iii] > maxProb ) { + maxProb = pVarInCluster[kkk][iii]; + maxCluster = kkk; + } + } + if( isKnown ) { + numMaxClusterKnown[maxCluster]++; + } else { + numMaxClusterNovel[maxCluster]++; + } + } + + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + clusterTITV[kkk] = probTi[kkk] / probTv[kkk]; + if( probKnown[kkk] > 2000.0 ) { // BUGBUG: make this a command line argument, parameterize performance based on this important argument + clusterTruePositiveRate[kkk] = calcTruePositiveRateFromKnownTITV( probKnownTi[kkk] / probKnownTv[kkk], probNovelTi[kkk] / probNovelTv[kkk] ); + } else { + clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] ); + } + } + } + + + + // This cluster method doesn't make use of the differences between known and novel Ti/Tv ratios + + /* + public final void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster ) { final int numVariants = data.length; final int numAnnotations = data[0].annotations.length; @@ -185,15 +432,15 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel // ttt - loop over EM iterations // Set up the initial random Gaussians - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { numMaxClusterKnown[kkk] = 0; numMaxClusterNovel[kkk] = 0; - pCluster[kkk] = 1.0 / ((double) numGaussians); + pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster)); mu[kkk] = data[rand.nextInt(numVariants)].annotations; final double[] randSigma = new double[numAnnotations]; if( dataManager.isNormalized ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = 0.75 + 0.4 * rand.nextDouble(); + randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble(); } } else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer for( int jjj = 0; jjj < numAnnotations; jjj++ ) { @@ -209,12 +456,12 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Expectation Step (calculate the probability that each data point is in each cluster) ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - evaluateGaussians( data, pVarInCluster ); + evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Maximization Step (move the clusters to maximize the sum probability of each data point) ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - maximizeGaussians( data, pVarInCluster ); + maximizeGaussians( data, pVarInCluster, startCluster, stopCluster ); System.out.println("Finished iteration " + (ttt+1) ); } @@ -222,28 +469,28 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Evaluate the clusters using titv as an estimate of the true positive rate ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - evaluateGaussians( data, pVarInCluster ); // One final evaluation because the Gaussians moved in the last maximization step + evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); // One final evaluation because the Gaussians moved in the last maximization step - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { probTi[kkk] = 0.0; probTv[kkk] = 0.0; } // Use the cluster's probabilistic Ti/Tv ratio as the indication of the cluster's true positive rate for( int iii = 0; iii < numVariants; iii++ ) { if( data[iii].isTransition ) { // transition - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { probTi[kkk] += pVarInCluster[kkk][iii]; } } else { // transversion - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { probTv[kkk] += pVarInCluster[kkk][iii]; } } // Calculate which cluster has the maximum probability for this variant for use as a metric of how well clustered the data is - double maxProb = pVarInCluster[0][iii]; - int maxCluster = 0; - for( int kkk = 1; kkk < numGaussians; kkk++ ) { + double maxProb = pVarInCluster[startCluster][iii]; + int maxCluster = startCluster; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { if( pVarInCluster[kkk][iii] > maxProb ) { maxProb = pVarInCluster[kkk][iii]; maxCluster = kkk; @@ -255,11 +502,13 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel numMaxClusterNovel[maxCluster]++; } } - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { clusterTITV[kkk] = probTi[kkk] / probTv[kkk]; clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] ); } } + */ + private void printClusters( final String clusterFileName ) { try { @@ -269,6 +518,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel for( int kkk = 0; kkk < numGaussians; kkk++ ) { if( numMaxClusterKnown[kkk] + numMaxClusterNovel[kkk] >= minVarInCluster ) { outputFile.print("@!CLUSTER,"); + outputFile.print( (kkk < numGaussians / 2 ? 1 : 0) + "," ); // is het cluster? + outputFile.print(pCluster[kkk] + ","); outputFile.print(numMaxClusterKnown[kkk] + ","); outputFile.print(numMaxClusterNovel[kkk] + ","); outputFile.print(clusterTITV[kkk] + ","); @@ -284,12 +535,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } outputFile.close(); } catch (Exception e) { - e.printStackTrace(); - System.exit(-1); + throw new StingException( "Unable to create output file: " + clusterFileName ); } } - public final double evaluateVariant( final Map annotationMap, final double qualityScore ) { + public final double evaluateVariant( final Map annotationMap, final double qualityScore, final boolean isHet ) { final double[] pVarInCluster = new double[numGaussians]; final double[] annotations = new double[dataManager.numAnnotations]; @@ -314,11 +564,13 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel annotations[jjj] = (value - dataManager.meanVector[jjj]) / dataManager.varianceVector[jjj]; } - evaluateGaussiansForSingleVariant( annotations, pVarInCluster ); + evaluateGaussiansForSingleVariant( annotations, pVarInCluster, isHet ); - double sum = 0; + double sum = 0.0; for( int kkk = 0; kkk < numGaussians; kkk++ ) { - sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk]; + if( isHetCluster[kkk] == isHet ) { + sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk]; + } } return sum; @@ -333,12 +585,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel markedVariant[iii] = false; } - PrintStream outputFile = null; + PrintStream outputFile; try { outputFile = new PrintStream( outputPrefix + ".dat" ); } catch (Exception e) { - e.printStackTrace(); - System.exit(-1); + throw new StingException( "Unable to create output file: " + outputPrefix + ".dat" ); } int numKnown = 0; @@ -382,21 +633,21 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel foundDesiredNumVariants = true; } outputFile.println( pCut + "," + numKnown + "," + numNovel + "," + - ( numKnownTi ==0 || numKnownTv == 0 ? "NaN" : ( ((double)numKnownTi) / ((double)numKnownTv) ) ) + "," + - ( numNovelTi ==0 || numNovelTv == 0 ? "NaN" : ( ((double)numNovelTi) / ((double)numNovelTv) ))); + ( numKnownTi == 0 || numKnownTv == 0 ? "NaN" : ( ((double)numKnownTi) / ((double)numKnownTv) ) ) + "," + + ( numNovelTi == 0 || numNovelTv == 0 ? "NaN" : ( ((double)numNovelTi) / ((double)numNovelTv) ))); } outputFile.close(); } - private void evaluateGaussians( final VariantDatum[] data, final double[][] pVarInCluster ) { + private void evaluateGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int startCluster, final int stopCluster ) { final int numAnnotations = data[0].annotations.length; for( int iii = 0; iii < data.length; iii++ ) { double sumProb = 0.0; - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { double sum = 0.0; for( int jjj = 0; jjj < numAnnotations; jjj++ ) { sum += ( (data[iii].annotations[jjj] - mu[kkk][jjj]) * (data[iii].annotations[jjj] - mu[kkk][jjj]) ) @@ -412,7 +663,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } if( sumProb > MIN_SUM_PROB ) { // Very small numbers are a very big problem - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { pVarInCluster[kkk][iii] /= sumProb; } } @@ -420,48 +671,53 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } - private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster ) { + private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster, final boolean isHet ) { final int numAnnotations = annotations.length; double sumProb = 0.0; for( int kkk = 0; kkk < numGaussians; kkk++ ) { - double sum = 0.0; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - sum += ( (annotations[jjj] - mu[kkk][jjj]) * (annotations[jjj] - mu[kkk][jjj]) ) - / sigma[kkk][jjj]; - } - //BUGBUG: removed pCluster[kkk]*, for the second pass - //pVarInCluster[kkk] = pCluster[kkk] * Math.exp( -0.5 * sum ); - pVarInCluster[kkk] = Math.exp( -0.5 * sum ); + if( isHetCluster[kkk] == isHet ) { + double sum = 0.0; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + sum += ( (annotations[jjj] - mu[kkk][jjj]) * (annotations[jjj] - mu[kkk][jjj]) ) + / sigma[kkk][jjj]; + } - if( pVarInCluster[kkk] < MIN_PROB) { // Very small numbers are a very big problem - pVarInCluster[kkk] = MIN_PROB; - } + //BUGBUG: pCluster[kkk] should be removed here? + pVarInCluster[kkk] = pCluster[kkk] * Math.exp( -0.5 * sum ); + //pVarInCluster[kkk] = Math.exp( -0.5 * sum ); - sumProb += pVarInCluster[kkk]; + if( pVarInCluster[kkk] < MIN_PROB) { // Very small numbers are a very big problem + pVarInCluster[kkk] = MIN_PROB; + } + + sumProb += pVarInCluster[kkk]; + } } if( sumProb > MIN_SUM_PROB ) { // Very small numbers are a very big problem for( int kkk = 0; kkk < numGaussians; kkk++ ) { - pVarInCluster[kkk] /= sumProb; + if( isHetCluster[kkk] == isHet ) { + pVarInCluster[kkk] /= sumProb; + } } } } - private void maximizeGaussians( final VariantDatum[] data, final double[][] pVarInCluster ) { + private void maximizeGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int startCluster, final int stopCluster ) { final int numVariants = data.length; final int numAnnotations = data[0].annotations.length; - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { mu[kkk][jjj] = 0.0; sigma[kkk][jjj] = 0.0; } } - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { double sumProb = 0.0; for( int iii = 0; iii < numVariants; iii++ ) { final double prob = pVarInCluster[kkk][iii]; @@ -475,7 +731,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel for( int jjj = 0; jjj < numAnnotations; jjj++ ) { mu[kkk][jjj] /= sumProb; } - } + } //BUGBUG: clean up dead clusters to speed up computation for( int iii = 0; iii < numVariants; iii++ ) { final double prob = pVarInCluster[kkk][iii]; @@ -499,9 +755,10 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel pCluster[kkk] = sumProb / numVariants; } + // Clean up extra big or extra small clusters //BUGBUG: Is this a good idea? - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { if( pCluster[kkk] > 0.45 ) { // This is a very large cluster compared to all the others final int numToReplace = 4; final double[] savedSigma = sigma[kkk]; @@ -513,7 +770,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel randVarIndex = rand.nextInt( numVariants ); final double probK = pVarInCluster[kkk][randVarIndex]; boolean inClusterK = true; - for( int ccc = 0; ccc < numGaussians; ccc++ ) { + for( int ccc = startCluster; ccc < stopCluster; ccc++ ) { if( pVarInCluster[ccc][randVarIndex] > probK ) { inClusterK = false; break; @@ -525,11 +782,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel // Find a place to put the example variant if( rrr == 0 ) { // Replace the big cluster that kicked this process off mu[kkk] = data[randVarIndex].annotations; - pCluster[kkk] = 1.0 / ((double) numGaussians); + pCluster[kkk] = 1.0 / ((double) (stopCluster-startCluster)); } else { // Replace the cluster with the minimum prob - double minProb = pCluster[0]; - int minClusterIndex = 0; - for( int ccc = 1; ccc < numGaussians; ccc++ ) { + double minProb = pCluster[startCluster]; + int minClusterIndex = startCluster; + for( int ccc = startCluster; ccc < stopCluster; ccc++ ) { if( pCluster[ccc] < minProb ) { minProb = pCluster[ccc]; minClusterIndex = ccc; @@ -543,21 +800,22 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel sigma[minClusterIndex][jjj] = MIN_SUM_PROB; } } - pCluster[minClusterIndex] = 1.0 / ((double) numGaussians); + pCluster[minClusterIndex] = 1.0 / ((double) (stopCluster-startCluster)); } } } } + // Replace small clusters with another random draw from the dataset - for( int kkk = 0; kkk < numGaussians; kkk++ ) { - if( pCluster[kkk] < 0.05 * (1.0 / ((double) numGaussians)) ) { // This is a very small cluster compared to all the others - pCluster[kkk] = 1.0 / ((double) numGaussians); + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + if( pCluster[kkk] < 0.07 * (1.0 / ((double) (stopCluster-startCluster))) ) { // This is a very small cluster compared to all the others + pCluster[kkk] = 1.0 / ((double) (stopCluster-startCluster)); mu[kkk] = data[rand.nextInt(numVariants)].annotations; final double[] randSigma = new double[numAnnotations]; if( dataManager.isNormalized ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = 0.6 + 0.4 * rand.nextDouble(); // Explore a wider range of possible sigma values since we are tossing out clusters anyway + randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble(); // BUGBUG: Explore a wider range of possible sigma values since we are tossing out clusters anyway? } } else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer for( int jjj = 0; jjj < numAnnotations; jjj++ ) { @@ -567,6 +825,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel sigma[kkk] = randSigma; } } + } } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantNearestNeighborsModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantNearestNeighborsModel.java index 1c79bf583..ab6b806c0 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantNearestNeighborsModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantNearestNeighborsModel.java @@ -62,12 +62,11 @@ public final class VariantNearestNeighborsModel extends VariantOptimizationModel pTrueVariant[iii] = calcTruePositiveRateFromTITV( vTree.calcNeighborhoodTITV( dataManager.data[iii] ) ); } - PrintStream outputFile = null; + PrintStream outputFile; try { outputFile = new PrintStream( outputPrefix + ".knn.optimize" ); } catch (Exception e) { - e.printStackTrace(); - System.exit(-1); + throw new StingException( "Unable to create output file: " + outputPrefix + ".knn.optimize" ); } for(int iii = 0; iii < numVariants; iii++) { outputFile.print(String.format("%.4f",pTrueVariant[iii]) + ","); diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationModel.java index c52e29ed3..71b0027d5 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationModel.java @@ -51,4 +51,10 @@ public abstract class VariantOptimizationModel implements VariantOptimizationInt //if( titv < 0.0 ) { titv = 0.0; } //return ( titv / targetTITV ); } + + public final double calcTruePositiveRateFromKnownTITV( final double knownTITV, double novelTITV ) { + if( novelTITV > knownTITV ) { novelTITV -= 2.0f*(novelTITV-knownTITV); } + if( novelTITV < 0.5 ) { novelTITV = 0.5; } + return ( (novelTITV - 0.5) / (knownTITV - 0.5) ); + } } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java index 569914451..45feecba9 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java @@ -10,7 +10,9 @@ import org.broadinstitute.sting.utils.ExpandingArrayList; import org.broadinstitute.sting.utils.StingException; import org.broadinstitute.sting.utils.cmdLine.Argument; -import java.io.IOException; +import java.util.Arrays; +import java.util.Set; +import java.util.TreeSet; /* * Copyright (c) 2010 The Broad Institute @@ -53,8 +55,10 @@ public class VariantOptimizer extends RodWalker ///////////////////////////// @Argument(fullName="target_titv", shortName="titv", doc="The target Ti/Tv ratio towards which to optimize. (~~2.1 for whole genome experiments)", required=true) private double TARGET_TITV = 2.1; - @Argument(fullName="ignore_input_filters", shortName="ignoreFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false) - private boolean IGNORE_INPUT_FILTERS = false; + @Argument(fullName="ignore_all_input_filters", shortName="ignoreAllFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false) + private boolean IGNORE_ALL_INPUT_FILTERS = false; + @Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false) + private String[] IGNORE_INPUT_FILTERS = null; @Argument(fullName="exclude_annotation", shortName="exclude", doc="The names of the annotations which should be excluded from the calculations", required=false) private String[] EXCLUDED_ANNOTATIONS = null; @Argument(fullName="force_annotation", shortName="force", doc="The names of the annotations which should be forced into the calculations even if they aren't present in every variant", required=false) @@ -62,11 +66,11 @@ public class VariantOptimizer extends RodWalker @Argument(fullName="clusterFile", shortName="clusterFile", doc="The output cluster file", required=true) private String CLUSTER_FILENAME = "optimizer.cluster"; @Argument(fullName="numGaussians", shortName="nG", doc="The number of Gaussians to be used in the Gaussian Mixture model", required=false) - private int NUM_GAUSSIANS = 32; + private int NUM_GAUSSIANS = 100; @Argument(fullName="numIterations", shortName="nI", doc="The number of iterations to be performed in the Gaussian Mixture model", required=false) - private int NUM_ITERATIONS = 10; - @Argument(fullName="minVarInCluster", shortName="minVar", doc="The minimum number of variants in a cluster to be considered a valid cluster. Used to prevent overfitting. Default is 2000.", required=true) - private int MIN_VAR_IN_CLUSTER = 2000; + private int NUM_ITERATIONS = 7; + @Argument(fullName="minVarInCluster", shortName="minVar", doc="The minimum number of variants in a cluster to be considered a valid cluster. It can be used to prevent overfitting.", required=false) + private int MIN_VAR_IN_CLUSTER = 0; //@Argument(fullName="knn", shortName="knn", doc="The number of nearest neighbors to be used in the k-Nearest Neighbors model", required=false) //private int NUM_KNN = 2000; @@ -80,6 +84,7 @@ public class VariantOptimizer extends RodWalker private boolean firstVariant = true; private int numAnnotations = 0; private static final double INFINITE_ANNOTATION_VALUE = 10000.0; + private Set ignoreInputFilterSet = null; //--------------------------------------------------------------------------------------------------------------- // @@ -88,7 +93,9 @@ public class VariantOptimizer extends RodWalker //--------------------------------------------------------------------------------------------------------------- public void initialize() { - //if( !PATH_TO_RESOURCES.endsWith("/") ) { PATH_TO_RESOURCES = PATH_TO_RESOURCES + "/"; } + if( IGNORE_INPUT_FILTERS != null ) { + ignoreInputFilterSet = new TreeSet(Arrays.asList(IGNORE_INPUT_FILTERS)); + } } //--------------------------------------------------------------------------------------------------------------- @@ -109,49 +116,52 @@ public class VariantOptimizer extends RodWalker for( final VariantContext vc : tracker.getAllVariantContexts(null, context.getLocation(), false, false) ) { - if( vc != null && vc.isSNP() && (IGNORE_INPUT_FILTERS || !vc.isFiltered()) ) { - if( firstVariant ) { // This is the first variant encountered so set up the list of annotations - annotationKeys.addAll( vc.getAttributes().keySet() ); - if( annotationKeys.contains("ID") ) { annotationKeys.remove("ID"); } // ID field is added to the vc's INFO field? - if( annotationKeys.contains("DB") ) { annotationKeys.remove("DB"); } - if( EXCLUDED_ANNOTATIONS != null ) { - for( final String excludedAnnotation : EXCLUDED_ANNOTATIONS ) { - if( annotationKeys.contains( excludedAnnotation ) ) { annotationKeys.remove( excludedAnnotation ); } + if( vc != null && vc.isSNP() ) { + if( !vc.isFiltered() || IGNORE_ALL_INPUT_FILTERS || (ignoreInputFilterSet != null && ignoreInputFilterSet.containsAll(vc.getFilters())) ) { + if( firstVariant ) { // This is the first variant encountered so set up the list of annotations + annotationKeys.addAll( vc.getAttributes().keySet() ); + if( annotationKeys.contains("ID") ) { annotationKeys.remove("ID"); } // ID field is added to the vc's INFO field? + if( annotationKeys.contains("DB") ) { annotationKeys.remove("DB"); } + if( EXCLUDED_ANNOTATIONS != null ) { + for( final String excludedAnnotation : EXCLUDED_ANNOTATIONS ) { + if( annotationKeys.contains( excludedAnnotation ) ) { annotationKeys.remove( excludedAnnotation ); } + } } - } - if( FORCED_ANNOTATIONS != null ) { - for( final String forcedAnnotation : FORCED_ANNOTATIONS ) { - if( !annotationKeys.contains( forcedAnnotation ) ) { annotationKeys.add( forcedAnnotation ); } + if( FORCED_ANNOTATIONS != null ) { + for( final String forcedAnnotation : FORCED_ANNOTATIONS ) { + if( !annotationKeys.contains( forcedAnnotation ) ) { annotationKeys.add( forcedAnnotation ); } + } } + numAnnotations = annotationKeys.size() + 1; // +1 for variant quality ("QUAL") + annotationValues = new double[numAnnotations]; + firstVariant = false; } - numAnnotations = annotationKeys.size() + 1; // +1 for variant quality ("QUAL") - annotationValues = new double[numAnnotations]; - firstVariant = false; + + int iii = 0; + for( final String key : annotationKeys ) { + + double value = 0.0; + try { + value = Double.parseDouble( (String)vc.getAttribute( key, "0.0" ) ); + if( Double.isInfinite(value) ) { + value = ( value > 0 ? 1.0 : -1.0 ) * INFINITE_ANNOTATION_VALUE; + } + } catch( NumberFormatException e ) { + // do nothing, default value is 0.0 + } + annotationValues[iii++] = value; + } + + // Variant quality ("QUAL") is not in the list of annotations, but is useful so add it here. + annotationValues[iii] = vc.getPhredScaledQual(); + + final VariantDatum variantDatum = new VariantDatum(); + variantDatum.annotations = annotationValues; + variantDatum.isTransition = vc.getSNPSubstitutionType().compareTo(BaseUtils.BaseSubstitutionType.TRANSITION) == 0; + variantDatum.isKnown = !vc.getAttribute("ID").equals("."); + variantDatum.isHet = vc.getHetCount() > vc.getHomVarCount(); // BUGBUG: what to do here for multi sample calls? + mapList.add( variantDatum ); } - - int iii = 0; - for( final String key : annotationKeys ) { - - double value = 0.0; - try { - value = Double.parseDouble( (String)vc.getAttribute( key, "0.0" ) ); - if( Double.isInfinite(value) ) { - value = ( value > 0 ? 1.0 : -1.0 ) * INFINITE_ANNOTATION_VALUE; - } - } catch( NumberFormatException e ) { - // do nothing, default value is 0.0 - } - annotationValues[iii++] = value; - } - - // Variant quality ("QUAL") is not in the list of annotations, but is useful so add it here. - annotationValues[iii] = vc.getPhredScaledQual(); - - VariantDatum variantDatum = new VariantDatum(); - variantDatum.annotations = annotationValues; - variantDatum.isTransition = vc.getSNPSubstitutionType().compareTo(BaseUtils.BaseSubstitutionType.TRANSITION) == 0; - variantDatum.isKnown = !vc.getAttribute("ID").equals("."); - mapList.add( variantDatum ); } }