Added a reservoir downsampler which can sample elements in an iterator uniformly

from a stream (see Vitter 1985).  Thanks to Eric and Andrey for the pointer.


git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@3197 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
hanna 2010-04-19 20:48:14 +00:00
parent c44f63c846
commit c08936d6f4
3 changed files with 315 additions and 0 deletions

View File

@ -0,0 +1,103 @@
package org.broadinstitute.sting.utils;
import net.sf.picard.util.PeekableIterator;
import java.util.*;
/**
* Randomly downsample from a stream of elements. This algorithm is a direct,
* naive implementation of reservoir downsampling as described in "Random Downsampling
* with a Reservoir" (Vitter 1985). At time of writing, this paper is located here:
* http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.138.784&rep=rep1&type=pdf
*
* Note that using the ReservoirDownsampler will leave the given iterator in an undefined
* state. Do not attempt to use the iterator (other than closing it) after the Downsampler
* completes.
*
* @author mhanna
* @version 0.1
*/
public class ReservoirDownsampler<T> implements Iterator<Collection<T>> {
/**
* Create a random number generator with a random, but reproducible, seed.
*/
private final Random random = new Random(47382911L);
/**
* The data source, wrapped in a peekable input stream.
*/
private final PeekableIterator<T> iterator;
/**
* Used to identify whether two elements are 'equal' in the eyes of the downsampler.
*/
private final Comparator<T> comparator;
/**
* What is the maximum number of reads that can be returned in a single batch.
*/
private final int maxElements;
/**
* Create a new downsampler with the given source iterator and given comparator.
* @param iterator Source of the data stream.
* @param comparator Used to compare two records to see whether they're 'equal' at this position.
* @param maxElements What is the maximum number of reads that can be returned in any call of this
*/
public ReservoirDownsampler(final Iterator<T> iterator, final Comparator<T> comparator, final int maxElements) {
this.iterator = new PeekableIterator<T>(iterator);
this.comparator = comparator;
if(maxElements < 0)
throw new StingException("Unable to work with an negative size collection of elements");
this.maxElements = maxElements;
}
public boolean hasNext() {
return iterator.hasNext();
}
/**
* Gets a collection of 'equal' elements, as judged by the comparator. If the number of equal elements
* is greater than the maximum, then the elements in the collection should be a truly random sampling.
* @return Collection of equal elements.
*/
public Collection<T> next() {
if(!hasNext())
throw new NoSuchElementException("No next element is present.");
List<T> batch = new ArrayList<T>(maxElements);
int currentElement = 0;
// Determine our basis of equality.
T first = iterator.next();
if(maxElements > 0)
batch.add(first);
currentElement++;
// Fill the reservoir
while(iterator.hasNext() &&
currentElement < maxElements &&
comparator.compare(first,iterator.peek()) == 0) {
batch.add(iterator.next());
currentElement++;
}
// Trim off remaining elements, randomly selecting them using the process as described by Vitter.
while(iterator.hasNext() && comparator.compare(first,iterator.peek()) == 0) {
T candidate = iterator.next();
final int slot = random.nextInt(currentElement);
if(slot >= 0 && slot < maxElements)
batch.set(slot,candidate);
currentElement++;
}
return batch;
}
/**
* Unsupported; throws exception to that effect.
*/
public void remove() {
throw new UnsupportedOperationException("Cannot remove from a ReservoirDownsampler.");
}
}

View File

