Misc cleanup in variant recalibrator.

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@3380 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
rpoplin 2010-05-18 17:37:01 +00:00
parent eb200e4cce
commit 9e15299475
5 changed files with 131 additions and 539 deletions

View File

@ -193,7 +193,7 @@ public class ApplyVariantClustersWalker extends RodWalker<ExpandingArrayList<Var
vcf.addInfoField("OQ", ((Double)vc.getPhredScaledQual()).toString() );
vcf.setQual( variantDatum.qual );
vcf.setFilterString(VCFRecord.UNFILTERED); //BUGBUG: Set to passes filters
vcf.setFilterString(VCFRecord.PASSES_FILTERS);
vcfWriter.addRecord( vcf );
} else { // not a SNP or is filtered so just dump it out to the VCF file

View File

@ -1,303 +0,0 @@
/*
* Copyright (c) 2010 The Broad Institute
*
* Permission is hereby granted, free of charge, to any person
* obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without
* restriction, including without limitation the rights to use,
* copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following
* conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR
* THE USE OR OTHER DEALINGS IN THE SOFTWARE.
*/
package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer;
import org.broad.tribble.vcf.VCFRecord;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.contexts.variantcontext.VariantContext;
import org.broadinstitute.sting.gatk.datasources.simpleDataSources.ReferenceOrderedDataSource;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.RodWalker;
import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
import org.broadinstitute.sting.utils.collections.Pair;
import org.broadinstitute.sting.utils.StingException;
import org.broadinstitute.sting.commandline.Argument;
import java.io.IOException;
import java.io.PrintStream;
import java.util.HashMap;
/**
* Calculates variant concordance with the given truth sets and plots ROC curves for each input call set.
*
* @author rpoplin
* @since Mar 18, 2010
*
* @help.summary Calculates variant concordance with the given truth sets and plots ROC curves for each input call set.
*/
public class VariantConcordanceROCCurveWalker extends RodWalker<ExpandingArrayList<Pair<String,VariantDatum>>, HashMap<String,ExpandingArrayList<VariantDatum>>> {
/////////////////////////////
// Command Line Arguments
/////////////////////////////
@Argument(fullName="output_prefix", shortName="output", doc="The prefix added to output VCF file name and optimization curve pdf file name", required=false)
private String OUTPUT_PREFIX = "optimizer";
@Argument(fullName = "path_to_Rscript", shortName = "Rscript", doc = "The path to your implementation of Rscript. For Broad users this is probably /broad/tools/apps/R-2.6.0/bin/Rscript", required = false)
private String PATH_TO_RSCRIPT = "/broad/tools/apps/R-2.6.0/bin/Rscript";
@Argument(fullName = "path_to_resources", shortName = "resources", doc = "Path to resources folder holding the Sting R scripts.", required = false)
private String PATH_TO_RESOURCES = "R/";
/////////////////////////////
// Private Member Variables
/////////////////////////////
private final ExpandingArrayList<String> inputRodNames = new ExpandingArrayList<String>();
private int numCurves;
private int[] trueNegGlobal;
private int[] falseNegGlobal;
private String sampleName = null;
private boolean multiSampleCalls = false;
//---------------------------------------------------------------------------------------------------------------
//
// initialize
//
//---------------------------------------------------------------------------------------------------------------
public void initialize() {
if( !PATH_TO_RESOURCES.endsWith("/") ) { PATH_TO_RESOURCES = PATH_TO_RESOURCES + "/"; }
for( ReferenceOrderedDataSource rod : this.getToolkit().getRodDataSources() ) {
if( rod != null && !rod.getName().toUpperCase().startsWith("TRUTH") ) {
if( rod.getReferenceOrderedData().getIterator().hasNext() && rod.getReferenceOrderedData().getIterator().next().getUnderlyingObject() instanceof VCFRecord ) {
inputRodNames.add(rod.getName());
System.out.println("Adding " + rod.getName() + " to input RodVCF list.");
if( sampleName == null && !multiSampleCalls ) {
final String[] samples = ((VCFRecord)rod.getReferenceOrderedData().getIterator().next().getUnderlyingObject()).getSampleNames();
if( samples.length > 1 ) {
multiSampleCalls = true;
System.out.println("Found multi sample calls.");
} else {
sampleName = samples[0];
System.out.println("Found single sample calls.");
}
}
}
}
}
numCurves = inputRodNames.size();
trueNegGlobal = new int[numCurves];
falseNegGlobal = new int[numCurves];
for( int kkk = 0; kkk < numCurves; kkk++ ) {
trueNegGlobal[kkk] = 0;
falseNegGlobal[kkk] = 0;
}
}
//---------------------------------------------------------------------------------------------------------------
//
// map
//
//---------------------------------------------------------------------------------------------------------------
public ExpandingArrayList<Pair<String,VariantDatum>> map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) {
final ExpandingArrayList<Pair<String,VariantDatum>> mapList = new ExpandingArrayList<Pair<String,VariantDatum>>();
if( tracker == null ) { // For some reason RodWalkers get map calls with null trackers
return mapList;
}
boolean isInTruthSet = false;
boolean isTrueVariant = false;
for( final VariantContext vc : tracker.getAllVariantContexts(ref, null, context.getLocation(), false, false) ) {
if( vc != null && vc.getName().toUpperCase().startsWith("TRUTH") ) {
if( vc.isSNP() && !vc.isFiltered() ) {
if( multiSampleCalls ) {
isInTruthSet = true;
if( vc.isPolymorphic() ) {
isTrueVariant = true;
}
} else {
if( !vc.getGenotype(sampleName).isNoCall() ) {
isInTruthSet = true;
if( !vc.getGenotype(sampleName).isHomRef() ) {
isTrueVariant = true;
}
}
}
}
}
}
if( !isInTruthSet ) {
return mapList; // jump out if this location isn't in the truth set
}
int curveIndex = 0;
for( final String inputName : inputRodNames ) {
final VariantContext vc = tracker.getVariantContext( inputName, null, context.getLocation(), false ); // assuming single variant per track per location
if( vc != null && vc.isPolymorphic() && !vc.isFiltered() ) {
final VariantDatum variantDatum = new VariantDatum();
variantDatum.qual = vc.getPhredScaledQual();
variantDatum.isTrueVariant = isTrueVariant;
mapList.add( new Pair<String,VariantDatum>(inputName, variantDatum) );
}
else if ( vc != null && vc.isFiltered() && vc.getFilters().contains("HARD_TO_VALIDATE") ) { // hard to validate sites are hard to validate so don't count them
}
else { // Either not in the call set at all, is a monomorphic call, or was filtered out
if( isTrueVariant ) { // ... but it is a true variant so this is a false negative call
falseNegGlobal[curveIndex]++;
} else { // ... and it is not a variant site so this is a true negative call
trueNegGlobal[curveIndex]++;
}
}
curveIndex++;
}
return mapList;
}
//---------------------------------------------------------------------------------------------------------------
//
// reduce
//
//---------------------------------------------------------------------------------------------------------------
public HashMap<String,ExpandingArrayList<VariantDatum>> reduceInit() {
final HashMap<String,ExpandingArrayList<VariantDatum>> init = new HashMap<String,ExpandingArrayList<VariantDatum>>();
for( final String inputName : inputRodNames ) {
init.put( inputName, new ExpandingArrayList<VariantDatum>() );
}
return init;
}
public HashMap<String,ExpandingArrayList<VariantDatum>> reduce( final ExpandingArrayList<Pair<String,VariantDatum>> mapValue, final HashMap<String,ExpandingArrayList<VariantDatum>> reduceSum ) {
for( Pair<String,VariantDatum> value : mapValue ) {
final ExpandingArrayList<VariantDatum> list = reduceSum.get(value.getFirst());
list.add(value.getSecond());
reduceSum.put(value.getFirst(),list);
}
return reduceSum;
}
public void onTraversalDone( HashMap<String,ExpandingArrayList<VariantDatum>> reduceSum ) {
final int NUM_CURVES = numCurves;
final HashMap<String, VariantDataManager> dataManagerMap = new HashMap<String, VariantDataManager>();
for( final String inputName : inputRodNames ) {
System.out.println("Creating data manager for: " + inputName);
dataManagerMap.put(inputName, new VariantDataManager( reduceSum.get(inputName), null ));
}
reduceSum.clear(); // Don't need this ever again, clean up some memory
final double[] minQual = new double[NUM_CURVES];
final double[] maxQual = new double[NUM_CURVES];
final double[] incrementQual = new double[NUM_CURVES];
final double[] qualCut = new double[NUM_CURVES];
final int NUM_STEPS = 200;
int curveIndex = 0;
for( final String inputName : inputRodNames ) {
final VariantDataManager dataManager = dataManagerMap.get(inputName);
minQual[curveIndex] = dataManager.data[0].qual;
maxQual[curveIndex] = dataManager.data[0].qual;
for( int iii = 1; iii < dataManager.data.length; iii++ ) {
final double qual = dataManager.data[iii].qual;
if( qual < minQual[curveIndex] ) { minQual[curveIndex] = qual; }
else if( qual > maxQual[curveIndex] ) { maxQual[curveIndex] = qual; }
}
incrementQual[curveIndex] = (maxQual[curveIndex] - minQual[curveIndex]) / ((double)NUM_STEPS);
qualCut[curveIndex] = minQual[curveIndex];
curveIndex++;
}
final int[] truePos = new int[NUM_CURVES];
final int[] falsePos = new int[NUM_CURVES];
final int[] trueNeg = new int[NUM_CURVES];
final int[] falseNeg = new int[NUM_CURVES];
PrintStream outputFile;
try {
outputFile = new PrintStream( OUTPUT_PREFIX + ".dat" );
} catch (Exception e) {
throw new StingException( "Unable to create output file: " + OUTPUT_PREFIX + ".dat" );
}
int jjj = 1;
for( final String inputName : inputRodNames ) {
outputFile.print(inputName + ",sensitivity" + jjj + ",specificity" + jjj + ",");
jjj++;
}
outputFile.println("sentinel");
for( int step = 0; step < NUM_STEPS; step++ ) {
curveIndex = 0;
for( final String inputName : inputRodNames ) {
final VariantDataManager dataManager = dataManagerMap.get(inputName);
truePos[curveIndex] = 0;
falsePos[curveIndex] = 0;
trueNeg[curveIndex] = 0;
falseNeg[curveIndex] = 0;
final int NUM_VARIANTS = dataManager.data.length;
for( int iii = 0; iii < NUM_VARIANTS; iii++ ) {
if( dataManager.data[iii].qual >= qualCut[curveIndex] ) { // this var is in this hypothetical call set
if( dataManager.data[iii].isTrueVariant ) {
truePos[curveIndex]++;
} else {
falsePos[curveIndex]++;
}
} else { // this var is out of this hypothetical call set
if( dataManager.data[iii].isTrueVariant ) {
falseNeg[curveIndex]++;
} else {
trueNeg[curveIndex]++;
}
}
}
final double sensitivity = ((double) truePos[curveIndex]) / ((double) truePos[curveIndex] + falseNegGlobal[curveIndex] + falseNeg[curveIndex]);
final double specificity = ((double) trueNegGlobal[curveIndex] + trueNeg[curveIndex]) /
((double) falsePos[curveIndex] + trueNegGlobal[curveIndex] + trueNeg[curveIndex]);
outputFile.print( String.format("%.8f,%.8f,%.8f,", qualCut[curveIndex], sensitivity, 1.0 - specificity) );
qualCut[curveIndex] += incrementQual[curveIndex];
curveIndex++;
}
outputFile.println("-1");
}
outputFile.close();
// Execute Rscript command to plot the optimization curve
// Print out the command line to make it clear to the user what is being executed and how one might modify it
final String rScriptCommandLine = PATH_TO_RSCRIPT + " " + PATH_TO_RESOURCES + "plot_variantROCCurve.R" + " " + OUTPUT_PREFIX + ".dat";
System.out.println( rScriptCommandLine );
// Execute the RScript command to plot the table of truth values
try {
Runtime.getRuntime().exec( rScriptCommandLine );
} catch ( IOException e ) {
throw new StingException( "Unable to execute RScript command: " + rScriptCommandLine );
}
}
}

