Better snpSelector, plus VCFmerge tool

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@2022 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
depristo 2009-11-11 22:02:57 +00:00
parent 0c2a957ae0
commit 1a4d071d37
4 changed files with 187 additions and 56 deletions

View File

@ -0,0 +1,14 @@
from itertools import *
def readFAI(file):
# 1 247249719 3 60 61
# 2 242951149 251370554 60 61
# 3 199501827 498370892 60 61
return [line.split() for line in open(file)]
def readFAIContigOrdering(file):
# 1 247249719 3 60 61
# 2 242951149 251370554 60 61
# 3 199501827 498370892 60 61
return dict([[rec[0], i] for rec, i in izip(readFAI(file), count())])

View File

@ -0,0 +1,56 @@
import os.path
import sys
from optparse import OptionParser
from vcfReader import *
from itertools import *
import faiReader
def main():
global OPTIONS
usage = "usage: %prog [options] file1 ... fileN"
parser = OptionParser(usage=usage)
parser.add_option("-f", "--f", dest="fai",
type='string', default=None,
help="FAI file defining the sort order of the VCF")
(OPTIONS, args) = parser.parse_args()
if len(args) == 0:
parser.error("Requires at least 1 VCF to merge")
order = None
if OPTIONS.fai <> None: order = faiReader.readFAIContigOrdering(OPTIONS.fai)
#print 'Order', order
header = None
records = []
for file in args:
#print file
for header, record, counter in lines2VCF(open(file), extendedOutput = True, decodeAll = False):
records.append(record)
def cmpVCFRecords(r1, r2):
if order <> None:
c1 = order[str(r1.getChrom())]
c2 = order[str(r2.getChrom())]
orderCmp = cmp(c1, c2)
if orderCmp <> 0:
return orderCmp
return cmp(r1.getPos(), r2.getPos())
records.sort(cmpVCFRecords)
for line in formatVCF(header, records):
#pass
print line
PROFILE = False
if __name__ == "__main__":
if PROFILE:
import cProfile
cProfile.run('main()', 'fooprof')
import pstats
p = pstats.Stats('fooprof')
p.sort_stats('cumulative').print_stats(10)
p.sort_stats('time').print_stats(10)
p.sort_stats('time', 'cum').print_stats(.5, 'init')
else:
main()

View File

@ -5,6 +5,8 @@ from vcfReader import *
#import pylab
from itertools import *
import math
import random
class CallCovariate:
def __init__(self, feature, left, right, FPRate = None, cumulative = False):
@ -61,7 +63,7 @@ class RecalibratedCall:
def __str__(self):
return '[%s: %s => Q%d]' % (str(self.call), self.featureStringList(), phredScale(self.jointFPErrorRate()))
def readVariants( file ):
def readVariants( file, maxRecords = None, decodeAll = True ):
counter = OPTIONS.skip
f = open(file)
header, ignore, lines = readVCFHeader(f)
@ -73,7 +75,7 @@ def readVariants( file ):
else:
return None
return header, filter(None, map(parseVariant, lines2VCF(lines, extendedOutput = True)))
return header, filter(None, map(parseVariant, islice(lines2VCF(lines, extendedOutput = True, decodeAll = decodeAll), maxRecords)))
def selectVariants( variants, selector = None ):
if selector <> None:
@ -174,7 +176,7 @@ def frange6(*args):
yield v
v += step
def calculateBins(variants, field, minValue, maxValue, rangeValue, partitions):
def calculateBins(variants, field, minValue, maxValue, partitions):
sortedVariants = sorted(variants, key = lambda x: x.getField(field))
sortedValues = map(lambda x: x.getField(field), sortedVariants)
@ -198,20 +200,20 @@ def calculateBins(variants, field, minValue, maxValue, rangeValue, partitions):
return bins
def calculateBinsLinear(variants, minValue, maxValue, rangeValue, partitions):
breaks = list(frange6(minValue, maxValue, rangeValue / partitions))
if breaks[len(breaks)-1] <> maxValue:
breaks = breaks + ['*']
return zip(breaks, map( lambda x: x - 0.001, breaks[1:]))
#
# def calculateBinsLinear(variants, minValue, maxValue, rangeValue, partitions):
# breaks = list(frange6(minValue, maxValue, rangeValue / partitions))
# if breaks[len(breaks)-1] <> maxValue:
# breaks = breaks + ['*']
# return zip(breaks, map( lambda x: x - 0.001, breaks[1:]))
def fieldRange(variants, field):
values = map(lambda v: v.getField(field), variants)
minValue = min(values)
maxValue = max(values)
rangeValue = maxValue - minValue
bins = calculateBins(variants, field, minValue, maxValue, rangeValue, OPTIONS.partitions)
return minValue, maxValue, rangeValue, bins
#rangeValue = maxValue - minValue
bins = calculateBins(variants, field, minValue, maxValue, OPTIONS.partitions)
return minValue, maxValue, bins
def printFieldQual( left, right, variants, titv, FPRate, nErrors ):
print ' %s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (binString(left, right), len(variants), titv, FPRate, phredScale(FPRate))
@ -273,11 +275,11 @@ def all( p, l ):
return True
def variantBinsForField(variants, field):
if not all( lambda x: x.hasField(field), variants):
raise Exception('Unknown field ' + field)
#if not all( lambda x: x.hasField(field), variants):
# raise Exception('Unknown field ' + field)
minValue, maxValue, range, bins = fieldRange(variants, field)
print 'Field range', minValue, maxValue, range
minValue, maxValue, bins = fieldRange(variants, field)
print 'Field range', minValue, maxValue
print 'Partitions', bins
return bins
@ -317,40 +319,52 @@ class CallCmp:
def FPRate(self):
return (1.0*self.nFP) / max(self.nTP + self.nFP, 1)
def FNRate(self):
return (1.0*self.nFN) / max(self.nTP + self.nFN, 1)
def __str__(self):
return 'TP=%6d FP=%6d FPRate=%.2f FN=%6d' % (self.nTP, self.nFP, self.FPRate(), self.nFN)
return 'TP=%6d FP=%6d FPRate=%.2f FN=%6d FNRate=%.2f' % (self.nTP, self.nFP, self.FPRate(), self.nFN, self.FNRate())
def variantInTruth(variant, truth):
return variant.getLoc() in truth
if variant.getLoc() in truth:
return truth[variant.getLoc()]
else:
return False
def sensitivitySpecificity(variants, truth):
nTP, nFP = 0, 0
FPs = []
for variant in variants:
if variantInTruth(variant, truth):
t = variantInTruth(variant, truth)
if t:
t.setField("FOUND", 1)
nTP += 1
else:
if OPTIONS.printFP: print 'FP:', variant
nFP += 1
#if variant.getPos() == 1520727:
# print "Variant is missing", variant
FPs.append(variant)
nFN = len(truth) - nTP
return CallCmp(nTP, nFP, nFN)
return CallCmp(nTP, nFP, nFN), FPs
def compareCalls(calls, truthCalls):
def compare1(cumulative):
for left, right, selectedVariants in mapVariantBins(calls, 'QUAL', cumulative = cumulative):
callComparison = sensitivitySpecificity(selectedVariants, truthCalls)
callComparison, theseFPs = sensitivitySpecificity(selectedVariants, truthCalls)
for fp in theseFPs: fp.setField("FP", 1)
#FPsVariants.append(theseFPs)
print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison
print 'PER BIN'
print 'PER BIN nCalls=', len(calls)
compare1(False)
print 'CUMULATIVE'
print 'CUMULATIVE nCalls=', len(calls)
compare1(True)
def randomSplit(l, pLeft):
import random
def keep(elt, p):
if p < pLeft:
return elt, None
@ -361,7 +375,7 @@ def randomSplit(l, pLeft):
return get(0), get(1)
def main():
global OPTIONS
global OPTIONS, header
usage = "usage: %prog files.list [options]"
parser = OptionParser(usage=usage)
parser.add_option("-f", "--f", dest="fields",
@ -370,6 +384,9 @@ def main():
parser.add_option("-t", "--truth", dest="truth",
type='string', default=None,
help="VCF formated truth file")
parser.add_option("", "--unFilteredTruth", dest="unFilteredTruth",
action='store_true', default=False,
help="If provided, the unfiltered truth calls will be used in comparisons")
parser.add_option("-p", "--partitions", dest="partitions",
type='int', default=25,
help="Number of partitions to examine")
@ -379,12 +396,18 @@ def main():
parser.add_option("-m", "--minVariantsPerBin", dest="minVariantsPerBin",
type='int', default=10,
help="")
parser.add_option("-M", "--maxRecords", dest="maxRecords",
type='int', default=None,
help="")
parser.add_option("-q", "--qMax", dest="maxQScore",
type='int', default=30,
help="")
parser.add_option("-o", "--outputVCF", dest="outputVCF",
type='string', default=None,
help="If provided, VCF file will be written out to this file")
parser.add_option("", "--FNoutputVCF", dest="FNoutputVCF",
type='string', default=None,
help="If provided, VCF file will be written out to this file")
parser.add_option("", "--titv", dest="titvTarget",
type='float', default=None,
help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls")
@ -394,15 +417,16 @@ def main():
parser.add_option("-b", "--bootstrap", dest="bootStrap",
type='float', default=0.0,
help="If provided, the % of the calls used to generate the recalibration tables.")
(OPTIONS, args) = parser.parse_args()
if len(args) > 2:
parser.error("incorrect number of arguments")
fields = OPTIONS.fields.split(',')
header, allCalls = readVariants(args[0])
header, allCalls = readVariants(args[0], OPTIONS.maxRecords)
print 'Read', len(allCalls), 'calls'
print 'header is', header
#print 'header is', header
if OPTIONS.titvTarget == None:
OPTIONS.titvTarget = titv(calls, VCFRecord.isKnown)
@ -422,10 +446,16 @@ def main():
recalEvalCalls = recalibrateCalls(callsToEval, fields, covariates)
printCallQuals(recalEvalCalls, OPTIONS.titvTarget, 'BOOTSTRAP EVAL CALLS')
truth = None
if len(args) > 1:
truthFile = args[1]
print 'Reading truth file', truthFile
truth = dict( [[v.getLoc(), v] for v in readVariants(truthFile)[1]])
rawTruth = readVariants(truthFile, maxRecords = None, decodeAll = False)[1]
def keepVariant(t):
#print t.getPos(), t.getLoc()
return OPTIONS.unFilteredTruth or t.passesFilters()
truth = dict( [[v.getLoc(), v] for v in filter(keepVariant, rawTruth)])
print len(rawTruth), len(truth)
print '--------------------------------------------------------------------------------'
print 'Comparing calls to truth', truthFile
@ -439,10 +469,27 @@ def main():
if OPTIONS.outputVCF:
f = open(OPTIONS.outputVCF, 'w')
print 'HEADER', header
#print 'HEADER', header
for line in formatVCF(header, allCalls):
print >> f, line
f.close()
if truth <> None and OPTIONS.FNoutputVCF:
f = open(OPTIONS.FNoutputVCF, 'w')
#print 'HEADER', header
for line in formatVCF(header, filter( lambda x: not x.hasField("FOUND"), truth.itervalues())):
print >> f, line
f.close()
PROFILE = False
if __name__ == "__main__":
main()
if PROFILE:
import cProfile
cProfile.run('main()', 'fooprof')
import pstats
p = pstats.Stats('fooprof')
p.sort_stats('cumulative').print_stats(10)
p.sort_stats('time').print_stats(10)
p.sort_stats('time', 'cum').print_stats(.5, 'init')
else:
main()

View File

@ -7,26 +7,28 @@ for p in ["AG", "CT"]:
TRANSITIONS[p] = True
TRANSITIONS[''.join(reversed(p))] = True
def convertToType(d):
def convertToType(d, onlyKeys = None):
out = dict()
types = [int, float, str]
for key, value in d.items():
for type in types:
try:
#print 'Parsing', key, value, type
out[key] = type(value)
#print ' Parsed as', key, value, type
break
except:
pass
if onlyKeys == None or key in onlyKeys:
for type in types:
try:
out[key] = type(value)
break
except:
pass
else:
out[key] = value
return out
class VCFRecord:
"""Simple support for accessing a VCF record"""
def __init__(self, basicBindings, header=None, rest=[]):
def __init__(self, basicBindings, header=None, rest=[], decodeAll = True):
self.header = header
self.info = convertToType(parseInfo(basicBindings["INFO"]))
self.bindings = convertToType(basicBindings)
self.info = parseInfo(basicBindings["INFO"])
if decodeAll: self.info = convertToType(self.info)
self.bindings = convertToType(basicBindings, onlyKeys = ['POS', 'QUAL'])
self.rest = rest
def hasHeader(self): return self.header <> None
@ -112,20 +114,30 @@ class VCFRecord:
return '\t'.join([str(self.getField(key)) for key in VCF_KEYS] + self.rest)
def parseInfo(s):
def handleBoolean(key_val):
if len(key_val) == 1:
return [key_val[0], 1]
d = dict()
for elt in s.split(";"):
if '=' in elt:
key, val = elt.split('=')
else:
return key_val
key, val = elt, 1
d[key] = val
return d
key_val = map( lambda x: handleBoolean(x.split("=")), s.split(";"))
return dict(key_val)
# def parseInfo(s):
# def handleBoolean(key_val):
# if len(key_val) == 1:
# return [key_val[0], 1]
# else:
# return key_val
#
# key_val = map( lambda x: handleBoolean(x.split("=")), s.split(";"))
# return dict(key_val)
def string2VCF(line, header=None):
def string2VCF(line, header=None, decodeAll = True):
if line[0] != "#":
s = line.split()
bindings = dict(zip(VCF_KEYS, s[0:8]))
return VCFRecord(bindings, header, rest=s[8:])
return VCFRecord(bindings, header, rest=s[8:], decodeAll = decodeAll)
else:
return None
@ -139,16 +151,18 @@ def readVCFHeader(lines):
if header <> []:
columnNames = header[-1]
return header, columnNames, itertools.chain([line], lines)
# we reach this point for empty files
return header, columnNames, []
def lines2VCF(lines, extendedOutput = False):
def lines2VCF(lines, extendedOutput = False, decodeAll = True):
header, columnNames, lines = readVCFHeader(lines)
counter = 0
for line in lines:
if line[0] != "#":
counter += 1
vcf = string2VCF(line, header=columnNames)
vcf = string2VCF(line, header=columnNames, decodeAll = decodeAll)
if vcf <> None:
if extendedOutput:
yield header, vcf, counter
@ -159,6 +173,6 @@ def lines2VCF(lines, extendedOutput = False):
def formatVCF(header, records):
#print records
print records[0]
#print records[0]
return itertools.chain(header, map(VCFRecord.format, records))