diff --git a/python/snpSelector.py b/python/snpSelector.py index e3c3ecdde..9d498c1e9 100755 --- a/python/snpSelector.py +++ b/python/snpSelector.py @@ -7,6 +7,7 @@ from itertools import * import math import random +DEBUG = False class CallCovariate: def __init__(self, feature, left, right, FPRate = None, cumulative = False): @@ -28,6 +29,8 @@ class CallCovariate: def getFeature(self): return self.feature def getCovariateField(self): return self.getFeature() + '_RQ' + + def __str__(self): return "[CC feature=%s left=%s right=%s]" % (self.feature, self.left, self.right) class RecalibratedCall: def __init__(self, call, features): @@ -35,7 +38,7 @@ class RecalibratedCall: self.features = dict([[feature, None] for feature in features]) def recalFeature( self, feature, FPRate ): - assert self.features[feature] == None # not reassigning values + assert self.features[feature] == None, "Feature " + feature + ' has value ' + str(self.features[feature]) + ' for call ' + str(self.call) # not reassigning values assert FPRate <= 1 and FPRate >= 0 self.features[feature] = FPRate @@ -52,6 +55,7 @@ class RecalibratedCall: #print self.features logTPRates = [math.log10(1-r) for r in self.features.itervalues() if r <> None] logJointTPRate = reduce(lambda x, y: x + y, logTPRates, 0) + logJointTPRate = min(logJointTPRate, 1e-3 / 3) # approximation from het of 0.001 jointTPRate = math.pow(10, logJointTPRate) #print logTPRates #print logJointTPRate, jointTPRate @@ -98,8 +102,6 @@ def gaussian(x, mu, sigma): exponent = -1 * ( x - mu )**2 / (2 * sigma**2) return constant * math.exp(exponent) -DEBUG = False - # if target = T, and FP calls have ti/tv = 0.5, we want to know how many FP calls # there are in N calls with ti/tv of X. # @@ -138,15 +140,10 @@ def titvFPRateEstimate(variants, target): FPRate = 1 - gaussianModel() nVariants = len(variants) - if nVariants > 0: - impliedNoErrors = nVariants * FPRate - calcTiTv = (impliedNoErrors * 0.5 + target * (nVariants-impliedNoErrors)) / nVariants - else: - impliedNoErrors, calcTiTv = 0, 0 - if DEBUG: print ':::', nVariants, titvRatio, target, ti, tv, FPRate, impliedNoErrors, calcTiTv + if DEBUG: print ':::', nVariants, titvRatio, target, FPRate - return titvRatio, FPRate, impliedNoErrors + return titvRatio, FPRate def phredScale(errorRate): return -10 * math.log10(max(errorRate, 1e-10)) @@ -178,44 +175,66 @@ def frange6(*args): yield v v += step +def compareFieldValues( v1, v2 ): + if type(v1) <> type(v2): + #print 'Different types', type(v1), type(v2) + c = cmp(type(v1), type(v2)) + else: + c = cmp(v1, v2) + #print 'Comparing %s %s = %s' % (v1, v2, c) + return c + def calculateBins(variants, field, minValue, maxValue, partitions): - sortedVariants = sorted(variants, key = lambda x: x.getField(field)) + sortedVariants = sorted(variants, key = lambda x: x.getField(field)) # cmp = compareFieldValues, sortedValues = map(lambda x: x.getField(field), sortedVariants) targetBinSize = len(variants) / (1.0*partitions) + #print sortedValues uniqBins = groupby(sortedValues) binsAndSizes = map(lambda x: [x[0], len(list(x[1]))], uniqBins) - #print binsAndSizes + #print 'BINS AND SIZES', binsAndSizes def bin2Break(bin): return [bin[0], bin[0], bin[1]] bins = [bin2Break(binsAndSizes[0])] for bin in binsAndSizes[1:]: - #print 'Breaks', bins + #print ' Breaks', bins + #print ' current bin', bin curSize = bin[1] prevSize = bins[-1][2] + #print curSize, prevSize if curSize + prevSize > targetBinSize: + #print ' => appending', bin2Break(bin) bins.append(bin2Break(bin)) else: bins[-1][1] = bin[0] bins[-1][2] += curSize + #print 'Returning ', bins + #sys.exit(1) return bins -# -# def calculateBinsLinear(variants, minValue, maxValue, rangeValue, partitions): -# breaks = list(frange6(minValue, maxValue, rangeValue / partitions)) -# if breaks[len(breaks)-1] <> maxValue: -# breaks = breaks + ['*'] -# return zip(breaks, map( lambda x: x - 0.001, breaks[1:])) - def fieldRange(variants, field): values = map(lambda v: v.getField(field), variants) minValue = min(values) maxValue = max(values) #rangeValue = maxValue - minValue bins = calculateBins(variants, field, minValue, maxValue, OPTIONS.partitions) + validateBins(bins) return minValue, maxValue, bins +def validateBins(bins): + #print 'Bins are', bins + for left1, right1, count1 in bins: + for left2, right2, count2 in bins: + def contains2(x): + return left2 < x and x < right2 + + if left1 <> left2 and right1 <> right2: + if None in [left1, left2, right1, right2]: + pass # we're ok + elif contains2(left1) or contains2(right2): + raise Exception("Bad bins", left1, right1, left2, right2) + def printFieldQualHeader(more = ""): print ' field left right nvariants titv dbSNP fprate q', more @@ -245,10 +264,12 @@ def recalibrateCalls(variants, fields, callCovariates): FPR = callCovariate.getFPRate() recalCall.recalFeature(callCovariate.getFeature(), FPR) recalCall.call.setField(callCovariate.getCovariateField(), phred(FPR)) + recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate())) recalCall.call.setField('OQ', originalQual) #print 'recalibrating', variant.getLoc() + #print ' =>', variant yield recalCall.call # @@ -299,15 +320,17 @@ def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulativ for field in fields: if DEBUG: print 'Optimizing field', field - titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget) + titv, FPRate = titvFPRateEstimate(variants, titvTarget) #print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate) for left, right, selectedVariants in mapVariantBins(variants, field, cumulative = cumulative): if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1): - titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget) + titv, FPRate = titvFPRateEstimate(selectedVariants, titvTarget) dbsnp = dbSNPRate(selectedVariants) covariates.append(CallCovariate(field, left, right, FPRate)) printFieldQual(field, left, right, selectedVariants, FPRate ) + else: + print 'Not calibrating bin', left, right, 'because it contains too few variants:', len(selectedVariants) return covariates @@ -416,13 +439,16 @@ def setup(): parser.add_option("-b", "--bootstrap", dest="bootStrap", type='float', default=None, help="If provided, the % of the calls used to generate the recalibration tables. [default: %default]") + parser.add_option("-r", "--dontRecalibrate", dest="dontRecalibrate", + action='store_true', default=False, + help="If provided, we will not actually do anything to the calls, they will just be assessed [default: %default]") (OPTIONS, args) = parser.parse_args() if len(args) > 2: parser.error("incorrect number of arguments") return args -def determineCovariates(file, fields): +def assessCalls(file): print 'Counting records in VCF', file numberOfRecords = quickCountRecords(open(file)) if OPTIONS.maxRecords <> None and OPTIONS.maxRecords < numberOfRecords: @@ -433,23 +459,27 @@ def determineCovariates(file, fields): print 'Number of VCF records', numberOfRecords, ', max number of records for covariates is', OPTIONS.maxRecordsForCovariates, 'so keeping', downsampleFraction * 100, '% of records' print 'Number of selected VCF records', len(allCalls) - if OPTIONS.titvTarget == None: - OPTIONS.titvTarget = titv(selectVariants(allCalls, VCFRecord.isKnown)) + titvtarget = OPTIONS.titvTarget + if titvtarget == None: + titvtarget = titv(selectVariants(allCalls, VCFRecord.isKnown)) print 'Ti/Tv all ', titv(allCalls) print 'Ti/Tv known', titv(selectVariants(allCalls, VCFRecord.isKnown)) print 'Ti/Tv novel', titv(selectVariants(allCalls, VCFRecord.isNovel)) + + return header, allCalls, titvtarget +def determineCovariates(allCalls, titvtarget, fields): if OPTIONS.bootStrap: callsToOptimize, recalEvalCalls = randomSplit(allCalls, OPTIONS.bootStrap) else: callsToOptimize = allCalls - recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, OPTIONS.titvTarget) - printCallQuals(list(recalOptCalls), OPTIONS.titvTarget, 'OPTIMIZED CALLS') + recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, titvtarget) + printCallQuals(list(recalOptCalls), titvtarget, 'OPTIMIZED CALLS') if OPTIONS.bootStrap: recalibatedEvalCalls = recalibrateCalls(recalEvalCalls, fields, covariates) - printCallQuals(list(recalibatedEvalCalls), OPTIONS.titvTarget, 'BOOTSTRAP EVAL CALLS') + printCallQuals(list(recalibatedEvalCalls), titvtarget, 'BOOTSTRAP EVAL CALLS') return covariates @@ -494,16 +524,22 @@ def evaluateTruth(header, callVCF, truthVCF): def main(): args = setup() - fields = OPTIONS.fields.split(',') - covariates = determineCovariates(args[0], fields) - header, callsToRecalibate = readVariants(args[0], OPTIONS.maxRecords) - RecalibratedCalls = recalibrateCalls(callsToRecalibate, fields, covariates) - writeRecalibratedCalls(OPTIONS.outputVCF, header, RecalibratedCalls) + header, allCalls, titvTarget = assessCalls(args[0]) + if not OPTIONS.dontRecalibrate: + fields = OPTIONS.fields.split(',') + covariates = determineCovariates(allCalls, titvTarget, fields) + header, callsToRecalibate = readVariants(args[0], OPTIONS.maxRecords) + RecalibratedCalls = recalibrateCalls(callsToRecalibate, fields, covariates) + writeRecalibratedCalls(OPTIONS.outputVCF, header, RecalibratedCalls) + else: + printCallQuals(allCalls, titvTarget) + OPTIONS.outputVCF = args[0] if len(args) > 1: evaluateTruth(header, OPTIONS.outputVCF, args[1]) + PROFILE = False if __name__ == "__main__": if PROFILE: diff --git a/python/vcfReader.py b/python/vcfReader.py index 60271260b..c54972544 100755 --- a/python/vcfReader.py +++ b/python/vcfReader.py @@ -7,7 +7,10 @@ for p in ["AG", "CT"]: TRANSITIONS[p] = True TRANSITIONS[''.join(reversed(p))] = True -def convertToType(d, onlyKeys = None): +def is_nan(x): + return type(x) is float and x != x + +def convertToType(chr, pos, d, onlyKeys = None): out = dict() types = [int, float, str] for key, value in d.items(): @@ -15,6 +18,9 @@ def convertToType(d, onlyKeys = None): for type in types: try: out[key] = type(value) + if is_nan(out[key]): + print 'Warning, nan found at %s:%s, using NaN string' % (chr, pos) + out[key] = 'NaN' break except: pass @@ -27,8 +33,9 @@ class VCFRecord: def __init__(self, basicBindings, header=None, rest=[], decodeAll = True): self.header = header self.info = parseInfo(basicBindings["INFO"]) - if decodeAll: self.info = convertToType(self.info) - self.bindings = convertToType(basicBindings, onlyKeys = ['POS', 'QUAL']) + chr, pos = basicBindings['CHROM'], basicBindings['POS'] + self.bindings = convertToType(chr, pos, basicBindings, onlyKeys = ['POS', 'QUAL']) + if decodeAll: self.info = convertToType(chr, pos, self.info) self.rest = rest def hasHeader(self): return self.header <> None