From 0791beab8ff65600c57cc8b3fd4dc38f61fb8f4e Mon Sep 17 00:00:00 2001 From: hanna Date: Mon, 17 May 2010 21:00:44 +0000 Subject: [PATCH] Checking in downsampling iterator alongside LocusIteratorByState, and removing the reference implementation. Also implemented a heap size monitor that can be used to programmatically report the current heap size. git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@3367 348d0f76-0448-11de-a6fe-93d51630548a --- .../sting/gatk/GenomeAnalysisEngine.java | 12 +- ... => DownsamplingLocusIteratorByState.java} | 436 +++++++++++------- .../sting/utils/HeapSizeMonitor.java | 80 ++++ .../sting/utils/ReservoirDownsampler.java | 238 ++++------ .../utils/ReservoirDownsamplerUnitTest.java | 124 ++--- 5 files changed, 505 insertions(+), 385 deletions(-) rename java/src/org/broadinstitute/sting/gatk/iterators/{DownsamplingReferenceImplementation.java => DownsamplingLocusIteratorByState.java} (64%) create mode 100644 java/src/org/broadinstitute/sting/utils/HeapSizeMonitor.java diff --git a/java/src/org/broadinstitute/sting/gatk/GenomeAnalysisEngine.java b/java/src/org/broadinstitute/sting/gatk/GenomeAnalysisEngine.java index 9493b2a0e..a8b48bd63 100755 --- a/java/src/org/broadinstitute/sting/gatk/GenomeAnalysisEngine.java +++ b/java/src/org/broadinstitute/sting/gatk/GenomeAnalysisEngine.java @@ -138,6 +138,9 @@ public class GenomeAnalysisEngine { * @return the value of this traversal. */ public Object execute(GATKArgumentCollection args, Walker my_walker, Collection filters) { + //HeapSizeMonitor monitor = new HeapSizeMonitor(); + //monitor.start(); + // validate our parameters if (args == null) { throw new StingException("The GATKArgumentCollection passed to GenomeAnalysisEngine can not be null."); @@ -169,7 +172,12 @@ public class GenomeAnalysisEngine { readsDataSource != null ? readsDataSource.getReadsInfo().getValidationExclusionList() : null); // execute the microscheduler, storing the results - return microScheduler.execute(my_walker, shardStrategy, argCollection.maximumEngineIterations); + Object result = microScheduler.execute(my_walker, shardStrategy, argCollection.maximumEngineIterations); + + //monitor.stop(); + //logger.info(String.format("Maximum heap size consumed: %d",monitor.getMaxMemoryUsed())); + + return result; } /** @@ -694,7 +702,7 @@ public class GenomeAnalysisEngine { else throw new StingException("The GATK cannot currently process unindexed BAM files"); - return new MonolithicShardStrategy(shardType); + return new (shardType); } ShardStrategy shardStrategy = null; diff --git a/java/src/org/broadinstitute/sting/gatk/iterators/DownsamplingReferenceImplementation.java b/java/src/org/broadinstitute/sting/gatk/iterators/DownsamplingLocusIteratorByState.java similarity index 64% rename from java/src/org/broadinstitute/sting/gatk/iterators/DownsamplingReferenceImplementation.java rename to java/src/org/broadinstitute/sting/gatk/iterators/DownsamplingLocusIteratorByState.java index 95d992048..7dd44b7f4 100755 --- a/java/src/org/broadinstitute/sting/gatk/iterators/DownsamplingReferenceImplementation.java +++ b/java/src/org/broadinstitute/sting/gatk/iterators/DownsamplingLocusIteratorByState.java @@ -32,7 +32,6 @@ import org.broadinstitute.sting.gatk.Reads; import org.broadinstitute.sting.gatk.GenomeAnalysisEngine; import org.broadinstitute.sting.gatk.contexts.AlignmentContext; import org.broadinstitute.sting.utils.*; -import org.broadinstitute.sting.utils.sam.AlignmentStartComparator; import org.broadinstitute.sting.utils.pileup.PileupElement; import org.broadinstitute.sting.utils.pileup.ReadBackedPileup; import org.broadinstitute.sting.utils.pileup.ExtendedEventPileupElement; @@ -41,27 +40,20 @@ import org.broadinstitute.sting.utils.pileup.ReadBackedExtendedEventPileup; import java.util.*; /** Iterator that traverses a SAM File, accumulating information on a per-locus basis */ -public class DownsamplingReferenceImplementation extends LocusIterator { - // TODO: Reintegrate LocusOverflowTracker +public class DownsamplingLocusIteratorByState extends LocusIterator { /** our log, which we want to capture anything from this class */ - private static Logger logger = Logger.getLogger(DownsamplingReferenceImplementation.class); - - /** - * Store a random number generator with a consistent seed for consistent downsampling from run to run. - * Note that each shard will be initialized with the same random seed; this will ensure consistent results - * across parallelized runs, at the expense of decreasing our level of randomness. - */ - private Random downsampleRandomizer = new Random(38148309L); + private static Logger logger = Logger.getLogger(LocusIteratorByState.class); // ----------------------------------------------------------------------------------------------------------------- // // member fields // // ----------------------------------------------------------------------------------------------------------------- - private final PeekableIterator> downsamplingIterator; private boolean hasExtendedEvents = false; // will be set to true if at least one read had an indel right before the current position - private Collection sampleNames = new ArrayList(); + + private final Collection sampleNames = new ArrayList(); + private final ReadStateManager readStates; private class SAMRecordState { SAMRecord read; @@ -252,28 +244,24 @@ public class DownsamplingReferenceImplementation extends LocusIterator { } } - private LinkedList readStates = new LinkedList(); //final boolean DEBUG = false; //final boolean DEBUG2 = false && DEBUG; private Reads readInfo; + private AlignmentContext nextAlignmentContext; // ----------------------------------------------------------------------------------------------------------------- // // constructors and other basic operations // // ----------------------------------------------------------------------------------------------------------------- - public DownsamplingReferenceImplementation(final Iterator samIterator, Reads readInformation) { - ReservoirDownsampler downsampler = new ReservoirDownsampler(samIterator, - new AlignmentStartComparator(), - new SampleNamePartitioner(), - readInformation.getMaxReadsAtLocus()); - this.downsamplingIterator = new PeekableIterator>(downsampler); - this.readInfo = readInformation; - + public DownsamplingLocusIteratorByState(final Iterator samIterator, Reads readInformation) { // Aggregate all sample names. // TODO: Push in header via constructor if(GenomeAnalysisEngine.instance.getDataSource() != null) sampleNames.addAll(SampleUtils.getSAMFileSamples(GenomeAnalysisEngine.instance.getSAMFileHeader())); + readStates = new ReadStateManager(samIterator,sampleNames,readInformation.getMaxReadsAtLocus()); + this.readInfo = readInformation; + } public Iterator iterator() { @@ -285,9 +273,9 @@ public class DownsamplingReferenceImplementation extends LocusIterator { } public boolean hasNext() { - boolean r = ! readStates.isEmpty() || downsamplingIterator.hasNext(); + lazyLoadNextAlignmentContext(); + boolean r = (nextAlignmentContext != null); //if ( DEBUG ) System.out.printf("hasNext() = %b%n", r); - return r; } @@ -300,11 +288,6 @@ public class DownsamplingReferenceImplementation extends LocusIterator { } } - public void clear() { - logger.debug(String.format(("clear() called"))); - readStates.clear(); - } - private GenomeLoc getLocation() { return readStates.isEmpty() ? null : readStates.getFirst().getLocation(); } @@ -315,12 +298,20 @@ public class DownsamplingReferenceImplementation extends LocusIterator { // // ----------------------------------------------------------------------------------------------------------------- public AlignmentContext next() { - // keep iterating forward until we encounter a reference position that has something "real" hanging over it - // (i.e. either a real base, or a real base or a deletion if includeReadsWithDeletion is true) - - - while(true) { + lazyLoadNextAlignmentContext(); + if(!hasNext()) + throw new NoSuchElementException("LocusIteratorByState: out of elements."); + AlignmentContext currentAlignmentContext = nextAlignmentContext; + nextAlignmentContext = null; + return currentAlignmentContext; + } + /** + * Creates the next alignment context from the given state. Note that this is implemented as a lazy load method. + * nextAlignmentContext MUST BE null in order for this method to advance to the next entry. + */ + private void lazyLoadNextAlignmentContext() { + while(nextAlignmentContext == null && readStates.hasNext()) { // this call will set hasExtendedEvents to true if it picks up a read with indel right before the current position on the ref: collectPendingReads(readInfo.getMaxReadsAtLocus()); @@ -330,12 +321,12 @@ public class DownsamplingReferenceImplementation extends LocusIterator { int nMQ0Reads = 0; - // if extended events are requested, and if previous traversal step brought us over an indel in + // if extended events are requested, and if previous traversal step brought us over an indel in // at least one read, we emit extended pileup (making sure that it is associated with the previous base, // i.e. the one right *before* the indel) and do NOT shift the current position on the ref. // In this case, the subsequent call to next() will emit the normal pileup at the current base // and shift the position. - if ( readInfo.generateExtendedEvents() && hasExtendedEvents ) { + if (readInfo.generateExtendedEvents() && hasExtendedEvents) { ArrayList indelPile = new ArrayList(readStates.size()); int maxDeletionLength = 0; @@ -382,12 +373,10 @@ public class DownsamplingReferenceImplementation extends LocusIterator { GenomeLoc loc = GenomeLocParser.incPos(our1stState.getLocation(),-1); // System.out.println("Indel(s) at "+loc); // for ( ExtendedEventPileupElement pe : indelPile ) { if ( pe.isIndel() ) System.out.println(" "+pe.toString()); } - return new AlignmentContext(loc, new ReadBackedExtendedEventPileup(loc, indelPile, size, maxDeletionLength, nInsertions, nDeletions, nMQ0Reads)); + nextAlignmentContext = new AlignmentContext(loc, new ReadBackedExtendedEventPileup(loc, indelPile, size, maxDeletionLength, nInsertions, nDeletions, nMQ0Reads)); } else { - ArrayList pile = new ArrayList(readStates.size()); - // todo -- performance problem -- should be lazy, really for ( SAMRecordState state : readStates ) { if ( state.getCurrentCigarOperator() != CigarOperator.D && state.getCurrentCigarOperator() != CigarOperator.N ) { @@ -410,9 +399,8 @@ public class DownsamplingReferenceImplementation extends LocusIterator { GenomeLoc loc = getLocation(); updateReadStates(); // critical - must be called after we get the current state offsets and location // if we got reads with non-D/N over the current position, we are done - if ( pile.size() != 0 ) return new AlignmentContext(loc, new ReadBackedPileup(loc, pile, size, nDeletions, nMQ0Reads)); + if ( pile.size() != 0 ) nextAlignmentContext = new AlignmentContext(loc, new ReadBackedPileup(loc, pile, size, nDeletions, nMQ0Reads)); } - } } @@ -455,125 +443,10 @@ public class DownsamplingReferenceImplementation extends LocusIterator { // } // } - private void collectPendingReads(int maxReadsPerSample) { - if(maxReadsPerSample <= 0) - throw new StingException("maxReadsPerSample is too low; it is " + maxReadsPerSample + ", but must be greater than 0"); - - while (downsamplingIterator.hasNext()) { - Collection reads = downsamplingIterator.peek(); - if(!reads.isEmpty() && !readIsPastCurrentPosition(reads.iterator().next())) { - // Consume the collection of reads. - downsamplingIterator.next(); - - for(String sampleName: sampleNames) { - LinkedList newReads = getReadsForGivenSample(reads,sampleName); - LinkedList existingReadStates = getReadStateForGivenSample(readStates,sampleName); - - if(existingReadStates.size()+newReads.size() <= maxReadsPerSample) { - for(SAMRecord read: newReads) { - SAMRecordState state = new SAMRecordState(read, readInfo.generateExtendedEvents()); - state.stepForwardOnGenome(); - readStates.add(state); - // TODO: What if we downsample the extended events away? - if (state.hadIndel()) hasExtendedEvents = true; - } - } - else { - // If we've reached this point, the active list of read states needs to be pruned. Start by - // pruning one off each alignment start, working backward. Repeat until there's either < 1 - // read available at any locus or - - // readStatesAtAlignmentStart stores a full complement of reads starting at a given locus. - List readStatesAtAlignmentStart = new ArrayList(); - List readStatesToPrune = new LinkedList(); - - while((existingReadStates.size()-readStatesToPrune.size()+newReads.size())>maxReadsPerSample) { - readStatesToPrune.clear(); - Iterator descendingIterator = existingReadStates.descendingIterator(); - while(descendingIterator.hasNext()) { - // Accumulate all reads at a given alignment start. - SAMRecordState currentState = descendingIterator.next(); - if(readStatesAtAlignmentStart.isEmpty() || - readStatesAtAlignmentStart.get(0).getRead().getAlignmentStart()==currentState.getRead().getAlignmentStart()) - readStatesAtAlignmentStart.add(currentState); - else { - if(readStatesAtAlignmentStart.size() > 1) { - SAMRecordState stateToRemove = readStatesAtAlignmentStart.get(downsampleRandomizer.nextInt(readStatesAtAlignmentStart.size())); - readStatesToPrune.add(stateToRemove); - if((existingReadStates.size()-readStatesToPrune.size()+newReads.size())<=maxReadsPerSample) - break; - } - readStatesAtAlignmentStart.clear(); - readStatesAtAlignmentStart.add(currentState); - } - } - - // Cleanup on last locus viewed. - if(readStatesAtAlignmentStart.size() > 1 && (existingReadStates.size()-readStatesToPrune.size()+newReads.size())>maxReadsPerSample) { - SAMRecordState stateToRemove = readStatesAtAlignmentStart.get(downsampleRandomizer.nextInt(readStatesAtAlignmentStart.size())); - readStatesToPrune.add(stateToRemove); - } - readStatesAtAlignmentStart.clear(); - - // Nothing left to prune. Break out to avoid infinite loop. - if(readStatesToPrune.isEmpty()) - break; - - // Get rid of all the chosen reads. - existingReadStates.removeAll(readStatesToPrune); - readStates.removeAll(readStatesToPrune); - } - - // Still no space available? Prune the leftmost read. - if(existingReadStates.size() >= maxReadsPerSample) { - SAMRecordState initialReadState = existingReadStates.remove(); - readStates.remove(initialReadState); - } - - // Fill from the list of new reads until we're either out of new reads or at capacity. - for(SAMRecord read: newReads) { - SAMRecordState state = new SAMRecordState(read, readInfo.generateExtendedEvents()); - state.stepForwardOnGenome(); - existingReadStates.add(state); - readStates.add(state); - // TODO: What if we downsample the extended events away? - if (state.hadIndel()) hasExtendedEvents = true; - if(existingReadStates.size()>=maxReadsPerSample) - break; - } - } - } - - //if (DEBUG) logger.debug(String.format(" ... added read %s", read.getReadName())); - } - else if(readIsPastCurrentPosition(reads.iterator().next())) - break; - } + private void collectPendingReads(int maxReads) { + readStates.collectPendingReads(); } - private LinkedList getReadsForGivenSample(final Collection reads, final String sampleName) { - // TODO: What about files with no read groups? What about files with no samples? - LinkedList readsForGivenSample = new LinkedList(); - for(SAMRecord read: reads) { - Object readSampleName = read.getReadGroup().getSample(); - if(readSampleName != null && readSampleName.equals(sampleName)) - readsForGivenSample.add(read); - } - return readsForGivenSample; - } - - private LinkedList getReadStateForGivenSample(final Collection readStates, final String sampleName) { - // TODO: What about files with no read groups? What about files with no samples? - LinkedList readStatesForGivenSample = new LinkedList(); - for(SAMRecordState readState: readStates) { - Object readSampleName = readState.getRead().getReadGroup().getSample(); - if(readSampleName != null && readSampleName.equals(sampleName)) - readStatesForGivenSample.add(readState); - } - return readStatesForGivenSample; - } - - // fast testing of position private boolean readIsPastCurrentPosition(SAMRecord read) { if ( readStates.isEmpty() ) @@ -632,14 +505,245 @@ public class DownsamplingReferenceImplementation extends LocusIterator { return null; } - /** - * Partitions a dataset by sample name. - */ - private class SampleNamePartitioner implements ReservoirDownsampler.Partitioner { - public Object partition(SAMRecord read) { - if(read.getReadGroup() != null && read.getReadGroup().getAttribute("SM") != null) - return read.getReadGroup().getAttribute("SM"); - return null; + private class ReadStateManager implements Iterable { + private final PeekableIterator iterator; + private final Map> downsamplersBySampleName = new HashMap>(); + private final int maxReadsPerSample; + + private final Deque>> readStatesByAlignmentStart; + + /** + * Store a random number generator with a consistent seed for consistent downsampling from run to run. + * Note that each shard will be initialized with the same random seed; this will ensure consistent results + * across parallelized runs, at the expense of decreasing our level of randomness. + */ + private Random downsampleRandomizer = new Random(38148309L); + + public ReadStateManager(Iterator source, Collection sampleNames, int maxReadsPerSample) { + this.iterator = new PeekableIterator(source); + this.maxReadsPerSample = maxReadsPerSample; + for(String sampleName: sampleNames) + downsamplersBySampleName.put(sampleName,new ReservoirDownsampler(maxReadsPerSample)); + this.readStatesByAlignmentStart = new LinkedList>>(); + } + + public Iterator iterator() { + return new Iterator() { + private final Iterator>> alignmentStartIterator; + private Iterator> sampleIterator; + private Iterator readStateIterator; + private SAMRecordState nextReadState; + private int readsInHanger = countReadsInHanger(); + + { + pruneEmptyElementsInHanger(); + alignmentStartIterator = readStatesByAlignmentStart.iterator(); + sampleIterator = alignmentStartIterator.hasNext() ? alignmentStartIterator.next().values().iterator() : null; + readStateIterator = (sampleIterator!=null && sampleIterator.hasNext()) ? sampleIterator.next().iterator() : null; + } + + public boolean hasNext() { + return readsInHanger > 0; + } + + public SAMRecordState next() { + advance(); + if(nextReadState==null) throw new NoSuchElementException("reader is out of elements"); + try { + return nextReadState; + } + finally { + nextReadState = null; + } + } + + public void remove() { + if(readStateIterator == null) + throw new StingException("Attempted to remove read, but no previous read was found."); + readStateIterator.remove(); + } + + private void advance() { + nextReadState = null; + if(readStateIterator!=null && readStateIterator.hasNext()) + nextReadState = readStateIterator.next(); + else if(sampleIterator!=null && sampleIterator.hasNext()) { + readStateIterator = sampleIterator.next().iterator(); + nextReadState = readStateIterator.hasNext() ? readStateIterator.next() : null; + } + else if(alignmentStartIterator!=null && alignmentStartIterator.hasNext()) { + sampleIterator = alignmentStartIterator.next().values().iterator(); + readStateIterator = sampleIterator.hasNext() ? sampleIterator.next().iterator() : null; + nextReadState = (readStateIterator!=null && readStateIterator.hasNext()) ? readStateIterator.next() : null; + } + + if(nextReadState != null) readsInHanger--; + } + }; + } + + public boolean isEmpty() { + pruneEmptyElementsInHanger(); + return readStatesByAlignmentStart.isEmpty(); + } + + public int size() { + int size = 0; + for(Map> readStatesBySample: readStatesByAlignmentStart) { + for(Collection readStates: readStatesBySample.values()) + size += readStates.size(); + } + return size; + } + + public SAMRecordState getFirst() { + return iterator().next(); + } + + public boolean hasNext() { + pruneEmptyElementsInHanger(); + return !readStatesByAlignmentStart.isEmpty() || iterator.hasNext(); + } + + public void collectPendingReads() { + while (iterator.hasNext() && !readIsPastCurrentPosition(iterator.peek())) { + SAMRecord read = iterator.next(); + downsamplersBySampleName.get(read.getReadGroup().getSample()).add(read); + } + + Map> culledReadStatesBySample = new HashMap>(); + + for(Map.Entry> entry: downsamplersBySampleName.entrySet()) { + String sampleName = entry.getKey(); + ReservoirDownsampler downsampler = entry.getValue(); + + Collection newReads = downsampler.getDownsampledContents(); + downsampler.clear(); + int readsInHanger = countReadsInHanger(sampleName); + + if(readsInHanger+newReads.size() <= maxReadsPerSample) + addReadsToHanger(culledReadStatesBySample,sampleName,newReads,newReads.size()); + else { + Iterator>> backIterator = readStatesByAlignmentStart.descendingIterator(); + boolean readPruned = true; + while(readsInHanger+newReads.size()>maxReadsPerSample && readPruned) { + readPruned = false; + while(readsInHanger+newReads.size()>maxReadsPerSample && backIterator.hasNext()) { + List readsAtLocus = backIterator.next().get(sampleName); + if(readsAtLocus.size() > 1) { + readsAtLocus.remove(downsampleRandomizer.nextInt(readsAtLocus.size())); + readPruned = true; + readsInHanger--; + } + } + } + + if(readsInHanger == maxReadsPerSample) { + Collection firstHangerForSample = readStatesByAlignmentStart.getFirst().get(sampleName); + readsInHanger -= firstHangerForSample.size(); + firstHangerForSample.clear(); + } + + addReadsToHanger(culledReadStatesBySample,sampleName,newReads,maxReadsPerSample-readsInHanger); + } + + readStatesByAlignmentStart.add(culledReadStatesBySample); + } + +/* else { + if() { + // Consume the collection of reads. + downsamplingIterator.next(); + + Map> newReadsBySample = new HashMap>(); + Map> culledReadStatesBySample = new HashMap>(); + + for(String sampleName: sampleNames) + newReadsBySample.put(sampleName,getReadsForGivenSample(reads,sampleName)); + + for(String sampleName: newReadsBySample.keySet()) { + Collection newReads = newReadsBySample.get(sampleName); + int readsInHanger = countReadsInHanger(sampleName); + + //if(readsInHanger+newReads.size() <= maxReadsPerSample) + addReadsToHanger(culledReadStatesBySample,sampleName,newReads,newReads.size()); + Iterator>> backIterator = readStatesByAlignmentStart.descendingIterator(); + boolean readPruned = true; + while(readsInHanger+newReads.size()>maxReadsPerSample && readPruned) { + readPruned = false; + while(readsInHanger+newReads.size()>maxReadsPerSample && backIterator.hasNext()) { + List readsAtLocus = backIterator.next().get(sampleName); + if(readsAtLocus.size() > 1) { + readsAtLocus.remove(downsampleRandomizer.nextInt(readsAtLocus.size())); + readPruned = true; + readsInHanger--; + } + } + } + + if(readsInHanger == maxReadsPerSample) { + Collection firstHangerForSample = readStatesByAlignmentStart.getFirst().get(sampleName); + readsInHanger -= firstHangerForSample.size(); + firstHangerForSample.clear(); + } + + addReadsToHanger(culledReadStatesBySample,sampleName,newReads,maxReadsPerSample-readsInHanger); + } + } + + readStatesByAlignmentStart.add(culledReadStatesBySample); + } + else if(readIsPastCurrentPosition(reads.iterator().next())) + break; + } +*/ + } + + private int countReadsInHanger() { + int count = 0; + for(Map> hangerEntry: readStatesByAlignmentStart) { + for(List reads: hangerEntry.values()) + count += reads.size(); + } + return count; + } + + private int countReadsInHanger(final String sampleName) { + int count = 0; + for(Map> hangerEntry: readStatesByAlignmentStart) { + if(sampleName == null && hangerEntry.containsKey(sampleName)) + count += hangerEntry.get(sampleName).size(); + } + return count; + } + + private void addReadsToHanger(final Map> newHanger, final String sampleName, final Collection reads, final int maxReads) { + List hangerEntry = new LinkedList(); + int readCount = 0; + for(SAMRecord read: reads) { + if(readCount >= maxReads) + break; + SAMRecordState state = new SAMRecordState(read, readInfo.generateExtendedEvents()); + state.stepForwardOnGenome(); + hangerEntry.add(state); + // TODO: What if we downsample the extended events away? + if (state.hadIndel()) hasExtendedEvents = true; + readCount++; + } + newHanger.put(sampleName,hangerEntry); + } + + private void pruneEmptyElementsInHanger() { + Iterator>> hangerIterator = readStatesByAlignmentStart.iterator(); + while(hangerIterator.hasNext()) { + Map> hangerEntry = hangerIterator.next(); + for(String sampleName: sampleNames) { + if(hangerEntry.containsKey(sampleName) && hangerEntry.get(sampleName).size() == 0) + hangerEntry.remove(sampleName); + } + if(hangerEntry.size() == 0) + hangerIterator.remove(); + } } } } diff --git a/java/src/org/broadinstitute/sting/utils/HeapSizeMonitor.java b/java/src/org/broadinstitute/sting/utils/HeapSizeMonitor.java new file mode 100644 index 000000000..07211d67a --- /dev/null +++ b/java/src/org/broadinstitute/sting/utils/HeapSizeMonitor.java @@ -0,0 +1,80 @@ +package org.broadinstitute.sting.utils; + +import java.lang.management.ManagementFactory; +import java.lang.management.MemoryMXBean; + +/** + * Monitor the current heap size, allowing the application to programmatically + * access the data. + * + * @author mhanna + * @version 0.1 + */ +public class HeapSizeMonitor { + private final int monitorFrequencyMillis; + private final MonitorRunnable monitorRunnable; + + private Thread monitorThread; + + public HeapSizeMonitor() { + this(1000); + } + + public HeapSizeMonitor(final int monitorFrequencyMillis) { + this.monitorFrequencyMillis = monitorFrequencyMillis; + this.monitorRunnable = new MonitorRunnable(); + } + + public long getMaxMemoryUsed() { + return monitorRunnable.getMaxMemoryUsed(); + } + + public void start() { + monitorThread = new Thread(monitorRunnable); + monitorThread.start(); + } + + public void stop() { + monitorRunnable.stop = true; + try { + monitorThread.join(); + } + catch(InterruptedException ex) { + throw new StingException("Unable to connect to monitor thread"); + } + monitorThread = null; + } + + private class MonitorRunnable implements Runnable { + private MemoryMXBean monitor; + + private long maxMemoryUsed; + private boolean stop; + + public MonitorRunnable() { + monitor = ManagementFactory.getMemoryMXBean(); + } + + public void reset() { + maxMemoryUsed = 0L; + stop = false; + } + + public long getMaxMemoryUsed() { + return maxMemoryUsed; + } + + public void run() { + while(!stop) { + System.gc(); + maxMemoryUsed = Math.max(monitor.getHeapMemoryUsage().getUsed(),maxMemoryUsed); + try { + Thread.sleep(monitorFrequencyMillis); + } + catch(InterruptedException ex) { + throw new StingException("Unable to continue monitoring heap consumption",ex); + } + } + } + } +} diff --git a/java/src/org/broadinstitute/sting/utils/ReservoirDownsampler.java b/java/src/org/broadinstitute/sting/utils/ReservoirDownsampler.java index 13af0041c..ea4b848c0 100644 --- a/java/src/org/broadinstitute/sting/utils/ReservoirDownsampler.java +++ b/java/src/org/broadinstitute/sting/utils/ReservoirDownsampler.java @@ -9,37 +9,20 @@ import java.util.*; * naive implementation of reservoir downsampling as described in "Random Downsampling * with a Reservoir" (Vitter 1985). At time of writing, this paper is located here: * http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.138.784&rep=rep1&type=pdf - * - * Contains an enhancement allowing users to partition downsampled data. If a partitioner - * is used, each partition will be allowed to contain maxElements elements. - * - * Note that using the ReservoirDownsampler will leave the given iterator in an undefined - * state. Do not attempt to use the iterator (other than closing it) after the Downsampler - * completes. - * + * @author mhanna * @version 0.1 */ -public class ReservoirDownsampler implements Iterator> { +public class ReservoirDownsampler implements Collection { /** * Create a random number generator with a random, but reproducible, seed. */ private final Random random = new Random(47382911L); /** - * The data source, wrapped in a peekable input stream. + * The reservoir of elements tracked by this downsampler. */ - private final PeekableIterator iterator; - - /** - * Used to identify whether two elements are 'equal' in the eyes of the downsampler. - */ - private final Comparator comparator; - - /** - * Partitions the elements into subsets, each having an equal number of maxElements. - */ - private final Partitioner partitioner; + private final ArrayList reservoir; /** * What is the maximum number of reads that can be returned in a single batch. @@ -48,138 +31,105 @@ public class ReservoirDownsampler implements Iterator> { /** * Create a new downsampler with the given source iterator and given comparator. - * @param iterator Source of the data stream. - * @param comparator Used to compare two records to see whether they're 'equal' at this position. - * @param maxElements What is the maximum number of reads that can be returned in any partition of any call of this iterator. - */ - public ReservoirDownsampler(final Iterator iterator, final Comparator comparator, final int maxElements) { - this(iterator,comparator,null,maxElements); - } - - /** - * Create a new downsampler with the given source iterator and given comparator. - * @param iterator Source of the data stream. - * @param comparator Used to compare two records to see whether they're 'equal' at this position. - * @param partitioner Used to divide the elements into bins. Each bin can have maxElements elements. * @param maxElements What is the maximum number of reads that can be returned in any call of this */ - public ReservoirDownsampler(final Iterator iterator, final Comparator comparator, final Partitioner partitioner, final int maxElements) { - this.iterator = new PeekableIterator(iterator); - this.comparator = comparator; - this.partitioner = partitioner; + public ReservoirDownsampler(final int maxElements) { if(maxElements < 0) throw new StingException("Unable to work with an negative size collection of elements"); + this.reservoir = new ArrayList(maxElements); this.maxElements = maxElements; } - public boolean hasNext() { - return iterator.hasNext(); - } - - /** - * Gets a collection of 'equal' elements, as judged by the comparator. If the number of equal elements - * is greater than the maximum, then the elements in the collection should be a truly random sampling. - * @return Collection of equal elements. - */ - public Collection next() { - if(!hasNext()) - throw new NoSuchElementException("No next element is present."); - - Map> partitions = new HashMap>(); - - // Determine our basis of equality. - T first = iterator.next(); - - if(maxElements > 0) - getPartitionForEntry(partitions,first).add(first); - - while(iterator.hasNext() && comparator.compare(first,iterator.peek()) == 0) { - T candidate = iterator.next(); - getPartitionForEntry(partitions,candidate).add(candidate); - } - - LinkedList batch = new LinkedList(); - for(Partition partition: partitions.values()) - batch.addAll(partition.elements); - - return batch; - } - - /** - * Gets the appropriate partition for the given entry from storage. - * @param partitions List of partitions from which to choose. - * @param entry Entry for which to compute the partition. - * @return The partition associated with this entry. Will be created if not present. - */ - private Partition getPartitionForEntry(final Map> partitions, final T entry) { - Object partition = partitioner!=null ? partitioner.partition(entry) : null; - if(!partitions.containsKey(partition)) - partitions.put(partition,new Partition(maxElements)); - return partitions.get(partition); - } - - /** - * Unsupported; throws exception to that effect. - */ - public void remove() { - throw new UnsupportedOperationException("Cannot remove from a ReservoirDownsampler."); - } - - /** - * A common interface for a functor that can take data of - * some type and return an object that can be used to partition - * that data in some way. Really just a declaration of a - * specialized map function. - */ - public interface Partitioner { - public Object partition(T input); - } - - /** - * Models a partition of a given set of elements. Knows how to select - * random elements with replacement. - * @param Data type for the elements of the partition. - */ - private class Partition { - /** - * How large can this partition grow? - */ - private final int partitionSize; - - /** - * The elements of the partition. - */ - private List elements = new ArrayList(); - - /** - * The total number of elements seen. - */ - private long elementsSeen = 0; - - public Partition(final int partitionSize) { - this.partitionSize = partitionSize; + @Override + public boolean add(T element) { + if(maxElements <= 0) + return false; + else if(reservoir.size() < maxElements) { + reservoir.add(element); + return true; } - - /** - * Add a new element to this collection, downsampling as necessary so that the partition - * stays under partitionSize elements. - * @param element Element to conditionally add. - */ - public void add(T element) { - if(elements.size() < partitionSize) - elements.add(element); - else { - // Get a uniformly distributed long > 0 and remap it to the range from [0,elementsSeen). - long slot = random.nextLong(); - while(slot == Long.MIN_VALUE) - slot = random.nextLong(); - slot = (long)(((float)Math.abs(slot))/Long.MAX_VALUE * (elementsSeen-1)); - - // If the chosen slot lives within the partition, replace the entry in that slot with the newest entry. - if(slot >= 0 && slot < partitionSize) - elements.set((int)slot,element); + else { + // Get a uniformly distributed int. If the chosen slot lives within the partition, replace the entry in that slot with the newest entry. + int slot = random.nextInt(maxElements); + if(slot >= 0 && slot < maxElements) { + reservoir.set(slot,element); + return true; } - elementsSeen++; + else + return false; } } + + @Override + public boolean addAll(Collection elements) { + boolean added = false; + for(T element: elements) + added |= add(element); + return added; + } + + /** + * Returns the contents of this reservoir, downsampled to the given value. Note that the return value + * @return The downsampled contents of this reservoir. + */ + public Collection getDownsampledContents() { + return Collections.unmodifiableCollection(reservoir); + } + + @Override + public void clear() { + reservoir.clear(); + } + + @Override + public boolean isEmpty() { + return reservoir.isEmpty(); + } + + @Override + public int size() { + return reservoir.size(); + } + + @Override + public Iterator iterator() { + return reservoir.iterator(); + } + + @Override + public boolean contains(Object o) { + return reservoir.contains(o); + } + + @Override + public boolean containsAll(Collection elements) { + return reservoir.containsAll(elements); + } + + @Override + public boolean retainAll(Collection elements) { + return reservoir.retainAll(elements); + } + + @Override + public boolean remove(Object o) { + return reservoir.remove(o); + } + + @Override + public boolean removeAll(Collection elements) { + return reservoir.removeAll(elements); + } + + @Override + public Object[] toArray() { + Object[] contents = new Object[reservoir.size()]; + reservoir.toArray(contents); + return contents; + } + + @Override + public T[] toArray(T[] array) { + return reservoir.toArray(array); + } } diff --git a/java/test/org/broadinstitute/sting/utils/ReservoirDownsamplerUnitTest.java b/java/test/org/broadinstitute/sting/utils/ReservoirDownsamplerUnitTest.java index c166904e6..636bddbcb 100644 --- a/java/test/org/broadinstitute/sting/utils/ReservoirDownsamplerUnitTest.java +++ b/java/test/org/broadinstitute/sting/utils/ReservoirDownsamplerUnitTest.java @@ -3,6 +3,7 @@ package org.broadinstitute.sting.utils; import org.junit.Test; import org.broadinstitute.sting.utils.sam.AlignmentStartComparator; import org.broadinstitute.sting.utils.sam.ArtificialSAMUtils; +import org.broadinstitute.sting.gatk.iterators.NullSAMIterator; import net.sf.samtools.SAMRecord; import net.sf.samtools.SAMFileHeader; @@ -24,19 +25,18 @@ public class ReservoirDownsamplerUnitTest { @Test public void testEmptyIterator() { - ReservoirDownsampler downsampler = new ReservoirDownsampler(Collections.emptyList().iterator(), - new AlignmentStartComparator(),1); - Assert.assertFalse("Downsampler is not empty but should be.",downsampler.hasNext()); + ReservoirDownsampler downsampler = new ReservoirDownsampler(1); + Assert.assertTrue("Downsampler is not empty but should be.",downsampler.isEmpty()); } @Test public void testOneElementWithPoolSizeOne() { List reads = Collections.singletonList(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76)); - ReservoirDownsampler downsampler = new ReservoirDownsampler(reads.iterator(), - new AlignmentStartComparator(),1); + ReservoirDownsampler downsampler = new ReservoirDownsampler(1); + downsampler.addAll(reads); - Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext()); - Collection batchedReads = downsampler.next(); + Assert.assertFalse("Downsampler is empty but shouldn't be",downsampler.isEmpty()); + Collection batchedReads = downsampler.getDownsampledContents(); Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size()); Assert.assertSame("Downsampler is returning an incorrect read",reads.get(0),batchedReads.iterator().next()); } @@ -44,11 +44,11 @@ public class ReservoirDownsamplerUnitTest { @Test public void testOneElementWithPoolSizeGreaterThanOne() { List reads = Collections.singletonList(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76)); - ReservoirDownsampler downsampler = new ReservoirDownsampler(reads.iterator(), - new AlignmentStartComparator(),5); + ReservoirDownsampler downsampler = new ReservoirDownsampler(5); + downsampler.addAll(reads); - Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext()); - Collection batchedReads = downsampler.next(); + Assert.assertFalse("Downsampler is empty but shouldn't be",downsampler.isEmpty()); + Collection batchedReads = downsampler.getDownsampledContents(); Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size()); Assert.assertSame("Downsampler is returning an incorrect read",reads.get(0),batchedReads.iterator().next()); @@ -60,11 +60,11 @@ public class ReservoirDownsamplerUnitTest { reads.add(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76)); reads.add(ArtificialSAMUtils.createArtificialRead(header,"read2",0,1,76)); reads.add(ArtificialSAMUtils.createArtificialRead(header,"read3",0,1,76)); - ReservoirDownsampler downsampler = new ReservoirDownsampler(reads.iterator(), - new AlignmentStartComparator(),5); + ReservoirDownsampler downsampler = new ReservoirDownsampler(5); + downsampler.addAll(reads); - Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext()); - List batchedReads = new ArrayList(downsampler.next()); + Assert.assertFalse("Downsampler is empty but shouldn't be",downsampler.isEmpty()); + List batchedReads = new ArrayList(downsampler.getDownsampledContents()); Assert.assertEquals("Downsampler is returning the wrong number of reads",3,batchedReads.size()); Assert.assertSame("Downsampler read 1 is incorrect",reads.get(0),batchedReads.get(0)); @@ -80,11 +80,11 @@ public class ReservoirDownsamplerUnitTest { reads.add(ArtificialSAMUtils.createArtificialRead(header,"read3",0,1,76)); reads.add(ArtificialSAMUtils.createArtificialRead(header,"read4",0,1,76)); reads.add(ArtificialSAMUtils.createArtificialRead(header,"read5",0,1,76)); - ReservoirDownsampler downsampler = new ReservoirDownsampler(reads.iterator(), - new AlignmentStartComparator(),5); + ReservoirDownsampler downsampler = new ReservoirDownsampler(5); + downsampler.addAll(reads); - Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext()); - List batchedReads = new ArrayList(downsampler.next()); + Assert.assertFalse("Downsampler is empty but shouldn't be",downsampler.isEmpty()); + List batchedReads = new ArrayList(downsampler.getDownsampledContents()); Assert.assertEquals("Downsampler is returning the wrong number of reads",5,batchedReads.size()); Assert.assertSame("Downsampler is returning an incorrect read",reads.get(0),batchedReads.iterator().next()); @@ -101,13 +101,12 @@ public class ReservoirDownsamplerUnitTest { reads.add(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76)); reads.add(ArtificialSAMUtils.createArtificialRead(header,"read2",0,1,76)); reads.add(ArtificialSAMUtils.createArtificialRead(header,"read3",0,1,76)); - ReservoirDownsampler downsampler = new ReservoirDownsampler(reads.iterator(), - new AlignmentStartComparator(),0); + ReservoirDownsampler downsampler = new ReservoirDownsampler(0); + downsampler.addAll(reads); - Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext()); - List batchedReads = new ArrayList(downsampler.next()); + Assert.assertTrue("Downsampler isn't empty but should be",downsampler.isEmpty()); + List batchedReads = new ArrayList(downsampler.getDownsampledContents()); Assert.assertEquals("Downsampler is returning the wrong number of reads",0,batchedReads.size()); - Assert.assertFalse("Downsampler is not empty but should be",downsampler.hasNext()); } @Test @@ -118,73 +117,52 @@ public class ReservoirDownsamplerUnitTest { reads.add(ArtificialSAMUtils.createArtificialRead(header,"read3",0,1,76)); reads.add(ArtificialSAMUtils.createArtificialRead(header,"read4",0,1,76)); reads.add(ArtificialSAMUtils.createArtificialRead(header,"read5",0,1,76)); - ReservoirDownsampler downsampler = new ReservoirDownsampler(reads.iterator(), - new AlignmentStartComparator(),1); + ReservoirDownsampler downsampler = new ReservoirDownsampler(1); + downsampler.addAll(reads); - Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext()); - List batchedReads = new ArrayList(downsampler.next()); + Assert.assertFalse("Downsampler is empty but shouldn't be",downsampler.isEmpty()); + List batchedReads = new ArrayList(downsampler.getDownsampledContents()); Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size()); Assert.assertTrue("Downsampler is returning a bad read.",reads.contains(batchedReads.get(0))) ; - Assert.assertFalse("Downsampler is not empty but should be",downsampler.hasNext()); } @Test public void testFillingAcrossLoci() { List reads = new ArrayList(); reads.add(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76)); - reads.add(ArtificialSAMUtils.createArtificialRead(header,"read2",0,2,76)); - reads.add(ArtificialSAMUtils.createArtificialRead(header,"read3",0,2,76)); - reads.add(ArtificialSAMUtils.createArtificialRead(header,"read4",0,3,76)); - reads.add(ArtificialSAMUtils.createArtificialRead(header,"read5",0,3,76)); - ReservoirDownsampler downsampler = new ReservoirDownsampler(reads.iterator(), - new AlignmentStartComparator(),5); + ReservoirDownsampler downsampler = new ReservoirDownsampler(5); + downsampler.addAll(reads); - Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext()); - List batchedReads = new ArrayList(downsampler.next()); + Assert.assertFalse("Downsampler is empty but shouldn't be",downsampler.isEmpty()); + List batchedReads = new ArrayList(downsampler.getDownsampledContents()); Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size()); Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(0),batchedReads.get(0)); - Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext()); - batchedReads = new ArrayList(downsampler.next()); - Assert.assertEquals("Downsampler is returning the wrong number of reads",2,batchedReads.size()); - Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(1),batchedReads.get(0)); - Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(2),batchedReads.get(1)); - - Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext()); - batchedReads = new ArrayList(downsampler.next()); - Assert.assertEquals("Downsampler is returning the wrong number of reads",2,batchedReads.size()); - Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(3),batchedReads.get(0)); - Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(4),batchedReads.get(1)); - - Assert.assertFalse("Downsampler is not empty but should be",downsampler.hasNext()); - } - - @Test - public void testDownsamplingAcrossLoci() { - List reads = new ArrayList(); - reads.add(ArtificialSAMUtils.createArtificialRead(header,"read1",0,1,76)); + reads.clear(); reads.add(ArtificialSAMUtils.createArtificialRead(header,"read2",0,2,76)); reads.add(ArtificialSAMUtils.createArtificialRead(header,"read3",0,2,76)); + + downsampler.clear(); + downsampler.addAll(reads); + + Assert.assertFalse("Downsampler is empty but shouldn't be",downsampler.isEmpty()); + batchedReads = new ArrayList(downsampler.getDownsampledContents()); + Assert.assertEquals("Downsampler is returning the wrong number of reads",2,batchedReads.size()); + Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(0),batchedReads.get(0)); + Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(1),batchedReads.get(1)); + + reads.clear(); reads.add(ArtificialSAMUtils.createArtificialRead(header,"read4",0,3,76)); reads.add(ArtificialSAMUtils.createArtificialRead(header,"read5",0,3,76)); - ReservoirDownsampler downsampler = new ReservoirDownsampler(reads.iterator(), - new AlignmentStartComparator(),1); - Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext()); - List batchedReads = new ArrayList(downsampler.next()); - Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size()); + downsampler.clear(); + downsampler.addAll(reads); + + Assert.assertFalse("Downsampler is empty but shouldn't be",downsampler.isEmpty()); + batchedReads = new ArrayList(downsampler.getDownsampledContents()); + Assert.assertEquals("Downsampler is returning the wrong number of reads",2,batchedReads.size()); Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(0),batchedReads.get(0)); - - Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext()); - batchedReads = new ArrayList(downsampler.next()); - Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size()); - Assert.assertTrue("Downsampler is returning an incorrect read.",batchedReads.get(0).equals(reads.get(1)) || batchedReads.get(0).equals(reads.get(2))); - - Assert.assertTrue("Downsampler is empty but shouldn't be",downsampler.hasNext()); - batchedReads = new ArrayList(downsampler.next()); - Assert.assertEquals("Downsampler is returning the wrong number of reads",1,batchedReads.size()); - Assert.assertTrue("Downsampler is returning an incorrect read.",batchedReads.get(0).equals(reads.get(3)) || batchedReads.get(0).equals(reads.get(4))); - - Assert.assertFalse("Downsampler is not empty but should be",downsampler.hasNext()); + Assert.assertEquals("Downsampler is returning an incorrect read.",reads.get(1),batchedReads.get(1)); } + }