From 4f6d26849ffdcc6ee36b3f97a42b5513a42a2384 Mon Sep 17 00:00:00 2001 From: jmaguire Date: Tue, 16 Jun 2009 20:03:24 +0000 Subject: [PATCH] Behold MultiSampleCaller! Complete re-write of PoolCaller algorithm, now basically beta quality code. Improvements over PoolCaller include: - more correct strand test - fractional counts from genotypes (which means no individual lod threshold needed) - signifigantly cleaner code; first beta-quality code I've written since BaitDesigner so long ago. - faster, less likely to crash! git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@1020 348d0f76-0448-11de-a6fe-93d51630548a --- .../gatk/walkers/CoverageBySample.java | 120 +++- .../gatk/walkers/MultiSampleCaller.java | 542 ++++++++++++++++++ .../playground/gatk/walkers/PoolCaller.java | 259 ++++++--- .../gatk/walkers/SingleSampleGenotyper.java | 14 - .../utils/AlleleFrequencyEstimate.java | 2 + .../playground/utils/GenotypeLikelihoods.java | 59 +- .../sting/utils/ReadBackedPileup.java | 19 +- 7 files changed, 892 insertions(+), 123 deletions(-) create mode 100644 java/src/org/broadinstitute/sting/playground/gatk/walkers/MultiSampleCaller.java diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/CoverageBySample.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/CoverageBySample.java index 2b61d53a1..3f0825d44 100644 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/CoverageBySample.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/CoverageBySample.java @@ -9,6 +9,8 @@ import org.broadinstitute.sting.gatk.LocusContext; import org.broadinstitute.sting.playground.gatk.walkers.AlleleFrequencyWalker; import org.broadinstitute.sting.playground.utils.AlleleFrequencyEstimate; import org.broadinstitute.sting.utils.cmdLine.Argument; +import org.broadinstitute.sting.utils.Utils; +import org.broadinstitute.sting.utils.BaseUtils; import java.util.*; @@ -41,14 +43,48 @@ public class CoverageBySample extends LocusWalker public String map(RefMetaDataTracker tracker, char ref, LocusContext context) { String line = context.getLocation().getContig() + " " + context.getLocation().getStart() + " " ; + + LocusContext[] contexts = filterLocusContext(context, sample_names, 0); + HashMap counts = countReadsBySample(context); - for (int i = 0; i < sample_names.size(); i++) + for (int i = 0; i < contexts.length; i++) { - int count = counts.get(sample_names.get(i)); - line += " " + count; + List reads = contexts[i].getReads(); + List offsets = contexts[i].getOffsets(); + + out.printf("%s %s ", context.getLocation(), sample_names.get(i)); + + int[] forward_counts = new int[4]; + int[] backward_counts = new int[4]; + + for (int j = 0; j < reads.size(); j++) + { + SAMRecord read = reads.get(j); + int offset = offsets.get(j); + boolean backward = read.getReadNegativeStrandFlag(); + char base = Character.toUpperCase((char)(read.getReadBases()[offset])); + + if (BaseUtils.simpleBaseToBaseIndex(base) == -1) { continue; } + + if (backward) { base = Character.toLowerCase(base); } + + if (! backward) { forward_counts[BaseUtils.simpleBaseToBaseIndex(base)]++; } + else { backward_counts[BaseUtils.simpleBaseToBaseIndex(base)]++; } + + //out.printf("%c", base); + } + out.printf("A[%d] C[%d] G[%d] T[%d] a[%d] c[%d] g[%d] t[%d]", + forward_counts[0], + forward_counts[1], + forward_counts[2], + forward_counts[3], + backward_counts[0], + backward_counts[1], + backward_counts[2], + backward_counts[3]); + out.printf("\n"); } - line += "\n"; - return line; + return ""; } private HashMap countReadsBySample(LocusContext context) @@ -68,13 +104,80 @@ public class CoverageBySample extends LocusWalker return counts; } - public void onTraversalDone() + private LocusContext[] filterLocusContext(LocusContext context, List sample_names, int downsample) + { + HashMap index = new HashMap(); + for (int i = 0; i < sample_names.size(); i++) + { + index.put(sample_names.get(i), i); + } + + LocusContext[] contexts = new LocusContext[sample_names.size()]; + ArrayList[] reads = new ArrayList[sample_names.size()]; + ArrayList[] offsets = new ArrayList[sample_names.size()]; + + for (int i = 0; i < sample_names.size(); i++) + { + reads[i] = new ArrayList(); + offsets[i] = new ArrayList(); + } + + for (int i = 0; i < context.getReads().size(); i++) + { + SAMRecord read = context.getReads().get(i); + Integer offset = context.getOffsets().get(i); + String RG = (String)(read.getAttribute("RG")); + + assert(header != null); + assert(header.getReadGroup(RG) != null); + + String sample = header.getReadGroup(RG).getSample(); + //if (SAMPLE_NAME_REGEX != null) { sample = sample.replaceAll(SAMPLE_NAME_REGEX, "$1"); } + reads[index.get(sample)].add(read); + offsets[index.get(sample)].add(offset); + } + + if (downsample != 0) + { + for (int j = 0; j < reads.length; j++) + { + List perm = new ArrayList(); + for (int i = 0; i < reads[j].size(); i++) { perm.add(i); } + perm = Utils.RandomSubset(perm, downsample); + + ArrayList downsampled_reads = new ArrayList(); + ArrayList downsampled_offsets = new ArrayList(); + + for (int i = 0; i < perm.size(); i++) + { + downsampled_reads.add(reads[j].get(perm.get(i))); + downsampled_offsets.add(offsets[j].get(perm.get(i))); + } + + reads[j] = downsampled_reads; + offsets[j] = downsampled_offsets; + contexts[j] = new LocusContext(context.getLocation(), reads[j], offsets[j]); + } + } + else + { + for (int j = 0; j < reads.length; j++) + { + contexts[j] = new LocusContext(context.getLocation(), reads[j], offsets[j]); + } + } + + return contexts; + } + + public void onTraversalDone(String result) { return; } public String reduceInit() { + /* String header = "contig offset"; for (int i = 0; i < sample_names.size(); i++) { @@ -83,11 +186,14 @@ public class CoverageBySample extends LocusWalker header += "\n"; out.print(header); return header; + */ + return ""; } public String reduce(String line, String sum) { - out.print(line); + //out.print(line); + out.flush(); return ""; } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/MultiSampleCaller.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/MultiSampleCaller.java new file mode 100644 index 000000000..9115ce440 --- /dev/null +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/MultiSampleCaller.java @@ -0,0 +1,542 @@ + +package org.broadinstitute.sting.playground.gatk.walkers; + +import net.sf.samtools.SAMFileHeader; +import net.sf.samtools.SAMReadGroupRecord; +import net.sf.samtools.SAMRecord; +import org.broadinstitute.sting.gatk.GenomeAnalysisEngine; +import org.broadinstitute.sting.gatk.LocusContext; +import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; +import org.broadinstitute.sting.gatk.walkers.LocusWalker; +import org.broadinstitute.sting.playground.utils.AlleleFrequencyEstimate; +import org.broadinstitute.sting.playground.utils.*; +import org.broadinstitute.sting.utils.*; +import org.broadinstitute.sting.utils.ReadBackedPileup; +import org.broadinstitute.sting.utils.cmdLine.Argument; + +import java.util.*; +import java.util.zip.*; +import java.io.*; + +// Beta iterative multi-sample caller +// j.maguire 6-11-2009 + +public class MultiSampleCaller extends LocusWalker +{ + @Argument(required=false, shortName="fractional_counts", doc="should we use fractional counts?") public boolean FRACTIONAL_COUNTS = false; + @Argument(required=false, shortName="max_iterations", doc="Maximum number of iterations for EM") public int MAX_ITERATIONS = 10; + @Argument(fullName="lodThreshold", shortName="lod", required=false, doc="lod threshold for outputting individual genotypes") public Double lodThreshold = 0.0; + @Argument(fullName="discovery_output", shortName="discovery_output", required=true, doc="file to write SNP discovery output to") public String DISCOVERY_OUTPUT; + @Argument(fullName="individual_output", shortName="individual_output", required=true, doc="file to write individual SNP calls to") public String INDIVIDUAL_OUTPUT; + @Argument(fullName="sample_name_regex", shortName="sample_name_regex", required=false, doc="sample_name_regex") public String SAMPLE_NAME_REGEX = null; + + // Private state. + List sample_names; + private SAMFileHeader header; + PrintStream individual_output_file; + PrintStream discovery_output_file; + + ///////// + // Walker Interface Functions + public void initialize() + { + try + { + discovery_output_file = new PrintStream(DISCOVERY_OUTPUT); + individual_output_file = new PrintStream(new GZIPOutputStream(new FileOutputStream(INDIVIDUAL_OUTPUT))); + + discovery_output_file.println("loc ref alt lod strand_score pD pNull discovery_lod in_dbsnp pA pC pG pT EM_alt_freq EM_N n_ref n_het n_hom pD_fw nNull_fw EM_alt_freq_fw pD_bw pNull_bw EM_alt_freq_bw"); + individual_output_file.println("loc ref sample_name genotype lodVsNextBest lodVsRef in_dbsnp AA AC AG AT CC CG CT GG GT TT"); + } + catch (Exception e) + { + e.printStackTrace(); + System.exit(-1); + } + + + GenomeAnalysisEngine toolkit = this.getToolkit(); + this.header = toolkit.getEngine().getSAMHeader(); + List read_groups = header.getReadGroups(); + + sample_names = new ArrayList(); + + HashSet unique_sample_names = new HashSet(); + + for (int i = 0; i < read_groups.size(); i++) + { + String sample_name = read_groups.get(i).getSample(); + + if (SAMPLE_NAME_REGEX != null) { sample_name = sample_name.replaceAll(SAMPLE_NAME_REGEX, "$1"); } + + if (unique_sample_names.contains(sample_name)) { continue; } + unique_sample_names.add(sample_name); + sample_names.add(sample_name); + System.out.println("SAMPLE: " + sample_name); + } + } + + public String in_dbsnp = "novel"; + + public String map(RefMetaDataTracker tracker, char ref, LocusContext context) + { + this.ref = ref; + if (tracker.lookup("DBSNP", null) != null) { in_dbsnp = "known"; } else { in_dbsnp = "novel"; } + this.MultiSampleCall(context, sample_names); + return null; + } + + public void onTraversalDone(String sum) + { + out.println("MultiSampleCaller done."); + return; + } + + public String reduceInit() + { + return null; + } + + public String reduce(String record, String sum) + { + return null; + } + + // END Walker Interface Functions + ///////// + + + ///////// + // Calling Functions + + char ref; + + GenotypeLikelihoods Genotype(LocusContext context, double[] allele_likelihoods) + { + ReadBackedPileup pileup = new ReadBackedPileup(ref, context); + String bases = pileup.getBases(); + + if (bases.length() == 0) + { + GenotypeLikelihoods G = new GenotypeLikelihoods(); + return G; + } + + List reads = context.getReads(); + List offsets = context.getOffsets(); + ref = Character.toUpperCase(ref); + + /* + // Handle indels. + if (call_indels) + { + String[] indels = BasicPileup.indelPileup(reads, offsets); + IndelCall indel_call = GenotypeLikelihoods.callIndel(indels); + if (indel_call != null) + { + if (! indel_call.type.equals("ref")) + { + System.out.printf("INDEL %s %s\n", context.getLocation(), indel_call); + } + } + } + */ + + // Handle single-base polymorphisms. + GenotypeLikelihoods G = new GenotypeLikelihoods(); + for ( int i = 0; i < reads.size(); i++ ) + { + SAMRecord read = reads.get(i); + int offset = offsets.get(i); + G.add(ref, read.getReadString().charAt(offset), read.getBaseQualities()[offset]); + } + G.ApplyPrior(ref, allele_likelihoods); + + /* + // Handle 2nd-best base calls. + if (fourBaseMode && pileup.getBases().length() < 750) + { + G.applySecondBaseDistributionPrior(pileup.getBases(), pileup.getSecondaryBasePileup()); + } + */ + + return G; + } + + // thoughly check this function + double[] CountFreqs(GenotypeLikelihoods[] genotype_likelihoods) + { + double[] allele_likelihoods = new double[4]; + for (int x = 0; x < genotype_likelihoods.length; x++) + { + if (genotype_likelihoods[x].coverage == 0) { continue; } + + double Z = 0; + for(int k = 0; k < 10; k++) { Z += Math.pow(10,genotype_likelihoods[x].likelihoods[k]); } + + double[] personal_allele_likelihoods = new double[4]; + int k = 0; + for (int i = 0; i < 4; i++) + { + for (int j = i; j < 4; j++) + { + personal_allele_likelihoods[i] += Math.pow(10,genotype_likelihoods[x].likelihoods[k])/Z; + personal_allele_likelihoods[j] += Math.pow(10,genotype_likelihoods[x].likelihoods[k])/Z; + k++; + } + } + double sum = 0; + for (int y = 0; y < 4; y++) { sum += personal_allele_likelihoods[y]; } + for (int y = 0; y < 4; y++) { personal_allele_likelihoods[y] /= sum; } + for (int y = 0; y < 4; y++) { allele_likelihoods[y] += personal_allele_likelihoods[y]; } + } + + double sum = 0; + for (int i = 0; i < 4; i++) { sum += allele_likelihoods[i]; } + for (int i = 0; i < 4; i++) { allele_likelihoods[i] /= sum; } + + return allele_likelihoods; + } + + // Potential precision error here. + double Compute_pD(GenotypeLikelihoods[] genotype_likelihoods) + { + double pD = 0; + for (int i = 0; i < sample_names.size(); i++) + { + double sum = 0; + for (int j = 0; j < 10; j++) + { + sum += Math.pow(10, genotype_likelihoods[i].likelihoods[j]); + } + pD += Math.log10(sum); + } + return pD; + } + + double Compute_pNull(LocusContext[] contexts) + { + double[] allele_likelihoods = new double[4]; + for (int i = 0; i < 4; i++) { allele_likelihoods[i] = 1e-6/3.0; } + allele_likelihoods[BaseUtils.simpleBaseToBaseIndex(ref)] = 1.0-1e-6; + GenotypeLikelihoods[] G = new GenotypeLikelihoods[sample_names.size()]; + for (int j = 0; j < sample_names.size(); j++) + { + G[j] = Genotype(contexts[j], allele_likelihoods); + } + return Compute_pD(G); + } + + double LOD(LocusContext[] contexts) + { + EM_Result em_result = EM(contexts); + GenotypeLikelihoods[] G = em_result.genotype_likelihoods; + double pD = Compute_pD(G); + double pNull = Compute_pNull(contexts); + double lod = pD - pNull; + return lod; + } + + class EM_Result + { + GenotypeLikelihoods[] genotype_likelihoods; + double[] allele_likelihoods; + int EM_N; + public EM_Result(GenotypeLikelihoods[] genotype_likelihoods, double[] allele_likelihoods) + { + this.genotype_likelihoods = genotype_likelihoods; + this.allele_likelihoods = allele_likelihoods; + + EM_N = 0; + for (int i = 0; i < genotype_likelihoods.length; i++) + { + if (genotype_likelihoods[i].coverage > 0) { EM_N += 1; } + } + } + } + + EM_Result EM(LocusContext[] contexts) + { + double[] allele_likelihoods = new double[4]; + + // These initial conditions should roughly replicate classic SSG. (at least on hets) + for (int i = 0; i < 4; i++) + { + if (i == BaseUtils.simpleBaseToBaseIndex(ref)) { allele_likelihoods[i] = 0.9994999; } //sqrt(0.999) + else { allele_likelihoods[i] = 0.0005002502; } // 0.001 / (2 * sqrt(0.999) + } + + GenotypeLikelihoods[] G = new GenotypeLikelihoods[sample_names.size()]; + for (int i = 0; i < MAX_ITERATIONS; i++) + { + for (int j = 0; j < sample_names.size(); j++) + { + G[j] = Genotype(contexts[j], allele_likelihoods); + } + allele_likelihoods = CountFreqs(G); + } + + return new EM_Result(G, allele_likelihoods); + } + + // Hacky global variables for debugging. + double pNull_fw; + double pNull_bw; + double pD_fw; + double pD_bw; + double EM_alt_freq_fw; + double EM_alt_freq_bw; + double StrandScore(LocusContext context) + { + LocusContext[] contexts = filterLocusContextBySample(context, sample_names, 0); + + LocusContext fw = filterLocusContextByStrand(context, "+"); + LocusContext bw = filterLocusContextByStrand(context, "-"); + LocusContext[] contexts_fw = filterLocusContextBySample(fw, sample_names, 0); + LocusContext[] contexts_bw = filterLocusContextBySample(bw, sample_names, 0); + + EM_Result em_fw = EM(contexts_fw); + EM_Result em_bw = EM(contexts_bw); + + pNull_fw = Compute_pNull(contexts_fw); + pNull_bw = Compute_pNull(contexts_bw); + + pD_fw = Compute_pD(em_fw.genotype_likelihoods); + pD_bw = Compute_pD(em_bw.genotype_likelihoods); + + EM_alt_freq_fw = Compute_alt_freq(ref, em_fw.allele_likelihoods); + EM_alt_freq_bw = Compute_alt_freq(ref, em_bw.allele_likelihoods); + + double pNull = Compute_pNull(contexts); + + double lod = LOD(contexts); + double lod_fw = (pD_fw + pNull_bw) - pNull; + double lod_bw = (pD_bw + pNull_fw) - pNull; + double strand_score = Math.max(lod_fw - lod, lod_bw - lod); + return strand_score; + } + + GenotypeLikelihoods HardyWeinberg(double[] allele_likelihoods) + { + GenotypeLikelihoods G = new GenotypeLikelihoods(); + int k = 0; + for (int i = 0; i < 4; i++) + { + for (int j = i; j < 4; j++) + { + G.likelihoods[k] = allele_likelihoods[i] * allele_likelihoods[j]; + k++; + } + } + return G; + } + + char PickAlt(char ref, double[] allele_likelihoods) + { + Integer[] perm = Utils.SortPermutation(allele_likelihoods); + if (perm[3] != BaseUtils.simpleBaseToBaseIndex(ref)) { return BaseUtils.baseIndexToSimpleBase(perm[3]); } + else { return BaseUtils.baseIndexToSimpleBase(perm[2]); } + } + + double Compute_discovery_lod(char ref, GenotypeLikelihoods[] genotype_likelihoods) + { + double pBest = 0; + double pRef = 0; + for (int i = 0; i < genotype_likelihoods.length; i++) + { + pBest += genotype_likelihoods[i].BestPosterior(); + pRef += genotype_likelihoods[i].RefPosterior(ref); + } + return pBest - pRef; + } + + // this one is a bit of a lazy hack. + double Compute_alt_freq(char ref, double[] allele_likelihoods) + { + return allele_likelihoods[BaseUtils.simpleBaseToBaseIndex(PickAlt(ref, allele_likelihoods))]; + } + + int Compute_n_ref(char ref, GenotypeLikelihoods[] genotype_likelihoods) + { + int n = 0; + for (int i = 0; i < genotype_likelihoods.length; i++) + { + if (genotype_likelihoods[i].coverage == 0) { continue; } + String g = genotype_likelihoods[i].BestGenotype(); + if ((g.charAt(0) == ref) && (g.charAt(1) == ref)) { n += 1; } + } + return n; + } + + int Compute_n_het(char ref, GenotypeLikelihoods[] genotype_likelihoods) + { + int n = 0; + for (int i = 0; i < genotype_likelihoods.length; i++) + { + if (genotype_likelihoods[i].coverage == 0) { continue; } + String g = genotype_likelihoods[i].BestGenotype(); + if ((g.charAt(0) == ref) && (g.charAt(1) != ref)) { n += 1; } + if ((g.charAt(0) != ref) && (g.charAt(1) == ref)) { n += 1; } + } + return n; + } + + int Compute_n_hom(char ref, GenotypeLikelihoods[] genotype_likelihoods) + { + int n = 0; + for (int i = 0; i < genotype_likelihoods.length; i++) + { + if (genotype_likelihoods[i].coverage == 0) { continue; } + String g = genotype_likelihoods[i].BestGenotype(); + if ((g.charAt(0) != ref) && (g.charAt(1) != ref)) { n += 1; } + } + return n; + } + + // This should actually return a GLF Record + String MultiSampleCall(LocusContext context, List sample_names) + { + LocusContext[] contexts = filterLocusContextBySample(context, sample_names, 0); + double lod = LOD(contexts); + double strand_score = StrandScore(context); + EM_Result em_result = EM(contexts); + GenotypeLikelihoods population_genotype_likelihoods = HardyWeinberg(em_result.allele_likelihoods); + + double pD = Compute_pD(em_result.genotype_likelihoods); + double pNull = Compute_pNull(contexts); + + double discovery_lod = Compute_discovery_lod(ref, em_result.genotype_likelihoods); + double alt_freq = Compute_alt_freq(ref, em_result.allele_likelihoods); + + char alt = 'N'; + if (lod > 0.0) { alt = PickAlt(ref, em_result.allele_likelihoods); } + + int n_ref = Compute_n_ref(ref, em_result.genotype_likelihoods); + int n_het = Compute_n_het(ref, em_result.genotype_likelihoods); + int n_hom = Compute_n_hom(ref, em_result.genotype_likelihoods); + + discovery_output_file.printf("%s %c %c %f %f %f %f %f %s ", context.getLocation(), ref, alt, lod, strand_score, pD, pNull, discovery_lod, in_dbsnp); + for (int i = 0; i < 4; i++) { discovery_output_file.printf("%f ", em_result.allele_likelihoods[i]); } + discovery_output_file.printf("%f %d %d %d %d %f %f %f %f %f %f\n", alt_freq, em_result.EM_N, n_ref, n_het, n_hom, pD_fw, pNull_fw, EM_alt_freq_fw, pD_bw, pNull_bw, EM_alt_freq_bw); + + for (int i = 0; i < em_result.genotype_likelihoods.length; i++) + { + individual_output_file.printf("%s %c %s ", context.getLocation(), ref, sample_names.get(i)); + individual_output_file.printf("%s %f %f %s ", em_result.genotype_likelihoods[i].BestGenotype(), + em_result.genotype_likelihoods[i].LodVsNextBest(), + em_result.genotype_likelihoods[i].LodVsRef(ref), + in_dbsnp); + //individual_output.printf("%s ", new ReadBackedPileup(ref, contexts[i]).getBasePileupAsCountsString()); + assert(em_result.genotype_likelihoods[i] != null); + em_result.genotype_likelihoods[i].sort(); + assert(em_result.genotype_likelihoods[i].sorted_likelihoods != null); + for (int j = 0; j < em_result.genotype_likelihoods[i].sorted_likelihoods.length; j++) + { + individual_output_file.printf("%f ", em_result.genotype_likelihoods[i].likelihoods[j]); + } + individual_output_file.printf("\n"); + } + + return null; + } + + // END Calling Functions + ///////// + + ///////// + // Utility Functions + + /// Filter a locus context by forward and backward + private LocusContext filterLocusContextByStrand(LocusContext context, String strand) + { + ArrayList reads = new ArrayList(); + ArrayList offsets = new ArrayList(); + + for (int i = 0; i < context.getReads().size(); i++) + { + SAMRecord read = context.getReads().get(i); + Integer offset = context.getOffsets().get(i); + + // Filter for strandedness + if ((!strand.contains("+")) && (read.getReadNegativeStrandFlag() == false)) { continue; } + if ((!strand.contains("-")) && (read.getReadNegativeStrandFlag() == true)) { continue; } + reads.add(read); + offsets.add(offset); + } + return new LocusContext(context.getLocation(), reads, offsets); + } + + // Filter a locus context by sample ID + private LocusContext[] filterLocusContextBySample(LocusContext context, List sample_names, int downsample) + { + HashMap index = new HashMap(); + for (int i = 0; i < sample_names.size(); i++) + { + index.put(sample_names.get(i), i); + } + + LocusContext[] contexts = new LocusContext[sample_names.size()]; + ArrayList[] reads = new ArrayList[sample_names.size()]; + ArrayList[] offsets = new ArrayList[sample_names.size()]; + + for (int i = 0; i < sample_names.size(); i++) + { + reads[i] = new ArrayList(); + offsets[i] = new ArrayList(); + } + + for (int i = 0; i < context.getReads().size(); i++) + { + SAMRecord read = context.getReads().get(i); + Integer offset = context.getOffsets().get(i); + String RG = (String)(read.getAttribute("RG")); + + assert(header != null); + assert(header.getReadGroup(RG) != null); + + String sample = header.getReadGroup(RG).getSample(); + if (SAMPLE_NAME_REGEX != null) { sample = sample.replaceAll(SAMPLE_NAME_REGEX, "$1"); } + reads[index.get(sample)].add(read); + offsets[index.get(sample)].add(offset); + } + + if (downsample != 0) + { + for (int j = 0; j < reads.length; j++) + { + List perm = new ArrayList(); + for (int i = 0; i < reads[j].size(); i++) { perm.add(i); } + perm = Utils.RandomSubset(perm, downsample); + + ArrayList downsampled_reads = new ArrayList(); + ArrayList downsampled_offsets = new ArrayList(); + + for (int i = 0; i < perm.size(); i++) + { + downsampled_reads.add(reads[j].get(perm.get(i))); + downsampled_offsets.add(offsets[j].get(perm.get(i))); + } + + reads[j] = downsampled_reads; + offsets[j] = downsampled_offsets; + contexts[j] = new LocusContext(context.getLocation(), reads[j], offsets[j]); + } + } + else + { + for (int j = 0; j < reads.length; j++) + { + contexts[j] = new LocusContext(context.getLocation(), reads[j], offsets[j]); + } + } + + return contexts; + } + + // END Utility functions + ///////// + + + + +} diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/PoolCaller.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/PoolCaller.java index b19155c5a..d04804c9e 100644 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/PoolCaller.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/PoolCaller.java @@ -9,9 +9,9 @@ import org.broadinstitute.sting.gatk.LocusContext; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.walkers.LocusWalker; import org.broadinstitute.sting.playground.utils.AlleleFrequencyEstimate; -import org.broadinstitute.sting.utils.GenomeLoc; +import org.broadinstitute.sting.utils.*; +import org.broadinstitute.sting.playground.utils.*; import org.broadinstitute.sting.utils.ReadBackedPileup; -import org.broadinstitute.sting.utils.Utils; import org.broadinstitute.sting.utils.cmdLine.Argument; import java.util.*; @@ -21,7 +21,7 @@ import java.io.*; // Draft iterative pooled caller // j.maguire 4-27-2009 -public class PoolCaller extends LocusWalker +public class PoolCaller extends LocusWalker { List callers = null; List sample_names = null; @@ -38,9 +38,6 @@ public class PoolCaller extends LocusWalker private PrintStream discovery_output_file; private PrintStream individual_output_file; - AlleleFrequencyEstimate[] calls; - ArrayList caller_sums; - public void initialize() { try @@ -59,15 +56,8 @@ public class PoolCaller extends LocusWalker this.header = toolkit.getEngine().getSAMHeader(); List read_groups = header.getReadGroups(); - /* - GenomeAnalysisEngine toolkit = this.getToolkit(); - SAMFileHeader header = toolkit.getSamReader().getFileHeader(); - List read_groups = header.getReadGroups(); - */ - sample_names = new ArrayList(); callers = new ArrayList(); - caller_sums = new ArrayList(); random = new Random(42); @@ -93,12 +83,11 @@ public class PoolCaller extends LocusWalker caller.SAMPLE_NAME_REGEX = SAMPLE_NAME_REGEX; caller.initialize(); caller.variantsOut = individual_output_file; - caller_sums.add(caller.reduceInit()); callers.add(caller); } } - public AlleleFrequencyEstimate map(RefMetaDataTracker tracker, char ref, LocusContext context) + public AlleleFrequencyEstimate[] map(RefMetaDataTracker tracker, char ref, LocusContext context) { if (ref == 'N') { return null; } ref = Character.toUpperCase(ref); @@ -110,66 +99,151 @@ public class PoolCaller extends LocusWalker if (forward.getReads().size() == 0) { return null; } if (backward.getReads().size() == 0) { return null; } - AlleleFrequencyEstimate estimate_both = EM(tracker, ref, context); - AlleleFrequencyEstimate estimate_forward = EM(tracker, ref, forward); - AlleleFrequencyEstimate estimate_backward = EM(tracker, ref, backward); + // Pick the alternate base + char alt = 'N'; + { + EM_Result result_both = EM(tracker, ref, context, -1, 'N', 1, lodThreshold, callers); + int[] counts = new int[4]; + if (result_both.individuals == null) { return null; } + for (int i = 0; i < result_both.individuals.length; i++) + { + if (result_both.individuals[i] == null) { continue; } + if (result_both.individuals[i].lodVsRef >= lodThreshold) + { + counts[BaseUtils.simpleBaseToBaseIndex(result_both.individuals[i].alt)] += 1; + } + Integer[] perm = Utils.SortPermutation(counts); + alt = BaseUtils.baseIndexToSimpleBase(perm[3]); + } + } - discovery_output_file.printf("%s %c %f %c %f\n", + double EM_alt_freq; + if (MAX_ITERATIONS == 1) { EM_alt_freq = -1; } + else { EM_alt_freq = 0.5; } + + EM_Result result_both = EM(tracker, ref, context, EM_alt_freq, alt, MAX_ITERATIONS, lodThreshold, callers); + EM_Result result_forward = EM(tracker, ref, forward, EM_alt_freq, alt, MAX_ITERATIONS, lodThreshold, callers); + EM_Result result_backward = EM(tracker, ref, backward, EM_alt_freq, alt, MAX_ITERATIONS, lodThreshold, callers); + + EM_Result null_both = EM(tracker, ref, context, 0, alt, 1, 1e-3, callers); + EM_Result null_forward = EM(tracker, ref, forward, 0, alt, 1, 1e-3, callers); + EM_Result null_backward = EM(tracker, ref, backward, 0, alt, 1, 1e-3, callers); + + if (result_both.pool == null) { return null; } + AlleleFrequencyEstimate estimate_both = result_both.pool; + + double lod_forward; + double lod_backward; + double lod_both; + double strand_score; + char forward_alt; + char backward_alt; + + if ((result_forward.pool == null) || + (result_backward.pool == null) || + (null_both == null) || + (null_forward == null) || + (null_backward == null)) + { + lod_forward = 0; + lod_backward = 0; + lod_both = 0; + strand_score = 0; + forward_alt = 'N'; + backward_alt = 'N'; + } + else + { + AlleleFrequencyEstimate estimate_forward = result_forward.pool; + AlleleFrequencyEstimate estimate_backward = result_backward.pool; + + + double p_D_both = 0; + double p_D_forward = 0; + double p_D_backward = 0; + double p_D_null_both = 0; + double p_D_null_forward = 0; + double p_D_null_backward = 0; + for (int i = 0; i < result_both.individuals.length; i++) + { + double sum_both = 0; + double sum_forward = 0; + double sum_backward = 0; + + double sum_null_both = 0; + double sum_null_forward = 0; + double sum_null_backward = 0; + + for (int j = 0; j < result_both.individuals[i].genotypeLikelihoods.likelihoods.length; j++) + { + sum_both += Math.pow(10, result_both.individuals[i].genotypeLikelihoods.likelihoods[j]); + sum_forward += Math.pow(10, result_forward.individuals[i].genotypeLikelihoods.likelihoods[j]); + sum_backward += Math.pow(10, result_backward.individuals[i].genotypeLikelihoods.likelihoods[j]); + sum_null_both += Math.pow(10, null_both.individuals[i].genotypeLikelihoods.likelihoods[j]); + sum_null_forward += Math.pow(10, null_forward.individuals[i].genotypeLikelihoods.likelihoods[j]); + sum_null_backward += Math.pow(10, null_backward.individuals[i].genotypeLikelihoods.likelihoods[j]); + } + + p_D_both += Math.log10(sum_both); + p_D_forward += Math.log10(sum_forward); + p_D_backward += Math.log10(sum_backward); + + p_D_null_both += Math.log10(sum_null_both); + p_D_null_forward += Math.log10(sum_null_forward); + p_D_null_backward += Math.log10(sum_null_backward); + } + forward_alt = estimate_forward.alt; + backward_alt = estimate_backward.alt; + lod_forward = (p_D_forward + p_D_null_backward) - p_D_null_both; + lod_backward = (p_D_backward + p_D_null_backward) - p_D_null_both; + lod_both = p_D_both - p_D_null_both; + strand_score = Math.max(lod_forward - lod_both, lod_backward - lod_both); + } + + System.out.printf("DBG %s %f %f %f %f\n", context.getLocation(), result_both.pool.pBest, null_both.pool.pBest, result_both.pool.pRef, null_both.pool.pBest); + + discovery_output_file.printf("%s %c %c %f\n", estimate_both.asPoolTabularString(), - estimate_forward.alt, - estimate_forward.lodVsRef, - estimate_backward.alt, - estimate_backward.lodVsRef); - //discovery_output_file.printf("%s\n", estimate_forward.asPoolTabularString()); - //discovery_output_file.printf("%s\n", estimate_backward.asPoolTabularString()); - //discovery_output_file.printf("\n"); + forward_alt, + backward_alt, + strand_score); - return null; + return result_both.individuals; } - private AlleleFrequencyEstimate EM(RefMetaDataTracker tracker, char ref, LocusContext context) + private class EM_Result + { + AlleleFrequencyEstimate pool; + AlleleFrequencyEstimate[] individuals; + + public EM_Result(AlleleFrequencyEstimate pool, AlleleFrequencyEstimate[] individuals) + { + this.pool = pool; + this.individuals = individuals; + } + + // Construct an EM_Result that indicates no data. + public EM_Result() + { + this.pool = null; + this.individuals = null; + } + } + + private EM_Result EM(RefMetaDataTracker tracker, char ref, LocusContext context, double EM_alt_freq, char alt, int MAX_ITERATIONS, double lodThreshold, List callers) { if (context.getReads().size() == 0) { return null; } LocusContext[] contexts = filterLocusContext(context, sample_names, 0); // EM Loop: - double EM_alt_freq; double EM_N = 0; - calls = null; - - // this line is kinda hacky - if (MAX_ITERATIONS == 1) { EM_alt_freq = -1; } - else { EM_alt_freq = 0.5; } + AlleleFrequencyEstimate[] calls = null; // (this loop is the EM cycle) double[] trajectory = new double[MAX_ITERATIONS + 1]; trajectory[0] = EM_alt_freq; double[] likelihood_trajectory = new double[MAX_ITERATIONS + 1]; likelihood_trajectory[0] = 0.0; boolean is_a_snp = false; - // Pick the alternate base - char alt = 'N'; - { - ReadBackedPileup pileup = new ReadBackedPileup(ref, context); - String bases = pileup.getBases(); - int A = 0; - int C = 0; - int G = 0; - int T = 0; - int max_count = -1; - for (int i = 0; i < bases.length(); i++) - { - char b = bases.charAt(i); - if (b == ref) { continue; } - switch (b) - { - case 'A' : A += 1; if (A > max_count) { max_count = A; alt = 'A'; } break; - case 'C' : C += 1; if (C > max_count) { max_count = C; alt = 'C'; } break; - case 'G' : G += 1; if (G > max_count) { max_count = G; alt = 'G'; } break; - case 'T' : T += 1; if (T > max_count) { max_count = T; alt = 'T'; } break; - } - } - } - for (int iterations = 0; iterations < MAX_ITERATIONS; iterations++) { // 6. Re-call from shallow coverage using the estimated frequency as a prior, @@ -193,12 +267,11 @@ public class PoolCaller extends LocusWalker if (! FRACTIONAL_COUNTS) { - //System.out.printf("DBG: %s %f %f\n", - // context.getLocation(), - // calls[i].lodVsNextBest, - // calls[i].lodVsRef); - EM_sum += calls[i].emperical_allele_frequency() * calls[i].N; - EM_N += calls[i].N; + if (Math.abs(calls[i].lodVsRef) >= lodThreshold) + { + EM_sum += calls[i].emperical_allele_frequency() * calls[i].N; + EM_N += calls[i].N; + } } else { @@ -218,14 +291,16 @@ public class PoolCaller extends LocusWalker if (likelihood_trajectory[iterations] == likelihood_trajectory[iterations+1]) { break; } - //System.out.printf("DBGTRAJ %s %f %f %f %f %f %f\n", - // context.getLocation(), - // EM_sum, - // EM_N, - // trajectory[iterations], - // trajectory[iterations+1], - // likelihood_trajectory[iterations], - // likelihood_trajectory[iterations+1]); + /* + System.out.printf("DBGTRAJ %s %f %f %f %f %f %f\n", + context.getLocation(), + EM_sum, + EM_N, + trajectory[iterations], + trajectory[iterations+1], + likelihood_trajectory[iterations], + likelihood_trajectory[iterations+1]); + */ } // 7. Output some statistics. @@ -246,9 +321,10 @@ public class PoolCaller extends LocusWalker { if (calls[i].depth == 0) { continue; } + if (calls[i].lodVsRef < lodThreshold) { continue; } + discovery_likelihood += calls[i].pBest; discovery_null += calls[i].pRef; - //System.out.printf("DBG %f %f %c %s\n", calls[i].pBest, calls[i].pRef, ref, calls[i].bases); if (calls[i].qhat == 0.0) { n_ref += 1; } if (calls[i].qhat == 0.5) { n_het += 1; } @@ -258,6 +334,8 @@ public class PoolCaller extends LocusWalker if (discovery_lod <= 0) { alt = 'N'; } //discovery_output_file.printf("%s %c %c %f %f %f %f %f %f %d %d %d\n", context.getLocation(), ref, alt, EM_alt_freq, discovery_likelihood, discovery_null, discovery_prior, discovery_lod, EM_N, n_ref, n_het, n_hom); + if (EM_N == 0) { return new EM_Result(); } + AlleleFrequencyEstimate estimate = new AlleleFrequencyEstimate(context.getLocation(), ref, alt, @@ -276,7 +354,7 @@ public class PoolCaller extends LocusWalker estimate.n_ref = n_ref; // HACK estimate.n_het = n_het; // HACK estimate.n_hom = n_hom; // HACK - return estimate; + return new EM_Result(estimate, calls); //for (int i = 0; i < likelihood_trajectory.length; i++) @@ -391,14 +469,7 @@ public class PoolCaller extends LocusWalker return contexts; } - private void CollectStrandInformation(char ref, LocusContext context) - { - List reads = context.getReads(); - List offsets = context.getOffsets(); - - } - - public void onTraversalDone(String result) + public void onTraversalDone(String[] result) { try { @@ -411,32 +482,30 @@ public class PoolCaller extends LocusWalker { e.printStackTrace(); } + out.println("PoolCaller done.\n"); return; } - public String reduceInit() + public String[] single_sample_reduce_sums = null; + public String[] reduceInit() { - discovery_output_file.printf("loc ref alt EM_alt_freq discovery_likelihood discovery_null discovery_prior discovery_lod EM_N n_ref n_het n_hom fw_alt fw_lod bw_alt bw_lod\n"); + discovery_output_file.printf("loc ref alt EM_alt_freq discovery_likelihood discovery_null discovery_prior discovery_lod EM_N n_ref n_het n_hom fw_alt bw_alt strand_score\n"); + String[] single_sample_reduce_sums = new String[callers.size()]; for (int i = 0; i < callers.size(); i++) { - callers.get(i).reduceInit(); + single_sample_reduce_sums[i] = callers.get(i).reduceInit(); } - return ""; + return single_sample_reduce_sums; } - public String reduce(AlleleFrequencyEstimate alleleFreq, String sum) + public String[] reduce(AlleleFrequencyEstimate[] alleleFreqs, String[] sum) { - if (calls == null) { return ""; } + if (alleleFreqs == null) { return sum; } for (int i = 0; i < callers.size(); i++) { - if (calls == null) { System.err.printf("calls == null\n"); } - if (calls[i] == null) { System.err.printf("calls[%d] == null\n", i); } - if (caller_sums == null) { System.err.printf("caller_sums == null\n"); } - if (callers.get(i) == null) { System.err.printf("callers[%d] == null\n", i); } - if (caller_sums.get(i) == null) { System.err.printf("caller_sums[%d] == null\n", i); } - caller_sums.set(i, callers.get(i).reduce(calls[i], caller_sums.get(i))); + sum[i] = callers.get(i).reduce(alleleFreqs[i], sum[i]); } - return ""; + return sum; } diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/SingleSampleGenotyper.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/SingleSampleGenotyper.java index 163755a9f..9d43067e6 100644 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/SingleSampleGenotyper.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/SingleSampleGenotyper.java @@ -272,26 +272,12 @@ public class SingleSampleGenotyper extends LocusWalker