From 9e152994753509e5c568e30ebb68ad1d46dc2895 Mon Sep 17 00:00:00 2001 From: rpoplin Date: Tue, 18 May 2010 17:37:01 +0000 Subject: [PATCH] Misc cleanup in variant recalibrator. git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@3380 348d0f76-0448-11de-a6fe-93d51630548a --- .../ApplyVariantClustersWalker.java | 2 +- .../VariantConcordanceROCCurveWalker.java | 303 --------------- .../variantoptimizer/VariantDatum.java | 2 - .../VariantGaussianMixtureModel.java | 355 +++++++----------- .../variantoptimizer/VariantOptimizer.java | 8 +- 5 files changed, 131 insertions(+), 539 deletions(-) delete mode 100755 java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantConcordanceROCCurveWalker.java diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java index c561d7325..6b74c070d 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/ApplyVariantClustersWalker.java @@ -193,7 +193,7 @@ public class ApplyVariantClustersWalker extends RodWalker>, HashMap>> { - - ///////////////////////////// - // Command Line Arguments - ///////////////////////////// - @Argument(fullName="output_prefix", shortName="output", doc="The prefix added to output VCF file name and optimization curve pdf file name", required=false) - private String OUTPUT_PREFIX = "optimizer"; - @Argument(fullName = "path_to_Rscript", shortName = "Rscript", doc = "The path to your implementation of Rscript. For Broad users this is probably /broad/tools/apps/R-2.6.0/bin/Rscript", required = false) - private String PATH_TO_RSCRIPT = "/broad/tools/apps/R-2.6.0/bin/Rscript"; - @Argument(fullName = "path_to_resources", shortName = "resources", doc = "Path to resources folder holding the Sting R scripts.", required = false) - private String PATH_TO_RESOURCES = "R/"; - - ///////////////////////////// - // Private Member Variables - ///////////////////////////// - private final ExpandingArrayList inputRodNames = new ExpandingArrayList(); - private int numCurves; - private int[] trueNegGlobal; - private int[] falseNegGlobal; - private String sampleName = null; - private boolean multiSampleCalls = false; - - //--------------------------------------------------------------------------------------------------------------- - // - // initialize - // - //--------------------------------------------------------------------------------------------------------------- - - public void initialize() { - if( !PATH_TO_RESOURCES.endsWith("/") ) { PATH_TO_RESOURCES = PATH_TO_RESOURCES + "/"; } - - for( ReferenceOrderedDataSource rod : this.getToolkit().getRodDataSources() ) { - if( rod != null && !rod.getName().toUpperCase().startsWith("TRUTH") ) { - if( rod.getReferenceOrderedData().getIterator().hasNext() && rod.getReferenceOrderedData().getIterator().next().getUnderlyingObject() instanceof VCFRecord ) { - inputRodNames.add(rod.getName()); - System.out.println("Adding " + rod.getName() + " to input RodVCF list."); - if( sampleName == null && !multiSampleCalls ) { - final String[] samples = ((VCFRecord)rod.getReferenceOrderedData().getIterator().next().getUnderlyingObject()).getSampleNames(); - - if( samples.length > 1 ) { - multiSampleCalls = true; - System.out.println("Found multi sample calls."); - } else { - sampleName = samples[0]; - System.out.println("Found single sample calls."); - } - } - } - } - } - numCurves = inputRodNames.size(); - trueNegGlobal = new int[numCurves]; - falseNegGlobal = new int[numCurves]; - for( int kkk = 0; kkk < numCurves; kkk++ ) { - trueNegGlobal[kkk] = 0; - falseNegGlobal[kkk] = 0; - } - - - } - - //--------------------------------------------------------------------------------------------------------------- - // - // map - // - //--------------------------------------------------------------------------------------------------------------- - - public ExpandingArrayList> map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) { - - final ExpandingArrayList> mapList = new ExpandingArrayList>(); - - if( tracker == null ) { // For some reason RodWalkers get map calls with null trackers - return mapList; - } - - boolean isInTruthSet = false; - boolean isTrueVariant = false; - - for( final VariantContext vc : tracker.getAllVariantContexts(ref, null, context.getLocation(), false, false) ) { - if( vc != null && vc.getName().toUpperCase().startsWith("TRUTH") ) { - if( vc.isSNP() && !vc.isFiltered() ) { - if( multiSampleCalls ) { - isInTruthSet = true; - if( vc.isPolymorphic() ) { - isTrueVariant = true; - } - } else { - if( !vc.getGenotype(sampleName).isNoCall() ) { - isInTruthSet = true; - - if( !vc.getGenotype(sampleName).isHomRef() ) { - isTrueVariant = true; - } - } - } - } - } - } - - if( !isInTruthSet ) { - return mapList; // jump out if this location isn't in the truth set - } - - int curveIndex = 0; - for( final String inputName : inputRodNames ) { - final VariantContext vc = tracker.getVariantContext( inputName, null, context.getLocation(), false ); // assuming single variant per track per location - - if( vc != null && vc.isPolymorphic() && !vc.isFiltered() ) { - final VariantDatum variantDatum = new VariantDatum(); - variantDatum.qual = vc.getPhredScaledQual(); - variantDatum.isTrueVariant = isTrueVariant; - mapList.add( new Pair(inputName, variantDatum) ); - } - else if ( vc != null && vc.isFiltered() && vc.getFilters().contains("HARD_TO_VALIDATE") ) { // hard to validate sites are hard to validate so don't count them - - } - else { // Either not in the call set at all, is a monomorphic call, or was filtered out - if( isTrueVariant ) { // ... but it is a true variant so this is a false negative call - falseNegGlobal[curveIndex]++; - } else { // ... and it is not a variant site so this is a true negative call - trueNegGlobal[curveIndex]++; - } - } - curveIndex++; - } - - return mapList; - } - - //--------------------------------------------------------------------------------------------------------------- - // - // reduce - // - //--------------------------------------------------------------------------------------------------------------- - - public HashMap> reduceInit() { - final HashMap> init = new HashMap>(); - for( final String inputName : inputRodNames ) { - init.put( inputName, new ExpandingArrayList() ); - } - return init; - } - - public HashMap> reduce( final ExpandingArrayList> mapValue, final HashMap> reduceSum ) { - for( Pair value : mapValue ) { - final ExpandingArrayList list = reduceSum.get(value.getFirst()); - list.add(value.getSecond()); - reduceSum.put(value.getFirst(),list); - } - return reduceSum; - } - - public void onTraversalDone( HashMap> reduceSum ) { - - final int NUM_CURVES = numCurves; - final HashMap dataManagerMap = new HashMap(); - for( final String inputName : inputRodNames ) { - System.out.println("Creating data manager for: " + inputName); - dataManagerMap.put(inputName, new VariantDataManager( reduceSum.get(inputName), null )); - } - reduceSum.clear(); // Don't need this ever again, clean up some memory - - final double[] minQual = new double[NUM_CURVES]; - final double[] maxQual = new double[NUM_CURVES]; - final double[] incrementQual = new double[NUM_CURVES]; - final double[] qualCut = new double[NUM_CURVES]; - - final int NUM_STEPS = 200; - - int curveIndex = 0; - for( final String inputName : inputRodNames ) { - final VariantDataManager dataManager = dataManagerMap.get(inputName); - minQual[curveIndex] = dataManager.data[0].qual; - maxQual[curveIndex] = dataManager.data[0].qual; - for( int iii = 1; iii < dataManager.data.length; iii++ ) { - final double qual = dataManager.data[iii].qual; - if( qual < minQual[curveIndex] ) { minQual[curveIndex] = qual; } - else if( qual > maxQual[curveIndex] ) { maxQual[curveIndex] = qual; } - } - incrementQual[curveIndex] = (maxQual[curveIndex] - minQual[curveIndex]) / ((double)NUM_STEPS); - qualCut[curveIndex] = minQual[curveIndex]; - curveIndex++; - } - - final int[] truePos = new int[NUM_CURVES]; - final int[] falsePos = new int[NUM_CURVES]; - final int[] trueNeg = new int[NUM_CURVES]; - final int[] falseNeg = new int[NUM_CURVES]; - - PrintStream outputFile; - try { - outputFile = new PrintStream( OUTPUT_PREFIX + ".dat" ); - } catch (Exception e) { - throw new StingException( "Unable to create output file: " + OUTPUT_PREFIX + ".dat" ); - } - int jjj = 1; - for( final String inputName : inputRodNames ) { - outputFile.print(inputName + ",sensitivity" + jjj + ",specificity" + jjj + ","); - jjj++; - } - outputFile.println("sentinel"); - - for( int step = 0; step < NUM_STEPS; step++ ) { - curveIndex = 0; - for( final String inputName : inputRodNames ) { - final VariantDataManager dataManager = dataManagerMap.get(inputName); - truePos[curveIndex] = 0; - falsePos[curveIndex] = 0; - trueNeg[curveIndex] = 0; - falseNeg[curveIndex] = 0; - final int NUM_VARIANTS = dataManager.data.length; - for( int iii = 0; iii < NUM_VARIANTS; iii++ ) { - if( dataManager.data[iii].qual >= qualCut[curveIndex] ) { // this var is in this hypothetical call set - if( dataManager.data[iii].isTrueVariant ) { - truePos[curveIndex]++; - } else { - falsePos[curveIndex]++; - } - } else { // this var is out of this hypothetical call set - if( dataManager.data[iii].isTrueVariant ) { - falseNeg[curveIndex]++; - } else { - trueNeg[curveIndex]++; - } - } - } - final double sensitivity = ((double) truePos[curveIndex]) / ((double) truePos[curveIndex] + falseNegGlobal[curveIndex] + falseNeg[curveIndex]); - final double specificity = ((double) trueNegGlobal[curveIndex] + trueNeg[curveIndex]) / - ((double) falsePos[curveIndex] + trueNegGlobal[curveIndex] + trueNeg[curveIndex]); - outputFile.print( String.format("%.8f,%.8f,%.8f,", qualCut[curveIndex], sensitivity, 1.0 - specificity) ); - qualCut[curveIndex] += incrementQual[curveIndex]; - curveIndex++; - } - outputFile.println("-1"); - } - - outputFile.close(); - - // Execute Rscript command to plot the optimization curve - // Print out the command line to make it clear to the user what is being executed and how one might modify it - final String rScriptCommandLine = PATH_TO_RSCRIPT + " " + PATH_TO_RESOURCES + "plot_variantROCCurve.R" + " " + OUTPUT_PREFIX + ".dat"; - System.out.println( rScriptCommandLine ); - - // Execute the RScript command to plot the table of truth values - try { - Runtime.getRuntime().exec( rScriptCommandLine ); - } catch ( IOException e ) { - throw new StingException( "Unable to execute RScript command: " + rScriptCommandLine ); - } - } -} diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDatum.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDatum.java index 663304d31..647ddea52 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDatum.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantDatum.java @@ -35,7 +35,5 @@ public class VariantDatum { public double[] annotations; public boolean isTransition; public boolean isKnown; - public boolean isTrueVariant; - public boolean isHet; public double qual; } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java index e3106a179..ad01d3fc2 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantGaussianMixtureModel.java @@ -61,7 +61,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel private final double[][] mu; // The means for the clusters private final Matrix[] sigma; // The variances for the clusters, sigma is really sigma^2 private final Matrix[] sigmaInverse; - private final boolean[] deadCluster; private final double[] pCluster; private final double[] determinant; private final double[] clusterTITV; @@ -82,7 +81,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel mu = new double[numGaussians][]; sigma = new Matrix[numGaussians]; determinant = new double[numGaussians]; - deadCluster = new boolean[numGaussians]; pCluster = new double[numGaussians]; clusterTITV = new double[numGaussians]; clusterTruePositiveRate = new double[numGaussians]; @@ -114,7 +112,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel // Several of the clustering parameters aren't used the second time around in ApplyVariantClusters numIterations = 0; clusterTITV = null; - deadCluster = null; minVarInCluster = 0; // BUGBUG: move this parsing out of the constructor @@ -144,7 +141,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel sigma[kkk] = new Matrix(sigmaVals[kkk]); sigmaInverse[kkk] = sigma[kkk].inverse(); // Precompute all the inverses and determinants for use later determinant[kkk] = sigma[kkk].det(); - //if( determinant[kkk] < MIN_DETERMINANT ) { determinant[kkk] = MIN_DETERMINANT; } kkk++; } isUsingTiTvModel = _isUsingTiTvModel; @@ -252,12 +248,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel // Set up the initial random Gaussians for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - deadCluster[kkk] = false; pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster)); - //final double[] randMu = new double[numAnnotations]; - //for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - // randMu[jjj] = -1.5 + 3.0 * rand.nextDouble(); - //} mu[kkk] = data[rand.nextInt(numVariants)].annotations; final double[][] randSigma = new double[numAnnotations][numAnnotations]; for( int jjj = 0; jjj < numAnnotations; jjj++ ) { @@ -270,18 +261,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel tmp = tmp.times(tmp.transpose()); sigma[kkk] = tmp; determinant[kkk] = sigma[kkk].det(); - //if( determinant[kkk] < MIN_DETERMINANT ) { deadCluster[kkk] = true; } } // The EM loop for( int ttt = 0; ttt < numIterations; ttt++ ) { - //int numValidClusters = 0; - //for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - // if( !deadCluster[kkk] ) { numValidClusters++; } - //} - //logger.info("Starting iteration " + (ttt+1) + " with " + numValidClusters + " clusters."); - ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// // Expectation Step (calculate the probability that each data point is in each cluster) ///////////////////////////////////////////////////////////////////////////////////////////////////////////////// @@ -309,7 +293,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel clusterTITV[kkk] = 0.0; clusterTruePositiveRate[kkk] = 1.0; } - printClusterParamters( clusterFileName + ".WithoutTiTv." + iterationNumber ); + printClusterParameters( clusterFileName + ".WithoutTiTv." + iterationNumber ); return; } @@ -340,29 +324,26 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final boolean isTransition = data[iii].isTransition; final boolean isKnown = data[iii].isKnown; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - if( !deadCluster[kkk] ) { - final double prob = pVarInCluster[kkk][iii]; - if( isKnown ) { // known - probKnown[kkk] += prob; - if( isTransition ) { // transition - probKnownTi[kkk] += prob; - probTi[kkk] += prob; - } else { // transversion - probKnownTv[kkk] += prob; - probTv[kkk] += prob; - } - } else { //novel - probNovel[kkk] += prob; - if( isTransition ) { // transition - probNovelTi[kkk] += prob; - probTi[kkk] += prob; - } else { // transversion - probNovelTv[kkk] += prob; - probTv[kkk] += prob; - } + final double prob = pVarInCluster[kkk][iii]; + if( isKnown ) { // known + probKnown[kkk] += prob; + if( isTransition ) { // transition + probKnownTi[kkk] += prob; + probTi[kkk] += prob; + } else { // transversion + probKnownTv[kkk] += prob; + probTv[kkk] += prob; + } + } else { //novel + probNovel[kkk] += prob; + if( isTransition ) { // transition + probNovelTi[kkk] += prob; + probTi[kkk] += prob; + } else { // transversion + probNovelTv[kkk] += prob; + probTv[kkk] += prob; } } - } } @@ -376,50 +357,46 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel knownAlphaFactor = 0.5; } for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - if( !deadCluster[kkk] ) { - clusterTITV[kkk] = probTi[kkk] / probTv[kkk]; - if( probKnown[kkk] > 500.0 && probNovel[kkk] > 500.0 ) { - clusterTruePositiveRate[kkk] = calcTruePositiveRateFromKnownTITV( probKnownTi[kkk] / probKnownTv[kkk], probNovelTi[kkk] / probNovelTv[kkk], clusterTITV[kkk], knownAlphaFactor ); - } else { - clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] ); - } + clusterTITV[kkk] = probTi[kkk] / probTv[kkk]; + if( probKnown[kkk] > 500.0 && probNovel[kkk] > 500.0 ) { + clusterTruePositiveRate[kkk] = calcTruePositiveRateFromKnownTITV( probKnownTi[kkk] / probKnownTv[kkk], probNovelTi[kkk] / probNovelTv[kkk], clusterTITV[kkk], knownAlphaFactor ); + } else { + clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] ); } } if( ttt == 0 ) { - printClusterParamters( clusterFileName + ".TargetTiTv." + iterationNumber ); + printClusterParameters( clusterFileName + ".TargetTiTv." + iterationNumber ); } else if( ttt == 1 ) { - printClusterParamters( clusterFileName + ".KnownTiTv." + iterationNumber ); + printClusterParameters( clusterFileName + ".KnownTiTv." + iterationNumber ); } else if( ttt == 2 ) { - printClusterParamters( clusterFileName + ".BlendedTiTv." + iterationNumber ); + printClusterParameters( clusterFileName + ".BlendedTiTv." + iterationNumber ); } } } - private void printClusterParamters( final String clusterFileName ) { + private void printClusterParameters( final String clusterFileName ) { try { final PrintStream outputFile = new PrintStream( clusterFileName ); dataManager.printClusterFileHeader( outputFile ); final int numAnnotations = mu[0].length; final int numVariants = dataManager.numVariants; for( int kkk = 0; kkk < numGaussians; kkk++ ) { - if( !deadCluster[kkk] ) { - if( pCluster[kkk] * numVariants > minVarInCluster ) { - final double sigmaVals[][] = sigma[kkk].getArray(); - outputFile.print("@!CLUSTER,"); - outputFile.print(pCluster[kkk] + ","); - outputFile.print(clusterTITV[kkk] + ","); - outputFile.print(clusterTruePositiveRate[kkk] + ","); - for(int jjj = 0; jjj < numAnnotations; jjj++ ) { - outputFile.print(mu[kkk][jjj] + ","); - } - for(int jjj = 0; jjj < numAnnotations; jjj++ ) { - for(int ppp = 0; ppp < numAnnotations; ppp++ ) { - outputFile.print(sigmaVals[jjj][ppp] + ","); - } - } - outputFile.println(-1); + if( pCluster[kkk] * numVariants > minVarInCluster ) { + final double sigmaVals[][] = sigma[kkk].getArray(); + outputFile.print("@!CLUSTER,"); + outputFile.print(pCluster[kkk] + ","); + outputFile.print(clusterTITV[kkk] + ","); + outputFile.print(clusterTruePositiveRate[kkk] + ","); + for(int jjj = 0; jjj < numAnnotations; jjj++ ) { + outputFile.print(mu[kkk][jjj] + ","); } + for(int jjj = 0; jjj < numAnnotations; jjj++ ) { + for(int ppp = 0; ppp < numAnnotations; ppp++ ) { + outputFile.print(sigmaVals[jjj][ppp] + ","); + } + } + outputFile.println(-1); } } outputFile.close(); @@ -458,52 +435,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel evaluateGaussiansForSingleVariant( annotations, pVarInCluster ); - //if( isUsingTiTvModel ) { - // Sum prob model - double sum = 0.0; - for( int kkk = 0; kkk < numGaussians; kkk++ ) { - sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk]; - } - return sum; - /* - } else { - // Max prob model - double maxProb = 0.0; - for( int kkk = 0; kkk < numGaussians; kkk++ ) { - if( pVarInCluster[kkk] > maxProb ) { - maxProb = pVarInCluster[kkk]; - } - } - return maxProb; - } - */ - - // Max prob model - /* - double maxProb = 0.0; - for( int kkk = 0; kkk < numGaussians; kkk++ ) { - if( pVarInCluster[kkk] > maxProb ) { - maxProb = pVarInCluster[kkk]; - } - } - return maxProb; - */ - - // Entropy model - /* double sum = 0.0; for( int kkk = 0; kkk < numGaussians; kkk++ ) { - //if( isHetCluster[kkk] == isHet ) { - sum += pVarInCluster[kkk] * Math.log(pVarInCluster[kkk]); - //} + sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk]; } - - double entropy = -1.0 * sum; - double maxEntropy = -1.0 * Math.log( 1.0 / ((double) numGaussians)); - - //System.out.println("H = " + entropy + ", pTrue = " + ( 1.0 - (entropy / maxEntropy) )); - return ( 1.0 - (entropy / maxEntropy) ); - */ + return sum; } public final void outputClusterReports( final String outputPrefix ) { @@ -688,60 +624,56 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final double sigmaVals[][][] = new double[numGaussians][][]; final double denom[] = new double[numGaussians]; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - if( !deadCluster[kkk] ) { - sigmaVals[kkk] = sigma[kkk].inverse().getArray(); - denom[kkk] = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(Math.abs(determinant[kkk]), 0.5); - } + sigmaVals[kkk] = sigma[kkk].inverse().getArray(); + denom[kkk] = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(Math.abs(determinant[kkk]), 0.5); } final double mult[] = new double[numAnnotations]; for( int iii = 0; iii < data.length; iii++ ) { double sumProb = 0.0; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - if( !deadCluster[kkk] ) { - double sum = 0.0; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - mult[jjj] = 0.0; - for( int ppp = 0; ppp < numAnnotations; ppp++ ) { - mult[jjj] += (data[iii].annotations[ppp] - mu[kkk][ppp]) * sigmaVals[kkk][ppp][jjj]; - } + double sum = 0.0; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mult[jjj] = 0.0; + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + mult[jjj] += (data[iii].annotations[ppp] - mu[kkk][ppp]) * sigmaVals[kkk][ppp][jjj]; } - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - sum += mult[jjj] * (data[iii].annotations[jjj] - mu[kkk][jjj]); - } - - pVarInCluster[kkk][iii] = pCluster[kkk] * (Math.exp( -0.5 * sum ) / denom[kkk]); - likelihood += pVarInCluster[kkk][iii]; - if(Double.isNaN(denom[kkk]) || determinant[kkk] < 0.5 * MIN_DETERMINANT) { - System.out.println("det = " + sigma[kkk].det()); - System.out.println("denom = " + denom[kkk]); - System.out.println("sumExp = " + sum); - System.out.println("pVar = " + pVarInCluster[kkk][iii]); - System.out.println("=-------="); - throw new StingException("Numerical Instability! determinant of covariance matrix <= 0. Try running with fewer clusters and then with better behaved annotation values."); - } - if(sum < 0.0) { - System.out.println("det = " + sigma[kkk].det()); - System.out.println("denom = " + denom[kkk]); - System.out.println("sumExp = " + sum); - System.out.println("pVar = " + pVarInCluster[kkk][iii]); - System.out.println("=-------="); - throw new StingException("Numerical Instability! covariance matrix no longer positive definite. Try running with fewer clusters and then with better behaved annotation values."); - } - if(pVarInCluster[kkk][iii] > 1.0) { - System.out.println("det = " + sigma[kkk].det()); - System.out.println("denom = " + denom[kkk]); - System.out.println("sumExp = " + sum); - System.out.println("pVar = " + pVarInCluster[kkk][iii]); - System.out.println("=-------="); - throw new StingException("Numerical Instability! probability distribution returns > 1.0. Try running with fewer clusters and then with better behaved annotation values."); - } - - if( pVarInCluster[kkk][iii] < MIN_PROB) { // Very small numbers are a very big problem - pVarInCluster[kkk][iii] = MIN_PROB;// + MIN_PROB * rand.nextDouble(); - } - - sumProb += pVarInCluster[kkk][iii]; } + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + sum += mult[jjj] * (data[iii].annotations[jjj] - mu[kkk][jjj]); + } + + pVarInCluster[kkk][iii] = pCluster[kkk] * (Math.exp( -0.5 * sum ) / denom[kkk]); + likelihood += pVarInCluster[kkk][iii]; + if(Double.isNaN(denom[kkk]) || determinant[kkk] < 0.5 * MIN_DETERMINANT) { + System.out.println("det = " + sigma[kkk].det()); + System.out.println("denom = " + denom[kkk]); + System.out.println("sumExp = " + sum); + System.out.println("pVar = " + pVarInCluster[kkk][iii]); + System.out.println("=-------="); + throw new StingException("Numerical Instability! determinant of covariance matrix <= 0. Try running with fewer clusters and then with better behaved annotation values."); + } + if(sum < 0.0) { + System.out.println("det = " + sigma[kkk].det()); + System.out.println("denom = " + denom[kkk]); + System.out.println("sumExp = " + sum); + System.out.println("pVar = " + pVarInCluster[kkk][iii]); + System.out.println("=-------="); + throw new StingException("Numerical Instability! covariance matrix no longer positive definite. Try running with fewer clusters and then with better behaved annotation values."); + } + if(pVarInCluster[kkk][iii] > 1.0) { + System.out.println("det = " + sigma[kkk].det()); + System.out.println("denom = " + denom[kkk]); + System.out.println("sumExp = " + sum); + System.out.println("pVar = " + pVarInCluster[kkk][iii]); + System.out.println("=-------="); + throw new StingException("Numerical Instability! probability distribution returns > 1.0. Try running with fewer clusters and then with better behaved annotation values."); + } + + if( pVarInCluster[kkk][iii] < MIN_PROB) { // Very small numbers are a very big problem + pVarInCluster[kkk][iii] = MIN_PROB;// + MIN_PROB * rand.nextDouble(); + } + + sumProb += pVarInCluster[kkk][iii]; } for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { @@ -757,8 +689,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster ) { final int numAnnotations = annotations.length; - - double sumProb = 0.0; final double mult[] = new double[numAnnotations]; for( int kkk = 0; kkk < numGaussians; kkk++ ) { final double sigmaVals[][] = sigmaInverse[kkk].getArray(); @@ -775,28 +705,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel final double denom = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(determinant[kkk], 0.5); pVarInCluster[kkk] = pCluster[kkk] * (Math.exp( -0.5 * sum )) / denom; - - /* - if( isUsingTiTvModel ) { - //pVarInCluster[kkk] = Math.exp( -0.5 * sum ); - if( pVarInCluster[kkk] < MIN_PROB) { // Very small numbers are a very big problem - pVarInCluster[kkk] = MIN_PROB; - } - sumProb += pVarInCluster[kkk]; - } else { - //final double denom = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(determinant[kkk], 0.5); - //pVarInCluster[kkk] = pCluster[kkk] * (Math.exp( -0.5 * sum )) / denom; - //pVarInCluster[kkk] = Math.exp( -0.5 * sum ); - // BUGBUG: should pCluster be the distribution from the GMM or a uniform distribution here? - } - */ } - - //if( isUsingTiTvModel ) { - // for( int kkk = 0; kkk < numGaussians; kkk++ ) { - // pVarInCluster[kkk] /= sumProb; - // } - //} } @@ -816,53 +725,48 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel } double sumPK = 0.0; for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - if( !deadCluster[kkk] ) { - double sumProb = 0.0; - for( int iii = 0; iii < numVariants; iii++ ) { - final double prob = pVarInCluster[kkk][iii]; - sumProb += prob; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - mu[kkk][jjj] += prob * data[iii].annotations[jjj]; - } - } - + double sumProb = 0.0; + for( int iii = 0; iii < numVariants; iii++ ) { + final double prob = pVarInCluster[kkk][iii]; + sumProb += prob; for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - mu[kkk][jjj] /= sumProb; - } - - for( int iii = 0; iii < numVariants; iii++ ) { - final double prob = pVarInCluster[kkk][iii]; - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { - sigmaVals[kkk][jjj][ppp] += prob * (data[iii].annotations[jjj]-mu[kkk][jjj]) * (data[iii].annotations[ppp]-mu[kkk][ppp]); - } - } - } - - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { - if( sigmaVals[kkk][jjj][ppp] < MIN_SIGMA ) { // Very small numbers are a very big problem - sigmaVals[kkk][jjj][ppp] = MIN_SIGMA;// + MIN_SIGMA * rand.nextDouble(); - } - sigmaVals[kkk][ppp][jjj] = sigmaVals[kkk][jjj][ppp]; // sigma must be a symmetric matrix - } - } - - for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - for( int ppp = 0; ppp < numAnnotations; ppp++ ) { - sigmaVals[kkk][jjj][ppp] /= sumProb; - } - } - - sigma[kkk] = new Matrix(sigmaVals[kkk]); - determinant[kkk] = sigma[kkk].det(); - //if( determinant[kkk] < MIN_DETERMINANT ) { deadCluster[kkk] = true; } - - if( !deadCluster[kkk] ) { - pCluster[kkk] = sumProb / numVariants; - sumPK += pCluster[kkk]; + mu[kkk][jjj] += prob * data[iii].annotations[jjj]; } } + + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + mu[kkk][jjj] /= sumProb; + } + + for( int iii = 0; iii < numVariants; iii++ ) { + final double prob = pVarInCluster[kkk][iii]; + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] += prob * (data[iii].annotations[jjj]-mu[kkk][jjj]) * (data[iii].annotations[ppp]-mu[kkk][ppp]); + } + } + } + + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = jjj; ppp < numAnnotations; ppp++ ) { + if( sigmaVals[kkk][jjj][ppp] < MIN_SIGMA ) { // Very small numbers are a very big problem + sigmaVals[kkk][jjj][ppp] = MIN_SIGMA;// + MIN_SIGMA * rand.nextDouble(); + } + sigmaVals[kkk][ppp][jjj] = sigmaVals[kkk][jjj][ppp]; // sigma must be a symmetric matrix + } + } + + for( int jjj = 0; jjj < numAnnotations; jjj++ ) { + for( int ppp = 0; ppp < numAnnotations; ppp++ ) { + sigmaVals[kkk][jjj][ppp] /= sumProb; + } + } + + sigma[kkk] = new Matrix(sigmaVals[kkk]); + determinant[kkk] = sigma[kkk].det(); + + pCluster[kkk] = sumProb / numVariants; + sumPK += pCluster[kkk]; } // ensure pCluster sums to one, it doesn't automatically due to very small numbers getting capped @@ -924,15 +828,10 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel // Replace extremely small clusters with another random draw from the dataset for( int kkk = startCluster; kkk < stopCluster; kkk++ ) { - //if(determinant[kkk] < MIN_DETERMINANT ) { if( pCluster[kkk] < 0.0005 * (1.0 / ((double) (stopCluster-startCluster))) || determinant[kkk] < MIN_DETERMINANT ) { // This is a very small cluster compared to all the others logger.info("!! Found singular cluster! Initializing a new random cluster."); - pCluster[kkk] = 0.1 / ((double) (stopCluster-startCluster)); // 0.5 / - //final double[] randMu = new double[numAnnotations]; - //for( int jjj = 0; jjj < numAnnotations; jjj++ ) { - // randMu[jjj] = -1.5 + 3.0 * rand.nextDouble(); - //} + pCluster[kkk] = 0.1 / ((double) (stopCluster-startCluster)); mu[kkk] = data[rand.nextInt(numVariants)].annotations; final double[][] randSigma = new double[numAnnotations][numAnnotations]; for( int jjj = 0; jjj < numAnnotations; jjj++ ) { diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java index 83d079582..c19ae74be 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantoptimizer/VariantOptimizer.java @@ -61,14 +61,12 @@ public class VariantOptimizer extends RodWalker private boolean IGNORE_ALL_INPUT_FILTERS = false; @Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false) private String[] IGNORE_INPUT_FILTERS = null; - @Argument(fullName="use_annotation", shortName="an", doc="The names of the annotations which should used for calculations", required=true) private String[] USE_ANNOTATIONS = null; - @Argument(fullName="clusterFile", shortName="clusterFile", doc="The output cluster file", required=true) private String CLUSTER_FILENAME = "optimizer.cluster"; @Argument(fullName="numGaussians", shortName="nG", doc="The number of Gaussians to be used in the Gaussian Mixture model", required=false) - private int NUM_GAUSSIANS = 7; + private int NUM_GAUSSIANS = 1; @Argument(fullName="numIterations", shortName="nI", doc="The number of iterations to be performed in the Gaussian Mixture model", required=false) private int NUM_ITERATIONS = 10; @Argument(fullName="minVarInCluster", shortName="minVar", doc="The minimum number of variants in a cluster to be considered a valid cluster. It can be used to prevent overfitting.", required=false) @@ -148,9 +146,9 @@ public class VariantOptimizer extends RodWalker boolean isKnown = !vc.getAttribute("ID").equals("."); if(usingDBSNP) { isKnown = false; - for( VariantContext dbsnpVC : tracker.getVariantContexts(DbSNPHelper.STANDARD_DBSNP_TRACK_NAME, null, context.getLocation(), false, false) ) { + for( final VariantContext dbsnpVC : tracker.getVariantContexts(DbSNPHelper.STANDARD_DBSNP_TRACK_NAME, null, context.getLocation(), false, false) ) { if(dbsnpVC != null && dbsnpVC.isSNP()) { - isKnown=true; + isKnown = true; } } }