snpSelector v2 -- code refactoring and support for comparison with known truth. Looks great.
git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@1986 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
parent
84ba604611
commit
f777c806d6
|
|
@ -192,12 +192,15 @@ def fieldRange(variants, field):
|
||||||
return minValue, maxValue, rangeValue, bins
|
return minValue, maxValue, rangeValue, bins
|
||||||
|
|
||||||
def printFieldQual( left, right, variants, titv, FPRate, nErrors ):
|
def printFieldQual( left, right, variants, titv, FPRate, nErrors ):
|
||||||
|
print ' %s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (binString(left, right), len(variants), titv, FPRate, phredScale(FPRate))
|
||||||
|
|
||||||
|
def binString(left, right):
|
||||||
leftStr = str(left)
|
leftStr = str(left)
|
||||||
if type(left) == float: leftStr = "%.2f" % left
|
if type(left) == float: leftStr = "%.2f" % left
|
||||||
rightStr = "%5s" % str(right)
|
rightStr = "%5s" % str(right)
|
||||||
if type(right) == float: rightStr = "%.2f" % right
|
if type(right) == float: rightStr = "%.2f" % right
|
||||||
#print 'FPRATe', FPRate, phredScale(FPRate)
|
return '%8s - %8s' % (leftStr, rightStr)
|
||||||
print ' %8s - %8s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (leftStr, rightStr, len(variants), titv, FPRate, phredScale(FPRate))
|
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
#
|
#
|
||||||
|
|
@ -226,25 +229,35 @@ def all( p, l ):
|
||||||
if not p(elt): return False
|
if not p(elt): return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def calibrateFeatures(variants, fields, titvTarget, updateCalls = True, printCall = False):
|
def variantBinsForField(variants, field):
|
||||||
if updateCalls: recalCalls = dict([[variant, RecalibratedCall(variant, fields)] for variant in variants])
|
|
||||||
|
|
||||||
for field in fields:
|
|
||||||
print 'Optimizing field', field
|
|
||||||
|
|
||||||
if not all( lambda x: x.hasField(field), variants):
|
if not all( lambda x: x.hasField(field), variants):
|
||||||
raise Exception('Unknown field ' + field)
|
raise Exception('Unknown field ' + field)
|
||||||
|
|
||||||
minValue, maxValue, range, bins = fieldRange(variants, field)
|
minValue, maxValue, range, bins = fieldRange(variants, field)
|
||||||
print 'Field range', minValue, maxValue, range
|
print 'Field range', minValue, maxValue, range
|
||||||
print 'Partitions', bins
|
print 'Partitions', bins
|
||||||
|
return bins
|
||||||
|
|
||||||
|
def mapVariantBins(variants, field):
|
||||||
|
bins = variantBinsForField(variants, field)
|
||||||
|
|
||||||
|
def variantsInBin(bin):
|
||||||
|
left, right = bin[0:2]
|
||||||
|
def select( variant ): return variant.getField(field) >= left and (right == '*' or variant.getField(field) <= right)
|
||||||
|
return left, right, selectVariants(variants, select)
|
||||||
|
|
||||||
|
return imap( variantsInBin, bins )
|
||||||
|
|
||||||
|
def calibrateFeatures(variants, fields, titvTarget, updateCalls = True, printCall = False):
|
||||||
|
if updateCalls: recalCalls = dict([[variant, RecalibratedCall(variant, fields)] for variant in variants])
|
||||||
|
|
||||||
|
for field in fields:
|
||||||
|
print 'Optimizing field', field
|
||||||
|
|
||||||
titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget)
|
titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget)
|
||||||
print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate)
|
print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate)
|
||||||
|
|
||||||
for left, right in map(lambda x: [x[0], x[1]], bins):
|
for left, right, selectedVariants in mapVariantBins(variants, field):
|
||||||
#print 'LR:', left, right
|
|
||||||
def select( variant ): return variant.getField(field) >= left and (right == '*' or variant.getField(field) <= right)
|
|
||||||
selectedVariants = selectVariants(variants, select)
|
|
||||||
if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1):
|
if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1):
|
||||||
titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget)
|
titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget)
|
||||||
if updateCalls:
|
if updateCalls:
|
||||||
|
|
@ -261,6 +274,40 @@ def calibrateFeatures(variants, fields, titvTarget, updateCalls = True, printCal
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
class CallCmp:
|
||||||
|
def __init__(self, nTP, nFP, nFN):
|
||||||
|
self.nTP = nTP
|
||||||
|
self.nFP = nFP
|
||||||
|
self.nFN = nFN
|
||||||
|
|
||||||
|
def FPRate(self):
|
||||||
|
return (1.0*self.nFP) / max(self.nTP + self.nFP, 1)
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return 'TP=%6d FP=%6d FPRate=%.2f FN=%6d' % (self.nTP, self.nFP, self.FPRate(), self.nFN)
|
||||||
|
|
||||||
|
def variantInTruth(variant, truth):
|
||||||
|
return variant.getLoc() in truth
|
||||||
|
|
||||||
|
def sensitivitySpecificity(variants, truth):
|
||||||
|
nTP, nFP = 0, 0
|
||||||
|
for variant in variants:
|
||||||
|
if variantInTruth(variant, truth):
|
||||||
|
nTP += 1
|
||||||
|
else:
|
||||||
|
if OPTIONS.printFP: print 'FP:', variant
|
||||||
|
nFP += 1
|
||||||
|
nFN = len(truth) - nTP
|
||||||
|
return CallCmp(nTP, nFP, nFN)
|
||||||
|
|
||||||
|
|
||||||
|
def compareCalls(optimizedCalls, truthCalls):
|
||||||
|
for left, right, selectedVariants in mapVariantBins(optimizedCalls, 'QUAL'):
|
||||||
|
callComparison = sensitivitySpecificity(selectedVariants, truthCalls)
|
||||||
|
print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
global OPTIONS
|
global OPTIONS
|
||||||
usage = "usage: %prog files.list [options]"
|
usage = "usage: %prog files.list [options]"
|
||||||
|
|
@ -272,7 +319,7 @@ def main():
|
||||||
type='string', default=None,
|
type='string', default=None,
|
||||||
help="VCF formated truth file")
|
help="VCF formated truth file")
|
||||||
parser.add_option("-p", "--partitions", dest="partitions",
|
parser.add_option("-p", "--partitions", dest="partitions",
|
||||||
type='int', default=10,
|
type='int', default=25,
|
||||||
help="Number of partitions to examine")
|
help="Number of partitions to examine")
|
||||||
parser.add_option("-s", "--s", dest="skip",
|
parser.add_option("-s", "--s", dest="skip",
|
||||||
type='int', default=1,
|
type='int', default=1,
|
||||||
|
|
@ -286,9 +333,12 @@ def main():
|
||||||
parser.add_option("", "--titv", dest="titvTarget",
|
parser.add_option("", "--titv", dest="titvTarget",
|
||||||
type='float', default=None,
|
type='float', default=None,
|
||||||
help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls")
|
help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls")
|
||||||
|
parser.add_option("", "--fp", dest="printFP",
|
||||||
|
action='store_true', default=False,
|
||||||
|
help="")
|
||||||
|
|
||||||
(OPTIONS, args) = parser.parse_args()
|
(OPTIONS, args) = parser.parse_args()
|
||||||
if len(args) != 2:
|
if len(args) > 2:
|
||||||
parser.error("incorrect number of arguments")
|
parser.error("incorrect number of arguments")
|
||||||
|
|
||||||
fields = OPTIONS.fields.split(',')
|
fields = OPTIONS.fields.split(',')
|
||||||
|
|
@ -301,7 +351,13 @@ def main():
|
||||||
print 'Ti/Tv known', titv(selectVariants(calls, VCFRecord.isKnown))
|
print 'Ti/Tv known', titv(selectVariants(calls, VCFRecord.isKnown))
|
||||||
print 'Ti/Tv novel', titv(selectVariants(calls, VCFRecord.isNovel))
|
print 'Ti/Tv novel', titv(selectVariants(calls, VCFRecord.isNovel))
|
||||||
|
|
||||||
optimizeCalls(calls, OPTIONS.fields.split(","), OPTIONS.titvTarget)
|
optimizedCalls = optimizeCalls(calls, OPTIONS.fields.split(","), OPTIONS.titvTarget)
|
||||||
|
|
||||||
|
if len(args) > 1:
|
||||||
|
truthFile = args[1]
|
||||||
|
print 'Reading truth file', truthFile
|
||||||
|
truth = dict( [[v.getLoc(), v] for v in readVariants(truthFile)])
|
||||||
|
compareCalls(optimizedCalls, truth)
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
@ -31,6 +31,7 @@ class VCFRecord:
|
||||||
|
|
||||||
def getChrom(self): return self.get("CHROM")
|
def getChrom(self): return self.get("CHROM")
|
||||||
def getPos(self): return self.get("POS")
|
def getPos(self): return self.get("POS")
|
||||||
|
def getLoc(self): return str(self.getChrom()) + ':' + str(self.getPos())
|
||||||
|
|
||||||
def getID(self): return self.get("ID")
|
def getID(self): return self.get("ID")
|
||||||
def isNovel(self): return self.getID() == "."
|
def isNovel(self): return self.getID() == "."
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue