From 569e1a1089c766913ba17f727cc43db794d65693 Mon Sep 17 00:00:00 2001 From: Mark DePristo Date: Tue, 23 Aug 2011 16:53:06 -0400 Subject: [PATCH] Walker.isDone() aborts execution early -- Useful if you want to have a parameter like MAX_RECORDS that wants the walker to stop after some number of map calls without having to resort to the old System.exit() call directly. --- .../gatk/executive/LinearMicroScheduler.java | 6 +++- .../gatk/traversals/TraverseDuplicates.java | 3 ++ .../sting/gatk/traversals/TraverseLoci.java | 11 +++++--- .../gatk/traversals/TraverseReadPairs.java | 4 +++ .../sting/gatk/traversals/TraverseReads.java | 3 ++ .../sting/gatk/walkers/Walker.java | 11 ++++++++ .../walkers/variantutils/VariantsToTable.java | 28 +++++++++---------- 7 files changed, 47 insertions(+), 19 deletions(-) diff --git a/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java b/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java index 48fd73e0b..65ff27497 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java +++ b/public/java/src/org/broadinstitute/sting/gatk/executive/LinearMicroScheduler.java @@ -48,9 +48,10 @@ public class LinearMicroScheduler extends MicroScheduler { walker.initialize(); Accumulator accumulator = Accumulator.create(engine,walker); + boolean done = walker.isDone(); int counter = 0; for (Shard shard : shardStrategy ) { - if ( shard == null ) // we ran out of shards that aren't owned + if ( done || shard == null ) // we ran out of shards that aren't owned break; if(shard.getShardType() == Shard.ShardType.LOCUS) { @@ -61,6 +62,7 @@ public class LinearMicroScheduler extends MicroScheduler { Object result = traversalEngine.traverse(walker, dataProvider, accumulator.getReduceInit()); accumulator.accumulate(dataProvider,result); dataProvider.close(); + if ( walker.isDone() ) break; } windowMaker.close(); } @@ -70,6 +72,8 @@ public class LinearMicroScheduler extends MicroScheduler { accumulator.accumulate(dataProvider,result); dataProvider.close(); } + + done = walker.isDone(); } Object result = accumulator.finishTraversal(); diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseDuplicates.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseDuplicates.java index 1ba48ca5f..046003154 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseDuplicates.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseDuplicates.java @@ -173,7 +173,9 @@ public class TraverseDuplicates extends TraversalEngine those with the same mate pair position, for paired reads * -> those flagged as unpaired and duplicated but having the same start and end */ + boolean done = walker.isDone(); for (SAMRecord read : iter) { + if ( done ) break; // get the genome loc from the read GenomeLoc site = engine.getGenomeLocParser().createGenomeLoc(read); @@ -194,6 +196,7 @@ public class TraverseDuplicates extends TraversalEngine extends TraversalEngine,Locu logger.debug(String.format("TraverseLoci.traverse: Shard is %s", dataProvider)); LocusView locusView = getLocusView( walker, dataProvider ); + boolean done = false; if ( locusView.hasNext() ) { // trivial optimization to avoid unnecessary processing when there's nothing here at all @@ -46,7 +47,7 @@ public class TraverseLoci extends TraversalEngine,Locu LocusReferenceView referenceView = new LocusReferenceView( walker, dataProvider ); // We keep processing while the next reference location is within the interval - while( locusView.hasNext() ) { + while( locusView.hasNext() && ! done ) { AlignmentContext locus = locusView.next(); GenomeLoc location = locus.getLocation(); @@ -76,15 +77,17 @@ public class TraverseLoci extends TraversalEngine,Locu if (keepMeP) { M x = walker.map(tracker, refContext, locus); sum = walker.reduce(x, sum); + done = walker.isDone(); } printProgress(dataProvider.getShard(),locus.getLocation()); } } - // We have a final map call to execute here to clean up the skipped based from the - // last position in the ROD to that in the interval - if ( WalkerManager.getWalkerDataSource(walker) == DataSource.REFERENCE_ORDERED_DATA ) { + // We have a final map call to execute here to clean up the skipped based from the + // last position in the ROD to that in the interval + if ( WalkerManager.getWalkerDataSource(walker) == DataSource.REFERENCE_ORDERED_DATA && ! walker.isDone() ) { + // only do this if the walker isn't done! RodLocusView rodLocusView = (RodLocusView)locusView; long nSkipped = rodLocusView.getLastSkippedBases(); if ( nSkipped > 0 ) { diff --git a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadPairs.java b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadPairs.java index 196d54036..dd4402d82 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadPairs.java +++ b/public/java/src/org/broadinstitute/sting/gatk/traversals/TraverseReadPairs.java @@ -50,7 +50,9 @@ public class TraverseReadPairs extends TraversalEngine pairs = new ArrayList(); + boolean done = walker.isDone(); for(SAMRecord read: reads) { + if ( done ) break; dataProvider.getShard().getReadMetrics().incrementNumReadsSeen(); if(pairs.size() == 0 || pairs.get(0).getReadName().equals(read.getReadName())) { @@ -65,6 +67,8 @@ public class TraverseReadPairs extends TraversalEngine extends TraversalEngine,Read // get the reference ordered data ReadBasedReferenceOrderedView rodView = new ReadBasedReferenceOrderedView(dataProvider); + boolean done = walker.isDone(); // while we still have more reads for (SAMRecord read : reads) { + if ( done ) break; // ReferenceContext -- the reference bases covered by the read ReferenceContext refContext = null; @@ -106,6 +108,7 @@ public class TraverseReads extends TraversalEngine,Read GenomeLoc locus = read.getReferenceIndex() == SAMRecord.NO_ALIGNMENT_REFERENCE_INDEX ? null : engine.getGenomeLocParser().createGenomeLoc(read.getReferenceName(),read.getAlignmentStart()); printProgress(dataProvider.getShard(),locus); + done = walker.isDone(); } return sum; } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/Walker.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/Walker.java index 9e261a0b1..c88c7c3c4 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/Walker.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/Walker.java @@ -126,6 +126,17 @@ public abstract class Walker { public void initialize() { } + /** + * A function for overloading in subclasses providing a mechanism to abort early from a walker. + * + * If this ever returns true, then the Traversal engine will stop executing map calls + * and start the process of shutting down the walker in an orderly fashion. + * @return + */ + public boolean isDone() { + return false; + } + /** * Provide an initial value for reduce computations. * @return Initial value of reduce. diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/variantutils/VariantsToTable.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/variantutils/VariantsToTable.java index 19db58e0c..5dd75c858 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/variantutils/VariantsToTable.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/variantutils/VariantsToTable.java @@ -147,22 +147,22 @@ public class VariantsToTable extends RodWalker { if ( tracker == null ) // RodWalkers can make funky map calls return 0; - if ( ++nRecords < MAX_RECORDS || MAX_RECORDS == -1 ) { - for ( VariantContext vc : tracker.getValues(variantCollection.variants, context.getLocation())) { - if ( (keepMultiAllelic || vc.isBiallelic()) && ( showFiltered || vc.isNotFiltered() ) ) { - List vals = extractFields(vc, fieldsToTake, ALLOW_MISSING_DATA); - out.println(Utils.join("\t", vals)); - } + nRecords++; + for ( VariantContext vc : tracker.getValues(variantCollection.variants, context.getLocation())) { + if ( (keepMultiAllelic || vc.isBiallelic()) && ( showFiltered || vc.isNotFiltered() ) ) { + List vals = extractFields(vc, fieldsToTake, ALLOW_MISSING_DATA); + out.println(Utils.join("\t", vals)); } - - return 1; - } else { - if ( nRecords >= MAX_RECORDS ) { - logger.warn("Calling sys exit to leave after " + nRecords + " records"); - System.exit(0); // todo -- what's the recommend way to abort like this? - } - return 0; } + + return 1; + } + + @Override + public boolean isDone() { + boolean done = MAX_RECORDS != -1 && nRecords >= MAX_RECORDS; + if ( done) logger.warn("isDone() will return true to leave after " + nRecords + " records"); + return done ; } private static final boolean isWildCard(String s) {