2009-07-03 16:07:02 +08:00
#!/usr/bin/env python
2009-07-09 06:04:26 +08:00
import sys , itertools , FlatFileTable
2009-07-23 00:58:23 +08:00
from enum import Enum
2009-07-03 16:07:02 +08:00
def subset_list_by_indices ( indices , list ) :
subset = [ ]
for index in indices :
subset . append ( list [ index ] )
return subset
2009-07-09 06:04:26 +08:00
def chunk_generator ( record_gen , key_fields ) :
2009-07-03 16:07:02 +08:00
""" Input:
2009-07-23 00:58:23 +08:00
line_gen : generator that produces dictionaries
key_fields : keys in each dictionary used to determine chunk membership
2009-07-03 16:07:02 +08:00
Output :
2009-07-09 06:04:26 +08:00
locus_chunk : list of consecutive lines that have the same key_fields """
2009-07-03 16:07:02 +08:00
locus_chunk = [ ]
last_key = " "
2009-07-09 06:04:26 +08:00
first_record = True
for record in record_gen :
key = [ record [ f ] for f in key_fields ]
if key == last_key or first_record :
locus_chunk . append ( record )
first_record = False
2009-07-03 16:07:02 +08:00
else :
if locus_chunk != [ ] :
yield locus_chunk
2009-07-09 06:04:26 +08:00
locus_chunk = [ record ]
2009-07-08 10:05:40 +08:00
last_key = key
2009-07-03 16:07:02 +08:00
yield locus_chunk
2009-07-09 06:04:26 +08:00
2009-07-23 00:58:23 +08:00
class call_stats :
def __init__ ( self , acc_conf_calls , conf_call_rate , cum_corr_calls , cum_calls , coverage ) :
self . AccuracyConfidentCalls = acc_conf_calls
self . ConfidentCallRate = conf_call_rate
self . CumulativeConfidentCorrectCalls = cum_corr_calls
self . CumulativeCalls = cum_calls
self . Coverage = coverage
# def stat_generator(chunk):
2009-07-03 16:07:02 +08:00
2009-07-23 00:58:23 +08:00
@staticmethod
def calc_discovery_stats ( chunk ) :
calls = 0
conf_calls = 0
correct_genotype = 0
for record in chunk :
if abs ( float ( record [ " BtrLod " ] ) ) > = 5 :
conf_calls + = 1
if call_type . discovery_call_correct ( record ) :
#if call_type.genotyping_call_correct(record):
correct_genotype + = 1
calls + = 1
return correct_genotype , conf_calls , calls
@staticmethod
def calc_genotyping_stats ( chunk ) :
calls = 0
conf_calls = 0
correct_genotype = 0
for record in chunk :
if abs ( float ( record [ " BtnbLod " ] ) ) > = 5 :
conf_calls + = 1
if call_type . genotyping_call_correct ( record ) :
correct_genotype + = 1
calls + = 1
return correct_genotype , conf_calls , calls
#return call_stats(float(correct_genotype) / max(conf_calls,1), float(conf_calls) / max(calls,1))
def __str__ ( self ) :
return " %d , %.5f , %.5f , %d , %d , %.5f " % ( self . Coverage , self . AccuracyConfidentCalls , self . ConfidentCallRate , self . CumulativeConfidentCorrectCalls , self . CumulativeCalls , float ( self . CumulativeConfidentCorrectCalls ) / self . CumulativeCalls )
class call_type :
""" Class that returns an Enum with the type of call provided by a record """
call_types_3_state = Enum ( " HomozygousSNP " , " HeterozygousSNP " , " HomozygousReference " )
call_types_2_state = Enum ( " Variant " , " Reference " )
@staticmethod
def from_record_3_state ( record ) :
""" Given reference base as string, determine whether called genotype is homref, het, homvar """
ref = record [ " ReferenceBase " ] [ 0 ]
genotype = record [ " HapmapChipGenotype " ]
return call_type . call_types_3_state [ genotype . count ( ref ) ]
@staticmethod
def from_record_2_state ( ref , genotype ) :
""" Given reference base as string, determine whether called genotype is ref or var """
#ref = record["ReferenceBase"][0]
#genotype = record["HapmapChipGenotype"]
return call_type . call_types_2_state [ 0 ] if genotype . count ( ref ) < 2 else call_type . call_types_2_state [ 1 ]
@staticmethod
def genotyping_call_correct ( record ) :
return record [ " HapmapChipGenotype " ] == record [ " BestGenotype " ]
@staticmethod
def discovery_call_correct ( record ) :
return call_type . from_record_2_state ( record [ " ReferenceBase " ] [ 0 ] , record [ " HapmapChipGenotype " ] ) == call_type . from_record_2_state ( record [ " ReferenceBase " ] [ 0 ] , record [ " BestGenotype " ] )
def aggregate_stats ( filename , max_loci ) :
aggregate = dict ( )
2009-07-03 16:07:02 +08:00
2009-07-09 06:04:26 +08:00
locus_gen = chunk_generator ( FlatFileTable . record_generator ( filename , None ) , ( " Sequence " , " Position " ) )
2009-07-23 00:58:23 +08:00
#print "Fraction correct genotype\tCoverage sampled\tLocus\tReference base\tHapmap chip genotype (Max. coverage genotype call for reference calls)"
for index , locus_chunk in enumerate ( locus_gen ) :
if index > = max_loci :
break
if ( index % 1000 ) == 0 :
sys . stderr . write ( str ( index ) + " loci processed, at: " + locus_chunk [ 0 ] [ " Sequence " ] + " : " + locus_chunk [ 0 ] [ " Position " ] + " \n " )
2009-07-03 16:07:02 +08:00
covs = dict ( )
2009-07-23 00:58:23 +08:00
coverage_chunk_gen = chunk_generator ( locus_chunk , ( " DownsampledCoverage " , " Sequence " , " Position " ) )
2009-07-03 16:07:02 +08:00
for cov_chunk in coverage_chunk_gen :
2009-07-23 00:58:23 +08:00
#record = cov_chunk[0]
#stat = call_stats.calc_stats(cov_chunk)
2009-07-03 16:07:02 +08:00
2009-07-23 00:58:23 +08:00
for record in cov_chunk :
key = call_type . from_record_3_state ( record ) , int ( record [ " DownsampledCoverage " ] )
#key = call_type.from_record_3_state(record)#, int(record["DownsampledCoverage"])
record [ " DownsampledCoverage " ] = int ( record [ " DownsampledCoverage " ] )
record [ " HapmapChipCallType " ] = call_type . from_record_3_state ( record )
value = record
if aggregate . has_key ( key ) :
aggregate [ key ] . append ( value )
else :
aggregate [ key ] = [ value ]
#print "\t".join(map(str,("%.4f\t%.4f" % (stat.AccuracyConfidordentCalls, stat.ConfidentCallRate), record["DownsampledCoverage"], record["Sequence"]+":"+record["Position"],record["ReferenceBase"],record["HapmapChipGenotype"])))
#print "\n".join(map(str,sorted(aggregate.items())))
return aggregate
def create_coverage_stats_table ( aggregate , table_filename , debug ) :
fout = open ( table_filename , " w " )
print >> fout , " CallType,Coverage,AccuracyConfidentCalls,ConfidentCallRate,CumCorrectCalls,CumCalls,CumCorrectFraction "
cum_correct_calls = [ 0 , 0 , 0 ]
cum_calls = [ 0 , 0 , 0 ]
for key , records in sorted ( aggregate . items ( ) ) :
if debug :
print " KEYS: " , key
for rec in records :
if True : #abs(float(rec["BtrLod"])) > 5:
print " TEST Genotyping: " , call_type . genotyping_call_correct ( rec )
print " TEST Discovery: " , call_type . discovery_call_correct ( rec )
print " DIFF: " , call_type . genotyping_call_correct ( rec ) != call_type . discovery_call_correct ( rec )
print " \n " . join ( [ " %20s => ' %s ' " % ( k , v ) for k , v in sorted ( rec . items ( ) ) ] )
print call_type . from_record_2_state ( rec [ " ReferenceBase " ] [ 0 ] , rec [ " HapmapChipGenotype " ] )
print call_type . from_record_2_state ( rec [ " ReferenceBase " ] [ 0 ] , rec [ " BestGenotype " ] )
print
print
#print "\n".join(["%s => %s" % record.items() for record in records])
if options . do_discovery :
correct_genotype , conf_calls , calls = call_stats . calc_discovery_stats ( records )
else :
correct_genotype , conf_calls , calls = call_stats . calc_genotyping_stats ( records )
this_call_type = call_type . from_record_3_state ( records [ 0 ] )
cum_correct_calls [ this_call_type . index ] + = correct_genotype
cum_calls [ this_call_type . index ] + = calls
#yield call_stats(float(correct_genotype) / max(conf_calls,1), float(conf_calls) / max(calls,1))
record = records [ 0 ]
print >> fout , str ( record [ " HapmapChipCallType " ] ) + " , " + str ( call_stats ( float ( correct_genotype ) / max ( conf_calls , 1 ) , float ( conf_calls ) / max ( calls , 1 ) , cum_correct_calls [ this_call_type . index ] , cum_calls [ this_call_type . index ] , record [ " DownsampledCoverage " ] ) )
# record["HapmapChipCallType"])
class weighted_avg :
def __init__ ( self ) :
self . sum = 0.0
self . count = 0
def add ( self , value , counts ) :
self . sum + = value * counts
self . count + = counts
#print value, counts, self.sum, self.count
def return_avg ( self ) :
return float ( self . sum ) / max ( self . count , 1 )
def stats_from_hist ( depth_hist_filename , stats_filename ) :
#hist_zero = {"CallType" : ,"Coverage","AccuracyConfidentCalls","ConfidentCallRate","CumCorrectCalls","CumCalls","CumCorrectFraction"}
hist = [ ]
hist_gen = FlatFileTable . record_generator ( depth_hist_filename , sep = " " , skip_n_lines = 2 )
for index , record in enumerate ( hist_gen ) :
assert int ( record [ " depth " ] ) == index
hist . append ( int ( record [ " count " ] ) )
stats_dict = dict ( )
stats_gen = FlatFileTable . record_generator ( stats_filename , sep = " , " )
for record in stats_gen :
key1 = int ( record [ " Coverage " ] )
key2 = record [ " CallType " ]
stats_dict . setdefault ( key1 , dict ( ) ) # create a nested dict if it doesn't exist
stats_dict [ key1 ] [ key2 ] = record # create an entry for these keys
#print stats_dict
acc = dict ( ) #[weighted_avg()] * 3
call_rate = dict ( ) #[weighted_avg()] * 3
start = 10
end = 10
for depth , depth_count in enumerate ( hist [ start : end + 1 ] , start ) :
#print "DEPTH: "+str(depth)
try :
depth_entries = stats_dict [ depth ]
for calltype , stat in depth_entries . items ( ) :
acc . setdefault ( calltype , weighted_avg ( ) )
call_rate . setdefault ( calltype , weighted_avg ( ) )
acc [ calltype ] . add ( float ( stat [ " AccuracyConfidentCalls " ] ) , depth_count )
call_rate [ calltype ] . add ( float ( stat [ " ConfidentCallRate " ] ) , depth_count )
# acc[calltype] = stat
#ref = depth_entries["HomozygousReference"]
#het = depth_entries["HeterozygousSNP"]
#hom = depth_entries["HomozygousSNP"]
except KeyError :
break
#acc.add(float(ref["AccuracyConfidentCalls"]), depth_count)
#call_rate.add(float(ref["ConfidentCallRate"]), depth_count)
#print(float(het["AccuracyConfidentCalls"]), depth_count)
#print(float(het["ConfidentCallRate"]), depth_count)
for calltype in ( " HomozygousSNP " , " HeterozygousSNP " , " HomozygousReference " ) :
print " %25s accuracy : %.3f " % ( calltype , acc [ calltype ] . return_avg ( ) )
print " %25s call rate: %.3f " % ( calltype , call_rate [ calltype ] . return_avg ( ) )
def usage ( parser ) :
#print "Usage: CoverageEval.py geli_file OPTIONS"
parser . print_usage ( )
sys . exit ( )
if __name__ == " __main__ " :
from optparse import OptionParser
parser = OptionParser ( )
parser . add_option ( " -c " , " --genotype_call_file " , help = " GELI file to use in generating coverage stat table " , dest = " genotype_filename " , )
parser . add_option ( " -s " , " --stats_file " , help = " file to containing empirical genotyping stats " , dest = " stats_filename " )
parser . add_option ( " -g " , " --histogram_file " , help = " file to containing counts of each depth of coverage " , dest = " hist_filename " )
parser . add_option ( " -m " , " --max_loci " , help = " maximum number of loci to parse (for debugging) " , default = sys . maxint , dest = " max_loci " , type = " int " )
parser . add_option ( " -v " , " --discovery " , help = " run discovery rather than genotyping calls " , default = False , dest = " do_discovery " , action = " store_true " )
parser . add_option ( " -e " , " --evaluate " , help = " evaluate genotypes; requires a stats file and a histogram file " , default = False , dest = " evaluate_genotypes " , action = " store_true " )
parser . add_option ( " -d " , " --debug " , help = " provide debugging output " , default = False , dest = " debug " , action = " store_true " )
( options , args ) = parser . parse_args ( )
#if len(args) < 1:
# usage(parser)
#genotype_filename = args[0]
if options . evaluate_genotypes :
print " Evaluating genotypes "
if options . hist_filename == None :
sys . exit ( " Must provide -g histogram filename option " )
if options . stats_filename == None :
sys . exit ( " Must provide -s stats fliname option " )
stats_from_hist ( options . hist_filename , options . stats_filename )
else :
print " Creating performance tables from genotypes file "
aggregate = aggregate_stats ( options . genotype_filename , options . max_loci )
stats_filename = options . genotype_filename + " .stats "
create_coverage_stats_table ( aggregate , stats_filename , options . debug )