From dd6d5aadf948b50e681af34cfaf983d73b43a516 Mon Sep 17 00:00:00 2001 From: kiran Date: Mon, 18 Jan 2010 00:55:12 +0000 Subject: [PATCH] Computes empirical confusion matrices, optionally with up to five bases of preceding context git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@2621 348d0f76-0448-11de-a6fe-93d51630548a --- .../diagnostics/ComputeConfusionMatrix.java | 143 ++++++++++++++++++ 1 file changed, 143 insertions(+) create mode 100755 java/src/org/broadinstitute/sting/playground/gatk/walkers/diagnostics/ComputeConfusionMatrix.java diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/diagnostics/ComputeConfusionMatrix.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/diagnostics/ComputeConfusionMatrix.java new file mode 100755 index 000000000..c14f6dc1e --- /dev/null +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/diagnostics/ComputeConfusionMatrix.java @@ -0,0 +1,143 @@ +package org.broadinstitute.sting.playground.gatk.walkers.diagnostics; + +import org.broadinstitute.sting.gatk.walkers.LocusWalker; +import org.broadinstitute.sting.gatk.walkers.Reference; +import org.broadinstitute.sting.gatk.walkers.Window; +import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; +import org.broadinstitute.sting.gatk.contexts.ReferenceContext; +import org.broadinstitute.sting.gatk.contexts.AlignmentContext; +import org.broadinstitute.sting.utils.cmdLine.Argument; +import org.broadinstitute.sting.utils.BaseUtils; +import net.sf.samtools.SAMRecord; + +import java.util.HashMap; +import java.util.Arrays; +import java.util.Hashtable; + +/** + * Computes empirical base confusion matrix, and optionally computes + * these matrices with up to five bases of preceding context + */ +@Reference(window=@Window(start=-5,stop=5)) +public class ComputeConfusionMatrix extends LocusWalker { + @Argument(fullName="minimumDepth", shortName="minDepth", doc="Require locus pileup to have specified minimum depth", required=false) + public Integer MIN_DEPTH = 10; + + @Argument(fullName="maximumDepth", shortName="maxDepth", doc="Require locus pileup to have specified maximum depth", required=false) + public Integer MAX_DEPTH = 100; + + @Argument(fullName="contextWindowSize", shortName="window", doc="Size of context window", required=false) + public Integer WINDOW_SIZE = 0; + + private Hashtable confusionCounts = new Hashtable(); + + public boolean filter(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) { + int pileupSize = context.size(); + + int numAlts = 0; + int[] baseCounts = context.getBasePileup().getBaseCounts(); + for (int baseIndex = 0; baseIndex < baseCounts.length; baseIndex++) { + if (baseIndex != ref.getBaseIndex()) { + numAlts += baseCounts[baseIndex]; + } + } + + return ( + pileupSize >= MIN_DEPTH && // don't process regions without a reasonable pileup + pileupSize < MAX_DEPTH && // don't process suspiciously overcovered regions + ref.getBases().length % 2 == 1 && // don't process regions that don't have a full context window + numAlts == 1 && // don't process regions that have more than one mismatching base + ref.getBaseIndex() >= 0 // don't process a locus with an ambiguous reference base + ); + } + + public Integer map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) { + int windowLength = ref.getBases().length; + int windowCenter = (windowLength - 1)/2; + + String fwRefBases = new String(ref.getBases()); + String fwRefBase = String.format("%c", ref.getBase()); + String fwWindowLeft = fwRefBases.substring(windowCenter - WINDOW_SIZE, windowCenter); + + //String rcRefBases = new String(BaseUtils.simpleReverseComplement(ref.getBases())); + //String rcRefBase = String.format("%c", BaseUtils.simpleComplement(ref.getBase())); + //String rcWindowRight = rcRefBases.substring(windowCenter + 1, windowCenter + 1 + WINDOW_SIZE); + + int[] baseCounts = context.getBasePileup().getBaseCounts(); + int altBaseIndex = -1; + for (int baseIndex = 0; baseIndex < 4; baseIndex++) { + if (baseCounts[baseIndex] == 1) { + altBaseIndex = baseIndex; + } + } + + String fwAltBase = String.format("%c", BaseUtils.baseIndexToSimpleBase(altBaseIndex)); + //String rcAltBase = BaseUtils.simpleComplement(fwAltBase); + + for (int readIndex = 0; readIndex < context.getReads().size(); readIndex++) { + SAMRecord read = context.getReads().get(readIndex); + int offset = context.getOffsets().get(readIndex); + + char base = read.getReadString().charAt(offset); + int baseIndex = BaseUtils.simpleBaseToBaseIndex(base); + + if (baseIndex == altBaseIndex) { + if (read.getReadNegativeStrandFlag()) { + //incrementConfusionCounts(rcWindowRight, rcRefBase, rcAltBase); + } else { + incrementConfusionCounts(fwWindowLeft, fwAltBase, fwRefBase); + } + } + } + + return null; + } + + private void incrementConfusionCounts(String context, String altBase, String refBase) { + String key = String.format("%s:%s:%s", context, altBase, refBase); + + Integer counts = confusionCounts.get(key); + if (counts == null) { counts = 0; } + + confusionCounts.put(key, counts + 1); + } + + public Integer reduceInit() { + return null; + } + + public Integer reduce(Integer value, Integer sum) { + return null; + } + + public void onTraversalDone(Integer result) { + String[] keys = confusionCounts.keySet().toArray(new String[0]); + Arrays.sort(keys); + + HashMap contextualNorms = new HashMap(); + for (String key : keys) { + String[] fields = key.split(":"); + + String contextualKey = String.format("%s:%s", fields[0], fields[1]); + Integer contextualCount = contextualNorms.get(contextualKey); + if (contextualCount == null) { contextualCount = 0; } + contextualNorms.put(contextualKey, contextualCount + confusionCounts.get(key)); + } + + out.printf("confusionMatrix\tcontext\talt\tref\tcontextualCounts\tcontextualPercentage\n"); + for (String key : keys) { + String[] fields = key.split(":"); + String contextualKey = String.format("%s:%s", fields[0], fields[1]); + + out.printf( + "confusionMatrix\t%s\t%s\t%s\t%d\t%d\t%f\n", + fields[0], + fields[1], + fields[2], + confusionCounts.get(key), + contextualNorms.get(contextualKey), + confusionCounts.get(key)/((float) contextualNorms.get(contextualKey)) + ); + } + } +}