diff --git a/python/snpSelector.py b/python/snpSelector.py index 19abce23f..1fd7999cd 100755 --- a/python/snpSelector.py +++ b/python/snpSelector.py @@ -69,7 +69,7 @@ class RecalibratedCall: def readVariants( file, maxRecords = None, decodeAll = True, downsampleFraction = 1 ): f = open(file) - header, ignore, lines = readVCFHeader(f) + header, columnNames, lines = readVCFHeader(f) def parseVariant(args): header1, VCF, counter = args @@ -78,7 +78,8 @@ def readVariants( file, maxRecords = None, decodeAll = True, downsampleFraction else: return None - return header, ifilter(None, imap(parseVariant, islice(lines2VCF(lines, extendedOutput = True, decodeAll = decodeAll), maxRecords))) + variants = ifilter(None, imap(parseVariant, islice(lines2VCF(lines, header=header, columnNames = columnNames, extendedOutput = True, decodeAll = decodeAll), maxRecords))) + return header, variants def selectVariants( variants, selector = None ): if selector <> None: @@ -378,25 +379,51 @@ def variantInTruth(variant, truth): else: return False +def isVariantInSample(t, sample): + #print "isVariantInSample", t.getLoc(), t.getField(sample), x + return t.getField(sample) <> "0/0" + +def variantsInTruth(truth): + # fixme + return len(filter(lambda x: isVariantInSample(x, OPTIONS.useSample), truth)) + def sensitivitySpecificity(variants, truth): - nTP, nFP = 0, 0 + nTP, nFP = 0, 0 FPs = [] for variant in variants: t = variantInTruth(variant, truth) - if t: + + isTP, isFP = False, False + if OPTIONS.useSample: + if t: # we have a site + isTP = isVariantInSample(t, OPTIONS.useSample) + isFP = not isTP + else: + isTP = t + isFP = not t + + #if variant.getLoc() == "1:867694": + # print variant, 'T: [', t, '] isTP, isFP', isTP, isFP + + if isTP: t.setField("FN", 0) variant.setField("TP", 1) nTP += 1 - else: + elif isFP: nFP += 1 - #if variant.getPos() == 1520727: - # print "Variant is missing", variant + variant.setField("TP", 0) + #print t, variant, "is a FP!" FPs.append(variant) - nFN = len(truth) - nTP + nFN = variantsInTruth(truth.itervalues()) - nTP return CallCmp(nTP, nFP, nFN), FPs +def markTruth(calls): + if not OPTIONS.useSample: + for variant in calls.itervalues(): + variant.setField("TP", 0) # set the TP field to 0 + def compareCalls(calls, truthCalls): - for variant in calls: variant.setField("TP", 0) # set the TP field to 0 + #markTruth(truthCalls) def compare1(name, cumulative): for field in ["QUAL", "OQ"]: @@ -405,7 +432,6 @@ def compareCalls(calls, truthCalls): printFieldQual("truth-comparison-" + name, field, left, right, selectedVariants, dephredScale(left)) print 'PER BIN nCalls=', len(calls) - # printFieldQualHeader() compare1('per-bin', False) print 'CUMULATIVE nCalls=', len(calls) @@ -464,6 +490,9 @@ def setup(): parser.add_option("-b", "--bootstrap", dest="bootStrap", type='float', default=None, help="If provided, the % of the calls used to generate the recalibration tables. [default: %default]") + parser.add_option("-s", "--useSample", dest="useSample", + type='string', default=False, + help="If provided, we will examine sample genotypes for this sample, and consider TP/FP/FN in the truth conditional on sample genotypes [default: %default]") parser.add_option("-r", "--dontRecalibrate", dest="dontRecalibrate", action='store_true', default=False, help="If provided, we will not actually do anything to the calls, they will just be assessed [default: %default]") @@ -521,12 +550,12 @@ def writeRecalibratedCalls(file, header, calls): def readTruth(truthVCF): print 'Reading truth file', truthVCF - rawTruth = list(readVariants(truthVCF, maxRecords = None, decodeAll = False)[1]) + rawTruth = list(readVariants(truthVCF, maxRecords = None, decodeAll = True)[1]) def keepVariant(t): #print t.getPos(), t.getLoc() return OPTIONS.unFilteredTruth or t.passesFilters() truth = dict( [[v.getLoc(), v] for v in filter(keepVariant, rawTruth)]) - print len(rawTruth), len(truth) + print 'Number of raw and passing filter truth calls', len(rawTruth), len(truth) return truth def evaluateTruth(header, callVCF, truth, truthVCF): @@ -561,6 +590,8 @@ def main(): if len(args) > 1: truthVCF = args[1] TRUTH_CALLS = readTruth(truthVCF) + #for v in TRUTH_CALLS.itervalues(): print v.getField("NA12878") + #sys.exit(1) if OPTIONS.recalLog <> None: RECAL_LOG = open(OPTIONS.recalLog, "w") diff --git a/python/vcf2table.py b/python/vcf2table.py index 955254ee0..1f2197f14 100755 --- a/python/vcf2table.py +++ b/python/vcf2table.py @@ -23,7 +23,8 @@ if __name__ == "__main__": counter = OPTIONS.skip fields = OPTIONS.fields.split(',') - for vcf,count in lines2VCF(sys.stdin): + print sys.stdin + for header, vcf, count in lines2VCF(sys.stdin, extendedOutput = True): #print vcf, count if count == 1 and vcf.hasHeader(): print '\t'.join(fields) diff --git a/python/vcfReader.py b/python/vcfReader.py index bb1857235..1d58963aa 100755 --- a/python/vcfReader.py +++ b/python/vcfReader.py @@ -30,11 +30,12 @@ def convertToType(chr, pos, d, onlyKeys = None): class VCFRecord: """Simple support for accessing a VCF record""" - def __init__(self, basicBindings, header=None, rest=[], decodeAll = True): + def __init__(self, basicBindings, header=None, rest=[], moreFields = dict(), decodeAll = True): self.header = header self.info = parseInfo(basicBindings["INFO"]) chr, pos = basicBindings['CHROM'], basicBindings['POS'] self.bindings = convertToType(chr, pos, basicBindings, onlyKeys = ['POS', 'QUAL']) + self.bindings.update(moreFields) if decodeAll: self.info = convertToType(chr, pos, self.info) self.rest = rest @@ -144,7 +145,12 @@ def string2VCF(line, header=None, decodeAll = True): if line[0] != "#": s = line.split() bindings = dict(zip(VCF_KEYS, s[0:8])) - return VCFRecord(bindings, header, rest=s[8:], decodeAll = decodeAll) + moreFields = dict() + #print 'HELLO', header, s, decodeAll + if header <> None and decodeAll: + moreFields = dict(zip(header[8:], s[8:])) + #print header, moreFields + return VCFRecord(bindings, header, rest=s[8:], moreFields = moreFields, decodeAll = decodeAll) else: return None @@ -156,7 +162,7 @@ def readVCFHeader(lines): header.append(line.strip()) else: if header <> []: - columnNames = header[-1] + columnNames = header[-1].strip("#").split() return header, columnNames, itertools.chain([line], lines) # we reach this point for empty files @@ -170,10 +176,11 @@ def quickCountRecords(lines): return counter -def lines2VCF(lines, extendedOutput = False, decodeAll = True): - header, columnNames, lines = readVCFHeader(lines) +def lines2VCF(lines, extendedOutput = False, decodeAll = True, header=None, columnNames = None): + if header == None: + header, columnNames, lines = readVCFHeader(lines) counter = 0 - + for line in lines: if line[0] != "#": counter += 1