Updated Bayesian phasing method to output per-site phasing statistics (and to not cap PQ at 40)

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@4064 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
fromer 2010-08-19 19:55:47 +00:00
parent 04e5b28f6d
commit effeedf1a3
2 changed files with 101 additions and 78 deletions

View File

@ -36,14 +36,10 @@ import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.refdata.ReferenceOrderedDatum; import org.broadinstitute.sting.gatk.refdata.ReferenceOrderedDatum;
import org.broadinstitute.sting.gatk.walkers.*; import org.broadinstitute.sting.gatk.walkers.*;
import org.broadinstitute.sting.commandline.Argument; import org.broadinstitute.sting.commandline.Argument;
import org.broadinstitute.sting.utils.BaseUtils; import org.broadinstitute.sting.utils.*;
import org.broadinstitute.sting.utils.PreciseNonNegativeDouble;
import org.broadinstitute.sting.utils.QualityUtils;
import org.broadinstitute.sting.utils.StingException;
import org.broadinstitute.sting.utils.vcf.VCFUtils; import org.broadinstitute.sting.utils.vcf.VCFUtils;
import org.broadinstitute.sting.utils.genotype.vcf.VCFWriter; import org.broadinstitute.sting.utils.genotype.vcf.VCFWriter;
import org.broadinstitute.sting.utils.genotype.vcf.VCFWriterImpl; import org.broadinstitute.sting.utils.genotype.vcf.VCFWriterImpl;
import org.broadinstitute.sting.utils.GenomeLoc;
import org.broadinstitute.sting.utils.pileup.PileupElement; import org.broadinstitute.sting.utils.pileup.PileupElement;
import org.broadinstitute.sting.utils.pileup.ReadBackedPileup; import org.broadinstitute.sting.utils.pileup.ReadBackedPileup;
@ -67,13 +63,17 @@ public class ReadBackedPhasingWalker extends LocusWalker<PhasingStatsAndOutput,
@Argument(fullName = "maxPhaseSites", shortName = "maxSites", doc = "The maximum number of successive heterozygous sites permitted to be used by the phasing algorithm; [default:20]", required = false) @Argument(fullName = "maxPhaseSites", shortName = "maxSites", doc = "The maximum number of successive heterozygous sites permitted to be used by the phasing algorithm; [default:20]", required = false)
protected Integer maxPhaseSites = 20; // 2^20 == 10^6 biallelic haplotypes protected Integer maxPhaseSites = 20; // 2^20 == 10^6 biallelic haplotypes
@Argument(fullName = "phaseScoreThresh", shortName = "phaseThresh", doc = "The minimum phasing quality score required to output phasing; [default:0.66]", required = false) @Argument(fullName = "phaseQualityThresh", shortName = "phaseThresh", doc = "The minimum phasing quality score required to output phasing; [default:4.77]", required = false)
protected Double phaseScoreThresh = 0.66; protected Double phaseQualityThresh = 4.77; // PQ = 4.77 <=> P(error) = 10^(-4.77/10) = 0.33, P(correct) = 0.66, so that we have odds ratio of >= 2
@Argument(fullName = "phasedVCFFile", shortName = "phasedVCF", doc = "The name of the phased VCF file output", required = true) @Argument(fullName = "phasedVCFFile", shortName = "phasedVCF", doc = "The name of the phased VCF file output", required = true)
protected String phasedVCFFile = null; protected String phasedVCFFile = null;
@Argument(fullName = "variantStatsFilePrefix", shortName = "variantStats", doc = "The prefix of the VCF/phasing statistics files", required = false)
protected String variantStatsFilePrefix = null;
private VCFWriter writer = null; private VCFWriter writer = null;
private PhasingQualityStatsWriter statsWriter = null;
private LinkedList<VariantAndReads> siteQueue = null; private LinkedList<VariantAndReads> siteQueue = null;
private VariantAndReads prevVr = null; // the VC emitted after phasing, and the alignment bases at the position emitted private VariantAndReads prevVr = null; // the VC emitted after phasing, and the alignment bases at the position emitted
@ -96,6 +96,9 @@ public class ReadBackedPhasingWalker extends LocusWalker<PhasingStatsAndOutput,
public void initialize() { public void initialize() {
siteQueue = new LinkedList<VariantAndReads>(); siteQueue = new LinkedList<VariantAndReads>();
prevVr = new VariantAndReads(null, null, true); prevVr = new VariantAndReads(null, null, true);
if (variantStatsFilePrefix != null)
statsWriter = new PhasingQualityStatsWriter(variantStatsFilePrefix);
} }
public boolean generateExtendedEvents() { public boolean generateExtendedEvents() {
@ -279,14 +282,18 @@ public class ReadBackedPhasingWalker extends LocusWalker<PhasingStatsAndOutput,
logger.debug("\nPhasing table [AFTER NORMALIZATION]:\n" + sampleHaps + "\n"); logger.debug("\nPhasing table [AFTER NORMALIZATION]:\n" + sampleHaps + "\n");
PhasingTable.PhasingTableEntry maxEntry = sampleHaps.maxEntry(); PhasingTable.PhasingTableEntry maxEntry = sampleHaps.maxEntry();
double score = maxEntry.getScore().getValue(); double posteriorProb = maxEntry.getScore().getValue();
logger.debug("MAX hap:\t" + maxEntry.getHaplotypeClass() + "\tscore:\t" + score); int phaseQuality = new Integer(QualityUtils.probToQual(posteriorProb, 0.0)); // 0.0 <=> do NOT cap the quality!
logger.debug("MAX hap:\t" + maxEntry.getHaplotypeClass() + "\tposteriorProb:\t" + posteriorProb + "\tphaseQuality:\t" + phaseQuality);
genotypesArePhased = (score >= phaseScoreThresh); if (statsWriter != null)
statsWriter.addStat(samp, distance(prevVc, vc), phaseQuality);
genotypesArePhased = (phaseQuality >= phaseQualityThresh);
if (genotypesArePhased) { if (genotypesArePhased) {
Biallele prevBiall = new Biallele(prevVc.getGenotype(samp)); Biallele prevBiall = new Biallele(prevVc.getGenotype(samp));
ensurePhasing(biall, prevBiall, maxEntry.getHaplotypeClass().getRepresentative()); ensurePhasing(biall, prevBiall, maxEntry.getHaplotypeClass().getRepresentative());
gtAttribs.put("PQ", new Integer(QualityUtils.probToQual(score))); gtAttribs.put("PQ", phaseQuality);
logger.debug("CHOSE PHASE:\n" + biall + "\n\n"); logger.debug("CHOSE PHASE:\n" + biall + "\n\n");
} }
@ -347,6 +354,15 @@ public class ReadBackedPhasingWalker extends LocusWalker<PhasingStatsAndOutput,
return (loc1.onSameContig(loc2) && loc1.distance(loc2) <= cacheWindow); return (loc1.onSameContig(loc2) && loc1.distance(loc2) <= cacheWindow);
} }
private static int distance(VariantContext vc1, VariantContext vc2) {
GenomeLoc loc1 = VariantContextUtils.getLocation(vc1);
GenomeLoc loc2 = VariantContextUtils.getLocation(vc2);
if (!loc1.onSameContig(loc2))
return Integer.MAX_VALUE;
return loc1.distance(loc2);
}
private void writeVCF(VariantContext vc) { private void writeVCF(VariantContext vc) {
if (writer == null) if (writer == null)
initializeVcfWriter(vc); initializeVcfWriter(vc);
@ -381,12 +397,14 @@ public class ReadBackedPhasingWalker extends LocusWalker<PhasingStatsAndOutput,
writeVarContList(finalList); writeVarContList(finalList);
if (writer != null) if (writer != null)
writer.close(); writer.close();
if (statsWriter != null)
statsWriter.close();
out.println("Number of reads observed: " + result.getNumReads()); out.println("Number of reads observed: " + result.getNumReads());
out.println("Number of variant sites observed: " + result.getNumVarSites()); out.println("Number of variant sites observed: " + result.getNumVarSites());
out.println("Average coverage: " + ((double) result.getNumReads() / result.getNumVarSites())); out.println("Average coverage: " + ((double) result.getNumReads() / result.getNumVarSites()));
out.println("\n-- Phasing summary [minimal haplotype probability: " + phaseScoreThresh + "] --"); out.println("\n-- Phasing summary [minimal haplotype probability: " + phaseQualityThresh + "] --");
for (Map.Entry<String, PhaseCounts> sampPhaseCountEntry : result.getPhaseCounts()) { for (Map.Entry<String, PhaseCounts> sampPhaseCountEntry : result.getPhaseCounts()) {
PhaseCounts pc = sampPhaseCountEntry.getValue(); PhaseCounts pc = sampPhaseCountEntry.getValue();
out.println("Sample: " + sampPhaseCountEntry.getKey() + "\tNumber of tested sites: " + pc.numTestedSites + "\tNumber of phased sites: " + pc.numPhased); out.println("Sample: " + sampPhaseCountEntry.getKey() + "\tNumber of tested sites: " + pc.numTestedSites + "\tNumber of phased sites: " + pc.numPhased);
@ -1085,56 +1103,45 @@ class PhasingStatsAndOutput {
} }
} }
class CardinalityCounter implements Iterator<int[]>, Iterable<int[]> { class PhasingQualityStatsWriter {
private int[] cards; private String variantStatsFilePrefix;
private int[] valList; private HashMap<String, BufferedWriter> sampleToStatsWriter = new HashMap<String, BufferedWriter>();
private boolean hasNext;
public CardinalityCounter(int[] cards) { public PhasingQualityStatsWriter(String variantStatsFilePrefix) {
this.cards = cards; this.variantStatsFilePrefix = variantStatsFilePrefix;
this.valList = new int[cards.length];
for (int i = 0; i < cards.length; i++) {
if (this.cards[i] <= 0)
throw new StingException("CANNOT have zero cardinalities!");
this.valList[i] = 0;
}
this.hasNext = true;
} }
public boolean hasNext() { public void addStat(String sample, int distanceFromPrevious, int phasingQuality) {
return hasNext; BufferedWriter sampWriter = sampleToStatsWriter.get(sample);
} if (sampWriter == null) {
String fileName = variantStatsFilePrefix + "." + sample + ".distance_PQ.txt";
public int[] next() { FileOutputStream output;
if (!hasNext()) try {
throw new StingException("CANNOT iterate past end!"); output = new FileOutputStream(fileName);
} catch (FileNotFoundException e) {
// Copy the assignment to be returned: throw new RuntimeException("Unable to create phasing quality stats file at location: " + fileName);
int[] nextList = new int[valList.length];
for (int i = 0; i < valList.length; i++)
nextList[i] = valList[i];
// Find the assignment after this one:
hasNext = false;
int i = cards.length - 1;
for (; i >= 0; i--) {
if (valList[i] < (cards[i] - 1)) {
valList[i]++;
hasNext = true;
break;
} }
valList[i] = 0; sampWriter = new BufferedWriter(new OutputStreamWriter(output));
sampleToStatsWriter.put(sample, sampWriter);
}
try {
sampWriter.write(distanceFromPrevious + "\t" + phasingQuality + "\n");
sampWriter.flush();
} catch (IOException e) {
throw new RuntimeException("Unable to write to per-sample phasing quality stats file", e);
} }
return nextList;
} }
public void remove() { public void close() {
throw new StingException("Cannot remove from CardinalityCounter!"); for (Map.Entry<String, BufferedWriter> sampWriterEntry : sampleToStatsWriter.entrySet()) {
BufferedWriter sampWriter = sampWriterEntry.getValue();
try {
sampWriter.flush();
sampWriter.close();
} catch (IOException e) {
throw new RuntimeException("Unable to close per-sample phasing quality stats file");
}
}
} }
}
public Iterator<int[]> iterator() {
return this;
}
}

View File

@ -32,7 +32,7 @@ public class PreciseNonNegativeDouble implements Comparable<PreciseNonNegativeDo
} }
} }
public PreciseNonNegativeDouble(org.broadinstitute.sting.utils.PreciseNonNegativeDouble pd) { public PreciseNonNegativeDouble(PreciseNonNegativeDouble pd) {
this.logValue = pd.logValue; this.logValue = pd.logValue;
} }
@ -44,24 +44,28 @@ public class PreciseNonNegativeDouble implements Comparable<PreciseNonNegativeDo
return logValue; return logValue;
} }
public org.broadinstitute.sting.utils.PreciseNonNegativeDouble setEqual(org.broadinstitute.sting.utils.PreciseNonNegativeDouble other) { public PreciseNonNegativeDouble setEqual(PreciseNonNegativeDouble other) {
logValue = other.logValue; logValue = other.logValue;
return this; return this;
} }
public org.broadinstitute.sting.utils.PreciseNonNegativeDouble plus(org.broadinstitute.sting.utils.PreciseNonNegativeDouble other) { public PreciseNonNegativeDouble plus(PreciseNonNegativeDouble other) {
return new org.broadinstitute.sting.utils.PreciseNonNegativeDouble(this).plusEqual(other); return new PreciseNonNegativeDouble(this).plusEqual(other);
} }
public org.broadinstitute.sting.utils.PreciseNonNegativeDouble times(org.broadinstitute.sting.utils.PreciseNonNegativeDouble other) { public PreciseNonNegativeDouble times(PreciseNonNegativeDouble other) {
return new org.broadinstitute.sting.utils.PreciseNonNegativeDouble(this).timesEqual(other); return new PreciseNonNegativeDouble(this).timesEqual(other);
} }
public org.broadinstitute.sting.utils.PreciseNonNegativeDouble div(org.broadinstitute.sting.utils.PreciseNonNegativeDouble other) { public PreciseNonNegativeDouble div(PreciseNonNegativeDouble other) {
return new org.broadinstitute.sting.utils.PreciseNonNegativeDouble(this).divEqual(other); return new PreciseNonNegativeDouble(this).divEqual(other);
} }
public int compareTo(org.broadinstitute.sting.utils.PreciseNonNegativeDouble other) { public PreciseNonNegativeDouble absDiff(PreciseNonNegativeDouble other) {
return new PreciseNonNegativeDouble(absSubLog(this.logValue, other.logValue), true);
}
public int compareTo(PreciseNonNegativeDouble other) {
// Since log is monotonic: e^a R e^b <=> a R b, where R is one of: >, <, == // Since log is monotonic: e^a R e^b <=> a R b, where R is one of: >, <, ==
double logValDiff = this.logValue - other.logValue; double logValDiff = this.logValue - other.logValue;
if (Math.abs(logValDiff) <= EQUALS_THRESH) if (Math.abs(logValDiff) <= EQUALS_THRESH)
@ -70,23 +74,33 @@ public class PreciseNonNegativeDouble implements Comparable<PreciseNonNegativeDo
return new Double(Math.signum(logValDiff)).intValue(); return new Double(Math.signum(logValDiff)).intValue();
} }
public boolean equals(org.broadinstitute.sting.utils.PreciseNonNegativeDouble other) { public boolean equals(PreciseNonNegativeDouble other) {
return (this.compareTo(other) == 0); return (this.compareTo(other) == 0);
} }
public boolean gt(org.broadinstitute.sting.utils.PreciseNonNegativeDouble other) { public boolean gt(PreciseNonNegativeDouble other) {
return (this.compareTo(other) > 0); return (this.compareTo(other) > 0);
} }
public boolean lt(org.broadinstitute.sting.utils.PreciseNonNegativeDouble other) { public boolean lt(PreciseNonNegativeDouble other) {
return (this.compareTo(other) < 0); return (this.compareTo(other) < 0);
} }
public org.broadinstitute.sting.utils.PreciseNonNegativeDouble plusEqual(org.broadinstitute.sting.utils.PreciseNonNegativeDouble other) { public PreciseNonNegativeDouble plusEqual(PreciseNonNegativeDouble other) {
logValue = addInLogSpace(logValue, other.logValue); logValue = addInLogSpace(logValue, other.logValue);
return this; return this;
} }
public PreciseNonNegativeDouble timesEqual(PreciseNonNegativeDouble other) {
logValue += other.logValue;
return this;
}
public PreciseNonNegativeDouble divEqual(PreciseNonNegativeDouble other) {
logValue -= other.logValue;
return this;
}
// If x = log(a), y = log(b), returns log(a+b) // If x = log(a), y = log(b), returns log(a+b)
public static double addInLogSpace(double x, double y) { public static double addInLogSpace(double x, double y) {
if (x == INFINITY || y == INFINITY) return INFINITY; //log( e^INFINITY + e^y ) = INFINITY if (x == INFINITY || y == INFINITY) return INFINITY; //log( e^INFINITY + e^y ) = INFINITY
@ -108,15 +122,17 @@ public class PreciseNonNegativeDouble implements Comparable<PreciseNonNegativeDo
return maxVal + Math.log(1.0 + Math.exp(negDiff)); return maxVal + Math.log(1.0 + Math.exp(negDiff));
} }
public org.broadinstitute.sting.utils.PreciseNonNegativeDouble timesEqual(org.broadinstitute.sting.utils.PreciseNonNegativeDouble other) { // If x = log(a), y = log(b), returns log |a-b|
logValue += other.logValue; double absSubLog(double x, double y) {
return this; if (x == -INFINITY && y == -INFINITY) {
} // log |e^-INFINITY - e^-INFINITY| = log |0-0| = log(0) = -INFINITY
return -INFINITY;
public org.broadinstitute.sting.utils.PreciseNonNegativeDouble divEqual(org.broadinstitute.sting.utils.PreciseNonNegativeDouble other) { }
logValue -= other.logValue; else if (x >= y) // x + log(1-e^(y-x)) = log(a) + log(1-e^(log(b)-log(a))) = log(a) + log(1-b/a) = a - b = |a-b|, since x >= y
return this; return x + Math.log(1 - Math.exp(y-x));
} else
return y + Math.log(1 - Math.exp(x-y));
}
public String toString() { public String toString() {
return new StringBuilder().append(getValue()).toString(); return new StringBuilder().append(getValue()).toString();