Mark I NanoScheduling TraverseLoci

-- Refactored TraverseLoci into old linear version and nano scheduling version
-- Temp. GATK argument to say how many nano threads to use
-- Can efficiently scale to 3 threads before blocking on input
This commit is contained in:
Mark DePristo 2012-09-04 13:47:40 -04:00
parent 757e6a0160
commit d503ed97ab
6 changed files with 293 additions and 42 deletions

View File

@ -27,7 +27,6 @@ package org.broadinstitute.sting.gatk;
import net.sf.picard.filter.SamRecordFilter;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.TreeMap;
@ -119,11 +118,18 @@ public class ReadMetrics implements Cloneable {
return nRecords;
}
/**
* Increments the number of 'iterations' (one call of filter/map/reduce sequence) completed.
*/
public void incrementNumIterations(final long by) {
nRecords += by;
}
/**
* Increments the number of 'iterations' (one call of filter/map/reduce sequence) completed.
*/
public void incrementNumIterations() {
nRecords++;
incrementNumIterations(1);
}
public long getNumReadsSeen() {

View File

@ -313,6 +313,10 @@ public class GATKArgumentCollection {
@Argument(fullName = "num_bam_file_handles", shortName = "bfh", doc="The total number of BAM file handles to keep open simultaneously", required=false)
public Integer numberOfBAMFileHandles = null;
@Argument(fullName="nanoThreads", shortName = "nanoThreads", doc="NanoThreading", required = false)
@Hidden
public int nanoThreads = 1;
@Input(fullName = "read_group_black_list", shortName="rgbl", doc="Filters out read groups matching <TAG>:<STRING> or a .txt file containing the filter strings one per line.", required = false)
public List<String> readGroupBlackList = null;

View File

@ -146,7 +146,8 @@ public abstract class MicroScheduler implements MicroSchedulerMBean {
if (walker instanceof ReadWalker) {
traversalEngine = numThreads > 1 ? new TraverseReadsNano(numThreads) : new TraverseReads();
} else if (walker instanceof LocusWalker) {
traversalEngine = new TraverseLoci();
// TODO -- refactor to use better interface
traversalEngine = engine.getArguments().nanoThreads > 1 ? new TraverseLociNano(engine.getArguments().nanoThreads) : new TraverseLociLinear();
} else if (walker instanceof DuplicateWalker) {
traversalEngine = new TraverseDuplicates();
} else if (walker instanceof ReadPairWalker) {

View File

@ -3,9 +3,7 @@ package org.broadinstitute.sting.gatk.traversals;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.gatk.WalkerManager;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.datasources.providers.*;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.DataSource;
import org.broadinstitute.sting.gatk.walkers.LocusWalker;
import org.broadinstitute.sting.gatk.walkers.Walker;
@ -15,28 +13,42 @@ import org.broadinstitute.sting.utils.pileup.ReadBackedPileupImpl;
/**
* A simple solution to iterating over all reference positions over a series of genomic locations.
*/
public class TraverseLoci<M,T> extends TraversalEngine<M,T,LocusWalker<M,T>,LocusShardDataProvider> {
public abstract class TraverseLociBase<M,T> extends TraversalEngine<M,T,LocusWalker<M,T>,LocusShardDataProvider> {
/**
* our log, which we want to capture anything from this class
*/
protected static final Logger logger = Logger.getLogger(TraversalEngine.class);
@Override
protected String getTraversalType() {
protected final String getTraversalType() {
return "sites";
}
protected static class TraverseResults<T> {
final int numIterations;
final T reduceResult;
public TraverseResults(int numIterations, T reduceResult) {
this.numIterations = numIterations;
this.reduceResult = reduceResult;
}
}
protected abstract TraverseResults<T> traverse( final LocusWalker<M,T> walker,
final LocusView locusView,
final LocusReferenceView referenceView,
final ReferenceOrderedView referenceOrderedDataView,
final T sum);
@Override
public T traverse( LocusWalker<M,T> walker,
LocusShardDataProvider dataProvider,
T sum) {
logger.debug(String.format("TraverseLoci.traverse: Shard is %s", dataProvider));
logger.debug(String.format("TraverseLociBase.traverse: Shard is %s", dataProvider));
LocusView locusView = getLocusView( walker, dataProvider );
boolean done = false;
final LocusView locusView = getLocusView( walker, dataProvider );
if ( locusView.hasNext() ) { // trivial optimization to avoid unnecessary processing when there's nothing here at all
//ReferenceOrderedView referenceOrderedDataView = new ReferenceOrderedView( dataProvider );
ReferenceOrderedView referenceOrderedDataView = null;
if ( WalkerManager.getWalkerDataSource(walker) != DataSource.REFERENCE_ORDERED_DATA )
@ -44,43 +56,23 @@ public class TraverseLoci<M,T> extends TraversalEngine<M,T,LocusWalker<M,T>,Locu
else
referenceOrderedDataView = (RodLocusView)locusView;
LocusReferenceView referenceView = new LocusReferenceView( walker, dataProvider );
final LocusReferenceView referenceView = new LocusReferenceView( walker, dataProvider );
// We keep processing while the next reference location is within the interval
while( locusView.hasNext() && ! done ) {
AlignmentContext locus = locusView.next();
GenomeLoc location = locus.getLocation();
dataProvider.getShard().getReadMetrics().incrementNumIterations();
// create reference context. Note that if we have a pileup of "extended events", the context will
// hold the (longest) stretch of deleted reference bases (if deletions are present in the pileup).
ReferenceContext refContext = referenceView.getReferenceContext(location);
// Iterate forward to get all reference ordered data covering this location
final RefMetaDataTracker tracker = referenceOrderedDataView.getReferenceOrderedDataAtLocus(locus.getLocation(), refContext);
final boolean keepMeP = walker.filter(tracker, refContext, locus);
if (keepMeP) {
M x = walker.map(tracker, refContext, locus);
sum = walker.reduce(x, sum);
done = walker.isDone();
}
printProgress(dataProvider.getShard(),locus.getLocation());
}
final TraverseResults<T> result = traverse( walker, locusView, referenceView, referenceOrderedDataView, sum );
sum = result.reduceResult;
dataProvider.getShard().getReadMetrics().incrementNumIterations(result.numIterations);
}
// We have a final map call to execute here to clean up the skipped based from the
// last position in the ROD to that in the interval
if ( WalkerManager.getWalkerDataSource(walker) == DataSource.REFERENCE_ORDERED_DATA && ! walker.isDone() ) {
// only do this if the walker isn't done!
RodLocusView rodLocusView = (RodLocusView)locusView;
long nSkipped = rodLocusView.getLastSkippedBases();
final RodLocusView rodLocusView = (RodLocusView)locusView;
final long nSkipped = rodLocusView.getLastSkippedBases();
if ( nSkipped > 0 ) {
GenomeLoc site = rodLocusView.getLocOneBeyondShard();
AlignmentContext ac = new AlignmentContext(site, new ReadBackedPileupImpl(site), nSkipped);
M x = walker.map(null, null, ac);
final GenomeLoc site = rodLocusView.getLocOneBeyondShard();
final AlignmentContext ac = new AlignmentContext(site, new ReadBackedPileupImpl(site), nSkipped);
final M x = walker.map(null, null, ac);
sum = walker.reduce(x, sum);
}
}
@ -90,14 +82,14 @@ public class TraverseLoci<M,T> extends TraversalEngine<M,T,LocusWalker<M,T>,Locu
/**
* Gets the best view of loci for this walker given the available data. The view will function as a 'trigger track'
* of sorts, providing a consistent interface so that TraverseLoci doesn't need to be reimplemented for any new datatype
* of sorts, providing a consistent interface so that TraverseLociBase doesn't need to be reimplemented for any new datatype
* that comes along.
* @param walker walker to interrogate.
* @param dataProvider Data which which to drive the locus view.
* @return A view of the locus data, where one iteration of the locus view maps to one iteration of the traversal.
*/
private LocusView getLocusView( Walker<M,T> walker, LocusShardDataProvider dataProvider ) {
DataSource dataSource = WalkerManager.getWalkerDataSource(walker);
final DataSource dataSource = WalkerManager.getWalkerDataSource(walker);
if( dataSource == DataSource.READS )
return new CoveredLocusView(dataProvider);
else if( dataSource == DataSource.REFERENCE ) //|| ! GenomeAnalysisEngine.instance.getArguments().enableRodWalkers )

View File

@ -0,0 +1,48 @@
package org.broadinstitute.sting.gatk.traversals;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.datasources.providers.LocusReferenceView;
import org.broadinstitute.sting.gatk.datasources.providers.LocusView;
import org.broadinstitute.sting.gatk.datasources.providers.ReferenceOrderedView;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.LocusWalker;
import org.broadinstitute.sting.utils.GenomeLoc;
/**
* A simple solution to iterating over all reference positions over a series of genomic locations.
*/
public class TraverseLociLinear<M,T> extends TraverseLociBase<M,T> {
@Override
protected TraverseResults<T> traverse(LocusWalker<M, T> walker, LocusView locusView, LocusReferenceView referenceView, ReferenceOrderedView referenceOrderedDataView, T sum) {
// We keep processing while the next reference location is within the interval
boolean done = false;
int numIterations = 0;
while( locusView.hasNext() && ! done ) {
numIterations++;
final AlignmentContext locus = locusView.next();
final GenomeLoc location = locus.getLocation();
// create reference context. Note that if we have a pileup of "extended events", the context will
// hold the (longest) stretch of deleted reference bases (if deletions are present in the pileup).
final ReferenceContext refContext = referenceView.getReferenceContext(location);
// Iterate forward to get all reference ordered data covering this location
final RefMetaDataTracker tracker = referenceOrderedDataView.getReferenceOrderedDataAtLocus(locus.getLocation(), refContext);
final boolean keepMeP = walker.filter(tracker, refContext, locus);
if (keepMeP) {
final M x = walker.map(tracker, refContext, locus);
sum = walker.reduce(x, sum);
done = walker.isDone();
}
// TODO -- refactor printProgress to separate updating read metrics from printing progress
//printProgress(dataProvider.getShard(),locus.getLocation());
}
return new TraverseResults<T>(numIterations, sum);
}
}

View File

@ -0,0 +1,200 @@
package org.broadinstitute.sting.gatk.traversals;
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.datasources.providers.LocusReferenceView;
import org.broadinstitute.sting.gatk.datasources.providers.LocusView;
import org.broadinstitute.sting.gatk.datasources.providers.ReferenceOrderedView;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.LocusWalker;
import org.broadinstitute.sting.utils.GenomeLoc;
import org.broadinstitute.sting.utils.nanoScheduler.MapFunction;
import org.broadinstitute.sting.utils.nanoScheduler.NanoScheduler;
import org.broadinstitute.sting.utils.nanoScheduler.ReduceFunction;
import java.util.Iterator;
/**
* A simple solution to iterating over all reference positions over a series of genomic locations.
*/
public class TraverseLociNano<M,T> extends TraverseLociBase<M,T> {
/** our log, which we want to capture anything from this class */
private static final boolean DEBUG = false;
private static final int BUFFER_SIZE = 1000;
final NanoScheduler<MapData, MapResult, T> nanoScheduler;
public TraverseLociNano(int nThreads) {
nanoScheduler = new NanoScheduler<MapData, MapResult, T>(BUFFER_SIZE, nThreads);
}
@Override
protected TraverseResults<T> traverse(final LocusWalker<M, T> walker,
final LocusView locusView,
final LocusReferenceView referenceView,
final ReferenceOrderedView referenceOrderedDataView,
final T sum) {
nanoScheduler.setDebug(DEBUG);
final TraverseLociMap myMap = new TraverseLociMap(walker);
final TraverseLociReduce myReduce = new TraverseLociReduce(walker);
final MapDataIterator inputIterator = new MapDataIterator(locusView, referenceView, referenceOrderedDataView);
final T result = nanoScheduler.execute(inputIterator, myMap, sum, myReduce);
// todo -- how do I print progress?
// final GATKSAMRecord lastRead = aggregatedInputs.get(aggregatedInputs.size() - 1).read;
// final GenomeLoc locus = engine.getGenomeLocParser().createGenomeLoc(lastRead);
// printProgress(dataProvider.getShard(), locus, aggregatedInputs.size());
return new TraverseResults<T>(inputIterator.numIterations, result);
}
/**
* Create iterator that provides inputs for all map calls into MapData, to be provided
* to NanoScheduler for Map/Reduce
*/
private class MapDataIterator implements Iterator<MapData> {
final LocusView locusView;
final LocusReferenceView referenceView;
final ReferenceOrderedView referenceOrderedDataView;
int numIterations = 0;
private MapDataIterator(LocusView locusView, LocusReferenceView referenceView, ReferenceOrderedView referenceOrderedDataView) {
this.locusView = locusView;
this.referenceView = referenceView;
this.referenceOrderedDataView = referenceOrderedDataView;
}
@Override
public boolean hasNext() {
return locusView.hasNext();
}
@Override
public MapData next() {
final AlignmentContext locus = locusView.next();
final GenomeLoc location = locus.getLocation();
//logger.info("Pulling data from MapDataIterator at " + location);
// create reference context. Note that if we have a pileup of "extended events", the context will
// hold the (longest) stretch of deleted reference bases (if deletions are present in the pileup).
final ReferenceContext refContext = referenceView.getReferenceContext(location);
// Iterate forward to get all reference ordered data covering this location
final RefMetaDataTracker tracker = referenceOrderedDataView.getReferenceOrderedDataAtLocus(location, refContext);
numIterations++;
return new MapData(locus, refContext, tracker);
}
@Override
public void remove() {
throw new UnsupportedOperationException("Cannot remove elements from MapDataIterator");
}
}
@Override
public void printOnTraversalDone() {
nanoScheduler.shutdown();
super.printOnTraversalDone();
}
/**
* The input data needed for each map call. The read, the reference, and the RODs
*/
private class MapData {
final AlignmentContext alignmentContext;
final ReferenceContext refContext;
final RefMetaDataTracker tracker;
private MapData(final AlignmentContext alignmentContext, ReferenceContext refContext, RefMetaDataTracker tracker) {
this.alignmentContext = alignmentContext;
this.refContext = refContext;
this.tracker = tracker;
}
@Override
public String toString() {
return "MapData " + alignmentContext.getLocation();
}
}
/**
* Contains the results of a map call, indicating whether the call was good, filtered, or done
*/
private class MapResult {
final M value;
final boolean reduceMe;
/**
* Create a MapResult with value that should be reduced
*
* @param value the value to reduce
*/
private MapResult(final M value) {
this.value = value;
this.reduceMe = true;
}
/**
* Create a MapResult that shouldn't be reduced
*/
private MapResult() {
this.value = null;
this.reduceMe = false;
}
}
/**
* A static object that tells reduce that the result of map should be skipped (filtered or done)
*/
private final MapResult SKIP_REDUCE = new MapResult();
/**
* MapFunction for TraverseReads meeting NanoScheduler interface requirements
*
* Applies walker.map to MapData, returning a MapResult object containing the result
*/
private class TraverseLociMap implements MapFunction<MapData, MapResult> {
final LocusWalker<M,T> walker;
private TraverseLociMap(LocusWalker<M, T> walker) {
this.walker = walker;
}
@Override
public MapResult apply(final MapData data) {
if ( ! walker.isDone() ) {
final boolean keepMeP = walker.filter(data.tracker, data.refContext, data.alignmentContext);
if (keepMeP) {
final M x = walker.map(data.tracker, data.refContext, data.alignmentContext);
return new MapResult(x);
}
}
return SKIP_REDUCE;
}
}
/**
* ReduceFunction for TraverseReads meeting NanoScheduler interface requirements
*
* Takes a MapResult object and applies the walkers reduce function to each map result, when applicable
*/
private class TraverseLociReduce implements ReduceFunction<MapResult, T> {
final LocusWalker<M,T> walker;
private TraverseLociReduce(LocusWalker<M, T> walker) {
this.walker = walker;
}
@Override
public T apply(MapResult one, T sum) {
if ( one.reduceMe )
// only run reduce on values that aren't DONE or FAILED
return walker.reduce(one.value, sum);
else
return sum;
}
}
}