Optimization for ActiveRegion.removeAll

-- Previous version took a Collection<GATKSAMRecord> to remove, and called ArrayList.removeAll() on this collection to remove reads from the ActiveRegion.  This can be very slow when there are lots of reads, as ArrayList.removeAll ultimately calls indexOf() that searches through the list calling equals() on each element.   New version takes a set, and uses an iterator on the list to remove() from the iterator any read that is in the set.  Given that we were already iterating over the list of reads to update the read span, this algorithm is actually simpler and faster than the previous one.
-- Update HaplotypeCaller filterReadsInRegion to use a Set not a List.
-- Expanded the unit tests a bit for ActiveRegion.removeAll
This commit is contained in:
Mark DePristo 2013-05-21 15:35:43 -04:00
parent d9cdc5d006
commit a1093ad230
3 changed files with 28 additions and 15 deletions

View File

@ -678,7 +678,7 @@ public class HaplotypeCaller extends ActiveRegionWalker<List<VariantContext>, In
if (dontGenotype) return NO_CALLS; // user requested we not proceed
// filter out reads from genotyping which fail mapping quality based criteria
final List<GATKSAMRecord> filteredReads = filterNonPassingReads( assemblyResult.regionForGenotyping );
final Collection<GATKSAMRecord> filteredReads = filterNonPassingReads( assemblyResult.regionForGenotyping );
final Map<String, List<GATKSAMRecord>> perSampleFilteredReadList = splitReadsBySample( filteredReads );
if( assemblyResult.regionForGenotyping.size() == 0 ) { return NO_CALLS; } // no reads remain after filtering so nothing else to do!
@ -918,17 +918,14 @@ public class HaplotypeCaller extends ActiveRegionWalker<List<VariantContext>, In
activeRegion.addAll(DownsamplingUtils.levelCoverageByPosition(ReadUtils.sortReadsByCoordinate(readsToUse), maxReadsInRegionPerSample, minReadsPerAlignmentStart));
}
private List<GATKSAMRecord> filterNonPassingReads( final org.broadinstitute.sting.utils.activeregion.ActiveRegion activeRegion ) {
final List<GATKSAMRecord> readsToRemove = new ArrayList<>();
// logger.info("Filtering non-passing regions: n incoming " + activeRegion.getReads().size());
private Set<GATKSAMRecord> filterNonPassingReads( final org.broadinstitute.sting.utils.activeregion.ActiveRegion activeRegion ) {
final Set<GATKSAMRecord> readsToRemove = new LinkedHashSet<>();
for( final GATKSAMRecord rec : activeRegion.getReads() ) {
if( rec.getReadLength() < MIN_READ_LENGTH || rec.getMappingQuality() < 20 || BadMateFilter.hasBadMate(rec) || (keepRG != null && !rec.getReadGroup().getId().equals(keepRG)) ) {
readsToRemove.add(rec);
// logger.info("\tremoving read " + rec + " len " + rec.getReadLength());
}
}
activeRegion.removeAll( readsToRemove );
// logger.info("Filtered non-passing regions: n remaining " + activeRegion.getReads().size());
return readsToRemove;
}
@ -938,7 +935,7 @@ public class HaplotypeCaller extends ActiveRegionWalker<List<VariantContext>, In
return getToolkit().getGenomeLocParser().createGenomeLoc(activeRegion.getExtendedLoc().getContig(), padLeft, padRight);
}
private Map<String, List<GATKSAMRecord>> splitReadsBySample( final List<GATKSAMRecord> reads ) {
private Map<String, List<GATKSAMRecord>> splitReadsBySample( final Collection<GATKSAMRecord> reads ) {
final Map<String, List<GATKSAMRecord>> returnMap = new HashMap<String, List<GATKSAMRecord>>();
for( final String sample : samplesList) {
List<GATKSAMRecord> readList = returnMap.get( sample );

View File

@ -336,13 +336,17 @@ public class ActiveRegion implements HasGenomeLocation {
/**
* Remove all of the reads in readsToRemove from this active region
* @param readsToRemove the collection of reads we want to remove
* @param readsToRemove the set of reads we want to remove
*/
public void removeAll( final Collection<GATKSAMRecord> readsToRemove ) {
reads.removeAll(readsToRemove);
public void removeAll( final Set<GATKSAMRecord> readsToRemove ) {
final Iterator<GATKSAMRecord> it = reads.iterator();
spanIncludingReads = extendedLoc;
for ( final GATKSAMRecord read : reads ) {
spanIncludingReads = spanIncludingReads.union( genomeLocParser.createGenomeLoc(read) );
while ( it.hasNext() ) {
final GATKSAMRecord read = it.next();
if ( readsToRemove.contains(read) )
it.remove();
else
spanIncludingReads = spanIncludingReads.union( genomeLocParser.createGenomeLoc(read) );
}
}

View File

@ -144,7 +144,7 @@ public class ActiveRegionUnitTest extends BaseTest {
}
@Test(enabled = !DEBUG, dataProvider = "ActiveRegionReads")
public void testActiveRegionReads(final GenomeLoc loc, final GATKSAMRecord read) {
public void testActiveRegionReads(final GenomeLoc loc, final GATKSAMRecord read) throws Exception {
final GenomeLoc expectedSpan = loc.union(genomeLocParser.createGenomeLoc(read));
final ActiveRegion region = new ActiveRegion(loc, null, true, genomeLocParser, 0);
@ -176,19 +176,31 @@ public class ActiveRegionUnitTest extends BaseTest {
Assert.assertEquals(region.getReadSpanLoc(), expectedSpan);
Assert.assertTrue(region.equalExceptReads(region2));
region.removeAll(Collections.<GATKSAMRecord>emptyList());
region.removeAll(Collections.<GATKSAMRecord>emptySet());
Assert.assertEquals(region.getReads(), Collections.singletonList(read));
Assert.assertEquals(region.size(), 1);
Assert.assertEquals(region.getExtendedLoc(), loc);
Assert.assertEquals(region.getReadSpanLoc(), expectedSpan);
Assert.assertTrue(region.equalExceptReads(region2));
region.removeAll(Collections.singletonList(read));
region.removeAll(Collections.singleton(read));
Assert.assertEquals(region.getReads(), Collections.emptyList());
Assert.assertEquals(region.size(), 0);
Assert.assertEquals(region.getExtendedLoc(), loc);
Assert.assertEquals(region.getReadSpanLoc(), loc);
Assert.assertTrue(region.equalExceptReads(region2));
final GATKSAMRecord read2 = (GATKSAMRecord)read.clone();
read2.setReadName(read.getReadName() + ".clone");
for ( final GATKSAMRecord readToKeep : Arrays.asList(read, read2)) {
region.addAll(Arrays.asList(read, read2));
final GATKSAMRecord readToDiscard = readToKeep == read ? read2 : read;
region.removeAll(Collections.singleton(readToDiscard));
Assert.assertEquals(region.getReads(), Arrays.asList(readToKeep));
Assert.assertEquals(region.size(), 1);
Assert.assertEquals(region.getExtendedLoc(), loc);
}
}
// -----------------------------------------------------------------------------------------------