Optimizing DiagnoseTargets

* Fixed output format to get a valid vcf
   * Optimzed the per sample pileup routine O(n^2) => O(n) pileup for samples
   * Added support to overlapping intervals
   * Removed expand target functionality (for now)
   * Removed total depth (pointless metric)
This commit is contained in:
Mauricio Carneiro 2012-04-10 09:46:29 -04:00
parent 1df0adf862
commit cd842b650e
4 changed files with 153 additions and 117 deletions

View File

@ -24,6 +24,7 @@
package org.broadinstitute.sting.gatk.walkers.diagnostics.targets;
import net.sf.picard.util.PeekableIterator;
import org.broad.tribble.Feature;
import org.broadinstitute.sting.commandline.*;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
@ -32,8 +33,6 @@ import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.*;
import org.broadinstitute.sting.gatk.walkers.annotator.interfaces.AnnotatorCompatibleWalker;
import org.broadinstitute.sting.utils.GenomeLoc;
import org.broadinstitute.sting.utils.GenomeLocComparator;
import org.broadinstitute.sting.utils.GenomeLocParser;
import org.broadinstitute.sting.utils.SampleUtils;
import org.broadinstitute.sting.utils.codecs.vcf.*;
import org.broadinstitute.sting.utils.exceptions.UserException;
@ -79,10 +78,7 @@ public class DiagnoseTargets extends LocusWalker<Long, Long> implements Annotato
private IntervalBinding<Feature> intervalTrack = null;
@Output(doc = "File to which variants should be written", required = true)
protected VCFWriter vcfWriter = null;
@Argument(fullName = "expand_interval", shortName = "exp", doc = "", required = false)
private int expandInterval = 50;
private VCFWriter vcfWriter = null;
@Argument(fullName = "minimum_base_quality", shortName = "mbq", doc = "", required = false)
private int minimumBaseQuality = 20;
@ -96,13 +92,11 @@ public class DiagnoseTargets extends LocusWalker<Long, Long> implements Annotato
@Argument(fullName = "maximum_coverage", shortName = "maxcov", doc = "", required = false)
private int maximumCoverage = 700;
private TreeSet<GenomeLoc> intervalList = null; // The list of intervals of interest (plus expanded intervals if user wants them)
private HashMap<GenomeLoc, IntervalStatistics> intervalMap = null; // interval => statistics
private Iterator<GenomeLoc> intervalListIterator; // An iterator to go over all the intervals provided as we traverse the genome
private GenomeLoc currentInterval = null; // The "current" interval loaded
private IntervalStatistics currentIntervalStatistics = null; // The "current" interval being filled with statistics
private Set<String> samples = null; // All the samples being processed
private GenomeLocParser parser; // just an object to allow us to create genome locs (for the expanded intervals)
private PeekableIterator<GenomeLoc> intervalListIterator; // an iterator to go over all the intervals provided as we traverse the genome
private Set<String> samples = null; // all the samples being processed
private final Allele SYMBOLIC_ALLELE = Allele.create("<DT>", false); // avoid creating the symbolic allele multiple times
@Override
public void initialize() {
@ -111,72 +105,22 @@ public class DiagnoseTargets extends LocusWalker<Long, Long> implements Annotato
if (intervalTrack == null)
throw new UserException("This tool currently only works if you provide an interval track");
parser = new GenomeLocParser(getToolkit().getMasterSequenceDictionary()); // Important to initialize the parser before creating the intervals below
List<GenomeLoc> originalList = intervalTrack.getIntervals(getToolkit()); // The original list of targets provided by the user that will be expanded or not depending on the options provided
intervalList = new TreeSet<GenomeLoc>(new GenomeLocComparator());
intervalMap = new HashMap<GenomeLoc, IntervalStatistics>();
for (GenomeLoc interval : originalList)
intervalList.add(interval);
//addAndExpandIntervalToMap(interval);
intervalListIterator = new PeekableIterator<GenomeLoc>(intervalTrack.getIntervals(getToolkit()).listIterator());
intervalListIterator = intervalList.iterator();
// get all of the unique sample names
samples = SampleUtils.getSAMFileSamples(getToolkit().getSAMFileHeader());
// initialize the header
Set<VCFHeaderLine> headerInfo = getHeaderInfo();
vcfWriter.writeHeader(new VCFHeader(headerInfo, samples));
}
/**
* Gets the header lines for the VCF writer
*
* @return A set of VCF header lines
*/
private Set<VCFHeaderLine> getHeaderInfo() {
Set<VCFHeaderLine> headerLines = new HashSet<VCFHeaderLine>();
// INFO fields for overall data
headerLines.add(new VCFInfoHeaderLine("END", 1, VCFHeaderLineType.Integer, "Stop position of the interval"));
headerLines.add(new VCFInfoHeaderLine("DP", 1, VCFHeaderLineType.Integer, "Total depth in the site. Sum of the depth of all pools"));
headerLines.add(new VCFInfoHeaderLine("AD", 1, VCFHeaderLineType.Float, "Average depth across the interval. Sum of the depth in a lci divided by interval size."));
headerLines.add(new VCFInfoHeaderLine("Diagnose Targets", 0, VCFHeaderLineType.Flag, "DiagnoseTargets mode"));
// FORMAT fields for each genotype
headerLines.add(new VCFFormatHeaderLine("DP", 1, VCFHeaderLineType.Integer, "Total depth in the site. Sum of the depth of all pools"));
headerLines.add(new VCFFormatHeaderLine("AD", 1, VCFHeaderLineType.Float, "Average depth across the interval. Sum of the depth in a lci divided by interval size."));
// FILTER fields
for (CallableStatus stat : CallableStatus.values()) {
headerLines.add(new VCFHeaderLine(stat.name(), stat.description));
}
return headerLines;
samples = SampleUtils.getSAMFileSamples(getToolkit().getSAMFileHeader()); // get all of the unique sample names for the VCF Header
vcfWriter.writeHeader(new VCFHeader(getHeaderInfo(), samples)); // initialize the VCF header
}
@Override
public Long map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) {
GenomeLoc refLocus = ref.getLocus();
while (currentInterval == null || currentInterval.isBefore(refLocus)) { // do this for first time and while currentInterval is behind current locus
if (!intervalListIterator.hasNext())
return 0L;
if (currentInterval != null)
processIntervalStats(currentInterval, Allele.create(ref.getBase(), true));
removePastIntervals(refLocus, ref.getBase()); // process and remove any intervals in the map that are don't overlap the current locus anymore
addNewOverlappingIntervals(refLocus); // add all new intervals that may overlap this reference locus
currentInterval = intervalListIterator.next();
addAndExpandIntervalToMap(currentInterval);
currentIntervalStatistics = intervalMap.get(currentInterval);
}
if (currentInterval.isPast(refLocus)) // skip if we are behind the current interval
return 0L;
currentIntervalStatistics.addLocus(context); // Add current locus to stats
for (IntervalStatistics intervalStatistics : intervalMap.values())
intervalStatistics.addLocus(context); // Add current locus to stats
return 1L;
}
@ -198,10 +142,15 @@ public class DiagnoseTargets extends LocusWalker<Long, Long> implements Annotato
return sum + value;
}
/**
* Process all remaining intervals
*
* @param result number of loci processed by the walker
*/
@Override
public void onTraversalDone(Long result) {
for (GenomeLoc interval : intervalMap.keySet())
processIntervalStats(interval, Allele.create("<DT>", true));
for (GenomeLoc interval : intervalMap.keySet())
processIntervalStats(intervalMap.get(interval), Allele.create("A"));
}
@Override
@ -219,82 +168,111 @@ public class DiagnoseTargets extends LocusWalker<Long, Long> implements Annotato
@Override
public boolean alwaysAppendDbsnpId() {return false;}
private GenomeLoc createIntervalBefore(GenomeLoc interval) {
int start = Math.max(interval.getStart() - expandInterval, 0);
int stop = Math.max(interval.getStart() - 1, 0);
return parser.createGenomeLoc(interval.getContig(), interval.getContigIndex(), start, stop);
}
/**
* Removes all intervals that are behind the current reference locus from the intervalMap
*
* @param refLocus the current reference locus
* @param refBase the reference allele
*/
private void removePastIntervals(GenomeLoc refLocus, byte refBase) {
List<GenomeLoc> toRemove = new LinkedList<GenomeLoc>();
for (GenomeLoc interval : intervalMap.keySet())
if (interval.isBefore(refLocus)) {
processIntervalStats(intervalMap.get(interval), Allele.create(refBase, true));
toRemove.add(interval);
}
private GenomeLoc createIntervalAfter(GenomeLoc interval) {
int contigLimit = getToolkit().getSAMFileHeader().getSequenceDictionary().getSequence(interval.getContigIndex()).getSequenceLength();
int start = Math.min(interval.getStop() + 1, contigLimit);
int stop = Math.min(interval.getStop() + expandInterval, contigLimit);
return parser.createGenomeLoc(interval.getContig(), interval.getContigIndex(), start, stop);
for (GenomeLoc interval : toRemove)
intervalMap.remove(interval);
GenomeLoc interval = intervalListIterator.peek(); // clean up all intervals that we might have skipped because there was no data
while(interval != null && interval.isBefore(refLocus)) {
interval = intervalListIterator.next();
processIntervalStats(createIntervalStatistic(interval), Allele.create(refBase, true));
interval = intervalListIterator.peek();
}
}
/**
* Takes an interval and commits it to memory.
* It will expand it if so told by the -exp command line argument
* Adds all intervals that overlap the current reference locus to the intervalMap
*
* @param interval The new interval to process
* @param refLocus the current reference locus
*/
private void addAndExpandIntervalToMap(GenomeLoc interval) {
if (expandInterval > 0) {
GenomeLoc before = createIntervalBefore(interval);
GenomeLoc after = createIntervalAfter(interval);
intervalList.add(before);
intervalList.add(after);
intervalMap.put(before, new IntervalStatistics(samples, before, minimumCoverage, maximumCoverage, minimumMappingQuality, minimumBaseQuality));
intervalMap.put(after, new IntervalStatistics(samples, after, minimumCoverage, maximumCoverage, minimumMappingQuality, minimumBaseQuality));
private void addNewOverlappingIntervals(GenomeLoc refLocus) {
GenomeLoc interval = intervalListIterator.peek();
while (interval != null && !interval.isPast(refLocus)) {
System.out.println("LOCUS : " + refLocus + " -- " + interval);
intervalMap.put(interval, createIntervalStatistic(interval));
intervalListIterator.next(); // discard the interval (we've already added it to the map)
interval = intervalListIterator.peek();
}
if (!intervalList.contains(interval))
intervalList.add(interval);
intervalMap.put(interval, new IntervalStatistics(samples, interval, minimumCoverage, maximumCoverage, minimumMappingQuality, minimumBaseQuality));
}
/**
* Takes the interval, finds it in the stash, prints it to the VCF, and removes it
*
* @param interval The interval in memory that you want to write out and clear
* @param allele the allele
* @param stats The statistics of the interval
* @param refAllele the reference allele
*/
private void processIntervalStats(GenomeLoc interval, Allele allele) {
IntervalStatistics stats = intervalMap.get(interval);
private void processIntervalStats(IntervalStatistics stats, Allele refAllele) {
GenomeLoc interval = stats.getInterval();
List<Allele> alleles = new ArrayList<Allele>();
Map<String, Object> attributes = new HashMap<String, Object>();
ArrayList<Genotype> genotypes = new ArrayList<Genotype>();
alleles.add(allele);
VariantContextBuilder vcb = new VariantContextBuilder("DiagnoseTargets", interval.getContig(), interval.getStart(), interval.getStop(), alleles);
alleles.add(refAllele);
alleles.add(SYMBOLIC_ALLELE);
VariantContextBuilder vcb = new VariantContextBuilder("DiagnoseTargets", interval.getContig(), interval.getStart(), interval.getStart(), alleles);
vcb = vcb.log10PError(VariantContext.NO_LOG10_PERROR); // QUAL field makes no sense in our VCF
vcb.filters(statusesToStrings(stats.callableStatuses()));
attributes.put(VCFConstants.END_KEY, interval.getStop());
attributes.put(VCFConstants.DEPTH_KEY, stats.totalCoverage());
attributes.put("AV", stats.averageCoverage());
attributes.put(VCFConstants.DEPTH_KEY, stats.averageCoverage());
vcb = vcb.attributes(attributes);
for (String sample : samples) {
Map<String, Object> infos = new HashMap<String, Object>();
infos.put("DP", stats.getSample(sample).totalCoverage());
infos.put("AV", stats.getSample(sample).averageCoverage());
infos.put(VCFConstants.DEPTH_KEY, stats.getSample(sample).averageCoverage());
Set<String> filters = new HashSet<String>();
filters.addAll(statusesToStrings(stats.getSample(sample).getCallableStatuses()));
genotypes.add(new Genotype(sample, alleles, VariantContext.NO_LOG10_PERROR, filters, infos, false));
genotypes.add(new Genotype(sample, null, VariantContext.NO_LOG10_PERROR, filters, infos, false));
}
vcb = vcb.genotypes(genotypes);
vcfWriter.add(vcb.make());
intervalMap.remove(interval);
}
/**
* Gets the header lines for the VCF writer
*
* @return A set of VCF header lines
*/
private static Set<VCFHeaderLine> getHeaderInfo() {
Set<VCFHeaderLine> headerLines = new HashSet<VCFHeaderLine>();
// INFO fields for overall data
headerLines.add(new VCFInfoHeaderLine(VCFConstants.END_KEY, 1, VCFHeaderLineType.Integer, "Stop position of the interval"));
headerLines.add(new VCFInfoHeaderLine(VCFConstants.DEPTH_KEY, 1, VCFHeaderLineType.Float, "Average depth across the interval. Sum of the depth in a lci divided by interval size."));
headerLines.add(new VCFInfoHeaderLine("Diagnose Targets", 0, VCFHeaderLineType.Flag, "DiagnoseTargets mode"));
// FORMAT fields for each genotype
headerLines.add(new VCFFormatHeaderLine(VCFConstants.DEPTH_KEY, 1, VCFHeaderLineType.Float, "Average depth across the interval. Sum of the depth in a lci divided by interval size."));
// FILTER fields
for (CallableStatus stat : CallableStatus.values())
headerLines.add(new VCFHeaderLine(stat.name(), stat.description));
return headerLines;
}
private static Set<String> statusesToStrings(Set<CallableStatus> statuses) {
Set<String> output = new HashSet<String>(statuses.size());
@ -303,4 +281,8 @@ public class DiagnoseTargets extends LocusWalker<Long, Long> implements Annotato
return output;
}
private IntervalStatistics createIntervalStatistic(GenomeLoc interval) {
return new IntervalStatistics(samples, interval, minimumCoverage, maximumCoverage, minimumMappingQuality, minimumBaseQuality);
}
}

View File

@ -26,6 +26,7 @@ package org.broadinstitute.sting.gatk.walkers.diagnostics.targets;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.utils.GenomeLoc;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import org.broadinstitute.sting.utils.pileup.ReadBackedPileup;
import java.util.HashMap;
@ -52,18 +53,28 @@ public class IntervalStatistics {
return samples.get(sample);
}
public GenomeLoc getInterval() {
return interval;
}
public void addLocus(AlignmentContext context) {
ReadBackedPileup pileup = context.getBasePileup();
for (String sample : samples.keySet())
getSample(sample).addLocus(context.getLocation(), pileup.getPileupForSample(sample));
Map<String, ReadBackedPileup> samplePileups = pileup.getPileupsForSamples(samples.keySet());
for (Map.Entry<String, ReadBackedPileup> entry : samplePileups.entrySet()) {
String sample = entry.getKey();
ReadBackedPileup samplePileup = entry.getValue();
SampleStatistics sampleStatistics = samples.get(sample);
if (sampleStatistics == null)
throw new ReviewedStingException(String.format("Trying to add locus statistics to a sample (%s) that doesn't exist in the Interval.", sample));
sampleStatistics.addLocus(context.getLocation(), samplePileup);
}
}
public long totalCoverage() {
if (preComputedTotalCoverage < 0)
calculateTotalCoverage();
return preComputedTotalCoverage;
}
public double averageCoverage() {
if (preComputedTotalCoverage < 0)

View File

@ -677,11 +677,11 @@ public abstract class AbstractReadBackedPileup<RBP extends AbstractReadBackedPil
PileupElementTracker<PE> filteredElements = tracker.getElements(sampleNames);
return filteredElements != null ? (RBP) createNewPileup(loc, filteredElements) : null;
} else {
HashSet<String> hashSampleNames = new HashSet<String>(sampleNames); // to speed up the "contains" access in the for loop
HashSet<String> hashSampleNames = new HashSet<String>(sampleNames); // to speed up the "contains" access in the for loop
UnifiedPileupElementTracker<PE> filteredTracker = new UnifiedPileupElementTracker<PE>();
for (PE p : pileupElementTracker) {
GATKSAMRecord read = p.getRead();
if (sampleNames != null) { // still checking on sampleNames because hashSampleNames will never be null. And empty means something else.
if (sampleNames != null) { // still checking on sampleNames because hashSampleNames will never be null. And empty means something else.
if (read.getReadGroup() != null && hashSampleNames.contains(read.getReadGroup().getSample()))
filteredTracker.add(p);
} else {
@ -693,6 +693,38 @@ public abstract class AbstractReadBackedPileup<RBP extends AbstractReadBackedPil
}
}
@Override
public Map<String, ReadBackedPileup> getPileupsForSamples(Collection<String> sampleNames) {
Map<String, ReadBackedPileup> result = new HashMap<String, ReadBackedPileup>();
if (pileupElementTracker instanceof PerSamplePileupElementTracker) {
PerSamplePileupElementTracker<PE> tracker = (PerSamplePileupElementTracker<PE>) pileupElementTracker;
for (String sample : sampleNames) {
PileupElementTracker<PE> filteredElements = tracker.getElements(sampleNames);
if (filteredElements != null)
result.put(sample, createNewPileup(loc, filteredElements));
}
} else {
Map<String, UnifiedPileupElementTracker<PE>> trackerMap = new HashMap<String, UnifiedPileupElementTracker<PE>>();
for (String sample : sampleNames) { // initialize pileups for each sample
UnifiedPileupElementTracker<PE> filteredTracker = new UnifiedPileupElementTracker<PE>();
trackerMap.put(sample, filteredTracker);
}
for (PE p : pileupElementTracker) { // go through all pileup elements only once and add them to the respective sample's pileup
GATKSAMRecord read = p.getRead();
if (read.getReadGroup() != null) {
String sample = read.getReadGroup().getSample();
UnifiedPileupElementTracker<PE> tracker = trackerMap.get(sample);
if (tracker != null) // we only add the pileup the requested samples. Completely ignore the rest
tracker.add(p);
}
}
for (Map.Entry<String, UnifiedPileupElementTracker<PE>> entry : trackerMap.entrySet()) // create the RBP for each sample
result.put(entry.getKey(), createNewPileup(loc, entry.getValue()));
}
return result;
}
@Override
public RBP getPileupForSample(String sampleName) {

View File

@ -32,6 +32,7 @@ import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
/**
* A data retrieval interface for accessing parts of the pileup.
@ -159,6 +160,16 @@ public interface ReadBackedPileup extends Iterable<PileupElement>, HasGenomeLoca
*/
public ReadBackedPileup getPileupForSamples(Collection<String> sampleNames);
/**
* Gets the particular subset of this pileup for each given sample name.
*
* Same as calling getPileupForSample for all samples, but in O(n) instead of O(n^2).
*
* @param sampleNames Name of the sample to use.
* @return A subset of this pileup containing only reads with the given sample.
*/
public Map<String, ReadBackedPileup> getPileupsForSamples(Collection<String> sampleNames);
/**
* Gets the particular subset of this pileup with the given sample name.