Adding VariantConcordanceROCCurveWalker to create ROC curves comparing concordance between optimized call sets and validation truth sets in VCF format in order to evaluate performance of variant optimizer independently of achieving a particular novel ti/tv ratio. Added option to ignore only the specified filters in the input call sets via --ignore_filter <String>. Added option to provide a prior estimate of error for known snps via --known_prior <qual>. The het and hom calls are clustered independently. Infrastructure in place to use titv of known snps to inform p(true) of novel snps. Tweaked protection against overfitting based on suggestions from several people. Minor edits to AnalyzeAnnotations.

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@3071 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
rpoplin 2010-03-24 19:43:10 +00:00
parent 47e30aba92
commit 06a212e612
13 changed files with 756 additions and 145 deletions

View File

@ -0,0 +1,21 @@
#!/broad/tools/apps/R-2.6.0/bin/Rscript
args <- commandArgs(TRUE)
verbose = TRUE
input = args[1]
data = read.table(input,sep=",",head=T)
numCurves = (length(data) - 1)/3
maxSpec = max(data[,(1:numCurves)*3])
outfile = paste(input, ".variantROCCurve.pdf", sep="")
pdf(outfile, height=7, width=7)
par(cex=1.3)
plot(data$specificity1,data$sensitivity1, type="n", xlim=c(0,maxSpec),ylim=c(0,1),xlab="1 - Specificity",ylab="Sensitivity")
for(iii in 1:numCurves) {
points(data[,iii*3],data[,(iii-1)*3+2],lwd=3,type="l",col=iii)
}
legend("bottomright", names(data)[(0:(numCurves-1))*3+1], col=1:numCurves,lwd=3)
dev.off()

View File

@ -108,7 +108,7 @@ public class AnalyzeAnnotationsWalker extends RodWalker<Integer, Integer> {
// First find out if this variant is in the truth sets
boolean isInTruthSet = false;
boolean isTrueVariant = false;
for( ReferenceOrderedDatum rod : tracker.getAllRods() ) {
for( final ReferenceOrderedDatum rod : tracker.getAllRods() ) {
if( rod != null && rod.getName().toUpperCase().startsWith("TRUTH") ) {
isInTruthSet = true;
@ -132,7 +132,7 @@ public class AnalyzeAnnotationsWalker extends RodWalker<Integer, Integer> {
}
// Add each annotation in this VCF Record to the dataManager
for( ReferenceOrderedDatum rod : tracker.getAllRods() ) {
for( final ReferenceOrderedDatum rod : tracker.getAllRods() ) {
if( rod != null && rod instanceof RodVCF && !rod.getName().toUpperCase().startsWith("TRUTH") ) {
final RodVCF variant = (RodVCF) rod;
if( variant.isSNP() ) {

View File

@ -63,7 +63,7 @@ public class AnnotationDataManager {
// Loop over each annotation in the vcf record
final Map<String,String> infoField = variant.getInfoValues();
infoField.put("QUAL", ((Double)variant.getQual()).toString() ); // add QUAL field to annotations
for( String annotationKey : infoField.keySet() ) {
for( final String annotationKey : infoField.keySet() ) {
float value;
try {
@ -102,7 +102,7 @@ public class AnnotationDataManager {
System.out.println( "\nFinished reading variants into memory. Executing RScript commands:" );
// For each annotation we've seen
for( String annotationKey : data.keySet() ) {
for( final String annotationKey : data.keySet() ) {
PrintStream output;
try {
@ -117,7 +117,7 @@ public class AnnotationDataManager {
// Bin SNPs and calculate truth metrics for each bin
thisAnnotationBin.clearBin();
for( AnnotationDatum datum : data.get( annotationKey ) ) {
for( final AnnotationDatum datum : data.get( annotationKey ) ) {
thisAnnotationBin.combine( datum );
if( thisAnnotationBin.numVariants( AnnotationDatum.FULL_SET ) >= MAX_VARIANTS_PER_BIN ) { // This annotation bin is full
output.println( thisAnnotationBin.value + "\t" + thisAnnotationBin.calcTiTv( AnnotationDatum.FULL_SET ) + "\t" + thisAnnotationBin.calcDBsnpRate() + "\t" + thisAnnotationBin.calcTPrate() +

View File

@ -59,7 +59,7 @@ public class AnnotationDatum implements Comparator<AnnotationDatum> {
}
}
public AnnotationDatum( float _value ) {
public AnnotationDatum( final float _value ) {
value = _value;
ti = new int[NUM_SETS];

View File

@ -12,10 +12,7 @@ import org.broadinstitute.sting.utils.genotype.vcf.*;
import java.io.File;
import java.io.IOException;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Set;
import java.util.TreeSet;
import java.util.*;
/*
* Copyright (c) 2010 The Broad Institute
@ -60,8 +57,12 @@ public class ApplyVariantClustersWalker extends RodWalker<ExpandingArrayList<Var
private double TARGET_TITV = 2.1;
@Argument(fullName="desired_num_variants", shortName="dV", doc="The desired number of variants to keep in a theoretically filtered set", required=false)
private int DESIRED_NUM_VARIANTS = 0;
@Argument(fullName="ignore_input_filters", shortName="ignoreFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false)
private boolean IGNORE_INPUT_FILTERS = false;
@Argument(fullName="ignore_all_input_filters", shortName="ignoreAllFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false)
private boolean IGNORE_ALL_INPUT_FILTERS = false;
@Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false)
private String[] IGNORE_INPUT_FILTERS = null;
@Argument(fullName="known_prior", shortName="knownPrior", doc="A prior on the quality of known variants, a phred scaled probability of being true. Default is 30.0", required=false)
private double KNOWN_VAR_QUAL_PRIOR = 30.0;
@Argument(fullName="output_prefix", shortName="output", doc="The prefix added to output VCF file name and optimization curve pdf file name", required=false)
private String OUTPUT_PREFIX = "optimizer";
@Argument(fullName="clusterFile", shortName="clusterFile", doc="The output cluster file", required=true)
@ -78,6 +79,8 @@ public class ApplyVariantClustersWalker extends RodWalker<ExpandingArrayList<Var
/////////////////////////////
private VariantGaussianMixtureModel theModel = null;
private VCFWriter vcfWriter;
private Set<String> ignoreInputFilterSet = null;
//---------------------------------------------------------------------------------------------------------------
//
@ -88,6 +91,10 @@ public class ApplyVariantClustersWalker extends RodWalker<ExpandingArrayList<Var
public void initialize() {
if( !PATH_TO_RESOURCES.endsWith("/") ) { PATH_TO_RESOURCES = PATH_TO_RESOURCES + "/"; }
if( IGNORE_INPUT_FILTERS != null ) {
ignoreInputFilterSet = new TreeSet<String>(Arrays.asList(IGNORE_INPUT_FILTERS));
}
switch (OPTIMIZATION_MODEL) {
case GAUSSIAN_MIXTURE_MODEL:
theModel = new VariantGaussianMixtureModel( CLUSTER_FILENAME );
@ -101,14 +108,14 @@ public class ApplyVariantClustersWalker extends RodWalker<ExpandingArrayList<Var
// setup the header fields
Set<VCFHeaderLine> hInfo = new HashSet<VCFHeaderLine>();
final Set<VCFHeaderLine> hInfo = new HashSet<VCFHeaderLine>();
hInfo.addAll(VCFUtils.getHeaderFields(getToolkit()));
hInfo.add(new VCFInfoHeaderLine("OQ", 1, VCFInfoHeaderLine.INFO_TYPE.Float, "The original variant quality score"));
hInfo.add(new VCFHeaderLine("source", "VariantOptimizer"));
vcfWriter = new VCFWriter( new File(OUTPUT_PREFIX + ".vcf") );
TreeSet<String> samples = new TreeSet<String>();
final TreeSet<String> samples = new TreeSet<String>();
SampleUtils.getUniquifiedSamplesFromRods(getToolkit(), samples, new HashMap<Pair<String, String>, String>());
VCFHeader vcfHeader = new VCFHeader(hInfo, samples);
final VCFHeader vcfHeader = new VCFHeader(hInfo, samples);
vcfWriter.writeHeader(vcfHeader);
}
@ -126,22 +133,41 @@ public class ApplyVariantClustersWalker extends RodWalker<ExpandingArrayList<Var
}
for( ReferenceOrderedDatum rod : tracker.getAllRods() ) {
for( final ReferenceOrderedDatum rod : tracker.getAllRods() ) {
if( rod != null && rod instanceof RodVCF ) {
final RodVCF rodVCF = ((RodVCF) rod);
if( rodVCF.isSNP() && (IGNORE_INPUT_FILTERS || !rodVCF.isFiltered()) ) {
final double pTrue = theModel.evaluateVariant( rodVCF.getInfoValues(), rodVCF.getQual() );
final double recalQual = QualityUtils.phredScaleErrorRate( Math.max( 1.0 - pTrue, 0.000000001) );
rodVCF.mCurrentRecord.addInfoField("OQ", ((Double)rodVCF.getQual()).toString() );
rodVCF.mCurrentRecord.setQual( recalQual );
vcfWriter.addRecord( rodVCF.mCurrentRecord );
VariantDatum variantDatum = new VariantDatum();
//BUGBUG: figure out how to make this use VariantContext to be consistent with other VariantOptimizer walkers
// need to convert vc into VCFRecord to write it out?
if( rodVCF.isSNP() &&
(!rodVCF.isFiltered() || IGNORE_ALL_INPUT_FILTERS || (ignoreInputFilterSet != null && ignoreInputFilterSet.containsAll(Arrays.asList(rodVCF.getFilteringCodes())))) ) {
final VariantDatum variantDatum = new VariantDatum();
variantDatum.isTransition = BaseUtils.isTransition((byte)rodVCF.getAlternativeBaseForSNP(), (byte)rodVCF.getReferenceForSNP()); //vc.getSNPSubstitutionType().compareTo(BaseUtils.BaseSubstitutionType.TRANSITION) == 0;
variantDatum.isKnown = !rodVCF.isNovel(); //!vc.getAttribute("ID").equals(".");
variantDatum.qual = recalQual;
int numHet = 0;
int numHom = 0;
for( final VCFGenotypeRecord rec : rodVCF.getVCFGenotypeRecords() ) {
if( rec.isHet() ) { numHet++; }
else if( rec.isHom() ) { numHom++; }
}
variantDatum.isHet = numHet > numHom; //vc.getHetCount() > vc.getHomVarCount(); // BUGBUG: what to do here for multi sample calls?
final double pTrue = theModel.evaluateVariant( rodVCF.getInfoValues(), rodVCF.getQual(), variantDatum.isHet );
final double recalQual = QualityUtils.phredScaleErrorRate( Math.max( 1.0 - pTrue, 0.000000001) );
if( variantDatum.isKnown ) {
variantDatum.qual = 0.5 * recalQual + 0.5 * KNOWN_VAR_QUAL_PRIOR;
} else {
variantDatum.qual = recalQual;
}
mapList.add( variantDatum );
rodVCF.mCurrentRecord.addInfoField("OQ", ((Double)rodVCF.getQual()).toString() );
rodVCF.mCurrentRecord.setQual( variantDatum.qual );
rodVCF.mCurrentRecord.setFilterString(VCFRecord.UNFILTERED);
vcfWriter.addRecord( rodVCF.mCurrentRecord );
} else { // not a SNP or is filtered so just dump it out to the VCF file
vcfWriter.addRecord( rodVCF.mCurrentRecord );
}
@ -169,6 +195,8 @@ public class ApplyVariantClustersWalker extends RodWalker<ExpandingArrayList<Var
public void onTraversalDone( ExpandingArrayList<VariantDatum> reduceSum ) {
vcfWriter.close();
final VariantDataManager dataManager = new VariantDataManager( reduceSum, theModel.dataManager.annotationKeys );
reduceSum.clear(); // Don't need this ever again, clean up some memory

View File

@ -32,6 +32,6 @@ package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer;
*/
public interface VariantClusteringModel extends VariantOptimizationInterface {
public void createClusters( final VariantDatum[] data );
public void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster );
//public void applyClusters( final VariantDatum[] data, final String outputPrefix );
}

View File

@ -0,0 +1,281 @@
package org.broadinstitute.sting.playground.gatk.walkers.variantoptimizer;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.contexts.variantcontext.VariantContext;
import org.broadinstitute.sting.gatk.datasources.simpleDataSources.ReferenceOrderedDataSource;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.refdata.RodVCF;
import org.broadinstitute.sting.gatk.walkers.RodWalker;
import org.broadinstitute.sting.utils.ExpandingArrayList;
import org.broadinstitute.sting.utils.Pair;
import org.broadinstitute.sting.utils.StingException;
import org.broadinstitute.sting.utils.cmdLine.Argument;
import java.io.IOException;
import java.io.PrintStream;
import java.util.HashMap;
/*
* Copyright (c) 2010 The Broad Institute
*
* Permission is hereby granted, free of charge, to any person
* obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without
* restriction, including without limitation the rights to use,
* copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following
* conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
*
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
* OTHER DEALINGS IN THE SOFTWARE.
*/
/**
* Calculates variant concordance with the given truth sets and plots ROC curves for each input call set.
*
* @author rpoplin
* @since Mar 18, 2010
*
* @help.summary Calculates variant concordance with the given truth sets and plots ROC curves for each input call set.
*/
public class VariantConcordanceROCCurveWalker extends RodWalker<ExpandingArrayList<Pair<String,VariantDatum>>, HashMap<String,ExpandingArrayList<VariantDatum>>> {
/////////////////////////////
// Command Line Arguments
/////////////////////////////
@Argument(fullName="output_prefix", shortName="output", doc="The prefix added to output VCF file name and optimization curve pdf file name", required=false)
private String OUTPUT_PREFIX = "optimizer";
@Argument(fullName = "path_to_Rscript", shortName = "Rscript", doc = "The path to your implementation of Rscript. For Broad users this is probably /broad/tools/apps/R-2.6.0/bin/Rscript", required = false)
private String PATH_TO_RSCRIPT = "/broad/tools/apps/R-2.6.0/bin/Rscript";
@Argument(fullName = "path_to_resources", shortName = "resources", doc = "Path to resources folder holding the Sting R scripts.", required = false)
private String PATH_TO_RESOURCES = "R/";
/////////////////////////////
// Private Member Variables
/////////////////////////////
private final ExpandingArrayList<String> inputRodNames = new ExpandingArrayList<String>();
private int numCurves;
private int[] trueNegGlobal;
private int[] falseNegGlobal;
private String sampleName = null;
//---------------------------------------------------------------------------------------------------------------
//
// initialize
//
//---------------------------------------------------------------------------------------------------------------
public void initialize() {
if( !PATH_TO_RESOURCES.endsWith("/") ) { PATH_TO_RESOURCES = PATH_TO_RESOURCES + "/"; }
for( ReferenceOrderedDataSource rod : this.getToolkit().getRodDataSources() ) {
if( rod != null && !rod.getName().toUpperCase().startsWith("TRUTH") ) {
if( rod.getReferenceOrderedData().iterator().next().get(0) instanceof RodVCF ) {
inputRodNames.add(rod.getName());
if( sampleName == null ) {
sampleName = ((RodVCF)rod.getReferenceOrderedData().iterator().next().get(0)).getSampleNames()[0]; // BUGBUG: single sample calls only for now
}
}
}
}
numCurves = inputRodNames.size();
trueNegGlobal = new int[numCurves];
falseNegGlobal = new int[numCurves];
for( int kkk = 0; kkk < numCurves; kkk++ ) {
trueNegGlobal[kkk] = 0;
falseNegGlobal[kkk] = 0;
}
}
//---------------------------------------------------------------------------------------------------------------
//
// map
//
//---------------------------------------------------------------------------------------------------------------
public ExpandingArrayList<Pair<String,VariantDatum>> map( RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context ) {
final ExpandingArrayList<Pair<String,VariantDatum>> mapList = new ExpandingArrayList<Pair<String,VariantDatum>>();
if( tracker == null ) { // For some reason RodWalkers get map calls with null trackers
return mapList;
}
boolean isInTruthSet = false;
boolean isTrueVariant = false;
for( final VariantContext vc : tracker.getAllVariantContexts(null, context.getLocation(), false, false) ) {
if( vc != null && vc.getName().toUpperCase().startsWith("TRUTH") ) {
if( !vc.getGenotype(sampleName).isNoCall() ) {
isInTruthSet = true;
if( !vc.getGenotype(sampleName).isHomRef() ) {
isTrueVariant = true;
}
}
//if( vc.isPolymorphic() ) { //BUGBUG: I don't think this is the right thing to do here, there are many polymorphic sites in the truth data because there are many samples
// isTrueVariant = true;
//}
}
}
if( !isInTruthSet ) {
return mapList; // jump out if this location isn't in the truth set
}
int curveIndex = 0;
for( final String inputName : inputRodNames ) {
final VariantContext vc = tracker.getVariantContext( inputName, null, context.getLocation(), false ); // assuming single variant per track per location
if( vc != null && vc.isPolymorphic() && !vc.isFiltered() ) {
final VariantDatum variantDatum = new VariantDatum();
variantDatum.qual = vc.getPhredScaledQual();
variantDatum.isTrueVariant = isTrueVariant;
mapList.add( new Pair<String,VariantDatum>(inputName, variantDatum) );
} else { // Either not in the call set at all, is a monomorphic call, or was filtered out
if( isTrueVariant ) { // ... but it is a true variant so this is a false negative call
falseNegGlobal[curveIndex]++;
} else { // ... and it is not a variant site so this is a true negative call
trueNegGlobal[curveIndex]++;
}
}
curveIndex++;
}
return mapList;
}
//---------------------------------------------------------------------------------------------------------------
//
// reduce
//
//---------------------------------------------------------------------------------------------------------------
public HashMap<String,ExpandingArrayList<VariantDatum>> reduceInit() {
final HashMap<String,ExpandingArrayList<VariantDatum>> init = new HashMap<String,ExpandingArrayList<VariantDatum>>();
for( final String inputName : inputRodNames ) {
init.put( inputName, new ExpandingArrayList<VariantDatum>() );
}
return init;
}
public HashMap<String,ExpandingArrayList<VariantDatum>> reduce( final ExpandingArrayList<Pair<String,VariantDatum>> mapValue, final HashMap<String,ExpandingArrayList<VariantDatum>> reduceSum ) {
for( Pair<String,VariantDatum> value : mapValue ) {
final ExpandingArrayList<VariantDatum> list = reduceSum.get(value.getFirst());
list.add(value.getSecond());
reduceSum.put(value.getFirst(),list);
}
return reduceSum;
}
public void onTraversalDone( HashMap<String,ExpandingArrayList<VariantDatum>> reduceSum ) {
final int NUM_CURVES = numCurves;
final HashMap<String, VariantDataManager> dataManagerMap = new HashMap<String, VariantDataManager>();
for( final String inputName : inputRodNames ) {
System.out.println("Creating data manager for: " + inputName);
dataManagerMap.put(inputName, new VariantDataManager( reduceSum.get(inputName), null ));
}
reduceSum.clear(); // Don't need this ever again, clean up some memory
final double[] minQual = new double[NUM_CURVES];
final double[] maxQual = new double[NUM_CURVES];
final double[] incrementQual = new double[NUM_CURVES];
final double[] qualCut = new double[NUM_CURVES];
final int NUM_STEPS = 200;
int curveIndex = 0;
for( final String inputName : inputRodNames ) {
final VariantDataManager dataManager = dataManagerMap.get(inputName);
minQual[curveIndex] = dataManager.data[0].qual;
maxQual[curveIndex] = dataManager.data[0].qual;
for( int iii = 1; iii < dataManager.data.length; iii++ ) {
final double qual = dataManager.data[iii].qual;
if( qual < minQual[curveIndex] ) { minQual[curveIndex] = qual; }
else if( qual > maxQual[curveIndex] ) { maxQual[curveIndex] = qual; }
}
incrementQual[curveIndex] = (maxQual[curveIndex] - minQual[curveIndex]) / ((double)NUM_STEPS);
qualCut[curveIndex] = minQual[curveIndex];
curveIndex++;
}
final int[] truePos = new int[NUM_CURVES];
final int[] falsePos = new int[NUM_CURVES];
final int[] trueNeg = new int[NUM_CURVES];
final int[] falseNeg = new int[NUM_CURVES];
PrintStream outputFile;
try {
outputFile = new PrintStream( OUTPUT_PREFIX + ".dat" );
} catch (Exception e) {
throw new StingException( "Unable to create output file: " + OUTPUT_PREFIX + ".dat" );
}
int jjj = 1;
for( final String inputName : inputRodNames ) {
outputFile.print(inputName + ",sensitivity" + jjj + ",specificity" + jjj + ",");
jjj++;
}
outputFile.println("sentinel");
for( int step = 0; step < NUM_STEPS; step++ ) {
curveIndex = 0;
for( final String inputName : inputRodNames ) {
final VariantDataManager dataManager = dataManagerMap.get(inputName);
truePos[curveIndex] = 0;
falsePos[curveIndex] = 0;
trueNeg[curveIndex] = 0;
falseNeg[curveIndex] = 0;
final int NUM_VARIANTS = dataManager.data.length;
for( int iii = 0; iii < NUM_VARIANTS; iii++ ) {
if( dataManager.data[iii].qual >= qualCut[curveIndex] ) { // this var is in this hypothetical call set
if( dataManager.data[iii].isTrueVariant ) {
truePos[curveIndex]++;
} else {
falsePos[curveIndex]++;
}
} else { // this var is out of this hypothetical call set
if( dataManager.data[iii].isTrueVariant ) {
falseNeg[curveIndex]++;
} else {
trueNeg[curveIndex]++;
}
}
}
final double sensitivity = ((double) truePos[curveIndex]) / ((double) truePos[curveIndex] + falseNegGlobal[curveIndex] + falseNeg[curveIndex]);
final double specificity = ((double) trueNegGlobal[curveIndex] + trueNeg[curveIndex]) /
((double) falsePos[curveIndex] + trueNegGlobal[curveIndex] + trueNeg[curveIndex]);
outputFile.print( String.format("%.4f,%.4f,%.4f,", qualCut[curveIndex], sensitivity, 1.0 - specificity) );
qualCut[curveIndex] += incrementQual[curveIndex];
curveIndex++;
}
outputFile.println("-1");
}
outputFile.close();
// Execute Rscript command to plot the optimization curve
// Print out the command line to make it clear to the user what is being executed and how one might modify it
final String rScriptCommandLine = PATH_TO_RSCRIPT + " " + PATH_TO_RESOURCES + "plot_variantROCCurve.R" + " " + OUTPUT_PREFIX + ".dat";
System.out.println( rScriptCommandLine );
// Execute the RScript command to plot the table of truth values
try {
Runtime.getRuntime().exec( rScriptCommandLine );
} catch ( IOException e ) {
throw new StingException( "Unable to execute RScript command: " + rScriptCommandLine );
}
}
}

View File

@ -77,8 +77,8 @@ public class VariantDataManager {
annotationKeys = new ExpandingArrayList<String>();
int jjj = 0;
for( String line : annotationLines ) {
String[] vals = line.split(",");
for( final String line : annotationLines ) {
final String[] vals = line.split(",");
annotationKeys.add(vals[1]);
meanVector[jjj] = Double.parseDouble(vals[2]);
varianceVector[jjj] = Double.parseDouble(vals[3]);
@ -87,12 +87,15 @@ public class VariantDataManager {
}
public void normalizeData() {
boolean foundZeroVarianceAnnotation = false;
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
final double theMean = mean(data, jjj);
final double theSTD = standardDeviation(data, theMean, jjj);
System.out.println( (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj)) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) );
if( theSTD < 1E-8 ) {
throw new StingException("Zero variance is a problem: standard deviation = " + theSTD);
foundZeroVarianceAnnotation = true;
System.out.println("Zero variance is a problem: standard deviation = " + theSTD);
System.out.println("User must -exclude annotations with zero variance. Annotation = " + (jjj == numAnnotations-1 ? "QUAL" : annotationKeys.get(jjj)));
}
meanVector[jjj] = theMean;
varianceVector[jjj] = theSTD;
@ -101,6 +104,9 @@ public class VariantDataManager {
}
}
isNormalized = true; // Each data point is now [ (x - mean) / standard deviation ]
if( foundZeroVarianceAnnotation ) {
throw new StingException("Found annotations with zero variance. They must be excluded before proceeding.");
}
}
private static double mean( final VariantDatum[] data, final int index ) {

View File

@ -36,5 +36,6 @@ public class VariantDatum {
public boolean isTransition;
public boolean isKnown;
public boolean isTrueVariant;
public boolean isHet;
public double qual;
}

View File

@ -55,6 +55,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
private final double[][] mu; // The means for the clusters
private final double[][] sigma; // The variances for the clusters, sigma is really sigma^2
private final double[] pCluster;
private final boolean[] isHetCluster;
private final int[] numMaxClusterKnown;
private final int[] numMaxClusterNovel;
private final double[] clusterTITV;
@ -68,12 +69,13 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
public VariantGaussianMixtureModel( final VariantDataManager _dataManager, final double _targetTITV, final int _numGaussians, final int _numIterations, final int _minVarInCluster ) {
super( _targetTITV );
dataManager = _dataManager;
numGaussians = _numGaussians;
numGaussians = ( _numGaussians % 2 == 0 ? _numGaussians : _numGaussians + 1 );
numIterations = _numIterations;
mu = new double[numGaussians][];
sigma = new double[numGaussians][];
pCluster = new double[numGaussians];
isHetCluster = null;
numMaxClusterKnown = new int[numGaussians];
numMaxClusterNovel = new int[numGaussians];
clusterTITV = new double[numGaussians];
@ -83,8 +85,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
public VariantGaussianMixtureModel( final String clusterFileName ) {
super( 0.0 );
ExpandingArrayList<String> annotationLines = new ExpandingArrayList<String>();
ExpandingArrayList<String> clusterLines = new ExpandingArrayList<String>();
final ExpandingArrayList<String> annotationLines = new ExpandingArrayList<String>();
final ExpandingArrayList<String> clusterLines = new ExpandingArrayList<String>();
try {
for ( String line : new xReadLines(new File( clusterFileName )) ) {
@ -102,24 +104,28 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
dataManager = new VariantDataManager( annotationLines );
numIterations = 0;
pCluster = null;
numMaxClusterKnown = null;
numMaxClusterNovel = null;
clusterTITV = null;
minVarInCluster = 0;
//BUGBUG: move this parsing out of the constructor
numGaussians = clusterLines.size();
mu = new double[numGaussians][dataManager.numAnnotations];
sigma = new double[numGaussians][dataManager.numAnnotations];
pCluster = new double[numGaussians];
isHetCluster = new boolean[numGaussians];
clusterTruePositiveRate = new double[numGaussians];
int kkk = 0;
for( String line : clusterLines ) {
String[] vals = line.split(",");
clusterTruePositiveRate[kkk] = Double.parseDouble(vals[4]);
final String[] vals = line.split(",");
isHetCluster[kkk] = Integer.parseInt(vals[1]) == 1;
pCluster[kkk] = Double.parseDouble(vals[2]);
clusterTruePositiveRate[kkk] = Double.parseDouble(vals[6]); //BUGBUG: #define these magic index numbers, very easy to make a mistake here
for( int jjj = 0; jjj < dataManager.numAnnotations; jjj++ ) {
mu[kkk][jjj] = Double.parseDouble(vals[5+jjj]);
sigma[kkk][jjj] = Double.parseDouble(vals[5+dataManager.numAnnotations+jjj]);
mu[kkk][jjj] = Double.parseDouble(vals[7+jjj]);
sigma[kkk][jjj] = Double.parseDouble(vals[7+dataManager.numAnnotations+jjj]) * 3; //BUGBUG: *1.3, suggestion by Nick to prevent GMM from over fitting and producing low likelihoods for most points
}
kkk++;
}
@ -129,13 +135,28 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
public final void run( final String clusterFileName ) {
final int MAX_VARS = 1000000; //BUGBUG: make this a command line argument
// Create the subset of the data to cluster with
int numNovel = 0;
int numKnown = 0;
int numHet = 0;
int numHom = 0;
for( final VariantDatum datum : dataManager.data ) {
if( !datum.isKnown ) {
if( datum.isKnown ) {
numKnown++;
} else {
numNovel++;
}
if( datum.isHet ) {
numHet++;
} else {
numHom++;
}
}
// This block of code is used to cluster with novels + 1.5x knowns mixed together
/*
VariantDatum[] data;
// Grab a set of data that is all of the novel variants plus 1.5x as many known variants drawn at random
@ -167,9 +188,235 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
printClusters( clusterFileName );
//System.out.println("Applying clusters to all variants...");
//applyClusters( dataManager.data, outputPrefix ); // Using all the data
*/
// This block of code is to cluster knowns and novels separately
/*
final VariantDatum[] dataNovel = new VariantDatum[Math.min(numNovel,MAX_VARS)];
final VariantDatum[] dataKnown = new VariantDatum[Math.min(numKnown,MAX_VARS)];
//BUGBUG: This is ugly
int jjj = 0;
if(numNovel <= MAX_VARS) {
for( final VariantDatum datum : dataManager.data ) {
if( !datum.isKnown ) {
dataNovel[jjj++] = datum;
}
}
} else {
System.out.println("Capped at " + MAX_VARS + " novel variants.");
while( jjj < MAX_VARS ) {
final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)];
if( !datum.isKnown ) {
dataNovel[jjj++] = datum;
}
}
}
int iii = 0;
if(numKnown <= MAX_VARS) {
for( final VariantDatum datum : dataManager.data ) {
if( datum.isKnown ) {
dataKnown[iii++] = datum;
}
}
} else {
System.out.println("Capped at " + MAX_VARS + " known variants.");
while( iii < MAX_VARS ) {
final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)];
if( datum.isKnown ) {
dataKnown[iii++] = datum;
}
}
}
System.out.println("Clustering with " + Math.min(numNovel,MAX_VARS) + " novel variants.");
createClusters( dataNovel, 0, numGaussians / 2 );
System.out.println("Clustering with " + Math.min(numKnown,MAX_VARS) + " known variants.");
createClusters( dataKnown, numGaussians / 2, numGaussians );
System.out.println("Outputting cluster parameters...");
printClusters( clusterFileName );
*/
// This block of code is to cluster het and hom calls separately, but mixing together knowns and novels
final VariantDatum[] dataHet = new VariantDatum[Math.min(numHet,MAX_VARS)];
final VariantDatum[] dataHom = new VariantDatum[Math.min(numHom,MAX_VARS)];
//BUGBUG: This is ugly
int jjj = 0;
if(numHet <= MAX_VARS) {
for( final VariantDatum datum : dataManager.data ) {
if( datum.isHet ) {
dataHet[jjj++] = datum;
}
}
} else {
System.out.println("Found " + numHet + " het variants but capped at clustering with " + MAX_VARS + ".");
while( jjj < MAX_VARS ) {
final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)];
if( datum.isHet ) {
dataHet[jjj++] = datum;
}
}
}
int iii = 0;
if(numHom <= MAX_VARS) {
for( final VariantDatum datum : dataManager.data ) {
if( !datum.isHet ) {
dataHom[iii++] = datum;
}
}
} else {
System.out.println("Found " + numHom + " hom variants but capped at clustering with " + MAX_VARS + ".");
while( iii < MAX_VARS ) {
final VariantDatum datum = dataManager.data[rand.nextInt(dataManager.numVariants)];
if( !datum.isHet ) {
dataHom[iii++] = datum;
}
}
}
System.out.println("Clustering with " + Math.min(numHet,MAX_VARS) + " het variants.");
createClusters( dataHet, 0, numGaussians / 2 );
System.out.println("Clustering with " + Math.min(numHom,MAX_VARS) + " hom variants.");
createClusters( dataHom, numGaussians / 2, numGaussians );
System.out.println("Outputting cluster parameters...");
printClusters( clusterFileName );
}
public final void createClusters( final VariantDatum[] data ) {
public final void createClusters( final VariantDatum[] data, int startCluster, int stopCluster ) {
final int numVariants = data.length;
final int numAnnotations = data[0].annotations.length;
final double[][] pVarInCluster = new double[numGaussians][numVariants];
final double[] probTi = new double[numGaussians];
final double[] probTv = new double[numGaussians];
final double[] probKnown = new double[numGaussians];
final double[] probNovel = new double[numGaussians];
final double[] probKnownTi = new double[numGaussians];
final double[] probKnownTv = new double[numGaussians];
final double[] probNovelTi = new double[numGaussians];
final double[] probNovelTv = new double[numGaussians];
// loop control variables:
// iii - loop over data points
// jjj - loop over annotations (features)
// kkk - loop over clusters
// ttt - loop over EM iterations
// Set up the initial random Gaussians
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
numMaxClusterKnown[kkk] = 0;
numMaxClusterNovel[kkk] = 0;
pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster));
mu[kkk] = data[rand.nextInt(numVariants)].annotations;
final double[] randSigma = new double[numAnnotations];
if( dataManager.isNormalized ) {
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble();
}
} else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
randSigma[jjj] = dataManager.varianceVector[jjj] + ((1.0 + rand.nextDouble()) * 0.01 * dataManager.varianceVector[jjj]);
}
}
sigma[kkk] = randSigma;
}
for( int ttt = 0; ttt < numIterations; ttt++ ) {
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Expectation Step (calculate the probability that each data point is in each cluster)
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
evaluateGaussians( data, pVarInCluster, startCluster, stopCluster );
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Maximization Step (move the clusters to maximize the sum probability of each data point)
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
maximizeGaussians( data, pVarInCluster, startCluster, stopCluster );
System.out.println("Finished iteration " + (ttt+1) );
}
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Evaluate the clusters using titv as an estimate of the true positive rate
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); // One final evaluation because the Gaussians moved in the last maximization step
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
probTi[kkk] = 0.0;
probTv[kkk] = 0.0;
probKnown[kkk] = 0.0;
probNovel[kkk] = 0.0;
}
for( int iii = 0; iii < numVariants; iii++ ) {
final boolean isTransition = data[iii].isTransition;
final boolean isKnown = data[iii].isKnown;
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
final double prob = pVarInCluster[kkk][iii];
if( isKnown ) { // known
probKnown[kkk] += prob;
if( isTransition ) { // transition
probKnownTi[kkk] += prob;
probTi[kkk] += prob;
} else { // transversion
probKnownTv[kkk] += prob;
probTv[kkk] += prob;
}
} else { //novel
probNovel[kkk] += prob;
if( isTransition ) { // transition
probNovelTi[kkk] += prob;
probTi[kkk] += prob;
} else { // transversion
probNovelTv[kkk] += prob;
probTv[kkk] += prob;
}
}
}
double maxProb = pVarInCluster[startCluster][iii];
int maxCluster = startCluster;
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
if( pVarInCluster[kkk][iii] > maxProb ) {
maxProb = pVarInCluster[kkk][iii];
maxCluster = kkk;
}
}
if( isKnown ) {
numMaxClusterKnown[maxCluster]++;
} else {
numMaxClusterNovel[maxCluster]++;
}
}
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
clusterTITV[kkk] = probTi[kkk] / probTv[kkk];
if( probKnown[kkk] > 2000.0 ) { // BUGBUG: make this a command line argument, parameterize performance based on this important argument
clusterTruePositiveRate[kkk] = calcTruePositiveRateFromKnownTITV( probKnownTi[kkk] / probKnownTv[kkk], probNovelTi[kkk] / probNovelTv[kkk] );
} else {
clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] );
}
}
}
// This cluster method doesn't make use of the differences between known and novel Ti/Tv ratios
/*
public final void createClusters( final VariantDatum[] data, final int startCluster, final int stopCluster ) {
final int numVariants = data.length;
final int numAnnotations = data[0].annotations.length;
@ -185,15 +432,15 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
// ttt - loop over EM iterations
// Set up the initial random Gaussians
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
numMaxClusterKnown[kkk] = 0;
numMaxClusterNovel[kkk] = 0;
pCluster[kkk] = 1.0 / ((double) numGaussians);
pCluster[kkk] = 1.0 / ((double) (stopCluster - startCluster));
mu[kkk] = data[rand.nextInt(numVariants)].annotations;
final double[] randSigma = new double[numAnnotations];
if( dataManager.isNormalized ) {
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
randSigma[jjj] = 0.75 + 0.4 * rand.nextDouble();
randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble();
}
} else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
@ -209,12 +456,12 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Expectation Step (calculate the probability that each data point is in each cluster)
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
evaluateGaussians( data, pVarInCluster );
evaluateGaussians( data, pVarInCluster, startCluster, stopCluster );
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Maximization Step (move the clusters to maximize the sum probability of each data point)
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
maximizeGaussians( data, pVarInCluster );
maximizeGaussians( data, pVarInCluster, startCluster, stopCluster );
System.out.println("Finished iteration " + (ttt+1) );
}
@ -222,28 +469,28 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// Evaluate the clusters using titv as an estimate of the true positive rate
/////////////////////////////////////////////////////////////////////////////////////////////////////////////////
evaluateGaussians( data, pVarInCluster ); // One final evaluation because the Gaussians moved in the last maximization step
evaluateGaussians( data, pVarInCluster, startCluster, stopCluster ); // One final evaluation because the Gaussians moved in the last maximization step
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
probTi[kkk] = 0.0;
probTv[kkk] = 0.0;
}
// Use the cluster's probabilistic Ti/Tv ratio as the indication of the cluster's true positive rate
for( int iii = 0; iii < numVariants; iii++ ) {
if( data[iii].isTransition ) { // transition
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
probTi[kkk] += pVarInCluster[kkk][iii];
}
} else { // transversion
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
probTv[kkk] += pVarInCluster[kkk][iii];
}
}
// Calculate which cluster has the maximum probability for this variant for use as a metric of how well clustered the data is
double maxProb = pVarInCluster[0][iii];
int maxCluster = 0;
for( int kkk = 1; kkk < numGaussians; kkk++ ) {
double maxProb = pVarInCluster[startCluster][iii];
int maxCluster = startCluster;
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
if( pVarInCluster[kkk][iii] > maxProb ) {
maxProb = pVarInCluster[kkk][iii];
maxCluster = kkk;
@ -255,11 +502,13 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
numMaxClusterNovel[maxCluster]++;
}
}
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
clusterTITV[kkk] = probTi[kkk] / probTv[kkk];
clusterTruePositiveRate[kkk] = calcTruePositiveRateFromTITV( clusterTITV[kkk] );
}
}
*/
private void printClusters( final String clusterFileName ) {
try {
@ -269,6 +518,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
if( numMaxClusterKnown[kkk] + numMaxClusterNovel[kkk] >= minVarInCluster ) {
outputFile.print("@!CLUSTER,");
outputFile.print( (kkk < numGaussians / 2 ? 1 : 0) + "," ); // is het cluster?
outputFile.print(pCluster[kkk] + ",");
outputFile.print(numMaxClusterKnown[kkk] + ",");
outputFile.print(numMaxClusterNovel[kkk] + ",");
outputFile.print(clusterTITV[kkk] + ",");
@ -284,12 +535,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
}
outputFile.close();
} catch (Exception e) {
e.printStackTrace();
System.exit(-1);
throw new StingException( "Unable to create output file: " + clusterFileName );
}
}
public final double evaluateVariant( final Map<String,String> annotationMap, final double qualityScore ) {
public final double evaluateVariant( final Map<String,String> annotationMap, final double qualityScore, final boolean isHet ) {
final double[] pVarInCluster = new double[numGaussians];
final double[] annotations = new double[dataManager.numAnnotations];
@ -314,11 +564,13 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
annotations[jjj] = (value - dataManager.meanVector[jjj]) / dataManager.varianceVector[jjj];
}
evaluateGaussiansForSingleVariant( annotations, pVarInCluster );
evaluateGaussiansForSingleVariant( annotations, pVarInCluster, isHet );
double sum = 0;
double sum = 0.0;
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk];
if( isHetCluster[kkk] == isHet ) {
sum += pVarInCluster[kkk] * clusterTruePositiveRate[kkk];
}
}
return sum;
@ -333,12 +585,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
markedVariant[iii] = false;
}
PrintStream outputFile = null;
PrintStream outputFile;
try {
outputFile = new PrintStream( outputPrefix + ".dat" );
} catch (Exception e) {
e.printStackTrace();
System.exit(-1);
throw new StingException( "Unable to create output file: " + outputPrefix + ".dat" );
}
int numKnown = 0;
@ -382,21 +633,21 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
foundDesiredNumVariants = true;
}
outputFile.println( pCut + "," + numKnown + "," + numNovel + "," +
( numKnownTi ==0 || numKnownTv == 0 ? "NaN" : ( ((double)numKnownTi) / ((double)numKnownTv) ) ) + "," +
( numNovelTi ==0 || numNovelTv == 0 ? "NaN" : ( ((double)numNovelTi) / ((double)numNovelTv) )));
( numKnownTi == 0 || numKnownTv == 0 ? "NaN" : ( ((double)numKnownTi) / ((double)numKnownTv) ) ) + "," +
( numNovelTi == 0 || numNovelTv == 0 ? "NaN" : ( ((double)numNovelTi) / ((double)numNovelTv) )));
}
outputFile.close();
}
private void evaluateGaussians( final VariantDatum[] data, final double[][] pVarInCluster ) {
private void evaluateGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int startCluster, final int stopCluster ) {
final int numAnnotations = data[0].annotations.length;
for( int iii = 0; iii < data.length; iii++ ) {
double sumProb = 0.0;
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
double sum = 0.0;
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
sum += ( (data[iii].annotations[jjj] - mu[kkk][jjj]) * (data[iii].annotations[jjj] - mu[kkk][jjj]) )
@ -412,7 +663,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
}
if( sumProb > MIN_SUM_PROB ) { // Very small numbers are a very big problem
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
pVarInCluster[kkk][iii] /= sumProb;
}
}
@ -420,48 +671,53 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
}
private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster ) {
private void evaluateGaussiansForSingleVariant( final double[] annotations, final double[] pVarInCluster, final boolean isHet ) {
final int numAnnotations = annotations.length;
double sumProb = 0.0;
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
double sum = 0.0;
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
sum += ( (annotations[jjj] - mu[kkk][jjj]) * (annotations[jjj] - mu[kkk][jjj]) )
/ sigma[kkk][jjj];
}
//BUGBUG: removed pCluster[kkk]*, for the second pass
//pVarInCluster[kkk] = pCluster[kkk] * Math.exp( -0.5 * sum );
pVarInCluster[kkk] = Math.exp( -0.5 * sum );
if( isHetCluster[kkk] == isHet ) {
double sum = 0.0;
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
sum += ( (annotations[jjj] - mu[kkk][jjj]) * (annotations[jjj] - mu[kkk][jjj]) )
/ sigma[kkk][jjj];
}
if( pVarInCluster[kkk] < MIN_PROB) { // Very small numbers are a very big problem
pVarInCluster[kkk] = MIN_PROB;
}
//BUGBUG: pCluster[kkk] should be removed here?
pVarInCluster[kkk] = pCluster[kkk] * Math.exp( -0.5 * sum );
//pVarInCluster[kkk] = Math.exp( -0.5 * sum );
sumProb += pVarInCluster[kkk];
if( pVarInCluster[kkk] < MIN_PROB) { // Very small numbers are a very big problem
pVarInCluster[kkk] = MIN_PROB;
}
sumProb += pVarInCluster[kkk];
}
}
if( sumProb > MIN_SUM_PROB ) { // Very small numbers are a very big problem
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
pVarInCluster[kkk] /= sumProb;
if( isHetCluster[kkk] == isHet ) {
pVarInCluster[kkk] /= sumProb;
}
}
}
}
private void maximizeGaussians( final VariantDatum[] data, final double[][] pVarInCluster ) {
private void maximizeGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int startCluster, final int stopCluster ) {
final int numVariants = data.length;
final int numAnnotations = data[0].annotations.length;
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
mu[kkk][jjj] = 0.0;
sigma[kkk][jjj] = 0.0;
}
}
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
double sumProb = 0.0;
for( int iii = 0; iii < numVariants; iii++ ) {
final double prob = pVarInCluster[kkk][iii];
@ -475,7 +731,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
mu[kkk][jjj] /= sumProb;
}
}
} //BUGBUG: clean up dead clusters to speed up computation
for( int iii = 0; iii < numVariants; iii++ ) {
final double prob = pVarInCluster[kkk][iii];
@ -499,9 +755,10 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
pCluster[kkk] = sumProb / numVariants;
}
// Clean up extra big or extra small clusters
//BUGBUG: Is this a good idea?
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
if( pCluster[kkk] > 0.45 ) { // This is a very large cluster compared to all the others
final int numToReplace = 4;
final double[] savedSigma = sigma[kkk];
@ -513,7 +770,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
randVarIndex = rand.nextInt( numVariants );
final double probK = pVarInCluster[kkk][randVarIndex];
boolean inClusterK = true;
for( int ccc = 0; ccc < numGaussians; ccc++ ) {
for( int ccc = startCluster; ccc < stopCluster; ccc++ ) {
if( pVarInCluster[ccc][randVarIndex] > probK ) {
inClusterK = false;
break;
@ -525,11 +782,11 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
// Find a place to put the example variant
if( rrr == 0 ) { // Replace the big cluster that kicked this process off
mu[kkk] = data[randVarIndex].annotations;
pCluster[kkk] = 1.0 / ((double) numGaussians);
pCluster[kkk] = 1.0 / ((double) (stopCluster-startCluster));
} else { // Replace the cluster with the minimum prob
double minProb = pCluster[0];
int minClusterIndex = 0;
for( int ccc = 1; ccc < numGaussians; ccc++ ) {
double minProb = pCluster[startCluster];
int minClusterIndex = startCluster;
for( int ccc = startCluster; ccc < stopCluster; ccc++ ) {
if( pCluster[ccc] < minProb ) {
minProb = pCluster[ccc];
minClusterIndex = ccc;
@ -543,21 +800,22 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
sigma[minClusterIndex][jjj] = MIN_SUM_PROB;
}
}
pCluster[minClusterIndex] = 1.0 / ((double) numGaussians);
pCluster[minClusterIndex] = 1.0 / ((double) (stopCluster-startCluster));
}
}
}
}
// Replace small clusters with another random draw from the dataset
for( int kkk = 0; kkk < numGaussians; kkk++ ) {
if( pCluster[kkk] < 0.05 * (1.0 / ((double) numGaussians)) ) { // This is a very small cluster compared to all the others
pCluster[kkk] = 1.0 / ((double) numGaussians);
for( int kkk = startCluster; kkk < stopCluster; kkk++ ) {
if( pCluster[kkk] < 0.07 * (1.0 / ((double) (stopCluster-startCluster))) ) { // This is a very small cluster compared to all the others
pCluster[kkk] = 1.0 / ((double) (stopCluster-startCluster));
mu[kkk] = data[rand.nextInt(numVariants)].annotations;
final double[] randSigma = new double[numAnnotations];
if( dataManager.isNormalized ) {
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
randSigma[jjj] = 0.6 + 0.4 * rand.nextDouble(); // Explore a wider range of possible sigma values since we are tossing out clusters anyway
randSigma[jjj] = 0.7 + 0.4 * rand.nextDouble(); // BUGBUG: Explore a wider range of possible sigma values since we are tossing out clusters anyway?
}
} else { // BUGBUG: if not normalized then the varianceVector hasn't been calculated --> null pointer
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
@ -567,6 +825,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
sigma[kkk] = randSigma;
}
}
}
}

View File

@ -62,12 +62,11 @@ public final class VariantNearestNeighborsModel extends VariantOptimizationModel
pTrueVariant[iii] = calcTruePositiveRateFromTITV( vTree.calcNeighborhoodTITV( dataManager.data[iii] ) );
}
PrintStream outputFile = null;
PrintStream outputFile;
try {
outputFile = new PrintStream( outputPrefix + ".knn.optimize" );
} catch (Exception e) {
e.printStackTrace();
System.exit(-1);
throw new StingException( "Unable to create output file: " + outputPrefix + ".knn.optimize" );
}
for(int iii = 0; iii < numVariants; iii++) {
outputFile.print(String.format("%.4f",pTrueVariant[iii]) + ",");

View File

@ -51,4 +51,10 @@ public abstract class VariantOptimizationModel implements VariantOptimizationInt
//if( titv < 0.0 ) { titv = 0.0; }
//return ( titv / targetTITV );
}
public final double calcTruePositiveRateFromKnownTITV( final double knownTITV, double novelTITV ) {
if( novelTITV > knownTITV ) { novelTITV -= 2.0f*(novelTITV-knownTITV); }
if( novelTITV < 0.5 ) { novelTITV = 0.5; }
return ( (novelTITV - 0.5) / (knownTITV - 0.5) );
}
}

