snpSelector v2 -- code refactoring and support for comparison with known truth. Looks great.
git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@1986 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
parent
84ba604611
commit
f777c806d6
|
|
@ -192,12 +192,15 @@ def fieldRange(variants, field):
|
|||
return minValue, maxValue, rangeValue, bins
|
||||
|
||||
def printFieldQual( left, right, variants, titv, FPRate, nErrors ):
|
||||
print ' %s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (binString(left, right), len(variants), titv, FPRate, phredScale(FPRate))
|
||||
|
||||
def binString(left, right):
|
||||
leftStr = str(left)
|
||||
if type(left) == float: leftStr = "%.2f" % left
|
||||
rightStr = "%5s" % str(right)
|
||||
if type(right) == float: rightStr = "%.2f" % right
|
||||
#print 'FPRATe', FPRate, phredScale(FPRate)
|
||||
print ' %8s - %8s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (leftStr, rightStr, len(variants), titv, FPRate, phredScale(FPRate))
|
||||
return '%8s - %8s' % (leftStr, rightStr)
|
||||
|
||||
#
|
||||
#
|
||||
#
|
||||
|
|
@ -226,25 +229,35 @@ def all( p, l ):
|
|||
if not p(elt): return False
|
||||
return True
|
||||
|
||||
def variantBinsForField(variants, field):
|
||||
if not all( lambda x: x.hasField(field), variants):
|
||||
raise Exception('Unknown field ' + field)
|
||||
|
||||
minValue, maxValue, range, bins = fieldRange(variants, field)
|
||||
print 'Field range', minValue, maxValue, range
|
||||
print 'Partitions', bins
|
||||
return bins
|
||||
|
||||
def mapVariantBins(variants, field):
|
||||
bins = variantBinsForField(variants, field)
|
||||
|
||||
def variantsInBin(bin):
|
||||
left, right = bin[0:2]
|
||||
def select( variant ): return variant.getField(field) >= left and (right == '*' or variant.getField(field) <= right)
|
||||
return left, right, selectVariants(variants, select)
|
||||
|
||||
return imap( variantsInBin, bins )
|
||||
|
||||
def calibrateFeatures(variants, fields, titvTarget, updateCalls = True, printCall = False):
|
||||
if updateCalls: recalCalls = dict([[variant, RecalibratedCall(variant, fields)] for variant in variants])
|
||||
|
||||
for field in fields:
|
||||
print 'Optimizing field', field
|
||||
|
||||
if not all( lambda x: x.hasField(field), variants):
|
||||
raise Exception('Unknown field ' + field)
|
||||
|
||||
minValue, maxValue, range, bins = fieldRange(variants, field)
|
||||
print 'Field range', minValue, maxValue, range
|
||||
print 'Partitions', bins
|
||||
titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget)
|
||||
print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate)
|
||||
|
||||
for left, right in map(lambda x: [x[0], x[1]], bins):
|
||||
#print 'LR:', left, right
|
||||
def select( variant ): return variant.getField(field) >= left and (right == '*' or variant.getField(field) <= right)
|
||||
selectedVariants = selectVariants(variants, select)
|
||||
|
||||
for left, right, selectedVariants in mapVariantBins(variants, field):
|
||||
if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1):
|
||||
titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget)
|
||||
if updateCalls:
|
||||
|
|
@ -261,6 +274,40 @@ def calibrateFeatures(variants, fields, titvTarget, updateCalls = True, printCal
|
|||
else:
|
||||
return None
|
||||
|
||||
class CallCmp:
|
||||
def __init__(self, nTP, nFP, nFN):
|
||||
self.nTP = nTP
|
||||
self.nFP = nFP
|
||||
self.nFN = nFN
|
||||
|
||||
def FPRate(self):
|
||||
return (1.0*self.nFP) / max(self.nTP + self.nFP, 1)
|
||||
|
||||
def __str__(self):
|
||||
return 'TP=%6d FP=%6d FPRate=%.2f FN=%6d' % (self.nTP, self.nFP, self.FPRate(), self.nFN)
|
||||
|
||||
def variantInTruth(variant, truth):
|
||||
return variant.getLoc() in truth
|
||||
|
||||
def sensitivitySpecificity(variants, truth):
|
||||
nTP, nFP = 0, 0
|
||||
for variant in variants:
|
||||
if variantInTruth(variant, truth):
|
||||
nTP += 1
|
||||
else:
|
||||
if OPTIONS.printFP: print 'FP:', variant
|
||||
nFP += 1
|
||||
nFN = len(truth) - nTP
|
||||
return CallCmp(nTP, nFP, nFN)
|
||||
|
||||
|
||||
def compareCalls(optimizedCalls, truthCalls):
|
||||
for left, right, selectedVariants in mapVariantBins(optimizedCalls, 'QUAL'):
|
||||
callComparison = sensitivitySpecificity(selectedVariants, truthCalls)
|
||||
print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison
|
||||
|
||||
|
||||
|
||||
def main():
|
||||
global OPTIONS
|
||||
usage = "usage: %prog files.list [options]"
|
||||
|
|
@ -272,7 +319,7 @@ def main():
|
|||
type='string', default=None,
|
||||
help="VCF formated truth file")
|
||||
parser.add_option("-p", "--partitions", dest="partitions",
|
||||
type='int', default=10,
|
||||
type='int', default=25,
|
||||
help="Number of partitions to examine")
|
||||
parser.add_option("-s", "--s", dest="skip",
|
||||
type='int', default=1,
|
||||
|
|
@ -286,9 +333,12 @@ def main():
|
|||
parser.add_option("", "--titv", dest="titvTarget",
|
||||
type='float', default=None,
|
||||
help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls")
|
||||
|
||||
parser.add_option("", "--fp", dest="printFP",
|
||||
action='store_true', default=False,
|
||||
help="")
|
||||
|
||||
(OPTIONS, args) = parser.parse_args()
|
||||
if len(args) != 2:
|
||||
if len(args) > 2:
|
||||
parser.error("incorrect number of arguments")
|
||||
|
||||
fields = OPTIONS.fields.split(',')
|
||||
|
|
@ -301,7 +351,13 @@ def main():
|
|||
print 'Ti/Tv known', titv(selectVariants(calls, VCFRecord.isKnown))
|
||||
print 'Ti/Tv novel', titv(selectVariants(calls, VCFRecord.isNovel))
|
||||
|
||||
optimizeCalls(calls, OPTIONS.fields.split(","), OPTIONS.titvTarget)
|
||||
optimizedCalls = optimizeCalls(calls, OPTIONS.fields.split(","), OPTIONS.titvTarget)
|
||||
|
||||
if len(args) > 1:
|
||||
truthFile = args[1]
|
||||
print 'Reading truth file', truthFile
|
||||
truth = dict( [[v.getLoc(), v] for v in readVariants(truthFile)])
|
||||
compareCalls(optimizedCalls, truth)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
@ -31,6 +31,7 @@ class VCFRecord:
|
|||
|
||||
def getChrom(self): return self.get("CHROM")
|
||||
def getPos(self): return self.get("POS")
|
||||
def getLoc(self): return str(self.getChrom()) + ':' + str(self.getPos())
|
||||
|
||||
def getID(self): return self.get("ID")
|
||||
def isNovel(self): return self.getID() == "."
|
||||
|
|
|
|||
Loading…
Reference in New Issue