continuing improvements in output of snpSelector

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@2198 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
depristo 2009-12-01 15:42:06 +00:00
parent 21a9a717e4
commit c93d37d9fb
1 changed files with 41 additions and 26 deletions

View File

@ -181,10 +181,13 @@ def compareFieldValues( v1, v2 ):
return c return c
def calculateBins(variants, field, minValue, maxValue, partitions): def calculateBins(variants, field, minValue, maxValue, partitions):
sortedVariants = sorted(variants, key = lambda x: x.getField(field)) # cmp = compareFieldValues, values = map(lambda x: x.getField(field), variants)
sortedValues = map(lambda x: x.getField(field), sortedVariants) return calculateBinsForValues(values, field, minValue, maxValue, partitions)
def calculateBinsForValues(values, field, minValue, maxValue, partitions):
sortedValues = sorted(values)
targetBinSize = len(variants) / (1.0*partitions) targetBinSize = len(values) / (1.0*partitions)
#print sortedValues #print sortedValues
uniqBins = groupby(sortedValues) uniqBins = groupby(sortedValues)
binsAndSizes = map(lambda x: [x[0], len(list(x[1]))], uniqBins) binsAndSizes = map(lambda x: [x[0], len(list(x[1]))], uniqBins)
@ -234,14 +237,14 @@ def validateBins(bins):
def printFieldQualHeader(): def printFieldQualHeader():
more = "" more = ""
if TRUTH_CALLS <> None: if TRUTH_CALLS <> None:
more = "TP FP FPRate FN FNRate" more = CallCmp.HEADER
def p(stream): def p(stream):
if stream <> None: if stream <> None:
print >> stream, ' field left right nVariants nNovels titv titvNovels dbSNP fprate q', more print >> stream, ' %20s %20s left right nVariants nNovels titv titvNovels dbSNP FPEstimate Q' % ("category", "field"), more
p(sys.stdout) p(sys.stdout)
p(RECAL_LOG) p(RECAL_LOG)
def printFieldQual( field, left, right, variants, FPRate ): def printFieldQual( category, field, left, right, variants, FPRate ):
more = "" more = ""
if TRUTH_CALLS <> None: if TRUTH_CALLS <> None:
callComparison, theseFPs = sensitivitySpecificity(variants, TRUTH_CALLS) callComparison, theseFPs = sensitivitySpecificity(variants, TRUTH_CALLS)
@ -249,7 +252,7 @@ def printFieldQual( field, left, right, variants, FPRate ):
novels = selectVariants(variants, VCFRecord.isNovel) novels = selectVariants(variants, VCFRecord.isNovel)
def p(stream): def p(stream):
if stream <> None: if stream <> None:
print >> stream, ' %s %s %8d %8d %.2f %.2f %.2f %.2e %d' % (field, binString(left, right), len(variants), len(novels), titv(variants), titv(novels), dbSNPRate(variants), FPRate, phredScale(FPRate)), more print >> stream, ' %20s %20s %s %8d %8d %.2f %.2f %.2f %.2e %3d' % (category, field, binString(left, right), len(variants), len(novels), titv(variants), titv(novels), dbSNPRate(variants), FPRate, phredScale(FPRate)), more
p(sys.stdout) p(sys.stdout)
p(RECAL_LOG) p(RECAL_LOG)
@ -265,7 +268,7 @@ def binString(left, right):
# #
# #
def recalibrateCalls(variants, fields, callCovariates): def recalibrateCalls(variants, fields, callCovariates):
def phred(v): return int(round(phredScale(v))) def phred(v): return round(phredScale(v), 2)
for variant in variants: for variant in variants:
recalCall = RecalibratedCall(variant, fields) recalCall = RecalibratedCall(variant, fields)
@ -277,7 +280,7 @@ def recalibrateCalls(variants, fields, callCovariates):
recalCall.recalFeature(callCovariate.getFeature(), FPR) recalCall.recalFeature(callCovariate.getFeature(), FPR)
recalCall.call.setField(callCovariate.getCovariateField(), phred(FPR)) recalCall.call.setField(callCovariate.getCovariateField(), phred(FPR))
#recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate()))
recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate())) recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate()))
recalCall.call.setField('OQ', originalQual) recalCall.call.setField('OQ', originalQual)
#print 'recalibrating', variant.getLoc() #print 'recalibrating', variant.getLoc()
@ -288,16 +291,16 @@ def recalibrateCalls(variants, fields, callCovariates):
# #
# #
def optimizeCalls(variants, fields, titvTarget): def optimizeCalls(variants, fields, titvTarget):
callCovariates = calibrateFeatures(variants, fields, titvTarget) callCovariates = calibrateFeatures(variants, fields, titvTarget, category = "covariates")
recalCalls = recalibrateCalls(variants, fields, callCovariates) recalCalls = recalibrateCalls(variants, fields, callCovariates)
return recalCalls, callCovariates return recalCalls, callCovariates
def printCallQuals(recalCalls, titvTarget, info = ""): def printCallQuals(field, recalCalls, titvTarget, info = ""):
print '--------------------------------------------------------------------------------' print '--------------------------------------------------------------------------------'
print info print info
calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = False, forcePrint = True, prefix = "OPT-", printHeader = False ) calibrateFeatures(recalCalls, [field], titvTarget, printCall = True, cumulative = False, forcePrint = True, prefix = "OPT-", printHeader = False, category = "optimized-calls" )
print 'Cumulative' print 'Cumulative'
calibrateFeatures(recalCalls, ['QUAL'], titvTarget, printCall = True, cumulative = True, forcePrint = True, prefix = "OPTCUM-", printHeader = False ) calibrateFeatures(recalCalls, [field], titvTarget, printCall = True, cumulative = True, forcePrint = True, prefix = "OPTCUM-", printHeader = False, category = "optimized-calls" )
def all( p, l ): def all( p, l ):
for elt in l: for elt in l:
@ -323,7 +326,7 @@ def mapVariantBins(variants, field, cumulative = False):
return imap( variantsInBin, bins ) return imap( variantsInBin, bins )
def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulative = False, forcePrint = False, prefix = '', printHeader = True ): def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulative = False, forcePrint = False, prefix = '', printHeader = True, category = None ):
covariates = [] covariates = []
if printHeader: printFieldQualHeader() if printHeader: printFieldQualHeader()
@ -338,7 +341,7 @@ def calibrateFeatures(variants, fields, titvTarget, printCall = False, cumulativ
titv, FPRate = titvFPRateEstimate(selectedVariants, titvTarget) titv, FPRate = titvFPRateEstimate(selectedVariants, titvTarget)
dbsnp = dbSNPRate(selectedVariants) dbsnp = dbSNPRate(selectedVariants)
covariates.append(CallCovariate(field, left, right, FPRate)) covariates.append(CallCovariate(field, left, right, FPRate))
printFieldQual( prefix + field, left, right, selectedVariants, FPRate ) printFieldQual( category, prefix + field, left, right, selectedVariants, FPRate )
else: else:
print 'Not calibrating bin', left, right, 'because it contains too few variants:', len(selectedVariants) print 'Not calibrating bin', left, right, 'because it contains too few variants:', len(selectedVariants)
@ -350,14 +353,24 @@ class CallCmp:
self.nFP = nFP self.nFP = nFP
self.nFN = nFN self.nFN = nFN
def FPRate(self): # def FPRate(self):
return (1.0*self.nFP) / max(self.nTP + self.nFP, 1) # return (1.0*self.nFP) / max(self.nTP + self.nFP, 1)
def FNRate(self): def FNRate(self):
return (1.0*self.nFN) / max(self.nTP + self.nFN, 1) return (1.0*self.nFN) / max(self.nTP + self.nFN, 1)
def sensitivity(self):
# = TP / (TP + FN)
return (1.0*self.nTP) / max( self.nTP + self.nFN,1 )
def PPV(self):
# = TP / (TP + FP)
return (1.0*self.nTP) / max( self.nTP + self.nFP, 1 )
HEADER = "TP FP FN FNRate Sensitivity PPV"
def __str__(self): def __str__(self):
return '%6d %6d %.2f %6d %.2f' % (self.nTP, self.nFP, self.FPRate(), self.nFN, self.FNRate()) return '%6d %6d %6d %.2f %.2f %.2f' % (self.nTP, self.nFP, self.nFN, self.FNRate(), self.sensitivity(), self.PPV())
def variantInTruth(variant, truth): def variantInTruth(variant, truth):
if variant.getLoc() in truth: if variant.getLoc() in truth:
@ -386,16 +399,17 @@ def compareCalls(calls, truthCalls):
for variant in calls: variant.setField("TP", 0) # set the TP field to 0 for variant in calls: variant.setField("TP", 0) # set the TP field to 0
def compare1(name, cumulative): def compare1(name, cumulative):
for left, right, selectedVariants in mapVariantBins(calls, 'QUAL', cumulative = cumulative): for field in ["QUAL", "OQ"]:
#print selectedVariants[0] for left, right, selectedVariants in mapVariantBins(calls, field, cumulative = cumulative):
printFieldQual(name, left, right, selectedVariants, dephredScale(left)) #print selectedVariants[0]
printFieldQual("truth-comparison-" + name, field, left, right, selectedVariants, dephredScale(left))
print 'PER BIN nCalls=', len(calls) print 'PER BIN nCalls=', len(calls)
# printFieldQualHeader() # printFieldQualHeader()
compare1('RECAL-QUAL-TRUTH-PER-BIN', False) compare1('per-bin', False)
print 'CUMULATIVE nCalls=', len(calls) print 'CUMULATIVE nCalls=', len(calls)
compare1('RECAL-QUAL-TRUTH-CUMULATIVE', True) compare1('cum', True)
def randomSplit(l, pLeft): def randomSplit(l, pLeft):
def keep(elt, p): def keep(elt, p):
@ -486,11 +500,11 @@ def determineCovariates(allCalls, titvtarget, fields):
callsToOptimize = allCalls callsToOptimize = allCalls
recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, titvtarget) recalOptCalls, covariates = optimizeCalls(callsToOptimize, fields, titvtarget)
printCallQuals(list(recalOptCalls), titvtarget, 'OPTIMIZED CALLS') printCallQuals("QUAL", list(recalOptCalls), titvtarget, 'OPTIMIZED CALLS')
if OPTIONS.bootStrap: if OPTIONS.bootStrap:
recalibatedEvalCalls = recalibrateCalls(recalEvalCalls, fields, covariates) recalibatedEvalCalls = recalibrateCalls(recalEvalCalls, fields, covariates)
printCallQuals(list(recalibatedEvalCalls), titvtarget, 'BOOTSTRAP EVAL CALLS') printCallQuals("QUAL", list(recalibatedEvalCalls), titvtarget, 'BOOTSTRAP EVAL CALLS')
return covariates return covariates
@ -562,7 +576,8 @@ def main():
RecalibratedCalls = recalibrateCalls(callsToRecalibate, fields, covariates) RecalibratedCalls = recalibrateCalls(callsToRecalibate, fields, covariates)
writeRecalibratedCalls(OPTIONS.outputVCF, header, RecalibratedCalls) writeRecalibratedCalls(OPTIONS.outputVCF, header, RecalibratedCalls)
else: else:
printCallQuals(allCalls, titvTarget) printFieldQualHeader()
printCallQuals("QUAL", allCalls, titvTarget)
OPTIONS.outputVCF = args[0] OPTIONS.outputVCF = args[0]
if len(args) > 1: if len(args) > 1: