From 52494d8176e484f04fdb48295d7c631134488aef Mon Sep 17 00:00:00 2001 From: depristo Date: Fri, 13 Nov 2009 21:46:31 +0000 Subject: [PATCH] cleanup of SNP selector -- ready for some additional testing git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@2042 348d0f76-0448-11de-a6fe-93d51630548a --- python/snpSelector.py | 201 +++++++++++++++++++++++------------------- python/vcfReader.py | 10 ++- 2 files changed, 121 insertions(+), 90 deletions(-) diff --git a/python/snpSelector.py b/python/snpSelector.py index ef061748a..e3c3ecdde 100755 --- a/python/snpSelector.py +++ b/python/snpSelector.py @@ -63,19 +63,18 @@ class RecalibratedCall: def __str__(self): return '[%s: %s => Q%d]' % (str(self.call), self.featureStringList(), phredScale(self.jointFPErrorRate())) -def readVariants( file, maxRecords = None, decodeAll = True ): - counter = OPTIONS.skip +def readVariants( file, maxRecords = None, decodeAll = True, downsampleFraction = 1 ): f = open(file) header, ignore, lines = readVCFHeader(f) def parseVariant(args): header1, VCF, counter = args - if counter % OPTIONS.skip == 0: + if random.random() <= downsampleFraction: return VCF else: return None - return header, filter(None, map(parseVariant, islice(lines2VCF(lines, extendedOutput = True, decodeAll = decodeAll), maxRecords))) + return header, ifilter(None, imap(parseVariant, islice(lines2VCF(lines, extendedOutput = True, decodeAll = decodeAll), maxRecords))) def selectVariants( variants, selector = None ): if selector <> None: @@ -88,8 +87,11 @@ def titv(variants): tv = len(variants) - ti titv = ti / (1.0*max(tv,1)) - return titv, ti, tv + return titv +def dbSNPRate(variants): + inDBSNP = len(filter(VCFRecord.isKnown, variants)) + return float(inDBSNP) / len(variants) def gaussian(x, mu, sigma): constant = 1 / math.sqrt(2 * math.pi * sigma**2) @@ -102,7 +104,7 @@ DEBUG = False # there are in N calls with ti/tv of X. # def titvFPRateEstimate(variants, target): - titvRatio, ti, tv = titv(variants) + titvRatio = titv(variants) # f <- function(To,T) { (To - T) / (1/2 - T) + 0.001 } def theoreticalCalc(): @@ -112,7 +114,7 @@ def titvFPRateEstimate(variants, target): FPRate = (titvRatio - target) / (0.5 - target) FPRate = min(max(FPRate, 0), 1) TPRate = max(min(1 - FPRate, 1 - dephredScale(OPTIONS.maxQScore)), dephredScale(OPTIONS.maxQScore)) - print 'FPRate', FPRate, titvRatio, target + if DEBUG: print 'FPRate', FPRate, titvRatio, target assert FPRate >= 0 and FPRate <= 1 return TPRate @@ -150,7 +152,7 @@ def phredScale(errorRate): return -10 * math.log10(max(errorRate, 1e-10)) def dephredScale(qscore): - return math.pow(10, qscore / -10) + return math.pow(10, float(qscore) / -10) def frange6(*args): """A float range generator.""" @@ -181,7 +183,6 @@ def calculateBins(variants, field, minValue, maxValue, partitions): sortedValues = map(lambda x: x.getField(field), sortedVariants) targetBinSize = len(variants) / (1.0*partitions) - print 'targetBinSize', targetBinSize uniqBins = groupby(sortedValues) binsAndSizes = map(lambda x: [x[0], len(list(x[1]))], uniqBins) #print binsAndSizes @@ -215,15 +216,18 @@ def fieldRange(variants, field): bins = calculateBins(variants, field, minValue, maxValue, OPTIONS.partitions) return minValue, maxValue, bins -def printFieldQual( left, right, variants, titv, FPRate, nErrors ): - print ' %s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (binString(left, right), len(variants), titv, FPRate, phredScale(FPRate)) +def printFieldQualHeader(more = ""): + print ' field left right nvariants titv dbSNP fprate q', more + +def printFieldQual( field, left, right, variants, FPRate, more = ""): + print ' %s %s %8d %.2f %.2f %.2e %d' % (field, binString(left, right), len(variants), titv(variants), dbSNPRate(variants), FPRate, phredScale(FPRate)), more def binString(left, right): leftStr = str(left) if type(left) == float: leftStr = "%.2f" % left rightStr = "%5s" % str(right) if type(right) == float: rightStr = "%.2f" % right - return '%8s - %8s' % (leftStr, rightStr) + return '%8s %8s' % (leftStr, rightStr) # @@ -232,7 +236,6 @@ def binString(left, right): def recalibrateCalls(variants, fields, callCovariates): def phred(v): return int(round(phredScale(v))) - newCalls = list() for variant in variants: recalCall = RecalibratedCall(variant, fields) originalQual = variant.getField('QUAL') @@ -245,9 +248,8 @@ def recalibrateCalls(variants, fields, callCovariates): recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate())) recalCall.call.setField('OQ', originalQual) - newCalls.append(recalCall.call) - - return newCalls + #print 'recalibrating', variant.getLoc() + yield recalCall.call # # @@ -258,9 +260,6 @@ def optimizeCalls(variants, fields, titvTarget): return recalCalls, callCovariates def printCallQuals(recalCalls, titvTarget, info = ""): - #for recalCall in islice(recalCalls, 10): - # print recalCall - print '--------------------------------------------------------------------------------' print info calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = False ) @@ -279,8 +278,8 @@ def variantBinsForField(variants, field): # raise Exception('Unknown field ' + field) minValue, maxValue, bins = fieldRange(variants, field) - print 'Field range', minValue, maxValue - print 'Partitions', bins + if DEBUG: print 'Field range', minValue, maxValue + if DEBUG: print 'Partitions', bins return bins def mapVariantBins(variants, field, cumulative = False): @@ -296,18 +295,19 @@ def mapVariantBins(variants, field, cumulative = False): def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulative = False ): covariates = [] + printFieldQualHeader() for field in fields: - print 'Optimizing field', field + if DEBUG: print 'Optimizing field', field titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget) - print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate) + #print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate) for left, right, selectedVariants in mapVariantBins(variants, field, cumulative = cumulative): if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1): titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget) + dbsnp = dbSNPRate(selectedVariants) covariates.append(CallCovariate(field, left, right, FPRate)) - printFieldQual( left, right, selectedVariants, titv, FPRate, nErrors ) - + printFieldQual(field, left, right, selectedVariants, FPRate ) return covariates @@ -324,7 +324,7 @@ class CallCmp: return (1.0*self.nFN) / max(self.nTP + self.nFN, 1) def __str__(self): - return 'TP=%6d FP=%6d FPRate=%.2f FN=%6d FNRate=%.2f' % (self.nTP, self.nFP, self.FPRate(), self.nFN, self.FNRate()) + return '%6d %6d %.2f %6d %.2f' % (self.nTP, self.nFP, self.FPRate(), self.nFN, self.FNRate()) def variantInTruth(variant, truth): if variant.getLoc() in truth: @@ -338,10 +338,10 @@ def sensitivitySpecificity(variants, truth): for variant in variants: t = variantInTruth(variant, truth) if t: - t.setField("FOUND", 1) + t.setField("FN", 0) + variant.setField("TP", 1) nTP += 1 else: - if OPTIONS.printFP: print 'FP:', variant nFP += 1 #if variant.getPos() == 1520727: # print "Variant is missing", variant @@ -351,18 +351,20 @@ def sensitivitySpecificity(variants, truth): def compareCalls(calls, truthCalls): - def compare1(cumulative): + for variant in calls: variant.setField("TP", 0) # set the TP field to 0 + + def compare1(name, cumulative): for left, right, selectedVariants in mapVariantBins(calls, 'QUAL', cumulative = cumulative): callComparison, theseFPs = sensitivitySpecificity(selectedVariants, truthCalls) - for fp in theseFPs: fp.setField("FP", 1) - #FPsVariants.append(theseFPs) - print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison + #print selectedVariants[0] + printFieldQual(name, left, right, selectedVariants, dephredScale(left), str(callComparison)) print 'PER BIN nCalls=', len(calls) - compare1(False) + printFieldQualHeader("TP FP FPRate FN FNRate") + compare1('TRUTH-PER-BIN', False) print 'CUMULATIVE nCalls=', len(calls) - compare1(True) + compare1('TRUTH-CUMULATIVE', True) def randomSplit(l, pLeft): def keep(elt, p): @@ -374,113 +376,134 @@ def randomSplit(l, pLeft): def get(i): return filter(lambda x: x <> None, [x[i] for x in data]) return get(0), get(1) -def main(): +def setup(): global OPTIONS, header usage = "usage: %prog files.list [options]" parser = OptionParser(usage=usage) parser.add_option("-f", "--f", dest="fields", type='string', default="QUAL", - help="Comma-separated list of fields to exact") + help="Comma-separated list of fields (either in the VCF columns of as INFO keys) to use during optimization [default: %default]") parser.add_option("-t", "--truth", dest="truth", type='string', default=None, - help="VCF formated truth file") + help="VCF formated truth file. If provided, the script will compare the input calls with the truth calls. It also emits calls tagged as TP and a separate file of FP calls") parser.add_option("", "--unFilteredTruth", dest="unFilteredTruth", action='store_true', default=False, - help="If provided, the unfiltered truth calls will be used in comparisons") + help="If provided, the unfiltered truth calls will be used in comparisons [default: %default]") parser.add_option("-p", "--partitions", dest="partitions", type='int', default=25, - help="Number of partitions to examine") - parser.add_option("-s", "--s", dest="skip", - type='int', default=1, - help="Only work with every 1 / skip records") + help="Number of partitions to use for each feature. Don't use so many that the number of variants per bin is very low. [default: %default]") + parser.add_option("", "--maxRecordsForCovariates", dest="maxRecordsForCovariates", + type='int', default=200000, + help="Derive covariate information from up to this many VCF records. For files with more than this number of records, the system downsamples the reads [default: %default]") parser.add_option("-m", "--minVariantsPerBin", dest="minVariantsPerBin", type='int', default=10, - help="") + help="Don't include any covariates with fewer than this number of variants in the bin, if such a thing happens. NEEDS TO BE FIXED") parser.add_option("-M", "--maxRecords", dest="maxRecords", type='int', default=None, - help="") + help="Maximum number of input VCF records to process, if provided. Default is all records") parser.add_option("-q", "--qMax", dest="maxQScore", type='int', default=30, - help="") + help="The maximum Q score allowed for both a single covariate and the overall QUAL score [default: %default]") parser.add_option("-o", "--outputVCF", dest="outputVCF", type='string', default=None, - help="If provided, VCF file will be written out to this file") + help="If provided, a VCF file will be written out to this file [default: %default]") parser.add_option("", "--FNoutputVCF", dest="FNoutputVCF", type='string', default=None, - help="If provided, VCF file will be written out to this file") + help="If provided, VCF file will be written out to this file [default: %default]") parser.add_option("", "--titv", dest="titvTarget", type='float', default=None, - help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls") - parser.add_option("", "--fp", dest="printFP", - action='store_true', default=False, - help="") + help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls [default: %default]") parser.add_option("-b", "--bootstrap", dest="bootStrap", - type='float', default=0.0, - help="If provided, the % of the calls used to generate the recalibration tables.") - + type='float', default=None, + help="If provided, the % of the calls used to generate the recalibration tables. [default: %default]") (OPTIONS, args) = parser.parse_args() if len(args) > 2: parser.error("incorrect number of arguments") + return args - fields = OPTIONS.fields.split(',') - header, allCalls = readVariants(args[0], OPTIONS.maxRecords) - print 'Read', len(allCalls), 'calls' - #print 'header is', header +def determineCovariates(file, fields): + print 'Counting records in VCF', file + numberOfRecords = quickCountRecords(open(file)) + if OPTIONS.maxRecords <> None and OPTIONS.maxRecords < numberOfRecords: + numberOfRecords = OPTIONS.maxRecords + downsampleFraction = min(float(OPTIONS.maxRecordsForCovariates) / numberOfRecords, 1) + header, allCalls = readVariants(file, OPTIONS.maxRecords, downsampleFraction=downsampleFraction) + allCalls = list(allCalls) + print 'Number of VCF records', numberOfRecords, ', max number of records for covariates is', OPTIONS.maxRecordsForCovariates, 'so keeping', downsampleFraction * 100, '% of records' + print 'Number of selected VCF records', len(allCalls) if OPTIONS.titvTarget == None: - OPTIONS.titvTarget = titv(calls, VCFRecord.isKnown) + OPTIONS.titvTarget = titv(selectVariants(allCalls, VCFRecord.isKnown)) print 'Ti/Tv all ', titv(allCalls) print 'Ti/Tv known', titv(selectVariants(allCalls, VCFRecord.isKnown)) print 'Ti/Tv novel', titv(selectVariants(allCalls, VCFRecord.isNovel)) if OPTIONS.bootStrap: - callsToOptimize, callsToEval = randomSplit(allCalls, OPTIONS.bootStrap) + callsToOptimize, recalEvalCalls = randomSplit(allCalls, OPTIONS.bootStrap) else: - callsToOptimize, callsToEval = allCalls, allCalls + callsToOptimize = allCalls recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, OPTIONS.titvTarget) - printCallQuals(recalOptCalls, OPTIONS.titvTarget, 'OPTIMIZED CALLS') - - if callsToEval <> callsToOptimize: - recalEvalCalls = recalibrateCalls(callsToEval, fields, covariates) - printCallQuals(recalEvalCalls, OPTIONS.titvTarget, 'BOOTSTRAP EVAL CALLS') + printCallQuals(list(recalOptCalls), OPTIONS.titvTarget, 'OPTIMIZED CALLS') - truth = None - if len(args) > 1: - truthFile = args[1] - print 'Reading truth file', truthFile - rawTruth = readVariants(truthFile, maxRecords = None, decodeAll = False)[1] - def keepVariant(t): - #print t.getPos(), t.getLoc() - return OPTIONS.unFilteredTruth or t.passesFilters() - truth = dict( [[v.getLoc(), v] for v in filter(keepVariant, rawTruth)]) - print len(rawTruth), len(truth) - - print '--------------------------------------------------------------------------------' - print 'Comparing calls to truth', truthFile - print '' + if OPTIONS.bootStrap: + recalibatedEvalCalls = recalibrateCalls(recalEvalCalls, fields, covariates) + printCallQuals(list(recalibatedEvalCalls), OPTIONS.titvTarget, 'BOOTSTRAP EVAL CALLS') - print 'Calls used in optimization' - compareCalls(recalOptCalls, truth) - if callsToEval <> callsToOptimize: - print 'Calls held in reserve (bootstrap)' - compareCalls(recalEvalCalls, truth) + return covariates - if OPTIONS.outputVCF: - f = open(OPTIONS.outputVCF, 'w') +def writeRecalibratedCalls(file, header, calls): + if file: + f = open(file, 'w') #print 'HEADER', header - for line in formatVCF(header, allCalls): + i = 0 + for line in formatVCF(header, calls): + if i % 10000 == 0: print 'writing VCF record', i + i += 1 print >> f, line f.close() +def evaluateTruth(header, callVCF, truthVCF): + print 'Reading truth file', truthVCF + rawTruth = list(readVariants(truthVCF, maxRecords = None, decodeAll = False)[1]) + def keepVariant(t): + #print t.getPos(), t.getLoc() + return OPTIONS.unFilteredTruth or t.passesFilters() + truth = dict( [[v.getLoc(), v] for v in filter(keepVariant, rawTruth)]) + print len(rawTruth), len(truth) + + print 'Reading variants back in from', callVCF + header, calls = readVariants(callVCF) + calls = list(calls) + + print '--------------------------------------------------------------------------------' + print 'Comparing calls to truth', truthVCF + print '' + + compareCalls(calls, truth) + + writeRecalibratedCalls(callVCF, header, calls) + if truth <> None and OPTIONS.FNoutputVCF: f = open(OPTIONS.FNoutputVCF, 'w') #print 'HEADER', header - for line in formatVCF(header, filter( lambda x: not x.hasField("FOUND"), truth.itervalues())): + for line in formatVCF(header, filter( lambda x: not x.hasField("TP"), truth.itervalues())): print >> f, line f.close() +def main(): + args = setup() + fields = OPTIONS.fields.split(',') + + covariates = determineCovariates(args[0], fields) + header, callsToRecalibate = readVariants(args[0], OPTIONS.maxRecords) + RecalibratedCalls = recalibrateCalls(callsToRecalibate, fields, covariates) + writeRecalibratedCalls(OPTIONS.outputVCF, header, RecalibratedCalls) + + if len(args) > 1: + evaluateTruth(header, OPTIONS.outputVCF, args[1]) + PROFILE = False if __name__ == "__main__": if PROFILE: diff --git a/python/vcfReader.py b/python/vcfReader.py index 4f6869119..60271260b 100755 --- a/python/vcfReader.py +++ b/python/vcfReader.py @@ -155,6 +155,14 @@ def readVCFHeader(lines): # we reach this point for empty files return header, columnNames, [] +def quickCountRecords(lines): + counter = 0 + for line in lines: + if line[0] != "#": + counter += 1 + return counter + + def lines2VCF(lines, extendedOutput = False, decodeAll = True): header, columnNames, lines = readVCFHeader(lines) counter = 0 @@ -174,5 +182,5 @@ def lines2VCF(lines, extendedOutput = False, decodeAll = True): def formatVCF(header, records): #print records #print records[0] - return itertools.chain(header, map(VCFRecord.format, records)) + return itertools.chain(header, itertools.imap(VCFRecord.format, records))