diff --git a/python/CoverageEval.py b/python/CoverageEval.py index c441014f1..c592d47ac 100755 --- a/python/CoverageEval.py +++ b/python/CoverageEval.py @@ -11,7 +11,7 @@ def subset_list_by_indices(indices, list): def chunk_generator(record_gen, key_fields): """Input: - line_gen: generator that produces dictionaries + record_gen: generator that produces dictionaries (records in database speak) key_fields: keys in each dictionary used to determine chunk membership Output: locus_chunk: list of consecutive lines that have the same key_fields""" @@ -33,59 +33,79 @@ Output: class call_stats: - def __init__(self, acc_conf_calls, conf_call_rate, cum_corr_calls, cum_calls, coverage): - self.AccuracyConfidentCalls = acc_conf_calls - self.ConfidentCallRate = conf_call_rate - self.CumulativeConfidentCorrectCalls = cum_corr_calls - self.CumulativeCalls = cum_calls - self.Coverage = coverage -# def stat_generator(chunk): + def __init__(self, call_type, coverage): #, acc_conf_calls, conf_call_rate, cum_corr_calls, cum_calls, coverage): + self.call_type = call_type + self.coverage = coverage + self.calls = 0 + self.conf_ref_calls = 0 + self.conf_het_calls = 0 + self.conf_hom_calls = 0 + self.conf_var_calls = 0 + self.conf_genotype_calls = 0 + self.conf_refvar_calls = 0 + self.correct_genotype = 0 + self.correct_refvar = 0 + + def add_stat(self, calls, conf_ref_calls, conf_het_calls, conf_hom_calls, conf_var_calls, conf_genotype_calls, conf_refvar_calls, correct_genotype, correct_refvar): + self.calls += calls + self.conf_ref_calls += conf_ref_calls + self.conf_het_calls += conf_het_calls + self.conf_hom_calls += conf_hom_calls + self.conf_var_calls += conf_var_calls + self.conf_genotype_calls += conf_genotype_calls + self.conf_refvar_calls += conf_refvar_calls + self.correct_genotype += correct_genotype + self.correct_refvar += correct_refvar @staticmethod def calc_discovery_stats(chunk): - calls = 0 conf_calls = 0 correct_genotype = 0 + calls = 0 for record in chunk: + calls += 1 if abs(float(record["BtrLod"])) >= 5: conf_calls += 1 if call_type.discovery_call_correct(record): - #if call_type.genotyping_call_correct(record): correct_genotype += 1 - - calls += 1 return correct_genotype, conf_calls, calls @staticmethod def calc_genotyping_stats(chunk): - calls = 0 conf_calls = 0 correct_genotype = 0 + calls = 0 for record in chunk: + calls += 1 if abs(float(record["BtnbLod"])) >= 5: conf_calls += 1 if call_type.genotyping_call_correct(record): correct_genotype += 1 - calls += 1 - return correct_genotype, conf_calls, calls - #return call_stats(float(correct_genotype) / max(conf_calls,1), float(conf_calls) / max(calls,1)) + @staticmethod + def stats_header(): + return "TrueGenotype,Coverage,AccuracyConfidentGenotypingCalls,ConfidentGenotypingCallRate,AccuracyConfidentDiscoveryCalls,ConfidentDiscoveryCallRate,Calls,ConfRefCalls,ConfHetCalls,ConfHomCalls,ConfGenotypeCalls,CorrectGenotypes,ConfVarCalls,ConfDiscoveryCalls,CorrectDiscovery" + def __str__(self): - return "%d,%.5f,%.5f,%d,%d,%.5f" % (self.Coverage, self.AccuracyConfidentCalls, self.ConfidentCallRate, self.CumulativeConfidentCorrectCalls, self.CumulativeCalls, float(self.CumulativeConfidentCorrectCalls)/self.CumulativeCalls ) + return ",".join(map(str, (self.calls, self.conf_ref_calls, self.conf_het_calls, self.conf_hom_calls, self.conf_genotype_calls, self.correct_genotype, self.conf_var_calls, self.conf_refvar_calls, self.correct_refvar))) + + def stats_str(self): + return "%s,%d,%.5f,%.5f,%.5f,%.5f,%s" % (self.call_type, self.coverage, float(self.correct_genotype) / max(self.conf_genotype_calls,1), float(self.conf_genotype_calls) / max(self.calls,1), float(self.correct_refvar) / max(self.conf_refvar_calls,1), float(self.conf_refvar_calls) / max(self.calls,1), self.__str__()) class call_type: """Class that returns an Enum with the type of call provided by a record""" call_types_3_state = Enum("HomozygousSNP","HeterozygousSNP","HomozygousReference") + call_types_3_state_short = Enum("Hom","Het","Ref") call_types_2_state = Enum("Variant","Reference") @staticmethod - def from_record_3_state(record): + def from_record_3_state(ref, genotype): # record): """Given reference base as string, determine whether called genotype is homref, het, homvar""" - ref = record["ReferenceBase"][0] - genotype = record["HapmapChipGenotype"] + #ref = record["ReferenceBase"][0] + #genotype = record["HapmapChipGenotype"] return call_type.call_types_3_state[genotype.count(ref)] @staticmethod @@ -98,17 +118,27 @@ class call_type: @staticmethod def genotyping_call_correct(record): return record["HapmapChipGenotype"] == record["BestGenotype"] + @staticmethod def discovery_call_correct(record): return call_type.from_record_2_state(record["ReferenceBase"][0], record["HapmapChipGenotype"]) == call_type.from_record_2_state(record["ReferenceBase"][0], record["BestGenotype"]) +def print_record_debug(rec): + print "TEST Genotyping:", call_type.genotyping_call_correct(rec) + print "TEST Discovery:", call_type.discovery_call_correct(rec) + print "DIFF:", call_type.genotyping_call_correct(rec) != call_type.discovery_call_correct(rec) + print "\n".join([" %20s => '%s'" % (k,v) for k,v in sorted(rec.items())]) + print call_type.from_record_2_state(rec["ReferenceBase"][0],rec["HapmapChipGenotype"]) + print call_type.from_record_2_state(rec["ReferenceBase"][0],rec["BestGenotype"]) + print -def aggregate_stats(filename, max_loci): +def aggregate_stats(filename, max_loci, table_filename, debug): aggregate = dict() + fout = open(table_filename,"w") + fout.write(call_stats.stats_header()+"\n") locus_gen = chunk_generator(FlatFileTable.record_generator(filename, None), ("Sequence","Position")) - #print "Fraction correct genotype\tCoverage sampled\tLocus\tReference base\tHapmap chip genotype (Max. coverage genotype call for reference calls)" for index, locus_chunk in enumerate(locus_gen): if index >= max_loci: break @@ -118,60 +148,49 @@ def aggregate_stats(filename, max_loci): covs = dict() coverage_chunk_gen = chunk_generator(locus_chunk, ("DownsampledCoverage", "Sequence", "Position")) for cov_chunk in coverage_chunk_gen: - #record = cov_chunk[0] + first_record = cov_chunk[0] #stat = call_stats.calc_stats(cov_chunk) for record in cov_chunk: - key = call_type.from_record_3_state(record), int(record["DownsampledCoverage"]) + hapmap_genotyping_call_type = call_type.from_record_3_state(record["ReferenceBase"][0],record["HapmapChipGenotype"]) + key = hapmap_genotyping_call_type, int(record["DownsampledCoverage"]) #key = call_type.from_record_3_state(record)#, int(record["DownsampledCoverage"]) record["DownsampledCoverage"] = int(record["DownsampledCoverage"]) - record["HapmapChipCallType"] = call_type.from_record_3_state(record) + record["HapmapChipCallType"] = hapmap_genotyping_call_type value = record - if aggregate.has_key(key): - aggregate[key].append(value) - else: - aggregate[key] = [value] - #print "\t".join(map(str,("%.4f\t%.4f" % (stat.AccuracyConfidordentCalls, stat.ConfidentCallRate), record["DownsampledCoverage"], record["Sequence"]+":"+record["Position"],record["ReferenceBase"],record["HapmapChipGenotype"]))) - #print "\n".join(map(str,sorted(aggregate.items()))) + correct_genotype, conf_genotype_calls, genotype_calls = call_stats.calc_genotyping_stats([record]) + correct_refvar, conf_refvar_calls, refvar_calls = call_stats.calc_discovery_stats([record]) + assert(genotype_calls == refvar_calls) - return aggregate + conf_ref_calls = 0 + conf_het_calls = 0 + conf_hom_calls = 0 + best_genotyping_call_type = call_type.from_record_3_state(record["ReferenceBase"][0],record["BestGenotype"]) + if conf_genotype_calls: + if best_genotyping_call_type.index == 0: conf_hom_calls = 1 + if best_genotyping_call_type.index == 1: conf_het_calls = 1 + if best_genotyping_call_type.index == 2: conf_ref_calls = 1 -def create_coverage_stats_table(aggregate, table_filename, debug): - fout = open(table_filename,"w") + conf_var_calls = 0 + if conf_refvar_calls: + this_variant_call_type = call_type.from_record_2_state(record["ReferenceBase"][0],record["BestGenotype"]) + conf_var_calls = 1 if this_variant_call_type.index == 0 else 0 + + aggregate.setdefault(key, call_stats(*key)) + #print ",".join(map(str,(genotype_calls, conf_ref_calls, conf_het_calls, conf_hom_calls, conf_var_calls, conf_genotype_calls, conf_refvar_calls, correct_genotype, correct_refvar))) + aggregate[key].add_stat(genotype_calls, conf_ref_calls, conf_het_calls, conf_hom_calls, conf_var_calls, conf_genotype_calls, conf_refvar_calls, correct_genotype, correct_refvar) - print >>fout, "CallType,Coverage,AccuracyConfidentCalls,ConfidentCallRate,CumCorrectCalls,CumCalls,CumCorrectFraction" + if debug:# and conf_refvar_calls: + print "KEYS:",key + print_record_debug(record) - cum_correct_calls = [0,0,0] - cum_calls = [0,0,0] for key, records in sorted(aggregate.items()): - if debug: - print "KEYS:",key - for rec in records: - if True: #abs(float(rec["BtrLod"])) > 5: - print "TEST Genotyping:", call_type.genotyping_call_correct(rec) - print "TEST Discovery:", call_type.discovery_call_correct(rec) - print "DIFF:", call_type.genotyping_call_correct(rec) != call_type.discovery_call_correct(rec) - print "\n".join([" %20s => '%s'" % (k,v) for k,v in sorted(rec.items())]) - print call_type.from_record_2_state(rec["ReferenceBase"][0],rec["HapmapChipGenotype"]) - print call_type.from_record_2_state(rec["ReferenceBase"][0],rec["BestGenotype"]) - print - print + fout.write(records.stats_str()+"\n") - #print "\n".join(["%s => %s" % record.items() for record in records]) - if options.do_discovery: - correct_genotype, conf_calls, calls = call_stats.calc_discovery_stats(records) - else: - correct_genotype, conf_calls, calls = call_stats.calc_genotyping_stats(records) - this_call_type = call_type.from_record_3_state(records[0]) - cum_correct_calls[this_call_type.index] += correct_genotype - cum_calls[this_call_type.index] += calls - #yield call_stats(float(correct_genotype) / max(conf_calls,1), float(conf_calls) / max(calls,1)) - record = records[0] - - print >>fout, str(record["HapmapChipCallType"])+","+str(call_stats(float(correct_genotype) / max(conf_calls,1), float(conf_calls) / max(calls,1), cum_correct_calls[this_call_type.index], cum_calls[this_call_type.index], record["DownsampledCoverage"])) - # record["HapmapChipCallType"]) + fout.close() + #return aggregate class weighted_avg: @@ -182,17 +201,22 @@ class weighted_avg: def add(self, value, counts): self.sum += value*counts self.count += counts - #print value, counts, self.sum, self.count def return_avg(self): return float(self.sum) / max(self.count,1) -def stats_from_hist(depth_hist_filename, stats_filename): +def stats_from_hist(depth_hist_filename, stats_filename, variant_eval_dir, depth_multiplier=1.0): - #hist_zero = {"CallType" : ,"Coverage","AccuracyConfidentCalls","ConfidentCallRate","CumCorrectCalls","CumCalls","CumCorrectFraction"} + #hist_zero = {"CallType" : ,"Coverage","AccuracyConfidentCalls","ConfidentCallRate","CumCorrectCalls","CumCalls","CumgCorrectFraction"} + #prob_genotype = [1e-5, 1e-3, 1-1e-3] + #prob_genotype = [0.37, 0.62, .0] + #prob_genotype = [0.203, 0.304, .491] + #prob_genotype = [0.216, 0.302, .481] + #prob_genotype = [0.205, 0.306, 0.491] # Based on CEU NA12878 actual hapmap chip calls + prob_genotype = [0.213, 0.313, 0.474] # Based on YRB NA19240 actual hapmap chip calls hist = [] - hist_gen = FlatFileTable.record_generator(depth_hist_filename, sep=" ", skip_n_lines=2) + hist_gen = FlatFileTable.record_generator(depth_hist_filename, sep=" ", skip_n_lines=3) for index, record in enumerate(hist_gen): assert int(record["depth"]) == index hist.append(int(record["count"])) @@ -201,47 +225,64 @@ def stats_from_hist(depth_hist_filename, stats_filename): stats_gen = FlatFileTable.record_generator(stats_filename, sep=",") for record in stats_gen: key1 = int(record["Coverage"]) - key2 = record["CallType"] + key2 = record["TrueGenotype"] + #print key2 stats_dict.setdefault(key1, dict()) # create a nested dict if it doesn't exist stats_dict[key1][key2] = record # create an entry for these keys - #print stats_dict - - acc = dict() #[weighted_avg()] * 3 - call_rate = dict() #[weighted_avg()] * 3 + acc = dict() + call_rate = dict() + conf_calls = dict() - start = 10 - end = 10 - for depth, depth_count in enumerate(hist[start:end+1],start): + start = 1 + end = 1000 + for depth, depth_count in enumerate(hist[start:end+1],start): # For Cd = depth count #print "DEPTH: "+str(depth) try: + depth = max(int(float(depth*depth_multiplier)),1) depth_entries = stats_dict[depth] - for calltype, stat in depth_entries.items(): - acc.setdefault(calltype,weighted_avg()) - call_rate.setdefault(calltype,weighted_avg()) - acc[calltype].add(float(stat["AccuracyConfidentCalls"]), depth_count) - call_rate[calltype].add(float(stat["ConfidentCallRate"]), depth_count) - - # acc[calltype] = stat - #ref = depth_entries["HomozygousReference"] - #het = depth_entries["HeterozygousSNP"] - #hom = depth_entries["HomozygousSNP"] except KeyError: + print "Stopped on depth",depth break - - #acc.add(float(ref["AccuracyConfidentCalls"]), depth_count) - #call_rate.add(float(ref["ConfidentCallRate"]), depth_count) + if True: + for true_genotype, stat in depth_entries.items(): # For t (SNP type) = true_genotype + #print "TRUE_GENOTYPE: "+str(true_genotype) + for genotype in call_type.call_types_3_state: + conf_calls.setdefault(genotype, 0.0) + prob_conf_x_call = float(stat["Conf"+str(call_type.call_types_3_state_short[genotype.index])+"Calls"])/float(stat["Calls"]) + conf_calls[genotype] += depth_count * prob_conf_x_call * prob_genotype[genotype.index] + #if genotype.index == 1: + # print "%.5f " % prob_conf_x_call, depth, depth_count, conf_calls[genotype], int(stat["Conf"+str(call_type.call_types_3_state_short[genotype.index])+"Calls"]), int(stat["Calls"]) - #print(float(het["AccuracyConfidentCalls"]), depth_count) - #print(float(het["ConfidentCallRate"]), depth_count) + acc.setdefault(true_genotype,weighted_avg()) + call_rate.setdefault(true_genotype,weighted_avg()) + acc[true_genotype].add(float(stat["AccuracyConfidentGenotypingCalls"]), depth_count) + call_rate[true_genotype].add(float(stat["ConfidentGenotypingCallRate"]), depth_count) - for calltype in ("HomozygousSNP","HeterozygousSNP","HomozygousReference"): - print "%25s accuracy : %.3f" % (calltype, acc[calltype].return_avg()) - print "%25s call rate: %.3f" % (calltype, call_rate[calltype].return_avg()) + import numpy + for genotype in call_type.call_types_3_state: + print "%19s accuracy : %.3f" % (str(genotype), acc[str(genotype)].return_avg()) + print "%19s call rate: %.3f" % (str(genotype), call_rate[str(genotype)].return_avg()) - + print "\nExpected performance given perfect accuracy and call rate:" + print "%19s %7s %7s %7s" % ("", "Actual", "Perfect", "Diff.") + total_hist_sites = numpy.sum(hist) + for genotype in call_type.call_types_3_state: + predicted = conf_calls[genotype] + perfect = prob_genotype[genotype.index]*total_hist_sites + diff = perfect - predicted + print "%19s calls: %7d %7d %7d" % (genotype, predicted, perfect, diff) + + #stats_gen = FlatFileTable.record_generator(stats_filename, sep=",") + #for chunk in chunk_generator(stats_gen, key_fields=("True_Genotype")): + + print "\nCoverage histogram mean: %.2f" % numpy.average(range(len(hist)), weights=hist) + + # STEM AND LEAF PLOT + + # If VariantEval directory given, compare with those results + def usage(parser): - #print "Usage: CoverageEval.py geli_file OPTIONS" parser.print_usage() sys.exit() @@ -250,32 +291,32 @@ if __name__ == "__main__": parser = OptionParser() parser.add_option("-c", "--genotype_call_file", help="GELI file to use in generating coverage stat table", dest="genotype_filename", ) - parser.add_option("-s", "--stats_file", help="file to containing empirical genotyping stats", dest="stats_filename") - parser.add_option("-g", "--histogram_file", help="file to containing counts of each depth of coverage", dest="hist_filename") + parser.add_option("-s", "--stats_file", help="file to containing empirical genotyping stats", dest="stats_filename", default=None) + parser.add_option("-g", "--histogram_file", help="file containing counts of each depth of coverage", dest="hist_filename") parser.add_option("-m", "--max_loci", help="maximum number of loci to parse (for debugging)", default=sys.maxint, dest="max_loci", type="int") - parser.add_option("-v", "--discovery", help="run discovery rather than genotyping calls", default=False, dest="do_discovery", action="store_true") parser.add_option("-e", "--evaluate", help="evaluate genotypes; requires a stats file and a histogram file", default=False, dest="evaluate_genotypes", action="store_true") parser.add_option("-d", "--debug", help="provide debugging output", default=False, dest="debug", action="store_true") - - + parser.add_option("-p", "--depth_multiplier", help="multiply all depths in histogram by this value; for \"downsampling\" depth", default=1.0, dest="depth_multiplier", type=float) + parser.add_option("-v", "--variant_eval_dir", help="directory with output of VariantEval to compare this prediction to", default=None, dest="variant_eval_dir") + (options, args) = parser.parse_args() - #if len(args) < 1: - # usage(parser) - #genotype_filename = args[0] - - if options.evaluate_genotypes: + if not options.evaluate_genotypes: + print "Creating performance tables from genotypes file" + #if options.stats_filename == None: + # sys.exit("Must provide -s stats fliname option") + if options.genotype_filename == None: + sys.exit("Must provide -c genotype call filename option") + stats_filename = options.stats_filename if options.stats_filename != None else options.genotype_filename+".stats" + aggregate_stats(options.genotype_filename, options.max_loci, options.stats_filename, options.debug) + else: print "Evaluating genotypes" + print "Depth multiplier:",options.depth_multiplier if options.hist_filename == None: sys.exit("Must provide -g histogram filename option") if options.stats_filename == None: sys.exit("Must provide -s stats fliname option") - stats_from_hist(options.hist_filename, options.stats_filename) - else: - print "Creating performance tables from genotypes file" - aggregate = aggregate_stats(options.genotype_filename, options.max_loci) - stats_filename = options.genotype_filename+".stats" - create_coverage_stats_table(aggregate, stats_filename, options.debug) + stats_from_hist(options.hist_filename, options.stats_filename, options.variant_eval_dir, options.depth_multiplier) diff --git a/python/FlatFileTable.py b/python/FlatFileTable.py index fe75344c1..9f54c603e 100644 --- a/python/FlatFileTable.py +++ b/python/FlatFileTable.py @@ -2,12 +2,17 @@ import sys, itertools -def record_generator(filename, sep="\t"): +def record_generator(filename, sep="\t", skip_n_lines=0): """Given a file with field headers on the first line and records on subsequent lines, generates a dictionary for each line keyed by the header fields""" fin = open(filename) - header = fin.readline().rstrip().split() # pull off header - for line in fin: + + for i in range(skip_n_lines): # Skip a number of lines + fin.readline() + + header = fin.readline().rstrip().split(sep) # Pull off header + + for line in fin: # fields = line.rstrip().split(sep) record = dict(itertools.izip(header, fields)) yield record