From 3990c6d950078d9685d57e52ce75e9aa2ad73718 Mon Sep 17 00:00:00 2001 From: depristo Date: Mon, 9 Nov 2009 22:48:51 +0000 Subject: [PATCH] snpSelector v3 -- bootstrapping support and VCF output git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@2004 348d0f76-0448-11de-a6fe-93d51630548a --- python/snpSelector.py | 179 +++++++++++++++++++++++++++++++----------- python/vcfReader.py | 48 ++++++++--- 2 files changed, 170 insertions(+), 57 deletions(-) diff --git a/python/snpSelector.py b/python/snpSelector.py index 10d38b303..2c18d81d0 100755 --- a/python/snpSelector.py +++ b/python/snpSelector.py @@ -6,6 +6,26 @@ from vcfReader import * from itertools import * import math +class CallCovariate: + def __init__(self, feature, left, right, FPRate = None, cumulative = False): + self.feature = feature + self.left = left + + if cumulative: + self.right = '*' + else: + self.right = right + + self.FPRate = FPRate + + def containsVariant(self, call): + fieldVal = call.getField(self.feature) + return fieldVal >= self.left and (self.right == '*' or fieldVal <= self.right) + + def getFPRate(self): return self.FPRate + def getFeature(self): return self.feature + + def getCovariateField(self): return self.getFeature() + '_RQ' class RecalibratedCall: def __init__(self, call, features): @@ -43,15 +63,17 @@ class RecalibratedCall: def readVariants( file ): counter = OPTIONS.skip + f = open(file) + header, ignore, lines = readVCFHeader(f) def parseVariant(args): - VCF, counter = args + header1, VCF, counter = args if counter % OPTIONS.skip == 0: return VCF else: return None - return filter(None, map(parseVariant, lines2VCF(open(file)))) + return header, filter(None, map(parseVariant, lines2VCF(lines, extendedOutput = True))) def selectVariants( variants, selector = None ): if selector <> None: @@ -200,29 +222,50 @@ def binString(left, right): rightStr = "%5s" % str(right) if type(right) == float: rightStr = "%.2f" % right return '%8s - %8s' % (leftStr, rightStr) + + +# +# +# +def recalibrateCalls(variants, fields, callCovariates): + def phred(v): return int(round(phredScale(v))) + + newCalls = list() + for variant in variants: + recalCall = RecalibratedCall(variant, fields) + originalQual = variant.getField('QUAL') + + for callCovariate in callCovariates: + if callCovariate.containsVariant(variant): + FPR = callCovariate.getFPRate() + recalCall.recalFeature(callCovariate.getFeature(), FPR) + recalCall.call.setField(callCovariate.getCovariateField(), phred(FPR)) + + recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate())) + recalCall.call.setField('OQ', originalQual) + newCalls.append(recalCall.call) + + return newCalls # # # def optimizeCalls(variants, fields, titvTarget): - recalCalls = calibrateFeatures(variants, fields, titvTarget) + callCovariates = calibrateFeatures(variants, fields, titvTarget) + recalCalls = recalibrateCalls(variants, fields, callCovariates) + return recalCalls, callCovariates - newCalls = list() - for recalCall in recalCalls.itervalues(): - originalQual = recalCall.call.getField('QUAL') - recalCall.call.setField('QUAL', int(round(phredScale(recalCall.jointFPErrorRate())))) - recalCall.call.setField('OQ', originalQual) - newCalls.append(recalCall.call) - - for recalCall in islice(recalCalls.itervalues(), 10): - print recalCall +def printCallQuals(recalCalls, titvTarget, info = ""): + #for recalCall in islice(recalCalls, 10): + # print recalCall print '--------------------------------------------------------------------------------' - print 'RECALIBRATED CALLS' - #newCalls = [x.call for x in recalCalls.itervalues()] - calibrateFeatures(newCalls, ['QUAL'], titvTarget, updateCalls = False, printCall = True ) + print info + calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = False ) + print 'Cumulative' + calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = True ) + - return newCalls def all( p, l ): for elt in l: @@ -238,41 +281,33 @@ def variantBinsForField(variants, field): print 'Partitions', bins return bins -def mapVariantBins(variants, field): +def mapVariantBins(variants, field, cumulative = False): bins = variantBinsForField(variants, field) def variantsInBin(bin): - left, right = bin[0:2] - def select( variant ): return variant.getField(field) >= left and (right == '*' or variant.getField(field) <= right) - return left, right, selectVariants(variants, select) + cc = CallCovariate(field, bin[0], bin[1], cumulative = cumulative) + + return cc.left, cc.right, selectVariants(variants, lambda v: cc.containsVariant(v)) return imap( variantsInBin, bins ) -def calibrateFeatures(variants, fields, titvTarget, updateCalls = True, printCall = False): - if updateCalls: recalCalls = dict([[variant, RecalibratedCall(variant, fields)] for variant in variants]) - +def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulative = False ): + covariates = [] + for field in fields: print 'Optimizing field', field titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget) print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate) - for left, right, selectedVariants in mapVariantBins(variants, field): + for left, right, selectedVariants in mapVariantBins(variants, field, cumulative = cumulative): if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1): titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget) - if updateCalls: - for variant in selectedVariants: recalCalls[variant].recalFeature(field, FPRate) - if printCall: - for call in selectedVariants: - if titv < 0.5: - print call - + covariates.append(CallCovariate(field, left, right, FPRate)) printFieldQual( left, right, selectedVariants, titv, FPRate, nErrors ) - if updateCalls: - return recalCalls - else: - return None + + return covariates class CallCmp: def __init__(self, nTP, nFP, nFN): @@ -301,12 +336,29 @@ def sensitivitySpecificity(variants, truth): return CallCmp(nTP, nFP, nFN) -def compareCalls(optimizedCalls, truthCalls): - for left, right, selectedVariants in mapVariantBins(optimizedCalls, 'QUAL'): - callComparison = sensitivitySpecificity(selectedVariants, truthCalls) - print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison +def compareCalls(calls, truthCalls): + def compare1(cumulative): + for left, right, selectedVariants in mapVariantBins(calls, 'QUAL', cumulative = cumulative): + callComparison = sensitivitySpecificity(selectedVariants, truthCalls) + print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison + print 'PER BIN' + compare1(False) + print 'CUMULATIVE' + compare1(True) + +def randomSplit(l, pLeft): + import random + + def keep(elt, p): + if p < pLeft: + return elt, None + else: + return None, elt + data = [keep(elt, p) for elt, p in zip(l, map(lambda x: random.random(), l))] + def get(i): return filter(lambda x: x <> None, [x[i] for x in data]) + return get(0), get(1) def main(): global OPTIONS @@ -330,34 +382,67 @@ def main(): parser.add_option("-q", "--qMax", dest="maxQScore", type='int', default=30, help="") + parser.add_option("-o", "--outputVCF", dest="outputVCF", + type='string', default=None, + help="If provided, VCF file will be written out to this file") parser.add_option("", "--titv", dest="titvTarget", type='float', default=None, help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls") parser.add_option("", "--fp", dest="printFP", action='store_true', default=False, help="") + parser.add_option("-b", "--bootstrap", dest="bootStrap", + type='float', default=0.0, + help="If provided, the % of the calls used to generate the recalibration tables.") (OPTIONS, args) = parser.parse_args() if len(args) > 2: parser.error("incorrect number of arguments") fields = OPTIONS.fields.split(',') - calls = readVariants(args[0]) - print 'Read', len(calls), 'calls' + header, allCalls = readVariants(args[0]) + print 'Read', len(allCalls), 'calls' + print 'header is', header if OPTIONS.titvTarget == None: OPTIONS.titvTarget = titv(calls, VCFRecord.isKnown) - print 'Ti/Tv all ', titv(calls) - print 'Ti/Tv known', titv(selectVariants(calls, VCFRecord.isKnown)) - print 'Ti/Tv novel', titv(selectVariants(calls, VCFRecord.isNovel)) + print 'Ti/Tv all ', titv(allCalls) + print 'Ti/Tv known', titv(selectVariants(allCalls, VCFRecord.isKnown)) + print 'Ti/Tv novel', titv(selectVariants(allCalls, VCFRecord.isNovel)) - optimizedCalls = optimizeCalls(calls, OPTIONS.fields.split(","), OPTIONS.titvTarget) + if OPTIONS.bootStrap: + callsToOptimize, callsToEval = randomSplit(allCalls, OPTIONS.bootStrap) + else: + callsToOptimize, callsToEval = allCalls, allCalls + + recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, OPTIONS.titvTarget) + printCallQuals(recalOptCalls, OPTIONS.titvTarget, 'OPTIMIZED CALLS') + + if callsToEval <> callsToOptimize: + recalEvalCalls = recalibrateCalls(callsToEval, fields, covariates) + printCallQuals(recalEvalCalls, OPTIONS.titvTarget, 'BOOTSTRAP EVAL CALLS') if len(args) > 1: truthFile = args[1] print 'Reading truth file', truthFile - truth = dict( [[v.getLoc(), v] for v in readVariants(truthFile)]) - compareCalls(optimizedCalls, truth) + truth = dict( [[v.getLoc(), v] for v in readVariants(truthFile)[1]]) + + print '--------------------------------------------------------------------------------' + print 'Comparing calls to truth', truthFile + print '' + + print 'Calls used in optimization' + compareCalls(recalOptCalls, truth) + if callsToEval <> callsToOptimize: + print 'Calls held in reserve (bootstrap)' + compareCalls(recalEvalCalls, truth) + + if OPTIONS.outputVCF: + f = open(OPTIONS.outputVCF, 'w') + print 'HEADER', header + for line in formatVCF(header, allCalls): + print >> f, line + f.close() if __name__ == "__main__": main() \ No newline at end of file diff --git a/python/vcfReader.py b/python/vcfReader.py index a806025a8..f837fadd4 100755 --- a/python/vcfReader.py +++ b/python/vcfReader.py @@ -1,3 +1,7 @@ +import itertools + +VCF_KEYS = "CHROM POS ID REF ALT QUAL FILTER INFO".split() + TRANSITIONS = dict() for p in ["AG", "CT"]: TRANSITIONS[p] = True @@ -19,10 +23,11 @@ def convertToType(d): class VCFRecord: """Simple support for accessing a VCF record""" - def __init__(self, basicBindings, header=None): + def __init__(self, basicBindings, header=None, rest=[]): self.header = header self.info = convertToType(parseInfo(basicBindings["INFO"])) self.bindings = convertToType(basicBindings) + self.rest = rest def hasHeader(self): return self.header <> None def getHeader(self): return self.header @@ -102,6 +107,9 @@ class VCFRecord: def __str__(self): #return str(self.bindings) + " INFO: " + str(self.info) return ' '.join(['%s=%s' % (x,y) for x,y in self.bindings.iteritems()]) + + def format(self): + return '\t'.join([str(self.getField(key)) for key in VCF_KEYS] + self.rest) def parseInfo(s): def handleBoolean(key_val): @@ -116,21 +124,41 @@ def parseInfo(s): def string2VCF(line, header=None): if line[0] != "#": s = line.split() - keys = "CHROM POS ID REF ALT QUAL FILTER INFO".split() - bindings = dict(zip(keys, s[0:8])) - return VCFRecord(bindings, header) + bindings = dict(zip(VCF_KEYS, s[0:8])) + return VCFRecord(bindings, header, rest=s[8:]) else: return None -def lines2VCF(lines): - header = None +def readVCFHeader(lines): + header = [] + columnNames = None + for line in lines: + if line[0] == "#": + header.append(line.strip()) + else: + if header <> []: + columnNames = header[-1] + return header, columnNames, itertools.chain([line], lines) + + +def lines2VCF(lines, extendedOutput = False): + header, columnNames, lines = readVCFHeader(lines) counter = 0 + for line in lines: if line[0] != "#": counter += 1 - vcf = string2VCF(line, header=header) + vcf = string2VCF(line, header=columnNames) if vcf <> None: - yield vcf, counter - else: - header = line[1:].split() + if extendedOutput: + yield header, vcf, counter + else: + yield vcf raise StopIteration() + + +def formatVCF(header, records): + #print records + print records[0] + return itertools.chain(header, map(VCFRecord.format, records)) +