diff --git a/python/snpSelector.py b/python/snpSelector.py index 67cdd93d6..19abce23f 100755 --- a/python/snpSelector.py +++ b/python/snpSelector.py @@ -181,10 +181,13 @@ def compareFieldValues( v1, v2 ): return c def calculateBins(variants, field, minValue, maxValue, partitions): - sortedVariants = sorted(variants, key = lambda x: x.getField(field)) # cmp = compareFieldValues, - sortedValues = map(lambda x: x.getField(field), sortedVariants) + values = map(lambda x: x.getField(field), variants) + return calculateBinsForValues(values, field, minValue, maxValue, partitions) + +def calculateBinsForValues(values, field, minValue, maxValue, partitions): + sortedValues = sorted(values) - targetBinSize = len(variants) / (1.0*partitions) + targetBinSize = len(values) / (1.0*partitions) #print sortedValues uniqBins = groupby(sortedValues) binsAndSizes = map(lambda x: [x[0], len(list(x[1]))], uniqBins) @@ -234,14 +237,14 @@ def validateBins(bins): def printFieldQualHeader(): more = "" if TRUTH_CALLS <> None: - more = "TP FP FPRate FN FNRate" + more = CallCmp.HEADER def p(stream): if stream <> None: - print >> stream, ' field left right nVariants nNovels titv titvNovels dbSNP fprate q', more + print >> stream, ' %20s %20s left right nVariants nNovels titv titvNovels dbSNP FPEstimate Q' % ("category", "field"), more p(sys.stdout) p(RECAL_LOG) -def printFieldQual( field, left, right, variants, FPRate ): +def printFieldQual( category, field, left, right, variants, FPRate ): more = "" if TRUTH_CALLS <> None: callComparison, theseFPs = sensitivitySpecificity(variants, TRUTH_CALLS) @@ -249,7 +252,7 @@ def printFieldQual( field, left, right, variants, FPRate ): novels = selectVariants(variants, VCFRecord.isNovel) def p(stream): if stream <> None: - print >> stream, ' %s %s %8d %8d %.2f %.2f %.2f %.2e %d' % (field, binString(left, right), len(variants), len(novels), titv(variants), titv(novels), dbSNPRate(variants), FPRate, phredScale(FPRate)), more + print >> stream, ' %20s %20s %s %8d %8d %.2f %.2f %.2f %.2e %3d' % (category, field, binString(left, right), len(variants), len(novels), titv(variants), titv(novels), dbSNPRate(variants), FPRate, phredScale(FPRate)), more p(sys.stdout) p(RECAL_LOG) @@ -265,7 +268,7 @@ def binString(left, right): # # def recalibrateCalls(variants, fields, callCovariates): - def phred(v): return int(round(phredScale(v))) + def phred(v): return round(phredScale(v), 2) for variant in variants: recalCall = RecalibratedCall(variant, fields) @@ -277,7 +280,7 @@ def recalibrateCalls(variants, fields, callCovariates): recalCall.recalFeature(callCovariate.getFeature(), FPR) recalCall.call.setField(callCovariate.getCovariateField(), phred(FPR)) - + #recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate())) recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate())) recalCall.call.setField('OQ', originalQual) #print 'recalibrating', variant.getLoc() @@ -288,16 +291,16 @@ def recalibrateCalls(variants, fields, callCovariates): # # def optimizeCalls(variants, fields, titvTarget): - callCovariates = calibrateFeatures(variants, fields, titvTarget) + callCovariates = calibrateFeatures(variants, fields, titvTarget, category = "covariates") recalCalls = recalibrateCalls(variants, fields, callCovariates) return recalCalls, callCovariates -def printCallQuals(recalCalls, titvTarget, info = ""): +def printCallQuals(field, recalCalls, titvTarget, info = ""): print '--------------------------------------------------------------------------------' print info - calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = False, forcePrint = True, prefix = "OPT-", printHeader = False ) + calibrateFeatures(recalCalls, [field], titvTarget, printCall = True, cumulative = False, forcePrint = True, prefix = "OPT-", printHeader = False, category = "optimized-calls" ) print 'Cumulative' - calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = True, forcePrint = True, prefix = "OPTCUM-", printHeader = False ) + calibrateFeatures(recalCalls, [field], titvTarget, printCall = True, cumulative = True, forcePrint = True, prefix = "OPTCUM-", printHeader = False, category = "optimized-calls" ) def all( p, l ): for elt in l: @@ -323,7 +326,7 @@ def mapVariantBins(variants, field, cumulative = False): return imap( variantsInBin, bins ) -def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulative = False, forcePrint = False, prefix = '', printHeader = True ): +def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulative = False, forcePrint = False, prefix = '', printHeader = True, category = None ): covariates = [] if printHeader: printFieldQualHeader() @@ -338,7 +341,7 @@ def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulativ titv, FPRate = titvFPRateEstimate(selectedVariants, titvTarget) dbsnp = dbSNPRate(selectedVariants) covariates.append(CallCovariate(field, left, right, FPRate)) - printFieldQual( prefix + field, left, right, selectedVariants, FPRate ) + printFieldQual( category, prefix + field, left, right, selectedVariants, FPRate ) else: print 'Not calibrating bin', left, right, 'because it contains too few variants:', len(selectedVariants) @@ -350,14 +353,24 @@ class CallCmp: self.nFP = nFP self.nFN = nFN - def FPRate(self): - return (1.0*self.nFP) / max(self.nTP + self.nFP, 1) +# def FPRate(self): +# return (1.0*self.nFP) / max(self.nTP + self.nFP, 1) def FNRate(self): return (1.0*self.nFN) / max(self.nTP + self.nFN, 1) + + def sensitivity(self): + # = TP / (TP + FN) + return (1.0*self.nTP) / max( self.nTP + self.nFN,1 ) + + def PPV(self): + # = TP / (TP + FP) + return (1.0*self.nTP) / max( self.nTP + self.nFP, 1 ) + + HEADER = "TP FP FN FNRate Sensitivity PPV" def __str__(self): - return '%6d %6d %.2f %6d %.2f' % (self.nTP, self.nFP, self.FPRate(), self.nFN, self.FNRate()) + return '%6d %6d %6d %.2f %.2f %.2f' % (self.nTP, self.nFP, self.nFN, self.FNRate(), self.sensitivity(), self.PPV()) def variantInTruth(variant, truth): if variant.getLoc() in truth: @@ -386,16 +399,17 @@ def compareCalls(calls, truthCalls): for variant in calls: variant.setField("TP", 0) # set the TP field to 0 def compare1(name, cumulative): - for left, right, selectedVariants in mapVariantBins(calls, 'QUAL', cumulative = cumulative): - #print selectedVariants[0] - printFieldQual(name, left, right, selectedVariants, dephredScale(left)) + for field in ["QUAL", "OQ"]: + for left, right, selectedVariants in mapVariantBins(calls, field, cumulative = cumulative): + #print selectedVariants[0] + printFieldQual("truth-comparison-" + name, field, left, right, selectedVariants, dephredScale(left)) print 'PER BIN nCalls=', len(calls) # printFieldQualHeader() - compare1('RECAL-QUAL-TRUTH-PER-BIN', False) + compare1('per-bin', False) print 'CUMULATIVE nCalls=', len(calls) - compare1('RECAL-QUAL-TRUTH-CUMULATIVE', True) + compare1('cum', True) def randomSplit(l, pLeft): def keep(elt, p): @@ -486,11 +500,11 @@ def determineCovariates(allCalls, titvtarget, fields): callsToOptimize = allCalls recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, titvtarget) - printCallQuals(list(recalOptCalls), titvtarget, 'OPTIMIZED CALLS') + printCallQuals("QUAL", list(recalOptCalls), titvtarget, 'OPTIMIZED CALLS') if OPTIONS.bootStrap: recalibatedEvalCalls = recalibrateCalls(recalEvalCalls, fields, covariates) - printCallQuals(list(recalibatedEvalCalls), titvtarget, 'BOOTSTRAP EVAL CALLS') + printCallQuals("QUAL", list(recalibatedEvalCalls), titvtarget, 'BOOTSTRAP EVAL CALLS') return covariates @@ -562,7 +576,8 @@ def main(): RecalibratedCalls = recalibrateCalls(callsToRecalibate, fields, covariates) writeRecalibratedCalls(OPTIONS.outputVCF, header, RecalibratedCalls) else: - printCallQuals(allCalls, titvTarget) + printFieldQualHeader() + printCallQuals("QUAL", allCalls, titvTarget) OPTIONS.outputVCF = args[0] if len(args) > 1: