Optimizing DiagnoseTargets

* Fixed output format to get a valid vcf
   * Optimzed the per sample pileup routine O(n^2) => O(n) pileup for samples
   * Added support to overlapping intervals
   * Removed expand target functionality (for now)
   * Removed total depth (pointless metric)
This commit is contained in:
Mauricio Carneiro 2012-04-10 09:46:29 -04:00
parent 1df0adf862
commit cd842b650e
4 changed files with 153 additions and 117 deletions

View File

@ -24,6 +24,7 @@
package org.broadinstitute.sting.gatk.walkers.diagnostics.targets; package org.broadinstitute.sting.gatk.walkers.diagnostics.targets;
import net.sf.picard.util.PeekableIterator;
import org.broad.tribble.Feature; import org.broad.tribble.Feature;
import org.broadinstitute.sting.commandline.*; import org.broadinstitute.sting.commandline.*;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext; import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
@ -32,8 +33,6 @@ import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.*; import org.broadinstitute.sting.gatk.walkers.*;
import org.broadinstitute.sting.gatk.walkers.annotator.interfaces.AnnotatorCompatibleWalker; import org.broadinstitute.sting.gatk.walkers.annotator.interfaces.AnnotatorCompatibleWalker;
import org.broadinstitute.sting.utils.GenomeLoc; import org.broadinstitute.sting.utils.GenomeLoc;
import org.broadinstitute.sting.utils.GenomeLocComparator;
import org.broadinstitute.sting.utils.GenomeLocParser;
import org.broadinstitute.sting.utils.SampleUtils; import org.broadinstitute.sting.utils.SampleUtils;
import org.broadinstitute.sting.utils.codecs.vcf.*; import org.broadinstitute.sting.utils.codecs.vcf.*;
import org.broadinstitute.sting.utils.exceptions.UserException; import org.broadinstitute.sting.utils.exceptions.UserException;
@ -79,10 +78,7 @@ public class DiagnoseTargets extends LocusWalker<Long, Long> implements Annotato
private IntervalBinding<Feature> intervalTrack = null; private IntervalBinding<Feature> intervalTrack = null;
@Output(doc = "File to which variants should be written", required = true) @Output(doc = "File to which variants should be written", required = true)
protected VCFWriter vcfWriter = null; private VCFWriter vcfWriter = null;
@Argument(fullName = "expand_interval", shortName = "exp", doc = "", required = false)
private int expandInterval = 50;
@Argument(fullName = "minimum_base_quality", shortName = "mbq", doc = "", required = false) @Argument(fullName = "minimum_base_quality", shortName = "mbq", doc = "", required = false)
private int minimumBaseQuality = 20; private int minimumBaseQuality = 20;
@ -96,13 +92,11 @@ public class DiagnoseTargets extends LocusWalker<Long, Long> implements Annotato
@Argument(fullName = "maximum_coverage", shortName = "maxcov", doc = "", required = false) @Argument(fullName = "maximum_coverage", shortName = "maxcov", doc = "", required = false)
private int maximumCoverage = 700; private int maximumCoverage = 700;
private TreeSet<GenomeLoc> intervalList = null; // The list of intervals of interest (plus expanded intervals if user wants them)
private HashMap<GenomeLoc, IntervalStatistics> intervalMap = null; // interval => statistics private HashMap<GenomeLoc, IntervalStatistics> intervalMap = null; // interval => statistics
private Iterator<GenomeLoc> intervalListIterator; // An iterator to go over all the intervals provided as we traverse the genome private PeekableIterator<GenomeLoc> intervalListIterator; // an iterator to go over all the intervals provided as we traverse the genome
private GenomeLoc currentInterval = null; // The "current" interval loaded private Set<String> samples = null; // all the samples being processed
private IntervalStatistics currentIntervalStatistics = null; // The "current" interval being filled with statistics
private Set<String> samples = null; // All the samples being processed private final Allele SYMBOLIC_ALLELE = Allele.create("<DT>", false); // avoid creating the symbolic allele multiple times
private GenomeLocParser parser; // just an object to allow us to create genome locs (for the expanded intervals)
@Override @Override
public void initialize() { public void initialize() {
@ -111,72 +105,22 @@ public class DiagnoseTargets extends LocusWalker<Long, Long> implements Annotato
if (intervalTrack == null) if (intervalTrack == null)
throw new UserException("This tool currently only works if you provide an interval track"); throw new UserException("This tool currently only works if you provide an interval track");
parser = new GenomeLocParser(getToolkit().getMasterSequenceDictionary()); // Important to initialize the parser before creating the intervals below
List<GenomeLoc> originalList = intervalTrack.getIntervals(getToolkit()); // The original list of targets provided by the user that will be expanded or not depending on the options provided
intervalList = new TreeSet<GenomeLoc>(new GenomeLocComparator());
intervalMap = new HashMap<GenomeLoc, IntervalStatistics>(); intervalMap = new HashMap<GenomeLoc, IntervalStatistics>();
for (GenomeLoc interval : originalList) intervalListIterator = new PeekableIterator<GenomeLoc>(intervalTrack.getIntervals(getToolkit()).listIterator());
intervalList.add(interval);
//addAndExpandIntervalToMap(interval);
intervalListIterator = intervalList.iterator(); samples = SampleUtils.getSAMFileSamples(getToolkit().getSAMFileHeader()); // get all of the unique sample names for the VCF Header
vcfWriter.writeHeader(new VCFHeader(getHeaderInfo(), samples)); // initialize the VCF header
// get all of the unique sample names
samples = SampleUtils.getSAMFileSamples(getToolkit().getSAMFileHeader());
// initialize the header
Set<VCFHeaderLine> headerInfo = getHeaderInfo();
vcfWriter.writeHeader(new VCFHeader(headerInfo, samples));
}
/**
* Gets the header lines for the VCF writer
*
* @return A set of VCF header lines
*/
private Set<VCFHeaderLine> getHeaderInfo() {
Set<VCFHeaderLine> headerLines = new HashSet<VCFHeaderLine>();
// INFO fields for overall data
headerLines.add(new VCFInfoHeaderLine("END", 1, VCFHeaderLineType.Integer, "Stop position of the interval"));
headerLines.add(new VCFInfoHeaderLine("DP", 1, VCFHeaderLineType.Integer, "Total depth in the site. Sum of the depth of all pools"));
headerLines.add(new VCFInfoHeaderLine("AD", 1, VCFHeaderLineType.Float, "Average depth across the interval. Sum of the depth in a lci divided by interval size."));
headerLines.add(new VCFInfoHeaderLine("Diagnose Targets", 0, VCFHeaderLineType.Flag, "DiagnoseTargets mode"));
// FORMAT fields for each genotype
headerLines.add(new VCFFormatHeaderLine("DP", 1, VCFHeaderLineType.Integer, "Total depth in the site. Sum of the depth of all pools"));
headerLines.add(new VCFFormatHeaderLine("AD", 1, VCFHeaderLineType.Float, "Average depth across the interval. Sum of the depth in a lci divided by interval size."));
// FILTER fields
for (CallableStatus stat : CallableStatus.values()) {
headerLines.add(new VCFHeaderLine(stat.name(), stat.description));
}
return headerLines;
} }
@Override @Override
public Long map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) { public Long map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) {
GenomeLoc refLocus = ref.getLocus(); GenomeLoc refLocus = ref.getLocus();
while (currentInterval == null || currentInterval.isBefore(refLocus)) { // do this for first time and while currentInterval is behind current locus
if (!intervalListIterator.hasNext())
return 0L;
if (currentInterval != null) removePastIntervals(refLocus, ref.getBase()); // process and remove any intervals in the map that are don't overlap the current locus anymore
processIntervalStats(currentInterval, Allele.create(ref.getBase(), true)); addNewOverlappingIntervals(refLocus); // add all new intervals that may overlap this reference locus
currentInterval = intervalListIterator.next(); for (IntervalStatistics intervalStatistics : intervalMap.values())
addAndExpandIntervalToMap(currentInterval); intervalStatistics.addLocus(context); // Add current locus to stats
currentIntervalStatistics = intervalMap.get(currentInterval);
}
if (currentInterval.isPast(refLocus)) // skip if we are behind the current interval
return 0L;
currentIntervalStatistics.addLocus(context); // Add current locus to stats
return 1L; return 1L;
} }
@ -198,10 +142,15 @@ public class DiagnoseTargets extends LocusWalker<Long, Long> implements Annotato
return sum + value; return sum + value;
} }
/**
* Process all remaining intervals
*
* @param result number of loci processed by the walker
*/
@Override @Override
public void onTraversalDone(Long result) { public void onTraversalDone(Long result) {
for (GenomeLoc interval : intervalMap.keySet()) for (GenomeLoc interval : intervalMap.keySet())
processIntervalStats(interval, Allele.create("<DT>", true)); processIntervalStats(intervalMap.get(interval), Allele.create("A"));
} }
@Override @Override
@ -219,82 +168,111 @@ public class DiagnoseTargets extends LocusWalker<Long, Long> implements Annotato
@Override @Override
public boolean alwaysAppendDbsnpId() {return false;} public boolean alwaysAppendDbsnpId() {return false;}
private GenomeLoc createIntervalBefore(GenomeLoc interval) { /**
int start = Math.max(interval.getStart() - expandInterval, 0); * Removes all intervals that are behind the current reference locus from the intervalMap
int stop = Math.max(interval.getStart() - 1, 0); *
return parser.createGenomeLoc(interval.getContig(), interval.getContigIndex(), start, stop); * @param refLocus the current reference locus
* @param refBase the reference allele
*/
private void removePastIntervals(GenomeLoc refLocus, byte refBase) {
List<GenomeLoc> toRemove = new LinkedList<GenomeLoc>();
for (GenomeLoc interval : intervalMap.keySet())
if (interval.isBefore(refLocus)) {
processIntervalStats(intervalMap.get(interval), Allele.create(refBase, true));
toRemove.add(interval);
} }
private GenomeLoc createIntervalAfter(GenomeLoc interval) { for (GenomeLoc interval : toRemove)
int contigLimit = getToolkit().getSAMFileHeader().getSequenceDictionary().getSequence(interval.getContigIndex()).getSequenceLength(); intervalMap.remove(interval);
int start = Math.min(interval.getStop() + 1, contigLimit);
int stop = Math.min(interval.getStop() + expandInterval, contigLimit); GenomeLoc interval = intervalListIterator.peek(); // clean up all intervals that we might have skipped because there was no data
return parser.createGenomeLoc(interval.getContig(), interval.getContigIndex(), start, stop); while(interval != null && interval.isBefore(refLocus)) {
interval = intervalListIterator.next();
processIntervalStats(createIntervalStatistic(interval), Allele.create(refBase, true));
interval = intervalListIterator.peek();
}
} }
/** /**
* Takes an interval and commits it to memory. * Adds all intervals that overlap the current reference locus to the intervalMap
* It will expand it if so told by the -exp command line argument
* *
* @param interval The new interval to process * @param refLocus the current reference locus
*/ */
private void addAndExpandIntervalToMap(GenomeLoc interval) { private void addNewOverlappingIntervals(GenomeLoc refLocus) {
if (expandInterval > 0) { GenomeLoc interval = intervalListIterator.peek();
GenomeLoc before = createIntervalBefore(interval); while (interval != null && !interval.isPast(refLocus)) {
GenomeLoc after = createIntervalAfter(interval); System.out.println("LOCUS : " + refLocus + " -- " + interval);
intervalList.add(before); intervalMap.put(interval, createIntervalStatistic(interval));
intervalList.add(after); intervalListIterator.next(); // discard the interval (we've already added it to the map)
intervalMap.put(before, new IntervalStatistics(samples, before, minimumCoverage, maximumCoverage, minimumMappingQuality, minimumBaseQuality)); interval = intervalListIterator.peek();
intervalMap.put(after, new IntervalStatistics(samples, after, minimumCoverage, maximumCoverage, minimumMappingQuality, minimumBaseQuality));
} }
if (!intervalList.contains(interval))
intervalList.add(interval);
intervalMap.put(interval, new IntervalStatistics(samples, interval, minimumCoverage, maximumCoverage, minimumMappingQuality, minimumBaseQuality));
} }
/** /**
* Takes the interval, finds it in the stash, prints it to the VCF, and removes it * Takes the interval, finds it in the stash, prints it to the VCF, and removes it
* *
* @param interval The interval in memory that you want to write out and clear * @param stats The statistics of the interval
* @param allele the allele * @param refAllele the reference allele
*/ */
private void processIntervalStats(GenomeLoc interval, Allele allele) { private void processIntervalStats(IntervalStatistics stats, Allele refAllele) {
IntervalStatistics stats = intervalMap.get(interval); GenomeLoc interval = stats.getInterval();
List<Allele> alleles = new ArrayList<Allele>(); List<Allele> alleles = new ArrayList<Allele>();
Map<String, Object> attributes = new HashMap<String, Object>(); Map<String, Object> attributes = new HashMap<String, Object>();
ArrayList<Genotype> genotypes = new ArrayList<Genotype>(); ArrayList<Genotype> genotypes = new ArrayList<Genotype>();
alleles.add(allele); alleles.add(refAllele);
VariantContextBuilder vcb = new VariantContextBuilder("DiagnoseTargets", interval.getContig(), interval.getStart(), interval.getStop(), alleles); alleles.add(SYMBOLIC_ALLELE);
VariantContextBuilder vcb = new VariantContextBuilder("DiagnoseTargets", interval.getContig(), interval.getStart(), interval.getStart(), alleles);
vcb = vcb.log10PError(VariantContext.NO_LOG10_PERROR); // QUAL field makes no sense in our VCF vcb = vcb.log10PError(VariantContext.NO_LOG10_PERROR); // QUAL field makes no sense in our VCF
vcb.filters(statusesToStrings(stats.callableStatuses())); vcb.filters(statusesToStrings(stats.callableStatuses()));
attributes.put(VCFConstants.END_KEY, interval.getStop()); attributes.put(VCFConstants.END_KEY, interval.getStop());
attributes.put(VCFConstants.DEPTH_KEY, stats.totalCoverage()); attributes.put(VCFConstants.DEPTH_KEY, stats.averageCoverage());
attributes.put("AV", stats.averageCoverage());
vcb = vcb.attributes(attributes); vcb = vcb.attributes(attributes);
for (String sample : samples) { for (String sample : samples) {
Map<String, Object> infos = new HashMap<String, Object>(); Map<String, Object> infos = new HashMap<String, Object>();
infos.put("DP", stats.getSample(sample).totalCoverage()); infos.put(VCFConstants.DEPTH_KEY, stats.getSample(sample).averageCoverage());
infos.put("AV", stats.getSample(sample).averageCoverage());
Set<String> filters = new HashSet<String>(); Set<String> filters = new HashSet<String>();
filters.addAll(statusesToStrings(stats.getSample(sample).getCallableStatuses())); filters.addAll(statusesToStrings(stats.getSample(sample).getCallableStatuses()));
genotypes.add(new Genotype(sample, alleles, VariantContext.NO_LOG10_PERROR, filters, infos, false)); genotypes.add(new Genotype(sample, null, VariantContext.NO_LOG10_PERROR, filters, infos, false));
} }
vcb = vcb.genotypes(genotypes); vcb = vcb.genotypes(genotypes);
vcfWriter.add(vcb.make()); vcfWriter.add(vcb.make());
intervalMap.remove(interval);
} }
/**
* Gets the header lines for the VCF writer
*
* @return A set of VCF header lines
*/
private static Set<VCFHeaderLine> getHeaderInfo() {
Set<VCFHeaderLine> headerLines = new HashSet<VCFHeaderLine>();
// INFO fields for overall data
headerLines.add(new VCFInfoHeaderLine(VCFConstants.END_KEY, 1, VCFHeaderLineType.Integer, "Stop position of the interval"));
headerLines.add(new VCFInfoHeaderLine(VCFConstants.DEPTH_KEY, 1, VCFHeaderLineType.Float, "Average depth across the interval. Sum of the depth in a lci divided by interval size."));
headerLines.add(new VCFInfoHeaderLine("Diagnose Targets", 0, VCFHeaderLineType.Flag, "DiagnoseTargets mode"));
// FORMAT fields for each genotype
headerLines.add(new VCFFormatHeaderLine(VCFConstants.DEPTH_KEY, 1, VCFHeaderLineType.Float, "Average depth across the interval. Sum of the depth in a lci divided by interval size."));
// FILTER fields
for (CallableStatus stat : CallableStatus.values())
headerLines.add(new VCFHeaderLine(stat.name(), stat.description));
return headerLines;
}
private static Set<String> statusesToStrings(Set<CallableStatus> statuses) { private static Set<String> statusesToStrings(Set<CallableStatus> statuses) {
Set<String> output = new HashSet<String>(statuses.size()); Set<String> output = new HashSet<String>(statuses.size());
@ -303,4 +281,8 @@ public class DiagnoseTargets extends LocusWalker<Long, Long> implements Annotato
return output; return output;
} }
private IntervalStatistics createIntervalStatistic(GenomeLoc interval) {
return new IntervalStatistics(samples, interval, minimumCoverage, maximumCoverage, minimumMappingQuality, minimumBaseQuality);
}
} }

