snpSelector v1 -- and supporting changes to VCF reader
git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@1983 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
parent
66d4a995e6
commit
7cb51dbc31
|
|
@ -0,0 +1,307 @@
|
||||||
|
import os.path
|
||||||
|
import sys
|
||||||
|
from optparse import OptionParser
|
||||||
|
from vcfReader import *
|
||||||
|
#import pylab
|
||||||
|
from itertools import *
|
||||||
|
import math
|
||||||
|
|
||||||
|
|
||||||
|
class RecalibratedCall:
|
||||||
|
def __init__(self, call, features):
|
||||||
|
self.call = call
|
||||||
|
self.features = dict([[feature, None] for feature in features])
|
||||||
|
|
||||||
|
def recalFeature( self, feature, FPRate ):
|
||||||
|
assert self.features[feature] == None # not reassigning values
|
||||||
|
assert FPRate <= 1 and FPRate >= 0
|
||||||
|
self.features[feature] = FPRate
|
||||||
|
|
||||||
|
def getFeature( self, feature, missingValue = None, phredScaleValue = False ):
|
||||||
|
v = self.features[feature]
|
||||||
|
if v == None:
|
||||||
|
return missingValue
|
||||||
|
elif phredScaleValue:
|
||||||
|
return phredScale(v)
|
||||||
|
else:
|
||||||
|
return v
|
||||||
|
|
||||||
|
def jointFPErrorRate(self):
|
||||||
|
#print self.features
|
||||||
|
logTPRates = [math.log10(1-r) for r in self.features.itervalues() if r <> None]
|
||||||
|
logJointTPRate = reduce(lambda x, y: x + y, logTPRates, 0)
|
||||||
|
jointTPRate = math.pow(10, logJointTPRate)
|
||||||
|
#print logTPRates
|
||||||
|
#print logJointTPRate, jointTPRate
|
||||||
|
return 1 - jointTPRate
|
||||||
|
|
||||||
|
def featureStringList(self):
|
||||||
|
return ','.join(map(lambda feature: '%s=Q%d' % (feature, self.getFeature(feature, '*', True)), self.features.iterkeys()))
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return '[%s: %s => Q%d]' % (str(self.call), self.featureStringList(), phredScale(self.jointFPErrorRate()))
|
||||||
|
|
||||||
|
def readVariants( file ):
|
||||||
|
counter = OPTIONS.skip
|
||||||
|
|
||||||
|
def parseVariant(args):
|
||||||
|
VCF, counter = args
|
||||||
|
if counter % OPTIONS.skip == 0:
|
||||||
|
return VCF
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
return filter(None, map(parseVariant, lines2VCF(open(file))))
|
||||||
|
|
||||||
|
def selectVariants( variants, selector = None ):
|
||||||
|
if selector <> None:
|
||||||
|
return filter(selector, variants)
|
||||||
|
else:
|
||||||
|
return variants
|
||||||
|
|
||||||
|
def titv(variants):
|
||||||
|
ti = len(filter(VCFRecord.isTransition, variants))
|
||||||
|
tv = len(variants) - ti
|
||||||
|
titv = ti / (1.0*max(tv,1))
|
||||||
|
|
||||||
|
return titv, ti, tv
|
||||||
|
|
||||||
|
|
||||||
|
def gaussian(x, mu, sigma):
|
||||||
|
constant = 1 / math.sqrt(2 * math.pi * sigma**2)
|
||||||
|
exponent = -1 * ( x - mu )**2 / (2 * sigma**2)
|
||||||
|
return constant * math.exp(exponent)
|
||||||
|
|
||||||
|
DEBUG = False
|
||||||
|
|
||||||
|
# if target = T, and FP calls have ti/tv = 0.5, we want to know how many FP calls
|
||||||
|
# there are in N calls with ti/tv of X.
|
||||||
|
#
|
||||||
|
def titvFPRateEstimate(variants, target):
|
||||||
|
titvRatio, ti, tv = titv(variants)
|
||||||
|
|
||||||
|
# f <- function(To,T) { (To - T) / (1/2 - T) + 0.001 }
|
||||||
|
def theoreticalCalc():
|
||||||
|
if titvRatio >= target:
|
||||||
|
FPRate = 0
|
||||||
|
else:
|
||||||
|
FPRate = (titvRatio - target) / (0.5 - target)
|
||||||
|
FPRate = min(max(FPRate, 0), 1)
|
||||||
|
TPRate = max(min(1 - FPRate, 1 - dephredScale(OPTIONS.maxQScore)), dephredScale(OPTIONS.maxQScore))
|
||||||
|
print 'FPRate', FPRate, titvRatio, target
|
||||||
|
assert FPRate >= 0 and FPRate <= 1
|
||||||
|
return TPRate
|
||||||
|
|
||||||
|
# gaussian model
|
||||||
|
def gaussianModel():
|
||||||
|
LEFT_HANDED = True
|
||||||
|
sigma = 5
|
||||||
|
constant = 1 / math.sqrt(2 * math.pi * sigma**2)
|
||||||
|
exponent = -1 * ( titvRatio - target )**2 / (2 * sigma**2)
|
||||||
|
TPRate = gaussian(titvRatio, target, sigma) / gaussian(target, target, sigma)
|
||||||
|
if LEFT_HANDED and titvRatio >= target:
|
||||||
|
TPRate = 1
|
||||||
|
TPRate -= dephredScale(OPTIONS.maxQScore)
|
||||||
|
if DEBUG: print 'TPRate', TPRate, constant, exponent, dephredScale(OPTIONS.maxQScore)
|
||||||
|
return TPRate
|
||||||
|
|
||||||
|
#denom = (0.2 - 0.8 * titvRatio)
|
||||||
|
#FPRate = 1
|
||||||
|
#if denom <> 0:
|
||||||
|
# FPRate = (1.0 / (target+1)) * (titvRatio - target) / denom
|
||||||
|
|
||||||
|
FPRate = 1 - gaussianModel()
|
||||||
|
nVariants = len(variants)
|
||||||
|
if nVariants > 0:
|
||||||
|
impliedNoErrors = nVariants * FPRate
|
||||||
|
calcTiTv = (impliedNoErrors * 0.5 + target * (nVariants-impliedNoErrors)) / nVariants
|
||||||
|
else:
|
||||||
|
impliedNoErrors, calcTiTv = 0, 0
|
||||||
|
|
||||||
|
if DEBUG: print ':::', nVariants, titvRatio, target, ti, tv, FPRate, impliedNoErrors, calcTiTv
|
||||||
|
|
||||||
|
return titvRatio, FPRate, impliedNoErrors
|
||||||
|
|
||||||
|
def phredScale(errorRate):
|
||||||
|
return -10 * math.log10(max(errorRate, 1e-10))
|
||||||
|
|
||||||
|
def dephredScale(qscore):
|
||||||
|
return math.pow(10, qscore / -10)
|
||||||
|
|
||||||
|
def frange6(*args):
|
||||||
|
"""A float range generator."""
|
||||||
|
start = 0.0
|
||||||
|
step = 1.0
|
||||||
|
|
||||||
|
l = len(args)
|
||||||
|
if l == 1:
|
||||||
|
end = args[0]
|
||||||
|
elif l == 2:
|
||||||
|
start, end = args
|
||||||
|
elif l == 3:
|
||||||
|
start, end, step = args
|
||||||
|
if step == 0.0:
|
||||||
|
raise ValueError, "step must not be zero"
|
||||||
|
else:
|
||||||
|
raise TypeError, "frange expects 1-3 arguments, got %d" % l
|
||||||
|
|
||||||
|
v = start
|
||||||
|
while True:
|
||||||
|
if (step > 0 and v >= end) or (step < 0 and v <= end):
|
||||||
|
raise StopIteration
|
||||||
|
yield v
|
||||||
|
v += step
|
||||||
|
|
||||||
|
def calculateBins(variants, field, minValue, maxValue, rangeValue, partitions):
|
||||||
|
sortedVariants = sorted(variants, key = lambda x: x.getField(field))
|
||||||
|
sortedValues = map(lambda x: x.getField(field), sortedVariants)
|
||||||
|
|
||||||
|
targetBinSize = len(variants) / (1.0*partitions)
|
||||||
|
print 'targetBinSize', targetBinSize
|
||||||
|
uniqBins = groupby(sortedValues)
|
||||||
|
binsAndSizes = map(lambda x: [x[0], len(list(x[1]))], uniqBins)
|
||||||
|
#print binsAndSizes
|
||||||
|
|
||||||
|
def bin2Break(bin): return [bin[0], bin[0], bin[1]]
|
||||||
|
bins = [bin2Break(binsAndSizes[0])]
|
||||||
|
for bin in binsAndSizes[1:]:
|
||||||
|
#print 'Breaks', bins
|
||||||
|
curSize = bin[1]
|
||||||
|
prevSize = bins[-1][2]
|
||||||
|
if curSize + prevSize > targetBinSize:
|
||||||
|
bins.append(bin2Break(bin))
|
||||||
|
else:
|
||||||
|
bins[-1][1] = bin[0]
|
||||||
|
bins[-1][2] += curSize
|
||||||
|
|
||||||
|
return bins
|
||||||
|
|
||||||
|
|
||||||
|
def calculateBinsLinear(variants, minValue, maxValue, rangeValue, partitions):
|
||||||
|
breaks = list(frange6(minValue, maxValue, rangeValue / partitions))
|
||||||
|
if breaks[len(breaks)-1] <> maxValue:
|
||||||
|
breaks = breaks + ['*']
|
||||||
|
return zip(breaks, map( lambda x: x - 0.001, breaks[1:]))
|
||||||
|
|
||||||
|
def fieldRange(variants, field):
|
||||||
|
values = map(lambda v: v.getField(field), variants)
|
||||||
|
minValue = min(values)
|
||||||
|
maxValue = max(values)
|
||||||
|
rangeValue = maxValue - minValue
|
||||||
|
bins = calculateBins(variants, field, minValue, maxValue, rangeValue, OPTIONS.partitions)
|
||||||
|
return minValue, maxValue, rangeValue, bins
|
||||||
|
|
||||||
|
def printFieldQual( left, right, variants, titv, FPRate, nErrors ):
|
||||||
|
leftStr = str(left)
|
||||||
|
if type(left) == float: leftStr = "%.2f" % left
|
||||||
|
rightStr = "%5s" % str(right)
|
||||||
|
if type(right) == float: rightStr = "%.2f" % right
|
||||||
|
#print 'FPRATe', FPRate, phredScale(FPRate)
|
||||||
|
print ' %8s - %8s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (leftStr, rightStr, len(variants), titv, FPRate, phredScale(FPRate))
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
def optimizeCalls(variants, fields, titvTarget):
|
||||||
|
recalCalls = calibrateFeatures(variants, fields, titvTarget)
|
||||||
|
|
||||||
|
newCalls = list()
|
||||||
|
for recalCall in recalCalls.itervalues():
|
||||||
|
originalQual = recalCall.call.getField('QUAL')
|
||||||
|
recalCall.call.setField('QUAL', int(round(phredScale(recalCall.jointFPErrorRate()))))
|
||||||
|
recalCall.call.setField('OQ', originalQual)
|
||||||
|
newCalls.append(recalCall.call)
|
||||||
|
|
||||||
|
for recalCall in islice(recalCalls.itervalues(), 10):
|
||||||
|
print recalCall
|
||||||
|
|
||||||
|
print '--------------------------------------------------------------------------------'
|
||||||
|
print 'RECALIBRATED CALLS'
|
||||||
|
#newCalls = [x.call for x in recalCalls.itervalues()]
|
||||||
|
calibrateFeatures(newCalls, ['QUAL'], titvTarget, updateCalls = False, printCall = True )
|
||||||
|
|
||||||
|
return newCalls
|
||||||
|
|
||||||
|
def all( p, l ):
|
||||||
|
for elt in l:
|
||||||
|
if not p(elt): return False
|
||||||
|
return True
|
||||||
|
|
||||||
|
def calibrateFeatures(variants, fields, titvTarget, updateCalls = True, printCall = False):
|
||||||
|
if updateCalls: recalCalls = dict([[variant, RecalibratedCall(variant, fields)] for variant in variants])
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
print 'Optimizing field', field
|
||||||
|
|
||||||
|
if not all( lambda x: x.hasField(field), variants):
|
||||||
|
raise Exception('Unknown field ' + field)
|
||||||
|
|
||||||
|
minValue, maxValue, range, bins = fieldRange(variants, field)
|
||||||
|
print 'Field range', minValue, maxValue, range
|
||||||
|
print 'Partitions', bins
|
||||||
|
titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget)
|
||||||
|
print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate)
|
||||||
|
|
||||||
|
for left, right in map(lambda x: [x[0], x[1]], bins):
|
||||||
|
#print 'LR:', left, right
|
||||||
|
def select( variant ): return variant.getField(field) >= left and (right == '*' or variant.getField(field) <= right)
|
||||||
|
selectedVariants = selectVariants(variants, select)
|
||||||
|
if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1):
|
||||||
|
titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget)
|
||||||
|
if updateCalls:
|
||||||
|
for variant in selectedVariants: recalCalls[variant].recalFeature(field, FPRate)
|
||||||
|
if printCall:
|
||||||
|
for call in selectedVariants:
|
||||||
|
if titv < 0.5:
|
||||||
|
print call
|
||||||
|
|
||||||
|
printFieldQual( left, right, selectedVariants, titv, FPRate, nErrors )
|
||||||
|
|
||||||
|
if updateCalls:
|
||||||
|
return recalCalls
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def main():
|
||||||
|
global OPTIONS
|
||||||
|
usage = "usage: %prog files.list [options]"
|
||||||
|
parser = OptionParser(usage=usage)
|
||||||
|
parser.add_option("-f", "--f", dest="fields",
|
||||||
|
type='string', default="QUAL",
|
||||||
|
help="Comma-separated list of fields to exact")
|
||||||
|
parser.add_option("-t", "--truth", dest="truth",
|
||||||
|
type='string', default=None,
|
||||||
|
help="VCF formated truth file")
|
||||||
|
parser.add_option("-p", "--partitions", dest="partitions",
|
||||||
|
type='int', default=10,
|
||||||
|
help="Number of partitions to examine")
|
||||||
|
parser.add_option("-s", "--s", dest="skip",
|
||||||
|
type='int', default=1,
|
||||||
|
help="Only work with every 1 / skip records")
|
||||||
|
parser.add_option("-m", "--minVariantsPerBin", dest="minVariantsPerBin",
|
||||||
|
type='int', default=10,
|
||||||
|
help="")
|
||||||
|
parser.add_option("-q", "--qMax", dest="maxQScore",
|
||||||
|
type='int', default=30,
|
||||||
|
help="")
|
||||||
|
parser.add_option("", "--titv", dest="titvTarget",
|
||||||
|
type='float', default=None,
|
||||||
|
help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls")
|
||||||
|
|
||||||
|
(OPTIONS, args) = parser.parse_args()
|
||||||
|
if len(args) != 2:
|
||||||
|
parser.error("incorrect number of arguments")
|
||||||
|
|
||||||
|
fields = OPTIONS.fields.split(',')
|
||||||
|
calls = readVariants(args[0])
|
||||||
|
print 'Read', len(calls), 'calls'
|
||||||
|
|
||||||
|
if OPTIONS.titvTarget == None:
|
||||||
|
OPTIONS.titvTarget = titv(calls, VCFRecord.isKnown)
|
||||||
|
print 'Ti/Tv all ', titv(calls)
|
||||||
|
print 'Ti/Tv known', titv(selectVariants(calls, VCFRecord.isKnown))
|
||||||
|
print 'Ti/Tv novel', titv(selectVariants(calls, VCFRecord.isNovel))
|
||||||
|
|
||||||
|
optimizeCalls(calls, OPTIONS.fields.split(","), OPTIONS.titvTarget)
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
|
|
@ -1,9 +1,28 @@
|
||||||
|
TRANSITIONS = dict()
|
||||||
|
for p in ["AG", "CT"]:
|
||||||
|
TRANSITIONS[p] = True
|
||||||
|
TRANSITIONS[''.join(reversed(p))] = True
|
||||||
|
|
||||||
|
def convertToType(d):
|
||||||
|
out = dict()
|
||||||
|
types = [int, float, str]
|
||||||
|
for key, value in d.items():
|
||||||
|
for type in types:
|
||||||
|
try:
|
||||||
|
#print 'Parsing', key, value, type
|
||||||
|
out[key] = type(value)
|
||||||
|
#print ' Parsed as', key, value, type
|
||||||
|
break
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return out
|
||||||
|
|
||||||
class VCFRecord:
|
class VCFRecord:
|
||||||
"""Simple support for accessing a VCF record"""
|
"""Simple support for accessing a VCF record"""
|
||||||
def __init__(self, basicBindings, header=None):
|
def __init__(self, basicBindings, header=None):
|
||||||
self.header = header
|
self.header = header
|
||||||
self.bindings = basicBindings
|
self.info = convertToType(parseInfo(basicBindings["INFO"]))
|
||||||
self.info = parseInfo(basicBindings["INFO"])
|
self.bindings = convertToType(basicBindings)
|
||||||
|
|
||||||
def hasHeader(self): return self.header <> None
|
def hasHeader(self): return self.header <> None
|
||||||
def getHeader(self): return self.header
|
def getHeader(self): return self.header
|
||||||
|
|
@ -21,11 +40,34 @@ class VCFRecord:
|
||||||
def getAlt(self): return self.get("ALT")
|
def getAlt(self): return self.get("ALT")
|
||||||
def getQual(self): return self.get("QUAL")
|
def getQual(self): return self.get("QUAL")
|
||||||
|
|
||||||
|
def getVariation(self): return self.getRef() + self.getAlt()
|
||||||
|
|
||||||
|
def isTransition(self):
|
||||||
|
#print self.getVariation(), TRANSITIONS
|
||||||
|
return self.getVariation() in TRANSITIONS
|
||||||
|
def isTransversion(self):
|
||||||
|
return not self.isTransition()
|
||||||
|
|
||||||
def getFilter(self): return self.get("FILTER")
|
def getFilter(self): return self.get("FILTER")
|
||||||
def failsFilters(self): return not self.passesFilters()
|
def failsFilters(self): return not self.passesFilters()
|
||||||
def passesFilters(self):
|
def passesFilters(self):
|
||||||
#print self.getFilter(), ">>>", self
|
#print self.getFilter(), ">>>", self
|
||||||
return self.getFilter() == "." or self.getFilter() == "0"
|
return self.getFilter() == "." or self.getFilter() == "0"
|
||||||
|
|
||||||
|
def hasField(self, field):
|
||||||
|
return field in self.bindings or field in self.info
|
||||||
|
|
||||||
|
def setField(self, field, value):
|
||||||
|
assert value <> None
|
||||||
|
|
||||||
|
#print 'setting field', field, value
|
||||||
|
#print 'getInfo', self.getInfo()
|
||||||
|
if field in self.bindings:
|
||||||
|
self.bindings[field] = value
|
||||||
|
else:
|
||||||
|
self.info[field] = value
|
||||||
|
self.setField("INFO", self.getInfo())
|
||||||
|
#print 'getInfo', self.getInfo()
|
||||||
|
|
||||||
def getField(self, field, default = None):
|
def getField(self, field, default = None):
|
||||||
if field in self.bindings:
|
if field in self.bindings:
|
||||||
|
|
@ -35,7 +77,15 @@ class VCFRecord:
|
||||||
else:
|
else:
|
||||||
return default
|
return default
|
||||||
|
|
||||||
def getInfo(self): return self.get("INFO")
|
#def getInfo(self): return self.get("INFO")
|
||||||
|
def getInfo(self):
|
||||||
|
def info2str(x,y):
|
||||||
|
if type(y) == bool:
|
||||||
|
return str(x)
|
||||||
|
else:
|
||||||
|
return str(x) + '=' + str(y)
|
||||||
|
return ';'.join(map(lambda x: info2str(*x), self.info.iteritems()))
|
||||||
|
|
||||||
def getInfoDict(self): return self.info
|
def getInfoDict(self): return self.info
|
||||||
|
|
||||||
def getInfoKey(self, name, default = None):
|
def getInfoKey(self, name, default = None):
|
||||||
|
|
@ -49,7 +99,8 @@ class VCFRecord:
|
||||||
return all(map(lambda key: key in self.getInfo(), keys))
|
return all(map(lambda key: key in self.getInfo(), keys))
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return str(self.bindings) + " INFO: " + str(self.info)
|
#return str(self.bindings) + " INFO: " + str(self.info)
|
||||||
|
return ' '.join(['%s=%s' % (x,y) for x,y in self.bindings.iteritems()])
|
||||||
|
|
||||||
def parseInfo(s):
|
def parseInfo(s):
|
||||||
def handleBoolean(key_val):
|
def handleBoolean(key_val):
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue