auto-merge

This commit is contained in:
Ryan Poplin 2012-12-19 10:12:49 -05:00
commit cda0c48570
7 changed files with 206 additions and 416 deletions

View File

@ -359,7 +359,10 @@ public class GATKBAMIndex {
int bytesExpected = buffer.limit(); int bytesExpected = buffer.limit();
//BufferedInputStream cannot read directly into a byte buffer, so we read into an array //BufferedInputStream cannot read directly into a byte buffer, so we read into an array
//and put the result into the bytebuffer after the if statement. //and put the result into the bytebuffer after the if statement.
int bytesRead = bufferedStream.read(byteArray,0,bytesExpected);
//SeekableBufferedStream is evil, it will "read" beyond the end of the file if you let it!
final int bytesToRead = (int) Math.min(bufferedStream.length() - bufferedStream.position(), bytesExpected); //min of int and long will definitely be castable to an int.
int bytesRead = bufferedStream.read(byteArray,0,bytesToRead);
// We have a rigid expectation here to read in exactly the number of bytes we've limited // We have a rigid expectation here to read in exactly the number of bytes we've limited
// our buffer to -- if we read in fewer bytes than this, or encounter EOF (-1), the index // our buffer to -- if we read in fewer bytes than this, or encounter EOF (-1), the index

View File

@ -104,6 +104,56 @@ public class PerSampleDownsamplingReadsIterator implements StingSAMIterator {
readComparator.compare(orderedDownsampledReadsCache.peek(), earliestPendingRead) <= 0; readComparator.compare(orderedDownsampledReadsCache.peek(), earliestPendingRead) <= 0;
} }
private boolean fillDownsampledReadsCache() {
SAMRecord prevRead = null;
int numPositionalChanges = 0;
// Continue submitting reads to the per-sample downsamplers until the read at the top of the priority queue
// can be released without violating global sort order
while ( nestedSAMIterator.hasNext() && ! readyToReleaseReads() ) {
SAMRecord read = nestedSAMIterator.next();
String sampleName = read.getReadGroup() != null ? read.getReadGroup().getSample() : null;
ReadsDownsampler<SAMRecord> thisSampleDownsampler = perSampleDownsamplers.get(sampleName);
if ( thisSampleDownsampler == null ) {
thisSampleDownsampler = downsamplerFactory.newInstance();
perSampleDownsamplers.put(sampleName, thisSampleDownsampler);
}
thisSampleDownsampler.submit(read);
processFinalizedAndPendingItems(thisSampleDownsampler);
if ( prevRead != null && prevRead.getAlignmentStart() != read.getAlignmentStart() ) {
numPositionalChanges++;
}
// Periodically inform all downsamplers of the current position in the read stream. This is
// to prevent downsamplers for samples with sparser reads than others from getting stuck too
// long in a pending state.
if ( numPositionalChanges > 0 && numPositionalChanges % DOWNSAMPLER_POSITIONAL_UPDATE_INTERVAL == 0 ) {
for ( ReadsDownsampler<SAMRecord> perSampleDownsampler : perSampleDownsamplers.values() ) {
perSampleDownsampler.signalNoMoreReadsBefore(read);
processFinalizedAndPendingItems(perSampleDownsampler);
}
}
prevRead = read;
}
if ( ! nestedSAMIterator.hasNext() ) {
for ( ReadsDownsampler<SAMRecord> perSampleDownsampler : perSampleDownsamplers.values() ) {
perSampleDownsampler.signalEndOfInput();
if ( perSampleDownsampler.hasFinalizedItems() ) {
orderedDownsampledReadsCache.addAll(perSampleDownsampler.consumeFinalizedItems());
}
}
earliestPendingRead = null;
earliestPendingDownsampler = null;
}
return readyToReleaseReads();
}
private void updateEarliestPendingRead( ReadsDownsampler<SAMRecord> currentDownsampler ) { private void updateEarliestPendingRead( ReadsDownsampler<SAMRecord> currentDownsampler ) {
// If there is no recorded earliest pending read and this downsampler has pending items, // If there is no recorded earliest pending read and this downsampler has pending items,
// then this downsampler's first pending item becomes the new earliest pending read: // then this downsampler's first pending item becomes the new earliest pending read:
@ -135,57 +185,11 @@ public class PerSampleDownsamplingReadsIterator implements StingSAMIterator {
} }
} }
private boolean fillDownsampledReadsCache() { private void processFinalizedAndPendingItems( ReadsDownsampler<SAMRecord> currentDownsampler ) {
SAMRecord prevRead = null; if ( currentDownsampler.hasFinalizedItems() ) {
int numPositionalChanges = 0; orderedDownsampledReadsCache.addAll(currentDownsampler.consumeFinalizedItems());
// Continue submitting reads to the per-sample downsamplers until the read at the top of the priority queue
// can be released without violating global sort order
while ( nestedSAMIterator.hasNext() && ! readyToReleaseReads() ) {
SAMRecord read = nestedSAMIterator.next();
String sampleName = read.getReadGroup() != null ? read.getReadGroup().getSample() : null;
ReadsDownsampler<SAMRecord> thisSampleDownsampler = perSampleDownsamplers.get(sampleName);
if ( thisSampleDownsampler == null ) {
thisSampleDownsampler = downsamplerFactory.newInstance();
perSampleDownsamplers.put(sampleName, thisSampleDownsampler);
}
thisSampleDownsampler.submit(read);
updateEarliestPendingRead(thisSampleDownsampler);
if ( prevRead != null && prevRead.getAlignmentStart() != read.getAlignmentStart() ) {
numPositionalChanges++;
}
// Periodically inform all downsamplers of the current position in the read stream. This is
// to prevent downsamplers for samples with sparser reads than others from getting stuck too
// long in a pending state.
if ( numPositionalChanges > 0 && numPositionalChanges % DOWNSAMPLER_POSITIONAL_UPDATE_INTERVAL == 0 ) {
for ( ReadsDownsampler<SAMRecord> perSampleDownsampler : perSampleDownsamplers.values() ) {
perSampleDownsampler.signalNoMoreReadsBefore(read);
updateEarliestPendingRead(perSampleDownsampler);
}
}
prevRead = read;
} }
updateEarliestPendingRead(currentDownsampler);
if ( ! nestedSAMIterator.hasNext() ) {
for ( ReadsDownsampler<SAMRecord> perSampleDownsampler : perSampleDownsamplers.values() ) {
perSampleDownsampler.signalEndOfInput();
}
earliestPendingRead = null;
earliestPendingDownsampler = null;
}
for ( ReadsDownsampler<SAMRecord> perSampleDownsampler : perSampleDownsamplers.values() ) {
if ( perSampleDownsampler.hasFinalizedItems() ) {
orderedDownsampledReadsCache.addAll(perSampleDownsampler.consumeFinalizedItems());
}
}
return readyToReleaseReads();
} }
public void remove() { public void remove() {

View File

@ -1,16 +1,15 @@
package org.broadinstitute.sting.utils.nanoScheduler; package org.broadinstitute.sting.utils.nanoScheduler;
import org.apache.log4j.Logger; import org.apache.log4j.Logger;
import org.broadinstitute.sting.utils.MultiThreadedErrorTracker;
import java.util.Iterator; import java.util.Iterator;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch; import java.util.concurrent.CountDownLatch;
/** /**
* Producer Thread that reads input values from an inputReads and puts them into an output queue * Helper class that allows multiple threads to reads input values from
* an iterator, and track the number of items read from that iterator.
*/ */
class InputProducer<InputType> implements Runnable { class InputProducer<InputType> {
private final static Logger logger = Logger.getLogger(InputProducer.class); private final static Logger logger = Logger.getLogger(InputProducer.class);
/** /**
@ -18,13 +17,6 @@ class InputProducer<InputType> implements Runnable {
*/ */
final Iterator<InputType> inputReader; final Iterator<InputType> inputReader;
/**
* Where we put our input values for consumption
*/
final BlockingQueue<InputValue> outputQueue;
final MultiThreadedErrorTracker errorTracker;
/** /**
* Have we read the last value from inputReader? * Have we read the last value from inputReader?
* *
@ -34,6 +26,14 @@ class InputProducer<InputType> implements Runnable {
*/ */
boolean readLastValue = false; boolean readLastValue = false;
/**
* Once we've readLastValue, lastValue contains a continually
* updating InputValue where EOF is true. It's not necessarily
* a single value, as each read updates lastValue with the
* next EOF marker
*/
private InputValue lastValue = null;
int nRead = 0; int nRead = 0;
int inputID = -1; int inputID = -1;
@ -43,16 +43,9 @@ class InputProducer<InputType> implements Runnable {
*/ */
final CountDownLatch latch = new CountDownLatch(1); final CountDownLatch latch = new CountDownLatch(1);
public InputProducer(final Iterator<InputType> inputReader, public InputProducer(final Iterator<InputType> inputReader) {
final MultiThreadedErrorTracker errorTracker,
final BlockingQueue<InputValue> outputQueue) {
if ( inputReader == null ) throw new IllegalArgumentException("inputReader cannot be null"); if ( inputReader == null ) throw new IllegalArgumentException("inputReader cannot be null");
if ( errorTracker == null ) throw new IllegalArgumentException("errorTracker cannot be null");
if ( outputQueue == null ) throw new IllegalArgumentException("OutputQueue cannot be null");
this.inputReader = inputReader; this.inputReader = inputReader;
this.errorTracker = errorTracker;
this.outputQueue = outputQueue;
} }
/** /**
@ -82,9 +75,8 @@ class InputProducer<InputType> implements Runnable {
* This method is synchronized, as it manipulates local state accessed across multiple threads. * This method is synchronized, as it manipulates local state accessed across multiple threads.
* *
* @return the next input stream value, or null if the stream contains no more elements * @return the next input stream value, or null if the stream contains no more elements
* @throws InterruptedException
*/ */
private synchronized InputType readNextItem() throws InterruptedException { private synchronized InputType readNextItem() {
if ( ! inputReader.hasNext() ) { if ( ! inputReader.hasNext() ) {
// we are done, mark ourselves as such and return null // we are done, mark ourselves as such and return null
readLastValue = true; readLastValue = true;
@ -100,49 +92,60 @@ class InputProducer<InputType> implements Runnable {
} }
/** /**
* Run this input producer, looping over all items in the input reader and * Are there currently more values in the iterator?
* enqueueing them as InputValues into the outputQueue. After the *
* end of the stream has been encountered, any threads waiting because * Note the word currently. It's possible that some already submitted
* they called waitForDone() will be freed. * job will read a value from this InputProvider, so in some sense
* there are no more values and in the future there'll be no next
* value. That said, once this returns false it means that all
* of the possible values have been read
*
* @return true if a future call to next might return a non-EOF value, false if
* the underlying iterator is definitely empty
*/ */
public void run() { public synchronized boolean hasNext() {
try { return ! allInputsHaveBeenRead();
while ( true ) {
final InputType value = readNextItem();
if ( value == null ) {
if ( ! readLastValue )
throw new IllegalStateException("value == null but readLastValue is false!");
// add the EOF object so our consumer knows we are done in all inputs
// note that we do not increase inputID here, so that variable indicates the ID
// of the last real value read from the queue
outputQueue.put(new InputValue(inputID + 1));
break;
} else {
// add the actual value to the outputQueue
outputQueue.put(new InputValue(++inputID, value));
}
}
latch.countDown();
} catch (Throwable ex) {
errorTracker.notifyOfError(ex);
} finally {
// logger.info("Exiting input thread readLastValue = " + readLastValue);
}
} }
/** /**
* Block until all of the items have been read from inputReader. * Get the next InputValue from this producer. The next value is
* either (1) the next value from the iterator, in which case the
* the return value is an InputValue containing that value, or (2)
* an InputValue with the EOF marker, indicating that the underlying
* iterator has been exhausted.
* *
* Note that this call doesn't actually read anything. You have to submit a thread * This function never fails -- it can be called endlessly and
* to actually execute run() directly. * while the underlying iterator has values it returns them, and then
* it returns a succession of EOF marking input values.
* *
* @throws InterruptedException * @return an InputValue containing the next value in the underlying
* iterator, or one with EOF marker, if the iterator is exhausted
*/ */
public void waitForDone() throws InterruptedException { public synchronized InputValue next() {
latch.await(); if ( readLastValue ) {
// we read the last value, so our value is the next
// EOF marker based on the last value. Make sure to
// update the last value so the markers keep incrementing
// their job ids
lastValue = lastValue.nextEOF();
return lastValue;
} else {
final InputType value = readNextItem();
if ( value == null ) {
if ( ! readLastValue )
throw new IllegalStateException("value == null but readLastValue is false!");
// add the EOF object so our consumer knows we are done in all inputs
// note that we do not increase inputID here, so that variable indicates the ID
// of the last real value read from the queue
lastValue = new InputValue(inputID + 1);
return lastValue;
} else {
// add the actual value to the outputQueue
return new InputValue(++inputID, value);
}
}
} }
/** /**

View File

@ -47,7 +47,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
final int bufferSize; final int bufferSize;
final int nThreads; final int nThreads;
final ExecutorService inputExecutor;
final ExecutorService masterExecutor; final ExecutorService masterExecutor;
final ExecutorService mapExecutor; final ExecutorService mapExecutor;
final Semaphore runningMapJobSlots; final Semaphore runningMapJobSlots;
@ -75,14 +74,12 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
this.nThreads = nThreads; this.nThreads = nThreads;
if ( nThreads == 1 ) { if ( nThreads == 1 ) {
this.mapExecutor = this.inputExecutor = this.masterExecutor = null; this.mapExecutor = this.masterExecutor = null;
runningMapJobSlots = null; runningMapJobSlots = null;
} else { } else {
this.mapExecutor = Executors.newFixedThreadPool(nThreads - 1, new NamedThreadFactory("NS-map-thread-%d"));
runningMapJobSlots = new Semaphore(this.bufferSize);
this.inputExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("NS-input-thread-%d"));
this.masterExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("NS-master-thread-%d")); this.masterExecutor = Executors.newSingleThreadExecutor(new NamedThreadFactory("NS-master-thread-%d"));
this.mapExecutor = Executors.newFixedThreadPool(nThreads, new NamedThreadFactory("NS-map-thread-%d"));
runningMapJobSlots = new Semaphore(this.bufferSize);
} }
} }
@ -111,7 +108,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
*/ */
public void shutdown() { public void shutdown() {
if ( nThreads > 1 ) { if ( nThreads > 1 ) {
shutdownExecutor("inputExecutor", inputExecutor);
shutdownExecutor("mapExecutor", mapExecutor); shutdownExecutor("mapExecutor", mapExecutor);
shutdownExecutor("masterExecutor", masterExecutor); shutdownExecutor("masterExecutor", masterExecutor);
} }
@ -323,7 +319,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
if ( errorTracker.hasAnErrorOccurred() ) { if ( errorTracker.hasAnErrorOccurred() ) {
masterExecutor.shutdownNow(); masterExecutor.shutdownNow();
mapExecutor.shutdownNow(); mapExecutor.shutdownNow();
inputExecutor.shutdownNow();
errorTracker.throwErrorIfPending(); errorTracker.throwErrorIfPending();
} }
} }
@ -351,15 +346,8 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
@Override @Override
public ReduceType call() { public ReduceType call() {
// a blocking queue that limits the number of input datum to the requested buffer size
// note we need +1 because we continue to enqueue the lastObject
final BlockingQueue<InputProducer<InputType>.InputValue> inputQueue
= new LinkedBlockingDeque<InputProducer<InputType>.InputValue>(bufferSize+1);
// Create the input producer and start it running // Create the input producer and start it running
final InputProducer<InputType> inputProducer = final InputProducer<InputType> inputProducer = new InputProducer<InputType>(inputReader);
new InputProducer<InputType>(inputReader, errorTracker, inputQueue);
inputExecutor.submit(inputProducer);
// a priority queue that stores up to bufferSize elements // a priority queue that stores up to bufferSize elements
// produced by completed map jobs. // produced by completed map jobs.
@ -376,7 +364,7 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
// acquire a slot to run a map job. Blocks if too many jobs are enqueued // acquire a slot to run a map job. Blocks if too many jobs are enqueued
runningMapJobSlots.acquire(); runningMapJobSlots.acquire();
mapExecutor.submit(new MapReduceJob(inputQueue, mapResultQueue, map, reducer)); mapExecutor.submit(new ReadMapReduceJob(inputProducer, mapResultQueue, map, reducer));
nSubmittedJobs++; nSubmittedJobs++;
} }
@ -402,10 +390,6 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
// logger.warn("waiting for final reduce"); // logger.warn("waiting for final reduce");
final ReduceType finalSum = reducer.waitForFinalReduce(); final ReduceType finalSum = reducer.waitForFinalReduce();
// now wait for the input provider thread to terminate
// logger.warn("waiting on inputProducer");
inputProducer.waitForDone();
// wait for all the map threads to finish by acquiring and then releasing all map job semaphores // wait for all the map threads to finish by acquiring and then releasing all map job semaphores
// logger.warn("waiting on map"); // logger.warn("waiting on map");
runningMapJobSlots.acquire(bufferSize); runningMapJobSlots.acquire(bufferSize);
@ -434,17 +418,17 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
} }
} }
private class MapReduceJob implements Runnable { private class ReadMapReduceJob implements Runnable {
final BlockingQueue<InputProducer<InputType>.InputValue> inputQueue; final InputProducer<InputType> inputProducer;
final PriorityBlockingQueue<MapResult<MapType>> mapResultQueue; final PriorityBlockingQueue<MapResult<MapType>> mapResultQueue;
final NSMapFunction<InputType, MapType> map; final NSMapFunction<InputType, MapType> map;
final Reducer<MapType, ReduceType> reducer; final Reducer<MapType, ReduceType> reducer;
private MapReduceJob(BlockingQueue<InputProducer<InputType>.InputValue> inputQueue, private ReadMapReduceJob(final InputProducer<InputType> inputProducer,
final PriorityBlockingQueue<MapResult<MapType>> mapResultQueue, final PriorityBlockingQueue<MapResult<MapType>> mapResultQueue,
final NSMapFunction<InputType, MapType> map, final NSMapFunction<InputType, MapType> map,
final Reducer<MapType, ReduceType> reducer) { final Reducer<MapType, ReduceType> reducer) {
this.inputQueue = inputQueue; this.inputProducer = inputProducer;
this.mapResultQueue = mapResultQueue; this.mapResultQueue = mapResultQueue;
this.map = map; this.map = map;
this.reducer = reducer; this.reducer = reducer;
@ -453,10 +437,10 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
@Override @Override
public void run() { public void run() {
try { try {
//debugPrint("Running MapReduceJob " + jobID); // get the next item from the input producer
final InputProducer<InputType>.InputValue inputWrapper = inputQueue.take(); final InputProducer<InputType>.InputValue inputWrapper = inputProducer.next();
final int jobID = inputWrapper.getId();
// depending on inputWrapper, actually do some work or not, putting result input result object
final MapResult<MapType> result; final MapResult<MapType> result;
if ( ! inputWrapper.isEOFMarker() ) { if ( ! inputWrapper.isEOFMarker() ) {
// just skip doing anything if we don't have work to do, which is possible // just skip doing anything if we don't have work to do, which is possible
@ -468,23 +452,19 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
final MapType mapValue = map.apply(input); final MapType mapValue = map.apply(input);
// enqueue the result into the mapResultQueue // enqueue the result into the mapResultQueue
result = new MapResult<MapType>(mapValue, jobID); result = new MapResult<MapType>(mapValue, inputWrapper.getId());
if ( progressFunction != null ) if ( progressFunction != null )
progressFunction.progress(input); progressFunction.progress(input);
} else { } else {
// push back the EOF marker so other waiting threads can read it
inputQueue.put(inputWrapper.nextEOF());
// if there's no input we push empty MapResults with jobIDs for synchronization with Reducer // if there's no input we push empty MapResults with jobIDs for synchronization with Reducer
result = new MapResult<MapType>(jobID); result = new MapResult<MapType>(inputWrapper.getId());
} }
mapResultQueue.put(result); mapResultQueue.put(result);
final int nReduced = reducer.reduceAsMuchAsPossible(mapResultQueue); final int nReduced = reducer.reduceAsMuchAsPossible(mapResultQueue);
} catch (Throwable ex) { } catch (Throwable ex) {
// logger.warn("Map job got exception " + ex);
errorTracker.notifyOfError(ex); errorTracker.notifyOfError(ex);
} finally { } finally {
// we finished a map job, release the job queue semaphore // we finished a map job, release the job queue semaphore

View File

@ -84,7 +84,7 @@ class Reducer<MapType, ReduceType> {
if ( nextMapResult == null ) { if ( nextMapResult == null ) {
return false; return false;
} else if ( nextMapResult.getJobID() < prevJobID + 1 ) { } else if ( nextMapResult.getJobID() < prevJobID + 1 ) {
throw new IllegalStateException("Next job ID " + nextMapResult.getJobID() + " is < previous job id " + prevJobID); throw new IllegalStateException("Next job ID " + nextMapResult.getJobID() + " is not < previous job id " + prevJobID);
} else if ( nextMapResult.getJobID() == prevJobID + 1 ) { } else if ( nextMapResult.getJobID() == prevJobID + 1 ) {
return true; return true;
} else { } else {

View File

@ -2,6 +2,10 @@ package org.broadinstitute.sting.gatk.traversals;
import com.google.java.contract.PreconditionError; import com.google.java.contract.PreconditionError;
import net.sf.samtools.*; import net.sf.samtools.*;
import org.broadinstitute.sting.commandline.Tags;
import org.broadinstitute.sting.gatk.datasources.reads.*;
import org.broadinstitute.sting.gatk.resourcemanagement.ThreadAllocation;
import org.broadinstitute.sting.utils.GenomeLocSortedSet;
import org.broadinstitute.sting.utils.activeregion.ActiveRegionReadState; import org.broadinstitute.sting.utils.activeregion.ActiveRegionReadState;
import org.broadinstitute.sting.utils.interval.IntervalMergingRule; import org.broadinstitute.sting.utils.interval.IntervalMergingRule;
import org.broadinstitute.sting.utils.interval.IntervalUtils; import org.broadinstitute.sting.utils.interval.IntervalUtils;
@ -12,11 +16,8 @@ import org.broadinstitute.sting.gatk.GenomeAnalysisEngine;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext; import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext; import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.datasources.providers.LocusShardDataProvider; import org.broadinstitute.sting.gatk.datasources.providers.LocusShardDataProvider;
import org.broadinstitute.sting.gatk.datasources.reads.MockLocusShard;
import org.broadinstitute.sting.gatk.datasources.reads.Shard;
import org.broadinstitute.sting.gatk.datasources.rmd.ReferenceOrderedDataSource; import org.broadinstitute.sting.gatk.datasources.rmd.ReferenceOrderedDataSource;
import org.broadinstitute.sting.gatk.executive.WindowMaker; import org.broadinstitute.sting.gatk.executive.WindowMaker;
import org.broadinstitute.sting.gatk.iterators.StingSAMIterator;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.ActiveRegionWalker; import org.broadinstitute.sting.gatk.walkers.ActiveRegionWalker;
import org.broadinstitute.sting.utils.GenomeLoc; import org.broadinstitute.sting.utils.GenomeLoc;
@ -101,7 +102,8 @@ public class TraverseActiveRegionsTest extends BaseTest {
private GenomeLocParser genomeLocParser; private GenomeLocParser genomeLocParser;
private List<GenomeLoc> intervals; private List<GenomeLoc> intervals;
private List<GATKSAMRecord> reads;
private static final String testBAM = "TraverseActiveRegionsTest.bam";
@BeforeClass @BeforeClass
private void init() throws FileNotFoundException { private void init() throws FileNotFoundException {
@ -110,6 +112,13 @@ public class TraverseActiveRegionsTest extends BaseTest {
genomeLocParser = new GenomeLocParser(dictionary); genomeLocParser = new GenomeLocParser(dictionary);
// TODO: test shard boundaries // TODO: test shard boundaries
// TODO: reads with indels
// TODO: reads which span many regions
// TODO: reads which are partially between intervals (in/outside extension)
// TODO: duplicate reads
// TODO: should we assign reads which are completely outside intervals but within extension?
intervals = new ArrayList<GenomeLoc>(); intervals = new ArrayList<GenomeLoc>();
intervals.add(genomeLocParser.createGenomeLoc("1", 10, 20)); intervals.add(genomeLocParser.createGenomeLoc("1", 10, 20));
@ -117,24 +126,34 @@ public class TraverseActiveRegionsTest extends BaseTest {
intervals.add(genomeLocParser.createGenomeLoc("1", 1000, 1999)); intervals.add(genomeLocParser.createGenomeLoc("1", 1000, 1999));
intervals.add(genomeLocParser.createGenomeLoc("1", 2000, 2999)); intervals.add(genomeLocParser.createGenomeLoc("1", 2000, 2999));
intervals.add(genomeLocParser.createGenomeLoc("1", 10000, 20000)); intervals.add(genomeLocParser.createGenomeLoc("1", 10000, 20000));
intervals.add(genomeLocParser.createGenomeLoc("1", 249250600, 249250621));
intervals.add(genomeLocParser.createGenomeLoc("2", 1, 100)); intervals.add(genomeLocParser.createGenomeLoc("2", 1, 100));
intervals.add(genomeLocParser.createGenomeLoc("20", 10000, 10100)); intervals.add(genomeLocParser.createGenomeLoc("20", 10000, 10100));
intervals = IntervalUtils.sortAndMergeIntervals(genomeLocParser, intervals, IntervalMergingRule.OVERLAPPING_ONLY).toList(); intervals = IntervalUtils.sortAndMergeIntervals(genomeLocParser, intervals, IntervalMergingRule.OVERLAPPING_ONLY).toList();
reads = new ArrayList<GATKSAMRecord>(); List<GATKSAMRecord> reads = new ArrayList<GATKSAMRecord>();
reads.add(buildSAMRecord("simple", "1", 100, 200)); reads.add(buildSAMRecord("simple", "1", 100, 200));
reads.add(buildSAMRecord("overlap_equal", "1", 10, 20)); reads.add(buildSAMRecord("overlap_equal", "1", 10, 20));
reads.add(buildSAMRecord("overlap_unequal", "1", 10, 21)); reads.add(buildSAMRecord("overlap_unequal", "1", 10, 21));
reads.add(buildSAMRecord("boundary_equal", "1", 1990, 2009)); reads.add(buildSAMRecord("boundary_equal", "1", 1990, 2009));
reads.add(buildSAMRecord("boundary_unequal", "1", 1990, 2008)); reads.add(buildSAMRecord("boundary_unequal", "1", 1990, 2008));
reads.add(buildSAMRecord("boundary_1_pre", "1", 1950, 2000));
reads.add(buildSAMRecord("boundary_1_post", "1", 1999, 2050));
reads.add(buildSAMRecord("extended_and_np", "1", 990, 1990)); reads.add(buildSAMRecord("extended_and_np", "1", 990, 1990));
reads.add(buildSAMRecord("outside_intervals", "1", 5000, 6000)); reads.add(buildSAMRecord("outside_intervals", "1", 5000, 6000));
reads.add(buildSAMRecord("end_of_chr1", "1", 249250600, 249250700));
reads.add(buildSAMRecord("simple20", "20", 10025, 10075)); reads.add(buildSAMRecord("simple20", "20", 10025, 10075));
// required by LocusIteratorByState, and I prefer to list them in test case order above createBAM(reads);
ReadUtils.sortReadsByCoordinate(reads); }
private void createBAM(List<GATKSAMRecord> reads) {
File outFile = new File(testBAM);
outFile.deleteOnExit();
SAMFileWriter out = new SAMFileWriterFactory().makeBAMWriter(reads.get(0).getHeader(), true, outFile);
for (GATKSAMRecord read : ReadUtils.sortReadsByCoordinate(reads)) {
out.addAlignment(read);
}
out.close();
} }
@Test @Test
@ -148,7 +167,7 @@ public class TraverseActiveRegionsTest extends BaseTest {
private List<GenomeLoc> getIsActiveIntervals(DummyActiveRegionWalker walker, List<GenomeLoc> intervals) { private List<GenomeLoc> getIsActiveIntervals(DummyActiveRegionWalker walker, List<GenomeLoc> intervals) {
List<GenomeLoc> activeIntervals = new ArrayList<GenomeLoc>(); List<GenomeLoc> activeIntervals = new ArrayList<GenomeLoc>();
for (LocusShardDataProvider dataProvider : createDataProviders(intervals)) { for (LocusShardDataProvider dataProvider : createDataProviders(intervals, testBAM)) {
t.traverse(walker, dataProvider, 0); t.traverse(walker, dataProvider, 0);
activeIntervals.addAll(walker.isActiveCalls); activeIntervals.addAll(walker.isActiveCalls);
} }
@ -230,73 +249,26 @@ public class TraverseActiveRegionsTest extends BaseTest {
// overlap_unequal: Primary in 1:1-999 // overlap_unequal: Primary in 1:1-999
// boundary_equal: Non-Primary in 1:1000-1999, Primary in 1:2000-2999 // boundary_equal: Non-Primary in 1:1000-1999, Primary in 1:2000-2999
// boundary_unequal: Primary in 1:1000-1999, Non-Primary in 1:2000-2999 // boundary_unequal: Primary in 1:1000-1999, Non-Primary in 1:2000-2999
// boundary_1_pre: Primary in 1:1000-1999, Non-Primary in 1:2000-2999
// boundary_1_post: Non-Primary in 1:1000-1999, Primary in 1:2000-2999
// extended_and_np: Non-Primary in 1:1-999, Primary in 1:1000-1999, Extended in 1:2000-2999 // extended_and_np: Non-Primary in 1:1-999, Primary in 1:1000-1999, Extended in 1:2000-2999
// outside_intervals: none // outside_intervals: none
// end_of_chr1: Primary in 1:249250600-249250621
// simple20: Primary in 20:10000-10100 // simple20: Primary in 20:10000-10100
Map<GenomeLoc, ActiveRegion> activeRegions = getActiveRegions(walker, intervals); Map<GenomeLoc, ActiveRegion> activeRegions = getActiveRegions(walker, intervals);
ActiveRegion region; ActiveRegion region;
region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 1, 999)); region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 1, 999));
verifyReadMapping(region, "simple", "overlap_equal", "overlap_unequal");
getRead(region, "simple");
getRead(region, "overlap_equal");
getRead(region, "overlap_unequal");
verifyReadNotPlaced(region, "boundary_equal");
verifyReadNotPlaced(region, "boundary_unequal");
verifyReadNotPlaced(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
verifyReadNotPlaced(region, "end_of_chr1");
verifyReadNotPlaced(region, "simple20");
region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 1000, 1999)); region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 1000, 1999));
verifyReadMapping(region, "boundary_unequal", "extended_and_np", "boundary_1_pre");
verifyReadNotPlaced(region, "simple");
verifyReadNotPlaced(region, "overlap_equal");
verifyReadNotPlaced(region, "overlap_unequal");
verifyReadNotPlaced(region, "boundary_equal");
getRead(region, "boundary_unequal");
getRead(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
verifyReadNotPlaced(region, "end_of_chr1");
verifyReadNotPlaced(region, "simple20");
region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 2000, 2999)); region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 2000, 2999));
verifyReadMapping(region, "boundary_equal", "boundary_1_post");
verifyReadNotPlaced(region, "simple");
verifyReadNotPlaced(region, "overlap_equal");
verifyReadNotPlaced(region, "overlap_unequal");
getRead(region, "boundary_equal");
verifyReadNotPlaced(region, "boundary_unequal");
verifyReadNotPlaced(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
verifyReadNotPlaced(region, "end_of_chr1");
verifyReadNotPlaced(region, "simple20");
region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 249250600, 249250621));
verifyReadNotPlaced(region, "simple");
verifyReadNotPlaced(region, "overlap_equal");
verifyReadNotPlaced(region, "overlap_unequal");
verifyReadNotPlaced(region, "boundary_equal");
verifyReadNotPlaced(region, "boundary_unequal");
verifyReadNotPlaced(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
getRead(region, "end_of_chr1");
verifyReadNotPlaced(region, "simple20");
region = activeRegions.get(genomeLocParser.createGenomeLoc("20", 10000, 10100)); region = activeRegions.get(genomeLocParser.createGenomeLoc("20", 10000, 10100));
verifyReadMapping(region, "simple20");
verifyReadNotPlaced(region, "simple");
verifyReadNotPlaced(region, "overlap_equal");
verifyReadNotPlaced(region, "overlap_unequal");
verifyReadNotPlaced(region, "boundary_equal");
verifyReadNotPlaced(region, "boundary_unequal");
verifyReadNotPlaced(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
verifyReadNotPlaced(region, "end_of_chr1");
getRead(region, "simple20");
} }
@Test @Test
@ -314,73 +286,26 @@ public class TraverseActiveRegionsTest extends BaseTest {
// overlap_unequal: Primary in 1:1-999 // overlap_unequal: Primary in 1:1-999
// boundary_equal: Non-Primary in 1:1000-1999, Primary in 1:2000-2999 // boundary_equal: Non-Primary in 1:1000-1999, Primary in 1:2000-2999
// boundary_unequal: Primary in 1:1000-1999, Non-Primary in 1:2000-2999 // boundary_unequal: Primary in 1:1000-1999, Non-Primary in 1:2000-2999
// boundary_1_pre: Primary in 1:1000-1999, Non-Primary in 1:2000-2999
// boundary_1_post: Non-Primary in 1:1000-1999, Primary in 1:2000-2999
// extended_and_np: Non-Primary in 1:1-999, Primary in 1:1000-1999, Extended in 1:2000-2999 // extended_and_np: Non-Primary in 1:1-999, Primary in 1:1000-1999, Extended in 1:2000-2999
// outside_intervals: none // outside_intervals: none
// end_of_chr1: Primary in 1:249250600-249250621
// simple20: Primary in 20:10000-10100 // simple20: Primary in 20:10000-10100
Map<GenomeLoc, ActiveRegion> activeRegions = getActiveRegions(walker, intervals); Map<GenomeLoc, ActiveRegion> activeRegions = getActiveRegions(walker, intervals);
ActiveRegion region; ActiveRegion region;
region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 1, 999)); region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 1, 999));
verifyReadMapping(region, "simple", "overlap_equal", "overlap_unequal", "extended_and_np");
getRead(region, "simple");
getRead(region, "overlap_equal");
getRead(region, "overlap_unequal");
verifyReadNotPlaced(region, "boundary_equal");
verifyReadNotPlaced(region, "boundary_unequal");
getRead(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
verifyReadNotPlaced(region, "end_of_chr1");
verifyReadNotPlaced(region, "simple20");
region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 1000, 1999)); region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 1000, 1999));
verifyReadMapping(region, "boundary_equal", "boundary_unequal", "extended_and_np", "boundary_1_pre", "boundary_1_post");
verifyReadNotPlaced(region, "simple");
verifyReadNotPlaced(region, "overlap_equal");
verifyReadNotPlaced(region, "overlap_unequal");
getRead(region, "boundary_equal");
getRead(region, "boundary_unequal");
getRead(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
verifyReadNotPlaced(region, "end_of_chr1");
verifyReadNotPlaced(region, "simple20");
region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 2000, 2999)); region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 2000, 2999));
verifyReadMapping(region, "boundary_equal", "boundary_unequal", "boundary_1_pre", "boundary_1_post");
verifyReadNotPlaced(region, "simple");
verifyReadNotPlaced(region, "overlap_equal");
verifyReadNotPlaced(region, "overlap_unequal");
getRead(region, "boundary_equal");
getRead(region, "boundary_unequal");
verifyReadNotPlaced(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
verifyReadNotPlaced(region, "end_of_chr1");
verifyReadNotPlaced(region, "simple20");
region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 249250600, 249250621));
verifyReadNotPlaced(region, "simple");
verifyReadNotPlaced(region, "overlap_equal");
verifyReadNotPlaced(region, "overlap_unequal");
verifyReadNotPlaced(region, "boundary_equal");
verifyReadNotPlaced(region, "boundary_unequal");
verifyReadNotPlaced(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
getRead(region, "end_of_chr1");
verifyReadNotPlaced(region, "simple20");
region = activeRegions.get(genomeLocParser.createGenomeLoc("20", 10000, 10100)); region = activeRegions.get(genomeLocParser.createGenomeLoc("20", 10000, 10100));
verifyReadMapping(region, "simple20");
verifyReadNotPlaced(region, "simple");
verifyReadNotPlaced(region, "overlap_equal");
verifyReadNotPlaced(region, "overlap_unequal");
verifyReadNotPlaced(region, "boundary_equal");
verifyReadNotPlaced(region, "boundary_unequal");
verifyReadNotPlaced(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
verifyReadNotPlaced(region, "end_of_chr1");
getRead(region, "simple20");
} }
@Test @Test
@ -399,73 +324,26 @@ public class TraverseActiveRegionsTest extends BaseTest {
// overlap_unequal: Primary in 1:1-999 // overlap_unequal: Primary in 1:1-999
// boundary_equal: Non-Primary in 1:1000-1999, Primary in 1:2000-2999 // boundary_equal: Non-Primary in 1:1000-1999, Primary in 1:2000-2999
// boundary_unequal: Primary in 1:1000-1999, Non-Primary in 1:2000-2999 // boundary_unequal: Primary in 1:1000-1999, Non-Primary in 1:2000-2999
// boundary_1_pre: Primary in 1:1000-1999, Non-Primary in 1:2000-2999
// boundary_1_post: Non-Primary in 1:1000-1999, Primary in 1:2000-2999
// extended_and_np: Non-Primary in 1:1-999, Primary in 1:1000-1999, Extended in 1:2000-2999 // extended_and_np: Non-Primary in 1:1-999, Primary in 1:1000-1999, Extended in 1:2000-2999
// outside_intervals: none // outside_intervals: none
// end_of_chr1: Primary in 1:249250600-249250621
// simple20: Primary in 20:10000-10100 // simple20: Primary in 20:10000-10100
Map<GenomeLoc, ActiveRegion> activeRegions = getActiveRegions(walker, intervals); Map<GenomeLoc, ActiveRegion> activeRegions = getActiveRegions(walker, intervals);
ActiveRegion region; ActiveRegion region;
region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 1, 999)); region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 1, 999));
verifyReadMapping(region, "simple", "overlap_equal", "overlap_unequal", "extended_and_np");
getRead(region, "simple");
getRead(region, "overlap_equal");
getRead(region, "overlap_unequal");
verifyReadNotPlaced(region, "boundary_equal");
verifyReadNotPlaced(region, "boundary_unequal");
getRead(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
verifyReadNotPlaced(region, "end_of_chr1");
verifyReadNotPlaced(region, "simple20");
region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 1000, 1999)); region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 1000, 1999));
verifyReadMapping(region, "boundary_equal", "boundary_unequal", "extended_and_np", "boundary_1_pre", "boundary_1_post");
verifyReadNotPlaced(region, "simple");
verifyReadNotPlaced(region, "overlap_equal");
verifyReadNotPlaced(region, "overlap_unequal");
getRead(region, "boundary_equal");
getRead(region, "boundary_unequal");
getRead(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
verifyReadNotPlaced(region, "end_of_chr1");
verifyReadNotPlaced(region, "simple20");
region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 2000, 2999)); region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 2000, 2999));
verifyReadMapping(region, "boundary_equal", "boundary_unequal", "extended_and_np", "boundary_1_pre", "boundary_1_post");
verifyReadNotPlaced(region, "simple");
verifyReadNotPlaced(region, "overlap_equal");
verifyReadNotPlaced(region, "overlap_unequal");
getRead(region, "boundary_equal");
getRead(region, "boundary_unequal");
getRead(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
verifyReadNotPlaced(region, "end_of_chr1");
verifyReadNotPlaced(region, "simple20");
region = activeRegions.get(genomeLocParser.createGenomeLoc("1", 249250600, 249250621));
verifyReadNotPlaced(region, "simple");
verifyReadNotPlaced(region, "overlap_equal");
verifyReadNotPlaced(region, "overlap_unequal");
verifyReadNotPlaced(region, "boundary_equal");
verifyReadNotPlaced(region, "boundary_unequal");
verifyReadNotPlaced(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
getRead(region, "end_of_chr1");
verifyReadNotPlaced(region, "simple20");
region = activeRegions.get(genomeLocParser.createGenomeLoc("20", 10000, 10100)); region = activeRegions.get(genomeLocParser.createGenomeLoc("20", 10000, 10100));
verifyReadMapping(region, "simple20");
verifyReadNotPlaced(region, "simple");
verifyReadNotPlaced(region, "overlap_equal");
verifyReadNotPlaced(region, "overlap_unequal");
verifyReadNotPlaced(region, "boundary_equal");
verifyReadNotPlaced(region, "boundary_unequal");
verifyReadNotPlaced(region, "extended_and_np");
verifyReadNotPlaced(region, "outside_intervals");
verifyReadNotPlaced(region, "end_of_chr1");
getRead(region, "simple20");
} }
@Test @Test
@ -473,25 +351,19 @@ public class TraverseActiveRegionsTest extends BaseTest {
// TODO // TODO
} }
private void verifyReadNotPlaced(ActiveRegion region, String readName) { private void verifyReadMapping(ActiveRegion region, String... reads) {
Collection<String> wantReads = new ArrayList<String>(Arrays.asList(reads));
for (SAMRecord read : region.getReads()) { for (SAMRecord read : region.getReads()) {
if (read.getReadName().equals(readName)) String regionReadName = read.getReadName();
Assert.fail("Read " + readName + " found in active region " + region); Assert.assertTrue(wantReads.contains(regionReadName), "Read " + regionReadName + " assigned to active region " + region);
} wantReads.remove(regionReadName);
}
private SAMRecord getRead(ActiveRegion region, String readName) {
for (SAMRecord read : region.getReads()) {
if (read.getReadName().equals(readName))
return read;
} }
Assert.fail("Read " + readName + " not assigned to active region " + region); Assert.assertTrue(wantReads.isEmpty(), "Reads missing in active region " + region);
return null;
} }
private Map<GenomeLoc, ActiveRegion> getActiveRegions(DummyActiveRegionWalker walker, List<GenomeLoc> intervals) { private Map<GenomeLoc, ActiveRegion> getActiveRegions(DummyActiveRegionWalker walker, List<GenomeLoc> intervals) {
for (LocusShardDataProvider dataProvider : createDataProviders(intervals)) for (LocusShardDataProvider dataProvider : createDataProviders(intervals, testBAM))
t.traverse(walker, dataProvider, 0); t.traverse(walker, dataProvider, 0);
t.endTraversal(walker, 0); t.endTraversal(walker, 0);
@ -536,7 +408,7 @@ public class TraverseActiveRegionsTest extends BaseTest {
// copied from LocusViewTemplate // copied from LocusViewTemplate
protected GATKSAMRecord buildSAMRecord(String readName, String contig, int alignmentStart, int alignmentEnd) { protected GATKSAMRecord buildSAMRecord(String readName, String contig, int alignmentStart, int alignmentEnd) {
SAMFileHeader header = new SAMFileHeader(); SAMFileHeader header = ArtificialSAMUtils.createDefaultReadGroup(new SAMFileHeader(), "test", "test");
header.setSequenceDictionary(dictionary); header.setSequenceDictionary(dictionary);
GATKSAMRecord record = new GATKSAMRecord(header); GATKSAMRecord record = new GATKSAMRecord(header);
@ -548,23 +420,28 @@ public class TraverseActiveRegionsTest extends BaseTest {
int len = alignmentEnd - alignmentStart + 1; int len = alignmentEnd - alignmentStart + 1;
cigar.add(new CigarElement(len, CigarOperator.M)); cigar.add(new CigarElement(len, CigarOperator.M));
record.setCigar(cigar); record.setCigar(cigar);
record.setReadBases(new byte[len]); record.setReadString(new String(new char[len]).replace("\0", "A"));
record.setBaseQualities(new byte[len]); record.setBaseQualities(new byte[len]);
return record; return record;
} }
private List<LocusShardDataProvider> createDataProviders(List<GenomeLoc> intervals) { private List<LocusShardDataProvider> createDataProviders(List<GenomeLoc> intervals, String bamFile) {
GenomeAnalysisEngine engine = new GenomeAnalysisEngine(); GenomeAnalysisEngine engine = new GenomeAnalysisEngine();
engine.setGenomeLocParser(genomeLocParser); engine.setGenomeLocParser(genomeLocParser);
t.initialize(engine); t.initialize(engine);
StingSAMIterator iterator = ArtificialSAMUtils.createReadIterator(new ArrayList<SAMRecord>(reads)); Collection<SAMReaderID> samFiles = new ArrayList<SAMReaderID>();
Shard shard = new MockLocusShard(genomeLocParser, intervals); SAMReaderID readerID = new SAMReaderID(new File(bamFile), new Tags());
samFiles.add(readerID);
SAMDataSource dataSource = new SAMDataSource(samFiles, new ThreadAllocation(), null, genomeLocParser);
List<LocusShardDataProvider> providers = new ArrayList<LocusShardDataProvider>(); List<LocusShardDataProvider> providers = new ArrayList<LocusShardDataProvider>();
for (WindowMaker.WindowMakerIterator window : new WindowMaker(shard, genomeLocParser, iterator, shard.getGenomeLocs())) { for (Shard shard : dataSource.createShardIteratorOverIntervals(new GenomeLocSortedSet(genomeLocParser, intervals), new LocusShardBalancer())) {
providers.add(new LocusShardDataProvider(shard, shard.getReadProperties(), genomeLocParser, window.getLocus(), window, reference, new ArrayList<ReferenceOrderedDataSource>())); for (WindowMaker.WindowMakerIterator window : new WindowMaker(shard, genomeLocParser, dataSource.seek(shard), shard.getGenomeLocs())) {
providers.add(new LocusShardDataProvider(shard, shard.getReadProperties(), genomeLocParser, window.getLocus(), window, reference, new ArrayList<ReferenceOrderedDataSource>()));
}
} }
return providers; return providers;

View File

@ -1,19 +1,13 @@
package org.broadinstitute.sting.utils.nanoScheduler; package org.broadinstitute.sting.utils.nanoScheduler;
import org.broadinstitute.sting.BaseTest; import org.broadinstitute.sting.BaseTest;
import org.broadinstitute.sting.utils.MultiThreadedErrorTracker;
import org.testng.Assert; import org.testng.Assert;
import org.testng.annotations.DataProvider; import org.testng.annotations.DataProvider;
import org.testng.annotations.Test; import org.testng.annotations.Test;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.Semaphore;
/** /**
* UnitTests for the InputProducer * UnitTests for the InputProducer
@ -42,34 +36,23 @@ public class InputProducerUnitTest extends BaseTest {
final List<Integer> elements = new ArrayList<Integer>(nElements); final List<Integer> elements = new ArrayList<Integer>(nElements);
for ( int i = 0; i < nElements; i++ ) elements.add(i); for ( int i = 0; i < nElements; i++ ) elements.add(i);
final LinkedBlockingDeque<InputProducer<Integer>.InputValue> readQueue = final InputProducer<Integer> ip = new InputProducer<Integer>(elements.iterator());
new LinkedBlockingDeque<InputProducer<Integer>.InputValue>(queueSize);
final InputProducer<Integer> ip = new InputProducer<Integer>(elements.iterator(), new MultiThreadedErrorTracker(), readQueue);
final ExecutorService es = Executors.newSingleThreadExecutor();
Assert.assertFalse(ip.allInputsHaveBeenRead(), "InputProvider said that all inputs have been read, but I haven't started reading yet"); Assert.assertFalse(ip.allInputsHaveBeenRead(), "InputProvider said that all inputs have been read, but I haven't started reading yet");
Assert.assertEquals(ip.getNumInputValues(), -1, "InputProvider told me that the queue was done, but I haven't started reading yet"); Assert.assertEquals(ip.getNumInputValues(), -1, "InputProvider told me that the queue was done, but I haven't started reading yet");
es.submit(ip);
int lastValue = -1; int lastValue = -1;
int nRead = 0; int nRead = 0;
while ( true ) { while ( ip.hasNext() ) {
final int nTotalElements = ip.getNumInputValues(); final int nTotalElements = ip.getNumInputValues();
final int observedQueueSize = readQueue.size();
Assert.assertTrue(observedQueueSize <= queueSize,
"Reader is enqueuing more elements " + observedQueueSize + " than allowed " + queueSize);
if ( nRead + observedQueueSize < nElements ) if ( nRead < nElements )
Assert.assertEquals(nTotalElements, -1, "getNumInputValues should have returned -1 with not all elements read"); Assert.assertEquals(nTotalElements, -1, "getNumInputValues should have returned -1 with not all elements read");
// note, cannot test else case because elements input could have emptied between calls // note, cannot test else case because elements input could have emptied between calls
final InputProducer<Integer>.InputValue value = readQueue.take(); final InputProducer<Integer>.InputValue value = ip.next();
if ( value.isEOFMarker() ) { if ( value.isEOFMarker() ) {
Assert.assertEquals(nRead, nElements, "Number of input values " + nRead + " not all that are expected " + nElements); Assert.assertEquals(nRead, nElements, "Number of input values " + nRead + " not all that are expected " + nElements);
Assert.assertEquals(readQueue.size(), 0, "Last queue element found but queue contains more values!");
break; break;
} else { } else {
Assert.assertTrue(lastValue < value.getValue(), "Read values coming out of order!"); Assert.assertTrue(lastValue < value.getValue(), "Read values coming out of order!");
@ -82,65 +65,5 @@ public class InputProducerUnitTest extends BaseTest {
Assert.assertTrue(ip.allInputsHaveBeenRead(), "InputProvider said that all inputs haven't been read, but I read them all"); Assert.assertTrue(ip.allInputsHaveBeenRead(), "InputProvider said that all inputs haven't been read, but I read them all");
Assert.assertEquals(ip.getNumInputValues(), nElements, "Wrong number of total elements getNumInputValues"); Assert.assertEquals(ip.getNumInputValues(), nElements, "Wrong number of total elements getNumInputValues");
es.shutdownNow();
}
@Test(enabled = true, dataProvider = "InputProducerTest", timeOut = NanoSchedulerUnitTest.NANO_SCHEDULE_MAX_RUNTIME)
public void testInputProducerLocking(final int nElements, final int queueSize) throws InterruptedException {
final List<Integer> elements = new ArrayList<Integer>(nElements);
for ( int i = 0; i < nElements; i++ ) elements.add(i);
final LinkedBlockingDeque<InputProducer<Integer>.InputValue> readQueue =
new LinkedBlockingDeque<InputProducer<Integer>.InputValue>();
final InputProducer<Integer> ip = new InputProducer<Integer>(elements.iterator(), new MultiThreadedErrorTracker(), readQueue);
final ExecutorService es = Executors.newSingleThreadExecutor();
es.submit(ip);
ip.waitForDone();
Assert.assertEquals(ip.getNumInputValues(), nElements, "InputProvider told me that the queue was done, but I haven't started reading yet");
Assert.assertEquals(readQueue.size(), nElements + 1, "readQueue should have had all elements read into it");
}
final static class BlockingIterator<T> implements Iterator<T> {
final Semaphore blockNext = new Semaphore(0);
final Semaphore blockOnNext = new Semaphore(0);
final Iterator<T> underlyingIterator;
BlockingIterator(Iterator<T> underlyingIterator) {
this.underlyingIterator = underlyingIterator;
}
public void allowNext() {
blockNext.release(1);
}
public void blockTillNext() throws InterruptedException {
blockOnNext.acquire(1);
}
@Override
public boolean hasNext() {
return underlyingIterator.hasNext();
}
@Override
public T next() {
try {
blockNext.acquire(1);
T value = underlyingIterator.next();
blockOnNext.release(1);
return value;
} catch (InterruptedException ex) {
throw new RuntimeException(ex);
}
}
@Override
public void remove() {
throw new UnsupportedOperationException("x");
}
} }
} }