diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/ReadBackedPhasingWalker.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/ReadBackedPhasingWalker.java index 5d8408621..6e52e1195 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/ReadBackedPhasingWalker.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/ReadBackedPhasingWalker.java @@ -54,63 +54,73 @@ import java.util.*; * Walks along all loci, caching a user-defined window of VariantContext sites, and then finishes phasing them when they go out of range (using downstream reads). * Use '-BTI variant' to only stop at positions in the VCF file bound to 'variant'. */ -@Requires(value={},referenceMetaData=@RMD(name="variant",type= ReferenceOrderedDatum.class)) -public class ReadBackedPhasingWalker extends LocusWalker>, VariantContextStats> { +@Requires(value = {}, referenceMetaData = @RMD(name = "variant", type = ReferenceOrderedDatum.class)) +public class ReadBackedPhasingWalker extends LocusWalker { - @Argument(fullName="cacheWindowSize", shortName="cacheWindow", doc="The window size (in bases) to cache variant sites and their reads; [default:300]", required=false) - protected Integer cacheWindow = 300; + @Argument(fullName = "cacheWindowSize", shortName = "cacheWindow", doc = "The window size (in bases) to cache variant sites and their reads; [default:20000]", required = false) + protected Integer cacheWindow = 20000; - @Argument(fullName="phasedVCFFile", shortName="phasedVCF", doc="The name of the phased VCF file output", required=true) + @Argument(fullName = "phasedVCFFile", shortName = "phasedVCF", doc = "The name of the phased VCF file output", required = true) protected String phasedVCFFile = null; private VCFWriter writer = null; - private LinkedList siteQueue = null; + private LinkedList siteQueue = null; + private VariantAndReads prevVr = null; // the VC emitted after phasing, and the alignment bases at the position emitted + + private static double SMALL_THRESH = 1e-6; + private static int MAX_NUM_PHASE_SITES = 20; // 2^20 == 10^6 biallelic haplotypes private void initializeVcfWriter(VariantContext vc) { // setup the header fields Set hInfo = new HashSet(); hInfo.addAll(VCFUtils.getHeaderFields(getToolkit())); hInfo.add(new VCFHeaderLine("reference", getToolkit().getArguments().referenceFile.getName())); + hInfo.add(new VCFFormatHeaderLine("PQ", 1, VCFHeaderLineType.Float, "Read-backed phasing quality score")); writer = new VCFWriterImpl(new File(phasedVCFFile)); writer.writeHeader(new VCFHeader(hInfo, new TreeSet(vc.getSampleNames()))); } public void initialize() { - siteQueue = new LinkedList(); + siteQueue = new LinkedList(); + prevVr = new VariantAndReads(null, null, true); } - public boolean generateExtendedEvents() { // want to see indels - return true; + public boolean generateExtendedEvents() { + return false; } - public VariantContextStats reduceInit() { return new VariantContextStats(); } + public PhasingStats reduceInit() { + return new PhasingStats(); + } /** * For each site of interest, cache the current site and then use the cache to phase all upstream sites * for which "sufficient" information has already been observed. * - * @param tracker the meta-data tracker - * @param ref the reference base - * @param context the context for the given locus + * @param tracker the meta-data tracker + * @param ref the reference base + * @param context the context for the given locus * @return statistics of and list of all phased VariantContexts and their base pileup that have gone out of cacheWindow range. */ - public Pair> map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) { - VariantContextStats vcStats = new VariantContextStats(); - List phasedList = new LinkedList(); - if ( tracker == null ) - return new Pair>(vcStats, phasedList); + public PhasingStatsAndOutput map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) { + if (tracker == null) + return null; - List rods = tracker.getReferenceMetaData("variant"); - ListIterator rodIt = rods.listIterator(); - while (rodIt.hasNext()) { - VariantContext vc = VariantContextAdaptors.toVariantContext("variant", rodIt.next(), ref); - if (vc.getType() == VariantContext.Type.MNP) { - throw new StingException("Doesn't support phasing for multiple-nucleotide polymorphism!"); + PhasingStats phaseStats = new PhasingStats(); + + LinkedList rodNames = new LinkedList(); + rodNames.add("variant"); + boolean requireStartHere = true; // only see each VariantContext once + boolean takeFirstOnly = false; + for (VariantContext vc : tracker.getVariantContexts(ref, rodNames, null, context.getLocation(), requireStartHere, takeFirstOnly)) { + boolean processVariant = true; + if (!vc.isSNP() || !vc.isBiallelic()) { + processVariant = false; } - VariantAndAlignment va = new VariantAndAlignment(vc, context); - siteQueue.add(va); + VariantAndReads vr = new VariantAndReads(vc, context, processVariant); + siteQueue.add(vr); int numReads = 0; if (context.hasBasePileup()) { @@ -119,116 +129,251 @@ public class ReadBackedPhasingWalker extends LocusWalker phasedList = processQueue(ref.getLocus(), phaseStats); - GenomeLoc refLoc = ref.getLocus(); - while (!siteQueue.isEmpty()) { - VariantContext vc = siteQueue.peek().variant; - if (!isInWindowRange(refLoc, VariantContextUtils.getLocation(vc))) { // Already saw all variant positions within cacheWindow distance ahead of vc (on its contig) - VariantContext phasedVc = this.phaseVariantAndRemove(); - phasedList.add(phasedVc); - } - else { // refLoc is still not far enough ahead of vc - break; // since we ASSUME that the VCF is ordered by - } - } - - return new Pair>(vcStats, phasedList); + return new PhasingStatsAndOutput(phaseStats, phasedList); } - /* Phase vc (head of siteQueue) using all VariantContext objects in the siteQueue that are + private List processQueue(GenomeLoc loc, PhasingStats phaseStats) { + List vcList = new LinkedList(); + + while (!siteQueue.isEmpty()) { + if (loc != null) { + VariantContext vc = siteQueue.peek().variant; + if (isInWindowRange(loc, VariantContextUtils.getLocation(vc))) { + // loc is still not far enough ahead of vc (since we ASSUME that the VCF is ordered by ) + break; + } + // Already saw all variant positions within cacheWindow distance ahead of vc (on its contig) + } + VariantContext phasedVc = finalizePhasingAndRemove(phaseStats); + vcList.add(phasedVc); + } + return vcList; + } + + /* Finalize phasing of vc (head of siteQueue) using all VariantContext objects in the siteQueue that are within cacheWindow distance ahead of vc (on its contig). ASSUMES: 1. siteQueue is NOT empty. 2. All VariantContexts in siteQueue are in positions downstream of vc (head of queue). */ - private VariantContext phaseVariantAndRemove() { - VariantContext vc = siteQueue.peek().variant; - ListIterator windowIt = siteQueue.listIterator(); - int toIndex = 0; - while (windowIt.hasNext()) { - if (isInWindowRange(vc, windowIt.next().variant)) { - toIndex++; + private VariantContext finalizePhasingAndRemove(PhasingStats phaseStats) { + VariantAndReads vr = siteQueue.remove(); // remove vr from head of queue + VariantContext vc = vr.variant; + if (!vr.processVariant) + return vc; // return vc as is + + boolean hasPreviousSite = previousIsRelevantTo(vc); + + logger.debug("Will phase vc = " + VariantContextUtils.getLocation(vc)); + + LinkedList windowVaList = new LinkedList(); + if (hasPreviousSite) { + windowVaList.add(prevVr); // need to add one position for phasing context + windowVaList.add(vr); // add position to be phased + for (VariantAndReads nextVr : siteQueue) { + if (!isInWindowRange(vc, nextVr.variant)) //nextVr too far ahead of the range used for phasing vc + break; + if (nextVr.processVariant) // include in the phasing computation + windowVaList.add(nextVr); } - else { //moved past the relevant range used for phasing - break; + + if (logger.isDebugEnabled()) { + ListIterator windowVcIt = windowVaList.listIterator(); + while (windowVcIt.hasNext()) { + VariantContext phaseInfoVc = windowVcIt.next().variant; + logger.debug("Using phaseInfoVc = " + VariantContextUtils.getLocation(phaseInfoVc)); + } } } - List windowVcList = siteQueue.subList(0,toIndex); - // - if (true) { - out.println("Will phase vc = " + VariantContextUtils.getLocation(vc)); - ListIterator windowVcIt = windowVcList.listIterator(); - while (windowVcIt.hasNext()) { - VariantContext phaseInfoVc = windowVcIt.next().variant; - out.println("Using phaseInfoVc = " + VariantContextUtils.getLocation(phaseInfoVc)); - } - out.println(""); - } - // + logger.debug(""); Map sampGenotypes = vc.getGenotypes(); + VariantContext prevVc = prevVr.variant; Map phasedGtMap = new TreeMap(); - for (Map.Entry entry : sampGenotypes.entrySet()) { - String samp = entry.getKey(); - Genotype gt = entry.getValue(); + // Perform per-sample phasing: + TreeMap samplePhaseStats = new TreeMap(); + for (Map.Entry sampGtEntry : sampGenotypes.entrySet()) { + logger.debug("sample = " + sampGtEntry.getKey()); - if (gt.getPloidy() != 2) { - throw new StingException("Doesn't support phasing for ploidy that is not 2!"); - } - Allele topAll = gt.getAllele(0); - Allele botAll = gt.getAllele(1); + boolean genotypesArePhased = true; // phase by default - ListIterator windowVcIt = windowVcList.listIterator(); - while (windowVcIt.hasNext()) { - VariantAndAlignment va = windowVcIt.next(); - VariantContext phaseInfoVc = va.variant; - AlignmentContext phaseInfoContext = va.alignment; + String samp = sampGtEntry.getKey(); + Genotype gt = sampGtEntry.getValue(); + Biallele biall = new Biallele(gt); + HashMap gtAttribs = new HashMap(gt.getAttributes()); - ReadBackedPileup reads = null; - if (phaseInfoContext.hasBasePileup()) { - reads = phaseInfoContext.getBasePileup(); + if (hasPreviousSite && gt.isHet() && prevVc.getGenotype(samp).isHet()) { //otherwise, can trivially phase + logger.debug("NON-TRIVIALLY CARE about TOP vs. BOTTOM for: "); + logger.debug("\n" + biall); + + LinkedList sampleWindowVaList = new LinkedList(); + for (VariantAndReads phaseInfoVr : windowVaList) { + VariantContext phaseInfoVc = phaseInfoVr.variant; + Genotype phaseInfoGt = phaseInfoVc.getGenotype(samp); + if (phaseInfoGt.isHet()) { // otherwise, of no value to phasing + sampleWindowVaList.add(phaseInfoVr); + logger.debug("STARTING TO PHASE USING POS = " + VariantContextUtils.getLocation(phaseInfoVc)); + } } - else if (phaseInfoContext.hasExtendedEventPileup()) { - reads = phaseInfoContext.getExtendedEventPileup(); - } - if (reads != null) { - ReadBackedPileup sampleReads = null; - if (reads.getSamples().contains(samp)) { - // Update the phasing table based on the reads for this sample: - sampleReads = reads.getPileupForSample(samp); - for (PileupElement p : sampleReads) { - SAMRecord rd = p.getRead(); - out.println("read = " + rd); + if (sampleWindowVaList.size() > MAX_NUM_PHASE_SITES) + logger.warn("Trying to phase within a window of " + cacheWindow + " bases yields " + sampleWindowVaList.size() + " heterozygous sites to phase -- EXPECT DELAYS!"); + + PhasingTable sampleHaps = new PhasingTable(); + + // Initialize phasing table with appropriate entries: + // + // 1. THIS IMPLEMENTATION IS INEFFICIENT SINCE IT DOES NOT PREALLOCATE + // THE ArrayList USED, BUT RATHER APPENDS TO IT EACH TIME. + // + // 2. THIS IMPLEMENTATION WILL FAIL WHEN NOT DEALING WITH SNP Alleles, SINCE THEN THE Allele.getBases() + // FUNCTION WILL RETURN VARIABLE-LENGTH Byte ARRAYS. IN THAT CASE, BaseArray WILL NEED TO BE CONVERTED TO + // AN ArrayList OF Allele [OR SIMILAR OBJECT] + for (VariantAndReads phaseInfoVr : sampleWindowVaList) { + VariantContext phaseInfoVc = phaseInfoVr.variant; + Genotype phaseInfoGt = phaseInfoVc.getGenotype(samp); + + if (sampleHaps.isEmpty()) { + for (Allele sampAll : phaseInfoGt.getAlleles()) { + sampleHaps.addEntry(new Haplotype(sampAll.getBases())); + } + } + else { + PhasingTable oldHaps = sampleHaps; + Iterator oldHapIt = oldHaps.iterator(); + sampleHaps = new PhasingTable(); + while (oldHapIt.hasNext()) { + PhasingTable.PhasingTableEntry pte = oldHapIt.next(); + Haplotype oldHap = pte.getHaplotype(); + for (Allele sampAll : phaseInfoGt.getAlleles()) { + ArrayList bases = oldHap.cloneBaseArrayList(); + for (byte b : sampAll.getBases()) { // LENGTH NOT PRE-DEFINED FOR NON-SNPs (MNP or INDELS!!) + bases.add(b); // INEFFICIENT! + } + Haplotype newHap = new Haplotype(BaseArray.getBasesPrimitiveNoNulls(bases)); + sampleHaps.addEntry(newHap); + } } } } + + // Assemble the "sub-reads" from the heterozygous positions for this sample: + LinkedList allPositions = new LinkedList(); + for (VariantAndReads phaseInfoVr : sampleWindowVaList) { + ReadBasesAtPosition readBases = phaseInfoVr.sampleReadBases.get(samp); + allPositions.add(readBases); + } + HashMap allReads = convertReadBasesAtPositionToReads(allPositions); + logger.debug("Number of reads at sites: " + allReads.size()); + + // Update the phasing table based on each of the sub-reads for this sample: + int numUsableReads = 0; + for (Map.Entry nameToReads : allReads.entrySet()) { + Read rd = nameToReads.getValue(); + if (rd.numNonNulls() <= 1) {// can't possibly provide any phasing information + continue; + } + if (false) + logger.debug("rd = " + rd + "\tname = " + nameToReads.getKey()); + + LinkedList compatHaps = new LinkedList(); + Iterator hapIt = sampleHaps.iterator(); + while (hapIt.hasNext()) { + PhasingTable.PhasingTableEntry pte = hapIt.next(); + if (rd.isCompatible(pte.getHaplotype())) + compatHaps.add(pte); + } + + if (!compatHaps.isEmpty()) { // otherwise, nothing to do + numUsableReads++; + double addVal = rd.matchScore() / compatHaps.size(); // don't overcount, so divide up the score evenly + for (PhasingTable.PhasingTableEntry pte : compatHaps) { + pte.addScore(addVal); + + if (false) { + if (addVal > SMALL_THRESH) { + logger.debug("score(" + rd + "," + pte.getHaplotype() + ") = " + addVal); + } + } + } + } + } + logger.debug("\nPhasing table [AFTER CALCULATION]:\n" + sampleHaps + "\n"); + logger.debug("numUsableReads = " + numUsableReads); + + /* Map a phase and its "complement" to a single representative phase, but marginalized to the first 2 positions + [i.e., the previous position and the current position]: + */ + ComplementAndMarginalizeHaplotypeMapper cmhm = new ComplementAndMarginalizeHaplotypeMapper(sampleWindowVaList, samp, 2); + sampleHaps = sampleHaps.mapHaplotypes(cmhm); + + logger.debug("\nPhasing table [AFTER MAPPING]:\n" + sampleHaps + "\n"); + + // Determine the phase at this position: + sampleHaps.normalizeScores(); + PhasingTable.PhasingTableEntry maxEntry = sampleHaps.maxEntry(); + double score = maxEntry.getScore(); + genotypesArePhased = (score > SMALL_THRESH); + + if (genotypesArePhased) { + Biallele prevBiall = new Biallele(prevVc.getGenotype(samp)); + ensurePhasing(maxEntry.getHaplotype(), biall, prevBiall); + gtAttribs.put("PQ", new Float(score)); + + logger.debug("CHOSE hap:\t" + maxEntry.getHaplotype() + "\tscore:\t" + score); + logger.debug("PHASED:\n" + biall + "\n\n"); + } + + PhaseCounts sampPhaseCounts = samplePhaseStats.get(samp); + if (sampPhaseCounts == null) { + sampPhaseCounts = new PhaseCounts(); + samplePhaseStats.put(samp, sampPhaseCounts); + } + sampPhaseCounts.numTestedSites++; + sampPhaseCounts.numPhased += (genotypesArePhased ? 1 : 0); } - Random rn = new Random(); - boolean genotypesArePhased = (rn.nextDouble() > 0.5); - - boolean swapChromosomes = (rn.nextDouble() > 0.5); - if (swapChromosomes) { - Allele tmp = topAll; - topAll = botAll; - botAll = tmp; - } - List phasedAll = new ArrayList(); - phasedAll.add(0, topAll); - phasedAll.add(1, botAll); - - Genotype phasedGt = new Genotype(gt.getSampleName(), phasedAll, gt.getNegLog10PError(), gt.getFilters(), gt.getAttributes(), genotypesArePhased); + List phasedAll = biall.getAllelesAsList(); + Genotype phasedGt = new Genotype(gt.getSampleName(), phasedAll, gt.getNegLog10PError(), gt.getFilters(), gtAttribs, genotypesArePhased); phasedGtMap.put(samp, phasedGt); } - siteQueue.remove(); // remove vc from head of queue - return new VariantContext(vc.getName(), vc.getChr(), vc.getStart(), vc.getEnd(), vc.getAlleles(), phasedGtMap, vc.getNegLog10PError(), vc.getFilters(), vc.getAttributes()); + VariantContext phasedVc = new VariantContext(vc.getName(), vc.getChr(), vc.getStart(), vc.getEnd(), vc.getAlleles(), phasedGtMap, vc.getNegLog10PError(), vc.getFilters(), vc.getAttributes()); + prevVr.variant = phasedVc; + prevVr.sampleReadBases = vr.sampleReadBases; + + phaseStats.addIn(new PhasingStats(samplePhaseStats)); + + return phasedVc; + } + + /* + Ensure that curBiall is phased relative to prevBiall as specified by hap. + */ + + public static void ensurePhasing(Haplotype hap, Biallele curBiall, Biallele prevBiall) { + if (hap.size() < 2) + throw new StingException("LOGICAL ERROR: Only considering haplotypes of length > 2!"); + + byte prevBase = hap.getBase(0); // The 1st base in the haplotype + byte curBase = hap.getBase(1); // The 2nd base in the haplotype + + boolean chosePrevTopChrom = prevBiall.matchesTopBase(prevBase); + boolean choseCurTopChrom = curBiall.matchesTopBase(curBase); + if (chosePrevTopChrom != choseCurTopChrom) + curBiall.swapAlleles(); + } + + private boolean previousIsRelevantTo(VariantContext vc) { + VariantContext prevVc = prevVr.variant; + return (prevVc != null && VariantContextUtils.getLocation(prevVc).onSameContig(VariantContextUtils.getLocation(vc))); } private boolean isInWindowRange(VariantContext vc1, VariantContext vc2) { @@ -243,7 +388,7 @@ public class ReadBackedPhasingWalker extends LocusWalker> statsAndList, VariantContextStats stats) { - Iterator varContIter = statsAndList.second.iterator(); - writeVarContIter(varContIter); - - stats.addTo(statsAndList.first); + public PhasingStats reduce(PhasingStatsAndOutput statsAndList, PhasingStats stats) { + if (statsAndList != null) { + writeVarContIter(statsAndList.output.iterator()); + stats.addIn(statsAndList.ps); + } return stats; } /** * Phase anything left in the cached siteQueue, and report the number of reads and VariantContexts processed. * - * @param result the number of reads and VariantContexts seen. + * @param result the number of reads and VariantContexts seen. */ - public void onTraversalDone(VariantContextStats result) { - List finalList = new LinkedList(); - while (!siteQueue.isEmpty()) { - VariantContext phasedVc = this.phaseVariantAndRemove(); - finalList.add(phasedVc); - } + public void onTraversalDone(PhasingStats result) { + List finalList = processQueue(null, result); writeVarContIter(finalList.iterator()); - - if ( writer != null ) + if (writer != null) writer.close(); out.println("Number of reads observed: " + result.getNumReads()); - out.println("Number of variant sites observed: " + result.getNumVarSites()); + out.println("Number of variant sites observed: " + result.getNumVarSites()); out.println("Average coverage: " + ((double) result.getNumReads() / result.getNumVarSites())); + + out.println("\n-- Phasing summary --"); + for (Map.Entry sampPhaseCountEntry : result.getPhaseCounts()) { + PhaseCounts pc = sampPhaseCountEntry.getValue(); + out.println("Sample: " + sampPhaseCountEntry.getKey() + "\tNumber of tested sites: " + pc.numTestedSites + "\tNumber of phased sites: " + pc.numPhased); + } + out.println(""); } protected void writeVarContIter(Iterator varContIter) { - while (varContIter.hasNext()) { - VariantContext vc = varContIter.next(); - writeVCF(vc); - } + while (varContIter.hasNext()) { + VariantContext vc = varContIter.next(); + writeVCF(vc); + } } - private static class VariantAndAlignment { - public VariantContext variant; - public AlignmentContext alignment; + protected static HashMap convertReadBasesAtPositionToReads(Collection basesAtPositions) { + HashMap reads = new HashMap(); - public VariantAndAlignment(VariantContext variant, AlignmentContext alignment) { + int index = 0; + for (ReadBasesAtPosition rbp : basesAtPositions) { + Iterator readBaseIt = rbp.iterator(); + while (readBaseIt.hasNext()) { + ReadBasesAtPosition.ReadBase rb = readBaseIt.next(); + String readName = rb.readName; + byte base = rb.base; + + Read rd = reads.get(readName); + if (rd == null) { + rd = new Read(basesAtPositions.size()); + reads.put(readName, rd); + } + rd.updateBase(index, base); + } + index++; + } + + return reads; + } + + + /* + Inner classes: + */ + + private static class Biallele { + public Allele top; + public Allele bottom; + + public Biallele(Genotype gt) { + if (gt.getPloidy() != 2) + throw new StingException("Doesn't support phasing for ploidy that is not 2!"); + + this.top = gt.getAllele(0); + this.bottom = gt.getAllele(1); + } + + public void swapAlleles() { + Allele tmp = top; + top = bottom; + bottom = tmp; + } + + public List getAllelesAsList() { + List allList = new ArrayList(2); + allList.add(0, top); + allList.add(1, bottom); + return allList; + } + + public byte getTopBase() { + byte[] topBases = top.getBases(); + if (topBases.length != 1) + throw new StingException("LOGICAL ERROR: should not process non-SNP sites!"); + + return topBases[0]; + } + + public byte getBottomBase() { + byte[] bottomBases = bottom.getBases(); + if (bottomBases.length != 1) + throw new StingException("LOGICAL ERROR: should not process non-SNP sites!"); + + return bottomBases[0]; + } + + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("Top:\t" + top.getBaseString() + "\n"); + sb.append("Bot:\t" + bottom.getBaseString() + "\n"); + return sb.toString(); + } + + public boolean matchesTopBase(byte base) { + boolean matchesTop; + if (base == getTopBase()) + matchesTop = true; + else if (base == getBottomBase()) + matchesTop = false; + else + throw new StingException("LOGICAL ERROR: base MUST match either TOP or BOTTOM!"); + + return matchesTop; + } + + public byte getOtherBase(byte base) { + byte topBase = getTopBase(); + byte botBase = getBottomBase(); + + if (base == topBase) + return botBase; + else if (base == botBase) + return topBase; + else + throw new StingException("LOGICAL ERROR: base MUST match either TOP or BOTTOM!"); + } + } + + private static class VariantAndReads { + public VariantContext variant; + public HashMap sampleReadBases; + public boolean processVariant; + + public VariantAndReads(VariantContext variant, AlignmentContext alignment, boolean processVariant) { this.variant = variant; - this.alignment = alignment; + this.sampleReadBases = new HashMap(); + this.processVariant = processVariant; + + if (alignment != null) { + ReadBackedPileup pileup = null; + if (alignment.hasBasePileup()) { + pileup = alignment.getBasePileup(); + } + else if (alignment.hasExtendedEventPileup()) { + pileup = alignment.getExtendedEventPileup(); + } + if (pileup != null) { + for (String samp : pileup.getSamples()) { + ReadBackedPileup samplePileup = pileup.getPileupForSample(samp); + ReadBasesAtPosition readBases = new ReadBasesAtPosition(); + for (PileupElement p : samplePileup) { + if (!p.isDeletion()) + readBases.putReadBase(p.getRead().getReadName(), p.getBase()); + } + sampleReadBases.put(samp, readBases); + } + } + } + } + } + + private static class ReadBasesAtPosition { + // list of: + private LinkedList bases; + + public ReadBasesAtPosition() { + this.bases = new LinkedList(); + } + + public void putReadBase(String readName, byte b) { + bases.add(new ReadBase(readName, b)); + } + + public Iterator iterator() { + return bases.iterator(); + } + + private static class ReadBase { + public String readName; + public byte base; + + public ReadBase(String readName, byte base) { + this.readName = readName; + this.base = base; + } + } + } + + private static abstract class HaplotypeMapper { + abstract Haplotype map(Haplotype hap); + } + + private static class ComplementAndMarginalizeHaplotypeMapper extends HaplotypeMapper { + private List vaList; + private String sample; + private int marginalizeLength; + + public ComplementAndMarginalizeHaplotypeMapper(List vaList, String sample, int marginalizeLength) { + this.vaList = vaList; + this.sample = sample; + this.marginalizeLength = marginalizeLength; + } + + public Haplotype map(Haplotype hap) { + if (hap.size() != vaList.size()) + throw new StingException("INTERNAL ERROR: hap.size() != vaList.size()"); + + Biallele firstPosBiallele = new Biallele(vaList.get(0).variant.getGenotype(sample)); + if (firstPosBiallele.matchesTopBase(hap.getBase(0))) { + /* hap already matches the representative haplotype [arbitrarily defined to be + the one with the top base in the VariantContext at the 1st position]: + */ + return hap.subHaplotype(0, marginalizeLength); // only want first marginalizeLength positions + } + + if (false) + logger.debug("hap = " + hap); + + // Take the other base at EACH position of the Haplotype: + byte[] complementBases = new byte[Math.min(hap.size(), marginalizeLength)]; + int index = 0; + for (VariantAndReads vr : vaList) { + VariantContext vc = vr.variant; + Biallele biall = new Biallele(vc.getGenotype(sample)); + + if (false) + logger.debug("biall =\n" + biall); + + complementBases[index] = biall.getOtherBase(hap.getBase(index)); + if (++index == marginalizeLength) // only want first marginalizeLength positions + break; + } + return new Haplotype(complementBases); + } + } + + private static class PhasingTable { + private LinkedList table; + + public PhasingTable() { + this.table = new LinkedList(); + } + + public PhasingTableEntry addEntry(Haplotype haplotype) { + PhasingTableEntry pte = new PhasingTableEntry(haplotype); + table.add(pte); + return pte; + } + + public Iterator iterator() { + return table.iterator(); + } + + public boolean isEmpty() { + return table.isEmpty(); + } + + public PhasingTableEntry maxEntry() { + if (table.isEmpty()) + return null; + + PhasingTableEntry maxPte = null; + for (PhasingTableEntry pte : table) { + if (maxPte == null || pte.getScore() > maxPte.getScore()) { + maxPte = pte; + } + } + return maxPte.clone(); + } + + // Assumes that scores are NON-NEGATIVE: + + public void normalizeScores() { + double normalizeBy = 0.0; + for (PhasingTableEntry pte : table) { + normalizeBy += pte.getScore(); + } + logger.debug("normalizeBy = " + normalizeBy); + + if (normalizeBy > SMALL_THRESH) { // otherwise, will have precision problems + for (PhasingTableEntry pte : table) { + pte.setScore(pte.getScore() / normalizeBy); + } + } + } + + public PhasingTable mapHaplotypes(HaplotypeMapper hm) { + class Score { + private double d; + + Score(double d) { + this.d = d; + } + + Score addValue(double v) { + d += v; + return this; + } + + double value() { + return d; + } + } + TreeMap hapMap = new TreeMap(); + + Iterator entryIt = iterator(); + while (entryIt.hasNext()) { + PhasingTableEntry pte = entryIt.next(); + Haplotype rep = hm.map(pte.getHaplotype()); + if (false) + logger.debug("MAPPED: " + pte.getHaplotype() + " -> " + rep); + + Score score = hapMap.get(rep); + if (score == null) { + score = new Score(0.0); + hapMap.put(rep, score); + } + score.addValue(pte.getScore()); + } + + PhasingTable combo = new PhasingTable(); + for (Map.Entry hapScore : hapMap.entrySet()) { + combo.addEntry(hapScore.getKey()).addScore(hapScore.getValue().value()); + } + return combo; + } + + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append("-------------------\n"); + Iterator hapIt = iterator(); + while (hapIt.hasNext()) { + PhasingTable.PhasingTableEntry pte = hapIt.next(); + sb.append("Haplotype:\t" + pte.getHaplotype() + "\tScore:\t" + pte.getScore() + "\n"); + } + sb.append("-------------------\n"); + return sb.toString(); + } + + public static class PhasingTableEntry implements Comparable, Cloneable { + private Haplotype haplotype; + private double score; + + public PhasingTableEntry(Haplotype haplotype) { + this.haplotype = haplotype; + this.score = 0.0; + } + + public PhasingTableEntry(PhasingTableEntry other) { + this.haplotype = other.haplotype.clone(); + this.score = other.score; + } + + public PhasingTableEntry clone() { + try { + super.clone(); + } catch (CloneNotSupportedException e) { + e.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. + } + return new PhasingTableEntry(this); + } + + public Haplotype getHaplotype() { + return haplotype; + } + + public double getScore() { + return score; + } + + public double addScore(double addVal) { + score += addVal; + return score; + } + + private double setScore(double newVal) { + score = newVal; + return score; + } + + public int compareTo(PhasingTableEntry that) { + return new Double(Math.signum(this.score - that.score)).intValue(); + } } } } -class VariantContextStats { +abstract class BaseArray implements Comparable { + protected ArrayList bases; + + public BaseArray(byte[] bases) { + this.bases = new ArrayList(bases.length); + for (int i = 0; i < bases.length; i++) { + this.bases.add(i, bases[i]); + } + } + + public BaseArray(Byte[] bases) { + this.bases = new ArrayList(bases.length); + for (int i = 0; i < bases.length; i++) { + this.bases.add(i, bases[i]); + } + } + + public BaseArray(int length) { + this(newNullArray(length)); + } + + static Byte[] newNullArray(int length) { + Byte[] bArr = new Byte[length]; + Arrays.fill(bArr, null); + return bArr; + } + + public void updateBase(int index, Byte base) { + bases.set(index, base); + } + + public Byte getBase(int index) { + return bases.get(index); + } + + public int size() { + return bases.size(); + } + + public static Byte[] getBases(List baseList) { + return baseList.toArray(new Byte[baseList.size()]); + } + + // Will thow NullPointerException if baseList contains Byte == null: + + public static byte[] getBasesPrimitiveNoNulls(List baseList) { + int sz = baseList.size(); + byte[] b = new byte[sz]; + for (int i = 0; i < sz; i++) { + b[i] = baseList.get(i); + } + return b; + } + + public ArrayList cloneBaseArrayList() { + return new ArrayList(bases); + } + + public String toString() { + StringBuilder sb = new StringBuilder(bases.size()); + for (Byte b : bases) { + sb.append(b != null ? (char) b.byteValue() : "_"); + } + return sb.toString(); + } + + public int compareTo(BaseArray that) { + int sz = this.bases.size(); + if (sz != that.bases.size()) + return (sz - that.bases.size()); + + for (int i = 0; i < sz; i++) { + Byte thisBase = this.getBase(i); + Byte thatBase = that.getBase(i); + if (thisBase == null || thatBase == null) { + if (thisBase == null && thatBase != null) { + return -1; + } + else if (thisBase != null && thatBase == null) { + return 1; + } + } + else if (!thisBase.equals(thatBase)) { + return thisBase - thatBase; + } + } + return 0; + } +} + +class Haplotype extends BaseArray implements Cloneable { + public Haplotype(byte[] bases) { + super(bases); + } + + private Haplotype(Byte[] bases) { + super(bases); + } + + private Haplotype(int length) { + super(length); + } + + public Haplotype(Haplotype other) { + this(getBases(other.bases)); + } + + public void updateBase(int index, Byte base) { + if (base == null) { + throw new StingException("Internal error: should NOT put null for a missing Haplotype base!"); + } + super.updateBase(index, base); + } + + public Haplotype clone() { + try { + super.clone(); + } catch (CloneNotSupportedException e) { + e.printStackTrace(); //To change body of catch statement use File | Settings | File Templates. + } + return new Haplotype(this); + } + + // Returns a new Haplotype containing the portion of this Haplotype between the specified fromIndex, inclusive, and toIndex, exclusive. + + public Haplotype subHaplotype(int fromIndex, int toIndex) { + return new Haplotype(BaseArray.getBasesPrimitiveNoNulls(cloneBaseArrayList().subList(fromIndex, toIndex))); + } +} + +class Read extends BaseArray { + // + // ADD IN SOME DATA MEMBERS [OR INPUT TO CONSTRUCTORS] WITH READ QUALITY (MAPPING QUALITY) AND ARRAY OF BASE QUALITIES... + // THESE WILL BE USED IN matchScore() + // + + public Read(byte[] bases) { + super(bases); + } + + public Read(Byte[] bases) { + super(bases); + } + + public Read(int length) { + super(length); + } + + public int numNonNulls() { + int num = 0; + for (int i = 0; i < bases.size(); i++) { + if (getBase(i) != null) + num++; + } + return num; + } + + // + + public double matchScore() { + return 1.0; + } + // + + /* Checks if the two BaseArrays are consistent where bases are not null. + */ + + public boolean isCompatible(Haplotype hap) { + int sz = this.bases.size(); + if (sz != hap.bases.size()) + throw new StingException("Read and Haplotype should have same length to be compared!"); + + for (int i = 0; i < sz; i++) { + Byte thisBase = this.getBase(i); + Byte hapBase = hap.getBase(i); + if (thisBase != null && hapBase != null && !thisBase.equals(hapBase)) { + return false; + } + } + return true; + } +} + +class PhasingStats { private int numReads; private int numVarSites; - public VariantContextStats() { - this.numReads = 0; - this.numVarSites = 0; + // Map of: sample -> PhaseCounts: + private TreeMap samplePhaseStats; + + public PhasingStats() { + this(new TreeMap()); } - public VariantContextStats(int numReads, int numVarSites) { + public PhasingStats(int numReads, int numVarSites) { this.numReads = numReads; this.numVarSites = numVarSites; + this.samplePhaseStats = new TreeMap(); } - public void addTo(VariantContextStats other) { + public PhasingStats(TreeMap samplePhaseStats) { + this.numReads = 0; + this.numVarSites = 0; + this.samplePhaseStats = samplePhaseStats; + } + + public void addIn(PhasingStats other) { this.numReads += other.numReads; this.numVarSites += other.numVarSites; + + for (Map.Entry sampPhaseEntry : other.samplePhaseStats.entrySet()) { + String sample = sampPhaseEntry.getKey(); + PhaseCounts otherCounts = sampPhaseEntry.getValue(); + PhaseCounts thisCounts = this.samplePhaseStats.get(sample); + if (thisCounts == null) { + thisCounts = new PhaseCounts(); + this.samplePhaseStats.put(sample, thisCounts); + } + thisCounts.addIn(otherCounts); + } } - public int getNumReads() {return numReads;} - public int getNumVarSites() {return numVarSites;} + public int getNumReads() { + return numReads; + } + + public int getNumVarSites() { + return numVarSites; + } + + public Collection> getPhaseCounts() { + return samplePhaseStats.entrySet(); + } +} + +class PhaseCounts { + public int numTestedSites; // number of het sites directly succeeding het sites + public int numPhased; + + public PhaseCounts() { + this.numTestedSites = 0; + this.numPhased = 0; + } + + public void addIn(PhaseCounts other) { + this.numTestedSites += other.numTestedSites; + this.numPhased += other.numPhased; + } +} + +class PhasingStatsAndOutput { + public PhasingStats ps; + public List output; + + public PhasingStatsAndOutput(PhasingStats ps, List output) { + this.ps = ps; + this.output = output; + } } \ No newline at end of file