From 7cb51dbc31e112661a87ecb1cd6dda7c5f2d44db Mon Sep 17 00:00:00 2001 From: depristo Date: Fri, 6 Nov 2009 23:00:46 +0000 Subject: [PATCH] snpSelector v1 -- and supporting changes to VCF reader git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@1983 348d0f76-0448-11de-a6fe-93d51630548a --- python/snpSelector.py | 307 ++++++++++++++++++++++++++++++++++++++++++ python/vcfReader.py | 59 +++++++- 2 files changed, 362 insertions(+), 4 deletions(-) create mode 100755 python/snpSelector.py diff --git a/python/snpSelector.py b/python/snpSelector.py new file mode 100755 index 000000000..0982e9e33 --- /dev/null +++ b/python/snpSelector.py @@ -0,0 +1,307 @@ +import os.path +import sys +from optparse import OptionParser +from vcfReader import * +#import pylab +from itertools import * +import math + + +class RecalibratedCall: + def __init__(self, call, features): + self.call = call + self.features = dict([[feature, None] for feature in features]) + + def recalFeature( self, feature, FPRate ): + assert self.features[feature] == None # not reassigning values + assert FPRate <= 1 and FPRate >= 0 + self.features[feature] = FPRate + + def getFeature( self, feature, missingValue = None, phredScaleValue = False ): + v = self.features[feature] + if v == None: + return missingValue + elif phredScaleValue: + return phredScale(v) + else: + return v + + def jointFPErrorRate(self): + #print self.features + logTPRates = [math.log10(1-r) for r in self.features.itervalues() if r <> None] + logJointTPRate = reduce(lambda x, y: x + y, logTPRates, 0) + jointTPRate = math.pow(10, logJointTPRate) + #print logTPRates + #print logJointTPRate, jointTPRate + return 1 - jointTPRate + + def featureStringList(self): + return ','.join(map(lambda feature: '%s=Q%d' % (feature, self.getFeature(feature, '*', True)), self.features.iterkeys())) + + def __str__(self): + return '[%s: %s => Q%d]' % (str(self.call), self.featureStringList(), phredScale(self.jointFPErrorRate())) + +def readVariants( file ): + counter = OPTIONS.skip + + def parseVariant(args): + VCF, counter = args + if counter % OPTIONS.skip == 0: + return VCF + else: + return None + + return filter(None, map(parseVariant, lines2VCF(open(file)))) + +def selectVariants( variants, selector = None ): + if selector <> None: + return filter(selector, variants) + else: + return variants + +def titv(variants): + ti = len(filter(VCFRecord.isTransition, variants)) + tv = len(variants) - ti + titv = ti / (1.0*max(tv,1)) + + return titv, ti, tv + + +def gaussian(x, mu, sigma): + constant = 1 / math.sqrt(2 * math.pi * sigma**2) + exponent = -1 * ( x - mu )**2 / (2 * sigma**2) + return constant * math.exp(exponent) + +DEBUG = False + +# if target = T, and FP calls have ti/tv = 0.5, we want to know how many FP calls +# there are in N calls with ti/tv of X. +# +def titvFPRateEstimate(variants, target): + titvRatio, ti, tv = titv(variants) + + # f <- function(To,T) { (To - T) / (1/2 - T) + 0.001 } + def theoreticalCalc(): + if titvRatio >= target: + FPRate = 0 + else: + FPRate = (titvRatio - target) / (0.5 - target) + FPRate = min(max(FPRate, 0), 1) + TPRate = max(min(1 - FPRate, 1 - dephredScale(OPTIONS.maxQScore)), dephredScale(OPTIONS.maxQScore)) + print 'FPRate', FPRate, titvRatio, target + assert FPRate >= 0 and FPRate <= 1 + return TPRate + + # gaussian model + def gaussianModel(): + LEFT_HANDED = True + sigma = 5 + constant = 1 / math.sqrt(2 * math.pi * sigma**2) + exponent = -1 * ( titvRatio - target )**2 / (2 * sigma**2) + TPRate = gaussian(titvRatio, target, sigma) / gaussian(target, target, sigma) + if LEFT_HANDED and titvRatio >= target: + TPRate = 1 + TPRate -= dephredScale(OPTIONS.maxQScore) + if DEBUG: print 'TPRate', TPRate, constant, exponent, dephredScale(OPTIONS.maxQScore) + return TPRate + + #denom = (0.2 - 0.8 * titvRatio) + #FPRate = 1 + #if denom <> 0: + # FPRate = (1.0 / (target+1)) * (titvRatio - target) / denom + + FPRate = 1 - gaussianModel() + nVariants = len(variants) + if nVariants > 0: + impliedNoErrors = nVariants * FPRate + calcTiTv = (impliedNoErrors * 0.5 + target * (nVariants-impliedNoErrors)) / nVariants + else: + impliedNoErrors, calcTiTv = 0, 0 + + if DEBUG: print ':::', nVariants, titvRatio, target, ti, tv, FPRate, impliedNoErrors, calcTiTv + + return titvRatio, FPRate, impliedNoErrors + +def phredScale(errorRate): + return -10 * math.log10(max(errorRate, 1e-10)) + +def dephredScale(qscore): + return math.pow(10, qscore / -10) + +def frange6(*args): + """A float range generator.""" + start = 0.0 + step = 1.0 + + l = len(args) + if l == 1: + end = args[0] + elif l == 2: + start, end = args + elif l == 3: + start, end, step = args + if step == 0.0: + raise ValueError, "step must not be zero" + else: + raise TypeError, "frange expects 1-3 arguments, got %d" % l + + v = start + while True: + if (step > 0 and v >= end) or (step < 0 and v <= end): + raise StopIteration + yield v + v += step + +def calculateBins(variants, field, minValue, maxValue, rangeValue, partitions): + sortedVariants = sorted(variants, key = lambda x: x.getField(field)) + sortedValues = map(lambda x: x.getField(field), sortedVariants) + + targetBinSize = len(variants) / (1.0*partitions) + print 'targetBinSize', targetBinSize + uniqBins = groupby(sortedValues) + binsAndSizes = map(lambda x: [x[0], len(list(x[1]))], uniqBins) + #print binsAndSizes + + def bin2Break(bin): return [bin[0], bin[0], bin[1]] + bins = [bin2Break(binsAndSizes[0])] + for bin in binsAndSizes[1:]: + #print 'Breaks', bins + curSize = bin[1] + prevSize = bins[-1][2] + if curSize + prevSize > targetBinSize: + bins.append(bin2Break(bin)) + else: + bins[-1][1] = bin[0] + bins[-1][2] += curSize + + return bins + + +def calculateBinsLinear(variants, minValue, maxValue, rangeValue, partitions): + breaks = list(frange6(minValue, maxValue, rangeValue / partitions)) + if breaks[len(breaks)-1] <> maxValue: + breaks = breaks + ['*'] + return zip(breaks, map( lambda x: x - 0.001, breaks[1:])) + +def fieldRange(variants, field): + values = map(lambda v: v.getField(field), variants) + minValue = min(values) + maxValue = max(values) + rangeValue = maxValue - minValue + bins = calculateBins(variants, field, minValue, maxValue, rangeValue, OPTIONS.partitions) + return minValue, maxValue, rangeValue, bins + +def printFieldQual( left, right, variants, titv, FPRate, nErrors ): + leftStr = str(left) + if type(left) == float: leftStr = "%.2f" % left + rightStr = "%5s" % str(right) + if type(right) == float: rightStr = "%.2f" % right + #print 'FPRATe', FPRate, phredScale(FPRate) + print ' %8s - %8s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (leftStr, rightStr, len(variants), titv, FPRate, phredScale(FPRate)) +# +# +# +def optimizeCalls(variants, fields, titvTarget): + recalCalls = calibrateFeatures(variants, fields, titvTarget) + + newCalls = list() + for recalCall in recalCalls.itervalues(): + originalQual = recalCall.call.getField('QUAL') + recalCall.call.setField('QUAL', int(round(phredScale(recalCall.jointFPErrorRate())))) + recalCall.call.setField('OQ', originalQual) + newCalls.append(recalCall.call) + + for recalCall in islice(recalCalls.itervalues(), 10): + print recalCall + + print '--------------------------------------------------------------------------------' + print 'RECALIBRATED CALLS' + #newCalls = [x.call for x in recalCalls.itervalues()] + calibrateFeatures(newCalls, ['QUAL'], titvTarget, updateCalls = False, printCall = True ) + + return newCalls + +def all( p, l ): + for elt in l: + if not p(elt): return False + return True + +def calibrateFeatures(variants, fields, titvTarget, updateCalls = True, printCall = False): + if updateCalls: recalCalls = dict([[variant, RecalibratedCall(variant, fields)] for variant in variants]) + + for field in fields: + print 'Optimizing field', field + + if not all( lambda x: x.hasField(field), variants): + raise Exception('Unknown field ' + field) + + minValue, maxValue, range, bins = fieldRange(variants, field) + print 'Field range', minValue, maxValue, range + print 'Partitions', bins + titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget) + print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate) + + for left, right in map(lambda x: [x[0], x[1]], bins): + #print 'LR:', left, right + def select( variant ): return variant.getField(field) >= left and (right == '*' or variant.getField(field) <= right) + selectedVariants = selectVariants(variants, select) + if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1): + titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget) + if updateCalls: + for variant in selectedVariants: recalCalls[variant].recalFeature(field, FPRate) + if printCall: + for call in selectedVariants: + if titv < 0.5: + print call + + printFieldQual( left, right, selectedVariants, titv, FPRate, nErrors ) + + if updateCalls: + return recalCalls + else: + return None + +def main(): + global OPTIONS + usage = "usage: %prog files.list [options]" + parser = OptionParser(usage=usage) + parser.add_option("-f", "--f", dest="fields", + type='string', default="QUAL", + help="Comma-separated list of fields to exact") + parser.add_option("-t", "--truth", dest="truth", + type='string', default=None, + help="VCF formated truth file") + parser.add_option("-p", "--partitions", dest="partitions", + type='int', default=10, + help="Number of partitions to examine") + parser.add_option("-s", "--s", dest="skip", + type='int', default=1, + help="Only work with every 1 / skip records") + parser.add_option("-m", "--minVariantsPerBin", dest="minVariantsPerBin", + type='int', default=10, + help="") + parser.add_option("-q", "--qMax", dest="maxQScore", + type='int', default=30, + help="") + parser.add_option("", "--titv", dest="titvTarget", + type='float', default=None, + help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls") + + (OPTIONS, args) = parser.parse_args() + if len(args) != 2: + parser.error("incorrect number of arguments") + + fields = OPTIONS.fields.split(',') + calls = readVariants(args[0]) + print 'Read', len(calls), 'calls' + + if OPTIONS.titvTarget == None: + OPTIONS.titvTarget = titv(calls, VCFRecord.isKnown) + print 'Ti/Tv all ', titv(calls) + print 'Ti/Tv known', titv(selectVariants(calls, VCFRecord.isKnown)) + print 'Ti/Tv novel', titv(selectVariants(calls, VCFRecord.isNovel)) + + optimizeCalls(calls, OPTIONS.fields.split(","), OPTIONS.titvTarget) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/python/vcfReader.py b/python/vcfReader.py index e793f8aae..c8b8ec8e5 100755 --- a/python/vcfReader.py +++ b/python/vcfReader.py @@ -1,9 +1,28 @@ +TRANSITIONS = dict() +for p in ["AG", "CT"]: + TRANSITIONS[p] = True + TRANSITIONS[''.join(reversed(p))] = True + +def convertToType(d): + out = dict() + types = [int, float, str] + for key, value in d.items(): + for type in types: + try: + #print 'Parsing', key, value, type + out[key] = type(value) + #print ' Parsed as', key, value, type + break + except: + pass + return out + class VCFRecord: """Simple support for accessing a VCF record""" def __init__(self, basicBindings, header=None): self.header = header - self.bindings = basicBindings - self.info = parseInfo(basicBindings["INFO"]) + self.info = convertToType(parseInfo(basicBindings["INFO"])) + self.bindings = convertToType(basicBindings) def hasHeader(self): return self.header <> None def getHeader(self): return self.header @@ -21,11 +40,34 @@ class VCFRecord: def getAlt(self): return self.get("ALT") def getQual(self): return self.get("QUAL") + def getVariation(self): return self.getRef() + self.getAlt() + + def isTransition(self): + #print self.getVariation(), TRANSITIONS + return self.getVariation() in TRANSITIONS + def isTransversion(self): + return not self.isTransition() + def getFilter(self): return self.get("FILTER") def failsFilters(self): return not self.passesFilters() def passesFilters(self): #print self.getFilter(), ">>>", self return self.getFilter() == "." or self.getFilter() == "0" + + def hasField(self, field): + return field in self.bindings or field in self.info + + def setField(self, field, value): + assert value <> None + + #print 'setting field', field, value + #print 'getInfo', self.getInfo() + if field in self.bindings: + self.bindings[field] = value + else: + self.info[field] = value + self.setField("INFO", self.getInfo()) + #print 'getInfo', self.getInfo() def getField(self, field, default = None): if field in self.bindings: @@ -35,7 +77,15 @@ class VCFRecord: else: return default - def getInfo(self): return self.get("INFO") + #def getInfo(self): return self.get("INFO") + def getInfo(self): + def info2str(x,y): + if type(y) == bool: + return str(x) + else: + return str(x) + '=' + str(y) + return ';'.join(map(lambda x: info2str(*x), self.info.iteritems())) + def getInfoDict(self): return self.info def getInfoKey(self, name, default = None): @@ -49,7 +99,8 @@ class VCFRecord: return all(map(lambda key: key in self.getInfo(), keys)) def __str__(self): - return str(self.bindings) + " INFO: " + str(self.info) + #return str(self.bindings) + " INFO: " + str(self.info) + return ' '.join(['%s=%s' % (x,y) for x,y in self.bindings.iteritems()]) def parseInfo(s): def handleBoolean(key_val):