Fix GSA-529: Fix RODs for parallel read walkers

-- TraverseReadsNano modified to read in all input data before invoking maps, so the input to TraverseReadsNano is a MapData object holding the sam record, the ref context, and the refmetadatatracker.
-- Update ValidateRODForReads to be tree reducible, using synchronized map and explicitly sort the output map from locations -> counts in onTraversalDone
-- Expanded integration tests to test nt 1, 2, 4.
This commit is contained in:
Mark DePristo 2012-08-30 15:10:58 -04:00
parent 7d95176539
commit 7a462399ce
2 changed files with 58 additions and 38 deletions

View File

@ -27,16 +27,21 @@ package org.broadinstitute.sting.gatk.traversals;
import net.sf.samtools.SAMRecord;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.gatk.datasources.providers.*;
import org.broadinstitute.sting.gatk.datasources.providers.ReadBasedReferenceOrderedView;
import org.broadinstitute.sting.gatk.datasources.providers.ReadReferenceView;
import org.broadinstitute.sting.gatk.datasources.providers.ReadShardDataProvider;
import org.broadinstitute.sting.gatk.datasources.providers.ReadView;
import org.broadinstitute.sting.gatk.datasources.reads.ReadShard;
import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker;
import org.broadinstitute.sting.gatk.walkers.ReadWalker;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import org.broadinstitute.sting.utils.nanoScheduler.MapFunction;
import org.broadinstitute.sting.utils.nanoScheduler.NanoScheduler;
import org.broadinstitute.sting.utils.nanoScheduler.ReduceFunction;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import java.util.ArrayList;
import java.util.List;
/**
* @author aaron
* @version 1.0
@ -50,12 +55,13 @@ public class TraverseReadsNano<M,T> extends TraversalEngine<M,T,ReadWalker<M,T>,
/** our log, which we want to capture anything from this class */
protected static final Logger logger = Logger.getLogger(TraverseReadsNano.class);
private static final boolean DEBUG = false;
final NanoScheduler<SAMRecord, M, T> nanoScheduler;
private static final int MIN_GROUP_SIZE = 100;
final NanoScheduler<MapData, M, T> nanoScheduler;
public TraverseReadsNano(int nThreads) {
final int bufferSize = ReadShard.getReadBufferSize() + 1; // actually has 1 more than max
final int mapGroupSize = bufferSize / 10 + 1;
nanoScheduler = new NanoScheduler<SAMRecord, M, T>(bufferSize, mapGroupSize, nThreads);
final int mapGroupSize = (int)Math.max(Math.ceil(bufferSize / 50.0 + 1), MIN_GROUP_SIZE);
nanoScheduler = new NanoScheduler<MapData, M, T>(bufferSize, mapGroupSize, nThreads);
}
@Override
@ -79,24 +85,42 @@ public class TraverseReadsNano<M,T> extends TraversalEngine<M,T,ReadWalker<M,T>,
if( !dataProvider.hasReads() )
throw new IllegalArgumentException("Unable to traverse reads; no read data is available.");
if ( dataProvider.hasReferenceOrderedData() )
throw new ReviewedStingException("Parallel read walkers currently don't support access to reference ordered data");
final ReadView reads = new ReadView(dataProvider);
final ReadReferenceView reference = new ReadReferenceView(dataProvider);
final ReadBasedReferenceOrderedView rodView = new ReadBasedReferenceOrderedView(dataProvider);
nanoScheduler.setDebug(DEBUG);
final TraverseReadsMap myMap = new TraverseReadsMap(reads, reference, rodView, walker);
final TraverseReadsMap myMap = new TraverseReadsMap(walker);
final TraverseReadsReduce myReduce = new TraverseReadsReduce(walker);
T result = nanoScheduler.execute(reads.iterator().iterator(), myMap, sum, myReduce);
T result = nanoScheduler.execute(aggregateMapData(dataProvider).iterator(), myMap, sum, myReduce);
// TODO -- how do we print progress?
//printProgress(dataProvider.getShard(), ???);
return result;
}
private List<MapData> aggregateMapData(final ReadShardDataProvider dataProvider) {
final ReadView reads = new ReadView(dataProvider);
final ReadReferenceView reference = new ReadReferenceView(dataProvider);
final ReadBasedReferenceOrderedView rodView = new ReadBasedReferenceOrderedView(dataProvider);
final List<MapData> mapData = new ArrayList<MapData>(); // TODO -- need size of reads
for ( final SAMRecord read : reads ) {
final ReferenceContext refContext = ! read.getReadUnmappedFlag()
? reference.getReferenceContext(read)
: null;
// if the read is mapped, create a metadata tracker
final RefMetaDataTracker tracker = read.getReferenceIndex() >= 0
? rodView.getReferenceOrderedDataForRead(read)
: null;
// update the number of reads we've seen
dataProvider.getShard().getReadMetrics().incrementNumIterations();
mapData.add(new MapData((GATKSAMRecord)read, refContext, tracker));
}
return mapData;
}
@Override
public void printOnTraversalDone() {
nanoScheduler.shutdown();
@ -116,36 +140,31 @@ public class TraverseReadsNano<M,T> extends TraversalEngine<M,T,ReadWalker<M,T>,
}
}
private class TraverseReadsMap implements MapFunction<SAMRecord, M> {
final ReadView reads;
final ReadReferenceView reference;
final ReadBasedReferenceOrderedView rodView;
private class MapData {
final GATKSAMRecord read;
final ReferenceContext refContext;
final RefMetaDataTracker tracker;
private MapData(GATKSAMRecord read, ReferenceContext refContext, RefMetaDataTracker tracker) {
this.read = read;
this.refContext = refContext;
this.tracker = tracker;
}
}
private class TraverseReadsMap implements MapFunction<MapData, M> {
final ReadWalker<M,T> walker;
private TraverseReadsMap(ReadView reads, ReadReferenceView reference, ReadBasedReferenceOrderedView rodView, ReadWalker<M, T> walker) {
this.reads = reads;
this.reference = reference;
this.rodView = rodView;
private TraverseReadsMap(ReadWalker<M, T> walker) {
this.walker = walker;
}
@Override
public M apply(final SAMRecord read) {
public M apply(final MapData data) {
if ( ! walker.isDone() ) {
// ReferenceContext -- the reference bases covered by the read
final ReferenceContext refContext = ! read.getReadUnmappedFlag() && reference != null
? reference.getReferenceContext(read)
: null;
// update the number of reads we've seen
//dataProvider.getShard().getReadMetrics().incrementNumIterations();
// if the read is mapped, create a metadata tracker
final RefMetaDataTracker tracker = read.getReferenceIndex() >= 0 ? rodView.getReferenceOrderedDataForRead(read) : null;
final boolean keepMeP = walker.filter(refContext, (GATKSAMRecord) read);
final boolean keepMeP = walker.filter(data.refContext, data.read);
if (keepMeP) {
return walker.map(refContext, (GATKSAMRecord) read, tracker);
return walker.map(data.refContext, data.read, data.tracker);
}
}

View File

@ -43,7 +43,8 @@ import java.util.concurrent.*;
* Time: 9:47 AM
*/
public class NanoScheduler<InputType, MapType, ReduceType> {
private static Logger logger = Logger.getLogger(NanoScheduler.class);
private final static Logger logger = Logger.getLogger(NanoScheduler.class);
private final static boolean ALLOW_SINGLE_THREAD_FASTPATH = true;
final int bufferSize;
final int mapGroupSize;
@ -172,7 +173,7 @@ public class NanoScheduler<InputType, MapType, ReduceType> {
if ( map == null ) throw new IllegalArgumentException("map function cannot be null");
if ( reduce == null ) throw new IllegalArgumentException("reduce function cannot be null");
if ( getnThreads() == 1 ) {
if ( ALLOW_SINGLE_THREAD_FASTPATH && getnThreads() == 1 ) {
return executeSingleThreaded(inputReader, map, initialValue, reduce);
} else {
return executeMultiThreaded(inputReader, map, initialValue, reduce);