Optimized counting of filtered records by filter.

Don't map class to counts in the ReadMetrics (necessitating 2 HashMap lookups for every increment).
Instead, wrap the ReadFilters with a counting version and then set those counts only when updating global metrics.
This commit is contained in:
Eric Banks 2013-05-21 18:19:23 -04:00
parent 3cfe2dcc64
commit 881b2b50ab
3 changed files with 84 additions and 23 deletions

View File

@ -42,7 +42,7 @@ public class ReadMetrics implements Cloneable {
private long nReads;
// keep track of filtered records by filter type (class)
private Map<Class, Long> filterCounter = new HashMap<>();
private Map<String, Long> filterCounter = new HashMap<>();
/**
* Combines these metrics with a set of other metrics, storing the results in this class.
@ -51,9 +51,9 @@ public class ReadMetrics implements Cloneable {
public synchronized void incrementMetrics(ReadMetrics metrics) {
nRecords += metrics.nRecords;
nReads += metrics.nReads;
for(Map.Entry<Class,Long> counterEntry: metrics.filterCounter.entrySet()) {
Class counterType = counterEntry.getKey();
long newValue = (filterCounter.containsKey(counterType) ? filterCounter.get(counterType) : 0) + counterEntry.getValue();
for(Map.Entry<String, Long> counterEntry: metrics.filterCounter.entrySet()) {
final String counterType = counterEntry.getKey();
final long newValue = (filterCounter.containsKey(counterType) ? filterCounter.get(counterType) : 0) + counterEntry.getValue();
filterCounter.put(counterType, newValue);
}
}
@ -78,21 +78,12 @@ public class ReadMetrics implements Cloneable {
}
public void incrementFilter(SamRecordFilter filter) {
long c = 0;
if ( filterCounter.containsKey(filter.getClass()) ) {
c = filterCounter.get(filter.getClass());
}
filterCounter.put(filter.getClass(), c + 1L);
public void setFilterCount(final String filter, final long count) {
filterCounter.put(filter, count);
}
public Map<String,Long> getCountsByFilter() {
final TreeMap<String, Long> sortedCounts = new TreeMap<>();
for(Map.Entry<Class,Long> counterEntry: filterCounter.entrySet()) {
sortedCounts.put(counterEntry.getKey().getSimpleName(),counterEntry.getValue());
}
return sortedCounts;
return new TreeMap<>(filterCounter);
}
/**

View File

@ -31,9 +31,7 @@ import net.sf.samtools.util.CloseableIterator;
import net.sf.samtools.util.CloserUtil;
import org.broadinstitute.sting.gatk.ReadMetrics;
import java.util.Collection;
import java.util.Iterator;
import java.util.NoSuchElementException;
import java.util.*;
/**
* Filtering Iterator which takes a filter and an iterator and iterates
@ -44,9 +42,27 @@ public class CountingFilteringIterator implements CloseableIterator<SAMRecord> {
private final ReadMetrics globalRuntimeMetrics;
private final ReadMetrics privateRuntimeMetrics;
private final Iterator<SAMRecord> iterator;
private final Collection<ReadFilter> filters;
private final List<CountingReadFilter> filters = new ArrayList<>();
private SAMRecord next = null;
// wrapper around ReadFilters to count the number of filtered reads
private final class CountingReadFilter extends ReadFilter {
protected final ReadFilter readFilter;
protected long counter = 0L;
public CountingReadFilter(final ReadFilter readFilter) {
this.readFilter = readFilter;
}
@Override
public boolean filterOut(final SAMRecord record) {
final boolean result = readFilter.filterOut(record);
if ( result )
counter++;
return result;
}
}
/**
* Constructor
*
@ -58,7 +74,8 @@ public class CountingFilteringIterator implements CloseableIterator<SAMRecord> {
this.globalRuntimeMetrics = metrics;
privateRuntimeMetrics = new ReadMetrics();
this.iterator = iterator;
this.filters = filters;
for ( final ReadFilter filter : filters )
this.filters.add(new CountingReadFilter(filter));
next = getNextRecord();
}
@ -97,8 +114,11 @@ public class CountingFilteringIterator implements CloseableIterator<SAMRecord> {
public void close() {
CloserUtil.close(iterator);
// update the global metrics with all the data we collected here
globalRuntimeMetrics.incrementMetrics(privateRuntimeMetrics);
for ( final CountingReadFilter filter : filters )
globalRuntimeMetrics.setFilterCount(filter.readFilter.getClass().getSimpleName(), filter.counter);
}
/**
@ -117,7 +137,6 @@ public class CountingFilteringIterator implements CloseableIterator<SAMRecord> {
boolean filtered = false;
for(SamRecordFilter filter: filters) {
if(filter.filterOut(record)) {
privateRuntimeMetrics.incrementFilter(filter);
filtered = true;
break;
}

View File

@ -34,7 +34,6 @@ import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.datasources.providers.LocusShardDataProvider;
import org.broadinstitute.sting.gatk.datasources.providers.ReadShardDataProvider;
import org.broadinstitute.sting.gatk.datasources.providers.ShardDataProvider;
import org.broadinstitute.sting.gatk.datasources.reads.*;
import org.broadinstitute.sting.gatk.datasources.rmd.ReferenceOrderedDataSource;
import org.broadinstitute.sting.gatk.executive.WindowMaker;
@ -263,6 +262,43 @@ public class ReadMetricsUnitTest extends BaseTest {
Assert.assertEquals(engine.getCumulativeMetrics().getNumIterations(), contigs.size() * numReadsPerContig);
}
@Test
public void testFilteredCounts() {
final GenomeAnalysisEngine engine = new GenomeAnalysisEngine();
engine.setGenomeLocParser(genomeLocParser);
final Collection<SAMReaderID> samFiles = new ArrayList<>();
final SAMReaderID readerID = new SAMReaderID(testBAM, new Tags());
samFiles.add(readerID);
final List<ReadFilter> filters = new ArrayList<>();
filters.add(new EveryTenthReadFilter());
final SAMDataSource dataSource = new SAMDataSource(samFiles, new ThreadAllocation(), null, genomeLocParser,
false,
SAMFileReader.ValidationStringency.STRICT,
null,
null,
new ValidationExclusion(),
filters,
new ArrayList<ReadTransformer>(),
false, (byte)30, false, true);
engine.setReadsDataSource(dataSource);
final TraverseReadsNano traverseReadsNano = new TraverseReadsNano(1);
final DummyReadWalker walker = new DummyReadWalker();
traverseReadsNano.initialize(engine, walker, null);
for ( final Shard shard : dataSource.createShardIteratorOverAllReads(new ReadShardBalancer()) ) {
final ReadShardDataProvider dataProvider = new ReadShardDataProvider(shard, engine.getGenomeLocParser(), dataSource.seek(shard), reference, new ArrayList<ReferenceOrderedDataSource>());
traverseReadsNano.traverse(walker, dataProvider, 0);
dataProvider.close();
}
Assert.assertEquals((long)engine.getCumulativeMetrics().getCountsByFilter().get(EveryTenthReadFilter.class.getSimpleName()), contigs.size() * numReadsPerContig / 10);
}
class DummyLocusWalker extends LocusWalker<Integer, Integer> {
@Override
public Integer map(RefMetaDataTracker tracker, ReferenceContext ref, AlignmentContext context) {
@ -318,4 +354,19 @@ public class ReadMetricsUnitTest extends BaseTest {
return 0;
}
}
private final class EveryTenthReadFilter extends ReadFilter {
private int myCounter = 0;
@Override
public boolean filterOut(final SAMRecord record) {
if ( ++myCounter == 10 ) {
myCounter = 0;
return true;
}
return false;
}
}
}