cleanup of SNP selector -- ready for some additional testing
git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@2042 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
parent
8eff1cc436
commit
52494d8176
|
|
@ -63,19 +63,18 @@ class RecalibratedCall:
|
|||
def __str__(self):
|
||||
return '[%s: %s => Q%d]' % (str(self.call), self.featureStringList(), phredScale(self.jointFPErrorRate()))
|
||||
|
||||
def readVariants( file, maxRecords = None, decodeAll = True ):
|
||||
counter = OPTIONS.skip
|
||||
def readVariants( file, maxRecords = None, decodeAll = True, downsampleFraction = 1 ):
|
||||
f = open(file)
|
||||
header, ignore, lines = readVCFHeader(f)
|
||||
|
||||
def parseVariant(args):
|
||||
header1, VCF, counter = args
|
||||
if counter % OPTIONS.skip == 0:
|
||||
if random.random() <= downsampleFraction:
|
||||
return VCF
|
||||
else:
|
||||
return None
|
||||
|
||||
return header, filter(None, map(parseVariant, islice(lines2VCF(lines, extendedOutput = True, decodeAll = decodeAll), maxRecords)))
|
||||
return header, ifilter(None, imap(parseVariant, islice(lines2VCF(lines, extendedOutput = True, decodeAll = decodeAll), maxRecords)))
|
||||
|
||||
def selectVariants( variants, selector = None ):
|
||||
if selector <> None:
|
||||
|
|
@ -88,8 +87,11 @@ def titv(variants):
|
|||
tv = len(variants) - ti
|
||||
titv = ti / (1.0*max(tv,1))
|
||||
|
||||
return titv, ti, tv
|
||||
return titv
|
||||
|
||||
def dbSNPRate(variants):
|
||||
inDBSNP = len(filter(VCFRecord.isKnown, variants))
|
||||
return float(inDBSNP) / len(variants)
|
||||
|
||||
def gaussian(x, mu, sigma):
|
||||
constant = 1 / math.sqrt(2 * math.pi * sigma**2)
|
||||
|
|
@ -102,7 +104,7 @@ DEBUG = False
|
|||
# there are in N calls with ti/tv of X.
|
||||
#
|
||||
def titvFPRateEstimate(variants, target):
|
||||
titvRatio, ti, tv = titv(variants)
|
||||
titvRatio = titv(variants)
|
||||
|
||||
# f <- function(To,T) { (To - T) / (1/2 - T) + 0.001 }
|
||||
def theoreticalCalc():
|
||||
|
|
@ -112,7 +114,7 @@ def titvFPRateEstimate(variants, target):
|
|||
FPRate = (titvRatio - target) / (0.5 - target)
|
||||
FPRate = min(max(FPRate, 0), 1)
|
||||
TPRate = max(min(1 - FPRate, 1 - dephredScale(OPTIONS.maxQScore)), dephredScale(OPTIONS.maxQScore))
|
||||
print 'FPRate', FPRate, titvRatio, target
|
||||
if DEBUG: print 'FPRate', FPRate, titvRatio, target
|
||||
assert FPRate >= 0 and FPRate <= 1
|
||||
return TPRate
|
||||
|
||||
|
|
@ -150,7 +152,7 @@ def phredScale(errorRate):
|
|||
return -10 * math.log10(max(errorRate, 1e-10))
|
||||
|
||||
def dephredScale(qscore):
|
||||
return math.pow(10, qscore / -10)
|
||||
return math.pow(10, float(qscore) / -10)
|
||||
|
||||
def frange6(*args):
|
||||
"""A float range generator."""
|
||||
|
|
@ -181,7 +183,6 @@ def calculateBins(variants, field, minValue, maxValue, partitions):
|
|||
sortedValues = map(lambda x: x.getField(field), sortedVariants)
|
||||
|
||||
targetBinSize = len(variants) / (1.0*partitions)
|
||||
print 'targetBinSize', targetBinSize
|
||||
uniqBins = groupby(sortedValues)
|
||||
binsAndSizes = map(lambda x: [x[0], len(list(x[1]))], uniqBins)
|
||||
#print binsAndSizes
|
||||
|
|
@ -215,15 +216,18 @@ def fieldRange(variants, field):
|
|||
bins = calculateBins(variants, field, minValue, maxValue, OPTIONS.partitions)
|
||||
return minValue, maxValue, bins
|
||||
|
||||
def printFieldQual( left, right, variants, titv, FPRate, nErrors ):
|
||||
print ' %s nVariants=%8d titv=%.2f FPRate=%.2e Q%d' % (binString(left, right), len(variants), titv, FPRate, phredScale(FPRate))
|
||||
def printFieldQualHeader(more = ""):
|
||||
print ' field left right nvariants titv dbSNP fprate q', more
|
||||
|
||||
def printFieldQual( field, left, right, variants, FPRate, more = ""):
|
||||
print ' %s %s %8d %.2f %.2f %.2e %d' % (field, binString(left, right), len(variants), titv(variants), dbSNPRate(variants), FPRate, phredScale(FPRate)), more
|
||||
|
||||
def binString(left, right):
|
||||
leftStr = str(left)
|
||||
if type(left) == float: leftStr = "%.2f" % left
|
||||
rightStr = "%5s" % str(right)
|
||||
if type(right) == float: rightStr = "%.2f" % right
|
||||
return '%8s - %8s' % (leftStr, rightStr)
|
||||
return '%8s %8s' % (leftStr, rightStr)
|
||||
|
||||
|
||||
#
|
||||
|
|
@ -232,7 +236,6 @@ def binString(left, right):
|
|||
def recalibrateCalls(variants, fields, callCovariates):
|
||||
def phred(v): return int(round(phredScale(v)))
|
||||
|
||||
newCalls = list()
|
||||
for variant in variants:
|
||||
recalCall = RecalibratedCall(variant, fields)
|
||||
originalQual = variant.getField('QUAL')
|
||||
|
|
@ -245,9 +248,8 @@ def recalibrateCalls(variants, fields, callCovariates):
|
|||
|
||||
recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate()))
|
||||
recalCall.call.setField('OQ', originalQual)
|
||||
newCalls.append(recalCall.call)
|
||||
|
||||
return newCalls
|
||||
#print 'recalibrating', variant.getLoc()
|
||||
yield recalCall.call
|
||||
|
||||
#
|
||||
#
|
||||
|
|
@ -258,9 +260,6 @@ def optimizeCalls(variants, fields, titvTarget):
|
|||
return recalCalls, callCovariates
|
||||
|
||||
def printCallQuals(recalCalls, titvTarget, info = ""):
|
||||
#for recalCall in islice(recalCalls, 10):
|
||||
# print recalCall
|
||||
|
||||
print '--------------------------------------------------------------------------------'
|
||||
print info
|
||||
calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = False )
|
||||
|
|
@ -279,8 +278,8 @@ def variantBinsForField(variants, field):
|
|||
# raise Exception('Unknown field ' + field)
|
||||
|
||||
minValue, maxValue, bins = fieldRange(variants, field)
|
||||
print 'Field range', minValue, maxValue
|
||||
print 'Partitions', bins
|
||||
if DEBUG: print 'Field range', minValue, maxValue
|
||||
if DEBUG: print 'Partitions', bins
|
||||
return bins
|
||||
|
||||
def mapVariantBins(variants, field, cumulative = False):
|
||||
|
|
@ -296,18 +295,19 @@ def mapVariantBins(variants, field, cumulative = False):
|
|||
def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulative = False ):
|
||||
covariates = []
|
||||
|
||||
printFieldQualHeader()
|
||||
for field in fields:
|
||||
print 'Optimizing field', field
|
||||
if DEBUG: print 'Optimizing field', field
|
||||
|
||||
titv, FPRate, nErrors = titvFPRateEstimate(variants, titvTarget)
|
||||
print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate)
|
||||
#print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate)
|
||||
|
||||
for left, right, selectedVariants in mapVariantBins(variants, field, cumulative = cumulative):
|
||||
if len(selectedVariants) > max(OPTIONS.minVariantsPerBin,1):
|
||||
titv, FPRate, nErrors = titvFPRateEstimate(selectedVariants, titvTarget)
|
||||
dbsnp = dbSNPRate(selectedVariants)
|
||||
covariates.append(CallCovariate(field, left, right, FPRate))
|
||||
printFieldQual( left, right, selectedVariants, titv, FPRate, nErrors )
|
||||
|
||||
printFieldQual(field, left, right, selectedVariants, FPRate )
|
||||
|
||||
return covariates
|
||||
|
||||
|
|
@ -324,7 +324,7 @@ class CallCmp:
|
|||
return (1.0*self.nFN) / max(self.nTP + self.nFN, 1)
|
||||
|
||||
def __str__(self):
|
||||
return 'TP=%6d FP=%6d FPRate=%.2f FN=%6d FNRate=%.2f' % (self.nTP, self.nFP, self.FPRate(), self.nFN, self.FNRate())
|
||||
return '%6d %6d %.2f %6d %.2f' % (self.nTP, self.nFP, self.FPRate(), self.nFN, self.FNRate())
|
||||
|
||||
def variantInTruth(variant, truth):
|
||||
if variant.getLoc() in truth:
|
||||
|
|
@ -338,10 +338,10 @@ def sensitivitySpecificity(variants, truth):
|
|||
for variant in variants:
|
||||
t = variantInTruth(variant, truth)
|
||||
if t:
|
||||
t.setField("FOUND", 1)
|
||||
t.setField("FN", 0)
|
||||
variant.setField("TP", 1)
|
||||
nTP += 1
|
||||
else:
|
||||
if OPTIONS.printFP: print 'FP:', variant
|
||||
nFP += 1
|
||||
#if variant.getPos() == 1520727:
|
||||
# print "Variant is missing", variant
|
||||
|
|
@ -351,18 +351,20 @@ def sensitivitySpecificity(variants, truth):
|
|||
|
||||
|
||||
def compareCalls(calls, truthCalls):
|
||||
def compare1(cumulative):
|
||||
for variant in calls: variant.setField("TP", 0) # set the TP field to 0
|
||||
|
||||
def compare1(name, cumulative):
|
||||
for left, right, selectedVariants in mapVariantBins(calls, 'QUAL', cumulative = cumulative):
|
||||
callComparison, theseFPs = sensitivitySpecificity(selectedVariants, truthCalls)
|
||||
for fp in theseFPs: fp.setField("FP", 1)
|
||||
#FPsVariants.append(theseFPs)
|
||||
print binString(left, right), 'titv=%.2f' % titv(selectedVariants)[0], callComparison
|
||||
#print selectedVariants[0]
|
||||
printFieldQual(name, left, right, selectedVariants, dephredScale(left), str(callComparison))
|
||||
|
||||
print 'PER BIN nCalls=', len(calls)
|
||||
compare1(False)
|
||||
printFieldQualHeader("TP FP FPRate FN FNRate")
|
||||
compare1('TRUTH-PER-BIN', False)
|
||||
|
||||
print 'CUMULATIVE nCalls=', len(calls)
|
||||
compare1(True)
|
||||
compare1('TRUTH-CUMULATIVE', True)
|
||||
|
||||
def randomSplit(l, pLeft):
|
||||
def keep(elt, p):
|
||||
|
|
@ -374,113 +376,134 @@ def randomSplit(l, pLeft):
|
|||
def get(i): return filter(lambda x: x <> None, [x[i] for x in data])
|
||||
return get(0), get(1)
|
||||
|
||||
def main():
|
||||
def setup():
|
||||
global OPTIONS, header
|
||||
usage = "usage: %prog files.list [options]"
|
||||
parser = OptionParser(usage=usage)
|
||||
parser.add_option("-f", "--f", dest="fields",
|
||||
type='string', default="QUAL",
|
||||
help="Comma-separated list of fields to exact")
|
||||
help="Comma-separated list of fields (either in the VCF columns of as INFO keys) to use during optimization [default: %default]")
|
||||
parser.add_option("-t", "--truth", dest="truth",
|
||||
type='string', default=None,
|
||||
help="VCF formated truth file")
|
||||
help="VCF formated truth file. If provided, the script will compare the input calls with the truth calls. It also emits calls tagged as TP and a separate file of FP calls")
|
||||
parser.add_option("", "--unFilteredTruth", dest="unFilteredTruth",
|
||||
action='store_true', default=False,
|
||||
help="If provided, the unfiltered truth calls will be used in comparisons")
|
||||
help="If provided, the unfiltered truth calls will be used in comparisons [default: %default]")
|
||||
parser.add_option("-p", "--partitions", dest="partitions",
|
||||
type='int', default=25,
|
||||
help="Number of partitions to examine")
|
||||
parser.add_option("-s", "--s", dest="skip",
|
||||
type='int', default=1,
|
||||
help="Only work with every 1 / skip records")
|
||||
help="Number of partitions to use for each feature. Don't use so many that the number of variants per bin is very low. [default: %default]")
|
||||
parser.add_option("", "--maxRecordsForCovariates", dest="maxRecordsForCovariates",
|
||||
type='int', default=200000,
|
||||
help="Derive covariate information from up to this many VCF records. For files with more than this number of records, the system downsamples the reads [default: %default]")
|
||||
parser.add_option("-m", "--minVariantsPerBin", dest="minVariantsPerBin",
|
||||
type='int', default=10,
|
||||
help="")
|
||||
help="Don't include any covariates with fewer than this number of variants in the bin, if such a thing happens. NEEDS TO BE FIXED")
|
||||
parser.add_option("-M", "--maxRecords", dest="maxRecords",
|
||||
type='int', default=None,
|
||||
help="")
|
||||
help="Maximum number of input VCF records to process, if provided. Default is all records")
|
||||
parser.add_option("-q", "--qMax", dest="maxQScore",
|
||||
type='int', default=30,
|
||||
help="")
|
||||
help="The maximum Q score allowed for both a single covariate and the overall QUAL score [default: %default]")
|
||||
parser.add_option("-o", "--outputVCF", dest="outputVCF",
|
||||
type='string', default=None,
|
||||
help="If provided, VCF file will be written out to this file")
|
||||
help="If provided, a VCF file will be written out to this file [default: %default]")
|
||||
parser.add_option("", "--FNoutputVCF", dest="FNoutputVCF",
|
||||
type='string', default=None,
|
||||
help="If provided, VCF file will be written out to this file")
|
||||
help="If provided, VCF file will be written out to this file [default: %default]")
|
||||
parser.add_option("", "--titv", dest="titvTarget",
|
||||
type='float', default=None,
|
||||
help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls")
|
||||
parser.add_option("", "--fp", dest="printFP",
|
||||
action='store_true', default=False,
|
||||
help="")
|
||||
help="If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls [default: %default]")
|
||||
parser.add_option("-b", "--bootstrap", dest="bootStrap",
|
||||
type='float', default=0.0,
|
||||
help="If provided, the % of the calls used to generate the recalibration tables.")
|
||||
|
||||
type='float', default=None,
|
||||
help="If provided, the % of the calls used to generate the recalibration tables. [default: %default]")
|
||||
|
||||
(OPTIONS, args) = parser.parse_args()
|
||||
if len(args) > 2:
|
||||
parser.error("incorrect number of arguments")
|
||||
return args
|
||||
|
||||
fields = OPTIONS.fields.split(',')
|
||||
header, allCalls = readVariants(args[0], OPTIONS.maxRecords)
|
||||
print 'Read', len(allCalls), 'calls'
|
||||
#print 'header is', header
|
||||
def determineCovariates(file, fields):
|
||||
print 'Counting records in VCF', file
|
||||
numberOfRecords = quickCountRecords(open(file))
|
||||
if OPTIONS.maxRecords <> None and OPTIONS.maxRecords < numberOfRecords:
|
||||
numberOfRecords = OPTIONS.maxRecords
|
||||
downsampleFraction = min(float(OPTIONS.maxRecordsForCovariates) / numberOfRecords, 1)
|
||||
header, allCalls = readVariants(file, OPTIONS.maxRecords, downsampleFraction=downsampleFraction)
|
||||
allCalls = list(allCalls)
|
||||
print 'Number of VCF records', numberOfRecords, ', max number of records for covariates is', OPTIONS.maxRecordsForCovariates, 'so keeping', downsampleFraction * 100, '% of records'
|
||||
print 'Number of selected VCF records', len(allCalls)
|
||||
|
||||
if OPTIONS.titvTarget == None:
|
||||
OPTIONS.titvTarget = titv(calls, VCFRecord.isKnown)
|
||||
OPTIONS.titvTarget = titv(selectVariants(allCalls, VCFRecord.isKnown))
|
||||
print 'Ti/Tv all ', titv(allCalls)
|
||||
print 'Ti/Tv known', titv(selectVariants(allCalls, VCFRecord.isKnown))
|
||||
print 'Ti/Tv novel', titv(selectVariants(allCalls, VCFRecord.isNovel))
|
||||
|
||||
if OPTIONS.bootStrap:
|
||||
callsToOptimize, callsToEval = randomSplit(allCalls, OPTIONS.bootStrap)
|
||||
callsToOptimize, recalEvalCalls = randomSplit(allCalls, OPTIONS.bootStrap)
|
||||
else:
|
||||
callsToOptimize, callsToEval = allCalls, allCalls
|
||||
callsToOptimize = allCalls
|
||||
|
||||
recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, OPTIONS.titvTarget)
|
||||
printCallQuals(recalOptCalls, OPTIONS.titvTarget, 'OPTIMIZED CALLS')
|
||||
|
||||
if callsToEval <> callsToOptimize:
|
||||
recalEvalCalls = recalibrateCalls(callsToEval, fields, covariates)
|
||||
printCallQuals(recalEvalCalls, OPTIONS.titvTarget, 'BOOTSTRAP EVAL CALLS')
|
||||
printCallQuals(list(recalOptCalls), OPTIONS.titvTarget, 'OPTIMIZED CALLS')
|
||||
|
||||
truth = None
|
||||
if len(args) > 1:
|
||||
truthFile = args[1]
|
||||
print 'Reading truth file', truthFile
|
||||
rawTruth = readVariants(truthFile, maxRecords = None, decodeAll = False)[1]
|
||||
def keepVariant(t):
|
||||
#print t.getPos(), t.getLoc()
|
||||
return OPTIONS.unFilteredTruth or t.passesFilters()
|
||||
truth = dict( [[v.getLoc(), v] for v in filter(keepVariant, rawTruth)])
|
||||
print len(rawTruth), len(truth)
|
||||
|
||||
print '--------------------------------------------------------------------------------'
|
||||
print 'Comparing calls to truth', truthFile
|
||||
print ''
|
||||
if OPTIONS.bootStrap:
|
||||
recalibatedEvalCalls = recalibrateCalls(recalEvalCalls, fields, covariates)
|
||||
printCallQuals(list(recalibatedEvalCalls), OPTIONS.titvTarget, 'BOOTSTRAP EVAL CALLS')
|
||||
|
||||
print 'Calls used in optimization'
|
||||
compareCalls(recalOptCalls, truth)
|
||||
if callsToEval <> callsToOptimize:
|
||||
print 'Calls held in reserve (bootstrap)'
|
||||
compareCalls(recalEvalCalls, truth)
|
||||
return covariates
|
||||
|
||||
if OPTIONS.outputVCF:
|
||||
f = open(OPTIONS.outputVCF, 'w')
|
||||
def writeRecalibratedCalls(file, header, calls):
|
||||
if file:
|
||||
f = open(file, 'w')
|
||||
#print 'HEADER', header
|
||||
for line in formatVCF(header, allCalls):
|
||||
i = 0
|
||||
for line in formatVCF(header, calls):
|
||||
if i % 10000 == 0: print 'writing VCF record', i
|
||||
i += 1
|
||||
print >> f, line
|
||||
f.close()
|
||||
|
||||
def evaluateTruth(header, callVCF, truthVCF):
|
||||
print 'Reading truth file', truthVCF
|
||||
rawTruth = list(readVariants(truthVCF, maxRecords = None, decodeAll = False)[1])
|
||||
def keepVariant(t):
|
||||
#print t.getPos(), t.getLoc()
|
||||
return OPTIONS.unFilteredTruth or t.passesFilters()
|
||||
truth = dict( [[v.getLoc(), v] for v in filter(keepVariant, rawTruth)])
|
||||
print len(rawTruth), len(truth)
|
||||
|
||||
print 'Reading variants back in from', callVCF
|
||||
header, calls = readVariants(callVCF)
|
||||
calls = list(calls)
|
||||
|
||||
print '--------------------------------------------------------------------------------'
|
||||
print 'Comparing calls to truth', truthVCF
|
||||
print ''
|
||||
|
||||
compareCalls(calls, truth)
|
||||
|
||||
writeRecalibratedCalls(callVCF, header, calls)
|
||||
|
||||
if truth <> None and OPTIONS.FNoutputVCF:
|
||||
f = open(OPTIONS.FNoutputVCF, 'w')
|
||||
#print 'HEADER', header
|
||||
for line in formatVCF(header, filter( lambda x: not x.hasField("FOUND"), truth.itervalues())):
|
||||
for line in formatVCF(header, filter( lambda x: not x.hasField("TP"), truth.itervalues())):
|
||||
print >> f, line
|
||||
f.close()
|
||||
|
||||
def main():
|
||||
args = setup()
|
||||
fields = OPTIONS.fields.split(',')
|
||||
|
||||
covariates = determineCovariates(args[0], fields)
|
||||
header, callsToRecalibate = readVariants(args[0], OPTIONS.maxRecords)
|
||||
RecalibratedCalls = recalibrateCalls(callsToRecalibate, fields, covariates)
|
||||
writeRecalibratedCalls(OPTIONS.outputVCF, header, RecalibratedCalls)
|
||||
|
||||
if len(args) > 1:
|
||||
evaluateTruth(header, OPTIONS.outputVCF, args[1])
|
||||
|
||||
PROFILE = False
|
||||
if __name__ == "__main__":
|
||||
if PROFILE:
|
||||
|
|
|
|||
|
|
@ -155,6 +155,14 @@ def readVCFHeader(lines):
|
|||
# we reach this point for empty files
|
||||
return header, columnNames, []
|
||||
|
||||
def quickCountRecords(lines):
|
||||
counter = 0
|
||||
for line in lines:
|
||||
if line[0] != "#":
|
||||
counter += 1
|
||||
return counter
|
||||
|
||||
|
||||
def lines2VCF(lines, extendedOutput = False, decodeAll = True):
|
||||
header, columnNames, lines = readVCFHeader(lines)
|
||||
counter = 0
|
||||
|
|
@ -174,5 +182,5 @@ def lines2VCF(lines, extendedOutput = False, decodeAll = True):
|
|||
def formatVCF(header, records):
|
||||
#print records
|
||||
#print records[0]
|
||||
return itertools.chain(header, map(VCFRecord.format, records))
|
||||
return itertools.chain(header, itertools.imap(VCFRecord.format, records))
|
||||
|
||||
|
|
|
|||
Loading…
Reference in New Issue