General bug fixes for snpSelector. More robust error checking and handling of NaN values.

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@2106 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
depristo 2009-11-21 14:48:29 +00:00
parent 797bb83209
commit da7de9960b
2 changed files with 79 additions and 36 deletions

View File

@ -7,6 +7,7 @@ from itertools import *
import math
import random
DEBUG = False
class CallCovariate:
def __init__(self, feature, left, right, FPRate = None, cumulative = False):
@ -28,6 +29,8 @@ class CallCovariate:
def getFeature(self): return self.feature
def getCovariateField(self): return self.getFeature() + '_RQ'
def __str__(self): return "[CC feature=%s left=%s right=%s]" % (self.feature, self.left, self.right)
class RecalibratedCall:
def __init__(self, call, features):
@ -35,7 +38,7 @@ class RecalibratedCall:
self.features = dict([[feature, None] for feature in features])
def recalFeature( self, feature, FPRate ):
assert self.features[feature] == None # not reassigning values
assert self.features[feature] == None, "Feature " + feature + ' has value ' + str(self.features[feature]) + ' for call ' + str(self.call) # not reassigning values
assert FPRate <= 1 and FPRate >= 0
self.features[feature] = FPRate
@ -52,6 +55,7 @@ class RecalibratedCall:
#print self.features
logTPRates = [math.log10(1-r) for r in self.features.itervalues() if r <> None]
logJointTPRate = reduce(lambda x, y: x + y, logTPRates, 0)
logJointTPRate = min(logJointTPRate, 1e-3 / 3) # approximation from het of 0.001
jointTPRate = math.pow(10, logJointTPRate)
#print logTPRates
#print logJointTPRate, jointTPRate
@ -98,8 +102,6 @@ def gaussian(x, mu, sigma):
exponent = -1 * ( x - mu )**2 / (2 * sigma**2)
return constant * math.exp(exponent)
DEBUG = False
# if target = T, and FP calls have ti/tv = 0.5, we want to know how many FP calls
# there are in N calls with ti/tv of X.
#
@ -138,15 +140,10 @@ def titvFPRateEstimate(variants, target):
FPRate = 1 - gaussianModel()
nVariants = len(variants)
if nVariants > 0:
impliedNoErrors = nVariants * FPRate
calcTiTv = (impliedNoErrors * 0.5 + target * (nVariants-impliedNoErrors)) / nVariants
else:
impliedNoErrors, calcTiTv = 0, 0
if DEBUG: print ':::', nVariants, titvRatio, target, ti, tv, FPRate, impliedNoErrors, calcTiTv
if DEBUG: print ':::', nVariants, titvRatio, target, FPRate
return titvRatio, FPRate, impliedNoErrors
return titvRatio, FPRate
def phredScale(errorRate):
return -10 * math.log10(max(errorRate, 1e-10))
@ -178,44 +175,66 @@ def frange6(*args):
yield v
v += step
def compareFieldValues( v1, v2 ):
if type(v1) <> type(v2):
#print 'Different types', type(v1), type(v2)
c = cmp(type(v1), type(v2))
else:
c = cmp(v1, v2)
#print 'Comparing %s %s = %s' % (v1, v2, c)
return c
def calculateBins(variants, field, minValue, maxValue, partitions):
sortedVariants = sorted(variants, key = lambda x: x.getField(field))
sortedVariants = sorted(variants, key = lambda x: x.getField(field)) # cmp = compareFieldValues,
sortedValues = map(lambda x: x.getField(field), sortedVariants)
targetBinSize = len(variants) / (1.0*partitions)
#print sortedValues
uniqBins = groupby(sortedValues)
binsAndSizes = map(lambda x: [x[0], len(list(x[1]))], uniqBins)
#print binsAndSizes
#print 'BINS AND SIZES', binsAndSizes
def bin2Break(bin): return [bin[0], bin[0], bin[1]]
bins = [bin2Break(binsAndSizes[0])]
for bin in binsAndSizes[1:]:
#print 'Breaks', bins
#print ' Breaks', bins
#print ' current bin', bin
curSize = bin[1]
prevSize = bins[-1][2]
#print curSize, prevSize
if curSize + prevSize > targetBinSize:
#print ' => appending', bin2Break(bin)
bins.append(bin2Break(bin))
else:
bins[-1][1] = bin[0]
bins[-1][2] += curSize
#print 'Returning ', bins
#sys.exit(1)
return bins
#
# def calculateBinsLinear(variants, minValue, maxValue, rangeValue, partitions):
# breaks = list(frange6(minValue, maxValue, rangeValue / partitions))
# if breaks[len(breaks)-1] <> maxValue:
# breaks = breaks + ['*']
# return zip(breaks, map( lambda x: x - 0.001, breaks[1:]))
def fieldRange(variants, field):
values = map(lambda v: v.getField(field), variants)
minValue = min(values)
maxValue = max(values)
#rangeValue = maxValue - minValue
bins = calculateBins(variants, field, minValue, maxValue, OPTIONS.partitions)
validateBins(bins)
return minValue, maxValue, bins
def validateBins(bins):
#print 'Bins are', bins
for left1, right1, count1 in bins:
for left2, right2, count2 in bins:
def contains2(x):
return left2 < x and x < right2
if left1 <> left2 and right1 <> right2:
if None in [left1, left2, right1, right2]:
pass # we're ok
elif contains2(left1) or contains2(right2):
raise Exception("Bad bins", left1, right1, left2, right2)
def printFieldQualHeader(more = ""):
print ' field left right nvariants titv dbSNP fprate q', more
@ -245,10 +264,12 @@ def recalibrateCalls(variants, fields, callCovariates):
FPR = callCovariate.getFPRate()
recalCall.recalFeature(callCovariate.getFeature(), FPR)
recalCall.call.setField(callCovariate.getCovariateField(), phred(FPR))
recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate()))
recalCall.call.setField('OQ', originalQual)
#print 'recalibrating', variant.getLoc()
#print ' =>', variant
yield recalCall.call
#
@ -299,15 +320,17 @@ def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulativ
for field in fields:
if DEBUG: print 'Optimizing field', field
titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget)
titv, FPRate = titvFPRateEstimate(variants, titvTarget)
#print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate)
for left, right, selectedVariants in mapVariantBins(variants, field, cumulative = cumulative):
if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1):
titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget)
titv, FPRate = titvFPRateEstimate(selectedVariants, titvTarget)
dbsnp = dbSNPRate(selectedVariants)
covariates.append(CallCovariate(field, left, right, FPRate))
printFieldQual(field, left, right, selectedVariants, FPRate )
else:
print 'Not calibrating bin', left, right, 'because it contains too few variants:', len(selectedVariants)
return covariates
@ -416,13 +439,16 @@ def setup():
parser.add_option("-b", "--bootstrap", dest="bootStrap",
type='float', default=None,
help="If provided, the % of the calls used to generate the recalibration tables. [default: %default]")
parser.add_option("-r", "--dontRecalibrate", dest="dontRecalibrate",
action='store_true', default=False,
help="If provided, we will not actually do anything to the calls, they will just be assessed [default: %default]")
(OPTIONS, args) = parser.parse_args()
if len(args) > 2:
parser.error("incorrect number of arguments")
return args
def determineCovariates(file, fields):
def assessCalls(file):
print 'Counting records in VCF', file
numberOfRecords = quickCountRecords(open(file))
if OPTIONS.maxRecords <> None and OPTIONS.maxRecords < numberOfRecords:
@ -433,23 +459,27 @@ def determineCovariates(file, fields):
print 'Number of VCF records', numberOfRecords, ', max number of records for covariates is', OPTIONS.maxRecordsForCovariates, 'so keeping', downsampleFraction * 100, '% of records'
print 'Number of selected VCF records', len(allCalls)
if OPTIONS.titvTarget == None:
OPTIONS.titvTarget = titv(selectVariants(allCalls, VCFRecord.isKnown))
titvtarget = OPTIONS.titvTarget
if titvtarget == None:
titvtarget = titv(selectVariants(allCalls, VCFRecord.isKnown))
print 'Ti/Tv all ', titv(allCalls)
print 'Ti/Tv known', titv(selectVariants(allCalls, VCFRecord.isKnown))
print 'Ti/Tv novel', titv(selectVariants(allCalls, VCFRecord.isNovel))
return header, allCalls, titvtarget
def determineCovariates(allCalls, titvtarget, fields):
if OPTIONS.bootStrap:
callsToOptimize, recalEvalCalls = randomSplit(allCalls, OPTIONS.bootStrap)
else:
callsToOptimize = allCalls
recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, OPTIONS.titvTarget)
printCallQuals(list(recalOptCalls), OPTIONS.titvTarget, 'OPTIMIZED CALLS')
recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, titvtarget)
printCallQuals(list(recalOptCalls), titvtarget, 'OPTIMIZED CALLS')
if OPTIONS.bootStrap:
recalibatedEvalCalls = recalibrateCalls(recalEvalCalls, fields, covariates)
printCallQuals(list(recalibatedEvalCalls), OPTIONS.titvTarget, 'BOOTSTRAP EVAL CALLS')
printCallQuals(list(recalibatedEvalCalls), titvtarget, 'BOOTSTRAP EVAL CALLS')
return covariates
@ -494,16 +524,22 @@ def evaluateTruth(header, callVCF, truthVCF):
def main():
args = setup()
fields = OPTIONS.fields.split(',')
covariates = determineCovariates(args[0], fields)
header, callsToRecalibate = readVariants(args[0], OPTIONS.maxRecords)
RecalibratedCalls = recalibrateCalls(callsToRecalibate, fields, covariates)
writeRecalibratedCalls(OPTIONS.outputVCF, header, RecalibratedCalls)
header, allCalls, titvTarget = assessCalls(args[0])
if not OPTIONS.dontRecalibrate:
fields = OPTIONS.fields.split(',')
covariates = determineCovariates(allCalls, titvTarget, fields)
header, callsToRecalibate = readVariants(args[0], OPTIONS.maxRecords)
RecalibratedCalls = recalibrateCalls(callsToRecalibate, fields, covariates)
writeRecalibratedCalls(OPTIONS.outputVCF, header, RecalibratedCalls)
else:
printCallQuals(allCalls, titvTarget)
OPTIONS.outputVCF = args[0]
if len(args) > 1:
evaluateTruth(header, OPTIONS.outputVCF, args[1])
PROFILE = False
if __name__ == "__main__":
if PROFILE:

