Better snpSelector, plus VCFmerge tool

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@2022 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
depristo 2009-11-11 22:02:57 +00:00
parent 0c2a957ae0
commit 1a4d071d37
4 changed files with 187 additions and 56 deletions

View File

@ -0,0 +1,14 @@
from itertools import *
def readFAI(file):
# 1 247249719 3 60 61
# 2 242951149 251370554 60 61
# 3 199501827 498370892 60 61
return [line.split() for line in open(file)]
def readFAIContigOrdering(file):
# 1 247249719 3 60 61
# 2 242951149 251370554 60 61
# 3 199501827 498370892 60 61
return dict([[rec[0], i] for rec, i in izip(readFAI(file), count())])

View File

@ -0,0 +1,56 @@
import os.path
import sys
from optparse import OptionParser
from vcfReader import *
from itertools import *
import faiReader
def main():
global OPTIONS
usage = "usage: %prog [options] file1 ... fileN"
parser = OptionParser(usage=usage)
parser.add_option("-f", "--f", dest="fai",
type='string', default=None,
help="FAI file defining the sort order of the VCF")
(OPTIONS, args) = parser.parse_args()
if len(args) == 0:
parser.error("Requires at least 1 VCF to merge")
order = None
if OPTIONS.fai <> None: order = faiReader.readFAIContigOrdering(OPTIONS.fai)
#print 'Order', order
header = None
records = []
for file in args:
#print file
for header, record, counter in lines2VCF(open(file), extendedOutput = True, decodeAll = False):
records.append(record)
def cmpVCFRecords(r1, r2):
if order <> None:
c1 = order[str(r1.getChrom())]
c2 = order[str(r2.getChrom())]
orderCmp = cmp(c1, c2)
if orderCmp <> 0:
return orderCmp
return cmp(r1.getPos(), r2.getPos())
records.sort(cmpVCFRecords)
for line in formatVCF(header, records):
#pass
print line
PROFILE = False
if __name__ == "__main__":
if PROFILE:
import cProfile
cProfile.run('main()', 'fooprof')
import pstats
p = pstats.Stats('fooprof')
p.sort_stats('cumulative').print_stats(10)
p.sort_stats('time').print_stats(10)
p.sort_stats('time', 'cum').print_stats(.5, 'init')
else:
main()

View File

