From 1fa71aa31d9fada1e88681c0030bae7aeafef5b3 Mon Sep 17 00:00:00 2001 From: jmaguire Date: Tue, 7 Jul 2009 15:29:31 +0000 Subject: [PATCH] Now outputs stats. Doesn't do the downsampling thing because I think I'll have enough counts. git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@1184 348d0f76-0448-11de-a6fe-93d51630548a --- .../MultiSampleCallerAccuracyTest.java | 159 +++++++++++++++--- 1 file changed, 135 insertions(+), 24 deletions(-) diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/MultiSampleCallerAccuracyTest.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/MultiSampleCallerAccuracyTest.java index c1a8f3e70..da07eeea7 100644 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/MultiSampleCallerAccuracyTest.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/MultiSampleCallerAccuracyTest.java @@ -13,6 +13,7 @@ import org.broadinstitute.sting.playground.utils.*; import org.broadinstitute.sting.utils.*; import org.broadinstitute.sting.utils.ReadBackedPileup; import org.broadinstitute.sting.utils.cmdLine.Argument; +import org.broadinstitute.sting.playground.indels.Matrix; import java.util.*; import java.util.zip.*; @@ -24,54 +25,122 @@ import java.io.*; public class MultiSampleCallerAccuracyTest extends MultiSampleCaller { @Argument(required=false, shortName="lod_threshold", doc="") public double LOD_THRESHOLD = 1e-6; + @Argument(required=true, shortName="stats_output", doc="") public String STATS_OUTPUT; + Matrix n_variants; + Matrix n_found; + + PrintStream stats_output; public void initialize() { this.DISCOVERY_OUTPUT = "/dev/null"; this.INDIVIDUAL_OUTPUT = "/dev/null"; + super.initialize(); + + n_variants = new Matrix(sample_names.size()*2, sample_names.size()*2); + n_found = new Matrix(sample_names.size()*2, sample_names.size()*2); + + for (int i = 0; i < sample_names.size()*2; i++) + { + for (int j = 0; j < sample_names.size()*2; j++) + { + n_variants.set(i,j,0); + n_found.set(i,j,0); + } + } + + try + { + stats_output = new PrintStream(STATS_OUTPUT); + } + catch (Exception e) + { + throw new RuntimeException(e); + } + } public MultiSampleCallResult map(RefMetaDataTracker tracker, char ref, LocusContext context) { HapMapGenotypeROD hapmap = (HapMapGenotypeROD)tracker.lookup("hapmap", null); - MultiSampleCallResult call_result = super.map(tracker, ref, context); - EM_Result em_result = call_result.em_result; + // Collect all the variants and the normals. + ArrayList variant_samples = new ArrayList(); + ArrayList reference_samples = new ArrayList(); - // Compute individual accuracy. - double n_calls = 0; - double n_correct = 0; - for (int i = 0; i < em_result.sample_names.length; i++) + int n_ref_chromosomes = 0; + int n_alt_chromosomes = 0; + + String reference_genotype = String.format("%c%c", ref, ref); + for (int i = 0; i < sample_names.size(); i++) { - String sample_name = em_result.sample_names[i]; - String hyp_genotype = em_result.genotype_likelihoods[i].BestGenotype(); - String ref_genotype = hapmap.get(sample_name); - double lod = em_result.genotype_likelihoods[i].LodVsNextBest(); + String true_genotype = hapmap.get(sample_names.get(i)); + if (true_genotype == null) { continue; } - if ((lod > LOD_THRESHOLD) && (ref_genotype != null)) - { - n_calls += 1; - if (hyp_genotype.equals(ref_genotype)) - { - n_correct += 1; - } - } + if (true_genotype.equals(reference_genotype)) { reference_samples.add(sample_names.get(i)); } + else { variant_samples.add(sample_names.get(i)); } + + if (true_genotype.equals(reference_genotype)) { n_ref_chromosomes += 1; } + else if (true_genotype.contains(String.format("%c",ref))) { n_ref_chromosomes += 1; n_alt_chromosomes += 1; } + else { n_alt_chromosomes += 2; } } - out.printf("%s %.0f %.0f %.2f%%\n", - context.getLocation(), - n_calls, - n_correct, - 100.0*n_correct / n_calls); + // Put together a context. + ArrayList working_samples = new ArrayList(); + working_samples.addAll(variant_samples); + working_samples.addAll(reference_samples); + LocusContext working_context = filterLocusContextBySamples(context, working_samples); - return call_result; + // Call. + MultiSampleCallResult call_result = super.map(tracker, ref, working_context); + EM_Result em_result = call_result.em_result; + + // Compute Statistics. + if (n_variants == null) { System.out.printf("n_variants is null\n"); } + if (n_found == null) { System.out.printf("n_found is null\n"); } + n_variants.set(n_ref_chromosomes, n_alt_chromosomes, n_variants.get(n_ref_chromosomes, n_alt_chromosomes)+1); + if ((call_result.lod > LOD_THRESHOLD) && (n_alt_chromosomes >= 1)) + { + n_found.set(n_ref_chromosomes, n_alt_chromosomes, n_found.get(n_ref_chromosomes, n_alt_chromosomes)+1); + } + + return null; + } + + private void PrintStats() + { + stats_output.printf("n_reference_chromosomes n_variant_chromosomes n_sites n_found fraction_found\n"); + for (int i = 0; i < sample_names.size()*2; i++) + { + for (int j = 0; j < sample_names.size()*2; j++) + { + int N = (int)n_variants.get(i,j); + int found = (int)n_found.get(i,j); + + if (N == 0) { continue; } + if (found == 0) { continue; } + + double fraction_found = 100.0 * (double)found / (double)N; + n_variants.set(i,j,0); + n_found.set(i,j,0); + stats_output.printf("%d %d %d %d %f\n", + i, + j, + N, + found, + fraction_found); + } + } } public void onTraversalDone(String sum) { + PrintStats(); + stats_output.flush(); + stats_output.close(); out.println("MultiSampleCallerAccuracyTest done."); return; } @@ -89,4 +158,46 @@ public class MultiSampleCallerAccuracyTest extends MultiSampleCaller // END Walker Interface Functions ///////// + + ///////// + // BEGIN Utility Functions + + // Filter a locus context by sample IDs + // (pulls out only reads from the specified samples, and returns them in one context). + private LocusContext filterLocusContextBySamples(LocusContext context, List sample_names) + { + HashSet index = new HashSet(); + for (int i = 0; i < sample_names.size(); i++) + { + index.add(sample_names.get(i)); + } + + ArrayList reads = new ArrayList(); + ArrayList offsets = new ArrayList(); + + for (int i = 0; i < context.getReads().size(); i++) + { + SAMRecord read = context.getReads().get(i); + Integer offset = context.getOffsets().get(i); + String RG = (String)(read.getAttribute("RG")); + + assert(header != null); + assert(header.getReadGroup(RG) != null); + + String sample = header.getReadGroup(RG).getSample(); + if (SAMPLE_NAME_REGEX != null) { sample = sample.replaceAll(SAMPLE_NAME_REGEX, "$1"); } + + if (index.contains(sample)) + { + reads.add(read); + offsets.add(offset); + } + } + + return new LocusContext(context.getLocation(), reads, offsets); + } + + // END Utility Functions + ///////// + }