View File

@ -26,6 +26,7 @@ package org.broadinstitute.sting.gatk.walkers.diagnostics.targets;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext; import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.utils.GenomeLoc; import org.broadinstitute.sting.utils.GenomeLoc;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import org.broadinstitute.sting.utils.pileup.ReadBackedPileup; import org.broadinstitute.sting.utils.pileup.ReadBackedPileup;
import java.util.HashMap; import java.util.HashMap;
@ -52,19 +53,29 @@ public class IntervalStatistics {
return samples.get(sample); return samples.get(sample);
} }
public GenomeLoc getInterval() {
return interval;
}
public void addLocus(AlignmentContext context) { public void addLocus(AlignmentContext context) {
ReadBackedPileup pileup = context.getBasePileup(); ReadBackedPileup pileup = context.getBasePileup();
for (String sample : samples.keySet()) Map<String, ReadBackedPileup> samplePileups = pileup.getPileupsForSamples(samples.keySet());
getSample(sample).addLocus(context.getLocation(), pileup.getPileupForSample(sample));
for (Map.Entry<String, ReadBackedPileup> entry : samplePileups.entrySet()) {
String sample = entry.getKey();
ReadBackedPileup samplePileup = entry.getValue();
SampleStatistics sampleStatistics = samples.get(sample);
if (sampleStatistics == null)
throw new ReviewedStingException(String.format("Trying to add locus statistics to a sample (%s) that doesn't exist in the Interval.", sample));
sampleStatistics.addLocus(context.getLocation(), samplePileup);
} }
public long totalCoverage() {
if (preComputedTotalCoverage < 0)
calculateTotalCoverage();
return preComputedTotalCoverage;
} }
public double averageCoverage() { public double averageCoverage() {
if (preComputedTotalCoverage < 0) if (preComputedTotalCoverage < 0)
calculateTotalCoverage(); calculateTotalCoverage();

View File

@ -693,6 +693,38 @@ public abstract class AbstractReadBackedPileup<RBP extends AbstractReadBackedPil
} }
} }
@Override
public Map<String, ReadBackedPileup> getPileupsForSamples(Collection<String> sampleNames) {
Map<String, ReadBackedPileup> result = new HashMap<String, ReadBackedPileup>();
if (pileupElementTracker instanceof PerSamplePileupElementTracker) {
PerSamplePileupElementTracker<PE> tracker = (PerSamplePileupElementTracker<PE>) pileupElementTracker;
for (String sample : sampleNames) {
PileupElementTracker<PE> filteredElements = tracker.getElements(sampleNames);
if (filteredElements != null)
result.put(sample, createNewPileup(loc, filteredElements));
}
} else {
Map<String, UnifiedPileupElementTracker<PE>> trackerMap = new HashMap<String, UnifiedPileupElementTracker<PE>>();
for (String sample : sampleNames) { // initialize pileups for each sample
UnifiedPileupElementTracker<PE> filteredTracker = new UnifiedPileupElementTracker<PE>();
trackerMap.put(sample, filteredTracker);
}
for (PE p : pileupElementTracker) { // go through all pileup elements only once and add them to the respective sample's pileup
GATKSAMRecord read = p.getRead();
if (read.getReadGroup() != null) {
String sample = read.getReadGroup().getSample();
UnifiedPileupElementTracker<PE> tracker = trackerMap.get(sample);
if (tracker != null) // we only add the pileup the requested samples. Completely ignore the rest
tracker.add(p);
}
}
for (Map.Entry<String, UnifiedPileupElementTracker<PE>> entry : trackerMap.entrySet()) // create the RBP for each sample
result.put(entry.getKey(), createNewPileup(loc, entry.getValue()));
}
return result;
}
@Override @Override
public RBP getPileupForSample(String sampleName) { public RBP getPileupForSample(String sampleName) {

View File

@ -32,6 +32,7 @@ import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import java.util.Collection; import java.util.Collection;
import java.util.HashSet; import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Map;
/** /**
* A data retrieval interface for accessing parts of the pileup. * A data retrieval interface for accessing parts of the pileup.
@ -159,6 +160,16 @@ public interface ReadBackedPileup extends Iterable<PileupElement>, HasGenomeLoca
*/ */
public ReadBackedPileup getPileupForSamples(Collection<String> sampleNames); public ReadBackedPileup getPileupForSamples(Collection<String> sampleNames);
/**
* Gets the particular subset of this pileup for each given sample name.
*
* Same as calling getPileupForSample for all samples, but in O(n) instead of O(n^2).
*
* @param sampleNames Name of the sample to use.
* @return A subset of this pileup containing only reads with the given sample.
*/
public Map<String, ReadBackedPileup> getPileupsForSamples(Collection<String> sampleNames);
/** /**
* Gets the particular subset of this pileup with the given sample name. * Gets the particular subset of this pileup with the given sample name.