View File

@ -7,7 +7,10 @@ for p in ["AG", "CT"]:
TRANSITIONS[p] = True
TRANSITIONS[''.join(reversed(p))] = True
def convertToType(d, onlyKeys = None):
def is_nan(x):
return type(x) is float and x != x
def convertToType(chr, pos, d, onlyKeys = None):
out = dict()
types = [int, float, str]
for key, value in d.items():
@ -15,6 +18,9 @@ def convertToType(d, onlyKeys = None):
for type in types:
try:
out[key] = type(value)
if is_nan(out[key]):
print 'Warning, nan found at %s:%s, using NaN string' % (chr, pos)
out[key] = 'NaN'
break
except:
pass
@ -27,8 +33,9 @@ class VCFRecord:
def __init__(self, basicBindings, header=None, rest=[], decodeAll = True):
self.header = header
self.info = parseInfo(basicBindings["INFO"])
if decodeAll: self.info = convertToType(self.info)
self.bindings = convertToType(basicBindings, onlyKeys = ['POS', 'QUAL'])
chr, pos = basicBindings['CHROM'], basicBindings['POS']
self.bindings = convertToType(chr, pos, basicBindings, onlyKeys = ['POS', 'QUAL'])
if decodeAll: self.info = convertToType(chr, pos, self.info)
self.rest = rest
def hasHeader(self): return self.header <> None