diff --git a/R/plot_variantROCCurve.R b/R/plot_variantROCCurve.R new file mode 100755 index 000000000..1b7ae9292 --- /dev/null +++ b/R/plot_variantROCCurve.R @@ -0,0 +1,21 @@ +#!/broad/tools/apps/R-2.6.0/bin/Rscript + +args <- commandArgs(TRUE) +verbose = TRUE + +input = args[1] + +data = read.table(input,sep=",",head=T) +numCurves = (length(data) - 1)/3 +maxSpec = max(data[,(1:numCurves)*3]) + +outfile = paste(input, ".variantROCCurve.pdf", sep="") +pdf(outfile, height=7, width=7) + +par(cex=1.3) +plot(data$specificity1,data$sensitivity1, type="n", xlim=c(0,maxSpec),ylim=c(0,1),xlab="1 - Specificity",ylab="Sensitivity") +for(iii in 1:numCurves) { + points(data[,iii*3],data[,(iii-1)*3+2],lwd=3,type="l",col=iii) +} +legend("bottomright", names(data)[(0:(numCurves-1))*3+1], col=1:numCurves,lwd=3) +dev.off() \ No newline at end of file diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnalyzeAnnotationsWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnalyzeAnnotationsWalker.java index 4b3c5f3de..721b93195 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnalyzeAnnotationsWalker.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnalyzeAnnotationsWalker.java @@ -108,7 +108,7 @@ public class AnalyzeAnnotationsWalker extends RodWalker { // First find out if this variant is in the truth sets boolean isInTruthSet = false; boolean isTrueVariant = false; - for( ReferenceOrderedDatum rod : tracker.getAllRods() ) { + for( final ReferenceOrderedDatum rod : tracker.getAllRods() ) { if( rod != null && rod.getName().toUpperCase().startsWith("TRUTH") ) { isInTruthSet = true; @@ -132,7 +132,7 @@ public class AnalyzeAnnotationsWalker extends RodWalker { } // Add each annotation in this VCF Record to the dataManager - for( ReferenceOrderedDatum rod : tracker.getAllRods() ) { + for( final ReferenceOrderedDatum rod : tracker.getAllRods() ) { if( rod != null && rod instanceof RodVCF && !rod.getName().toUpperCase().startsWith("TRUTH") ) { final RodVCF variant = (RodVCF) rod; if( variant.isSNP() ) { diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDataManager.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDataManager.java index 64d2420d1..8d76b24a8 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDataManager.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDataManager.java @@ -63,7 +63,7 @@ public class AnnotationDataManager { // Loop over each annotation in the vcf record final Map infoField = variant.getInfoValues(); infoField.put("QUAL", ((Double)variant.getQual()).toString() ); // add QUAL field to annotations - for( String annotationKey : infoField.keySet() ) { + for( final String annotationKey : infoField.keySet() ) { float value; try { @@ -102,7 +102,7 @@ public class AnnotationDataManager { System.out.println( "\nFinished reading variants into memory. Executing RScript commands:" ); // For each annotation we've seen - for( String annotationKey : data.keySet() ) { + for( final String annotationKey : data.keySet() ) { PrintStream output; try { @@ -117,7 +117,7 @@ public class AnnotationDataManager { // Bin SNPs and calculate truth metrics for each bin thisAnnotationBin.clearBin(); - for( AnnotationDatum datum : data.get( annotationKey ) ) { + for( final AnnotationDatum datum : data.get( annotationKey ) ) { thisAnnotationBin.combine( datum ); if( thisAnnotationBin.numVariants( AnnotationDatum.FULL_SET ) >= MAX_VARIANTS_PER_BIN ) { // This annotation bin is full output.println( thisAnnotationBin.value + "\t" + thisAnnotationBin.calcTiTv( AnnotationDatum.FULL_SET ) + "\t" + thisAnnotationBin.calcDBsnpRate() + "\t" + thisAnnotationBin.calcTPrate() + diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDatum.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDatum.java index 5fb668ade..fb321b899 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDatum.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/AnnotationDatum.java @@ -59,7 +59,7 @@ public class AnnotationDatum implements Comparator { } } - public AnnotationDatum( float _value ) { + public AnnotationDatum( final float _value ) { value = _value; ti = new int[NUM_SETS]; diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java index 41fad687d..e1668a67a 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java @@ -12,10 +12,7 @@ import org.broadinstitute.sting.utils.genotype.vcf.*; import java.io.File; import java.io.IOException; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Set; -import java.util.TreeSet; +import java.util.*; /* * Copyright (c) 2010 The Broad Institute @@ -60,8 +57,12 @@ public class ApplyVariantClustersWalker extends RodWalker ignoreInputFilterSet = null; + //--------------------------------------------------------------------------------------------------------------- // @@ -88,6 +91,10 @@ public class ApplyVariantClustersWalker extends RodWalker(Arrays.asList(IGNORE_INPUT_FILTERS)); + } + switch (OPTIMIZATION_MODEL) { case GAUSSIAN_MIXTURE_MODEL: theModel = new VariantGaussianMixtureModel( CLUSTER_FILENAME ); @@ -101,14 +108,14 @@ public class ApplyVariantClustersWalker extends RodWalker hInfo = new HashSet(); + final Set hInfo = new HashSet(); hInfo.addAll(VCFUtils.getHeaderFields(getToolkit())); hInfo.add(new VCFInfoHeaderLine("OQ", 1, VCFInfoHeaderLine.INFO_TYPE.Float, "The original variant quality score")); hInfo.add(new VCFHeaderLine("source", "VariantOptimizer")); vcfWriter = new VCFWriter( new File(OUTPUT_PREFIX + ".vcf") ); - TreeSet samples = new TreeSet(); + final TreeSet samples = new TreeSet(); SampleUtils.getUniquifiedSamplesFromRods(getToolkit(), samples, new HashMap, String>()); - VCFHeader vcfHeader = new VCFHeader(hInfo, samples); + final VCFHeader vcfHeader = new VCFHeader(hInfo, samples); vcfWriter.writeHeader(vcfHeader); } @@ -126,22 +133,41 @@ public class ApplyVariantClustersWalker extends RodWalker numHom; //vc.getHetCount() > vc.getHomVarCount(); // BUGBUG: what to do here for multi sample calls? + + final double pTrue = theModel.evaluateVariant( rodVCF.getInfoValues(), rodVCF.getQual(), variantDatum.isHet ); + final double recalQual = QualityUtils.phredScaleErrorRate( Math.max( 1.0 - pTrue, 0.000000001) ); + + if( variantDatum.isKnown ) { + variantDatum.qual = 0.5 * recalQual + 0.5 * KNOWN_VAR_QUAL_PRIOR; + } else { + variantDatum.qual = recalQual; + } mapList.add( variantDatum ); + rodVCF.mCurrentRecord.addInfoField("OQ", ((Double)rodVCF.getQual()).toString() ); + rodVCF.mCurrentRecord.setQual( variantDatum.qual ); + rodVCF.mCurrentRecord.setFilterString(VCFRecord.UNFILTERED); + vcfWriter.addRecord( rodVCF.mCurrentRecord ); + + } else { // not a SNP or is filtered so just dump it out to the VCF file vcfWriter.addRecord( rodVCF.mCurrentRecord ); } @@ -169,6 +195,8 @@ public class ApplyVariantClustersWalker extends RodWalker reduceSum ) { + vcfWriter.close(); + final VariantDataManager dataManager = new VariantDataManager( reduceSum, theModel.dataManager.annotationKeys ); reduceSum.clear(); // Don't need this ever again, clean up some memory diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java index 1f42a61de..891f21443 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantClusteringModel.java @@ -32,6 +32,6 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; */ public interface VariantClusteringModel extends VariantOptimizationInterface { - public void createClusters( final VariantDatum[] data ); + public void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster ); //public void applyClusters( final VariantDatum[] data, final String outputPrefix ); } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantConcordanceROCCurveWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantConcordanceROCCurveWalker.java new file mode 100755 index 000000000..652af1936 --- /dev/null +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantConcordanceROCCurveWalker.java @@ -0,0 +1,281 @@ +package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer; + +import org.broadinstitute.sting.gatk.contexts.AlignmentContext; +import org.broadinstitute.sting.gatk.contexts.ReferenceContext; +import org.broadinstitute.sting.gatk.contexts.variantcontext.VariantContext; +import org.broadinstitute.sting.gatk.datasources.simpleDataSources.ReferenceOrderedDataSource; +import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; +import org.broadinstitute.sting.gatk.refdata.RodVCF; +import org.broadinstitute.sting.gatk.walkers.RodWalker; +import org.broadinstitute.sting.utils.ExpandingArrayList; +import org.broadinstitute.sting.utils.Pair; +import org.broadinstitute.sting.utils.StingException; +import org.broadinstitute.sting.utils.cmdLine.Argument; + +import java.io.IOException; +import java.io.PrintStream; +import java.util.HashMap; + +/* + * Copyright (c) 2010 The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +/** + * Calculates variant concordance with the given truth sets and plots ROC curves for each input call set. + * + * @author rpoplin + * @since Mar 18, 2010 + * + * @help.summary Calculates variant concordance with the given truth sets and plots ROC curves for each input call set. + */ + +public class VariantConcordanceROCCurveWalker extends RodWalker>, HashMap>> { + + ///////////////////////////// + // Command Line Arguments + ///////////////////////////// + @Argument(fullName="output_prefix", shortName="output", doc="The prefix added to output VCF file name and optimization curve pdf file name", required=false) + private String OUTPUT_PREFIX = "optimizer"; + @Argument(fullName = "path_to_Rscript", shortName = "Rscript", doc = "The path to your implementation of Rscript. For Broad users this is probably /broad/tools/apps/R-2.6.0/bin/Rscript", required = false) + private String PATH_TO_RSCRIPT = "/broad/tools/apps/R-2.6.0/bin/Rscript"; + @Argument(fullName = "path_to_resources", shortName = "resources", doc = "Path to resources folder holding the Sting R scripts.", required = false) + private String PATH_TO_RESOURCES = "R/"; + + ///////////////////////////// + // Private Member Variables + ///////////////////////////// + private final ExpandingArrayList inputRodNames = new ExpandingArrayList(); + private int numCurves; + private int[] trueNegGlobal; + private int[] falseNegGlobal; + private String sampleName = null; + + //--------------------------------------------------------------------------------------------------------------- + // + // initialize + // + //--------------------------------------------------------------------------------------------------------------- + + public void initialize() { + if( !PATH_TO_RESOURCES.endsWith("/") ) { PATH_TO_RESOURCES = PATH_TO_RESOURCES + "/"; } + + for( ReferenceOrderedDataSource rod : this.getToolkit().getRodDataSources() ) { + if( rod != null && !rod.getName().toUpperCase().startsWith("TRUTH") ) { + if( rod.getReferenceOrderedData().iterator().next().get(0) instanceof RodVCF ) { + inputRodNames.add(rod.getName()); + if( sampleName == null ) { + sampleName = ((RodVCF)rod.getReferenceOrderedData().iterator().next().get(0)).getSampleNames()[0]; // BUGBUG: single sample calls only for now + } + } + } + } + numCurves = inputRodNames.size(); + trueNegGlobal = new int[numCurves]; + falseNegGlobal = new int[numCurves]; + for( int kkk = 0; kkk < numCurves; kkk++ ) { + trueNegGlobal[kkk] = 0; + falseNegGlobal[kkk] = 0; + } + } + + //--------------------------------------------------------------------------------------------------------------- + // + // map + // + //--------------------------------------------------------------------------------------------------------------- + + public ExpandingArrayList> map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) { + + final ExpandingArrayList> mapList = new ExpandingArrayList>(); + + if( tracker == null ) { // For some reason RodWalkers get map calls with null trackers + return mapList; + } + + boolean isInTruthSet = false; + boolean isTrueVariant = false; + + for( final VariantContext vc : tracker.getAllVariantContexts(null, context.getLocation(), false, false) ) { + if( vc != null && vc.getName().toUpperCase().startsWith("TRUTH") ) { + if( !vc.getGenotype(sampleName).isNoCall() ) { + isInTruthSet = true; + + if( !vc.getGenotype(sampleName).isHomRef() ) { + isTrueVariant = true; + } + } + //if( vc.isPolymorphic() ) { //BUGBUG: I don't think this is the right thing to do here, there are many polymorphic sites in the truth data because there are many samples + // isTrueVariant = true; + //} + } + } + + if( !isInTruthSet ) { + return mapList; // jump out if this location isn't in the truth set + } + + int curveIndex = 0; + for( final String inputName : inputRodNames ) { + final VariantContext vc = tracker.getVariantContext( inputName, null, context.getLocation(), false ); // assuming single variant per track per location + + if( vc != null && vc.isPolymorphic() && !vc.isFiltered() ) { + final VariantDatum variantDatum = new VariantDatum(); + variantDatum.qual = vc.getPhredScaledQual(); + variantDatum.isTrueVariant = isTrueVariant; + mapList.add( new Pair(inputName, variantDatum) ); + } else { // Either not in the call set at all, is a monomorphic call, or was filtered out + if( isTrueVariant ) { // ... but it is a true variant so this is a false negative call + falseNegGlobal[curveIndex]++; + } else { // ... and it is not a variant site so this is a true negative call + trueNegGlobal[curveIndex]++; + } + } + curveIndex++; + } + + return mapList; + } + + //--------------------------------------------------------------------------------------------------------------- + // + // reduce + // + //--------------------------------------------------------------------------------------------------------------- + + public HashMap> reduceInit() { + final HashMap> init = new HashMap>(); + for( final String inputName : inputRodNames ) { + init.put( inputName, new ExpandingArrayList() ); + } + return init; + } + + public HashMap> reduce( final ExpandingArrayList> mapValue, final HashMap> reduceSum ) { + for( Pair value : mapValue ) { + final ExpandingArrayList list = reduceSum.get(value.getFirst()); + list.add(value.getSecond()); + reduceSum.put(value.getFirst(),list); + } + return reduceSum; + } + + public void onTraversalDone( HashMap> reduceSum ) { + + final int NUM_CURVES = numCurves; + final HashMap dataManagerMap = new HashMap(); + for( final String inputName : inputRodNames ) { + System.out.println("Creating data manager for: " + inputName); + dataManagerMap.put(inputName, new VariantDataManager( reduceSum.get(inputName), null )); + } + reduceSum.clear(); // Don't need this ever again, clean up some memory + + final double[] minQual = new double[NUM_CURVES]; + final double[] maxQual = new double[NUM_CURVES]; + final double[] incrementQual = new double[NUM_CURVES]; + final double[] qualCut = new double[NUM_CURVES]; + + final int NUM_STEPS = 200; + + int curveIndex = 0; + for( final String inputName : inputRodNames ) { + final VariantDataManager dataManager = dataManagerMap.get(inputName); + minQual[curveIndex] = dataManager.data[0].qual; + maxQual[curveIndex] = dataManager.data[0].qual; + for( int iii = 1; iii < dataManager.data.length; iii++ ) { + final double qual = dataManager.data[iii].qual; + if( qual < minQual[curveIndex] ) { minQual[curveIndex] = qual; } + else if( qual > maxQual[curveIndex] ) { maxQual[curveIndex] = qual; } + } + incrementQual[curveIndex] = (maxQual[curveIndex] - minQual[curveIndex]) / ((double)NUM_STEPS); + qualCut[curveIndex] = minQual[curveIndex]; + curveIndex++; + } + + final int[] truePos = new int[NUM_CURVES]; + final int[] falsePos = new int[NUM_CURVES]; + final int[] trueNeg = new int[NUM_CURVES]; + final int[] falseNeg = new int[NUM_CURVES]; + + PrintStream outputFile; + try { + outputFile = new PrintStream( OUTPUT_PREFIX + ".dat" ); + } catch (Exception e) { + throw new StingException( "Unable to create output file: " + OUTPUT_PREFIX + ".dat" ); + } + int jjj = 1; + for( final String inputName : inputRodNames ) { + outputFile.print(inputName + ",sensitivity" + jjj + ",specificity" + jjj + ","); + jjj++; + } + outputFile.println("sentinel"); + + for( int step = 0; step < NUM_STEPS; step++ ) { + curveIndex = 0; + for( final String inputName : inputRodNames ) { + final VariantDataManager dataManager = dataManagerMap.get(inputName); + truePos[curveIndex] = 0; + falsePos[curveIndex] = 0; + trueNeg[curveIndex] = 0; + falseNeg[curveIndex] = 0; + final int NUM_VARIANTS = dataManager.data.length; + for( int iii = 0; iii < NUM_VARIANTS; iii++ ) { + if( dataManager.data[iii].qual >= qualCut[curveIndex] ) { // this var is in this hypothetical call set + if( dataManager.data[iii].isTrueVariant ) { + truePos[curveIndex]++; + } else { + falsePos[curveIndex]++; + } + } else { // this var is out of this hypothetical call set + if( dataManager.data[iii].isTrueVariant ) { + falseNeg[curveIndex]++; + } else { + trueNeg[curveIndex]++; + } + } + } + final double sensitivity = ((double) truePos[curveIndex]) / ((double) truePos[curveIndex] + falseNegGlobal[curveIndex] + falseNeg[curveIndex]); + final double specificity = ((double) trueNegGlobal[curveIndex] + trueNeg[curveIndex]) / + ((double) falsePos[curveIndex] + trueNegGlobal[curveIndex] + trueNeg[curveIndex]); + outputFile.print( String.format("%.4f,%.4f,%.4f,", qualCut[curveIndex], sensitivity, 1.0 - specificity) ); + qualCut[curveIndex] += incrementQual[curveIndex]; + curveIndex++; + } + outputFile.println("-1"); + } + + outputFile.close(); + + // Execute Rscript command to plot the optimization curve + // Print out the command line to make it clear to the user what is being executed and how one might modify it + final String rScriptCommandLine = PATH_TO_RSCRIPT + " " + PATH_TO_RESOURCES + "plot_variantROCCurve.R" + " " + OUTPUT_PREFIX + ".dat"; + System.out.println( rScriptCommandLine ); + + // Execute the RScript command to plot the table of truth values + try { + Runtime.getRuntime().exec( rScriptCommandLine ); + } catch ( IOException e ) { + throw new StingException( "Unable to execute RScript command: " + rScriptCommandLine ); + } + } +} diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java index 634987a0d..775d9a521 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDataManager.java @@ -77,8 +77,8 @@ public class VariantDataManager { annotationKeys = new ExpandingArrayList(); int jjj = 0; - for( String line : annotationLines ) { - String[] vals = line.split(","); + for( final String line : annotationLines ) { + final String[] vals = line.split(","); annotationKeys.add(vals[1]); meanVector[jjj] = Double.parseDouble(vals[2]); varianceVector[jjj] = Double.parseDouble(vals[3]); @@ -87,12 +87,15 @@ public class VariantDataManager { } public void normalizeData() { + boolean foundZeroVarianceAnnotation = false; for( int jjj = 0; jjj < numAnnotations; jjj++ ) { final double theMean = mean(data, jjj); final double theSTD = standardDeviation(data, theMean, jjj); System.out.println( (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj)) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); if( theSTD < 1E-8 ) { - throw new StingException("Zero variance is a problem: standard deviation = " + theSTD); + foundZeroVarianceAnnotation = true; + System.out.println("Zero variance is a problem: standard deviation = " + theSTD); + System.out.println("User must -exclude annotations with zero variance. Annotation = " + (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj))); } meanVector[jjj] = theMean; varianceVector[jjj] = theSTD; @@ -101,6 +104,9 @@ public class VariantDataManager { } } isNormalized = true; // Each data point is now [ (x - mean) / standard deviation ] + if( foundZeroVarianceAnnotation ) { + throw new StingException("Found annotations with zero variance. They must be excluded before proceeding."); + } } private static double mean( final VariantDatum[] data, final int index ) { diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDatum.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDatum.java index 469fa90e5..663304d31 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDatum.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDatum.java @@ -36,5 +36,6 @@ public class VariantDatum { public boolean isTransition; public boolean isKnown; public boolean isTrueVariant; + public boolean isHet; public double qual; } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java index 79d487db4..3ef776c45 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java @@ -55,6 +55,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel private final double[][] mu; // The means for the clusters private final double[][] sigma; // The variances for the clusters, sigma is really sigma^2 private final double[] pCluster; + private final boolean[] isHetCluster; private final int[] numMaxClusterKnown; private final int[] numMaxClusterNovel; private final double[] clusterTITV; @@ -68,12 +69,13 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel public VariantGaussianMixtureModel( final VariantDataManager _dataManager, final double _targetTITV, final int _numGaussians, final int _numIterations, final int _minVarInCluster ) { super( _targetTITV ); dataManager = _dataManager; - numGaussians = _numGaussians; + numGaussians = ( _numGaussians % 2 == 0 ? _numGaussians : _numGaussians + 1 ); numIterations = _numIterations; mu = new double[numGaussians][]; sigma = new double[numGaussians][]; pCluster = new double[numGaussians]; + isHetCluster = null; numMaxClusterKnown = new int[numGaussians]; numMaxClusterNovel = new int[numGaussians]; clusterTITV = new double[numGaussians]; @@ -83,8 +85,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel public VariantGaussianMixtureModel( final String clusterFileName ) { super( 0.0 ); - ExpandingArrayList annotationLines = new ExpandingArrayList(); - ExpandingArrayList clusterLines = new ExpandingArrayList(); + final ExpandingArrayList annotationLines = new ExpandingArrayList(); + final ExpandingArrayList clusterLines = new ExpandingArrayList(); try { for ( String line : new xReadLines(new File( clusterFileName )) ) { @@ -102,24 +104,28 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel dataManager = new VariantDataManager( annotationLines ); numIterations = 0; - pCluster = null; numMaxClusterKnown = null; numMaxClusterNovel = null; clusterTITV = null; minVarInCluster = 0; + //BUGBUG: move this parsing out of the constructor numGaussians = clusterLines.size(); mu = new double[numGaussians][dataManager.numAnnotations]; sigma = new double[numGaussians][dataManager.numAnnotations]; + pCluster = new double[numGaussians]; + isHetCluster = new boolean[numGaussians]; clusterTruePositiveRate = new double[numGaussians]; int kkk = 0; for( String line : clusterLines ) { - String[] vals = line.split(","); - clusterTruePositiveRate[kkk] = Double.parseDouble(vals[4]); + final String[] vals = line.split(","); + isHetCluster[kkk] = Integer.parseInt(vals[1]) == 1; + pCluster[kkk] = Double.parseDouble(vals[2]); + clusterTruePositiveRate[kkk] = Double.parseDouble(vals[6]); //BUGBUG: #define these magic index numbers, very easy to make a mistake here for( int jjj = 0; jjj < dataManager.numAnnotations; jjj++ ) { - mu[kkk][jjj] = Double.parseDouble(vals[5+jjj]); - sigma[kkk][jjj] = Double.parseDouble(vals[5+dataManager.numAnnotations+jjj]); + mu[kkk][jjj] = Double.parseDouble(vals[7+jjj]); + sigma[kkk][jjj] = Double.parseDouble(vals[7+dataManager.numAnnotations+jjj]) * 3; //BUGBUG: *1.3, suggestion by Nick to prevent GMM from over fitting and producing low likelihoods for most points } kkk++; } @@ -129,13 +135,28 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel public final void run( final String clusterFileName ) { + final int MAX_VARS = 1000000; //BUGBUG: make this a command line argument + // Create the subset of the data to cluster with int numNovel = 0; + int numKnown = 0; + int numHet = 0; + int numHom = 0; for( final VariantDatum datum : dataManager.data ) { - if( !datum.isKnown ) { + if( datum.isKnown ) { + numKnown++; + } else { numNovel++; } + if( datum.isHet ) { + numHet++; + } else { + numHom++; + } } + + // This block of code is used to cluster with novels + 1.5x knowns mixed together + /* VariantDatum[] data; // Grab a set of data that is all of the novel variants plus 1.5x as many known variants drawn at random @@ -167,9 +188,235 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel printClusters( clusterFileName ); //System.out.println("Applying clusters to all variants..."); //applyClusters( dataManager.data, outputPrefix ); // Using all the data + */ + + + + + // This block of code is to cluster knowns and novels separately + /* + final VariantDatum[] dataNovel = new VariantDatum[Math.min(numNovel,MAX_VARS)]; + final VariantDatum[] dataKnown = new VariantDatum[Math.min(numKnown,MAX_VARS)]; + + //BUGBUG: This is ugly + int jjj = 0; + if(numNovel <= MAX_VARS) { + for( final VariantDatum datum : dataManager.data ) { + if( !datum.isKnown ) { + dataNovel[jjj++] = datum; + } + } + } else { + System.out.println("Capped at " + MAX_VARS + " novel variants."); + while( jjj < MAX_VARS ) { + final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; + if( !datum.isKnown ) { + dataNovel[jjj++] = datum; + } + } + } + + int iii = 0; + if(numKnown <= MAX_VARS) { + for( final VariantDatum datum : dataManager.data ) { + if( datum.isKnown ) { + dataKnown[iii++] = datum; + } + } + } else { + System.out.println("Capped at " + MAX_VARS + " known variants."); + while( iii < MAX_VARS ) { + final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; + if( datum.isKnown ) { + dataKnown[iii++] = datum; + } + } + } + + + System.out.println("Clustering with " + Math.min(numNovel,MAX_VARS) + " novel variants."); + createClusters( dataNovel, 0, numGaussians / 2 ); + System.out.println("Clustering with " + Math.min(numKnown,MAX_VARS) + " known variants."); + createClusters( dataKnown, numGaussians / 2, numGaussians ); + System.out.println("Outputting cluster parameters..."); + printClusters( clusterFileName ); + + */ + + // This block of code is to cluster het and hom calls separately, but mixing together knowns and novels + final VariantDatum[] dataHet = new VariantDatum[Math.min(numHet,MAX_VARS)]; + final VariantDatum[] dataHom = new VariantDatum[Math.min(numHom,MAX_VARS)]; + + //BUGBUG: This is ugly + int jjj = 0; + if(numHet <= MAX_VARS) { + for( final VariantDatum datum : dataManager.data ) { + if( datum.isHet ) { + dataHet[jjj++] = datum; + } + } + } else { + System.out.println("Found " + numHet + " het variants but capped at clustering with " + MAX_VARS + "."); + while( jjj < MAX_VARS ) { + final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; + if( datum.isHet ) { + dataHet[jjj++] = datum; + } + } + } + + int iii = 0; + if(numHom <= MAX_VARS) { + for( final VariantDatum datum : dataManager.data ) { + if( !datum.isHet ) { + dataHom[iii++] = datum; + } + } + } else { + System.out.println("Found " + numHom + " hom variants but capped at clustering with " + MAX_VARS + "."); + while( iii < MAX_VARS ) { + final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)]; + if( !datum.isHet ) { + dataHom[iii++] = datum; + } + } + } + + System.out.println("Clustering with " + Math.min(numHet,MAX_VARS) + " het variants."); + createClusters( dataHet, 0, numGaussians / 2 ); + System.out.println("Clustering with " + Math.min(numHom,MAX_VARS) + " hom variants."); + createClusters( dataHom, numGaussians / 2, numGaussians ); + System.out.println("Outputting cluster parameters..."); + printClusters( clusterFileName ); + } - public final void createClusters( final VariantDatum[] data ) { + + + public final void createClusters( final VariantDatum[] data, int startCluster, int stopCluster ) { + + final int numVariants = data.length; + final int numAnnotations = data[0].annotations.length; + + final double[][] pVarInCluster = new double[numGaussians][numVariants]; + final double[] probTi = new double[numGaussians]; + final double[] probTv = new double[numGaussians]; + final double[] probKnown = new double[numGaussians]; + final double[] probNovel = new double[numGaussians]; + final double[] probKnownTi = new double[numGaussians]; + final double[] probKnownTv = new double[numGaussians]; + final double[] probNovelTi = new double[numGaussians]; + final double[] probNovelTv = new double[numGaussians]; + + // loop control variables: + // iii - loop over data points + // jjj - loop over annotations (features) + // kkk - loop over clusters + // ttt - loop over EM iterations + + // Set up the initial random Gaussians + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + numMaxClusterKnown[kkk] = 0; + numMaxClusterNovel[kkk] = 0; + pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster)); + mu[kkk] = data[rand.nextInt(numVariants)].annotations; + final double[] randSigma = new double[numAnnotations]; + if( dataManager.isNormalized ) { + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble(); + } + } else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + randSigma[jjj] = dataManager.varianceVector[jjj] + ((1.0 + rand.nextDouble()) * 0.01 * dataManager.varianceVector[jjj]); + } + } + sigma[kkk] = randSigma; + } + + for( int ttt = 0; ttt < numIterations; ttt++ ) { + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Expectation Step (calculate the probability that each data point is in each cluster) + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Maximization Step (move the clusters to maximize the sum probability of each data point) + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + maximizeGaussians( data, pVarInCluster, startCluster, stopCluster ); + + System.out.println("Finished iteration " + (ttt+1) ); + } + + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + // Evaluate the clusters using titv as an estimate of the true positive rate + ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// + evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); // One final evaluation because the Gaussians moved in the last maximization step + + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + probTi[kkk] = 0.0; + probTv[kkk] = 0.0; + probKnown[kkk] = 0.0; + probNovel[kkk] = 0.0; + } + for( int iii = 0; iii < numVariants; iii++ ) { + final boolean isTransition = data[iii].isTransition; + final boolean isKnown = data[iii].isKnown; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + final double prob = pVarInCluster[kkk][iii]; + if( isKnown ) { // known + probKnown[kkk] += prob; + if( isTransition ) { // transition + probKnownTi[kkk] += prob; + probTi[kkk] += prob; + } else { // transversion + probKnownTv[kkk] += prob; + probTv[kkk] += prob; + } + } else { //novel + probNovel[kkk] += prob; + if( isTransition ) { // transition + probNovelTi[kkk] += prob; + probTi[kkk] += prob; + } else { // transversion + probNovelTv[kkk] += prob; + probTv[kkk] += prob; + } + } + + } + + double maxProb = pVarInCluster[startCluster][iii]; + int maxCluster = startCluster; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + if( pVarInCluster[kkk][iii] > maxProb ) { + maxProb = pVarInCluster[kkk][iii]; + maxCluster = kkk; + } + } + if( isKnown ) { + numMaxClusterKnown[maxCluster]++; + } else { + numMaxClusterNovel[maxCluster]++; + } + } + + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + clusterTITV[kkk] = probTi[kkk] / probTv[kkk]; + if( probKnown[kkk] > 2000.0 ) { // BUGBUG: make this a command line argument, parameterize performance based on this important argument + clusterTruePositiveRate[kkk] = calcTruePositiveRateFromKnownTITV( probKnownTi[kkk] / probKnownTv[kkk], probNovelTi[kkk] / probNovelTv[kkk] ); + } else { + clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] ); + } + } + } + + + + // This cluster method doesn't make use of the differences between known and novel Ti/Tv ratios + + /* + public final void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster ) { final int numVariants = data.length; final int numAnnotations = data[0].annotations.length; @@ -185,15 +432,15 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel // ttt - loop over EM iterations // Set up the initial random Gaussians - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { numMaxClusterKnown[kkk] = 0; numMaxClusterNovel[kkk] = 0; - pCluster[kkk] = 1.0 / ((double) numGaussians); + pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster)); mu[kkk] = data[rand.nextInt(numVariants)].annotations; final double[] randSigma = new double[numAnnotations]; if( dataManager.isNormalized ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = 0.75 + 0.4 * rand.nextDouble(); + randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble(); } } else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer for( int jjj = 0; jjj < numAnnotations; jjj++ ) { @@ -209,12 +456,12 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Expectation Step (calculate the probability that each data point is in each cluster) ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - evaluateGaussians( data, pVarInCluster ); + evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Maximization Step (move the clusters to maximize the sum probability of each data point) ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - maximizeGaussians( data, pVarInCluster ); + maximizeGaussians( data, pVarInCluster, startCluster, stopCluster ); System.out.println("Finished iteration " + (ttt+1) ); } @@ -222,28 +469,28 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Evaluate the clusters using titv as an estimate of the true positive rate ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// - evaluateGaussians( data, pVarInCluster ); // One final evaluation because the Gaussians moved in the last maximization step + evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); // One final evaluation because the Gaussians moved in the last maximization step - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { probTi[kkk] = 0.0; probTv[kkk] = 0.0; } // Use the cluster's probabilistic Ti/Tv ratio as the indication of the cluster's true positive rate for( int iii = 0; iii < numVariants; iii++ ) { if( data[iii].isTransition ) { // transition - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { probTi[kkk] += pVarInCluster[kkk][iii]; } } else { // transversion - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { probTv[kkk] += pVarInCluster[kkk][iii]; } } // Calculate which cluster has the maximum probability for this variant for use as a metric of how well clustered the data is - double maxProb = pVarInCluster[0][iii]; - int maxCluster = 0; - for( int kkk = 1; kkk < numGaussians; kkk++ ) { + double maxProb = pVarInCluster[startCluster][iii]; + int maxCluster = startCluster; + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { if( pVarInCluster[kkk][iii] > maxProb ) { maxProb = pVarInCluster[kkk][iii]; maxCluster = kkk; @@ -255,11 +502,13 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel numMaxClusterNovel[maxCluster]++; } } - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { clusterTITV[kkk] = probTi[kkk] / probTv[kkk]; clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] ); } } + */ + private void printClusters( final String clusterFileName ) { try { @@ -269,6 +518,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel for( int kkk = 0; kkk < numGaussians; kkk++ ) { if( numMaxClusterKnown[kkk] + numMaxClusterNovel[kkk] >= minVarInCluster ) { outputFile.print("@!CLUSTER,"); + outputFile.print( (kkk < numGaussians / 2 ? 1 : 0) + "," ); // is het cluster? + outputFile.print(pCluster[kkk] + ","); outputFile.print(numMaxClusterKnown[kkk] + ","); outputFile.print(numMaxClusterNovel[kkk] + ","); outputFile.print(clusterTITV[kkk] + ","); @@ -284,12 +535,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } outputFile.close(); } catch (Exception e) { - e.printStackTrace(); - System.exit(-1); + throw new StingException( "Unable to create output file: " + clusterFileName ); } } - public final double evaluateVariant( final Map annotationMap, final double qualityScore ) { + public final double evaluateVariant( final Map annotationMap, final double qualityScore, final boolean isHet ) { final double[] pVarInCluster = new double[numGaussians]; final double[] annotations = new double[dataManager.numAnnotations]; @@ -314,11 +564,13 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel annotations[jjj] = (value - dataManager.meanVector[jjj]) / dataManager.varianceVector[jjj]; } - evaluateGaussiansForSingleVariant( annotations, pVarInCluster ); + evaluateGaussiansForSingleVariant( annotations, pVarInCluster, isHet ); - double sum = 0; + double sum = 0.0; for( int kkk = 0; kkk < numGaussians; kkk++ ) { - sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk]; + if( isHetCluster[kkk] == isHet ) { + sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk]; + } } return sum; @@ -333,12 +585,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel markedVariant[iii] = false; } - PrintStream outputFile = null; + PrintStream outputFile; try { outputFile = new PrintStream( outputPrefix + ".dat" ); } catch (Exception e) { - e.printStackTrace(); - System.exit(-1); + throw new StingException( "Unable to create output file: " + outputPrefix + ".dat" ); } int numKnown = 0; @@ -382,21 +633,21 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel foundDesiredNumVariants = true; } outputFile.println( pCut + "," + numKnown + "," + numNovel + "," + - ( numKnownTi ==0 || numKnownTv == 0 ? "NaN" : ( ((double)numKnownTi) / ((double)numKnownTv) ) ) + "," + - ( numNovelTi ==0 || numNovelTv == 0 ? "NaN" : ( ((double)numNovelTi) / ((double)numNovelTv) ))); + ( numKnownTi == 0 || numKnownTv == 0 ? "NaN" : ( ((double)numKnownTi) / ((double)numKnownTv) ) ) + "," + + ( numNovelTi == 0 || numNovelTv == 0 ? "NaN" : ( ((double)numNovelTi) / ((double)numNovelTv) ))); } outputFile.close(); } - private void evaluateGaussians( final VariantDatum[] data, final double[][] pVarInCluster ) { + private void evaluateGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int startCluster, final int stopCluster ) { final int numAnnotations = data[0].annotations.length; for( int iii = 0; iii < data.length; iii++ ) { double sumProb = 0.0; - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { double sum = 0.0; for( int jjj = 0; jjj < numAnnotations; jjj++ ) { sum += ( (data[iii].annotations[jjj] - mu[kkk][jjj]) * (data[iii].annotations[jjj] - mu[kkk][jjj]) ) @@ -412,7 +663,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } if( sumProb > MIN_SUM_PROB ) { // Very small numbers are a very big problem - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { pVarInCluster[kkk][iii] /= sumProb; } } @@ -420,48 +671,53 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } - private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster ) { + private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster, final boolean isHet ) { final int numAnnotations = annotations.length; double sumProb = 0.0; for( int kkk = 0; kkk < numGaussians; kkk++ ) { - double sum = 0.0; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - sum += ( (annotations[jjj] - mu[kkk][jjj]) * (annotations[jjj] - mu[kkk][jjj]) ) - / sigma[kkk][jjj]; - } - //BUGBUG: removed pCluster[kkk]*, for the second pass - //pVarInCluster[kkk] = pCluster[kkk] * Math.exp( -0.5 * sum ); - pVarInCluster[kkk] = Math.exp( -0.5 * sum ); + if( isHetCluster[kkk] == isHet ) { + double sum = 0.0; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + sum += ( (annotations[jjj] - mu[kkk][jjj]) * (annotations[jjj] - mu[kkk][jjj]) ) + / sigma[kkk][jjj]; + } - if( pVarInCluster[kkk] < MIN_PROB) { // Very small numbers are a very big problem - pVarInCluster[kkk] = MIN_PROB; - } + //BUGBUG: pCluster[kkk] should be removed here? + pVarInCluster[kkk] = pCluster[kkk] * Math.exp( -0.5 * sum ); + //pVarInCluster[kkk] = Math.exp( -0.5 * sum ); - sumProb += pVarInCluster[kkk]; + if( pVarInCluster[kkk] < MIN_PROB) { // Very small numbers are a very big problem + pVarInCluster[kkk] = MIN_PROB; + } + + sumProb += pVarInCluster[kkk]; + } } if( sumProb > MIN_SUM_PROB ) { // Very small numbers are a very big problem for( int kkk = 0; kkk < numGaussians; kkk++ ) { - pVarInCluster[kkk] /= sumProb; + if( isHetCluster[kkk] == isHet ) { + pVarInCluster[kkk] /= sumProb; + } } } } - private void maximizeGaussians( final VariantDatum[] data, final double[][] pVarInCluster ) { + private void maximizeGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int startCluster, final int stopCluster ) { final int numVariants = data.length; final int numAnnotations = data[0].annotations.length; - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { mu[kkk][jjj] = 0.0; sigma[kkk][jjj] = 0.0; } } - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { double sumProb = 0.0; for( int iii = 0; iii < numVariants; iii++ ) { final double prob = pVarInCluster[kkk][iii]; @@ -475,7 +731,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel for( int jjj = 0; jjj < numAnnotations; jjj++ ) { mu[kkk][jjj] /= sumProb; } - } + } //BUGBUG: clean up dead clusters to speed up computation for( int iii = 0; iii < numVariants; iii++ ) { final double prob = pVarInCluster[kkk][iii]; @@ -499,9 +755,10 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel pCluster[kkk] = sumProb / numVariants; } + // Clean up extra big or extra small clusters //BUGBUG: Is this a good idea? - for( int kkk = 0; kkk < numGaussians; kkk++ ) { + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { if( pCluster[kkk] > 0.45 ) { // This is a very large cluster compared to all the others final int numToReplace = 4; final double[] savedSigma = sigma[kkk]; @@ -513,7 +770,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel randVarIndex = rand.nextInt( numVariants ); final double probK = pVarInCluster[kkk][randVarIndex]; boolean inClusterK = true; - for( int ccc = 0; ccc < numGaussians; ccc++ ) { + for( int ccc = startCluster; ccc < stopCluster; ccc++ ) { if( pVarInCluster[ccc][randVarIndex] > probK ) { inClusterK = false; break; @@ -525,11 +782,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel // Find a place to put the example variant if( rrr == 0 ) { // Replace the big cluster that kicked this process off mu[kkk] = data[randVarIndex].annotations; - pCluster[kkk] = 1.0 / ((double) numGaussians); + pCluster[kkk] = 1.0 / ((double) (stopCluster-startCluster)); } else { // Replace the cluster with the minimum prob - double minProb = pCluster[0]; - int minClusterIndex = 0; - for( int ccc = 1; ccc < numGaussians; ccc++ ) { + double minProb = pCluster[startCluster]; + int minClusterIndex = startCluster; + for( int ccc = startCluster; ccc < stopCluster; ccc++ ) { if( pCluster[ccc] < minProb ) { minProb = pCluster[ccc]; minClusterIndex = ccc; @@ -543,21 +800,22 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel sigma[minClusterIndex][jjj] = MIN_SUM_PROB; } } - pCluster[minClusterIndex] = 1.0 / ((double) numGaussians); + pCluster[minClusterIndex] = 1.0 / ((double) (stopCluster-startCluster)); } } } } + // Replace small clusters with another random draw from the dataset - for( int kkk = 0; kkk < numGaussians; kkk++ ) { - if( pCluster[kkk] < 0.05 * (1.0 / ((double) numGaussians)) ) { // This is a very small cluster compared to all the others - pCluster[kkk] = 1.0 / ((double) numGaussians); + for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { + if( pCluster[kkk] < 0.07 * (1.0 / ((double) (stopCluster-startCluster))) ) { // This is a very small cluster compared to all the others + pCluster[kkk] = 1.0 / ((double) (stopCluster-startCluster)); mu[kkk] = data[rand.nextInt(numVariants)].annotations; final double[] randSigma = new double[numAnnotations]; if( dataManager.isNormalized ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - randSigma[jjj] = 0.6 + 0.4 * rand.nextDouble(); // Explore a wider range of possible sigma values since we are tossing out clusters anyway + randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble(); // BUGBUG: Explore a wider range of possible sigma values since we are tossing out clusters anyway? } } else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer for( int jjj = 0; jjj < numAnnotations; jjj++ ) { @@ -567,6 +825,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel sigma[kkk] = randSigma; } } + } } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantNearestNeighborsModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantNearestNeighborsModel.java index 1c79bf583..ab6b806c0 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantNearestNeighborsModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantNearestNeighborsModel.java @@ -62,12 +62,11 @@ public final class VariantNearestNeighborsModel extends VariantOptimizationModel pTrueVariant[iii] = calcTruePositiveRateFromTITV( vTree.calcNeighborhoodTITV( dataManager.data[iii] ) ); } - PrintStream outputFile = null; + PrintStream outputFile; try { outputFile = new PrintStream( outputPrefix + ".knn.optimize" ); } catch (Exception e) { - e.printStackTrace(); - System.exit(-1); + throw new StingException( "Unable to create output file: " + outputPrefix + ".knn.optimize" ); } for(int iii = 0; iii < numVariants; iii++) { outputFile.print(String.format("%.4f",pTrueVariant[iii]) + ","); diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationModel.java index c52e29ed3..71b0027d5 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizationModel.java @@ -51,4 +51,10 @@ public abstract class VariantOptimizationModel implements VariantOptimizationInt //if( titv < 0.0 ) { titv = 0.0; } //return ( titv / targetTITV ); } + + public final double calcTruePositiveRateFromKnownTITV( final double knownTITV, double novelTITV ) { + if( novelTITV > knownTITV ) { novelTITV -= 2.0f*(novelTITV-knownTITV); } + if( novelTITV < 0.5 ) { novelTITV = 0.5; } + return ( (novelTITV - 0.5) / (knownTITV - 0.5) ); + } } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java index 569914451..45feecba9 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java @@ -10,7 +10,9 @@ import org.broadinstitute.sting.utils.ExpandingArrayList; import org.broadinstitute.sting.utils.StingException; import org.broadinstitute.sting.utils.cmdLine.Argument; -import java.io.IOException; +import java.util.Arrays; +import java.util.Set; +import java.util.TreeSet; /* * Copyright (c) 2010 The Broad Institute @@ -53,8 +55,10 @@ public class VariantOptimizer extends RodWalker ///////////////////////////// @Argument(fullName="target_titv", shortName="titv", doc="The target Ti/Tv ratio towards which to optimize. (~~2.1 for whole genome experiments)", required=true) private double TARGET_TITV = 2.1; - @Argument(fullName="ignore_input_filters", shortName="ignoreFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false) - private boolean IGNORE_INPUT_FILTERS = false; + @Argument(fullName="ignore_all_input_filters", shortName="ignoreAllFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false) + private boolean IGNORE_ALL_INPUT_FILTERS = false; + @Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false) + private String[] IGNORE_INPUT_FILTERS = null; @Argument(fullName="exclude_annotation", shortName="exclude", doc="The names of the annotations which should be excluded from the calculations", required=false) private String[] EXCLUDED_ANNOTATIONS = null; @Argument(fullName="force_annotation", shortName="force", doc="The names of the annotations which should be forced into the calculations even if they aren't present in every variant", required=false) @@ -62,11 +66,11 @@ public class VariantOptimizer extends RodWalker @Argument(fullName="clusterFile", shortName="clusterFile", doc="The output cluster file", required=true) private String CLUSTER_FILENAME = "optimizer.cluster"; @Argument(fullName="numGaussians", shortName="nG", doc="The number of Gaussians to be used in the Gaussian Mixture model", required=false) - private int NUM_GAUSSIANS = 32; + private int NUM_GAUSSIANS = 100; @Argument(fullName="numIterations", shortName="nI", doc="The number of iterations to be performed in the Gaussian Mixture model", required=false) - private int NUM_ITERATIONS = 10; - @Argument(fullName="minVarInCluster", shortName="minVar", doc="The minimum number of variants in a cluster to be considered a valid cluster. Used to prevent overfitting. Default is 2000.", required=true) - private int MIN_VAR_IN_CLUSTER = 2000; + private int NUM_ITERATIONS = 7; + @Argument(fullName="minVarInCluster", shortName="minVar", doc="The minimum number of variants in a cluster to be considered a valid cluster. It can be used to prevent overfitting.", required=false) + private int MIN_VAR_IN_CLUSTER = 0; //@Argument(fullName="knn", shortName="knn", doc="The number of nearest neighbors to be used in the k-Nearest Neighbors model", required=false) //private int NUM_KNN = 2000; @@ -80,6 +84,7 @@ public class VariantOptimizer extends RodWalker private boolean firstVariant = true; private int numAnnotations = 0; private static final double INFINITE_ANNOTATION_VALUE = 10000.0; + private Set ignoreInputFilterSet = null; //--------------------------------------------------------------------------------------------------------------- // @@ -88,7 +93,9 @@ public class VariantOptimizer extends RodWalker //--------------------------------------------------------------------------------------------------------------- public void initialize() { - //if( !PATH_TO_RESOURCES.endsWith("/") ) { PATH_TO_RESOURCES = PATH_TO_RESOURCES + "/"; } + if( IGNORE_INPUT_FILTERS != null ) { + ignoreInputFilterSet = new TreeSet(Arrays.asList(IGNORE_INPUT_FILTERS)); + } } //--------------------------------------------------------------------------------------------------------------- @@ -109,49 +116,52 @@ public class VariantOptimizer extends RodWalker for( final VariantContext vc : tracker.getAllVariantContexts(null, context.getLocation(), false, false) ) { - if( vc != null && vc.isSNP() && (IGNORE_INPUT_FILTERS || !vc.isFiltered()) ) { - if( firstVariant ) { // This is the first variant encountered so set up the list of annotations - annotationKeys.addAll( vc.getAttributes().keySet() ); - if( annotationKeys.contains("ID") ) { annotationKeys.remove("ID"); } // ID field is added to the vc's INFO field? - if( annotationKeys.contains("DB") ) { annotationKeys.remove("DB"); } - if( EXCLUDED_ANNOTATIONS != null ) { - for( final String excludedAnnotation : EXCLUDED_ANNOTATIONS ) { - if( annotationKeys.contains( excludedAnnotation ) ) { annotationKeys.remove( excludedAnnotation ); } + if( vc != null && vc.isSNP() ) { + if( !vc.isFiltered() || IGNORE_ALL_INPUT_FILTERS || (ignoreInputFilterSet != null && ignoreInputFilterSet.containsAll(vc.getFilters())) ) { + if( firstVariant ) { // This is the first variant encountered so set up the list of annotations + annotationKeys.addAll( vc.getAttributes().keySet() ); + if( annotationKeys.contains("ID") ) { annotationKeys.remove("ID"); } // ID field is added to the vc's INFO field? + if( annotationKeys.contains("DB") ) { annotationKeys.remove("DB"); } + if( EXCLUDED_ANNOTATIONS != null ) { + for( final String excludedAnnotation : EXCLUDED_ANNOTATIONS ) { + if( annotationKeys.contains( excludedAnnotation ) ) { annotationKeys.remove( excludedAnnotation ); } + } } - } - if( FORCED_ANNOTATIONS != null ) { - for( final String forcedAnnotation : FORCED_ANNOTATIONS ) { - if( !annotationKeys.contains( forcedAnnotation ) ) { annotationKeys.add( forcedAnnotation ); } + if( FORCED_ANNOTATIONS != null ) { + for( final String forcedAnnotation : FORCED_ANNOTATIONS ) { + if( !annotationKeys.contains( forcedAnnotation ) ) { annotationKeys.add( forcedAnnotation ); } + } } + numAnnotations = annotationKeys.size() + 1; // +1 for variant quality ("QUAL") + annotationValues = new double[numAnnotations]; + firstVariant = false; } - numAnnotations = annotationKeys.size() + 1; // +1 for variant quality ("QUAL") - annotationValues = new double[numAnnotations]; - firstVariant = false; + + int iii = 0; + for( final String key : annotationKeys ) { + + double value = 0.0; + try { + value = Double.parseDouble( (String)vc.getAttribute( key, "0.0" ) ); + if( Double.isInfinite(value) ) { + value = ( value > 0 ? 1.0 : -1.0 ) * INFINITE_ANNOTATION_VALUE; + } + } catch( NumberFormatException e ) { + // do nothing, default value is 0.0 + } + annotationValues[iii++] = value; + } + + // Variant quality ("QUAL") is not in the list of annotations, but is useful so add it here. + annotationValues[iii] = vc.getPhredScaledQual(); + + final VariantDatum variantDatum = new VariantDatum(); + variantDatum.annotations = annotationValues; + variantDatum.isTransition = vc.getSNPSubstitutionType().compareTo(BaseUtils.BaseSubstitutionType.TRANSITION) == 0; + variantDatum.isKnown = !vc.getAttribute("ID").equals("."); + variantDatum.isHet = vc.getHetCount() > vc.getHomVarCount(); // BUGBUG: what to do here for multi sample calls? + mapList.add( variantDatum ); } - - int iii = 0; - for( final String key : annotationKeys ) { - - double value = 0.0; - try { - value = Double.parseDouble( (String)vc.getAttribute( key, "0.0" ) ); - if( Double.isInfinite(value) ) { - value = ( value > 0 ? 1.0 : -1.0 ) * INFINITE_ANNOTATION_VALUE; - } - } catch( NumberFormatException e ) { - // do nothing, default value is 0.0 - } - annotationValues[iii++] = value; - } - - // Variant quality ("QUAL") is not in the list of annotations, but is useful so add it here. - annotationValues[iii] = vc.getPhredScaledQual(); - - VariantDatum variantDatum = new VariantDatum(); - variantDatum.annotations = annotationValues; - variantDatum.isTransition = vc.getSNPSubstitutionType().compareTo(BaseUtils.BaseSubstitutionType.TRANSITION) == 0; - variantDatum.isKnown = !vc.getAttribute("ID").equals("."); - mapList.add( variantDatum ); } }