diff --git a/python/snpSelector.py b/python/snpSelector.py index b7db252cc..67cdd93d6 100755 --- a/python/snpSelector.py +++ b/python/snpSelector.py @@ -231,12 +231,27 @@ def validateBins(bins): elif contains2(left1) or contains2(right2): raise Exception("Bad bins", left1, right1, left2, right2) -def printFieldQualHeader(more = ""): - print ' field left right nVariants nNovels titv titvNovels dbSNP fprate q', more +def printFieldQualHeader(): + more = "" + if TRUTH_CALLS <> None: + more = "TP FP FPRate FN FNRate" + def p(stream): + if stream <> None: + print >> stream, ' field left right nVariants nNovels titv titvNovels dbSNP fprate q', more + p(sys.stdout) + p(RECAL_LOG) -def printFieldQual( field, left, right, variants, FPRate, more = ""): +def printFieldQual( field, left, right, variants, FPRate ): + more = "" + if TRUTH_CALLS <> None: + callComparison, theseFPs = sensitivitySpecificity(variants, TRUTH_CALLS) + more = str(callComparison) novels = selectVariants(variants, VCFRecord.isNovel) - print ' %s %s %8d %8d %.2f %.2f %.2f %.2e %d' % (field, binString(left, right), len(variants), len(novels), titv(variants), titv(novels), dbSNPRate(variants), FPRate, phredScale(FPRate)), more + def p(stream): + if stream <> None: + print >> stream, ' %s %s %8d %8d %.2f %.2f %.2f %.2e %d' % (field, binString(left, right), len(variants), len(novels), titv(variants), titv(novels), dbSNPRate(variants), FPRate, phredScale(FPRate)), more + p(sys.stdout) + p(RECAL_LOG) def binString(left, right): leftStr = str(left) @@ -280,11 +295,9 @@ def optimizeCalls(variants, fields, titvTarget): def printCallQuals(recalCalls, titvTarget, info = ""): print '--------------------------------------------------------------------------------' print info - calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = False, forcePrint = True ) + calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = False, forcePrint = True, prefix = "OPT-", printHeader = False ) print 'Cumulative' - calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = True, forcePrint = True ) - - + calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = True, forcePrint = True, prefix = "OPTCUM-", printHeader = False ) def all( p, l ): for elt in l: @@ -310,10 +323,10 @@ def mapVariantBins(variants, field, cumulative = False): return imap( variantsInBin, bins ) -def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulative = False, forcePrint = False ): +def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulative = False, forcePrint = False, prefix = '', printHeader = True ): covariates = [] - printFieldQualHeader() + if printHeader: printFieldQualHeader() for field in fields: if DEBUG: print 'Optimizing field', field @@ -325,7 +338,7 @@ def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulativ titv, FPRate = titvFPRateEstimate(selectedVariants, titvTarget) dbsnp = dbSNPRate(selectedVariants) covariates.append(CallCovariate(field, left, right, FPRate)) - printFieldQual(field, left, right, selectedVariants, FPRate ) + printFieldQual( prefix + field, left, right, selectedVariants, FPRate ) else: print 'Not calibrating bin', left, right, 'because it contains too few variants:', len(selectedVariants) @@ -369,22 +382,20 @@ def sensitivitySpecificity(variants, truth): nFN = len(truth) - nTP return CallCmp(nTP, nFP, nFN), FPs - def compareCalls(calls, truthCalls): for variant in calls: variant.setField("TP", 0) # set the TP field to 0 def compare1(name, cumulative): for left, right, selectedVariants in mapVariantBins(calls, 'QUAL', cumulative = cumulative): - callComparison, theseFPs = sensitivitySpecificity(selectedVariants, truthCalls) #print selectedVariants[0] - printFieldQual(name, left, right, selectedVariants, dephredScale(left), str(callComparison)) + printFieldQual(name, left, right, selectedVariants, dephredScale(left)) print 'PER BIN nCalls=', len(calls) - printFieldQualHeader("TP FP FPRate FN FNRate") - compare1('TRUTH-PER-BIN', False) + # printFieldQualHeader() + compare1('RECAL-QUAL-TRUTH-PER-BIN', False) print 'CUMULATIVE nCalls=', len(calls) - compare1('TRUTH-CUMULATIVE', True) + compare1('RECAL-QUAL-TRUTH-CUMULATIVE', True) def randomSplit(l, pLeft): def keep(elt, p): @@ -406,6 +417,9 @@ def setup(): parser.add_option("-t", "--truth", dest="truth", type='string', default=None, help="VCF formated truth file. If provided, the script will compare the input calls with the truth calls. It also emits calls tagged as TP and a separate file of FP calls") + parser.add_option("-l", "--recalLog", dest="recalLog", + type='string', default="recal.log", + help="VCF formated truth file. If provided, the script will compare the input calls with the truth calls. It also emits calls tagged as TP and a separate file of FP calls") parser.add_option("", "--unFilteredTruth", dest="unFilteredTruth", action='store_true', default=False, help="If provided, the unfiltered truth calls will be used in comparisons [default: %default]") @@ -425,7 +439,7 @@ def setup(): type='int', default=60, help="The maximum Q score allowed for both a single covariate and the overall QUAL score [default: %default]") parser.add_option("-o", "--outputVCF", dest="outputVCF", - type='string', default=None, + type='string', default="recal.vcf", help="If provided, a VCF file will be written out to this file [default: %default]") parser.add_option("", "--FNoutputVCF", dest="FNoutputVCF", type='string', default=None, @@ -491,7 +505,7 @@ def writeRecalibratedCalls(file, header, calls): print >> f, line f.close() -def evaluateTruth(header, callVCF, truthVCF): +def readTruth(truthVCF): print 'Reading truth file', truthVCF rawTruth = list(readVariants(truthVCF, maxRecords = None, decodeAll = False)[1]) def keepVariant(t): @@ -499,7 +513,9 @@ def evaluateTruth(header, callVCF, truthVCF): return OPTIONS.unFilteredTruth or t.passesFilters() truth = dict( [[v.getLoc(), v] for v in filter(keepVariant, rawTruth)]) print len(rawTruth), len(truth) + return truth +def evaluateTruth(header, callVCF, truth, truthVCF): print 'Reading variants back in from', callVCF header, calls = readVariants(callVCF) calls = list(calls) @@ -519,12 +535,28 @@ def evaluateTruth(header, callVCF, truthVCF): print >> f, line f.close() +TRUTH_CALLS = None +RECAL_LOG = None def main(): + global TRUTH_CALLS, RECAL_LOG + args = setup() + fields = OPTIONS.fields.split(',') + + truthVCF = None + if len(args) > 1: + truthVCF = args[1] + TRUTH_CALLS = readTruth(truthVCF) + + if OPTIONS.recalLog <> None: + RECAL_LOG = open(OPTIONS.recalLog, "w") + print >> RECAL_LOG, "# optimized vcf", args[0] + print >> RECAL_LOG, "# truth vcf", truthVCF + for key, value in OPTIONS.__dict__.iteritems(): + print >> RECAL_LOG, '#', key, value header, allCalls, titvTarget = assessCalls(args[0]) if not OPTIONS.dontRecalibrate: - fields = OPTIONS.fields.split(',') covariates = determineCovariates(allCalls, titvTarget, fields) header, callsToRecalibate = readVariants(args[0], OPTIONS.maxRecords) RecalibratedCalls = recalibrateCalls(callsToRecalibate, fields, covariates) @@ -534,7 +566,7 @@ def main(): OPTIONS.outputVCF = args[0] if len(args) > 1: - evaluateTruth(header, OPTIONS.outputVCF, args[1]) + evaluateTruth(header, OPTIONS.outputVCF, TRUTH_CALLS, truthVCF) PROFILE = False