Mark I NanoScheduling TraverseLoci
-- Refactored TraverseLoci into old linear version and nano scheduling version -- Temp. GATK argument to say how many nano threads to use -- Can efficiently scale to 3 threads before blocking on input
This commit is contained in:
parent
757e6a0160
commit
d503ed97ab
|
|
@ -27,7 +27,6 @@ package org.broadinstitute.sting.gatk;
|
|||
import net.sf.picard.filter.SamRecordFilter;
|
||||
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
|
||||
|
||||
import java.util.Collections;
|
||||
import java.util.HashMap;
|
||||
import java.util.Map;
|
||||
import java.util.TreeMap;
|
||||
|
|
@ -119,11 +118,18 @@ public class ReadMetrics implements Cloneable {
|
|||
return nRecords;
|
||||
}
|
||||
|
||||
/**
|
||||
* Increments the number of 'iterations' (one call of filter/map/reduce sequence) completed.
|
||||
*/
|
||||
public void incrementNumIterations(final long by) {
|
||||
nRecords += by;
|
||||
}
|
||||
|
||||
/**
|
||||
* Increments the number of 'iterations' (one call of filter/map/reduce sequence) completed.
|
||||
*/
|
||||
public void incrementNumIterations() {
|
||||
nRecords++;
|
||||
incrementNumIterations(1);
|
||||
}
|
||||
|
||||
public long getNumReadsSeen() {
|
||||
|
|
|
|||
|
|
@ -313,6 +313,10 @@ public class GATKArgumentCollection {
|
|||
@Argument(fullName = "num_bam_file_handles", shortName = "bfh", doc="The total number of BAM file handles to keep open simultaneously", required=false)
|
||||
public Integer numberOfBAMFileHandles = null;
|
||||
|
||||
@Argument(fullName="nanoThreads", shortName = "nanoThreads", doc="NanoThreading", required = false)
|
||||
@Hidden
|
||||
public int nanoThreads = 1;
|
||||
|
||||
@Input(fullName = "read_group_black_list", shortName="rgbl", doc="Filters out read groups matching <TAG>:<STRING> or a .txt file containing the filter strings one per line.", required = false)
|
||||
public List<String> readGroupBlackList = null;
|
||||
|
||||
|
|
|
|||
|
|
@ -146,7 +146,8 @@ public abstract class MicroScheduler implements MicroSchedulerMBean {
|
|||
if (walker instanceof ReadWalker) {
|
||||
traversalEngine = numThreads > 1 ? new TraverseReadsNano(numThreads) : new TraverseReads();
|
||||
} else if (walker instanceof LocusWalker) {
|
||||
traversalEngine = new TraverseLoci();
|
||||
// TODO -- refactor to use better interface
|
||||
traversalEngine = engine.getArguments().nanoThreads > 1 ? new TraverseLociNano(engine.getArguments().nanoThreads) : new TraverseLociLinear();
|
||||
} else if (walker instanceof DuplicateWalker) {
|
||||
traversalEngine = new TraverseDuplicates();
|
||||
} else if (walker instanceof ReadPairWalker) {
|
||||
|
|
|
|||
|
|
@ -3,9 +3,7 @@ package org.broadinstitute.sting.gatk.traversals;
|
|||
import org.apache.log4j.Logger;
|
||||
import org.broadinstitute.sting.gatk.WalkerManager;
|
||||
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
|
||||
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
|
||||
import org.broadinstitute.sting.gatk.datasources.providers.*;
|
||||
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
|
||||
import org.broadinstitute.sting.gatk.walkers.DataSource;
|
||||
import org.broadinstitute.sting.gatk.walkers.LocusWalker;
|
||||
import org.broadinstitute.sting.gatk.walkers.Walker;
|
||||
|
|
@ -15,28 +13,42 @@ import org.broadinstitute.sting.utils.pileup.ReadBackedPileupImpl;
|
|||
/**
|
||||
* A simple solution to iterating over all reference positions over a series of genomic locations.
|
||||
*/
|
||||
public class TraverseLoci<M,T> extends TraversalEngine<M,T,LocusWalker<M,T>,LocusShardDataProvider> {
|
||||
public abstract class TraverseLociBase<M,T> extends TraversalEngine<M,T,LocusWalker<M,T>,LocusShardDataProvider> {
|
||||
/**
|
||||
* our log, which we want to capture anything from this class
|
||||
*/
|
||||
protected static final Logger logger = Logger.getLogger(TraversalEngine.class);
|
||||
|
||||
@Override
|
||||
protected String getTraversalType() {
|
||||
protected final String getTraversalType() {
|
||||
return "sites";
|
||||
}
|
||||
|
||||
protected static class TraverseResults<T> {
|
||||
final int numIterations;
|
||||
final T reduceResult;
|
||||
|
||||
public TraverseResults(int numIterations, T reduceResult) {
|
||||
this.numIterations = numIterations;
|
||||
this.reduceResult = reduceResult;
|
||||
}
|
||||
}
|
||||
|
||||
protected abstract TraverseResults<T> traverse( final LocusWalker<M,T> walker,
|
||||
final LocusView locusView,
|
||||
final LocusReferenceView referenceView,
|
||||
final ReferenceOrderedView referenceOrderedDataView,
|
||||
final T sum);
|
||||
|
||||
@Override
|
||||
public T traverse( LocusWalker<M,T> walker,
|
||||
LocusShardDataProvider dataProvider,
|
||||
T sum) {
|
||||
logger.debug(String.format("TraverseLoci.traverse: Shard is %s", dataProvider));
|
||||
logger.debug(String.format("TraverseLociBase.traverse: Shard is %s", dataProvider));
|
||||
|
||||
LocusView locusView = getLocusView( walker, dataProvider );
|
||||
boolean done = false;
|
||||
final LocusView locusView = getLocusView( walker, dataProvider );
|
||||
|
||||
if ( locusView.hasNext() ) { // trivial optimization to avoid unnecessary processing when there's nothing here at all
|
||||
|
||||
//ReferenceOrderedView referenceOrderedDataView = new ReferenceOrderedView( dataProvider );
|
||||
ReferenceOrderedView referenceOrderedDataView = null;
|
||||
if ( WalkerManager.getWalkerDataSource(walker) != DataSource.REFERENCE_ORDERED_DATA )
|
||||
|
|
@ -44,43 +56,23 @@ public class TraverseLoci<M,T> extends TraversalEngine<M,T,LocusWalker<M,T>,Locu
|
|||
else
|
||||
referenceOrderedDataView = (RodLocusView)locusView;
|
||||
|
||||
LocusReferenceView referenceView = new LocusReferenceView( walker, dataProvider );
|
||||
final LocusReferenceView referenceView = new LocusReferenceView( walker, dataProvider );
|
||||
|
||||
// We keep processing while the next reference location is within the interval
|
||||
while( locusView.hasNext() && ! done ) {
|
||||
AlignmentContext locus = locusView.next();
|
||||
GenomeLoc location = locus.getLocation();
|
||||
|
||||
dataProvider.getShard().getReadMetrics().incrementNumIterations();
|
||||
|
||||
// create reference context. Note that if we have a pileup of "extended events", the context will
|
||||
// hold the (longest) stretch of deleted reference bases (if deletions are present in the pileup).
|
||||
ReferenceContext refContext = referenceView.getReferenceContext(location);
|
||||
|
||||
// Iterate forward to get all reference ordered data covering this location
|
||||
final RefMetaDataTracker tracker = referenceOrderedDataView.getReferenceOrderedDataAtLocus(locus.getLocation(), refContext);
|
||||
|
||||
final boolean keepMeP = walker.filter(tracker, refContext, locus);
|
||||
if (keepMeP) {
|
||||
M x = walker.map(tracker, refContext, locus);
|
||||
sum = walker.reduce(x, sum);
|
||||
done = walker.isDone();
|
||||
}
|
||||
|
||||
printProgress(dataProvider.getShard(),locus.getLocation());
|
||||
}
|
||||
final TraverseResults<T> result = traverse( walker, locusView, referenceView, referenceOrderedDataView, sum );
|
||||
sum = result.reduceResult;
|
||||
dataProvider.getShard().getReadMetrics().incrementNumIterations(result.numIterations);
|
||||
}
|
||||
|
||||
// We have a final map call to execute here to clean up the skipped based from the
|
||||
// last position in the ROD to that in the interval
|
||||
if ( WalkerManager.getWalkerDataSource(walker) == DataSource.REFERENCE_ORDERED_DATA && ! walker.isDone() ) {
|
||||
// only do this if the walker isn't done!
|
||||
RodLocusView rodLocusView = (RodLocusView)locusView;
|
||||
long nSkipped = rodLocusView.getLastSkippedBases();
|
||||
final RodLocusView rodLocusView = (RodLocusView)locusView;
|
||||
final long nSkipped = rodLocusView.getLastSkippedBases();
|
||||
if ( nSkipped > 0 ) {
|
||||
GenomeLoc site = rodLocusView.getLocOneBeyondShard();
|
||||
AlignmentContext ac = new AlignmentContext(site, new ReadBackedPileupImpl(site), nSkipped);
|
||||
M x = walker.map(null, null, ac);
|
||||
final GenomeLoc site = rodLocusView.getLocOneBeyondShard();
|
||||
final AlignmentContext ac = new AlignmentContext(site, new ReadBackedPileupImpl(site), nSkipped);
|
||||
final M x = walker.map(null, null, ac);
|
||||
sum = walker.reduce(x, sum);
|
||||
}
|
||||
}
|
||||
|
|
@ -90,14 +82,14 @@ public class TraverseLoci<M,T> extends TraversalEngine<M,T,LocusWalker<M,T>,Locu
|
|||
|
||||
/**
|
||||
* Gets the best view of loci for this walker given the available data. The view will function as a 'trigger track'
|
||||
* of sorts, providing a consistent interface so that TraverseLoci doesn't need to be reimplemented for any new datatype
|
||||
* of sorts, providing a consistent interface so that TraverseLociBase doesn't need to be reimplemented for any new datatype
|
||||
* that comes along.
|
||||
* @param walker walker to interrogate.
|
||||
* @param dataProvider Data which which to drive the locus view.
|
||||
* @return A view of the locus data, where one iteration of the locus view maps to one iteration of the traversal.
|
||||
*/
|
||||
private LocusView getLocusView( Walker<M,T> walker, LocusShardDataProvider dataProvider ) {
|
||||
DataSource dataSource = WalkerManager.getWalkerDataSource(walker);
|
||||
final DataSource dataSource = WalkerManager.getWalkerDataSource(walker);
|
||||
if( dataSource == DataSource.READS )
|
||||
return new CoveredLocusView(dataProvider);
|
||||
else if( dataSource == DataSource.REFERENCE ) //|| ! GenomeAnalysisEngine.instance.getArguments().enableRodWalkers )
|
||||
|
|
@ -0,0 +1,48 @@
|
|||
package org.broadinstitute.sting.gatk.traversals;
|
||||
|
||||
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
|
||||
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
|
||||
import org.broadinstitute.sting.gatk.datasources.providers.LocusReferenceView;
|
||||
import org.broadinstitute.sting.gatk.datasources.providers.LocusView;
|
||||
import org.broadinstitute.sting.gatk.datasources.providers.ReferenceOrderedView;
|
||||
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
|
||||
import org.broadinstitute.sting.gatk.walkers.LocusWalker;
|
||||
import org.broadinstitute.sting.utils.GenomeLoc;
|
||||
|
||||
/**
|
||||
* A simple solution to iterating over all reference positions over a series of genomic locations.
|
||||
*/
|
||||
public class TraverseLociLinear<M,T> extends TraverseLociBase<M,T> {
|
||||
|
||||
@Override
|
||||
protected TraverseResults<T> traverse(LocusWalker<M, T> walker, LocusView locusView, LocusReferenceView referenceView, ReferenceOrderedView referenceOrderedDataView, T sum) {
|
||||
// We keep processing while the next reference location is within the interval
|
||||
boolean done = false;
|
||||
int numIterations = 0;
|
||||
|
||||
while( locusView.hasNext() && ! done ) {
|
||||
numIterations++;
|
||||
final AlignmentContext locus = locusView.next();
|
||||
final GenomeLoc location = locus.getLocation();
|
||||
|
||||
// create reference context. Note that if we have a pileup of "extended events", the context will
|
||||
// hold the (longest) stretch of deleted reference bases (if deletions are present in the pileup).
|
||||
final ReferenceContext refContext = referenceView.getReferenceContext(location);
|
||||
|
||||
// Iterate forward to get all reference ordered data covering this location
|
||||
final RefMetaDataTracker tracker = referenceOrderedDataView.getReferenceOrderedDataAtLocus(locus.getLocation(), refContext);
|
||||
|
||||
final boolean keepMeP = walker.filter(tracker, refContext, locus);
|
||||
if (keepMeP) {
|
||||
final M x = walker.map(tracker, refContext, locus);
|
||||
sum = walker.reduce(x, sum);
|
||||
done = walker.isDone();
|
||||
}
|
||||
|
||||
// TODO -- refactor printProgress to separate updating read metrics from printing progress
|
||||
//printProgress(dataProvider.getShard(),locus.getLocation());
|
||||
}
|
||||
|
||||
return new TraverseResults<T>(numIterations, sum);
|
||||
}
|
||||
}
|
||||
|
|
@ -0,0 +1,200 @@
|
|||
package org.broadinstitute.sting.gatk.traversals;
|
||||
|
||||
import org.broadinstitute.sting.gatk.contexts.AlignmentContext;
|
||||
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
|
||||
import org.broadinstitute.sting.gatk.datasources.providers.LocusReferenceView;
|
||||
import org.broadinstitute.sting.gatk.datasources.providers.LocusView;
|
||||
import org.broadinstitute.sting.gatk.datasources.providers.ReferenceOrderedView;
|
||||
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
|
||||
import org.broadinstitute.sting.gatk.walkers.LocusWalker;
|
||||
import org.broadinstitute.sting.utils.GenomeLoc;
|
||||
import org.broadinstitute.sting.utils.nanoScheduler.MapFunction;
|
||||
import org.broadinstitute.sting.utils.nanoScheduler.NanoScheduler;
|
||||
import org.broadinstitute.sting.utils.nanoScheduler.ReduceFunction;
|
||||
|
||||
import java.util.Iterator;
|
||||
|
||||
/**
|
||||
* A simple solution to iterating over all reference positions over a series of genomic locations.
|
||||
*/
|
||||
public class TraverseLociNano<M,T> extends TraverseLociBase<M,T> {
|
||||
/** our log, which we want to capture anything from this class */
|
||||
private static final boolean DEBUG = false;
|
||||
private static final int BUFFER_SIZE = 1000;
|
||||
|
||||
final NanoScheduler<MapData, MapResult, T> nanoScheduler;
|
||||
|
||||
public TraverseLociNano(int nThreads) {
|
||||
nanoScheduler = new NanoScheduler<MapData, MapResult, T>(BUFFER_SIZE, nThreads);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected TraverseResults<T> traverse(final LocusWalker<M, T> walker,
|
||||
final LocusView locusView,
|
||||
final LocusReferenceView referenceView,
|
||||
final ReferenceOrderedView referenceOrderedDataView,
|
||||
final T sum) {
|
||||
nanoScheduler.setDebug(DEBUG);
|
||||
final TraverseLociMap myMap = new TraverseLociMap(walker);
|
||||
final TraverseLociReduce myReduce = new TraverseLociReduce(walker);
|
||||
|
||||
final MapDataIterator inputIterator = new MapDataIterator(locusView, referenceView, referenceOrderedDataView);
|
||||
final T result = nanoScheduler.execute(inputIterator, myMap, sum, myReduce);
|
||||
|
||||
// todo -- how do I print progress?
|
||||
// final GATKSAMRecord lastRead = aggregatedInputs.get(aggregatedInputs.size() - 1).read;
|
||||
// final GenomeLoc locus = engine.getGenomeLocParser().createGenomeLoc(lastRead);
|
||||
// printProgress(dataProvider.getShard(), locus, aggregatedInputs.size());
|
||||
|
||||
return new TraverseResults<T>(inputIterator.numIterations, result);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create iterator that provides inputs for all map calls into MapData, to be provided
|
||||
* to NanoScheduler for Map/Reduce
|
||||
*/
|
||||
private class MapDataIterator implements Iterator<MapData> {
|
||||
final LocusView locusView;
|
||||
final LocusReferenceView referenceView;
|
||||
final ReferenceOrderedView referenceOrderedDataView;
|
||||
int numIterations = 0;
|
||||
|
||||
private MapDataIterator(LocusView locusView, LocusReferenceView referenceView, ReferenceOrderedView referenceOrderedDataView) {
|
||||
this.locusView = locusView;
|
||||
this.referenceView = referenceView;
|
||||
this.referenceOrderedDataView = referenceOrderedDataView;
|
||||
}
|
||||
|
||||
@Override
|
||||
public boolean hasNext() {
|
||||
return locusView.hasNext();
|
||||
}
|
||||
|
||||
@Override
|
||||
public MapData next() {
|
||||
final AlignmentContext locus = locusView.next();
|
||||
final GenomeLoc location = locus.getLocation();
|
||||
|
||||
//logger.info("Pulling data from MapDataIterator at " + location);
|
||||
|
||||
// create reference context. Note that if we have a pileup of "extended events", the context will
|
||||
// hold the (longest) stretch of deleted reference bases (if deletions are present in the pileup).
|
||||
final ReferenceContext refContext = referenceView.getReferenceContext(location);
|
||||
|
||||
// Iterate forward to get all reference ordered data covering this location
|
||||
final RefMetaDataTracker tracker = referenceOrderedDataView.getReferenceOrderedDataAtLocus(location, refContext);
|
||||
|
||||
numIterations++;
|
||||
return new MapData(locus, refContext, tracker);
|
||||
}
|
||||
|
||||
@Override
|
||||
public void remove() {
|
||||
throw new UnsupportedOperationException("Cannot remove elements from MapDataIterator");
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
public void printOnTraversalDone() {
|
||||
nanoScheduler.shutdown();
|
||||
super.printOnTraversalDone();
|
||||
}
|
||||
|
||||
/**
|
||||
* The input data needed for each map call. The read, the reference, and the RODs
|
||||
*/
|
||||
private class MapData {
|
||||
final AlignmentContext alignmentContext;
|
||||
final ReferenceContext refContext;
|
||||
final RefMetaDataTracker tracker;
|
||||
|
||||
private MapData(final AlignmentContext alignmentContext, ReferenceContext refContext, RefMetaDataTracker tracker) {
|
||||
this.alignmentContext = alignmentContext;
|
||||
this.refContext = refContext;
|
||||
this.tracker = tracker;
|
||||
}
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return "MapData " + alignmentContext.getLocation();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Contains the results of a map call, indicating whether the call was good, filtered, or done
|
||||
*/
|
||||
private class MapResult {
|
||||
final M value;
|
||||
final boolean reduceMe;
|
||||
|
||||
/**
|
||||
* Create a MapResult with value that should be reduced
|
||||
*
|
||||
* @param value the value to reduce
|
||||
*/
|
||||
private MapResult(final M value) {
|
||||
this.value = value;
|
||||
this.reduceMe = true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a MapResult that shouldn't be reduced
|
||||
*/
|
||||
private MapResult() {
|
||||
this.value = null;
|
||||
this.reduceMe = false;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* A static object that tells reduce that the result of map should be skipped (filtered or done)
|
||||
*/
|
||||
private final MapResult SKIP_REDUCE = new MapResult();
|
||||
|
||||
/**
|
||||
* MapFunction for TraverseReads meeting NanoScheduler interface requirements
|
||||
*
|
||||
* Applies walker.map to MapData, returning a MapResult object containing the result
|
||||
*/
|
||||
private class TraverseLociMap implements MapFunction<MapData, MapResult> {
|
||||
final LocusWalker<M,T> walker;
|
||||
|
||||
private TraverseLociMap(LocusWalker<M, T> walker) {
|
||||
this.walker = walker;
|
||||
}
|
||||
|
||||
@Override
|
||||
public MapResult apply(final MapData data) {
|
||||
if ( ! walker.isDone() ) {
|
||||
final boolean keepMeP = walker.filter(data.tracker, data.refContext, data.alignmentContext);
|
||||
if (keepMeP) {
|
||||
final M x = walker.map(data.tracker, data.refContext, data.alignmentContext);
|
||||
return new MapResult(x);
|
||||
}
|
||||
}
|
||||
return SKIP_REDUCE;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* ReduceFunction for TraverseReads meeting NanoScheduler interface requirements
|
||||
*
|
||||
* Takes a MapResult object and applies the walkers reduce function to each map result, when applicable
|
||||
*/
|
||||
private class TraverseLociReduce implements ReduceFunction<MapResult, T> {
|
||||
final LocusWalker<M,T> walker;
|
||||
|
||||
private TraverseLociReduce(LocusWalker<M, T> walker) {
|
||||
this.walker = walker;
|
||||
}
|
||||
|
||||
@Override
|
||||
public T apply(MapResult one, T sum) {
|
||||
if ( one.reduceMe )
|
||||
// only run reduce on values that aren't DONE or FAILED
|
||||
return walker.reduce(one.value, sum);
|
||||
else
|
||||
return sum;
|
||||
}
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue