General bug fixes for snpSelector. More robust error checking and handling of NaN values.
git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@2106 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
parent
797bb83209
commit
da7de9960b
|
|
@ -7,6 +7,7 @@ from itertools import *
|
|||
import math
|
||||
import random
|
||||
|
||||
DEBUG = False
|
||||
|
||||
class CallCovariate:
|
||||
def __init__(self, feature, left, right, FPRate = None, cumulative = False):
|
||||
|
|
@ -28,6 +29,8 @@ class CallCovariate:
|
|||
def getFeature(self): return self.feature
|
||||
|
||||
def getCovariateField(self): return self.getFeature() + '_RQ'
|
||||
|
||||
def __str__(self): return "[CC feature=%s left=%s right=%s]" % (self.feature, self.left, self.right)
|
||||
|
||||
class RecalibratedCall:
|
||||
def __init__(self, call, features):
|
||||
|
|
@ -35,7 +38,7 @@ class RecalibratedCall:
|
|||
self.features = dict([[feature, None] for feature in features])
|
||||
|
||||
def recalFeature( self, feature, FPRate ):
|
||||
assert self.features[feature] == None # not reassigning values
|
||||
assert self.features[feature] == None, "Feature " + feature + ' has value ' + str(self.features[feature]) + ' for call ' + str(self.call) # not reassigning values
|
||||
assert FPRate <= 1 and FPRate >= 0
|
||||
self.features[feature] = FPRate
|
||||
|
||||
|
|
@ -52,6 +55,7 @@ class RecalibratedCall:
|
|||
#print self.features
|
||||
logTPRates = [math.log10(1-r) for r in self.features.itervalues() if r <> None]
|
||||
logJointTPRate = reduce(lambda x, y: x + y, logTPRates, 0)
|
||||
logJointTPRate = min(logJointTPRate, 1e-3 / 3) # approximation from het of 0.001
|
||||
jointTPRate = math.pow(10, logJointTPRate)
|
||||
#print logTPRates
|
||||
#print logJointTPRate, jointTPRate
|
||||
|
|
@ -98,8 +102,6 @@ def gaussian(x, mu, sigma):
|
|||
exponent = -1 * ( x - mu )**2 / (2 * sigma**2)
|
||||
return constant * math.exp(exponent)
|
||||
|
||||
DEBUG = False
|
||||
|
||||
# if target = T, and FP calls have ti/tv = 0.5, we want to know how many FP calls
|
||||
# there are in N calls with ti/tv of X.
|
||||
#
|
||||
|
|
@ -138,15 +140,10 @@ def titvFPRateEstimate(variants, target):
|
|||
|
||||
FPRate = 1 - gaussianModel()
|
||||
nVariants = len(variants)
|
||||
if nVariants > 0:
|
||||
impliedNoErrors = nVariants * FPRate
|
||||
calcTiTv = (impliedNoErrors * 0.5 + target * (nVariants-impliedNoErrors)) / nVariants
|
||||
else:
|
||||
impliedNoErrors, calcTiTv = 0, 0
|
||||
|
||||
if DEBUG: print ':::', nVariants, titvRatio, target, ti, tv, FPRate, impliedNoErrors, calcTiTv
|
||||
if DEBUG: print ':::', nVariants, titvRatio, target, FPRate
|
||||
|
||||
return titvRatio, FPRate, impliedNoErrors
|
||||
return titvRatio, FPRate
|
||||
|
||||
def phredScale(errorRate):
|
||||
return -10 * math.log10(max(errorRate, 1e-10))
|
||||
|
|
@ -178,44 +175,66 @@ def frange6(*args):
|
|||
yield v
|
||||
v += step
|
||||
|
||||
def compareFieldValues( v1, v2 ):
|
||||
if type(v1) <> type(v2):
|
||||
#print 'Different types', type(v1), type(v2)
|
||||
c = cmp(type(v1), type(v2))
|
||||
else:
|
||||
c = cmp(v1, v2)
|
||||
#print 'Comparing %s %s = %s' % (v1, v2, c)
|
||||
return c
|
||||
|
||||
def calculateBins(variants, field, minValue, maxValue, partitions):
|
||||
sortedVariants = sorted(variants, key = lambda x: x.getField(field))
|
||||
sortedVariants = sorted(variants, key = lambda x: x.getField(field)) # cmp = compareFieldValues,
|
||||
sortedValues = map(lambda x: x.getField(field), sortedVariants)
|
||||
|
||||
targetBinSize = len(variants) / (1.0*partitions)
|
||||
#print sortedValues
|
||||
uniqBins = groupby(sortedValues)
|
||||
binsAndSizes = map(lambda x: [x[0], len(list(x[1]))], uniqBins)
|
||||
#print binsAndSizes
|
||||
#print 'BINS AND SIZES', binsAndSizes
|
||||
|
||||
def bin2Break(bin): return [bin[0], bin[0], bin[1]]
|
||||
bins = [bin2Break(binsAndSizes[0])]
|
||||
for bin in binsAndSizes[1:]:
|
||||
#print 'Breaks', bins
|
||||
#print ' Breaks', bins
|
||||
#print ' current bin', bin
|
||||
curSize = bin[1]
|
||||
prevSize = bins[-1][2]
|
||||
#print curSize, prevSize
|
||||
if curSize + prevSize > targetBinSize:
|
||||
#print ' => appending', bin2Break(bin)
|
||||
bins.append(bin2Break(bin))
|
||||
else:
|
||||
bins[-1][1] = bin[0]
|
||||
bins[-1][2] += curSize
|
||||
|
||||
#print 'Returning ', bins
|
||||
#sys.exit(1)
|
||||
return bins
|
||||
|
||||
#
|
||||
# def calculateBinsLinear(variants, minValue, maxValue, rangeValue, partitions):
|
||||
# breaks = list(frange6(minValue, maxValue, rangeValue / partitions))
|
||||
# if breaks[len(breaks)-1] <> maxValue:
|
||||
# breaks = breaks + ['*']
|
||||
# return zip(breaks, map( lambda x: x - 0.001, breaks[1:]))
|
||||
|
||||
def fieldRange(variants, field):
|
||||
values = map(lambda v: v.getField(field), variants)
|
||||
minValue = min(values)
|
||||
maxValue = max(values)
|
||||
#rangeValue = maxValue - minValue
|
||||
bins = calculateBins(variants, field, minValue, maxValue, OPTIONS.partitions)
|
||||
validateBins(bins)
|
||||
return minValue, maxValue, bins
|
||||
|
||||
def validateBins(bins):
|
||||
#print 'Bins are', bins
|
||||
for left1, right1, count1 in bins:
|
||||
for left2, right2, count2 in bins:
|
||||
def contains2(x):
|
||||
return left2 < x and x < right2
|
||||
|
||||
if left1 <> left2 and right1 <> right2:
|
||||
if None in [left1, left2, right1, right2]:
|
||||
pass # we're ok
|
||||
elif contains2(left1) or contains2(right2):
|
||||
raise Exception("Bad bins", left1, right1, left2, right2)
|
||||
|
||||
def printFieldQualHeader(more = ""):
|
||||
print ' field left right nvariants titv dbSNP fprate q', more
|
||||
|
||||
|
|
@ -245,10 +264,12 @@ def recalibrateCalls(variants, fields, callCovariates):
|
|||
FPR = callCovariate.getFPRate()
|
||||
recalCall.recalFeature(callCovariate.getFeature(), FPR)
|
||||
recalCall.call.setField(callCovariate.getCovariateField(), phred(FPR))
|
||||
|
||||
|
||||
recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate()))
|
||||
recalCall.call.setField('OQ', originalQual)
|
||||
#print 'recalibrating', variant.getLoc()
|
||||
#print ' =>', variant
|
||||
yield recalCall.call
|
||||
|
||||
#
|
||||
|
|
@ -299,15 +320,17 @@ def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulativ
|
|||
for field in fields:
|
||||
if DEBUG: print 'Optimizing field', field
|
||||
|
||||
titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget)
|
||||
titv, FPRate = titvFPRateEstimate(variants, titvTarget)
|
||||
#print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate)
|
||||
|
||||
for left, right, selectedVariants in mapVariantBins(variants, field, cumulative = cumulative):
|
||||
if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1):
|
||||
titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget)
|
||||
titv, FPRate = titvFPRateEstimate(selectedVariants, titvTarget)
|
||||
dbsnp = dbSNPRate(selectedVariants)
|
||||
covariates.append(CallCovariate(field, left, right, FPRate))
|
||||
printFieldQual(field, left, right, selectedVariants, FPRate )
|
||||
else:
|
||||
print 'Not calibrating bin', left, right, 'because it contains too few variants:', len(selectedVariants)
|
||||
|
||||
return covariates
|
||||
|
||||
|
|
@ -416,13 +439,16 @@ def setup():
|
|||
parser.add_option("-b", "--bootstrap", dest="bootStrap",
|
||||
type='float', default=None,
|
||||
help="If provided, the % of the calls used to generate the recalibration tables. [default: %default]")
|
||||
parser.add_option("-r", "--dontRecalibrate", dest="dontRecalibrate",
|
||||
action='store_true', default=False,
|
||||
help="If provided, we will not actually do anything to the calls, they will just be assessed [default: %default]")
|
||||
|
||||
(OPTIONS, args) = parser.parse_args()
|
||||
if len(args) > 2:
|
||||
parser.error("incorrect number of arguments")
|
||||
return args
|
||||
|
||||
def determineCovariates(file, fields):
|
||||
def assessCalls(file):
|
||||
print 'Counting records in VCF', file
|
||||
numberOfRecords = quickCountRecords(open(file))
|
||||
if OPTIONS.maxRecords <> None and OPTIONS.maxRecords < numberOfRecords:
|
||||
|
|
@ -433,23 +459,27 @@ def determineCovariates(file, fields):
|
|||
print 'Number of VCF records', numberOfRecords, ', max number of records for covariates is', OPTIONS.maxRecordsForCovariates, 'so keeping', downsampleFraction * 100, '% of records'
|
||||
print 'Number of selected VCF records', len(allCalls)
|
||||
|
||||
if OPTIONS.titvTarget == None:
|
||||
OPTIONS.titvTarget = titv(selectVariants(allCalls, VCFRecord.isKnown))
|
||||
titvtarget = OPTIONS.titvTarget
|
||||
if titvtarget == None:
|
||||
titvtarget = titv(selectVariants(allCalls, VCFRecord.isKnown))
|
||||
print 'Ti/Tv all ', titv(allCalls)
|
||||
print 'Ti/Tv known', titv(selectVariants(allCalls, VCFRecord.isKnown))
|
||||
print 'Ti/Tv novel', titv(selectVariants(allCalls, VCFRecord.isNovel))
|
||||
|
||||
return header, allCalls, titvtarget
|
||||
|
||||
def determineCovariates(allCalls, titvtarget, fields):
|
||||
if OPTIONS.bootStrap:
|
||||
callsToOptimize, recalEvalCalls = randomSplit(allCalls, OPTIONS.bootStrap)
|
||||
else:
|
||||
callsToOptimize = allCalls
|
||||
|
||||
recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, OPTIONS.titvTarget)
|
||||
printCallQuals(list(recalOptCalls), OPTIONS.titvTarget, 'OPTIMIZED CALLS')
|
||||
recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, titvtarget)
|
||||
printCallQuals(list(recalOptCalls), titvtarget, 'OPTIMIZED CALLS')
|
||||
|
||||
if OPTIONS.bootStrap:
|
||||
recalibatedEvalCalls = recalibrateCalls(recalEvalCalls, fields, covariates)
|
||||
printCallQuals(list(recalibatedEvalCalls), OPTIONS.titvTarget, 'BOOTSTRAP EVAL CALLS')
|
||||
printCallQuals(list(recalibatedEvalCalls), titvtarget, 'BOOTSTRAP EVAL CALLS')
|
||||
|
||||
return covariates
|
||||
|
||||
|
|
@ -494,16 +524,22 @@ def evaluateTruth(header, callVCF, truthVCF):
|
|||
|
||||
def main():
|
||||
args = setup()
|
||||
fields = OPTIONS.fields.split(',')
|
||||
|
||||
covariates = determineCovariates(args[0], fields)
|
||||
header, callsToRecalibate = readVariants(args[0], OPTIONS.maxRecords)
|
||||
RecalibratedCalls = recalibrateCalls(callsToRecalibate, fields, covariates)
|
||||
writeRecalibratedCalls(OPTIONS.outputVCF, header, RecalibratedCalls)
|
||||
header, allCalls, titvTarget = assessCalls(args[0])
|
||||
if not OPTIONS.dontRecalibrate:
|
||||
fields = OPTIONS.fields.split(',')
|
||||
covariates = determineCovariates(allCalls, titvTarget, fields)
|
||||
header, callsToRecalibate = readVariants(args[0], OPTIONS.maxRecords)
|
||||
RecalibratedCalls = recalibrateCalls(callsToRecalibate, fields, covariates)
|
||||
writeRecalibratedCalls(OPTIONS.outputVCF, header, RecalibratedCalls)
|
||||
else:
|
||||
printCallQuals(allCalls, titvTarget)
|
||||
OPTIONS.outputVCF = args[0]
|
||||
|
||||
if len(args) > 1:
|
||||
evaluateTruth(header, OPTIONS.outputVCF, args[1])
|
||||
|
||||
|
||||
PROFILE = False
|
||||
if __name__ == "__main__":
|
||||
if PROFILE:
|
||||
|
|
|
|||
|
|
@ -7,7 +7,10 @@ for p in ["AG", "CT"]:
|
|||
TRANSITIONS[p] = True
|
||||
TRANSITIONS[''.join(reversed(p))] = True
|
||||
|
||||
def convertToType(d, onlyKeys = None):
|
||||
def is_nan(x):
|
||||
return type(x) is float and x != x
|
||||
|
||||
def convertToType(chr, pos, d, onlyKeys = None):
|
||||
out = dict()
|
||||
types = [int, float, str]
|
||||
for key, value in d.items():
|
||||
|
|
@ -15,6 +18,9 @@ def convertToType(d, onlyKeys = None):
|
|||
for type in types:
|
||||
try:
|
||||
out[key] = type(value)
|
||||
if is_nan(out[key]):
|
||||
print 'Warning, nan found at %s:%s, using NaN string' % (chr, pos)
|
||||
out[key] = 'NaN'
|
||||
break
|
||||
except:
|
||||
pass
|
||||
|
|
@ -27,8 +33,9 @@ class VCFRecord:
|
|||
def __init__(self, basicBindings, header=None, rest=[], decodeAll = True):
|
||||
self.header = header
|
||||
self.info = parseInfo(basicBindings["INFO"])
|
||||
if decodeAll: self.info = convertToType(self.info)
|
||||
self.bindings = convertToType(basicBindings, onlyKeys = ['POS', 'QUAL'])
|
||||
chr, pos = basicBindings['CHROM'], basicBindings['POS']
|
||||
self.bindings = convertToType(chr, pos, basicBindings, onlyKeys = ['POS', 'QUAL'])
|
||||
if decodeAll: self.info = convertToType(chr, pos, self.info)
|
||||
self.rest = rest
|
||||
|
||||
def hasHeader(self): return self.header <> None
|
||||
|
|
|
|||
Loading…
Reference in New Issue