Misc cleanup in variant recalibrator.
git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@3380 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
parent
eb200e4cce
commit
9e15299475
|
|
@ -193,7 +193,7 @@ public class ApplyVariantClustersWalker extends RodWalker<ExpandingArrayList<Var
|
||||||
|
|
||||||
vcf.addInfoField("OQ", ((Double)vc.getPhredScaledQual()).toString() );
|
vcf.addInfoField("OQ", ((Double)vc.getPhredScaledQual()).toString() );
|
||||||
vcf.setQual( variantDatum.qual );
|
vcf.setQual( variantDatum.qual );
|
||||||
vcf.setFilterString(VCFRecord.UNFILTERED); //BUGBUG: Set to passes filters
|
vcf.setFilterString(VCFRecord.PASSES_FILTERS);
|
||||||
vcfWriter.addRecord( vcf );
|
vcfWriter.addRecord( vcf );
|
||||||
|
|
||||||
} else { // not a SNP or is filtered so just dump it out to the VCF file
|
} else { // not a SNP or is filtered so just dump it out to the VCF file
|
||||||
|
|
|
||||||
|
|
@ -1,303 +0,0 @@
|
||||||
/*
|
|
||||||
* Copyright (c) 2010 The Broad Institute
|
|
||||||
*
|
|
||||||
* Permission is hereby granted, free of charge, to any person
|
|
||||||
* obtaining a copy of this software and associated documentation
|
|
||||||
* files (the "Software"), to deal in the Software without
|
|
||||||
* restriction, including without limitation the rights to use,
|
|
||||||
* copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
||||||
* copies of the Software, and to permit persons to whom the
|
|
||||||
* Software is furnished to do so, subject to the following
|
|
||||||
* conditions:
|
|
||||||
*
|
|
||||||
* The above copyright notice and this permission notice shall be
|
|
||||||
* included in all copies or substantial portions of the Software.
|
|
||||||
*
|
|
||||||
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
|
|
||||||
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
|
|
||||||
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
|
|
||||||
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
|
|
||||||
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
|
|
||||||
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
|
|
||||||
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR
|
|
||||||
* THE USE OR OTHER DEALINGS IN THE SOFTWARE.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer;
|
|
||||||
|
|
||||||
import org.broad.tribble.vcf.VCFRecord;
|
|
||||||
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
|
|
||||||
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
|
|
||||||
import org.broadinstitute.sting.gatk.contexts.variantcontext.VariantContext;
|
|
||||||
import org.broadinstitute.sting.gatk.datasources.simpleDataSources.ReferenceOrderedDataSource;
|
|
||||||
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
|
|
||||||
import org.broadinstitute.sting.gatk.walkers.RodWalker;
|
|
||||||
import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
|
|
||||||
import org.broadinstitute.sting.utils.collections.Pair;
|
|
||||||
import org.broadinstitute.sting.utils.StingException;
|
|
||||||
import org.broadinstitute.sting.commandline.Argument;
|
|
||||||
|
|
||||||
import java.io.IOException;
|
|
||||||
import java.io.PrintStream;
|
|
||||||
import java.util.HashMap;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculates variant concordance with the given truth sets and plots ROC curves for each input call set.
|
|
||||||
*
|
|
||||||
* @author rpoplin
|
|
||||||
* @since Mar 18, 2010
|
|
||||||
*
|
|
||||||
* @help.summary Calculates variant concordance with the given truth sets and plots ROC curves for each input call set.
|
|
||||||
*/
|
|
||||||
|
|
||||||
public class VariantConcordanceROCCurveWalker extends RodWalker<ExpandingArrayList<Pair<String,VariantDatum>>, HashMap<String,ExpandingArrayList<VariantDatum>>> {
|
|
||||||
|
|
||||||
/////////////////////////////
|
|
||||||
// Command Line Arguments
|
|
||||||
/////////////////////////////
|
|
||||||
@Argument(fullName="output_prefix", shortName="output", doc="The prefix added to output VCF file name and optimization curve pdf file name", required=false)
|
|
||||||
private String OUTPUT_PREFIX = "optimizer";
|
|
||||||
@Argument(fullName = "path_to_Rscript", shortName = "Rscript", doc = "The path to your implementation of Rscript. For Broad users this is probably /broad/tools/apps/R-2.6.0/bin/Rscript", required = false)
|
|
||||||
private String PATH_TO_RSCRIPT = "/broad/tools/apps/R-2.6.0/bin/Rscript";
|
|
||||||
@Argument(fullName = "path_to_resources", shortName = "resources", doc = "Path to resources folder holding the Sting R scripts.", required = false)
|
|
||||||
private String PATH_TO_RESOURCES = "R/";
|
|
||||||
|
|
||||||
/////////////////////////////
|
|
||||||
// Private Member Variables
|
|
||||||
/////////////////////////////
|
|
||||||
private final ExpandingArrayList<String> inputRodNames = new ExpandingArrayList<String>();
|
|
||||||
private int numCurves;
|
|
||||||
private int[] trueNegGlobal;
|
|
||||||
private int[] falseNegGlobal;
|
|
||||||
private String sampleName = null;
|
|
||||||
private boolean multiSampleCalls = false;
|
|
||||||
|
|
||||||
//---------------------------------------------------------------------------------------------------------------
|
|
||||||
//
|
|
||||||
// initialize
|
|
||||||
//
|
|
||||||
//---------------------------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
public void initialize() {
|
|
||||||
if( !PATH_TO_RESOURCES.endsWith("/") ) { PATH_TO_RESOURCES = PATH_TO_RESOURCES + "/"; }
|
|
||||||
|
|
||||||
for( ReferenceOrderedDataSource rod : this.getToolkit().getRodDataSources() ) {
|
|
||||||
if( rod != null && !rod.getName().toUpperCase().startsWith("TRUTH") ) {
|
|
||||||
if( rod.getReferenceOrderedData().getIterator().hasNext() && rod.getReferenceOrderedData().getIterator().next().getUnderlyingObject() instanceof VCFRecord ) {
|
|
||||||
inputRodNames.add(rod.getName());
|
|
||||||
System.out.println("Adding " + rod.getName() + " to input RodVCF list.");
|
|
||||||
if( sampleName == null && !multiSampleCalls ) {
|
|
||||||
final String[] samples = ((VCFRecord)rod.getReferenceOrderedData().getIterator().next().getUnderlyingObject()).getSampleNames();
|
|
||||||
|
|
||||||
if( samples.length > 1 ) {
|
|
||||||
multiSampleCalls = true;
|
|
||||||
System.out.println("Found multi sample calls.");
|
|
||||||
} else {
|
|
||||||
sampleName = samples[0];
|
|
||||||
System.out.println("Found single sample calls.");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
numCurves = inputRodNames.size();
|
|
||||||
trueNegGlobal = new int[numCurves];
|
|
||||||
falseNegGlobal = new int[numCurves];
|
|
||||||
for( int kkk = 0; kkk < numCurves; kkk++ ) {
|
|
||||||
trueNegGlobal[kkk] = 0;
|
|
||||||
falseNegGlobal[kkk] = 0;
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
//---------------------------------------------------------------------------------------------------------------
|
|
||||||
//
|
|
||||||
// map
|
|
||||||
//
|
|
||||||
//---------------------------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
public ExpandingArrayList<Pair<String,VariantDatum>> map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) {
|
|
||||||
|
|
||||||
final ExpandingArrayList<Pair<String,VariantDatum>> mapList = new ExpandingArrayList<Pair<String,VariantDatum>>();
|
|
||||||
|
|
||||||
if( tracker == null ) { // For some reason RodWalkers get map calls with null trackers
|
|
||||||
return mapList;
|
|
||||||
}
|
|
||||||
|
|
||||||
boolean isInTruthSet = false;
|
|
||||||
boolean isTrueVariant = false;
|
|
||||||
|
|
||||||
for( final VariantContext vc : tracker.getAllVariantContexts(ref, null, context.getLocation(), false, false) ) {
|
|
||||||
if( vc != null && vc.getName().toUpperCase().startsWith("TRUTH") ) {
|
|
||||||
if( vc.isSNP() && !vc.isFiltered() ) {
|
|
||||||
if( multiSampleCalls ) {
|
|
||||||
isInTruthSet = true;
|
|
||||||
if( vc.isPolymorphic() ) {
|
|
||||||
isTrueVariant = true;
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if( !vc.getGenotype(sampleName).isNoCall() ) {
|
|
||||||
isInTruthSet = true;
|
|
||||||
|
|
||||||
if( !vc.getGenotype(sampleName).isHomRef() ) {
|
|
||||||
isTrueVariant = true;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if( !isInTruthSet ) {
|
|
||||||
return mapList; // jump out if this location isn't in the truth set
|
|
||||||
}
|
|
||||||
|
|
||||||
int curveIndex = 0;
|
|
||||||
for( final String inputName : inputRodNames ) {
|
|
||||||
final VariantContext vc = tracker.getVariantContext( inputName, null, context.getLocation(), false ); // assuming single variant per track per location
|
|
||||||
|
|
||||||
if( vc != null && vc.isPolymorphic() && !vc.isFiltered() ) {
|
|
||||||
final VariantDatum variantDatum = new VariantDatum();
|
|
||||||
variantDatum.qual = vc.getPhredScaledQual();
|
|
||||||
variantDatum.isTrueVariant = isTrueVariant;
|
|
||||||
mapList.add( new Pair<String,VariantDatum>(inputName, variantDatum) );
|
|
||||||
}
|
|
||||||
else if ( vc != null && vc.isFiltered() && vc.getFilters().contains("HARD_TO_VALIDATE") ) { // hard to validate sites are hard to validate so don't count them
|
|
||||||
|
|
||||||
}
|
|
||||||
else { // Either not in the call set at all, is a monomorphic call, or was filtered out
|
|
||||||
if( isTrueVariant ) { // ... but it is a true variant so this is a false negative call
|
|
||||||
falseNegGlobal[curveIndex]++;
|
|
||||||
} else { // ... and it is not a variant site so this is a true negative call
|
|
||||||
trueNegGlobal[curveIndex]++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
curveIndex++;
|
|
||||||
}
|
|
||||||
|
|
||||||
return mapList;
|
|
||||||
}
|
|
||||||
|
|
||||||
//---------------------------------------------------------------------------------------------------------------
|
|
||||||
//
|
|
||||||
// reduce
|
|
||||||
//
|
|
||||||
//---------------------------------------------------------------------------------------------------------------
|
|
||||||
|
|
||||||
public HashMap<String,ExpandingArrayList<VariantDatum>> reduceInit() {
|
|
||||||
final HashMap<String,ExpandingArrayList<VariantDatum>> init = new HashMap<String,ExpandingArrayList<VariantDatum>>();
|
|
||||||
for( final String inputName : inputRodNames ) {
|
|
||||||
init.put( inputName, new ExpandingArrayList<VariantDatum>() );
|
|
||||||
}
|
|
||||||
return init;
|
|
||||||
}
|
|
||||||
|
|
||||||
public HashMap<String,ExpandingArrayList<VariantDatum>> reduce( final ExpandingArrayList<Pair<String,VariantDatum>> mapValue, final HashMap<String,ExpandingArrayList<VariantDatum>> reduceSum ) {
|
|
||||||
for( Pair<String,VariantDatum> value : mapValue ) {
|
|
||||||
final ExpandingArrayList<VariantDatum> list = reduceSum.get(value.getFirst());
|
|
||||||
list.add(value.getSecond());
|
|
||||||
reduceSum.put(value.getFirst(),list);
|
|
||||||
}
|
|
||||||
return reduceSum;
|
|
||||||
}
|
|
||||||
|
|
||||||
public void onTraversalDone( HashMap<String,ExpandingArrayList<VariantDatum>> reduceSum ) {
|
|
||||||
|
|
||||||
final int NUM_CURVES = numCurves;
|
|
||||||
final HashMap<String, VariantDataManager> dataManagerMap = new HashMap<String, VariantDataManager>();
|
|
||||||
for( final String inputName : inputRodNames ) {
|
|
||||||
System.out.println("Creating data manager for: " + inputName);
|
|
||||||
dataManagerMap.put(inputName, new VariantDataManager( reduceSum.get(inputName), null ));
|
|
||||||
}
|
|
||||||
reduceSum.clear(); // Don't need this ever again, clean up some memory
|
|
||||||
|
|
||||||
final double[] minQual = new double[NUM_CURVES];
|
|
||||||
final double[] maxQual = new double[NUM_CURVES];
|
|
||||||
final double[] incrementQual = new double[NUM_CURVES];
|
|
||||||
final double[] qualCut = new double[NUM_CURVES];
|
|
||||||
|
|
||||||
final int NUM_STEPS = 200;
|
|
||||||
|
|
||||||
int curveIndex = 0;
|
|
||||||
for( final String inputName : inputRodNames ) {
|
|
||||||
final VariantDataManager dataManager = dataManagerMap.get(inputName);
|
|
||||||
minQual[curveIndex] = dataManager.data[0].qual;
|
|
||||||
maxQual[curveIndex] = dataManager.data[0].qual;
|
|
||||||
for( int iii = 1; iii < dataManager.data.length; iii++ ) {
|
|
||||||
final double qual = dataManager.data[iii].qual;
|
|
||||||
if( qual < minQual[curveIndex] ) { minQual[curveIndex] = qual; }
|
|
||||||
else if( qual > maxQual[curveIndex] ) { maxQual[curveIndex] = qual; }
|
|
||||||
}
|
|
||||||
incrementQual[curveIndex] = (maxQual[curveIndex] - minQual[curveIndex]) / ((double)NUM_STEPS);
|
|
||||||
qualCut[curveIndex] = minQual[curveIndex];
|
|
||||||
curveIndex++;
|
|
||||||
}
|
|
||||||
|
|
||||||
final int[] truePos = new int[NUM_CURVES];
|
|
||||||
final int[] falsePos = new int[NUM_CURVES];
|
|
||||||
final int[] trueNeg = new int[NUM_CURVES];
|
|
||||||
final int[] falseNeg = new int[NUM_CURVES];
|
|
||||||
|
|
||||||
PrintStream outputFile;
|
|
||||||
try {
|
|
||||||
outputFile = new PrintStream( OUTPUT_PREFIX + ".dat" );
|
|
||||||
} catch (Exception e) {
|
|
||||||
throw new StingException( "Unable to create output file: " + OUTPUT_PREFIX + ".dat" );
|
|
||||||
}
|
|
||||||
int jjj = 1;
|
|
||||||
for( final String inputName : inputRodNames ) {
|
|
||||||
outputFile.print(inputName + ",sensitivity" + jjj + ",specificity" + jjj + ",");
|
|
||||||
jjj++;
|
|
||||||
}
|
|
||||||
outputFile.println("sentinel");
|
|
||||||
|
|
||||||
for( int step = 0; step < NUM_STEPS; step++ ) {
|
|
||||||
curveIndex = 0;
|
|
||||||
for( final String inputName : inputRodNames ) {
|
|
||||||
final VariantDataManager dataManager = dataManagerMap.get(inputName);
|
|
||||||
truePos[curveIndex] = 0;
|
|
||||||
falsePos[curveIndex] = 0;
|
|
||||||
trueNeg[curveIndex] = 0;
|
|
||||||
falseNeg[curveIndex] = 0;
|
|
||||||
final int NUM_VARIANTS = dataManager.data.length;
|
|
||||||
for( int iii = 0; iii < NUM_VARIANTS; iii++ ) {
|
|
||||||
if( dataManager.data[iii].qual >= qualCut[curveIndex] ) { // this var is in this hypothetical call set
|
|
||||||
if( dataManager.data[iii].isTrueVariant ) {
|
|
||||||
truePos[curveIndex]++;
|
|
||||||
} else {
|
|
||||||
falsePos[curveIndex]++;
|
|
||||||
}
|
|
||||||
} else { // this var is out of this hypothetical call set
|
|
||||||
if( dataManager.data[iii].isTrueVariant ) {
|
|
||||||
falseNeg[curveIndex]++;
|
|
||||||
} else {
|
|
||||||
trueNeg[curveIndex]++;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
final double sensitivity = ((double) truePos[curveIndex]) / ((double) truePos[curveIndex] + falseNegGlobal[curveIndex] + falseNeg[curveIndex]);
|
|
||||||
final double specificity = ((double) trueNegGlobal[curveIndex] + trueNeg[curveIndex]) /
|
|
||||||
((double) falsePos[curveIndex] + trueNegGlobal[curveIndex] + trueNeg[curveIndex]);
|
|
||||||
outputFile.print( String.format("%.8f,%.8f,%.8f,", qualCut[curveIndex], sensitivity, 1.0 - specificity) );
|
|
||||||
qualCut[curveIndex] += incrementQual[curveIndex];
|
|
||||||
curveIndex++;
|
|
||||||
}
|
|
||||||
outputFile.println("-1");
|
|
||||||
}
|
|
||||||
|
|
||||||
outputFile.close();
|
|
||||||
|
|
||||||
// Execute Rscript command to plot the optimization curve
|
|
||||||
// Print out the command line to make it clear to the user what is being executed and how one might modify it
|
|
||||||
final String rScriptCommandLine = PATH_TO_RSCRIPT + " " + PATH_TO_RESOURCES + "plot_variantROCCurve.R" + " " + OUTPUT_PREFIX + ".dat";
|
|
||||||
System.out.println( rScriptCommandLine );
|
|
||||||
|
|
||||||
// Execute the RScript command to plot the table of truth values
|
|
||||||
try {
|
|
||||||
Runtime.getRuntime().exec( rScriptCommandLine );
|
|
||||||
} catch ( IOException e ) {
|
|
||||||
throw new StingException( "Unable to execute RScript command: " + rScriptCommandLine );
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
@ -35,7 +35,5 @@ public class VariantDatum {
|
||||||
public double[] annotations;
|
public double[] annotations;
|
||||||
public boolean isTransition;
|
public boolean isTransition;
|
||||||
public boolean isKnown;
|
public boolean isKnown;
|
||||||
public boolean isTrueVariant;
|
|
||||||
public boolean isHet;
|
|
||||||
public double qual;
|
public double qual;
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -61,7 +61,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
private final double[][] mu; // The means for the clusters
|
private final double[][] mu; // The means for the clusters
|
||||||
private final Matrix[] sigma; // The variances for the clusters, sigma is really sigma^2
|
private final Matrix[] sigma; // The variances for the clusters, sigma is really sigma^2
|
||||||
private final Matrix[] sigmaInverse;
|
private final Matrix[] sigmaInverse;
|
||||||
private final boolean[] deadCluster;
|
|
||||||
private final double[] pCluster;
|
private final double[] pCluster;
|
||||||
private final double[] determinant;
|
private final double[] determinant;
|
||||||
private final double[] clusterTITV;
|
private final double[] clusterTITV;
|
||||||
|
|
@ -82,7 +81,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
mu = new double[numGaussians][];
|
mu = new double[numGaussians][];
|
||||||
sigma = new Matrix[numGaussians];
|
sigma = new Matrix[numGaussians];
|
||||||
determinant = new double[numGaussians];
|
determinant = new double[numGaussians];
|
||||||
deadCluster = new boolean[numGaussians];
|
|
||||||
pCluster = new double[numGaussians];
|
pCluster = new double[numGaussians];
|
||||||
clusterTITV = new double[numGaussians];
|
clusterTITV = new double[numGaussians];
|
||||||
clusterTruePositiveRate = new double[numGaussians];
|
clusterTruePositiveRate = new double[numGaussians];
|
||||||
|
|
@ -114,7 +112,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
// Several of the clustering parameters aren't used the second time around in ApplyVariantClusters
|
// Several of the clustering parameters aren't used the second time around in ApplyVariantClusters
|
||||||
numIterations = 0;
|
numIterations = 0;
|
||||||
clusterTITV = null;
|
clusterTITV = null;
|
||||||
deadCluster = null;
|
|
||||||
minVarInCluster = 0;
|
minVarInCluster = 0;
|
||||||
|
|
||||||
// BUGBUG: move this parsing out of the constructor
|
// BUGBUG: move this parsing out of the constructor
|
||||||
|
|
@ -144,7 +141,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
sigma[kkk] = new Matrix(sigmaVals[kkk]);
|
sigma[kkk] = new Matrix(sigmaVals[kkk]);
|
||||||
sigmaInverse[kkk] = sigma[kkk].inverse(); // Precompute all the inverses and determinants for use later
|
sigmaInverse[kkk] = sigma[kkk].inverse(); // Precompute all the inverses and determinants for use later
|
||||||
determinant[kkk] = sigma[kkk].det();
|
determinant[kkk] = sigma[kkk].det();
|
||||||
//if( determinant[kkk] < MIN_DETERMINANT ) { determinant[kkk] = MIN_DETERMINANT; }
|
|
||||||
kkk++;
|
kkk++;
|
||||||
}
|
}
|
||||||
isUsingTiTvModel = _isUsingTiTvModel;
|
isUsingTiTvModel = _isUsingTiTvModel;
|
||||||
|
|
@ -252,12 +248,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
|
|
||||||
// Set up the initial random Gaussians
|
// Set up the initial random Gaussians
|
||||||
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
||||||
deadCluster[kkk] = false;
|
|
||||||
pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster));
|
pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster));
|
||||||
//final double[] randMu = new double[numAnnotations];
|
|
||||||
//for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
|
||||||
// randMu[jjj] = -1.5 + 3.0 * rand.nextDouble();
|
|
||||||
//}
|
|
||||||
mu[kkk] = data[rand.nextInt(numVariants)].annotations;
|
mu[kkk] = data[rand.nextInt(numVariants)].annotations;
|
||||||
final double[][] randSigma = new double[numAnnotations][numAnnotations];
|
final double[][] randSigma = new double[numAnnotations][numAnnotations];
|
||||||
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
||||||
|
|
@ -270,18 +261,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
tmp = tmp.times(tmp.transpose());
|
tmp = tmp.times(tmp.transpose());
|
||||||
sigma[kkk] = tmp;
|
sigma[kkk] = tmp;
|
||||||
determinant[kkk] = sigma[kkk].det();
|
determinant[kkk] = sigma[kkk].det();
|
||||||
//if( determinant[kkk] < MIN_DETERMINANT ) { deadCluster[kkk] = true; }
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// The EM loop
|
// The EM loop
|
||||||
for( int ttt = 0; ttt < numIterations; ttt++ ) {
|
for( int ttt = 0; ttt < numIterations; ttt++ ) {
|
||||||
|
|
||||||
//int numValidClusters = 0;
|
|
||||||
//for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
|
||||||
// if( !deadCluster[kkk] ) { numValidClusters++; }
|
|
||||||
//}
|
|
||||||
//logger.info("Starting iteration " + (ttt+1) + " with " + numValidClusters + " clusters.");
|
|
||||||
|
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
// Expectation Step (calculate the probability that each data point is in each cluster)
|
// Expectation Step (calculate the probability that each data point is in each cluster)
|
||||||
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||||
|
|
@ -309,7 +293,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
clusterTITV[kkk] = 0.0;
|
clusterTITV[kkk] = 0.0;
|
||||||
clusterTruePositiveRate[kkk] = 1.0;
|
clusterTruePositiveRate[kkk] = 1.0;
|
||||||
}
|
}
|
||||||
printClusterParamters( clusterFileName + ".WithoutTiTv." + iterationNumber );
|
printClusterParameters( clusterFileName + ".WithoutTiTv." + iterationNumber );
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -340,29 +324,26 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
final boolean isTransition = data[iii].isTransition;
|
final boolean isTransition = data[iii].isTransition;
|
||||||
final boolean isKnown = data[iii].isKnown;
|
final boolean isKnown = data[iii].isKnown;
|
||||||
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
||||||
if( !deadCluster[kkk] ) {
|
final double prob = pVarInCluster[kkk][iii];
|
||||||
final double prob = pVarInCluster[kkk][iii];
|
if( isKnown ) { // known
|
||||||
if( isKnown ) { // known
|
probKnown[kkk] += prob;
|
||||||
probKnown[kkk] += prob;
|
if( isTransition ) { // transition
|
||||||
if( isTransition ) { // transition
|
probKnownTi[kkk] += prob;
|
||||||
probKnownTi[kkk] += prob;
|
probTi[kkk] += prob;
|
||||||
probTi[kkk] += prob;
|
} else { // transversion
|
||||||
} else { // transversion
|
probKnownTv[kkk] += prob;
|
||||||
probKnownTv[kkk] += prob;
|
probTv[kkk] += prob;
|
||||||
probTv[kkk] += prob;
|
}
|
||||||
}
|
} else { //novel
|
||||||
} else { //novel
|
probNovel[kkk] += prob;
|
||||||
probNovel[kkk] += prob;
|
if( isTransition ) { // transition
|
||||||
if( isTransition ) { // transition
|
probNovelTi[kkk] += prob;
|
||||||
probNovelTi[kkk] += prob;
|
probTi[kkk] += prob;
|
||||||
probTi[kkk] += prob;
|
} else { // transversion
|
||||||
} else { // transversion
|
probNovelTv[kkk] += prob;
|
||||||
probNovelTv[kkk] += prob;
|
probTv[kkk] += prob;
|
||||||
probTv[kkk] += prob;
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -376,50 +357,46 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
knownAlphaFactor = 0.5;
|
knownAlphaFactor = 0.5;
|
||||||
}
|
}
|
||||||
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
||||||
if( !deadCluster[kkk] ) {
|
clusterTITV[kkk] = probTi[kkk] / probTv[kkk];
|
||||||
clusterTITV[kkk] = probTi[kkk] / probTv[kkk];
|
if( probKnown[kkk] > 500.0 && probNovel[kkk] > 500.0 ) {
|
||||||
if( probKnown[kkk] > 500.0 && probNovel[kkk] > 500.0 ) {
|
clusterTruePositiveRate[kkk] = calcTruePositiveRateFromKnownTITV( probKnownTi[kkk] / probKnownTv[kkk], probNovelTi[kkk] / probNovelTv[kkk], clusterTITV[kkk], knownAlphaFactor );
|
||||||
clusterTruePositiveRate[kkk] = calcTruePositiveRateFromKnownTITV( probKnownTi[kkk] / probKnownTv[kkk], probNovelTi[kkk] / probNovelTv[kkk], clusterTITV[kkk], knownAlphaFactor );
|
} else {
|
||||||
} else {
|
clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] );
|
||||||
clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] );
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if( ttt == 0 ) {
|
if( ttt == 0 ) {
|
||||||
printClusterParamters( clusterFileName + ".TargetTiTv." + iterationNumber );
|
printClusterParameters( clusterFileName + ".TargetTiTv." + iterationNumber );
|
||||||
} else if( ttt == 1 ) {
|
} else if( ttt == 1 ) {
|
||||||
printClusterParamters( clusterFileName + ".KnownTiTv." + iterationNumber );
|
printClusterParameters( clusterFileName + ".KnownTiTv." + iterationNumber );
|
||||||
} else if( ttt == 2 ) {
|
} else if( ttt == 2 ) {
|
||||||
printClusterParamters( clusterFileName + ".BlendedTiTv." + iterationNumber );
|
printClusterParameters( clusterFileName + ".BlendedTiTv." + iterationNumber );
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
private void printClusterParamters( final String clusterFileName ) {
|
private void printClusterParameters( final String clusterFileName ) {
|
||||||
try {
|
try {
|
||||||
final PrintStream outputFile = new PrintStream( clusterFileName );
|
final PrintStream outputFile = new PrintStream( clusterFileName );
|
||||||
dataManager.printClusterFileHeader( outputFile );
|
dataManager.printClusterFileHeader( outputFile );
|
||||||
final int numAnnotations = mu[0].length;
|
final int numAnnotations = mu[0].length;
|
||||||
final int numVariants = dataManager.numVariants;
|
final int numVariants = dataManager.numVariants;
|
||||||
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
|
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
|
||||||
if( !deadCluster[kkk] ) {
|
if( pCluster[kkk] * numVariants > minVarInCluster ) {
|
||||||
if( pCluster[kkk] * numVariants > minVarInCluster ) {
|
final double sigmaVals[][] = sigma[kkk].getArray();
|
||||||
final double sigmaVals[][] = sigma[kkk].getArray();
|
outputFile.print("@!CLUSTER,");
|
||||||
outputFile.print("@!CLUSTER,");
|
outputFile.print(pCluster[kkk] + ",");
|
||||||
outputFile.print(pCluster[kkk] + ",");
|
outputFile.print(clusterTITV[kkk] + ",");
|
||||||
outputFile.print(clusterTITV[kkk] + ",");
|
outputFile.print(clusterTruePositiveRate[kkk] + ",");
|
||||||
outputFile.print(clusterTruePositiveRate[kkk] + ",");
|
for(int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
||||||
for(int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
outputFile.print(mu[kkk][jjj] + ",");
|
||||||
outputFile.print(mu[kkk][jjj] + ",");
|
|
||||||
}
|
|
||||||
for(int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
|
||||||
for(int ppp = 0; ppp < numAnnotations; ppp++ ) {
|
|
||||||
outputFile.print(sigmaVals[jjj][ppp] + ",");
|
|
||||||
}
|
|
||||||
}
|
|
||||||
outputFile.println(-1);
|
|
||||||
}
|
}
|
||||||
|
for(int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
||||||
|
for(int ppp = 0; ppp < numAnnotations; ppp++ ) {
|
||||||
|
outputFile.print(sigmaVals[jjj][ppp] + ",");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
outputFile.println(-1);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
outputFile.close();
|
outputFile.close();
|
||||||
|
|
@ -458,52 +435,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
|
|
||||||
evaluateGaussiansForSingleVariant( annotations, pVarInCluster );
|
evaluateGaussiansForSingleVariant( annotations, pVarInCluster );
|
||||||
|
|
||||||
//if( isUsingTiTvModel ) {
|
|
||||||
// Sum prob model
|
|
||||||
double sum = 0.0;
|
|
||||||
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
|
|
||||||
sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk];
|
|
||||||
}
|
|
||||||
return sum;
|
|
||||||
/*
|
|
||||||
} else {
|
|
||||||
// Max prob model
|
|
||||||
double maxProb = 0.0;
|
|
||||||
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
|
|
||||||
if( pVarInCluster[kkk] > maxProb ) {
|
|
||||||
maxProb = pVarInCluster[kkk];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return maxProb;
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Max prob model
|
|
||||||
/*
|
|
||||||
double maxProb = 0.0;
|
|
||||||
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
|
|
||||||
if( pVarInCluster[kkk] > maxProb ) {
|
|
||||||
maxProb = pVarInCluster[kkk];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return maxProb;
|
|
||||||
*/
|
|
||||||
|
|
||||||
// Entropy model
|
|
||||||
/*
|
|
||||||
double sum = 0.0;
|
double sum = 0.0;
|
||||||
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
|
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
|
||||||
//if( isHetCluster[kkk] == isHet ) {
|
sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk];
|
||||||
sum += pVarInCluster[kkk] * Math.log(pVarInCluster[kkk]);
|
|
||||||
//}
|
|
||||||
}
|
}
|
||||||
|
return sum;
|
||||||
double entropy = -1.0 * sum;
|
|
||||||
double maxEntropy = -1.0 * Math.log( 1.0 / ((double) numGaussians));
|
|
||||||
|
|
||||||
//System.out.println("H = " + entropy + ", pTrue = " + ( 1.0 - (entropy / maxEntropy) ));
|
|
||||||
return ( 1.0 - (entropy / maxEntropy) );
|
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
public final void outputClusterReports( final String outputPrefix ) {
|
public final void outputClusterReports( final String outputPrefix ) {
|
||||||
|
|
@ -688,60 +624,56 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
final double sigmaVals[][][] = new double[numGaussians][][];
|
final double sigmaVals[][][] = new double[numGaussians][][];
|
||||||
final double denom[] = new double[numGaussians];
|
final double denom[] = new double[numGaussians];
|
||||||
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
||||||
if( !deadCluster[kkk] ) {
|
sigmaVals[kkk] = sigma[kkk].inverse().getArray();
|
||||||
sigmaVals[kkk] = sigma[kkk].inverse().getArray();
|
denom[kkk] = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(Math.abs(determinant[kkk]), 0.5);
|
||||||
denom[kkk] = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(Math.abs(determinant[kkk]), 0.5);
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
final double mult[] = new double[numAnnotations];
|
final double mult[] = new double[numAnnotations];
|
||||||
for( int iii = 0; iii < data.length; iii++ ) {
|
for( int iii = 0; iii < data.length; iii++ ) {
|
||||||
double sumProb = 0.0;
|
double sumProb = 0.0;
|
||||||
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
||||||
if( !deadCluster[kkk] ) {
|
double sum = 0.0;
|
||||||
double sum = 0.0;
|
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
||||||
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
mult[jjj] = 0.0;
|
||||||
mult[jjj] = 0.0;
|
for( int ppp = 0; ppp < numAnnotations; ppp++ ) {
|
||||||
for( int ppp = 0; ppp < numAnnotations; ppp++ ) {
|
mult[jjj] += (data[iii].annotations[ppp] - mu[kkk][ppp]) * sigmaVals[kkk][ppp][jjj];
|
||||||
mult[jjj] += (data[iii].annotations[ppp] - mu[kkk][ppp]) * sigmaVals[kkk][ppp][jjj];
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
|
||||||
sum += mult[jjj] * (data[iii].annotations[jjj] - mu[kkk][jjj]);
|
|
||||||
}
|
|
||||||
|
|
||||||
pVarInCluster[kkk][iii] = pCluster[kkk] * (Math.exp( -0.5 * sum ) / denom[kkk]);
|
|
||||||
likelihood += pVarInCluster[kkk][iii];
|
|
||||||
if(Double.isNaN(denom[kkk]) || determinant[kkk] < 0.5 * MIN_DETERMINANT) {
|
|
||||||
System.out.println("det = " + sigma[kkk].det());
|
|
||||||
System.out.println("denom = " + denom[kkk]);
|
|
||||||
System.out.println("sumExp = " + sum);
|
|
||||||
System.out.println("pVar = " + pVarInCluster[kkk][iii]);
|
|
||||||
System.out.println("=-------=");
|
|
||||||
throw new StingException("Numerical Instability! determinant of covariance matrix <= 0. Try running with fewer clusters and then with better behaved annotation values.");
|
|
||||||
}
|
|
||||||
if(sum < 0.0) {
|
|
||||||
System.out.println("det = " + sigma[kkk].det());
|
|
||||||
System.out.println("denom = " + denom[kkk]);
|
|
||||||
System.out.println("sumExp = " + sum);
|
|
||||||
System.out.println("pVar = " + pVarInCluster[kkk][iii]);
|
|
||||||
System.out.println("=-------=");
|
|
||||||
throw new StingException("Numerical Instability! covariance matrix no longer positive definite. Try running with fewer clusters and then with better behaved annotation values.");
|
|
||||||
}
|
|
||||||
if(pVarInCluster[kkk][iii] > 1.0) {
|
|
||||||
System.out.println("det = " + sigma[kkk].det());
|
|
||||||
System.out.println("denom = " + denom[kkk]);
|
|
||||||
System.out.println("sumExp = " + sum);
|
|
||||||
System.out.println("pVar = " + pVarInCluster[kkk][iii]);
|
|
||||||
System.out.println("=-------=");
|
|
||||||
throw new StingException("Numerical Instability! probability distribution returns > 1.0. Try running with fewer clusters and then with better behaved annotation values.");
|
|
||||||
}
|
|
||||||
|
|
||||||
if( pVarInCluster[kkk][iii] < MIN_PROB) { // Very small numbers are a very big problem
|
|
||||||
pVarInCluster[kkk][iii] = MIN_PROB;// + MIN_PROB * rand.nextDouble();
|
|
||||||
}
|
|
||||||
|
|
||||||
sumProb += pVarInCluster[kkk][iii];
|
|
||||||
}
|
}
|
||||||
|
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
||||||
|
sum += mult[jjj] * (data[iii].annotations[jjj] - mu[kkk][jjj]);
|
||||||
|
}
|
||||||
|
|
||||||
|
pVarInCluster[kkk][iii] = pCluster[kkk] * (Math.exp( -0.5 * sum ) / denom[kkk]);
|
||||||
|
likelihood += pVarInCluster[kkk][iii];
|
||||||
|
if(Double.isNaN(denom[kkk]) || determinant[kkk] < 0.5 * MIN_DETERMINANT) {
|
||||||
|
System.out.println("det = " + sigma[kkk].det());
|
||||||
|
System.out.println("denom = " + denom[kkk]);
|
||||||
|
System.out.println("sumExp = " + sum);
|
||||||
|
System.out.println("pVar = " + pVarInCluster[kkk][iii]);
|
||||||
|
System.out.println("=-------=");
|
||||||
|
throw new StingException("Numerical Instability! determinant of covariance matrix <= 0. Try running with fewer clusters and then with better behaved annotation values.");
|
||||||
|
}
|
||||||
|
if(sum < 0.0) {
|
||||||
|
System.out.println("det = " + sigma[kkk].det());
|
||||||
|
System.out.println("denom = " + denom[kkk]);
|
||||||
|
System.out.println("sumExp = " + sum);
|
||||||
|
System.out.println("pVar = " + pVarInCluster[kkk][iii]);
|
||||||
|
System.out.println("=-------=");
|
||||||
|
throw new StingException("Numerical Instability! covariance matrix no longer positive definite. Try running with fewer clusters and then with better behaved annotation values.");
|
||||||
|
}
|
||||||
|
if(pVarInCluster[kkk][iii] > 1.0) {
|
||||||
|
System.out.println("det = " + sigma[kkk].det());
|
||||||
|
System.out.println("denom = " + denom[kkk]);
|
||||||
|
System.out.println("sumExp = " + sum);
|
||||||
|
System.out.println("pVar = " + pVarInCluster[kkk][iii]);
|
||||||
|
System.out.println("=-------=");
|
||||||
|
throw new StingException("Numerical Instability! probability distribution returns > 1.0. Try running with fewer clusters and then with better behaved annotation values.");
|
||||||
|
}
|
||||||
|
|
||||||
|
if( pVarInCluster[kkk][iii] < MIN_PROB) { // Very small numbers are a very big problem
|
||||||
|
pVarInCluster[kkk][iii] = MIN_PROB;// + MIN_PROB * rand.nextDouble();
|
||||||
|
}
|
||||||
|
|
||||||
|
sumProb += pVarInCluster[kkk][iii];
|
||||||
}
|
}
|
||||||
|
|
||||||
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
||||||
|
|
@ -757,8 +689,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster ) {
|
private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster ) {
|
||||||
|
|
||||||
final int numAnnotations = annotations.length;
|
final int numAnnotations = annotations.length;
|
||||||
|
|
||||||
double sumProb = 0.0;
|
|
||||||
final double mult[] = new double[numAnnotations];
|
final double mult[] = new double[numAnnotations];
|
||||||
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
|
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
|
||||||
final double sigmaVals[][] = sigmaInverse[kkk].getArray();
|
final double sigmaVals[][] = sigmaInverse[kkk].getArray();
|
||||||
|
|
@ -775,28 +705,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
|
|
||||||
final double denom = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(determinant[kkk], 0.5);
|
final double denom = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(determinant[kkk], 0.5);
|
||||||
pVarInCluster[kkk] = pCluster[kkk] * (Math.exp( -0.5 * sum )) / denom;
|
pVarInCluster[kkk] = pCluster[kkk] * (Math.exp( -0.5 * sum )) / denom;
|
||||||
|
|
||||||
/*
|
|
||||||
if( isUsingTiTvModel ) {
|
|
||||||
//pVarInCluster[kkk] = Math.exp( -0.5 * sum );
|
|
||||||
if( pVarInCluster[kkk] < MIN_PROB) { // Very small numbers are a very big problem
|
|
||||||
pVarInCluster[kkk] = MIN_PROB;
|
|
||||||
}
|
|
||||||
sumProb += pVarInCluster[kkk];
|
|
||||||
} else {
|
|
||||||
//final double denom = Math.pow(2.0 * 3.14159, ((double)numAnnotations) / 2.0) * Math.pow(determinant[kkk], 0.5);
|
|
||||||
//pVarInCluster[kkk] = pCluster[kkk] * (Math.exp( -0.5 * sum )) / denom;
|
|
||||||
//pVarInCluster[kkk] = Math.exp( -0.5 * sum );
|
|
||||||
// BUGBUG: should pCluster be the distribution from the GMM or a uniform distribution here?
|
|
||||||
}
|
|
||||||
*/
|
|
||||||
}
|
}
|
||||||
|
|
||||||
//if( isUsingTiTvModel ) {
|
|
||||||
// for( int kkk = 0; kkk < numGaussians; kkk++ ) {
|
|
||||||
// pVarInCluster[kkk] /= sumProb;
|
|
||||||
// }
|
|
||||||
//}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -816,53 +725,48 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
}
|
}
|
||||||
double sumPK = 0.0;
|
double sumPK = 0.0;
|
||||||
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
||||||
if( !deadCluster[kkk] ) {
|
double sumProb = 0.0;
|
||||||
double sumProb = 0.0;
|
for( int iii = 0; iii < numVariants; iii++ ) {
|
||||||
for( int iii = 0; iii < numVariants; iii++ ) {
|
final double prob = pVarInCluster[kkk][iii];
|
||||||
final double prob = pVarInCluster[kkk][iii];
|
sumProb += prob;
|
||||||
sumProb += prob;
|
|
||||||
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
|
||||||
mu[kkk][jjj] += prob * data[iii].annotations[jjj];
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
||||||
mu[kkk][jjj] /= sumProb;
|
mu[kkk][jjj] += prob * data[iii].annotations[jjj];
|
||||||
}
|
|
||||||
|
|
||||||
for( int iii = 0; iii < numVariants; iii++ ) {
|
|
||||||
final double prob = pVarInCluster[kkk][iii];
|
|
||||||
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
|
||||||
for( int ppp = jjj; ppp < numAnnotations; ppp++ ) {
|
|
||||||
sigmaVals[kkk][jjj][ppp] += prob * (data[iii].annotations[jjj]-mu[kkk][jjj]) * (data[iii].annotations[ppp]-mu[kkk][ppp]);
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
|
||||||
for( int ppp = jjj; ppp < numAnnotations; ppp++ ) {
|
|
||||||
if( sigmaVals[kkk][jjj][ppp] < MIN_SIGMA ) { // Very small numbers are a very big problem
|
|
||||||
sigmaVals[kkk][jjj][ppp] = MIN_SIGMA;// + MIN_SIGMA * rand.nextDouble();
|
|
||||||
}
|
|
||||||
sigmaVals[kkk][ppp][jjj] = sigmaVals[kkk][jjj][ppp]; // sigma must be a symmetric matrix
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
|
||||||
for( int ppp = 0; ppp < numAnnotations; ppp++ ) {
|
|
||||||
sigmaVals[kkk][jjj][ppp] /= sumProb;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
sigma[kkk] = new Matrix(sigmaVals[kkk]);
|
|
||||||
determinant[kkk] = sigma[kkk].det();
|
|
||||||
//if( determinant[kkk] < MIN_DETERMINANT ) { deadCluster[kkk] = true; }
|
|
||||||
|
|
||||||
if( !deadCluster[kkk] ) {
|
|
||||||
pCluster[kkk] = sumProb / numVariants;
|
|
||||||
sumPK += pCluster[kkk];
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
||||||
|
mu[kkk][jjj] /= sumProb;
|
||||||
|
}
|
||||||
|
|
||||||
|
for( int iii = 0; iii < numVariants; iii++ ) {
|
||||||
|
final double prob = pVarInCluster[kkk][iii];
|
||||||
|
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
||||||
|
for( int ppp = jjj; ppp < numAnnotations; ppp++ ) {
|
||||||
|
sigmaVals[kkk][jjj][ppp] += prob * (data[iii].annotations[jjj]-mu[kkk][jjj]) * (data[iii].annotations[ppp]-mu[kkk][ppp]);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
||||||
|
for( int ppp = jjj; ppp < numAnnotations; ppp++ ) {
|
||||||
|
if( sigmaVals[kkk][jjj][ppp] < MIN_SIGMA ) { // Very small numbers are a very big problem
|
||||||
|
sigmaVals[kkk][jjj][ppp] = MIN_SIGMA;// + MIN_SIGMA * rand.nextDouble();
|
||||||
|
}
|
||||||
|
sigmaVals[kkk][ppp][jjj] = sigmaVals[kkk][jjj][ppp]; // sigma must be a symmetric matrix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
||||||
|
for( int ppp = 0; ppp < numAnnotations; ppp++ ) {
|
||||||
|
sigmaVals[kkk][jjj][ppp] /= sumProb;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sigma[kkk] = new Matrix(sigmaVals[kkk]);
|
||||||
|
determinant[kkk] = sigma[kkk].det();
|
||||||
|
|
||||||
|
pCluster[kkk] = sumProb / numVariants;
|
||||||
|
sumPK += pCluster[kkk];
|
||||||
}
|
}
|
||||||
|
|
||||||
// ensure pCluster sums to one, it doesn't automatically due to very small numbers getting capped
|
// ensure pCluster sums to one, it doesn't automatically due to very small numbers getting capped
|
||||||
|
|
@ -924,15 +828,10 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
|
||||||
|
|
||||||
// Replace extremely small clusters with another random draw from the dataset
|
// Replace extremely small clusters with another random draw from the dataset
|
||||||
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
|
||||||
//if(determinant[kkk] < MIN_DETERMINANT ) {
|
|
||||||
if( pCluster[kkk] < 0.0005 * (1.0 / ((double) (stopCluster-startCluster))) ||
|
if( pCluster[kkk] < 0.0005 * (1.0 / ((double) (stopCluster-startCluster))) ||
|
||||||
determinant[kkk] < MIN_DETERMINANT ) { // This is a very small cluster compared to all the others
|
determinant[kkk] < MIN_DETERMINANT ) { // This is a very small cluster compared to all the others
|
||||||
logger.info("!! Found singular cluster! Initializing a new random cluster.");
|
logger.info("!! Found singular cluster! Initializing a new random cluster.");
|
||||||
pCluster[kkk] = 0.1 / ((double) (stopCluster-startCluster)); // 0.5 /
|
pCluster[kkk] = 0.1 / ((double) (stopCluster-startCluster));
|
||||||
//final double[] randMu = new double[numAnnotations];
|
|
||||||
//for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
|
||||||
// randMu[jjj] = -1.5 + 3.0 * rand.nextDouble();
|
|
||||||
//}
|
|
||||||
mu[kkk] = data[rand.nextInt(numVariants)].annotations;
|
mu[kkk] = data[rand.nextInt(numVariants)].annotations;
|
||||||
final double[][] randSigma = new double[numAnnotations][numAnnotations];
|
final double[][] randSigma = new double[numAnnotations][numAnnotations];
|
||||||
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
|
||||||
|
|
|
||||||
|
|
@ -61,14 +61,12 @@ public class VariantOptimizer extends RodWalker<ExpandingArrayList<VariantDatum>
|
||||||
private boolean IGNORE_ALL_INPUT_FILTERS = false;
|
private boolean IGNORE_ALL_INPUT_FILTERS = false;
|
||||||
@Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false)
|
@Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false)
|
||||||
private String[] IGNORE_INPUT_FILTERS = null;
|
private String[] IGNORE_INPUT_FILTERS = null;
|
||||||
|
|
||||||
@Argument(fullName="use_annotation", shortName="an", doc="The names of the annotations which should used for calculations", required=true)
|
@Argument(fullName="use_annotation", shortName="an", doc="The names of the annotations which should used for calculations", required=true)
|
||||||
private String[] USE_ANNOTATIONS = null;
|
private String[] USE_ANNOTATIONS = null;
|
||||||
|
|
||||||
@Argument(fullName="clusterFile", shortName="clusterFile", doc="The output cluster file", required=true)
|
@Argument(fullName="clusterFile", shortName="clusterFile", doc="The output cluster file", required=true)
|
||||||
private String CLUSTER_FILENAME = "optimizer.cluster";
|
private String CLUSTER_FILENAME = "optimizer.cluster";
|
||||||
@Argument(fullName="numGaussians", shortName="nG", doc="The number of Gaussians to be used in the Gaussian Mixture model", required=false)
|
@Argument(fullName="numGaussians", shortName="nG", doc="The number of Gaussians to be used in the Gaussian Mixture model", required=false)
|
||||||
private int NUM_GAUSSIANS = 7;
|
private int NUM_GAUSSIANS = 1;
|
||||||
@Argument(fullName="numIterations", shortName="nI", doc="The number of iterations to be performed in the Gaussian Mixture model", required=false)
|
@Argument(fullName="numIterations", shortName="nI", doc="The number of iterations to be performed in the Gaussian Mixture model", required=false)
|
||||||
private int NUM_ITERATIONS = 10;
|
private int NUM_ITERATIONS = 10;
|
||||||
@Argument(fullName="minVarInCluster", shortName="minVar", doc="The minimum number of variants in a cluster to be considered a valid cluster. It can be used to prevent overfitting.", required=false)
|
@Argument(fullName="minVarInCluster", shortName="minVar", doc="The minimum number of variants in a cluster to be considered a valid cluster. It can be used to prevent overfitting.", required=false)
|
||||||
|
|
@ -148,9 +146,9 @@ public class VariantOptimizer extends RodWalker<ExpandingArrayList<VariantDatum>
|
||||||
boolean isKnown = !vc.getAttribute("ID").equals(".");
|
boolean isKnown = !vc.getAttribute("ID").equals(".");
|
||||||
if(usingDBSNP) {
|
if(usingDBSNP) {
|
||||||
isKnown = false;
|
isKnown = false;
|
||||||
for( VariantContext dbsnpVC : tracker.getVariantContexts(DbSNPHelper.STANDARD_DBSNP_TRACK_NAME, null, context.getLocation(), false, false) ) {
|
for( final VariantContext dbsnpVC : tracker.getVariantContexts(DbSNPHelper.STANDARD_DBSNP_TRACK_NAME, null, context.getLocation(), false, false) ) {
|
||||||
if(dbsnpVC != null && dbsnpVC.isSNP()) {
|
if(dbsnpVC != null && dbsnpVC.isSNP()) {
|
||||||
isKnown=true;
|
isKnown = true;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue