2009-11-07 07:00:46 +08:00
import os . path
import sys
from optparse import OptionParser
from vcfReader import *
#import pylab
2009-12-15 22:38:39 +08:00
import operator
2009-11-07 07:00:46 +08:00
from itertools import *
import math
2009-11-12 06:02:57 +08:00
import random
2009-11-21 22:48:29 +08:00
DEBUG = False
2009-11-07 07:00:46 +08:00
2009-12-15 22:38:39 +08:00
class Range :
ANY = ' * '
def __init__ ( self , left = ANY , right = ANY , leftOpen = False , rightOpen = False ) :
2009-11-10 06:48:51 +08:00
self . left = left
2009-12-15 22:38:39 +08:00
self . right = right
self . leftOpen = leftOpen
self . rightOpen = rightOpen
def __str__ ( self ) :
leftB , rightB = ' [ ' , ' ] '
if self . leftOpen : leftB = ' ( '
if self . rightOpen : rightB = ' ) '
return ' %s %s , %s %s ' % ( leftB , self . left , self . right , rightB )
__repr__ = __str__
def dashedString ( self ) :
return str ( self ) . replace ( " , " , " - " )
def contains ( self , v ) :
def test ( r , op , open ) :
return r == Range . ANY or op ( v , r ) or ( not open and v == r )
return test ( self . left , operator . __gt__ , self . leftOpen ) and test ( self . right , operator . __lt__ , self . rightOpen )
2009-11-10 06:48:51 +08:00
2009-12-15 22:38:39 +08:00
class CallCovariate :
def __init__ ( self , feature , featureRange , qualRange , FPRate = None ) :
self . feature = feature
self . featureRange = featureRange
2009-11-10 06:48:51 +08:00
2009-12-15 22:38:39 +08:00
self . qualRange = qualRange
2009-11-10 06:48:51 +08:00
self . FPRate = FPRate
def containsVariant ( self , call ) :
2009-12-15 22:38:39 +08:00
inFeature = self . featureRange . contains ( call . getField ( self . feature ) )
inQual = self . qualRange . contains ( call . getQual ( ) )
#print 'inFeature, inQual', inFeature, inQual
return inFeature and inQual
2009-11-10 06:48:51 +08:00
def getFPRate ( self ) : return self . FPRate
def getFeature ( self ) : return self . feature
def getCovariateField ( self ) : return self . getFeature ( ) + ' _RQ '
2009-11-21 22:48:29 +08:00
2009-12-15 22:38:39 +08:00
def __str__ ( self ) : return " [CC feature= %s range= %s qualRange= %s ] " % ( self . feature , self . featureRange , self . qualRange )
2009-11-07 07:00:46 +08:00
class RecalibratedCall :
def __init__ ( self , call , features ) :
self . call = call
self . features = dict ( [ [ feature , None ] for feature in features ] )
def recalFeature ( self , feature , FPRate ) :
2009-11-21 22:48:29 +08:00
assert self . features [ feature ] == None , " Feature " + feature + ' has value ' + str ( self . features [ feature ] ) + ' for call ' + str ( self . call ) # not reassigning values
2009-11-07 07:00:46 +08:00
assert FPRate < = 1 and FPRate > = 0
self . features [ feature ] = FPRate
def getFeature ( self , feature , missingValue = None , phredScaleValue = False ) :
v = self . features [ feature ]
if v == None :
return missingValue
elif phredScaleValue :
return phredScale ( v )
else :
return v
def jointFPErrorRate ( self ) :
#print self.features
logTPRates = [ math . log10 ( 1 - r ) for r in self . features . itervalues ( ) if r < > None ]
logJointTPRate = reduce ( lambda x , y : x + y , logTPRates , 0 )
2009-11-21 22:48:29 +08:00
logJointTPRate = min ( logJointTPRate , 1e-3 / 3 ) # approximation from het of 0.001
2009-11-07 07:00:46 +08:00
jointTPRate = math . pow ( 10 , logJointTPRate )
#print logTPRates
#print logJointTPRate, jointTPRate
return 1 - jointTPRate
def featureStringList ( self ) :
return ' , ' . join ( map ( lambda feature : ' %s =Q %d ' % ( feature , self . getFeature ( feature , ' * ' , True ) ) , self . features . iterkeys ( ) ) )
def __str__ ( self ) :
return ' [ %s : %s => Q %d ] ' % ( str ( self . call ) , self . featureStringList ( ) , phredScale ( self . jointFPErrorRate ( ) ) )
2009-12-14 01:59:32 +08:00
def readVariants ( file , maxRecords = None , decodeAll = True , downsampleFraction = 1 , filter = None , minQScore = - 1 , mustBeVariant = False ) :
2009-12-07 11:37:14 +08:00
if filter == None :
filter = not OPTIONS . unfiltered
2009-11-10 06:48:51 +08:00
f = open ( file )
2009-12-02 22:36:03 +08:00
header , columnNames , lines = readVCFHeader ( f )
2009-11-07 07:00:46 +08:00
2009-12-14 01:59:32 +08:00
nLowQual = 0
2009-11-07 07:00:46 +08:00
def parseVariant ( args ) :
2009-12-14 01:59:32 +08:00
global nLowQual
2009-11-10 06:48:51 +08:00
header1 , VCF , counter = args
2009-12-14 01:59:32 +08:00
if filter and not VCF . passesFilters ( ) or ( False and mustBeVariant == True and not VCF . isVariant ( ) ) : # currently ignore mustBeVariant
2009-12-07 11:37:14 +08:00
#print 'filtering', VCF
return None
2009-12-14 01:59:32 +08:00
elif VCF . getQual ( ) < = minQScore :
#print 'filtering', VCF
#nLowQual += 1
return None
2009-12-07 11:37:14 +08:00
elif random . random ( ) < = downsampleFraction :
2009-11-07 07:00:46 +08:00
return VCF
else :
return None
2009-12-02 22:36:03 +08:00
variants = ifilter ( None , imap ( parseVariant , islice ( lines2VCF ( lines , header = header , columnNames = columnNames , extendedOutput = True , decodeAll = decodeAll ) , maxRecords ) ) )
2009-12-14 01:59:32 +08:00
if nLowQual > 0 :
print ' %d snps filtered due to QUAL < %d ' % ( nLowQual , minQScore )
2009-12-02 22:36:03 +08:00
return header , variants
2009-11-07 07:00:46 +08:00
def selectVariants ( variants , selector = None ) :
if selector < > None :
return filter ( selector , variants )
else :
return variants
def titv ( variants ) :
ti = len ( filter ( VCFRecord . isTransition , variants ) )
tv = len ( variants ) - ti
titv = ti / ( 1.0 * max ( tv , 1 ) )
2009-11-14 05:46:31 +08:00
return titv
2009-11-07 07:00:46 +08:00
2009-11-14 05:46:31 +08:00
def dbSNPRate ( variants ) :
inDBSNP = len ( filter ( VCFRecord . isKnown , variants ) )
2009-12-15 22:38:39 +08:00
return float ( inDBSNP ) / max ( len ( variants ) , 1 )
2009-11-07 07:00:46 +08:00
def gaussian ( x , mu , sigma ) :
constant = 1 / math . sqrt ( 2 * math . pi * sigma * * 2 )
exponent = - 1 * ( x - mu ) * * 2 / ( 2 * sigma * * 2 )
return constant * math . exp ( exponent )
# if target = T, and FP calls have ti/tv = 0.5, we want to know how many FP calls
# there are in N calls with ti/tv of X.
#
def titvFPRateEstimate ( variants , target ) :
2009-11-14 05:46:31 +08:00
titvRatio = titv ( variants )
2009-11-07 07:00:46 +08:00
# f <- function(To,T) { (To - T) / (1/2 - T) + 0.001 }
def theoreticalCalc ( ) :
if titvRatio > = target :
FPRate = 0
else :
FPRate = ( titvRatio - target ) / ( 0.5 - target )
FPRate = min ( max ( FPRate , 0 ) , 1 )
TPRate = max ( min ( 1 - FPRate , 1 - dephredScale ( OPTIONS . maxQScore ) ) , dephredScale ( OPTIONS . maxQScore ) )
2009-11-14 05:46:31 +08:00
if DEBUG : print ' FPRate ' , FPRate , titvRatio , target
2009-11-07 07:00:46 +08:00
assert FPRate > = 0 and FPRate < = 1
return TPRate
# gaussian model
def gaussianModel ( ) :
LEFT_HANDED = True
2009-11-26 06:08:12 +08:00
sigma = 1 # old value is 5
2009-11-07 07:00:46 +08:00
constant = 1 / math . sqrt ( 2 * math . pi * sigma * * 2 )
exponent = - 1 * ( titvRatio - target ) * * 2 / ( 2 * sigma * * 2 )
TPRate = gaussian ( titvRatio , target , sigma ) / gaussian ( target , target , sigma )
if LEFT_HANDED and titvRatio > = target :
TPRate = 1
TPRate - = dephredScale ( OPTIONS . maxQScore )
if DEBUG : print ' TPRate ' , TPRate , constant , exponent , dephredScale ( OPTIONS . maxQScore )
return TPRate
2009-11-26 06:08:12 +08:00
FPRate = 1 - theoreticalCalc ( )
#FPRate = 1 - gaussianModel()
2009-11-07 07:00:46 +08:00
nVariants = len ( variants )
2009-11-21 22:48:29 +08:00
if DEBUG : print ' ::: ' , nVariants , titvRatio , target , FPRate
2009-11-07 07:00:46 +08:00
2009-11-21 22:48:29 +08:00
return titvRatio , FPRate
2009-11-07 07:00:46 +08:00
def phredScale ( errorRate ) :
return - 10 * math . log10 ( max ( errorRate , 1e-10 ) )
def dephredScale ( qscore ) :
2009-11-14 05:46:31 +08:00
return math . pow ( 10 , float ( qscore ) / - 10 )
2009-11-07 07:00:46 +08:00
def frange6 ( * args ) :
""" A float range generator. """
start = 0.0
step = 1.0
l = len ( args )
if l == 1 :
end = args [ 0 ]
elif l == 2 :
start , end = args
elif l == 3 :
start , end , step = args
if step == 0.0 :
raise ValueError , " step must not be zero "
else :
raise TypeError , " frange expects 1-3 arguments, got %d " % l
v = start
while True :
if ( step > 0 and v > = end ) or ( step < 0 and v < = end ) :
raise StopIteration
yield v
v + = step
2009-11-21 22:48:29 +08:00
def compareFieldValues ( v1 , v2 ) :
if type ( v1 ) < > type ( v2 ) :
#print 'Different types', type(v1), type(v2)
c = cmp ( type ( v1 ) , type ( v2 ) )
else :
c = cmp ( v1 , v2 )
#print 'Comparing %s %s = %s' % (v1, v2, c)
return c
2009-11-12 06:02:57 +08:00
def calculateBins ( variants , field , minValue , maxValue , partitions ) :
2009-12-01 23:42:06 +08:00
values = map ( lambda x : x . getField ( field ) , variants )
return calculateBinsForValues ( values , field , minValue , maxValue , partitions )
def calculateBinsForValues ( values , field , minValue , maxValue , partitions ) :
sortedValues = sorted ( values )
2009-12-04 07:52:35 +08:00
captureFieldRangeForPrinting ( field , sortedValues )
2009-11-07 07:00:46 +08:00
2009-12-01 23:42:06 +08:00
targetBinSize = len ( values ) / ( 1.0 * partitions )
2009-11-21 22:48:29 +08:00
#print sortedValues
2009-11-07 07:00:46 +08:00
uniqBins = groupby ( sortedValues )
binsAndSizes = map ( lambda x : [ x [ 0 ] , len ( list ( x [ 1 ] ) ) ] , uniqBins )
2009-11-21 22:48:29 +08:00
#print 'BINS AND SIZES', binsAndSizes
2009-11-07 07:00:46 +08:00
def bin2Break ( bin ) : return [ bin [ 0 ] , bin [ 0 ] , bin [ 1 ] ]
bins = [ bin2Break ( binsAndSizes [ 0 ] ) ]
for bin in binsAndSizes [ 1 : ] :
2009-11-21 22:48:29 +08:00
#print ' Breaks', bins
#print ' current bin', bin
2009-12-04 07:52:35 +08:00
curLeft = bins [ - 1 ] [ 0 ]
2009-11-07 07:00:46 +08:00
curSize = bin [ 1 ]
prevSize = bins [ - 1 ] [ 2 ]
2009-11-21 22:48:29 +08:00
#print curSize, prevSize
2009-12-04 07:52:35 +08:00
if curSize + prevSize > targetBinSize or ( not isNumber ( curLeft ) and isNumber ( bin [ 0 ] ) ) :
2009-11-21 22:48:29 +08:00
#print ' => appending', bin2Break(bin)
2009-11-07 07:00:46 +08:00
bins . append ( bin2Break ( bin ) )
else :
bins [ - 1 ] [ 1 ] = bin [ 0 ]
bins [ - 1 ] [ 2 ] + = curSize
2009-12-31 06:14:50 +08:00
#print 'Returning ', bins
2009-11-21 22:48:29 +08:00
#sys.exit(1)
2009-11-07 07:00:46 +08:00
return bins
def fieldRange ( variants , field ) :
values = map ( lambda v : v . getField ( field ) , variants )
minValue = min ( values )
maxValue = max ( values )
2009-11-12 06:02:57 +08:00
#rangeValue = maxValue - minValue
bins = calculateBins ( variants , field , minValue , maxValue , OPTIONS . partitions )
2009-11-21 22:48:29 +08:00
validateBins ( bins )
2009-11-12 06:02:57 +08:00
return minValue , maxValue , bins
2009-11-07 07:00:46 +08:00
2009-11-21 22:48:29 +08:00
def validateBins ( bins ) :
#print 'Bins are', bins
for left1 , right1 , count1 in bins :
for left2 , right2 , count2 in bins :
def contains2 ( x ) :
return left2 < x and x < right2
if left1 < > left2 and right1 < > right2 :
if None in [ left1 , left2 , right1 , right2 ] :
pass # we're ok
elif contains2 ( left1 ) or contains2 ( right2 ) :
raise Exception ( " Bad bins " , left1 , right1 , left2 , right2 )
2009-11-30 06:20:40 +08:00
def printFieldQualHeader ( ) :
more = " "
if TRUTH_CALLS < > None :
2009-12-01 23:42:06 +08:00
more = CallCmp . HEADER
2009-11-30 06:20:40 +08:00
def p ( stream ) :
if stream < > None :
2009-12-15 22:38:39 +08:00
print >> stream , ' %20s %20s left right %15s nVariants nNovels titv titvNovels dbSNP Q ' % ( " category " , " field " , " qRange " ) , more
2009-11-30 06:20:40 +08:00
p ( sys . stdout )
p ( RECAL_LOG )
2009-11-14 05:46:31 +08:00
2009-12-15 22:38:39 +08:00
def printFieldQual ( category , field , cc , variants ) :
2009-11-30 06:20:40 +08:00
more = " "
if TRUTH_CALLS < > None :
callComparison , theseFPs = sensitivitySpecificity ( variants , TRUTH_CALLS )
more = str ( callComparison )
2009-11-24 23:45:48 +08:00
novels = selectVariants ( variants , VCFRecord . isNovel )
2009-11-30 06:20:40 +08:00
def p ( stream ) :
if stream < > None :
2009-12-15 22:38:39 +08:00
print >> stream , ' %20s %20s %15s %15s %8d %8d %2.2f %2.2f %3.2f %3d ' % ( category , field , binString ( field , cc . featureRange ) , cc . qualRange . dashedString ( ) , len ( variants ) , len ( novels ) , titv ( variants ) , titv ( novels ) , dbSNPRate ( variants ) * 100 , phredScale ( cc . FPRate ) ) , more
2009-11-30 06:20:40 +08:00
p ( sys . stdout )
p ( RECAL_LOG )
2009-11-08 03:32:12 +08:00
2009-12-04 07:52:35 +08:00
FIELD_RANGES = dict ( )
def captureFieldRangeForPrinting ( field , sortedValues ) :
""" Finds the minimum float value in sortedValues for convenience printing in recal.log """
#print sortedValues
floatValues = filter ( isNumber , sortedValues )
if floatValues < > [ ] :
FIELD_RANGES [ field ] = floatValues [ 0 ]
#print 'Setting field range to', field, FIELD_RANGES[field]
def isNumber ( x ) :
return isinstance ( x , ( int , long , float ) )
2009-12-15 22:38:39 +08:00
def binString ( field , cc ) :
2009-12-04 07:52:35 +08:00
epsilon = 1e-2
2009-12-15 22:38:39 +08:00
left , right = cc . left , cc . right
2009-11-07 07:00:46 +08:00
leftStr = str ( left )
rightStr = " %5s " % str ( right )
2009-12-04 07:52:35 +08:00
if OPTIONS . plottableNones and not isNumber ( left ) and not isNumber ( right ) and field in FIELD_RANGES :
left = right = FIELD_RANGES [ field ] - epsilon
if OPTIONS . plottableNones and not isNumber ( left ) and isNumber ( right ) :
left = right - epsilon
2009-12-14 01:59:32 +08:00
if isNumber ( left ) : leftStr = " %.4f " % left
if isNumber ( right ) : rightStr = " %.4f " % right
return ' %12s %12s ' % ( leftStr , rightStr )
2009-11-10 06:48:51 +08:00
2009-11-07 07:00:46 +08:00
#
#
#
2009-11-10 06:48:51 +08:00
def recalibrateCalls ( variants , fields , callCovariates ) :
2009-12-14 01:59:32 +08:00
def phred ( v ) : return phredScale ( v )
2009-11-10 06:48:51 +08:00
for variant in variants :
recalCall = RecalibratedCall ( variant , fields )
originalQual = variant . getField ( ' QUAL ' )
for callCovariate in callCovariates :
if callCovariate . containsVariant ( variant ) :
FPR = callCovariate . getFPRate ( )
recalCall . recalFeature ( callCovariate . getFeature ( ) , FPR )
recalCall . call . setField ( callCovariate . getCovariateField ( ) , phred ( FPR ) )
2009-11-21 22:48:29 +08:00
2009-12-01 23:42:06 +08:00
#recalCall.call.setField('QUAL', phred(recalCall.jointFPErrorRate()))
2009-11-10 06:48:51 +08:00
recalCall . call . setField ( ' QUAL ' , phred ( recalCall . jointFPErrorRate ( ) ) )
2009-11-07 07:00:46 +08:00
recalCall . call . setField ( ' OQ ' , originalQual )
2009-11-14 05:46:31 +08:00
#print 'recalibrating', variant.getLoc()
2009-11-21 22:48:29 +08:00
#print ' =>', variant
2009-11-14 05:46:31 +08:00
yield recalCall . call
2009-11-10 06:48:51 +08:00
#
#
#
def optimizeCalls ( variants , fields , titvTarget ) :
2009-12-15 22:38:39 +08:00
callCovariates = calibrateFeatures ( variants , fields , titvTarget , category = " covariates " , useBreaks = True )
2009-11-10 06:48:51 +08:00
recalCalls = recalibrateCalls ( variants , fields , callCovariates )
return recalCalls , callCovariates
2009-12-01 23:42:06 +08:00
def printCallQuals ( field , recalCalls , titvTarget , info = " " ) :
2009-11-07 07:00:46 +08:00
print ' -------------------------------------------------------------------------------- '
2009-11-10 06:48:51 +08:00
print info
2009-12-01 23:42:06 +08:00
calibrateFeatures ( recalCalls , [ field ] , titvTarget , printCall = True , cumulative = False , forcePrint = True , prefix = " OPT- " , printHeader = False , category = " optimized-calls " )
2009-11-10 06:48:51 +08:00
print ' Cumulative '
2009-12-01 23:42:06 +08:00
calibrateFeatures ( recalCalls , [ field ] , titvTarget , printCall = True , cumulative = True , forcePrint = True , prefix = " OPTCUM- " , printHeader = False , category = " optimized-calls " )
2009-11-07 07:00:46 +08:00
def all ( p , l ) :
for elt in l :
if not p ( elt ) : return False
return True
2009-11-08 03:32:12 +08:00
2009-12-15 22:38:39 +08:00
def mapVariantBins ( variants , field , cumulative = False , breakQuals = [ Range ( ) ] ) :
minValue , maxValue , featureBins = fieldRange ( variants , field )
#print 'BREAKQuals', breakQuals[0]
bins = [ ( x , y ) for x in featureBins for y in breakQuals ]
#print 'BINS', bins
2009-11-10 06:48:51 +08:00
2009-12-15 22:38:39 +08:00
def variantsInBin ( featureBin , qualRange ) :
right = featureBin [ 1 ]
if cumulative :
right = Range . ANY
cc = CallCovariate ( field , Range ( featureBin [ 0 ] , right ) , qualRange )
return cc , selectVariants ( variants , lambda v : cc . containsVariant ( v ) )
2009-11-08 03:32:12 +08:00
2009-12-15 22:38:39 +08:00
#sys.exit(1)
return starmap ( variantsInBin , bins )
def qBreaksRanges ( qBreaks , useBreaks ) :
if qBreaks == None or not useBreaks :
return [ Range ( ) ] # include everything in a single range
else :
breaks = map ( float , qBreaks . split ( ' , ' ) )
return map ( lambda x , y : Range ( x , y , rightOpen = True ) , chain ( [ Range . ANY ] , breaks ) , chain ( breaks , [ Range . ANY ] ) )
2009-11-08 03:32:12 +08:00
2009-12-15 22:38:39 +08:00
def calibrateFeatures ( variants , fields , titvTarget , printCall = False , cumulative = False , forcePrint = False , prefix = ' ' , printHeader = True , category = None , useBreaks = False ) :
2009-11-10 06:48:51 +08:00
covariates = [ ]
2009-11-30 06:20:40 +08:00
if printHeader : printFieldQualHeader ( )
2009-11-07 07:00:46 +08:00
for field in fields :
2009-11-14 05:46:31 +08:00
if DEBUG : print ' Optimizing field ' , field
2009-11-07 07:00:46 +08:00
2009-11-21 22:48:29 +08:00
titv , FPRate = titvFPRateEstimate ( variants , titvTarget )
2009-11-14 05:46:31 +08:00
#print 'Overall FRRate:', FPRate, nErrors, phredScale(FPRate)
2009-11-08 03:32:12 +08:00
2009-12-15 22:38:39 +08:00
for cc , selectedVariants in mapVariantBins ( variants , field , cumulative = cumulative , breakQuals = qBreaksRanges ( OPTIONS . QBreaks , useBreaks and field < > ' QUAL ' ) ) :
#print 'CC', cc, field, useBreaks
2009-11-24 23:45:48 +08:00
if len ( selectedVariants ) > max ( OPTIONS . minVariantsPerBin , 1 ) or forcePrint :
2009-11-21 22:48:29 +08:00
titv , FPRate = titvFPRateEstimate ( selectedVariants , titvTarget )
2009-12-15 22:38:39 +08:00
#dbsnp = dbSNPRate(selectedVariants)
cc . FPRate = FPRate
covariates . append ( cc )
printFieldQual ( category , prefix + field , cc , selectedVariants )
2009-11-21 22:48:29 +08:00
else :
2009-12-15 22:38:39 +08:00
print ' Not calibrating bin ' , cc , ' because it contains too few variants: ' , len ( selectedVariants )
2009-11-10 06:48:51 +08:00
return covariates
2009-11-07 07:00:46 +08:00
2009-11-08 03:32:12 +08:00
class CallCmp :
def __init__ ( self , nTP , nFP , nFN ) :
self . nTP = nTP
self . nFP = nFP
self . nFN = nFN
2009-12-01 23:42:06 +08:00
# def FPRate(self):
# return (1.0*self.nFP) / max(self.nTP + self.nFP, 1)
2009-11-12 06:02:57 +08:00
def FNRate ( self ) :
return ( 1.0 * self . nFN ) / max ( self . nTP + self . nFN , 1 )
2009-12-01 23:42:06 +08:00
def sensitivity ( self ) :
# = TP / (TP + FN)
return ( 1.0 * self . nTP ) / max ( self . nTP + self . nFN , 1 )
def PPV ( self ) :
# = TP / (TP + FP)
return ( 1.0 * self . nTP ) / max ( self . nTP + self . nFP , 1 )
HEADER = " TP FP FN FNRate Sensitivity PPV "
2009-11-08 03:32:12 +08:00
def __str__ ( self ) :
2009-12-04 07:52:35 +08:00
return ' %6d %6d %6d %.3f %.3f %.3f ' % ( self . nTP , self . nFP , self . nFN , self . FNRate ( ) , self . sensitivity ( ) , self . PPV ( ) )
2009-11-08 03:32:12 +08:00
def variantInTruth ( variant , truth ) :
2009-11-12 06:02:57 +08:00
if variant . getLoc ( ) in truth :
return truth [ variant . getLoc ( ) ]
else :
return False
2009-11-08 03:32:12 +08:00
2009-12-02 22:36:03 +08:00
def isVariantInSample ( t , sample ) :
#print "isVariantInSample", t.getLoc(), t.getField(sample), x
return t . getField ( sample ) < > " 0/0 "
def variantsInTruth ( truth ) :
# fixme
return len ( filter ( lambda x : isVariantInSample ( x , OPTIONS . useSample ) , truth ) )
2009-11-08 03:32:12 +08:00
def sensitivitySpecificity ( variants , truth ) :
2009-12-02 22:36:03 +08:00
nTP , nFP = 0 , 0
2009-11-12 06:02:57 +08:00
FPs = [ ]
2009-11-08 03:32:12 +08:00
for variant in variants :
2009-11-12 06:02:57 +08:00
t = variantInTruth ( variant , truth )
2009-12-02 22:36:03 +08:00
isTP , isFP = False , False
2009-12-14 01:59:32 +08:00
if OPTIONS . useSample or OPTIONS . onlyAtTruth :
2009-12-02 22:36:03 +08:00
if t : # we have a site
2009-12-14 01:59:32 +08:00
isTP = ( isVariantInSample ( t , OPTIONS . useSample ) and t . isVariant ( ) ) or ( not isVariantInSample ( t , OPTIONS . useSample ) and not t . isVariant ( ) )
2009-12-02 22:36:03 +08:00
isFP = not isTP
else :
isTP = t
isFP = not t
#if variant.getLoc() == "1:867694":
# print variant, 'T: [', t, '] isTP, isFP', isTP, isFP
if isTP :
2009-11-14 05:46:31 +08:00
t . setField ( " FN " , 0 )
variant . setField ( " TP " , 1 )
2009-11-08 03:32:12 +08:00
nTP + = 1
2009-12-02 22:36:03 +08:00
elif isFP :
2009-11-08 03:32:12 +08:00
nFP + = 1
2009-12-02 22:36:03 +08:00
variant . setField ( " TP " , 0 )
#print t, variant, "is a FP!"
2009-11-12 06:02:57 +08:00
FPs . append ( variant )
2009-12-14 01:59:32 +08:00
nRef = len ( filter ( lambda x : not x . isVariant ( ) , truth . itervalues ( ) ) )
nFN = variantsInTruth ( truth . itervalues ( ) ) - nTP - nRef
#print 'nRef', nTP, nFP, nFN, nRef
2009-11-12 06:02:57 +08:00
return CallCmp ( nTP , nFP , nFN ) , FPs
2009-11-08 03:32:12 +08:00
2009-12-02 22:36:03 +08:00
def markTruth ( calls ) :
if not OPTIONS . useSample :
for variant in calls . itervalues ( ) :
variant . setField ( " TP " , 0 ) # set the TP field to 0
2009-11-10 06:48:51 +08:00
def compareCalls ( calls , truthCalls ) :
2009-12-02 22:36:03 +08:00
#markTruth(truthCalls)
2009-11-14 05:46:31 +08:00
def compare1 ( name , cumulative ) :
2009-12-01 23:42:06 +08:00
for field in [ " QUAL " , " OQ " ] :
2009-12-15 22:38:39 +08:00
for cc , selectedVariants in mapVariantBins ( calls , field , cumulative = cumulative ) :
2009-12-01 23:42:06 +08:00
#print selectedVariants[0]
2009-12-15 22:38:39 +08:00
printFieldQual ( " truth-comparison- " + name , field , cc , selectedVariants )
2009-11-08 03:32:12 +08:00
2009-11-12 06:02:57 +08:00
print ' PER BIN nCalls= ' , len ( calls )
2009-12-01 23:42:06 +08:00
compare1 ( ' per-bin ' , False )
2009-11-08 03:32:12 +08:00
2009-11-12 06:02:57 +08:00
print ' CUMULATIVE nCalls= ' , len ( calls )
2009-12-01 23:42:06 +08:00
compare1 ( ' cum ' , True )
2009-11-10 06:48:51 +08:00
def randomSplit ( l , pLeft ) :
def keep ( elt , p ) :
if p < pLeft :
return elt , None
else :
return None , elt
data = [ keep ( elt , p ) for elt , p in zip ( l , map ( lambda x : random . random ( ) , l ) ) ]
def get ( i ) : return filter ( lambda x : x < > None , [ x [ i ] for x in data ] )
return get ( 0 ) , get ( 1 )
2009-11-08 03:32:12 +08:00
2009-11-14 05:46:31 +08:00
def setup ( ) :
2009-11-12 06:02:57 +08:00
global OPTIONS , header
2009-11-07 07:00:46 +08:00
usage = " usage: % prog files.list [options] "
parser = OptionParser ( usage = usage )
parser . add_option ( " -f " , " --f " , dest = " fields " ,
type = ' string ' , default = " QUAL " ,
2009-11-14 05:46:31 +08:00
help = " Comma-separated list of fields (either in the VCF columns of as INFO keys) to use during optimization [default: %d efault] " )
2009-11-07 07:00:46 +08:00
parser . add_option ( " -t " , " --truth " , dest = " truth " ,
type = ' string ' , default = None ,
2009-11-14 05:46:31 +08:00
help = " VCF formated truth file. If provided, the script will compare the input calls with the truth calls. It also emits calls tagged as TP and a separate file of FP calls " )
2009-11-30 06:20:40 +08:00
parser . add_option ( " -l " , " --recalLog " , dest = " recalLog " ,
type = ' string ' , default = " recal.log " ,
help = " VCF formated truth file. If provided, the script will compare the input calls with the truth calls. It also emits calls tagged as TP and a separate file of FP calls " )
2009-12-07 11:37:14 +08:00
parser . add_option ( " -u " , " --unfiltered " , dest = " unfiltered " ,
2009-11-12 06:02:57 +08:00
action = ' store_true ' , default = False ,
2009-12-07 11:37:14 +08:00
help = " If provided, unfiltered calls will be used in comparisons [default: %d efault] " )
2009-12-04 07:52:35 +08:00
parser . add_option ( " " , " --plottable " , dest = " plottableNones " ,
action = ' store_true ' , default = False ,
help = " If provided, will generate fake plottable points for annotations with None values -- doesn ' t effect the behavior of the system just makes it easy to plot outputs [default: %d efault] " )
2009-12-14 01:59:32 +08:00
parser . add_option ( " " , " --onlyAtTruth " , dest = " onlyAtTruth " ,
action = ' store_true ' , default = False ,
help = " If provided, we only consider TP/FP/FN at truth sites[default: %d efault] " )
2009-11-07 07:00:46 +08:00
parser . add_option ( " -p " , " --partitions " , dest = " partitions " ,
2009-11-08 03:32:12 +08:00
type = ' int ' , default = 25 ,
2009-11-14 05:46:31 +08:00
help = " Number of partitions to use for each feature. Don ' t use so many that the number of variants per bin is very low. [default: %d efault] " )
parser . add_option ( " " , " --maxRecordsForCovariates " , dest = " maxRecordsForCovariates " ,
2009-12-04 07:52:35 +08:00
type = ' int ' , default = 2000000 ,
2009-11-14 05:46:31 +08:00
help = " Derive covariate information from up to this many VCF records. For files with more than this number of records, the system downsamples the reads [default: %d efault] " )
2009-11-07 07:00:46 +08:00
parser . add_option ( " -m " , " --minVariantsPerBin " , dest = " minVariantsPerBin " ,
type = ' int ' , default = 10 ,
2009-11-14 05:46:31 +08:00
help = " Don ' t include any covariates with fewer than this number of variants in the bin, if such a thing happens. NEEDS TO BE FIXED " )
2009-11-12 06:02:57 +08:00
parser . add_option ( " -M " , " --maxRecords " , dest = " maxRecords " ,
type = ' int ' , default = None ,
2009-11-14 05:46:31 +08:00
help = " Maximum number of input VCF records to process, if provided. Default is all records " )
2009-12-14 01:59:32 +08:00
parser . add_option ( " -Q " , " --qMin " , dest = " minQScore " ,
type = ' int ' , default = - 1 ,
help = " The minimum Q score of the initial SNP list to consider for selection [default: %d efault] " )
2009-12-15 22:38:39 +08:00
parser . add_option ( " " , " --QBreaks " , dest = " QBreaks " ,
type = ' string ' , default = None ,
help = " Breaks in QUAL for generating covarites [default: %d efault] " )
2009-11-07 07:00:46 +08:00
parser . add_option ( " -q " , " --qMax " , dest = " maxQScore " ,
2009-11-26 06:08:12 +08:00
type = ' int ' , default = 60 ,
2009-11-14 05:46:31 +08:00
help = " The maximum Q score allowed for both a single covariate and the overall QUAL score [default: %d efault] " )
2009-11-10 06:48:51 +08:00
parser . add_option ( " -o " , " --outputVCF " , dest = " outputVCF " ,
2009-11-30 06:20:40 +08:00
type = ' string ' , default = " recal.vcf " ,
2009-11-14 05:46:31 +08:00
help = " If provided, a VCF file will be written out to this file [default: %d efault] " )
2009-11-12 06:02:57 +08:00
parser . add_option ( " " , " --FNoutputVCF " , dest = " FNoutputVCF " ,
type = ' string ' , default = None ,
2009-11-14 05:46:31 +08:00
help = " If provided, VCF file will be written out to this file [default: %d efault] " )
2009-11-07 07:00:46 +08:00
parser . add_option ( " " , " --titv " , dest = " titvTarget " ,
type = ' float ' , default = None ,
2009-11-14 05:46:31 +08:00
help = " If provided, we will optimize calls to the targeted ti/tv rather than that calculated from known calls [default: %d efault] " )
2009-11-10 06:48:51 +08:00
parser . add_option ( " -b " , " --bootstrap " , dest = " bootStrap " ,
2009-11-14 05:46:31 +08:00
type = ' float ' , default = None ,
help = " If provided, the % o f the calls used to generate the recalibration tables. [default: %d efault] " )
2009-12-02 22:36:03 +08:00
parser . add_option ( " -s " , " --useSample " , dest = " useSample " ,
type = ' string ' , default = False ,
help = " If provided, we will examine sample genotypes for this sample, and consider TP/FP/FN in the truth conditional on sample genotypes [default: %d efault] " )
2009-11-21 22:48:29 +08:00
parser . add_option ( " -r " , " --dontRecalibrate " , dest = " dontRecalibrate " ,
action = ' store_true ' , default = False ,
help = " If provided, we will not actually do anything to the calls, they will just be assessed [default: %d efault] " )
2009-11-12 06:02:57 +08:00
2009-11-07 07:00:46 +08:00
( OPTIONS , args ) = parser . parse_args ( )
2009-11-08 03:32:12 +08:00
if len ( args ) > 2 :
2009-11-07 07:00:46 +08:00
parser . error ( " incorrect number of arguments " )
2009-11-14 05:46:31 +08:00
return args
2009-11-21 22:48:29 +08:00
def assessCalls ( file ) :
2009-11-14 05:46:31 +08:00
print ' Counting records in VCF ' , file
numberOfRecords = quickCountRecords ( open ( file ) )
if OPTIONS . maxRecords < > None and OPTIONS . maxRecords < numberOfRecords :
numberOfRecords = OPTIONS . maxRecords
downsampleFraction = min ( float ( OPTIONS . maxRecordsForCovariates ) / numberOfRecords , 1 )
2009-12-14 01:59:32 +08:00
header , allCalls = readVariants ( file , OPTIONS . maxRecords , downsampleFraction = downsampleFraction , minQScore = OPTIONS . minQScore )
2009-11-14 05:46:31 +08:00
allCalls = list ( allCalls )
print ' Number of VCF records ' , numberOfRecords , ' , max number of records for covariates is ' , OPTIONS . maxRecordsForCovariates , ' so keeping ' , downsampleFraction * 100 , ' % o f records '
print ' Number of selected VCF records ' , len ( allCalls )
2009-11-07 07:00:46 +08:00
2009-11-21 22:48:29 +08:00
titvtarget = OPTIONS . titvTarget
if titvtarget == None :
titvtarget = titv ( selectVariants ( allCalls , VCFRecord . isKnown ) )
2009-11-10 06:48:51 +08:00
print ' Ti/Tv all ' , titv ( allCalls )
print ' Ti/Tv known ' , titv ( selectVariants ( allCalls , VCFRecord . isKnown ) )
print ' Ti/Tv novel ' , titv ( selectVariants ( allCalls , VCFRecord . isNovel ) )
2009-11-21 22:48:29 +08:00
return header , allCalls , titvtarget
2009-11-07 07:00:46 +08:00
2009-11-21 22:48:29 +08:00
def determineCovariates ( allCalls , titvtarget , fields ) :
2009-11-10 06:48:51 +08:00
if OPTIONS . bootStrap :
2009-11-14 05:46:31 +08:00
callsToOptimize , recalEvalCalls = randomSplit ( allCalls , OPTIONS . bootStrap )
2009-11-10 06:48:51 +08:00
else :
2009-11-14 05:46:31 +08:00
callsToOptimize = allCalls
2009-11-10 06:48:51 +08:00
2009-11-21 22:48:29 +08:00
recalOptCalls , covariates = optimizeCalls ( callsToOptimize , fields , titvtarget )
2009-12-01 23:42:06 +08:00
printCallQuals ( " QUAL " , list ( recalOptCalls ) , titvtarget , ' OPTIMIZED CALLS ' )
2009-11-08 03:32:12 +08:00
2009-11-14 05:46:31 +08:00
if OPTIONS . bootStrap :
recalibatedEvalCalls = recalibrateCalls ( recalEvalCalls , fields , covariates )
2009-12-01 23:42:06 +08:00
printCallQuals ( " QUAL " , list ( recalibatedEvalCalls ) , titvtarget , ' BOOTSTRAP EVAL CALLS ' )
2009-11-14 05:46:31 +08:00
return covariates
def writeRecalibratedCalls ( file , header , calls ) :
if file :
f = open ( file , ' w ' )
2009-11-12 06:02:57 +08:00
#print 'HEADER', header
2009-11-14 05:46:31 +08:00
i = 0
for line in formatVCF ( header , calls ) :
if i % 10000 == 0 : print ' writing VCF record ' , i
i + = 1
2009-11-10 06:48:51 +08:00
print >> f , line
f . close ( )
2009-11-07 07:00:46 +08:00
2009-11-30 06:20:40 +08:00
def readTruth ( truthVCF ) :
2009-11-14 05:46:31 +08:00
print ' Reading truth file ' , truthVCF
2009-12-14 01:59:32 +08:00
rawTruth = list ( readVariants ( truthVCF , maxRecords = None , decodeAll = True , mustBeVariant = True ) [ 1 ] )
2009-12-07 11:37:14 +08:00
truth = dict ( [ [ v . getLoc ( ) , v ] for v in rawTruth ] )
2009-12-02 22:36:03 +08:00
print ' Number of raw and passing filter truth calls ' , len ( rawTruth ) , len ( truth )
2009-11-30 06:20:40 +08:00
return truth
2009-11-14 05:46:31 +08:00
2009-11-30 06:20:40 +08:00
def evaluateTruth ( header , callVCF , truth , truthVCF ) :
2009-11-14 05:46:31 +08:00
print ' Reading variants back in from ' , callVCF
header , calls = readVariants ( callVCF )
calls = list ( calls )
print ' -------------------------------------------------------------------------------- '
print ' Comparing calls to truth ' , truthVCF
print ' '
compareCalls ( calls , truth )
writeRecalibratedCalls ( callVCF , header , calls )
2009-12-03 06:53:17 +08:00
def isFN ( v ) :
return isVariantInSample ( v , OPTIONS . useSample ) and not v . hasField ( " FN " )
2009-11-12 06:02:57 +08:00
if truth < > None and OPTIONS . FNoutputVCF :
f = open ( OPTIONS . FNoutputVCF , ' w ' )
#print 'HEADER', header
2009-12-03 06:53:17 +08:00
for line in formatVCF ( header , filter ( isFN , truth . itervalues ( ) ) ) :
2009-11-12 06:02:57 +08:00
print >> f , line
f . close ( )
2009-11-30 06:20:40 +08:00
TRUTH_CALLS = None
RECAL_LOG = None
2009-11-14 05:46:31 +08:00
def main ( ) :
2009-11-30 06:20:40 +08:00
global TRUTH_CALLS , RECAL_LOG
2009-11-14 05:46:31 +08:00
args = setup ( )
2009-11-30 06:20:40 +08:00
fields = OPTIONS . fields . split ( ' , ' )
truthVCF = None
2009-12-10 07:26:26 +08:00
#print("LENGTH OF ARGS "+str(len(args)))
if OPTIONS . truth < > None :
truthVCF = OPTIONS . truth
2009-12-10 07:31:46 +08:00
TRUTH_CALLS = readTruth ( truthVCF )
2009-12-10 07:26:26 +08:00
2009-11-30 06:20:40 +08:00
if OPTIONS . recalLog < > None :
RECAL_LOG = open ( OPTIONS . recalLog , " w " )
print >> RECAL_LOG , " # optimized vcf " , args [ 0 ]
print >> RECAL_LOG , " # truth vcf " , truthVCF
for key , value in OPTIONS . __dict__ . iteritems ( ) :
print >> RECAL_LOG , ' # ' , key , value
2009-11-14 05:46:31 +08:00
2009-11-21 22:48:29 +08:00
header , allCalls , titvTarget = assessCalls ( args [ 0 ] )
if not OPTIONS . dontRecalibrate :
covariates = determineCovariates ( allCalls , titvTarget , fields )
2009-12-14 01:59:32 +08:00
header , callsToRecalibate = readVariants ( args [ 0 ] , OPTIONS . maxRecords , minQScore = OPTIONS . minQScore )
2009-11-21 22:48:29 +08:00
RecalibratedCalls = recalibrateCalls ( callsToRecalibate , fields , covariates )
writeRecalibratedCalls ( OPTIONS . outputVCF , header , RecalibratedCalls )
else :
2009-12-01 23:42:06 +08:00
printFieldQualHeader ( )
printCallQuals ( " QUAL " , allCalls , titvTarget )
2009-11-21 22:48:29 +08:00
OPTIONS . outputVCF = args [ 0 ]
2009-11-14 05:46:31 +08:00
2009-12-14 01:59:32 +08:00
if truthVCF < > None :
2009-11-30 06:20:40 +08:00
evaluateTruth ( header , OPTIONS . outputVCF , TRUTH_CALLS , truthVCF )
2009-11-14 05:46:31 +08:00
2009-11-21 22:48:29 +08:00
2009-11-12 06:02:57 +08:00
PROFILE = False
2009-11-07 07:00:46 +08:00
if __name__ == " __main__ " :
2009-11-12 06:02:57 +08:00
if PROFILE :
import cProfile
cProfile . run ( ' main() ' , ' fooprof ' )
import pstats
p = pstats . Stats ( ' fooprof ' )
p . sort_stats ( ' cumulative ' ) . print_stats ( 10 )
p . sort_stats ( ' time ' ) . print_stats ( 10 )
p . sort_stats ( ' time ' , ' cum ' ) . print_stats ( .5 , ' init ' )
else :
2009-12-10 07:26:26 +08:00
main ( )