Ensure thread-safety of CachingIndexedFastaSequenceFile

-- Cosmetic cleanup of ReadReferenceView
-- TraverseReadsNano provides the reference context, since it's thread-safe
-- Cleanup CachingIndexedFastaSequenceFile.  Add docs, remove unnecessary setters
-- Expand CachingIndexedFastaSequenceFileUnitTest to test explicitly multi-threaded safety.
This commit is contained in:
Mark DePristo 2012-08-27 12:11:38 -04:00
parent e5b1f1c7f4
commit 63a9ae817a
4 changed files with 170 additions and 106 deletions

View File

@ -59,16 +59,18 @@ public class ReadReferenceView extends ReferenceView {
}
public byte[] getBases() {
// System.out.printf("Getting bases for location %s%n", loc);
// throw new StingException("x");
return getReferenceBases(loc);
}
}
public ReferenceContext getReferenceContext( SAMRecord read ) {
/**
* Return a reference context appropriate for the span of read
*
* @param read the mapped read to test
* @return
*/
public ReferenceContext getReferenceContext( final SAMRecord read ) {
GenomeLoc loc = genomeLocParser.createGenomeLoc(read);
// byte[] bases = super.getReferenceBases(loc);
// return new ReferenceContext( loc, loc, bases );
return new ReferenceContext( genomeLocParser, loc, loc, getReferenceBasesProvider(loc) );
}

View File

@ -84,7 +84,7 @@ public class TraverseReadsNano<M,T> extends TraversalEngine<M,T,ReadWalker<M,T>,
throw new ReviewedStingException("Parallel read walkers currently don't support access to reference ordered data");
final ReadView reads = new ReadView(dataProvider);
final ReadReferenceView reference = new NotImplementedReadReferenceView(dataProvider);
final ReadReferenceView reference = new ReadReferenceView(dataProvider);
final ReadBasedReferenceOrderedView rodView = new ReadBasedReferenceOrderedView(dataProvider);
nanoScheduler.setDebug(DEBUG);
@ -101,23 +101,7 @@ public class TraverseReadsNano<M,T> extends TraversalEngine<M,T,ReadWalker<M,T>,
@Override
public void printOnTraversalDone() {
nanoScheduler.shutdown();
super.printOnTraversalDone(); //To change body of overridden methods use File | Settings | File Templates.
}
private static class NotImplementedReadReferenceView extends ReadReferenceView {
private NotImplementedReadReferenceView(ShardDataProvider provider) {
super(provider);
}
@Override
protected byte[] getReferenceBases(SAMRecord read) {
throw new ReviewedStingException("Parallel read walkers don't support accessing reference yet");
}
@Override
protected byte[] getReferenceBases(GenomeLoc genomeLoc) {
throw new ReviewedStingException("Parallel read walkers don't support accessing reference yet");
}
super.printOnTraversalDone();
}
private class TraverseReadsReduce implements ReduceFunction<M, T> {

View File

@ -29,6 +29,7 @@ import net.sf.picard.reference.FastaSequenceIndex;
import net.sf.picard.reference.IndexedFastaSequenceFile;
import net.sf.picard.reference.ReferenceSequence;
import net.sf.samtools.SAMSequenceRecord;
import org.apache.log4j.Priority;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import java.io.File;
@ -38,14 +39,11 @@ import java.util.Arrays;
/**
* A caching version of the IndexedFastaSequenceFile that avoids going to disk as often as the raw indexer.
*
* Thread-safe! Uses a lock object to protect write and access to the cache.
* Thread-safe! Uses a thread-local cache
*/
public class CachingIndexedFastaSequenceFile extends IndexedFastaSequenceFile {
protected static final org.apache.log4j.Logger logger = org.apache.log4j.Logger.getLogger(CachingIndexedFastaSequenceFile.class);
/** global enable flag */
private static final boolean USE_CACHE = true;
/** do we want to print debugging information about cache efficiency? */
private static final boolean PRINT_EFFICIENCY = false;
@ -53,31 +51,29 @@ public class CachingIndexedFastaSequenceFile extends IndexedFastaSequenceFile {
private static final int PRINT_FREQUENCY = 10000;
/** The default cache size in bp */
private static final long DEFAULT_CACHE_SIZE = 1000000;
public static final long DEFAULT_CACHE_SIZE = 1000000;
/** The cache size of this CachingIndexedFastaSequenceFile */
final long cacheSize;
/** When we have a cache miss at position X, we load sequence from X - cacheMissBackup */
final long cacheMissBackup;
// information about checking efficiency
long cacheHits = 0;
long cacheMisses = 0;
/** The cache size of this CachingIndexedFastaSequenceFile */
long cacheSize = DEFAULT_CACHE_SIZE;
/** When we have a cache miss at position X, we load sequence from X - cacheMissBackup */
long cacheMissBackup = 100;
/** Represents a specific cached sequence, with a specific start and stop, as well as the bases */
private static class Cache {
long start = -1, stop = -1;
ReferenceSequence seq = null;
}
/**
* Thread local cache to allow multi-threaded use of this class
*/
private ThreadLocal<Cache> cache;
{
resetThreadLocalCache();
}
protected void resetThreadLocalCache() {
cache = new ThreadLocal<Cache> () {
@Override protected Cache initialValue() {
return new Cache();
@ -87,76 +83,107 @@ public class CachingIndexedFastaSequenceFile extends IndexedFastaSequenceFile {
/**
* Same as general constructor but allows one to override the default cacheSize
* @param file
*
* @param fasta
* @param index
* @param cacheSize
*/
public CachingIndexedFastaSequenceFile(final File file, final FastaSequenceIndex index, long cacheSize) {
super(file, index);
setCacheSize(cacheSize);
}
private void setCacheSize(long cacheSize) {
public CachingIndexedFastaSequenceFile(final File fasta, final FastaSequenceIndex index, final long cacheSize) {
super(fasta, index);
if ( cacheSize < 0 ) throw new IllegalArgumentException("cacheSize must be > 0");
this.cacheSize = cacheSize;
this.cacheMissBackup = Math.max(cacheSize / 1000, 1);
}
/**
* Open the given indexed fasta sequence file. Throw an exception if the file cannot be opened.
* @param file The file to open.
*
* @param fasta The file to open.
* @param index Pre-built FastaSequenceIndex, for the case in which one does not exist on disk.
* @throws java.io.FileNotFoundException If the fasta or any of its supporting files cannot be found.
*/
public CachingIndexedFastaSequenceFile(final File file, final FastaSequenceIndex index) {
this(file, index, DEFAULT_CACHE_SIZE);
public CachingIndexedFastaSequenceFile(final File fasta, final FastaSequenceIndex index) {
this(fasta, index, DEFAULT_CACHE_SIZE);
}
/**
* Open the given indexed fasta sequence file. Throw an exception if the file cannot be opened.
* @param file The file to open.
*
* Looks for a index file for fasta on disk
*
* @param fasta The file to open.
*/
public CachingIndexedFastaSequenceFile(final File file) throws FileNotFoundException {
this(file, DEFAULT_CACHE_SIZE);
public CachingIndexedFastaSequenceFile(final File fasta) throws FileNotFoundException {
this(fasta, DEFAULT_CACHE_SIZE);
}
public CachingIndexedFastaSequenceFile(final File file, long cacheSize ) throws FileNotFoundException {
super(file);
setCacheSize(cacheSize);
/**
* Open the given indexed fasta sequence file. Throw an exception if the file cannot be opened.
*
* Looks for a index file for fasta on disk
* Uses provided cacheSize instead of the default
*
* @param fasta The file to open.
*/
public CachingIndexedFastaSequenceFile(final File fasta, final long cacheSize ) throws FileNotFoundException {
super(fasta);
if ( cacheSize < 0 ) throw new IllegalArgumentException("cacheSize must be > 0");
this.cacheSize = cacheSize;
this.cacheMissBackup = Math.max(cacheSize / 1000, 1);
}
public void printEfficiency() {
// comment out to disable tracking
if ( (cacheHits + cacheMisses) % PRINT_FREQUENCY == 0 ) {
logger.info(String.format("### CachingIndexedFastaReader: hits=%d misses=%d efficiency %.6f%%%n", cacheHits, cacheMisses, calcEfficiency()));
}
/**
* Print the efficiency (hits / queries) to logger with priority
*/
public void printEfficiency(final Priority priority) {
logger.log(priority, String.format("### CachingIndexedFastaReader: hits=%d misses=%d efficiency %.6f%%", cacheHits, cacheMisses, calcEfficiency()));
}
/**
* Returns the efficiency (% of hits of all queries) of this object
* @return
*/
public double calcEfficiency() {
return 100.0 * cacheHits / (cacheMisses + cacheHits * 1.0);
}
/**
* @return the number of cache hits that have occurred
*/
public long getCacheHits() {
return cacheHits;
}
/**
* @return the number of cache misses that have occurred
*/
public long getCacheMisses() {
return cacheMisses;
}
/**
* @return the size of the cache we are using
*/
public long getCacheSize() {
return cacheSize;
}
/**
* Gets the subsequence of the contig in the range [start,stop]
*
* Uses the sequence cache if possible, or updates the cache to handle the request. If the range
* is larger than the cache itself, just loads the sequence directly, not changing the cache at all
*
* @param contig Contig whose subsequence to retrieve.
* @param start inclusive, 1-based start of region.
* @param stop inclusive, 1-based stop of region.
* @return The partial reference sequence associated with this range.
*/
public ReferenceSequence getSubsequenceAt( String contig, long start, long stop ) {
ReferenceSequence result;
Cache myCache = cache.get();
//System.out.printf("getSubsequentAt cache=%s%n", myCache);
public ReferenceSequence getSubsequenceAt( final String contig, final long start, final long stop ) {
final ReferenceSequence result;
final Cache myCache = cache.get();
if ( ! USE_CACHE || (stop - start) >= cacheSize ) {
if ( (stop - start) >= cacheSize ) {
cacheMisses++;
result = super.getSubsequenceAt(contig, start, stop);
} else {
@ -177,8 +204,8 @@ public class CachingIndexedFastaSequenceFile extends IndexedFastaSequenceFile {
}
// at this point we determine where in the cache we want to extract the requested subsequence
int cacheOffsetStart = (int)(start - myCache.start);
int cacheOffsetStop = (int)(stop - start + cacheOffsetStart + 1);
final int cacheOffsetStart = (int)(start - myCache.start);
final int cacheOffsetStop = (int)(stop - start + cacheOffsetStart + 1);
try {
result = new ReferenceSequence(myCache.seq.getName(), myCache.seq.getContigIndex(), Arrays.copyOfRange(myCache.seq.getBases(), cacheOffsetStart, cacheOffsetStop));
@ -188,12 +215,8 @@ public class CachingIndexedFastaSequenceFile extends IndexedFastaSequenceFile {
}
}
// // comment out to disable testing
// ReferenceSequence verify = super.getSubsequenceAt(contig, start, stop);
// if ( ! Arrays.equals(verify.getBases(), result.getBases()) )
// throw new ReviewedStingException(String.format("BUG: cached reference sequence not the same as clean fetched version at %s %d %d", contig, start, stop));
if ( PRINT_EFFICIENCY ) printEfficiency();
if ( PRINT_EFFICIENCY && (getCacheHits() + getCacheMisses()) % PRINT_FREQUENCY == 0 )
printEfficiency(Priority.INFO);
return result;
}
}

View File

@ -5,21 +5,24 @@ package org.broadinstitute.sting.utils.fasta;
// the imports for unit testing.
import org.broadinstitute.sting.utils.exceptions.UserException;
import org.testng.Assert;
import org.testng.annotations.Test;
import org.testng.annotations.DataProvider;
import org.broadinstitute.sting.BaseTest;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.Arrays;
import java.util.List;
import java.util.ArrayList;
import net.sf.picard.reference.IndexedFastaSequenceFile;
import net.sf.picard.reference.ReferenceSequence;
import net.sf.samtools.SAMSequenceRecord;
import org.apache.log4j.Priority;
import org.broadinstitute.sting.BaseTest;
import org.testng.Assert;
import org.testng.annotations.DataProvider;
import org.testng.annotations.Test;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
/**
* Basic unit test for GenomeLoc
@ -30,7 +33,7 @@ public class CachingIndexedFastaSequenceFileUnitTest extends BaseTest {
//private static final List<Integer> QUERY_SIZES = Arrays.asList(1);
private static final List<Integer> QUERY_SIZES = Arrays.asList(1, 10, 100);
private static final List<Integer> CACHE_SIZES = Arrays.asList(-1, 1000);
private static final List<Integer> CACHE_SIZES = Arrays.asList(-1, 100, 1000);
@DataProvider(name = "fastas")
public Object[][] createData1() {
@ -46,20 +49,24 @@ public class CachingIndexedFastaSequenceFileUnitTest extends BaseTest {
return params.toArray(new Object[][]{});
}
@Test(dataProvider = "fastas", enabled = true)
public void testCachingIndexedFastaReaderSequential1(File fasta, int cacheSize, int querySize) {
IndexedFastaSequenceFile caching, uncached;
try {
caching = cacheSize == -1 ? new CachingIndexedFastaSequenceFile(fasta) : new CachingIndexedFastaSequenceFile(fasta, cacheSize);
uncached = new IndexedFastaSequenceFile(fasta);
}
catch(FileNotFoundException ex) {
throw new UserException.CouldNotReadInputFile(fasta,ex);
}
private static long getCacheSize(final long cacheSizeRequested) {
return cacheSizeRequested == -1 ? CachingIndexedFastaSequenceFile.DEFAULT_CACHE_SIZE : cacheSizeRequested;
}
SAMSequenceRecord contig = uncached.getSequenceDictionary().getSequence(0);
@Test(dataProvider = "fastas", enabled = true)
public void testCachingIndexedFastaReaderSequential1(File fasta, int cacheSize, int querySize) throws FileNotFoundException {
final CachingIndexedFastaSequenceFile caching = new CachingIndexedFastaSequenceFile(fasta, getCacheSize(cacheSize));
SAMSequenceRecord contig = caching.getSequenceDictionary().getSequence(0);
logger.warn(String.format("Checking contig %s length %d with cache size %d and query size %d",
contig.getSequenceName(), contig.getSequenceLength(), cacheSize, querySize));
testSequential(caching, fasta, querySize);
}
private void testSequential(final CachingIndexedFastaSequenceFile caching, final File fasta, final int querySize) throws FileNotFoundException {
final IndexedFastaSequenceFile uncached = new IndexedFastaSequenceFile(fasta);
SAMSequenceRecord contig = uncached.getSequenceDictionary().getSequence(0);
for ( int i = 0; i < contig.getSequenceLength(); i += STEP_SIZE ) {
int start = i;
int stop = start + querySize;
@ -72,19 +79,23 @@ public class CachingIndexedFastaSequenceFileUnitTest extends BaseTest {
Assert.assertEquals(cachedVal.getBases(), uncachedVal.getBases());
}
}
// asserts for efficiency. We are going to make contig.length / STEP_SIZE queries
// at each of range: start -> start + querySize against a cache with size of X.
// we expect to hit the cache each time range falls within X. We expect a hit
// on the cache if range is within X. Which should happen at least (X - query_size * 2) / STEP_SIZE
// times.
final int minExpectedHits = (int)Math.floor((Math.min(caching.getCacheSize(), contig.getSequenceLength()) - querySize * 2.0) / STEP_SIZE);
caching.printEfficiency(Priority.WARN);
Assert.assertTrue(caching.getCacheHits() >= minExpectedHits, "Expected at least " + minExpectedHits + " cache hits but only got " + caching.getCacheHits());
}
// Tests grabbing sequences around a middle cached value.
@Test(dataProvider = "fastas", enabled = true)
public void testCachingIndexedFastaReaderTwoStage(File fasta, int cacheSize, int querySize) {
IndexedFastaSequenceFile caching, uncached;
try {
uncached = new IndexedFastaSequenceFile(fasta);
caching = new CachingIndexedFastaSequenceFile(fasta, cacheSize);
}
catch(FileNotFoundException ex) {
throw new UserException.CouldNotReadInputFile(fasta,ex);
}
public void testCachingIndexedFastaReaderTwoStage(File fasta, int cacheSize, int querySize) throws FileNotFoundException {
final IndexedFastaSequenceFile uncached = new IndexedFastaSequenceFile(fasta);
final CachingIndexedFastaSequenceFile caching = new CachingIndexedFastaSequenceFile(fasta, getCacheSize(cacheSize));
SAMSequenceRecord contig = uncached.getSequenceDictionary().getSequence(0);
@ -108,4 +119,48 @@ public class CachingIndexedFastaSequenceFileUnitTest extends BaseTest {
}
}
}
@DataProvider(name = "ParallelFastaTest")
public Object[][] createParallelFastaTest() {
List<Object[]> params = new ArrayList<Object[]>();
// for ( int nt : Arrays.asList(1, 2, 3) ) {
// for ( int cacheSize : CACHE_SIZES ) {
// params.add(new Object[]{simpleFasta, cacheSize, 10, nt});
// }
// }
for ( File fasta : Arrays.asList(simpleFasta) ) {
for ( int cacheSize : CACHE_SIZES ) {
for ( int querySize : QUERY_SIZES ) {
for ( int nt : Arrays.asList(1, 2, 3, 4) ) {
params.add(new Object[]{fasta, cacheSize, querySize, nt});
}
}
}
}
return params.toArray(new Object[][]{});
}
@Test(dataProvider = "ParallelFastaTest", enabled = true, timeOut = 60000)
public void testCachingIndexedFastaReaderParallel(final File fasta, final int cacheSize, final int querySize, final int nt) throws FileNotFoundException, InterruptedException {
final CachingIndexedFastaSequenceFile caching = new CachingIndexedFastaSequenceFile(fasta, getCacheSize(cacheSize));
logger.warn(String.format("Parallel caching index fasta reader test cacheSize %d querySize %d nt %d", caching.getCacheSize(), querySize, nt));
for ( int iterations = 0; iterations < 1; iterations++ ) {
final ExecutorService executor = Executors.newFixedThreadPool(nt);
final Collection<Callable<Object>> tasks = new ArrayList<Callable<Object>>(nt);
for ( int i = 0; i < nt; i++ )
tasks.add(new Callable<Object>() {
@Override
public Object call() throws Exception {
testSequential(caching, fasta, querySize);
return null;
}
});
executor.invokeAll(tasks);
executor.shutdownNow();
}
}
}