@ -5,6 +5,8 @@ from vcfReader import *
#import pylab #import pylab
from itertools import * from itertools import *
import math import math
import random
class CallCovariate: class CallCovariate:
def __init__(self, feature, left, right, FPRate = None, cumulative = False): def __init__(self, feature, left, right, FPRate = None, cumulative = False):
@ -61,7 +63,7 @@ class RecalibratedCall:
def __str__(self): def __str__(self):
return '[%s: %s => Q%d]' % (str(self.call), self.featureStringList(), phredScale(self.jointFPErrorRate())) return '[%s: %s => Q%d]' % (str(self.call), self.featureStringList(), phredScale(self.jointFPErrorRate()))
def readVariants( file ): def readVariants( file, maxRecords = None, decodeAll = True ):
counter = OPTIONS.skip counter = OPTIONS.skip
f = open(file) f = open(file)
header, ignore, lines = readVCFHeader(f) header, ignore, lines = readVCFHeader(f)
@ -73,7 +75,7 @@ def readVariants( file ):
else: else:
return None return None
return header, filter(None, map(parseVariant, lines2VCF(lines, extendedOutput = True))) return header, filter(None, map(parseVariant, islice(lines2VCF(lines, extendedOutput = True, decodeAll = decodeAll), maxRecords)))
def selectVariants( variants, selector = None ): def selectVariants( variants, selector = None ):
if selector <> None: if selector <> None:
@ -174,7 +176,7 @@ def frange6(*args):
yield v yield v
v += step v += step
def calculateBins(variants, field, minValue, maxValue, rangeValue, partitions): def calculateBins(variants, field, minValue, maxValue, partitions):
sortedVariants = sorted(variants, key = lambda x: x.getField(field)) sortedVariants = sorted(variants, key = lambda x: x.getField(field))
sortedValues = map(lambda x: x.getField(field), sortedVariants) sortedValues = map(lambda x: x.getField(field), sortedVariants)
@ -198,20 +200,20 @@ def calculateBins(variants, field, minValue, maxValue, rangeValue, partitions):
return bins return bins
#
def calculateBinsLinear(variants, minValue, maxValue, rangeValue, partitions): # def calculateBinsLinear(variants, minValue, maxValue, rangeValue, partitions):
breaks = list(frange6(minValue, maxValue, rangeValue / partitions)) # breaks = list(frange6(minValue, maxValue, rangeValue / partitions))
if breaks[len(breaks)-1] <> maxValue: # if breaks[len(breaks)-1] <> maxValue:
breaks = breaks + ['*'] # breaks = breaks + ['*']
return zip(breaks, map( lambda x: x - 0.001, breaks[1:])) # return zip(breaks, map( lambda x: x - 0.001, breaks[1:]))
def fieldRange(variants, field): def fieldRange(variants, field):
values = map(lambda v: v.getField(field), variants) values = map(lambda v: v.getField(field), variants)
minValue = min(values) minValue = min(values)
maxValue = max(values) maxValue = max(values)
rangeValue = maxValue - minValue #rangeValue = maxValue - minValue
bins = calculateBins(variants, field, minValue, maxValue, rangeValue, OPTIONS.partitions) bins = calculateBins(variants, field, minValue, maxValue, OPTIONS.partitions)
return minValue, maxValue, rangeValue, bins return minValue, maxValue, bins
def printFieldQual( left, right, variants, titv, FPRate, nErrors ): def printFieldQual( left, right, variants, titv, FPRate, nErrors ):
print ' %s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (binString(left, right), len(variants), titv, FPRate, phredScale(FPRate)) print ' %s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (binString(left, right), len(variants), titv, FPRate, phredScale(FPRate))
@ -273,11 +275,11 @@ def all( p, l ):
return True return True
def variantBinsForField(variants, field): def variantBinsForField(variants, field):
if not all( lambda x: x.hasField(field), variants): #if not all( lambda x: x.hasField(field), variants):
raise Exception('Unknown field ' + field) # raise Exception('Unknown field ' + field)
minValue, maxValue, range, bins = fieldRange(variants, field) minValue, maxValue, bins = fieldRange(variants, field)
print 'Field range', minValue, maxValue, range print 'Field range', minValue, maxValue
print 'Partitions', bins print 'Partitions', bins
return bins return bins
@ -317,40 +319,52 @@ class CallCmp:
def FPRate(self): def FPRate(self):
return (1.0*self.nFP) / max(self.nTP + self.nFP, 1) return (1.0*self.nFP) / max(self.nTP + self.nFP, 1)
def FNRate(self):
return (1.0*self.nFN) / max(self.nTP + self.nFN, 1)
def __str__(self): def __str__(self):
return 'TP=%6d FP=%6d FPRate=%.2f FN=%6d' % (self.nTP, self.nFP, self.FPRate(), self.nFN) return 'TP=%6d FP=%6d FPRate=%.2f FN=%6d FNRate=%.2f' % (self.nTP, self.nFP, self.FPRate(), self.nFN, self.FNRate())
def variantInTruth(variant, truth): def variantInTruth(variant, truth):
return variant.getLoc() in truth if variant.getLoc() in truth:
return truth[variant.getLoc()]
else:
return False
def sensitivitySpecificity(variants, truth): def sensitivitySpecificity(variants, truth):
nTP, nFP = 0, 0 nTP, nFP = 0, 0
FPs = []
for variant in variants: for variant in variants:
if variantInTruth(variant, truth): t = variantInTruth(variant, truth)
if t:
t.setField("FOUND", 1)
nTP += 1 nTP += 1
else: else:
if OPTIONS.printFP: print 'FP:', variant if OPTIONS.printFP: print 'FP:', variant
nFP += 1 nFP += 1
#if variant.getPos() == 1520727:
# print "Variant is missing", variant
FPs.append(variant)
nFN = len(truth) - nTP nFN = len(truth) - nTP
return CallCmp(nTP, nFP, nFN) return CallCmp(nTP, nFP, nFN), FPs
def compareCalls(calls, truthCalls): def compareCalls(calls, truthCalls):
def compare1(cumulative): def compare1(cumulative):
for left, right, selectedVariants in mapVariantBins(calls, 'QUAL', cumulative = cumulative): for left, right, selectedVariants in mapVariantBins(calls, 'QUAL', cumulative = cumulative):
callComparison = sensitivitySpecificity(selectedVariants, truthCalls) callComparison, theseFPs = sensitivitySpecificity(selectedVariants, truthCalls)
for fp in theseFPs: fp.setField("FP", 1)
#FPsVariants.append(theseFPs)
print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison
print 'PER BIN' print 'PER BIN nCalls=', len(calls)
compare1(False) compare1(False)
print 'CUMULATIVE' print 'CUMULATIVE nCalls=', len(calls)
compare1(True) compare1(True)
def randomSplit(l, pLeft): def randomSplit(l, pLeft):
import random
def keep(elt, p): def keep(elt, p):
if p < pLeft: if p < pLeft:
return elt, None return elt, None
@ -361,7 +375,7 @@ def randomSplit(l, pLeft):
return get(0), get(1) return get(0), get(1)
def main(): def main():
global OPTIONS global OPTIONS, header
usage = "usage: %prog files.list [options]" usage = "usage: %prog files.list [options]"
parser = OptionParser(usage=usage) parser = OptionParser(usage=usage)
parser.add_option("-f", "--f", dest="fields", parser.add_option("-f", "--f", dest="fields",
@ -370,6 +384,9 @@ def main():
parser.add_option("-t", "--truth", dest="truth", parser.add_option("-t", "--truth", dest="truth",
type='string', default=None, type='string', default=None,
help="VCF formated truth file") help="VCF formated truth file")
parser.add_option("", "--unFilteredTruth", dest="unFilteredTruth",
action='store_true', default=False,
help="If provided, the unfiltered truth calls will be used in comparisons")
parser.add_option("-p", "--partitions", dest="partitions", parser.add_option("-p", "--partitions", dest="partitions",
type='int', default=25, type='int', default=25,
help="Number of partitions to examine") help="Number of partitions to examine")
@ -379,12 +396,18 @@ def main():
parser.add_option("-m", "--minVariantsPerBin", dest="minVariantsPerBin", parser.add_option("-m", "--minVariantsPerBin", dest="minVariantsPerBin",
type='int', default=10, type='int', default=10,
help="") help="")
parser.add_option("-M", "--maxRecords", dest="maxRecords",
type='int', default=None,
help="")
parser.add_option("-q", "--qMax", dest="maxQScore", parser.add_option("-q", "--qMax", dest="maxQScore",
type='int', default=30, type='int', default=30,
help="") help="")
parser.add_option("-o", "--outputVCF", dest="outputVCF", parser.add_option("-o", "--outputVCF", dest="outputVCF",
type='string', default=None, type='string', default=None,
help="If provided, VCF file will be written out to this file") help="If provided, VCF file will be written out to this file")
parser.add_option("", "--FNoutputVCF", dest="FNoutputVCF",
type='string', default=None,
help="If provided, VCF file will be written out to this file")
parser.add_option("", "--titv", dest="titvTarget", parser.add_option("", "--titv", dest="titvTarget",
type='float', default=None, type='float', default=None,
help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls") help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls")
@ -394,15 +417,16 @@ def main():
parser.add_option("-b", "--bootstrap", dest="bootStrap", parser.add_option("-b", "--bootstrap", dest="bootStrap",
type='float', default=0.0, type='float', default=0.0,
help="If provided, the % of the calls used to generate the recalibration tables.") help="If provided, the % of the calls used to generate the recalibration tables.")
(OPTIONS, args) = parser.parse_args() (OPTIONS, args) = parser.parse_args()
if len(args) > 2: if len(args) > 2:
parser.error("incorrect number of arguments") parser.error("incorrect number of arguments")
fields = OPTIONS.fields.split(',') fields = OPTIONS.fields.split(',')
header, allCalls = readVariants(args[0]) header, allCalls = readVariants(args[0], OPTIONS.maxRecords)
print 'Read', len(allCalls), 'calls' print 'Read', len(allCalls), 'calls'
print 'header is', header #print 'header is', header
if OPTIONS.titvTarget == None: if OPTIONS.titvTarget == None:
OPTIONS.titvTarget = titv(calls, VCFRecord.isKnown) OPTIONS.titvTarget = titv(calls, VCFRecord.isKnown)
@ -422,10 +446,16 @@ def main():
recalEvalCalls = recalibrateCalls(callsToEval, fields, covariates) recalEvalCalls = recalibrateCalls(callsToEval, fields, covariates)
printCallQuals(recalEvalCalls, OPTIONS.titvTarget, 'BOOTSTRAP EVAL CALLS') printCallQuals(recalEvalCalls, OPTIONS.titvTarget, 'BOOTSTRAP EVAL CALLS')
truth = None
if len(args) > 1: if len(args) > 1:
truthFile = args[1] truthFile = args[1]
print 'Reading truth file', truthFile print 'Reading truth file', truthFile
truth = dict( [[v.getLoc(), v] for v in readVariants(truthFile)[1]]) rawTruth = readVariants(truthFile, maxRecords = None, decodeAll = False)[1]
def keepVariant(t):
#print t.getPos(), t.getLoc()
return OPTIONS.unFilteredTruth or t.passesFilters()
truth = dict( [[v.getLoc(), v] for v in filter(keepVariant, rawTruth)])
print len(rawTruth), len(truth)
print '--------------------------------------------------------------------------------' print '--------------------------------------------------------------------------------'
print 'Comparing calls to truth', truthFile print 'Comparing calls to truth', truthFile
@ -439,10 +469,27 @@ def main():
if OPTIONS.outputVCF: if OPTIONS.outputVCF:
f = open(OPTIONS.outputVCF, 'w') f = open(OPTIONS.outputVCF, 'w')
print 'HEADER', header #print 'HEADER', header
for line in formatVCF(header, allCalls): for line in formatVCF(header, allCalls):
print >> f, line print >> f, line
f.close() f.close()
if truth <> None and OPTIONS.FNoutputVCF:
f = open(OPTIONS.FNoutputVCF, 'w')
#print 'HEADER', header
for line in formatVCF(header, filter( lambda x: not x.hasField("FOUND"), truth.itervalues())):
print >> f, line
f.close()
PROFILE = False
if __name__ == "__main__": if __name__ == "__main__":
main() if PROFILE:
import cProfile
cProfile.run('main()', 'fooprof')
import pstats
p = pstats.Stats('fooprof')
p.sort_stats('cumulative').print_stats(10)
p.sort_stats('time').print_stats(10)
p.sort_stats('time', 'cum').print_stats(.5, 'init')
else:
main()

View File

@ -7,26 +7,28 @@ for p in ["AG", "CT"]:
TRANSITIONS[p] = True TRANSITIONS[p] = True
TRANSITIONS[''.join(reversed(p))] = True TRANSITIONS[''.join(reversed(p))] = True
def convertToType(d): def convertToType(d, onlyKeys = None):
out = dict() out = dict()
types = [int, float, str] types = [int, float, str]
for key, value in d.items(): for key, value in d.items():
for type in types: if onlyKeys == None or key in onlyKeys:
try: for type in types:
#print 'Parsing', key, value, type try:
out[key] = type(value) out[key] = type(value)
#print ' Parsed as', key, value, type break
break except:
except: pass
pass else:
out[key] = value
return out return out
class VCFRecord: class VCFRecord:
"""Simple support for accessing a VCF record""" """Simple support for accessing a VCF record"""
def __init__(self, basicBindings, header=None, rest=[]): def __init__(self, basicBindings, header=None, rest=[], decodeAll = True):
self.header = header self.header = header
self.info = convertToType(parseInfo(basicBindings["INFO"])) self.info = parseInfo(basicBindings["INFO"])
self.bindings = convertToType(basicBindings) if decodeAll: self.info = convertToType(self.info)
self.bindings = convertToType(basicBindings, onlyKeys = ['POS', 'QUAL'])
self.rest = rest self.rest = rest
def hasHeader(self): return self.header <> None def hasHeader(self): return self.header <> None
@ -112,20 +114,30 @@ class VCFRecord:
return '\t'.join([str(self.getField(key)) for key in VCF_KEYS] + self.rest) return '\t'.join([str(self.getField(key)) for key in VCF_KEYS] + self.rest)
def parseInfo(s): def parseInfo(s):
def handleBoolean(key_val): d = dict()
if len(key_val) == 1: for elt in s.split(";"):
return [key_val[0], 1] if '=' in elt:
key, val = elt.split('=')
else: else:
return key_val key, val = elt, 1
d[key] = val
return d
key_val = map( lambda x: handleBoolean(x.split("=")), s.split(";")) # def parseInfo(s):
return dict(key_val) # def handleBoolean(key_val):
# if len(key_val) == 1:
# return [key_val[0], 1]
# else:
# return key_val
#
# key_val = map( lambda x: handleBoolean(x.split("=")), s.split(";"))
# return dict(key_val)
def string2VCF(line, header=None): def string2VCF(line, header=None, decodeAll = True):
if line[0] != "#": if line[0] != "#":
s = line.split() s = line.split()
bindings = dict(zip(VCF_KEYS, s[0:8])) bindings = dict(zip(VCF_KEYS, s[0:8]))
return VCFRecord(bindings, header, rest=s[8:]) return VCFRecord(bindings, header, rest=s[8:], decodeAll = decodeAll)
else: else:
return None return None
@ -139,16 +151,18 @@ def readVCFHeader(lines):
if header <> []: if header <> []:
columnNames = header[-1] columnNames = header[-1]
return header, columnNames, itertools.chain([line], lines) return header, columnNames, itertools.chain([line], lines)
# we reach this point for empty files
return header, columnNames, []
def lines2VCF(lines, extendedOutput = False): def lines2VCF(lines, extendedOutput = False, decodeAll = True):
header, columnNames, lines = readVCFHeader(lines) header, columnNames, lines = readVCFHeader(lines)
counter = 0 counter = 0
for line in lines: for line in lines:
if line[0] != "#": if line[0] != "#":
counter += 1 counter += 1
vcf = string2VCF(line, header=columnNames) vcf = string2VCF(line, header=columnNames, decodeAll = decodeAll)
if vcf <> None: if vcf <> None:
if extendedOutput: if extendedOutput:
yield header, vcf, counter yield header, vcf, counter
@ -159,6 +173,6 @@ def lines2VCF(lines, extendedOutput = False):
def formatVCF(header, records): def formatVCF(header, records):
#print records #print records
print records[0] #print records[0]
return itertools.chain(header, map(VCFRecord.format, records)) return itertools.chain(header, map(VCFRecord.format, records))