diff --git a/python/snpSelector.py b/python/snpSelector.py index 0982e9e33..10d38b303 100755 --- a/python/snpSelector.py +++ b/python/snpSelector.py @@ -192,12 +192,15 @@ def fieldRange(variants, field): return minValue, maxValue, rangeValue, bins def printFieldQual( left, right, variants, titv, FPRate, nErrors ): + print ' %s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (binString(left, right), len(variants), titv, FPRate, phredScale(FPRate)) + +def binString(left, right): leftStr = str(left) if type(left) == float: leftStr = "%.2f" % left rightStr = "%5s" % str(right) if type(right) == float: rightStr = "%.2f" % right - #print 'FPRATe', FPRate, phredScale(FPRate) - print ' %8s - %8s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (leftStr, rightStr, len(variants), titv, FPRate, phredScale(FPRate)) + return '%8s - %8s' % (leftStr, rightStr) + # # # @@ -226,25 +229,35 @@ def all( p, l ): if not p(elt): return False return True +def variantBinsForField(variants, field): + if not all( lambda x: x.hasField(field), variants): + raise Exception('Unknown field ' + field) + + minValue, maxValue, range, bins = fieldRange(variants, field) + print 'Field range', minValue, maxValue, range + print 'Partitions', bins + return bins + +def mapVariantBins(variants, field): + bins = variantBinsForField(variants, field) + + def variantsInBin(bin): + left, right = bin[0:2] + def select( variant ): return variant.getField(field) >= left and (right == '*' or variant.getField(field) <= right) + return left, right, selectVariants(variants, select) + + return imap( variantsInBin, bins ) + def calibrateFeatures(variants, fields, titvTarget, updateCalls = True, printCall = False): if updateCalls: recalCalls = dict([[variant, RecalibratedCall(variant, fields)] for variant in variants]) for field in fields: print 'Optimizing field', field - if not all( lambda x: x.hasField(field), variants): - raise Exception('Unknown field ' + field) - - minValue, maxValue, range, bins = fieldRange(variants, field) - print 'Field range', minValue, maxValue, range - print 'Partitions', bins titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget) print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate) - - for left, right in map(lambda x: [x[0], x[1]], bins): - #print 'LR:', left, right - def select( variant ): return variant.getField(field) >= left and (right == '*' or variant.getField(field) <= right) - selectedVariants = selectVariants(variants, select) + + for left, right, selectedVariants in mapVariantBins(variants, field): if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1): titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget) if updateCalls: @@ -261,6 +274,40 @@ def calibrateFeatures(variants, fields, titvTarget, updateCalls = True, printCal else: return None +class CallCmp: + def __init__(self, nTP, nFP, nFN): + self.nTP = nTP + self.nFP = nFP + self.nFN = nFN + + def FPRate(self): + return (1.0*self.nFP) / max(self.nTP + self.nFP, 1) + + def __str__(self): + return 'TP=%6d FP=%6d FPRate=%.2f FN=%6d' % (self.nTP, self.nFP, self.FPRate(), self.nFN) + +def variantInTruth(variant, truth): + return variant.getLoc() in truth + +def sensitivitySpecificity(variants, truth): + nTP, nFP = 0, 0 + for variant in variants: + if variantInTruth(variant, truth): + nTP += 1 + else: + if OPTIONS.printFP: print 'FP:', variant + nFP += 1 + nFN = len(truth) - nTP + return CallCmp(nTP, nFP, nFN) + + +def compareCalls(optimizedCalls, truthCalls): + for left, right, selectedVariants in mapVariantBins(optimizedCalls, 'QUAL'): + callComparison = sensitivitySpecificity(selectedVariants, truthCalls) + print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison + + + def main(): global OPTIONS usage = "usage: %prog files.list [options]" @@ -272,7 +319,7 @@ def main(): type='string', default=None, help="VCF formated truth file") parser.add_option("-p", "--partitions", dest="partitions", - type='int', default=10, + type='int', default=25, help="Number of partitions to examine") parser.add_option("-s", "--s", dest="skip", type='int', default=1, @@ -286,9 +333,12 @@ def main(): parser.add_option("", "--titv", dest="titvTarget", type='float', default=None, help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls") - + parser.add_option("", "--fp", dest="printFP", + action='store_true', default=False, + help="") + (OPTIONS, args) = parser.parse_args() - if len(args) != 2: + if len(args) > 2: parser.error("incorrect number of arguments") fields = OPTIONS.fields.split(',') @@ -301,7 +351,13 @@ def main(): print 'Ti/Tv known', titv(selectVariants(calls, VCFRecord.isKnown)) print 'Ti/Tv novel', titv(selectVariants(calls, VCFRecord.isNovel)) - optimizeCalls(calls, OPTIONS.fields.split(","), OPTIONS.titvTarget) + optimizedCalls = optimizeCalls(calls, OPTIONS.fields.split(","), OPTIONS.titvTarget) + + if len(args) > 1: + truthFile = args[1] + print 'Reading truth file', truthFile + truth = dict( [[v.getLoc(), v] for v in readVariants(truthFile)]) + compareCalls(optimizedCalls, truth) if __name__ == "__main__": main() \ No newline at end of file diff --git a/python/vcfReader.py b/python/vcfReader.py index c8b8ec8e5..a806025a8 100755 --- a/python/vcfReader.py +++ b/python/vcfReader.py @@ -31,6 +31,7 @@ class VCFRecord: def getChrom(self): return self.get("CHROM") def getPos(self): return self.get("POS") + def getLoc(self): return str(self.getChrom()) + ':' + str(self.getPos()) def getID(self): return self.get("ID") def isNovel(self): return self.getID() == "."