gatk-3.8/python/snpSelector.py

495 lines
17 KiB
Python
Executable File

import os.path
import sys
from optparse import OptionParser
from vcfReader import *
#import pylab
from itertools import *
import math
import random
class CallCovariate:
def __init__(self, feature, left, right, FPRate = None, cumulative = False):
self.feature = feature
self.left = left
if cumulative:
self.right = '*'
else:
self.right = right
self.FPRate = FPRate
def containsVariant(self, call):
fieldVal = call.getField(self.feature)
return fieldVal >= self.left and (self.right == '*' or fieldVal <= self.right)
def getFPRate(self): return self.FPRate
def getFeature(self): return self.feature
def getCovariateField(self): return self.getFeature() + '_RQ'
class RecalibratedCall:
def __init__(self, call, features):
self.call = call
self.features = dict([[feature, None] for feature in features])
def recalFeature( self, feature, FPRate ):
assert self.features[feature] == None # not reassigning values
assert FPRate <= 1 and FPRate >= 0
self.features[feature] = FPRate
def getFeature( self, feature, missingValue = None, phredScaleValue = False ):
v = self.features[feature]
if v == None:
return missingValue
elif phredScaleValue:
return phredScale(v)
else:
return v
def jointFPErrorRate(self):
#print self.features
logTPRates = [math.log10(1-r) for r in self.features.itervalues() if r <> None]
logJointTPRate = reduce(lambda x, y: x + y, logTPRates, 0)
jointTPRate = math.pow(10, logJointTPRate)
#print logTPRates
#print logJointTPRate, jointTPRate
return 1 - jointTPRate
def featureStringList(self):
return ','.join(map(lambda feature: '%s=Q%d' % (feature, self.getFeature(feature, '*', True)), self.features.iterkeys()))
def __str__(self):
return '[%s: %s => Q%d]' % (str(self.call), self.featureStringList(), phredScale(self.jointFPErrorRate()))
def readVariants( file, maxRecords = None, decodeAll = True ):
counter = OPTIONS.skip
f = open(file)
header, ignore, lines = readVCFHeader(f)
def parseVariant(args):
header1, VCF, counter = args
if counter % OPTIONS.skip == 0:
return VCF
else:
return None
return header, filter(None, map(parseVariant, islice(lines2VCF(lines, extendedOutput = True, decodeAll = decodeAll), maxRecords)))
def selectVariants( variants, selector = None ):
if selector <> None:
return filter(selector, variants)
else:
return variants
def titv(variants):
ti = len(filter(VCFRecord.isTransition, variants))
tv = len(variants) - ti
titv = ti / (1.0*max(tv,1))
return titv, ti, tv
def gaussian(x, mu, sigma):
constant = 1 / math.sqrt(2 * math.pi * sigma**2)
exponent = -1 * ( x - mu )**2 / (2 * sigma**2)
return constant * math.exp(exponent)
DEBUG = False
# if target = T, and FP calls have ti/tv = 0.5, we want to know how many FP calls
# there are in N calls with ti/tv of X.
#
def titvFPRateEstimate(variants, target):
titvRatio, ti, tv = titv(variants)
# f <- function(To,T) { (To - T) / (1/2 - T) + 0.001 }
def theoreticalCalc():
if titvRatio >= target:
FPRate = 0
else:
FPRate = (titvRatio - target) / (0.5 - target)
FPRate = min(max(FPRate, 0), 1)
TPRate = max(min(1 - FPRate, 1 - dephredScale(OPTIONS.maxQScore)), dephredScale(OPTIONS.maxQScore))
print 'FPRate', FPRate, titvRatio, target
assert FPRate >= 0 and FPRate <= 1
return TPRate
# gaussian model
def gaussianModel():
LEFT_HANDED = True
sigma = 5
constant = 1 / math.sqrt(2 * math.pi * sigma**2)
exponent = -1 * ( titvRatio - target )**2 / (2 * sigma**2)
TPRate = gaussian(titvRatio, target, sigma) / gaussian(target, target, sigma)
if LEFT_HANDED and titvRatio >= target:
TPRate = 1
TPRate -= dephredScale(OPTIONS.maxQScore)
if DEBUG: print 'TPRate', TPRate, constant, exponent, dephredScale(OPTIONS.maxQScore)
return TPRate
#denom = (0.2 - 0.8 * titvRatio)
#FPRate = 1
#if denom <> 0:
# FPRate = (1.0 / (target+1)) * (titvRatio - target) / denom
FPRate = 1 - gaussianModel()
nVariants = len(variants)
if nVariants > 0:
impliedNoErrors = nVariants * FPRate
calcTiTv = (impliedNoErrors * 0.5 + target * (nVariants-impliedNoErrors)) / nVariants
else:
impliedNoErrors, calcTiTv = 0, 0
if DEBUG: print ':::', nVariants, titvRatio, target, ti, tv, FPRate, impliedNoErrors, calcTiTv
return titvRatio, FPRate, impliedNoErrors
def phredScale(errorRate):
return -10 * math.log10(max(errorRate, 1e-10))
def dephredScale(qscore):
return math.pow(10, qscore / -10)
def frange6(*args):
"""A float range generator."""
start = 0.0
step = 1.0
l = len(args)
if l == 1:
end = args[0]
elif l == 2:
start, end = args
elif l == 3:
start, end, step = args
if step == 0.0:
raise ValueError, "step must not be zero"
else:
raise TypeError, "frange expects 1-3 arguments, got %d" % l
v = start
while True:
if (step > 0 and v >= end) or (step < 0 and v <= end):
raise StopIteration
yield v
v += step
def calculateBins(variants, field, minValue, maxValue, partitions):
sortedVariants = sorted(variants, key = lambda x: x.getField(field))
sortedValues = map(lambda x: x.getField(field), sortedVariants)
targetBinSize = len(variants) / (1.0*partitions)
print 'targetBinSize', targetBinSize
uniqBins = groupby(sortedValues)
binsAndSizes = map(lambda x: [x[0], len(list(x[1]))], uniqBins)
#print binsAndSizes
def bin2Break(bin): return [bin[0], bin[0], bin[1]]
bins = [bin2Break(binsAndSizes[0])]
for bin in binsAndSizes[1:]:
#print 'Breaks', bins
curSize = bin[1]
prevSize = bins[-1][2]
if curSize + prevSize > targetBinSize:
bins.append(bin2Break(bin))
else:
bins[-1][1] = bin[0]
bins[-1][2] += curSize
return bins
#
# def calculateBinsLinear(variants, minValue, maxValue, rangeValue, partitions):
# breaks = list(frange6(minValue, maxValue, rangeValue / partitions))
# if breaks[len(breaks)-1] <> maxValue:
# breaks = breaks + ['*']
# return zip(breaks, map( lambda x: x - 0.001, breaks[1:]))
def fieldRange(variants, field):
values = map(lambda v: v.getField(field), variants)
minValue = min(values)
maxValue = max(values)
#rangeValue = maxValue - minValue
bins = calculateBins(variants, field, minValue, maxValue, OPTIONS.partitions)
return minValue, maxValue, bins
def printFieldQual( left, right, variants, titv, FPRate, nErrors ):
print ' %s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (binString(left, right), len(variants), titv, FPRate, phredScale(FPRate))
def binString(left, right):
leftStr = str(left)
if type(left) == float: leftStr = "%.2f" % left
rightStr = "%5s" % str(right)
if type(right) == float: rightStr = "%.2f" % right
return '%8s - %8s' % (leftStr, rightStr)
#
#
#
def recalibrateCalls(variants, fields, callCovariates):
def phred(v): return int(round(phredScale(v)))
newCalls = list()
for variant in variants:
recalCall = RecalibratedCall(variant, fields)
originalQual = variant.getField('QUAL')
for callCovariate in callCovariates:
if callCovariate.containsVariant(variant):
FPR = callCovariate.getFPRate()
recalCall.recalFeature(callCovariate.getFeature(), FPR)
recalCall.call.setField(callCovariate.getCovariateField(), phred(FPR))
recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate()))
recalCall.call.setField('OQ', originalQual)
newCalls.append(recalCall.call)
return newCalls
#
#
#
def optimizeCalls(variants, fields, titvTarget):
callCovariates = calibrateFeatures(variants, fields, titvTarget)
recalCalls = recalibrateCalls(variants, fields, callCovariates)
return recalCalls, callCovariates
def printCallQuals(recalCalls, titvTarget, info = ""):
#for recalCall in islice(recalCalls, 10):
# print recalCall
print '--------------------------------------------------------------------------------'
print info
calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = False )
print 'Cumulative'
calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = True )
def all( p, l ):
for elt in l:
if not p(elt): return False
return True
def variantBinsForField(variants, field):
#if not all( lambda x: x.hasField(field), variants):
# raise Exception('Unknown field ' + field)
minValue, maxValue, bins = fieldRange(variants, field)
print 'Field range', minValue, maxValue
print 'Partitions', bins
return bins
def mapVariantBins(variants, field, cumulative = False):
bins = variantBinsForField(variants, field)
def variantsInBin(bin):
cc = CallCovariate(field, bin[0], bin[1], cumulative = cumulative)
return cc.left, cc.right, selectVariants(variants, lambda v: cc.containsVariant(v))
return imap( variantsInBin, bins )
def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulative = False ):
covariates = []
for field in fields:
print 'Optimizing field', field
titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget)
print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate)
for left, right, selectedVariants in mapVariantBins(variants, field, cumulative = cumulative):
if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1):
titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget)
covariates.append(CallCovariate(field, left, right, FPRate))
printFieldQual( left, right, selectedVariants, titv, FPRate, nErrors )
return covariates
class CallCmp:
def __init__(self, nTP, nFP, nFN):
self.nTP = nTP
self.nFP = nFP
self.nFN = nFN
def FPRate(self):
return (1.0*self.nFP) / max(self.nTP + self.nFP, 1)
def FNRate(self):
return (1.0*self.nFN) / max(self.nTP + self.nFN, 1)
def __str__(self):
return 'TP=%6d FP=%6d FPRate=%.2f FN=%6d FNRate=%.2f' % (self.nTP, self.nFP, self.FPRate(), self.nFN, self.FNRate())
def variantInTruth(variant, truth):
if variant.getLoc() in truth:
return truth[variant.getLoc()]
else:
return False
def sensitivitySpecificity(variants, truth):
nTP, nFP = 0, 0
FPs = []
for variant in variants:
t = variantInTruth(variant, truth)
if t:
t.setField("FOUND", 1)
nTP += 1
else:
if OPTIONS.printFP: print 'FP:', variant
nFP += 1
#if variant.getPos() == 1520727:
# print "Variant is missing", variant
FPs.append(variant)
nFN = len(truth) - nTP
return CallCmp(nTP, nFP, nFN), FPs
def compareCalls(calls, truthCalls):
def compare1(cumulative):
for left, right, selectedVariants in mapVariantBins(calls, 'QUAL', cumulative = cumulative):
callComparison, theseFPs = sensitivitySpecificity(selectedVariants, truthCalls)
for fp in theseFPs: fp.setField("FP", 1)
#FPsVariants.append(theseFPs)
print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison
print 'PER BIN nCalls=', len(calls)
compare1(False)
print 'CUMULATIVE nCalls=', len(calls)
compare1(True)
def randomSplit(l, pLeft):
def keep(elt, p):
if p < pLeft:
return elt, None
else:
return None, elt
data = [keep(elt, p) for elt, p in zip(l, map(lambda x: random.random(), l))]
def get(i): return filter(lambda x: x <> None, [x[i] for x in data])
return get(0), get(1)
def main():
global OPTIONS, header
usage = "usage: %prog files.list [options]"
parser = OptionParser(usage=usage)
parser.add_option("-f", "--f", dest="fields",
type='string', default="QUAL",
help="Comma-separated list of fields to exact")
parser.add_option("-t", "--truth", dest="truth",
type='string', default=None,
help="VCF formated truth file")
parser.add_option("", "--unFilteredTruth", dest="unFilteredTruth",
action='store_true', default=False,
help="If provided, the unfiltered truth calls will be used in comparisons")
parser.add_option("-p", "--partitions", dest="partitions",
type='int', default=25,
help="Number of partitions to examine")
parser.add_option("-s", "--s", dest="skip",
type='int', default=1,
help="Only work with every 1 / skip records")
parser.add_option("-m", "--minVariantsPerBin", dest="minVariantsPerBin",
type='int', default=10,
help="")
parser.add_option("-M", "--maxRecords", dest="maxRecords",
type='int', default=None,
help="")
parser.add_option("-q", "--qMax", dest="maxQScore",
type='int', default=30,
help="")
parser.add_option("-o", "--outputVCF", dest="outputVCF",
type='string', default=None,
help="If provided, VCF file will be written out to this file")
parser.add_option("", "--FNoutputVCF", dest="FNoutputVCF",
type='string', default=None,
help="If provided, VCF file will be written out to this file")
parser.add_option("", "--titv", dest="titvTarget",
type='float', default=None,
help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls")
parser.add_option("", "--fp", dest="printFP",
action='store_true', default=False,
help="")
parser.add_option("-b", "--bootstrap", dest="bootStrap",
type='float', default=0.0,
help="If provided, the % of the calls used to generate the recalibration tables.")
(OPTIONS, args) = parser.parse_args()
if len(args) > 2:
parser.error("incorrect number of arguments")
fields = OPTIONS.fields.split(',')
header, allCalls = readVariants(args[0], OPTIONS.maxRecords)
print 'Read', len(allCalls), 'calls'
#print 'header is', header
if OPTIONS.titvTarget == None:
OPTIONS.titvTarget = titv(calls, VCFRecord.isKnown)
print 'Ti/Tv all ', titv(allCalls)
print 'Ti/Tv known', titv(selectVariants(allCalls, VCFRecord.isKnown))
print 'Ti/Tv novel', titv(selectVariants(allCalls, VCFRecord.isNovel))
if OPTIONS.bootStrap:
callsToOptimize, callsToEval = randomSplit(allCalls, OPTIONS.bootStrap)
else:
callsToOptimize, callsToEval = allCalls, allCalls
recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, OPTIONS.titvTarget)
printCallQuals(recalOptCalls, OPTIONS.titvTarget, 'OPTIMIZED CALLS')
if callsToEval <> callsToOptimize:
recalEvalCalls = recalibrateCalls(callsToEval, fields, covariates)
printCallQuals(recalEvalCalls, OPTIONS.titvTarget, 'BOOTSTRAP EVAL CALLS')
truth = None
if len(args) > 1:
truthFile = args[1]
print 'Reading truth file', truthFile
rawTruth = readVariants(truthFile, maxRecords = None, decodeAll = False)[1]
def keepVariant(t):
#print t.getPos(), t.getLoc()
return OPTIONS.unFilteredTruth or t.passesFilters()
truth = dict( [[v.getLoc(), v] for v in filter(keepVariant, rawTruth)])
print len(rawTruth), len(truth)
print '--------------------------------------------------------------------------------'
print 'Comparing calls to truth', truthFile
print ''
print 'Calls used in optimization'
compareCalls(recalOptCalls, truth)
if callsToEval <> callsToOptimize:
print 'Calls held in reserve (bootstrap)'
compareCalls(recalEvalCalls, truth)
if OPTIONS.outputVCF:
f = open(OPTIONS.outputVCF, 'w')
#print 'HEADER', header
for line in formatVCF(header, allCalls):
print >> f, line
f.close()
if truth <> None and OPTIONS.FNoutputVCF:
f = open(OPTIONS.FNoutputVCF, 'w')
#print 'HEADER', header
for line in formatVCF(header, filter( lambda x: not x.hasField("FOUND"), truth.itervalues())):
print >> f, line
f.close()
PROFILE = False
if __name__ == "__main__":
if PROFILE:
import cProfile
cProfile.run('main()', 'fooprof')
import pstats
p = pstats.Stats('fooprof')
p.sort_stats('cumulative').print_stats(10)
p.sort_stats('time').print_stats(10)
p.sort_stats('time', 'cum').print_stats(.5, 'init')
else:
main()