snpSelector v3 -- bootstrapping support and VCF output
git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@2004 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
parent
2fa2ae43ec
commit
3990c6d950
|
|
@ -6,6 +6,26 @@ from vcfReader import *
|
||||||
from itertools import *
|
from itertools import *
|
||||||
import math
|
import math
|
||||||
|
|
||||||
|
class CallCovariate:
|
||||||
|
def __init__(self, feature, left, right, FPRate = None, cumulative = False):
|
||||||
|
self.feature = feature
|
||||||
|
self.left = left
|
||||||
|
|
||||||
|
if cumulative:
|
||||||
|
self.right = '*'
|
||||||
|
else:
|
||||||
|
self.right = right
|
||||||
|
|
||||||
|
self.FPRate = FPRate
|
||||||
|
|
||||||
|
def containsVariant(self, call):
|
||||||
|
fieldVal = call.getField(self.feature)
|
||||||
|
return fieldVal >= self.left and (self.right == '*' or fieldVal <= self.right)
|
||||||
|
|
||||||
|
def getFPRate(self): return self.FPRate
|
||||||
|
def getFeature(self): return self.feature
|
||||||
|
|
||||||
|
def getCovariateField(self): return self.getFeature() + '_RQ'
|
||||||
|
|
||||||
class RecalibratedCall:
|
class RecalibratedCall:
|
||||||
def __init__(self, call, features):
|
def __init__(self, call, features):
|
||||||
|
|
@ -43,15 +63,17 @@ class RecalibratedCall:
|
||||||
|
|
||||||
def readVariants( file ):
|
def readVariants( file ):
|
||||||
counter = OPTIONS.skip
|
counter = OPTIONS.skip
|
||||||
|
f = open(file)
|
||||||
|
header, ignore, lines = readVCFHeader(f)
|
||||||
|
|
||||||
def parseVariant(args):
|
def parseVariant(args):
|
||||||
VCF, counter = args
|
header1, VCF, counter = args
|
||||||
if counter % OPTIONS.skip == 0:
|
if counter % OPTIONS.skip == 0:
|
||||||
return VCF
|
return VCF
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return filter(None, map(parseVariant, lines2VCF(open(file))))
|
return header, filter(None, map(parseVariant, lines2VCF(lines, extendedOutput = True)))
|
||||||
|
|
||||||
def selectVariants( variants, selector = None ):
|
def selectVariants( variants, selector = None ):
|
||||||
if selector <> None:
|
if selector <> None:
|
||||||
|
|
@ -200,29 +222,50 @@ def binString(left, right):
|
||||||
rightStr = "%5s" % str(right)
|
rightStr = "%5s" % str(right)
|
||||||
if type(right) == float: rightStr = "%.2f" % right
|
if type(right) == float: rightStr = "%.2f" % right
|
||||||
return '%8s - %8s' % (leftStr, rightStr)
|
return '%8s - %8s' % (leftStr, rightStr)
|
||||||
|
|
||||||
|
|
||||||
|
#
|
||||||
|
#
|
||||||
|
#
|
||||||
|
def recalibrateCalls(variants, fields, callCovariates):
|
||||||
|
def phred(v): return int(round(phredScale(v)))
|
||||||
|
|
||||||
|
newCalls = list()
|
||||||
|
for variant in variants:
|
||||||
|
recalCall = RecalibratedCall(variant, fields)
|
||||||
|
originalQual = variant.getField('QUAL')
|
||||||
|
|
||||||
|
for callCovariate in callCovariates:
|
||||||
|
if callCovariate.containsVariant(variant):
|
||||||
|
FPR = callCovariate.getFPRate()
|
||||||
|
recalCall.recalFeature(callCovariate.getFeature(), FPR)
|
||||||
|
recalCall.call.setField(callCovariate.getCovariateField(), phred(FPR))
|
||||||
|
|
||||||
|
recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate()))
|
||||||
|
recalCall.call.setField('OQ', originalQual)
|
||||||
|
newCalls.append(recalCall.call)
|
||||||
|
|
||||||
|
return newCalls
|
||||||
|
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
def optimizeCalls(variants, fields, titvTarget):
|
def optimizeCalls(variants, fields, titvTarget):
|
||||||
recalCalls = calibrateFeatures(variants, fields, titvTarget)
|
callCovariates = calibrateFeatures(variants, fields, titvTarget)
|
||||||
|
recalCalls = recalibrateCalls(variants, fields, callCovariates)
|
||||||
|
return recalCalls, callCovariates
|
||||||
|
|
||||||
newCalls = list()
|
def printCallQuals(recalCalls, titvTarget, info = ""):
|
||||||
for recalCall in recalCalls.itervalues():
|
#for recalCall in islice(recalCalls, 10):
|
||||||
originalQual = recalCall.call.getField('QUAL')
|
# print recalCall
|
||||||
recalCall.call.setField('QUAL', int(round(phredScale(recalCall.jointFPErrorRate()))))
|
|
||||||
recalCall.call.setField('OQ', originalQual)
|
|
||||||
newCalls.append(recalCall.call)
|
|
||||||
|
|
||||||
for recalCall in islice(recalCalls.itervalues(), 10):
|
|
||||||
print recalCall
|
|
||||||
|
|
||||||
print '--------------------------------------------------------------------------------'
|
print '--------------------------------------------------------------------------------'
|
||||||
print 'RECALIBRATED CALLS'
|
print info
|
||||||
#newCalls = [x.call for x in recalCalls.itervalues()]
|
calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = False )
|
||||||
calibrateFeatures(newCalls, ['QUAL'], titvTarget, updateCalls = False, printCall = True )
|
print 'Cumulative'
|
||||||
|
calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = True )
|
||||||
|
|
||||||
|
|
||||||
return newCalls
|
|
||||||
|
|
||||||
def all( p, l ):
|
def all( p, l ):
|
||||||
for elt in l:
|
for elt in l:
|
||||||
|
|
@ -238,41 +281,33 @@ def variantBinsForField(variants, field):
|
||||||
print 'Partitions', bins
|
print 'Partitions', bins
|
||||||
return bins
|
return bins
|
||||||
|
|
||||||
def mapVariantBins(variants, field):
|
def mapVariantBins(variants, field, cumulative = False):
|
||||||
bins = variantBinsForField(variants, field)
|
bins = variantBinsForField(variants, field)
|
||||||
|
|
||||||
def variantsInBin(bin):
|
def variantsInBin(bin):
|
||||||
left, right = bin[0:2]
|
cc = CallCovariate(field, bin[0], bin[1], cumulative = cumulative)
|
||||||
def select( variant ): return variant.getField(field) >= left and (right == '*' or variant.getField(field) <= right)
|
|
||||||
return left, right, selectVariants(variants, select)
|
return cc.left, cc.right, selectVariants(variants, lambda v: cc.containsVariant(v))
|
||||||
|
|
||||||
return imap( variantsInBin, bins )
|
return imap( variantsInBin, bins )
|
||||||
|
|
||||||
def calibrateFeatures(variants, fields, titvTarget, updateCalls = True, printCall = False):
|
def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulative = False ):
|
||||||
if updateCalls: recalCalls = dict([[variant, RecalibratedCall(variant, fields)] for variant in variants])
|
covariates = []
|
||||||
|
|
||||||
for field in fields:
|
for field in fields:
|
||||||
print 'Optimizing field', field
|
print 'Optimizing field', field
|
||||||
|
|
||||||
titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget)
|
titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget)
|
||||||
print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate)
|
print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate)
|
||||||
|
|
||||||
for left, right, selectedVariants in mapVariantBins(variants, field):
|
for left, right, selectedVariants in mapVariantBins(variants, field, cumulative = cumulative):
|
||||||
if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1):
|
if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1):
|
||||||
titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget)
|
titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget)
|
||||||
if updateCalls:
|
covariates.append(CallCovariate(field, left, right, FPRate))
|
||||||
for variant in selectedVariants: recalCalls[variant].recalFeature(field, FPRate)
|
|
||||||
if printCall:
|
|
||||||
for call in selectedVariants:
|
|
||||||
if titv < 0.5:
|
|
||||||
print call
|
|
||||||
|
|
||||||
printFieldQual( left, right, selectedVariants, titv, FPRate, nErrors )
|
printFieldQual( left, right, selectedVariants, titv, FPRate, nErrors )
|
||||||
|
|
||||||
if updateCalls:
|
|
||||||
return recalCalls
|
return covariates
|
||||||
else:
|
|
||||||
return None
|
|
||||||
|
|
||||||
class CallCmp:
|
class CallCmp:
|
||||||
def __init__(self, nTP, nFP, nFN):
|
def __init__(self, nTP, nFP, nFN):
|
||||||
|
|
@ -301,12 +336,29 @@ def sensitivitySpecificity(variants, truth):
|
||||||
return CallCmp(nTP, nFP, nFN)
|
return CallCmp(nTP, nFP, nFN)
|
||||||
|
|
||||||
|
|
||||||
def compareCalls(optimizedCalls, truthCalls):
|
def compareCalls(calls, truthCalls):
|
||||||
for left, right, selectedVariants in mapVariantBins(optimizedCalls, 'QUAL'):
|
def compare1(cumulative):
|
||||||
callComparison = sensitivitySpecificity(selectedVariants, truthCalls)
|
for left, right, selectedVariants in mapVariantBins(calls, 'QUAL', cumulative = cumulative):
|
||||||
print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison
|
callComparison = sensitivitySpecificity(selectedVariants, truthCalls)
|
||||||
|
print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison
|
||||||
|
|
||||||
|
print 'PER BIN'
|
||||||
|
compare1(False)
|
||||||
|
|
||||||
|
print 'CUMULATIVE'
|
||||||
|
compare1(True)
|
||||||
|
|
||||||
|
def randomSplit(l, pLeft):
|
||||||
|
import random
|
||||||
|
|
||||||
|
def keep(elt, p):
|
||||||
|
if p < pLeft:
|
||||||
|
return elt, None
|
||||||
|
else:
|
||||||
|
return None, elt
|
||||||
|
data = [keep(elt, p) for elt, p in zip(l, map(lambda x: random.random(), l))]
|
||||||
|
def get(i): return filter(lambda x: x <> None, [x[i] for x in data])
|
||||||
|
return get(0), get(1)
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global OPTIONS
|
global OPTIONS
|
||||||
|
|
@ -330,34 +382,67 @@ def main():
|
||||||
parser.add_option("-q", "--qMax", dest="maxQScore",
|
parser.add_option("-q", "--qMax", dest="maxQScore",
|
||||||
type='int', default=30,
|
type='int', default=30,
|
||||||
help="")
|
help="")
|
||||||
|
parser.add_option("-o", "--outputVCF", dest="outputVCF",
|
||||||
|
type='string', default=None,
|
||||||
|
help="If provided, VCF file will be written out to this file")
|
||||||
parser.add_option("", "--titv", dest="titvTarget",
|
parser.add_option("", "--titv", dest="titvTarget",
|
||||||
type='float', default=None,
|
type='float', default=None,
|
||||||
help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls")
|
help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls")
|
||||||
parser.add_option("", "--fp", dest="printFP",
|
parser.add_option("", "--fp", dest="printFP",
|
||||||
action='store_true', default=False,
|
action='store_true', default=False,
|
||||||
help="")
|
help="")
|
||||||
|
parser.add_option("-b", "--bootstrap", dest="bootStrap",
|
||||||
|
type='float', default=0.0,
|
||||||
|
help="If provided, the % of the calls used to generate the recalibration tables.")
|
||||||
|
|
||||||
(OPTIONS, args) = parser.parse_args()
|
(OPTIONS, args) = parser.parse_args()
|
||||||
if len(args) > 2:
|
if len(args) > 2:
|
||||||
parser.error("incorrect number of arguments")
|
parser.error("incorrect number of arguments")
|
||||||
|
|
||||||
fields = OPTIONS.fields.split(',')
|
fields = OPTIONS.fields.split(',')
|
||||||
calls = readVariants(args[0])
|
header, allCalls = readVariants(args[0])
|
||||||
print 'Read', len(calls), 'calls'
|
print 'Read', len(allCalls), 'calls'
|
||||||
|
print 'header is', header
|
||||||
|
|
||||||
if OPTIONS.titvTarget == None:
|
if OPTIONS.titvTarget == None:
|
||||||
OPTIONS.titvTarget = titv(calls, VCFRecord.isKnown)
|
OPTIONS.titvTarget = titv(calls, VCFRecord.isKnown)
|
||||||
print 'Ti/Tv all ', titv(calls)
|
print 'Ti/Tv all ', titv(allCalls)
|
||||||
print 'Ti/Tv known', titv(selectVariants(calls, VCFRecord.isKnown))
|
print 'Ti/Tv known', titv(selectVariants(allCalls, VCFRecord.isKnown))
|
||||||
print 'Ti/Tv novel', titv(selectVariants(calls, VCFRecord.isNovel))
|
print 'Ti/Tv novel', titv(selectVariants(allCalls, VCFRecord.isNovel))
|
||||||
|
|
||||||
optimizedCalls = optimizeCalls(calls, OPTIONS.fields.split(","), OPTIONS.titvTarget)
|
if OPTIONS.bootStrap:
|
||||||
|
callsToOptimize, callsToEval = randomSplit(allCalls, OPTIONS.bootStrap)
|
||||||
|
else:
|
||||||
|
callsToOptimize, callsToEval = allCalls, allCalls
|
||||||
|
|
||||||
|
recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, OPTIONS.titvTarget)
|
||||||
|
printCallQuals(recalOptCalls, OPTIONS.titvTarget, 'OPTIMIZED CALLS')
|
||||||
|
|
||||||
|
if callsToEval <> callsToOptimize:
|
||||||
|
recalEvalCalls = recalibrateCalls(callsToEval, fields, covariates)
|
||||||
|
printCallQuals(recalEvalCalls, OPTIONS.titvTarget, 'BOOTSTRAP EVAL CALLS')
|
||||||
|
|
||||||
if len(args) > 1:
|
if len(args) > 1:
|
||||||
truthFile = args[1]
|
truthFile = args[1]
|
||||||
print 'Reading truth file', truthFile
|
print 'Reading truth file', truthFile
|
||||||
truth = dict( [[v.getLoc(), v] for v in readVariants(truthFile)])
|
truth = dict( [[v.getLoc(), v] for v in readVariants(truthFile)[1]])
|
||||||
compareCalls(optimizedCalls, truth)
|
|
||||||
|
print '--------------------------------------------------------------------------------'
|
||||||
|
print 'Comparing calls to truth', truthFile
|
||||||
|
print ''
|
||||||
|
|
||||||
|
print 'Calls used in optimization'
|
||||||
|
compareCalls(recalOptCalls, truth)
|
||||||
|
if callsToEval <> callsToOptimize:
|
||||||
|
print 'Calls held in reserve (bootstrap)'
|
||||||
|
compareCalls(recalEvalCalls, truth)
|
||||||
|
|
||||||
|
if OPTIONS.outputVCF:
|
||||||
|
f = open(OPTIONS.outputVCF, 'w')
|
||||||
|
print 'HEADER', header
|
||||||
|
for line in formatVCF(header, allCalls):
|
||||||
|
print >> f, line
|
||||||
|
f.close()
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
@ -1,3 +1,7 @@
|
||||||
|
import itertools
|
||||||
|
|
||||||
|
VCF_KEYS = "CHROM POS ID REF ALT QUAL FILTER INFO".split()
|
||||||
|
|
||||||
TRANSITIONS = dict()
|
TRANSITIONS = dict()
|
||||||
for p in ["AG", "CT"]:
|
for p in ["AG", "CT"]:
|
||||||
TRANSITIONS[p] = True
|
TRANSITIONS[p] = True
|
||||||
|
|
@ -19,10 +23,11 @@ def convertToType(d):
|
||||||
|
|
||||||
class VCFRecord:
|
class VCFRecord:
|
||||||
"""Simple support for accessing a VCF record"""
|
"""Simple support for accessing a VCF record"""
|
||||||
def __init__(self, basicBindings, header=None):
|
def __init__(self, basicBindings, header=None, rest=[]):
|
||||||
self.header = header
|
self.header = header
|
||||||
self.info = convertToType(parseInfo(basicBindings["INFO"]))
|
self.info = convertToType(parseInfo(basicBindings["INFO"]))
|
||||||
self.bindings = convertToType(basicBindings)
|
self.bindings = convertToType(basicBindings)
|
||||||
|
self.rest = rest
|
||||||
|
|
||||||
def hasHeader(self): return self.header <> None
|
def hasHeader(self): return self.header <> None
|
||||||
def getHeader(self): return self.header
|
def getHeader(self): return self.header
|
||||||
|
|
@ -102,6 +107,9 @@ class VCFRecord:
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
#return str(self.bindings) + " INFO: " + str(self.info)
|
#return str(self.bindings) + " INFO: " + str(self.info)
|
||||||
return ' '.join(['%s=%s' % (x,y) for x,y in self.bindings.iteritems()])
|
return ' '.join(['%s=%s' % (x,y) for x,y in self.bindings.iteritems()])
|
||||||
|
|
||||||
|
def format(self):
|
||||||
|
return '\t'.join([str(self.getField(key)) for key in VCF_KEYS] + self.rest)
|
||||||
|
|
||||||
def parseInfo(s):
|
def parseInfo(s):
|
||||||
def handleBoolean(key_val):
|
def handleBoolean(key_val):
|
||||||
|
|
@ -116,21 +124,41 @@ def parseInfo(s):
|
||||||
def string2VCF(line, header=None):
|
def string2VCF(line, header=None):
|
||||||
if line[0] != "#":
|
if line[0] != "#":
|
||||||
s = line.split()
|
s = line.split()
|
||||||
keys = "CHROM POS ID REF ALT QUAL FILTER INFO".split()
|
bindings = dict(zip(VCF_KEYS, s[0:8]))
|
||||||
bindings = dict(zip(keys, s[0:8]))
|
return VCFRecord(bindings, header, rest=s[8:])
|
||||||
return VCFRecord(bindings, header)
|
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def lines2VCF(lines):
|
def readVCFHeader(lines):
|
||||||
header = None
|
header = []
|
||||||
|
columnNames = None
|
||||||
|
for line in lines:
|
||||||
|
if line[0] == "#":
|
||||||
|
header.append(line.strip())
|
||||||
|
else:
|
||||||
|
if header <> []:
|
||||||
|
columnNames = header[-1]
|
||||||
|
return header, columnNames, itertools.chain([line], lines)
|
||||||
|
|
||||||
|
|
||||||
|
def lines2VCF(lines, extendedOutput = False):
|
||||||
|
header, columnNames, lines = readVCFHeader(lines)
|
||||||
counter = 0
|
counter = 0
|
||||||
|
|
||||||
for line in lines:
|
for line in lines:
|
||||||
if line[0] != "#":
|
if line[0] != "#":
|
||||||
counter += 1
|
counter += 1
|
||||||
vcf = string2VCF(line, header=header)
|
vcf = string2VCF(line, header=columnNames)
|
||||||
if vcf <> None:
|
if vcf <> None:
|
||||||
yield vcf, counter
|
if extendedOutput:
|
||||||
else:
|
yield header, vcf, counter
|
||||||
header = line[1:].split()
|
else:
|
||||||
|
yield vcf
|
||||||
raise StopIteration()
|
raise StopIteration()
|
||||||
|
|
||||||
|
|
||||||
|
def formatVCF(header, records):
|
||||||
|
#print records
|
||||||
|
print records[0]
|
||||||
|
return itertools.chain(header, map(VCFRecord.format, records))
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue