misc. bug fixes

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@2212 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
depristo 2009-12-02 14:36:03 +00:00
parent d487428468
commit 8a87d5add1
3 changed files with 58 additions and 19 deletions

View File

@ -69,7 +69,7 @@ class RecalibratedCall:
def readVariants( file, maxRecords = None, decodeAll = True, downsampleFraction = 1 ):
f = open(file)
header, ignore, lines = readVCFHeader(f)
header, columnNames, lines = readVCFHeader(f)
def parseVariant(args):
header1, VCF, counter = args
@ -78,7 +78,8 @@ def readVariants( file, maxRecords = None, decodeAll = True, downsampleFraction
else:
return None
return header, ifilter(None, imap(parseVariant, islice(lines2VCF(lines, extendedOutput = True, decodeAll = decodeAll), maxRecords)))
variants = ifilter(None, imap(parseVariant, islice(lines2VCF(lines, header=header, columnNames = columnNames, extendedOutput = True, decodeAll = decodeAll), maxRecords)))
return header, variants
def selectVariants( variants, selector = None ):
if selector <> None:
@ -378,25 +379,51 @@ def variantInTruth(variant, truth):
else:
return False
def isVariantInSample(t, sample):
#print "isVariantInSample", t.getLoc(), t.getField(sample), x
return t.getField(sample) <> "0/0"
def variantsInTruth(truth):
# fixme
return len(filter(lambda x: isVariantInSample(x, OPTIONS.useSample), truth))
def sensitivitySpecificity(variants, truth):
nTP, nFP = 0, 0
nTP, nFP = 0, 0
FPs = []
for variant in variants:
t = variantInTruth(variant, truth)
if t:
isTP, isFP = False, False
if OPTIONS.useSample:
if t: # we have a site
isTP = isVariantInSample(t, OPTIONS.useSample)
isFP = not isTP
else:
isTP = t
isFP = not t
#if variant.getLoc() == "1:867694":
# print variant, 'T: [', t, '] isTP, isFP', isTP, isFP
if isTP:
t.setField("FN", 0)
variant.setField("TP", 1)
nTP += 1
else:
elif isFP:
nFP += 1
#if variant.getPos() == 1520727:
# print "Variant is missing", variant
variant.setField("TP", 0)
#print t, variant, "is a FP!"
FPs.append(variant)
nFN = len(truth) - nTP
nFN = variantsInTruth(truth.itervalues()) - nTP
return CallCmp(nTP, nFP, nFN), FPs
def markTruth(calls):
if not OPTIONS.useSample:
for variant in calls.itervalues():
variant.setField("TP", 0) # set the TP field to 0
def compareCalls(calls, truthCalls):
for variant in calls: variant.setField("TP", 0) # set the TP field to 0
#markTruth(truthCalls)
def compare1(name, cumulative):
for field in ["QUAL", "OQ"]:
@ -405,7 +432,6 @@ def compareCalls(calls, truthCalls):
printFieldQual("truth-comparison-" + name, field, left, right, selectedVariants, dephredScale(left))
print 'PER BIN nCalls=', len(calls)
# printFieldQualHeader()
compare1('per-bin', False)
print 'CUMULATIVE nCalls=', len(calls)
@ -464,6 +490,9 @@ def setup():
parser.add_option("-b", "--bootstrap", dest="bootStrap",
type='float', default=None,
help="If provided, the % of the calls used to generate the recalibration tables. [default: %default]")
parser.add_option("-s", "--useSample", dest="useSample",
type='string', default=False,
help="If provided, we will examine sample genotypes for this sample, and consider TP/FP/FN in the truth conditional on sample genotypes [default: %default]")
parser.add_option("-r", "--dontRecalibrate", dest="dontRecalibrate",
action='store_true', default=False,
help="If provided, we will not actually do anything to the calls, they will just be assessed [default: %default]")
@ -521,12 +550,12 @@ def writeRecalibratedCalls(file, header, calls):
def readTruth(truthVCF):
print 'Reading truth file', truthVCF
rawTruth = list(readVariants(truthVCF, maxRecords = None, decodeAll = False)[1])
rawTruth = list(readVariants(truthVCF, maxRecords = None, decodeAll = True)[1])
def keepVariant(t):
#print t.getPos(), t.getLoc()
return OPTIONS.unFilteredTruth or t.passesFilters()
truth = dict( [[v.getLoc(), v] for v in filter(keepVariant, rawTruth)])
print len(rawTruth), len(truth)
print 'Number of raw and passing filter truth calls', len(rawTruth), len(truth)
return truth
def evaluateTruth(header, callVCF, truth, truthVCF):
@ -561,6 +590,8 @@ def main():
if len(args) > 1:
truthVCF = args[1]
TRUTH_CALLS = readTruth(truthVCF)
#for v in TRUTH_CALLS.itervalues(): print v.getField("NA12878")
#sys.exit(1)
if OPTIONS.recalLog <> None:
RECAL_LOG = open(OPTIONS.recalLog, "w")

View File

@ -23,7 +23,8 @@ if __name__ == "__main__":
counter = OPTIONS.skip
fields = OPTIONS.fields.split(',')
for vcf,count in lines2VCF(sys.stdin):
print sys.stdin
for header, vcf, count in lines2VCF(sys.stdin, extendedOutput = True):
#print vcf, count
if count == 1 and vcf.hasHeader():
print '\t'.join(fields)

View File

@ -30,11 +30,12 @@ def convertToType(chr, pos, d, onlyKeys = None):
class VCFRecord:
"""Simple support for accessing a VCF record"""
def __init__(self, basicBindings, header=None, rest=[], decodeAll = True):
def __init__(self, basicBindings, header=None, rest=[], moreFields = dict(), decodeAll = True):
self.header = header
self.info = parseInfo(basicBindings["INFO"])
chr, pos = basicBindings['CHROM'], basicBindings['POS']
self.bindings = convertToType(chr, pos, basicBindings, onlyKeys = ['POS', 'QUAL'])
self.bindings.update(moreFields)
if decodeAll: self.info = convertToType(chr, pos, self.info)
self.rest = rest
@ -144,7 +145,12 @@ def string2VCF(line, header=None, decodeAll = True):
if line[0] != "#":
s = line.split()
bindings = dict(zip(VCF_KEYS, s[0:8]))
return VCFRecord(bindings, header, rest=s[8:], decodeAll = decodeAll)
moreFields = dict()
#print 'HELLO', header, s, decodeAll
if header <> None and decodeAll:
moreFields = dict(zip(header[8:], s[8:]))
#print header, moreFields
return VCFRecord(bindings, header, rest=s[8:], moreFields = moreFields, decodeAll = decodeAll)
else:
return None
@ -156,7 +162,7 @@ def readVCFHeader(lines):
header.append(line.strip())
else:
if header <> []:
columnNames = header[-1]
columnNames = header[-1].strip("#").split()
return header, columnNames, itertools.chain([line], lines)
# we reach this point for empty files
@ -170,10 +176,11 @@ def quickCountRecords(lines):
return counter
def lines2VCF(lines, extendedOutput = False, decodeAll = True):
header, columnNames, lines = readVCFHeader(lines)
def lines2VCF(lines, extendedOutput = False, decodeAll = True, header=None, columnNames = None):
if header == None:
header, columnNames, lines = readVCFHeader(lines)
counter = 0
for line in lines:
if line[0] != "#":
counter += 1