snpSelector v2 -- code refactoring and support for comparison with known truth. Looks great.

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@1986 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
depristo 2009-11-07 19:32:12 +00:00
parent 84ba604611
commit f777c806d6
2 changed files with 74 additions and 17 deletions

View File

@ -192,12 +192,15 @@ def fieldRange(variants, field):
return minValue, maxValue, rangeValue, bins
def printFieldQual( left, right, variants, titv, FPRate, nErrors ):
print ' %s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (binString(left, right), len(variants), titv, FPRate, phredScale(FPRate))
def binString(left, right):
leftStr = str(left)
if type(left) == float: leftStr = "%.2f" % left
rightStr = "%5s" % str(right)
if type(right) == float: rightStr = "%.2f" % right
#print 'FPRATe', FPRate, phredScale(FPRate)
print ' %8s - %8s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (leftStr, rightStr, len(variants), titv, FPRate, phredScale(FPRate))
return '%8s - %8s' % (leftStr, rightStr)
#
#
#
@ -226,25 +229,35 @@ def all( p, l ):
if not p(elt): return False
return True
def variantBinsForField(variants, field):
if not all( lambda x: x.hasField(field), variants):
raise Exception('Unknown field ' + field)
minValue, maxValue, range, bins = fieldRange(variants, field)
print 'Field range', minValue, maxValue, range
print 'Partitions', bins
return bins
def mapVariantBins(variants, field):
bins = variantBinsForField(variants, field)
def variantsInBin(bin):
left, right = bin[0:2]
def select( variant ): return variant.getField(field) >= left and (right == '*' or variant.getField(field) <= right)
return left, right, selectVariants(variants, select)
return imap( variantsInBin, bins )
def calibrateFeatures(variants, fields, titvTarget, updateCalls = True, printCall = False):
if updateCalls: recalCalls = dict([[variant, RecalibratedCall(variant, fields)] for variant in variants])
for field in fields:
print 'Optimizing field', field
if not all( lambda x: x.hasField(field), variants):
raise Exception('Unknown field ' + field)
minValue, maxValue, range, bins = fieldRange(variants, field)
print 'Field range', minValue, maxValue, range
print 'Partitions', bins
titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget)
print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate)
for left, right in map(lambda x: [x[0], x[1]], bins):
#print 'LR:', left, right
def select( variant ): return variant.getField(field) >= left and (right == '*' or variant.getField(field) <= right)
selectedVariants = selectVariants(variants, select)
for left, right, selectedVariants in mapVariantBins(variants, field):
if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1):
titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget)
if updateCalls:
@ -261,6 +274,40 @@ def calibrateFeatures(variants, fields, titvTarget, updateCalls = True, printCal
else:
return None
class CallCmp:
def __init__(self, nTP, nFP, nFN):
self.nTP = nTP
self.nFP = nFP
self.nFN = nFN
def FPRate(self):
return (1.0*self.nFP) / max(self.nTP + self.nFP, 1)
def __str__(self):
return 'TP=%6d FP=%6d FPRate=%.2f FN=%6d' % (self.nTP, self.nFP, self.FPRate(), self.nFN)
def variantInTruth(variant, truth):
return variant.getLoc() in truth
def sensitivitySpecificity(variants, truth):
nTP, nFP = 0, 0
for variant in variants:
if variantInTruth(variant, truth):
nTP += 1
else:
if OPTIONS.printFP: print 'FP:', variant
nFP += 1
nFN = len(truth) - nTP
return CallCmp(nTP, nFP, nFN)
def compareCalls(optimizedCalls, truthCalls):
for left, right, selectedVariants in mapVariantBins(optimizedCalls, 'QUAL'):
callComparison = sensitivitySpecificity(selectedVariants, truthCalls)
print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison
def main():
global OPTIONS
usage = "usage: %prog files.list [options]"
@ -272,7 +319,7 @@ def main():
type='string', default=None,
help="VCF formated truth file")
parser.add_option("-p", "--partitions", dest="partitions",
type='int', default=10,
type='int', default=25,
help="Number of partitions to examine")
parser.add_option("-s", "--s", dest="skip",
type='int', default=1,
@ -286,9 +333,12 @@ def main():
parser.add_option("", "--titv", dest="titvTarget",
type='float', default=None,
help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls")
parser.add_option("", "--fp", dest="printFP",
action='store_true', default=False,
help="")
(OPTIONS, args) = parser.parse_args()
if len(args) != 2:
if len(args) > 2:
parser.error("incorrect number of arguments")
fields = OPTIONS.fields.split(',')
@ -301,7 +351,13 @@ def main():
print 'Ti/Tv known', titv(selectVariants(calls, VCFRecord.isKnown))
print 'Ti/Tv novel', titv(selectVariants(calls, VCFRecord.isNovel))
optimizeCalls(calls, OPTIONS.fields.split(","), OPTIONS.titvTarget)
optimizedCalls = optimizeCalls(calls, OPTIONS.fields.split(","), OPTIONS.titvTarget)
if len(args) > 1:
truthFile = args[1]
print 'Reading truth file', truthFile
truth = dict( [[v.getLoc(), v] for v in readVariants(truthFile)])
compareCalls(optimizedCalls, truth)
if __name__ == "__main__":
main()

View File

@ -31,6 +31,7 @@ class VCFRecord:
def getChrom(self): return self.get("CHROM")
def getPos(self): return self.get("POS")
def getLoc(self): return str(self.getChrom()) + ':' + str(self.getPos())
def getID(self): return self.get("ID")
def isNovel(self): return self.getID() == "."