@ -0,0 +1,190 @@
package org.broadinstitute.sting.utils;
import org.junit.Test;
import org.broadinstitute.sting.utils.sam.AlignmentStartComparator;
import org.broadinstitute.sting.utils.sam.ArtificialSAMUtils;
import net.sf.samtools.SAMRecord;
import net.sf.samtools.SAMFileHeader;
import java.util.*;
import junit.framework.Assert;
/**
* Basic tests to prove the integrity of the reservoir downsampler.
* At the moment, always run tests on SAM records as that's the task
* for which the downsampler was conceived.
*
* @author mhanna
* @version 0.1
*/
public class ReservoirDownsamplerUnitTest {
private static final SAMFileHeader header = ArtificialSAMUtils.createArtificialSamHeader(1,1,200);
@Test
public void testEmptyIterator() {
ReservoirDownsampler<SAMRecord> downsampler = new ReservoirDownsampler<SAMRecord>(Collections.<SAMRecord>emptyList().iterator(),
new AlignmentStartComparator(),1);
Assert.assertFalse("Downsampler is not empty but should be.",downsampler.hasNext());
}
@Test
public void testOneElementWithPoolSizeOne() {
List<SAMRecord> reads = Collections.singletonList(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76));
ReservoirDownsampler<SAMRecord> downsampler = new ReservoirDownsampler<SAMRecord>(reads.iterator(),
new AlignmentStartComparator(),1);
Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext());
Collection<SAMRecord> batchedReads = downsampler.next();
Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size());
Assert.assertSame("Downsampler is returning an incorrect read",reads.get(0),batchedReads.iterator().next());
}
@Test
public void testOneElementWithPoolSizeGreaterThanOne() {
List<SAMRecord> reads = Collections.singletonList(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76));
ReservoirDownsampler<SAMRecord> downsampler = new ReservoirDownsampler<SAMRecord>(reads.iterator(),
new AlignmentStartComparator(),5);
Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext());
Collection<SAMRecord> batchedReads = downsampler.next();
Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size());
Assert.assertSame("Downsampler is returning an incorrect read",reads.get(0),batchedReads.iterator().next());
}
@Test
public void testPoolFilledPartially() {
List<SAMRecord> reads = new ArrayList<SAMRecord>();
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read2",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read3",0,1,76));
ReservoirDownsampler<SAMRecord> downsampler = new ReservoirDownsampler<SAMRecord>(reads.iterator(),
new AlignmentStartComparator(),5);
Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext());
List<SAMRecord> batchedReads = new ArrayList<SAMRecord>(downsampler.next());
Assert.assertEquals("Downsampler is returning the wrong number of reads",3,batchedReads.size());
Assert.assertSame("Downsampler read 1 is incorrect",reads.get(0),batchedReads.get(0));
Assert.assertSame("Downsampler read 2 is incorrect",reads.get(1),batchedReads.get(1));
Assert.assertSame("Downsampler read 3 is incorrect",reads.get(2),batchedReads.get(2));
}
@Test
public void testPoolFilledExactly() {
List<SAMRecord> reads = new ArrayList<SAMRecord>();
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read2",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read3",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read4",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read5",0,1,76));
ReservoirDownsampler<SAMRecord> downsampler = new ReservoirDownsampler<SAMRecord>(reads.iterator(),
new AlignmentStartComparator(),5);
Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext());
List<SAMRecord> batchedReads = new ArrayList<SAMRecord>(downsampler.next());
Assert.assertEquals("Downsampler is returning the wrong number of reads",5,batchedReads.size());
Assert.assertSame("Downsampler is returning an incorrect read",reads.get(0),batchedReads.iterator().next());
Assert.assertSame("Downsampler read 1 is incorrect",reads.get(0),batchedReads.get(0));
Assert.assertSame("Downsampler read 2 is incorrect",reads.get(1),batchedReads.get(1));
Assert.assertSame("Downsampler read 3 is incorrect",reads.get(2),batchedReads.get(2));
Assert.assertSame("Downsampler read 4 is incorrect",reads.get(3),batchedReads.get(3));
Assert.assertSame("Downsampler read 5 is incorrect",reads.get(4),batchedReads.get(4));
}
@Test
public void testLargerPileWithZeroElementPool() {
List<SAMRecord> reads = new ArrayList<SAMRecord>();
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read2",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read3",0,1,76));
ReservoirDownsampler<SAMRecord> downsampler = new ReservoirDownsampler<SAMRecord>(reads.iterator(),
new AlignmentStartComparator(),0);
Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext());
List<SAMRecord> batchedReads = new ArrayList<SAMRecord>(downsampler.next());
Assert.assertEquals("Downsampler is returning the wrong number of reads",0,batchedReads.size());
Assert.assertFalse("Downsampler is not empty but should be",downsampler.hasNext());
}
@Test
public void testLargerPileWithSingleElementPool() {
List<SAMRecord> reads = new ArrayList<SAMRecord>();
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read2",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read3",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read4",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read5",0,1,76));
ReservoirDownsampler<SAMRecord> downsampler = new ReservoirDownsampler<SAMRecord>(reads.iterator(),
new AlignmentStartComparator(),1);
Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext());
List<SAMRecord> batchedReads = new ArrayList<SAMRecord>(downsampler.next());
Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size());
Assert.assertTrue("Downsampler is returning a bad read.",reads.contains(batchedReads.get(0))) ;
Assert.assertFalse("Downsampler is not empty but should be",downsampler.hasNext());
}
@Test
public void testFillingAcrossLoci() {
List<SAMRecord> reads = new ArrayList<SAMRecord>();
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read2",0,2,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read3",0,2,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read4",0,3,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read5",0,3,76));
ReservoirDownsampler<SAMRecord> downsampler = new ReservoirDownsampler<SAMRecord>(reads.iterator(),
new AlignmentStartComparator(),5);
Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext());
List<SAMRecord> batchedReads = new ArrayList<SAMRecord>(downsampler.next());
Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size());
Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(0),batchedReads.get(0));
Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext());
batchedReads = new ArrayList<SAMRecord>(downsampler.next());
Assert.assertEquals("Downsampler is returning the wrong number of reads",2,batchedReads.size());
Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(1),batchedReads.get(0));
Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(2),batchedReads.get(1));
Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext());
batchedReads = new ArrayList<SAMRecord>(downsampler.next());
Assert.assertEquals("Downsampler is returning the wrong number of reads",2,batchedReads.size());
Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(3),batchedReads.get(0));
Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(4),batchedReads.get(1));
Assert.assertFalse("Downsampler is not empty but should be",downsampler.hasNext());
}
@Test
public void testDownsamplingAcrossLoci() {
List<SAMRecord> reads = new ArrayList<SAMRecord>();
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read2",0,2,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read3",0,2,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read4",0,3,76));
reads.add(ArtificialSAMUtils.createArtificialRead(header,"read5",0,3,76));
ReservoirDownsampler<SAMRecord> downsampler = new ReservoirDownsampler<SAMRecord>(reads.iterator(),
new AlignmentStartComparator(),1);
Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext());
List<SAMRecord> batchedReads = new ArrayList<SAMRecord>(downsampler.next());
Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size());
Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(0),batchedReads.get(0));
Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext());
batchedReads = new ArrayList<SAMRecord>(downsampler.next());
Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size());
Assert.assertTrue("Downsampler is returning an incorrect read.",batchedReads.get(0).equals(reads.get(1)) || batchedReads.get(0).equals(reads.get(2)));
Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext());
batchedReads = new ArrayList<SAMRecord>(downsampler.next());
Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size());
Assert.assertTrue("Downsampler is returning an incorrect read.",batchedReads.get(0).equals(reads.get(3)) || batchedReads.get(0).equals(reads.get(4)));
Assert.assertFalse("Downsampler is not empty but should be",downsampler.hasNext());
}
}

View File

@ -0,0 +1,22 @@
package org.broadinstitute.sting.utils.sam;
import net.sf.samtools.SAMRecord;
import java.util.Comparator;
/**
* Compares two SAMRecords only the basis on alignment start. Note that
* comparisons are performed ONLY on the basis of alignment start; any
* two SAM records with the same alignment start will be considered equal.
*
* Unmapped alignments will all be considered equal.
*
* @author mhanna
* @version 0.1
*/
public class AlignmentStartComparator implements Comparator<SAMRecord> {
public int compare(SAMRecord lhs, SAMRecord rhs) {
// Note: no integer overflow here because alignment starts are >= 0.
return lhs.getAlignmentStart() - rhs.getAlignmentStart();
}
}