View File

@ -35,7 +35,5 @@ public class VariantDatum {
public double[] annotations;
public boolean isTransition;
public boolean isKnown;
public boolean isTrueVariant;
public boolean isHet;
public double qual;
}

View File

@ -61,7 +61,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
private final double[][] mu; // The means for the clusters
private final Matrix[] sigma; // The variances for the clusters, sigma is really sigma^2
private final Matrix[] sigmaInverse;
private final boolean[] deadCluster;
private final double[] pCluster;
private final double[] determinant;
private final double[] clusterTITV;
@ -82,7 +81,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
mu = new double[numGaussians][];
sigma = new Matrix[numGaussians];
determinant = new double[numGaussians];
deadCluster = new boolean[numGaussians];
pCluster = new double[numGaussians];
clusterTITV = new double[numGaussians];
clusterTruePositiveRate = new double[numGaussians];
@ -114,7 +112,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
// Several of the clustering parameters aren't used the second time around in ApplyVariantClusters
numIterations = 0;
clusterTITV = null;
deadCluster = null;
minVarInCluster = 0;
// BUGBUG: move this parsing out of the constructor
@ -144,7 +141,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
sigma[kkk] = new Matrix(sigmaVals[kkk]);
sigmaInverse[kkk] = sigma[kkk].inverse(); // Precompute all the inverses and determinants for use later
determinant[kkk] = sigma[kkk].det();
//if( determinant[kkk] < MIN_DETERMINANT ) { determinant[kkk] = MIN_DETERMINANT; }
kkk++;
}
isUsingTiTvModel = _isUsingTiTvModel;
@ -252,12 +248,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
// Set up the initial random Gaussians
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
deadCluster[kkk] = false;
pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster));
//final double[] randMu = new double[numAnnotations];
//for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
// randMu[jjj] = -1.5 + 3.0 * rand.nextDouble();
//}
mu[kkk] = data[rand.nextInt(numVariants)].annotations;
final double[][] randSigma = new double[numAnnotations][numAnnotations];
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
@ -270,18 +261,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
tmp = tmp.times(tmp.transpose());
sigma[kkk] = tmp;
determinant[kkk] = sigma[kkk].det();
//if( determinant[kkk] < MIN_DETERMINANT ) { deadCluster[kkk] = true; }
}
// The EM loop
for( int ttt = 0; ttt < numIterations; ttt++ ) {
//int numValidClusters = 0;
//for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
// if( !deadCluster[kkk] ) { numValidClusters++; }
//}
//logger.info("Starting iteration " + (ttt+1) + " with " + numValidClusters + " clusters.");
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Expectation Step (calculate the probability that each data point is in each cluster)
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
@ -309,7 +293,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
clusterTITV[kkk] = 0.0;
clusterTruePositiveRate[kkk] = 1.0;
}
printClusterParamters( clusterFileName + ".WithoutTiTv." + iterationNumber );
printClusterParameters( clusterFileName + ".WithoutTiTv." + iterationNumber );
return;
}
@ -340,29 +324,26 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
final boolean isTransition = data[iii].isTransition;
final boolean isKnown = data[iii].isKnown;
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
if( !deadCluster[kkk] ) {
final double prob = pVarInCluster[kkk][iii];
if( isKnown ) { // known
probKnown[kkk] += prob;
if( isTransition ) { // transition
probKnownTi[kkk] += prob;
probTi[kkk] += prob;
} else { // transversion
probKnownTv[kkk] += prob;
probTv[kkk] += prob;
}
} else { //novel
probNovel[kkk] += prob;
if( isTransition ) { // transition
probNovelTi[kkk] += prob;
probTi[kkk] += prob;
} else { // transversion
probNovelTv[kkk] += prob;
probTv[kkk] += prob;
}
final double prob = pVarInCluster[kkk][iii];
if( isKnown ) { // known
probKnown[kkk] += prob;
if( isTransition ) { // transition
probKnownTi[kkk] += prob;
probTi[kkk] += prob;
} else { // transversion
probKnownTv[kkk] += prob;
probTv[kkk] += prob;
}
} else { //novel
probNovel[kkk] += prob;
if( isTransition ) { // transition
probNovelTi[kkk] += prob;
probTi[kkk] += prob;
} else { // transversion
probNovelTv[kkk] += prob;
probTv[kkk] += prob;
}
}
}
}
@ -376,50 +357,46 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
knownAlphaFactor = 0.5;
}
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
if( !deadCluster[kkk] ) {
clusterTITV[kkk] = probTi[kkk] / probTv[kkk];
if( probKnown[kkk] > 500.0 && probNovel[kkk] > 500.0 ) {
clusterTruePositiveRate[kkk] = calcTruePositiveRateFromKnownTITV( probKnownTi[kkk] / probKnownTv[kkk], probNovelTi[kkk] / probNovelTv[kkk], clusterTITV[kkk], knownAlphaFactor );
} else {
clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] );
}
clusterTITV[kkk] = probTi[kkk] / probTv[kkk];
if( probKnown[kkk] > 500.0 && probNovel[kkk] > 500.0 ) {
clusterTruePositiveRate[kkk] = calcTruePositiveRateFromKnownTITV( probKnownTi[kkk] / probKnownTv[kkk], probNovelTi[kkk] / probNovelTv[kkk], clusterTITV[kkk], knownAlphaFactor );
} else {
clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] );
}
}
if( ttt == 0 ) {
printClusterParamters( clusterFileName + ".TargetTiTv." + iterationNumber );
printClusterParameters( clusterFileName + ".TargetTiTv." + iterationNumber );
} else if( ttt == 1 ) {
printClusterParamters( clusterFileName + ".KnownTiTv." + iterationNumber );
printClusterParameters( clusterFileName + ".KnownTiTv." + iterationNumber );
} else if( ttt == 2 ) {
printClusterParamters( clusterFileName + ".BlendedTiTv." + iterationNumber );
printClusterParameters( clusterFileName + ".BlendedTiTv." + iterationNumber );
}
}
}
private void printClusterParamters( final String clusterFileName ) {
private void printClusterParameters( final String clusterFileName ) {
try {
final PrintStream outputFile = new PrintStream( clusterFileName );
dataManager.printClusterFileHeader( outputFile );
final int numAnnotations = mu[0].length;
final int numVariants = dataManager.numVariants;
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
if( !deadCluster[kkk] ) {
if( pCluster[kkk] * numVariants > minVarInCluster ) {
final double sigmaVals[][] = sigma[kkk].getArray();
outputFile.print("@!CLUSTER,");
outputFile.print(pCluster[kkk] + ",");
outputFile.print(clusterTITV[kkk] + ",");
outputFile.print(clusterTruePositiveRate[kkk] + ",");
for(int jjj = 0; jjj < numAnnotations; jjj++ ) {
outputFile.print(mu[kkk][jjj] + ",");
}
for(int jjj = 0; jjj < numAnnotations; jjj++ ) {
for(int ppp = 0; ppp < numAnnotations; ppp++ ) {
outputFile.print(sigmaVals[jjj][ppp] + ",");
}
}
outputFile.println(-1);
if( pCluster[kkk] * numVariants > minVarInCluster ) {
final double sigmaVals[][] = sigma[kkk].getArray();
outputFile.print("@!CLUSTER,");
outputFile.print(pCluster[kkk] + ",");
outputFile.print(clusterTITV[kkk] + ",");
outputFile.print(clusterTruePositiveRate[kkk] + ",");
for(int jjj = 0; jjj < numAnnotations; jjj++ ) {
outputFile.print(mu[kkk][jjj] + ",");
}
for(int jjj = 0; jjj < numAnnotations; jjj++ ) {
for(int ppp = 0; ppp < numAnnotations; ppp++ ) {
outputFile.print(sigmaVals[jjj][ppp] + ",");
}
}
outputFile.println(-1);
}
}
outputFile.close();
@ -458,52 +435,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
evaluateGaussiansForSingleVariant( annotations, pVarInCluster );
//if( isUsingTiTvModel ) {
// Sum prob model
double sum = 0.0;
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk];
}
return sum;
/*
} else {
// Max prob model
double maxProb = 0.0;
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
if( pVarInCluster[kkk] > maxProb ) {
maxProb = pVarInCluster[kkk];
}
}
return maxProb;
}
*/
// Max prob model
/*
double maxProb = 0.0;
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
if( pVarInCluster[kkk] > maxProb ) {
maxProb = pVarInCluster[kkk];
}
}
return maxProb;
*/
// Entropy model
/*
double sum = 0.0;
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
//if( isHetCluster[kkk] == isHet ) {
sum += pVarInCluster[kkk] * Math.log(pVarInCluster[kkk]);
//}
sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk];
}
double entropy = -1.0 * sum;
double maxEntropy = -1.0 * Math.log( 1.0 / ((double) numGaussians));
//System.out.println("H = " + entropy + ", pTrue = " + ( 1.0 - (entropy / maxEntropy) ));
return ( 1.0 - (entropy / maxEntropy) );
*/
return sum;
}
public final void outputClusterReports( final String outputPrefix ) {
@ -688,60 +624,56 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
final double sigmaVals[][][] = new double[numGaussians][][];
final double denom[] = new double[numGaussians];
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
if( !deadCluster[kkk] ) {
sigmaVals[kkk] = sigma[kkk].inverse().getArray();
denom[kkk] = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(Math.abs(determinant[kkk]), 0.5);
}
sigmaVals[kkk] = sigma[kkk].inverse().getArray();
denom[kkk] = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(Math.abs(determinant[kkk]), 0.5);
}
final double mult[] = new double[numAnnotations];
for( int iii = 0; iii < data.length; iii++ ) {
double sumProb = 0.0;
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
if( !deadCluster[kkk] ) {
double sum = 0.0;
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
mult[jjj] = 0.0;
for( int ppp = 0; ppp < numAnnotations; ppp++ ) {
mult[jjj] += (data[iii].annotations[ppp] - mu[kkk][ppp]) * sigmaVals[kkk][ppp][jjj];
}
double sum = 0.0;
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
mult[jjj] = 0.0;
for( int ppp = 0; ppp < numAnnotations; ppp++ ) {
mult[jjj] += (data[iii].annotations[ppp] - mu[kkk][ppp]) * sigmaVals[kkk][ppp][jjj];
}
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
sum += mult[jjj] * (data[iii].annotations[jjj] - mu[kkk][jjj]);
}
pVarInCluster[kkk][iii] = pCluster[kkk] * (Math.exp( -0.5 * sum ) / denom[kkk]);
likelihood += pVarInCluster[kkk][iii];
if(Double.isNaN(denom[kkk]) || determinant[kkk] < 0.5 * MIN_DETERMINANT) {
System.out.println("det = " + sigma[kkk].det());
System.out.println("denom = " + denom[kkk]);
System.out.println("sumExp = " + sum);
System.out.println("pVar = " + pVarInCluster[kkk][iii]);
System.out.println("=-------=");
throw new StingException("Numerical Instability! determinant of covariance matrix <= 0. Try running with fewer clusters and then with better behaved annotation values.");
}
if(sum < 0.0) {
System.out.println("det = " + sigma[kkk].det());
System.out.println("denom = " + denom[kkk]);
System.out.println("sumExp = " + sum);
System.out.println("pVar = " + pVarInCluster[kkk][iii]);
System.out.println("=-------=");
throw new StingException("Numerical Instability! covariance matrix no longer positive definite. Try running with fewer clusters and then with better behaved annotation values.");
}
if(pVarInCluster[kkk][iii] > 1.0) {
System.out.println("det = " + sigma[kkk].det());
System.out.println("denom = " + denom[kkk]);
System.out.println("sumExp = " + sum);
System.out.println("pVar = " + pVarInCluster[kkk][iii]);
System.out.println("=-------=");
throw new StingException("Numerical Instability! probability distribution returns > 1.0. Try running with fewer clusters and then with better behaved annotation values.");
}
if( pVarInCluster[kkk][iii] < MIN_PROB) { // Very small numbers are a very big problem
pVarInCluster[kkk][iii] = MIN_PROB;// + MIN_PROB * rand.nextDouble();
}
sumProb += pVarInCluster[kkk][iii];
}
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
sum += mult[jjj] * (data[iii].annotations[jjj] - mu[kkk][jjj]);
}
pVarInCluster[kkk][iii] = pCluster[kkk] * (Math.exp( -0.5 * sum ) / denom[kkk]);
likelihood += pVarInCluster[kkk][iii];
if(Double.isNaN(denom[kkk]) || determinant[kkk] < 0.5 * MIN_DETERMINANT) {
System.out.println("det = " + sigma[kkk].det());
System.out.println("denom = " + denom[kkk]);
System.out.println("sumExp = " + sum);
System.out.println("pVar = " + pVarInCluster[kkk][iii]);
System.out.println("=-------=");
throw new StingException("Numerical Instability! determinant of covariance matrix <= 0. Try running with fewer clusters and then with better behaved annotation values.");
}
if(sum < 0.0) {
System.out.println("det = " + sigma[kkk].det());
System.out.println("denom = " + denom[kkk]);
System.out.println("sumExp = " + sum);
System.out.println("pVar = " + pVarInCluster[kkk][iii]);
System.out.println("=-------=");
throw new StingException("Numerical Instability! covariance matrix no longer positive definite. Try running with fewer clusters and then with better behaved annotation values.");
}
if(pVarInCluster[kkk][iii] > 1.0) {
System.out.println("det = " + sigma[kkk].det());
System.out.println("denom = " + denom[kkk]);
System.out.println("sumExp = " + sum);
System.out.println("pVar = " + pVarInCluster[kkk][iii]);
System.out.println("=-------=");
throw new StingException("Numerical Instability! probability distribution returns > 1.0. Try running with fewer clusters and then with better behaved annotation values.");
}
if( pVarInCluster[kkk][iii] < MIN_PROB) { // Very small numbers are a very big problem
pVarInCluster[kkk][iii] = MIN_PROB;// + MIN_PROB * rand.nextDouble();
}
sumProb += pVarInCluster[kkk][iii];
}
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
@ -757,8 +689,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster ) {
final int numAnnotations = annotations.length;
double sumProb = 0.0;
final double mult[] = new double[numAnnotations];
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
final double sigmaVals[][] = sigmaInverse[kkk].getArray();
@ -775,28 +705,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
final double denom = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(determinant[kkk], 0.5);
pVarInCluster[kkk] = pCluster[kkk] * (Math.exp( -0.5 * sum )) / denom;
/*
if( isUsingTiTvModel ) {
//pVarInCluster[kkk] = Math.exp( -0.5 * sum );
if( pVarInCluster[kkk] < MIN_PROB) { // Very small numbers are a very big problem
pVarInCluster[kkk] = MIN_PROB;
}
sumProb += pVarInCluster[kkk];
} else {
//final double denom = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(determinant[kkk], 0.5);
//pVarInCluster[kkk] = pCluster[kkk] * (Math.exp( -0.5 * sum )) / denom;
//pVarInCluster[kkk] = Math.exp( -0.5 * sum );
// BUGBUG: should pCluster be the distribution from the GMM or a uniform distribution here?
}
*/
}
//if( isUsingTiTvModel ) {
// for( int kkk = 0; kkk < numGaussians; kkk++ ) {
// pVarInCluster[kkk] /= sumProb;
// }
//}
}
@ -816,53 +725,48 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
}
double sumPK = 0.0;
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
if( !deadCluster[kkk] ) {
double sumProb = 0.0;
for( int iii = 0; iii < numVariants; iii++ ) {
final double prob = pVarInCluster[kkk][iii];
sumProb += prob;
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
mu[kkk][jjj] += prob * data[iii].annotations[jjj];
}
}
double sumProb = 0.0;
for( int iii = 0; iii < numVariants; iii++ ) {
final double prob = pVarInCluster[kkk][iii];
sumProb += prob;
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
mu[kkk][jjj] /= sumProb;
}
for( int iii = 0; iii < numVariants; iii++ ) {
final double prob = pVarInCluster[kkk][iii];
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
for( int ppp = jjj; ppp < numAnnotations; ppp++ ) {
sigmaVals[kkk][jjj][ppp] += prob * (data[iii].annotations[jjj]-mu[kkk][jjj]) * (data[iii].annotations[ppp]-mu[kkk][ppp]);
}
}
}
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
for( int ppp = jjj; ppp < numAnnotations; ppp++ ) {
if( sigmaVals[kkk][jjj][ppp] < MIN_SIGMA ) { // Very small numbers are a very big problem
sigmaVals[kkk][jjj][ppp] = MIN_SIGMA;// + MIN_SIGMA * rand.nextDouble();
}
sigmaVals[kkk][ppp][jjj] = sigmaVals[kkk][jjj][ppp]; // sigma must be a symmetric matrix
}
}
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
for( int ppp = 0; ppp < numAnnotations; ppp++ ) {
sigmaVals[kkk][jjj][ppp] /= sumProb;
}
}
sigma[kkk] = new Matrix(sigmaVals[kkk]);
determinant[kkk] = sigma[kkk].det();
//if( determinant[kkk] < MIN_DETERMINANT ) { deadCluster[kkk] = true; }
if( !deadCluster[kkk] ) {
pCluster[kkk] = sumProb / numVariants;
sumPK += pCluster[kkk];
mu[kkk][jjj] += prob * data[iii].annotations[jjj];
}
}
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
mu[kkk][jjj] /= sumProb;
}
for( int iii = 0; iii < numVariants; iii++ ) {
final double prob = pVarInCluster[kkk][iii];
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
for( int ppp = jjj; ppp < numAnnotations; ppp++ ) {
sigmaVals[kkk][jjj][ppp] += prob * (data[iii].annotations[jjj]-mu[kkk][jjj]) * (data[iii].annotations[ppp]-mu[kkk][ppp]);
}
}
}
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
for( int ppp = jjj; ppp < numAnnotations; ppp++ ) {
if( sigmaVals[kkk][jjj][ppp] < MIN_SIGMA ) { // Very small numbers are a very big problem
sigmaVals[kkk][jjj][ppp] = MIN_SIGMA;// + MIN_SIGMA * rand.nextDouble();
}
sigmaVals[kkk][ppp][jjj] = sigmaVals[kkk][jjj][ppp]; // sigma must be a symmetric matrix
}
}
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
for( int ppp = 0; ppp < numAnnotations; ppp++ ) {
sigmaVals[kkk][jjj][ppp] /= sumProb;
}
}
sigma[kkk] = new Matrix(sigmaVals[kkk]);
determinant[kkk] = sigma[kkk].det();
pCluster[kkk] = sumProb / numVariants;
sumPK += pCluster[kkk];
}
// ensure pCluster sums to one, it doesn't automatically due to very small numbers getting capped
@ -924,15 +828,10 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
// Replace extremely small clusters with another random draw from the dataset
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
//if(determinant[kkk] < MIN_DETERMINANT ) {
if( pCluster[kkk] < 0.0005 * (1.0 / ((double) (stopCluster-startCluster))) ||
determinant[kkk] < MIN_DETERMINANT ) { // This is a very small cluster compared to all the others
logger.info("!! Found singular cluster! Initializing a new random cluster.");
pCluster[kkk] = 0.1 / ((double) (stopCluster-startCluster)); // 0.5 /
//final double[] randMu = new double[numAnnotations];
//for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
// randMu[jjj] = -1.5 + 3.0 * rand.nextDouble();
//}
pCluster[kkk] = 0.1 / ((double) (stopCluster-startCluster));
mu[kkk] = data[rand.nextInt(numVariants)].annotations;
final double[][] randSigma = new double[numAnnotations][numAnnotations];
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {

View File

@ -61,14 +61,12 @@ public class VariantOptimizer extends RodWalker<ExpandingArrayList<VariantDatum>
private boolean IGNORE_ALL_INPUT_FILTERS = false;
@Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false)
private String[] IGNORE_INPUT_FILTERS = null;
@Argument(fullName="use_annotation", shortName="an", doc="The names of the annotations which should used for calculations", required=true)
private String[] USE_ANNOTATIONS = null;
@Argument(fullName="clusterFile", shortName="clusterFile", doc="The output cluster file", required=true)
private String CLUSTER_FILENAME = "optimizer.cluster";
@Argument(fullName="numGaussians", shortName="nG", doc="The number of Gaussians to be used in the Gaussian Mixture model", required=false)
private int NUM_GAUSSIANS = 7;
private int NUM_GAUSSIANS = 1;
@Argument(fullName="numIterations", shortName="nI", doc="The number of iterations to be performed in the Gaussian Mixture model", required=false)
private int NUM_ITERATIONS = 10;
@Argument(fullName="minVarInCluster", shortName="minVar", doc="The minimum number of variants in a cluster to be considered a valid cluster. It can be used to prevent overfitting.", required=false)
@ -148,9 +146,9 @@ public class VariantOptimizer extends RodWalker<ExpandingArrayList<VariantDatum>
boolean isKnown = !vc.getAttribute("ID").equals(".");
if(usingDBSNP) {
isKnown = false;
for( VariantContext dbsnpVC : tracker.getVariantContexts(DbSNPHelper.STANDARD_DBSNP_TRACK_NAME, null, context.getLocation(), false, false) ) {
for( final VariantContext dbsnpVC : tracker.getVariantContexts(DbSNPHelper.STANDARD_DBSNP_TRACK_NAME, null, context.getLocation(), false, false) ) {
if(dbsnpVC != null && dbsnpVC.isSNP()) {
isKnown=true;
isKnown = true;
}
}
}