View File

@ -10,7 +10,9 @@ import org.broadinstitute.sting.utils.ExpandingArrayList;
import org.broadinstitute.sting.utils.StingException;
import org.broadinstitute.sting.utils.cmdLine.Argument;
import java.io.IOException;
import java.util.Arrays;
import java.util.Set;
import java.util.TreeSet;
/*
* Copyright (c) 2010 The Broad Institute
@ -53,8 +55,10 @@ public class VariantOptimizer extends RodWalker<ExpandingArrayList<VariantDatum>
/////////////////////////////
@Argument(fullName="target_titv", shortName="titv", doc="The target Ti/Tv ratio towards which to optimize. (~~2.1 for whole genome experiments)", required=true)
private double TARGET_TITV = 2.1;
@Argument(fullName="ignore_input_filters", shortName="ignoreFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false)
private boolean IGNORE_INPUT_FILTERS = false;
@Argument(fullName="ignore_all_input_filters", shortName="ignoreAllFilters", doc="If specified the optimizer will use variants even if the FILTER column is marked in the VCF file", required=false)
private boolean IGNORE_ALL_INPUT_FILTERS = false;
@Argument(fullName="ignore_filter", shortName="ignoreFilter", doc="If specified the optimizer will use variants even if the specified filter name is marked in the input VCF file", required=false)
private String[] IGNORE_INPUT_FILTERS = null;
@Argument(fullName="exclude_annotation", shortName="exclude", doc="The names of the annotations which should be excluded from the calculations", required=false)
private String[] EXCLUDED_ANNOTATIONS = null;
@Argument(fullName="force_annotation", shortName="force", doc="The names of the annotations which should be forced into the calculations even if they aren't present in every variant", required=false)
@ -62,11 +66,11 @@ public class VariantOptimizer extends RodWalker<ExpandingArrayList<VariantDatum>
@Argument(fullName="clusterFile", shortName="clusterFile", doc="The output cluster file", required=true)
private String CLUSTER_FILENAME = "optimizer.cluster";
@Argument(fullName="numGaussians", shortName="nG", doc="The number of Gaussians to be used in the Gaussian Mixture model", required=false)
private int NUM_GAUSSIANS = 32;
private int NUM_GAUSSIANS = 100;
@Argument(fullName="numIterations", shortName="nI", doc="The number of iterations to be performed in the Gaussian Mixture model", required=false)
private int NUM_ITERATIONS = 10;
@Argument(fullName="minVarInCluster", shortName="minVar", doc="The minimum number of variants in a cluster to be considered a valid cluster. Used to prevent overfitting. Default is 2000.", required=true)
private int MIN_VAR_IN_CLUSTER = 2000;
private int NUM_ITERATIONS = 7;
@Argument(fullName="minVarInCluster", shortName="minVar", doc="The minimum number of variants in a cluster to be considered a valid cluster. It can be used to prevent overfitting.", required=false)
private int MIN_VAR_IN_CLUSTER = 0;
//@Argument(fullName="knn", shortName="knn", doc="The number of nearest neighbors to be used in the k-Nearest Neighbors model", required=false)
//private int NUM_KNN = 2000;
@ -80,6 +84,7 @@ public class VariantOptimizer extends RodWalker<ExpandingArrayList<VariantDatum>
private boolean firstVariant = true;
private int numAnnotations = 0;
private static final double INFINITE_ANNOTATION_VALUE = 10000.0;
private Set<String> ignoreInputFilterSet = null;
//---------------------------------------------------------------------------------------------------------------
//
@ -88,7 +93,9 @@ public class VariantOptimizer extends RodWalker<ExpandingArrayList<VariantDatum>
//---------------------------------------------------------------------------------------------------------------
public void initialize() {
//if( !PATH_TO_RESOURCES.endsWith("/") ) { PATH_TO_RESOURCES = PATH_TO_RESOURCES + "/"; }
if( IGNORE_INPUT_FILTERS != null ) {
ignoreInputFilterSet = new TreeSet<String>(Arrays.asList(IGNORE_INPUT_FILTERS));
}
}
//---------------------------------------------------------------------------------------------------------------
@ -109,49 +116,52 @@ public class VariantOptimizer extends RodWalker<ExpandingArrayList<VariantDatum>
for( final VariantContext vc : tracker.getAllVariantContexts(null, context.getLocation(), false, false) ) {
if( vc != null && vc.isSNP() && (IGNORE_INPUT_FILTERS || !vc.isFiltered()) ) {
if( firstVariant ) { // This is the first variant encountered so set up the list of annotations
annotationKeys.addAll( vc.getAttributes().keySet() );
if( annotationKeys.contains("ID") ) { annotationKeys.remove("ID"); } // ID field is added to the vc's INFO field?
if( annotationKeys.contains("DB") ) { annotationKeys.remove("DB"); }
if( EXCLUDED_ANNOTATIONS != null ) {
for( final String excludedAnnotation : EXCLUDED_ANNOTATIONS ) {
if( annotationKeys.contains( excludedAnnotation ) ) { annotationKeys.remove( excludedAnnotation ); }
if( vc != null && vc.isSNP() ) {
if( !vc.isFiltered() || IGNORE_ALL_INPUT_FILTERS || (ignoreInputFilterSet != null && ignoreInputFilterSet.containsAll(vc.getFilters())) ) {
if( firstVariant ) { // This is the first variant encountered so set up the list of annotations
annotationKeys.addAll( vc.getAttributes().keySet() );
if( annotationKeys.contains("ID") ) { annotationKeys.remove("ID"); } // ID field is added to the vc's INFO field?
if( annotationKeys.contains("DB") ) { annotationKeys.remove("DB"); }
if( EXCLUDED_ANNOTATIONS != null ) {
for( final String excludedAnnotation : EXCLUDED_ANNOTATIONS ) {
if( annotationKeys.contains( excludedAnnotation ) ) { annotationKeys.remove( excludedAnnotation ); }
}
}
}
if( FORCED_ANNOTATIONS != null ) {
for( final String forcedAnnotation : FORCED_ANNOTATIONS ) {
if( !annotationKeys.contains( forcedAnnotation ) ) { annotationKeys.add( forcedAnnotation ); }
if( FORCED_ANNOTATIONS != null ) {
for( final String forcedAnnotation : FORCED_ANNOTATIONS ) {
if( !annotationKeys.contains( forcedAnnotation ) ) { annotationKeys.add( forcedAnnotation ); }
}
}
numAnnotations = annotationKeys.size() + 1; // +1 for variant quality ("QUAL")
annotationValues = new double[numAnnotations];
firstVariant = false;
}
numAnnotations = annotationKeys.size() + 1; // +1 for variant quality ("QUAL")
annotationValues = new double[numAnnotations];
firstVariant = false;
int iii = 0;
for( final String key : annotationKeys ) {
double value = 0.0;
try {
value = Double.parseDouble( (String)vc.getAttribute( key, "0.0" ) );
if( Double.isInfinite(value) ) {
value = ( value > 0 ? 1.0 : -1.0 ) * INFINITE_ANNOTATION_VALUE;
}
} catch( NumberFormatException e ) {
// do nothing, default value is 0.0
}
annotationValues[iii++] = value;
}
// Variant quality ("QUAL") is not in the list of annotations, but is useful so add it here.
annotationValues[iii] = vc.getPhredScaledQual();
final VariantDatum variantDatum = new VariantDatum();
variantDatum.annotations = annotationValues;
variantDatum.isTransition = vc.getSNPSubstitutionType().compareTo(BaseUtils.BaseSubstitutionType.TRANSITION) == 0;
variantDatum.isKnown = !vc.getAttribute("ID").equals(".");
variantDatum.isHet = vc.getHetCount() > vc.getHomVarCount(); // BUGBUG: what to do here for multi sample calls?
mapList.add( variantDatum );
}
int iii = 0;
for( final String key : annotationKeys ) {
double value = 0.0;
try {
value = Double.parseDouble( (String)vc.getAttribute( key, "0.0" ) );
if( Double.isInfinite(value) ) {
value = ( value > 0 ? 1.0 : -1.0 ) * INFINITE_ANNOTATION_VALUE;
}
} catch( NumberFormatException e ) {
// do nothing, default value is 0.0
}
annotationValues[iii++] = value;
}
// Variant quality ("QUAL") is not in the list of annotations, but is useful so add it here.
annotationValues[iii] = vc.getPhredScaledQual();
VariantDatum variantDatum = new VariantDatum();
variantDatum.annotations = annotationValues;
variantDatum.isTransition = vc.getSNPSubstitutionType().compareTo(BaseUtils.BaseSubstitutionType.TRANSITION) == 0;
variantDatum.isKnown = !vc.getAttribute("ID").equals(".");
mapList.add( variantDatum );
}
}