diff --git a/public/java/src/org/broadinstitute/sting/gatk/ReadMetrics.java b/public/java/src/org/broadinstitute/sting/gatk/ReadMetrics.java index f73e7ccd5..29372abcd 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/ReadMetrics.java +++ b/public/java/src/org/broadinstitute/sting/gatk/ReadMetrics.java @@ -42,7 +42,7 @@ public class ReadMetrics implements Cloneable { private long nReads; // keep track of filtered records by filter type (class) - private Map filterCounter = new HashMap<>(); + private Map filterCounter = new HashMap<>(); /** * Combines these metrics with a set of other metrics, storing the results in this class. @@ -51,9 +51,9 @@ public class ReadMetrics implements Cloneable { public synchronized void incrementMetrics(ReadMetrics metrics) { nRecords += metrics.nRecords; nReads += metrics.nReads; - for(Map.Entry counterEntry: metrics.filterCounter.entrySet()) { - Class counterType = counterEntry.getKey(); - long newValue = (filterCounter.containsKey(counterType) ? filterCounter.get(counterType) : 0) + counterEntry.getValue(); + for(Map.Entry counterEntry: metrics.filterCounter.entrySet()) { + final String counterType = counterEntry.getKey(); + final long newValue = (filterCounter.containsKey(counterType) ? filterCounter.get(counterType) : 0) + counterEntry.getValue(); filterCounter.put(counterType, newValue); } } @@ -78,21 +78,12 @@ public class ReadMetrics implements Cloneable { } - public void incrementFilter(SamRecordFilter filter) { - long c = 0; - if ( filterCounter.containsKey(filter.getClass()) ) { - c = filterCounter.get(filter.getClass()); - } - - filterCounter.put(filter.getClass(), c + 1L); + public void setFilterCount(final String filter, final long count) { + filterCounter.put(filter, count); } public Map getCountsByFilter() { - final TreeMap sortedCounts = new TreeMap<>(); - for(Map.Entry counterEntry: filterCounter.entrySet()) { - sortedCounts.put(counterEntry.getKey().getSimpleName(),counterEntry.getValue()); - } - return sortedCounts; + return new TreeMap<>(filterCounter); } /** diff --git a/public/java/src/org/broadinstitute/sting/gatk/filters/CountingFilteringIterator.java b/public/java/src/org/broadinstitute/sting/gatk/filters/CountingFilteringIterator.java index 6c926e3cf..1942fc19a 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/filters/CountingFilteringIterator.java +++ b/public/java/src/org/broadinstitute/sting/gatk/filters/CountingFilteringIterator.java @@ -31,9 +31,7 @@ import net.sf.samtools.util.CloseableIterator; import net.sf.samtools.util.CloserUtil; import org.broadinstitute.sting.gatk.ReadMetrics; -import java.util.Collection; -import java.util.Iterator; -import java.util.NoSuchElementException; +import java.util.*; /** * Filtering Iterator which takes a filter and an iterator and iterates @@ -44,9 +42,27 @@ public class CountingFilteringIterator implements CloseableIterator { private final ReadMetrics globalRuntimeMetrics; private final ReadMetrics privateRuntimeMetrics; private final Iterator iterator; - private final Collection filters; + private final List filters = new ArrayList<>(); private SAMRecord next = null; + // wrapper around ReadFilters to count the number of filtered reads + private final class CountingReadFilter extends ReadFilter { + protected final ReadFilter readFilter; + protected long counter = 0L; + + public CountingReadFilter(final ReadFilter readFilter) { + this.readFilter = readFilter; + } + + @Override + public boolean filterOut(final SAMRecord record) { + final boolean result = readFilter.filterOut(record); + if ( result ) + counter++; + return result; + } + } + /** * Constructor * @@ -58,7 +74,8 @@ public class CountingFilteringIterator implements CloseableIterator { this.globalRuntimeMetrics = metrics; privateRuntimeMetrics = new ReadMetrics(); this.iterator = iterator; - this.filters = filters; + for ( final ReadFilter filter : filters ) + this.filters.add(new CountingReadFilter(filter)); next = getNextRecord(); } @@ -97,8 +114,11 @@ public class CountingFilteringIterator implements CloseableIterator { public void close() { CloserUtil.close(iterator); + // update the global metrics with all the data we collected here globalRuntimeMetrics.incrementMetrics(privateRuntimeMetrics); + for ( final CountingReadFilter filter : filters ) + globalRuntimeMetrics.setFilterCount(filter.readFilter.getClass().getSimpleName(), filter.counter); } /** @@ -117,7 +137,6 @@ public class CountingFilteringIterator implements CloseableIterator { boolean filtered = false; for(SamRecordFilter filter: filters) { if(filter.filterOut(record)) { - privateRuntimeMetrics.incrementFilter(filter); filtered = true; break; } diff --git a/public/java/test/org/broadinstitute/sting/gatk/ReadMetricsUnitTest.java b/public/java/test/org/broadinstitute/sting/gatk/ReadMetricsUnitTest.java index 32fd35d95..3225a128c 100644 --- a/public/java/test/org/broadinstitute/sting/gatk/ReadMetricsUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/gatk/ReadMetricsUnitTest.java @@ -34,7 +34,6 @@ import org.broadinstitute.sting.gatk.contexts.AlignmentContext; import org.broadinstitute.sting.gatk.contexts.ReferenceContext; import org.broadinstitute.sting.gatk.datasources.providers.LocusShardDataProvider; import org.broadinstitute.sting.gatk.datasources.providers.ReadShardDataProvider; -import org.broadinstitute.sting.gatk.datasources.providers.ShardDataProvider; import org.broadinstitute.sting.gatk.datasources.reads.*; import org.broadinstitute.sting.gatk.datasources.rmd.ReferenceOrderedDataSource; import org.broadinstitute.sting.gatk.executive.WindowMaker; @@ -263,6 +262,43 @@ public class ReadMetricsUnitTest extends BaseTest { Assert.assertEquals(engine.getCumulativeMetrics().getNumIterations(), contigs.size() * numReadsPerContig); } + @Test + public void testFilteredCounts() { + final GenomeAnalysisEngine engine = new GenomeAnalysisEngine(); + engine.setGenomeLocParser(genomeLocParser); + + final Collection samFiles = new ArrayList<>(); + final SAMReaderID readerID = new SAMReaderID(testBAM, new Tags()); + samFiles.add(readerID); + + final List filters = new ArrayList<>(); + filters.add(new EveryTenthReadFilter()); + + final SAMDataSource dataSource = new SAMDataSource(samFiles, new ThreadAllocation(), null, genomeLocParser, + false, + SAMFileReader.ValidationStringency.STRICT, + null, + null, + new ValidationExclusion(), + filters, + new ArrayList(), + false, (byte)30, false, true); + + engine.setReadsDataSource(dataSource); + + final TraverseReadsNano traverseReadsNano = new TraverseReadsNano(1); + final DummyReadWalker walker = new DummyReadWalker(); + traverseReadsNano.initialize(engine, walker, null); + + for ( final Shard shard : dataSource.createShardIteratorOverAllReads(new ReadShardBalancer()) ) { + final ReadShardDataProvider dataProvider = new ReadShardDataProvider(shard, engine.getGenomeLocParser(), dataSource.seek(shard), reference, new ArrayList()); + traverseReadsNano.traverse(walker, dataProvider, 0); + dataProvider.close(); + } + + Assert.assertEquals((long)engine.getCumulativeMetrics().getCountsByFilter().get(EveryTenthReadFilter.class.getSimpleName()), contigs.size() * numReadsPerContig / 10); + } + class DummyLocusWalker extends LocusWalker { @Override public Integer map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) { @@ -318,4 +354,19 @@ public class ReadMetricsUnitTest extends BaseTest { return 0; } } + + private final class EveryTenthReadFilter extends ReadFilter { + + private int myCounter = 0; + + @Override + public boolean filterOut(final SAMRecord record) { + if ( ++myCounter == 10 ) { + myCounter = 0; + return true; + } + + return false; + } + } } \ No newline at end of file