diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java index b89f68e24..be8357679 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java @@ -28,26 +28,22 @@ package org.broadinstitute.sting.gatk.walkers.bqsr; import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; import org.broadinstitute.sting.utils.BaseUtils; import org.broadinstitute.sting.utils.classloader.ProtectedPackageSource; -import org.broadinstitute.sting.utils.collections.NestedIntegerArray; import org.broadinstitute.sting.utils.pileup.PileupElement; import org.broadinstitute.sting.utils.recalibration.EventType; import org.broadinstitute.sting.utils.recalibration.ReadCovariates; -import org.broadinstitute.sting.utils.recalibration.RecalDatum; import org.broadinstitute.sting.utils.recalibration.RecalibrationTables; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; +import org.broadinstitute.sting.utils.threading.ThreadLocalArray; public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine implements ProtectedPackageSource { - // optimizations: don't reallocate an array each time - private byte[] tempQualArray; - private boolean[] tempErrorArray; - private double[] tempFractionalErrorArray; + // optimization: only allocate temp arrays once per thread + private final ThreadLocal threadLocalTempQualArray = new ThreadLocalArray(EventType.values().length, byte.class); + private final ThreadLocal threadLocalTempErrorArray = new ThreadLocalArray(EventType.values().length, boolean.class); + private final ThreadLocal threadLocalTempFractionalErrorArray = new ThreadLocalArray(EventType.values().length, double.class); public void initialize(final Covariate[] covariates, final RecalibrationTables recalibrationTables) { super.initialize(covariates, recalibrationTables); - tempQualArray = new byte[EventType.values().length]; - tempErrorArray = new boolean[EventType.values().length]; - tempFractionalErrorArray = new double[EventType.values().length]; } /** @@ -60,10 +56,13 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp * @param refBase The reference base at this locus */ @Override - public synchronized void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) { + public void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) { final int offset = pileupElement.getOffset(); final ReadCovariates readCovariates = covariateKeySetFrom(pileupElement.getRead()); + byte[] tempQualArray = threadLocalTempQualArray.get(); + boolean[] tempErrorArray = threadLocalTempErrorArray.get(); + tempQualArray[EventType.BASE_SUBSTITUTION.index] = pileupElement.getQual(); tempErrorArray[EventType.BASE_SUBSTITUTION.index] = !BaseUtils.basesAreEqual(pileupElement.getBase(), refBase); tempQualArray[EventType.BASE_INSERTION.index] = pileupElement.getBaseInsertionQual(); @@ -77,40 +76,29 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp final byte qual = tempQualArray[eventIndex]; final boolean isError = tempErrorArray[eventIndex]; - final NestedIntegerArray rgRecalTable = recalibrationTables.getReadGroupTable(); - final RecalDatum rgPreviousDatum = rgRecalTable.get(keys[0], eventIndex); - final RecalDatum rgThisDatum = createDatumObject(qual, isError); - if (rgPreviousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it - rgRecalTable.put(rgThisDatum, keys[0], eventIndex); - else - rgPreviousDatum.combine(rgThisDatum); + // TODO: should this really be combine rather than increment? + combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex); - final NestedIntegerArray qualRecalTable = recalibrationTables.getQualityScoreTable(); - final RecalDatum qualPreviousDatum = qualRecalTable.get(keys[0], keys[1], eventIndex); - if (qualPreviousDatum == null) - qualRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], eventIndex); - else - qualPreviousDatum.increment(isError); + incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex); for (int i = 2; i < covariates.length; i++) { if (keys[i] < 0) continue; - final NestedIntegerArray covRecalTable = recalibrationTables.getTable(i); - final RecalDatum covPreviousDatum = covRecalTable.get(keys[0], keys[1], keys[i], eventIndex); - if (covPreviousDatum == null) - covRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], keys[i], eventIndex); - else - covPreviousDatum.increment(isError); + + incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex); } } } @Override - public synchronized void updateDataForRead(final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) { + public void updateDataForRead(final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) { for( int offset = 0; offset < read.getReadBases().length; offset++ ) { if( !skip[offset] ) { final ReadCovariates readCovariates = covariateKeySetFrom(read); + byte[] tempQualArray = threadLocalTempQualArray.get(); + double[] tempFractionalErrorArray = threadLocalTempFractionalErrorArray.get(); + tempQualArray[EventType.BASE_SUBSTITUTION.index] = read.getBaseQualities()[offset]; tempFractionalErrorArray[EventType.BASE_SUBSTITUTION.index] = snpErrors[offset]; tempQualArray[EventType.BASE_INSERTION.index] = read.getBaseInsertionQualities()[offset]; @@ -124,30 +112,15 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp final byte qual = tempQualArray[eventIndex]; final double isError = tempFractionalErrorArray[eventIndex]; - final NestedIntegerArray rgRecalTable = recalibrationTables.getReadGroupTable(); - final RecalDatum rgPreviousDatum = rgRecalTable.get(keys[0], eventIndex); - final RecalDatum rgThisDatum = createDatumObject(qual, isError); - if (rgPreviousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it - rgRecalTable.put(rgThisDatum, keys[0], eventIndex); - else - rgPreviousDatum.combine(rgThisDatum); + combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex); - final NestedIntegerArray qualRecalTable = recalibrationTables.getQualityScoreTable(); - final RecalDatum qualPreviousDatum = qualRecalTable.get(keys[0], keys[1], eventIndex); - if (qualPreviousDatum == null) - qualRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], eventIndex); - else - qualPreviousDatum.increment(1.0, isError); + incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex); for (int i = 2; i < covariates.length; i++) { if (keys[i] < 0) continue; - final NestedIntegerArray covRecalTable = recalibrationTables.getTable(i); - final RecalDatum covPreviousDatum = covRecalTable.get(keys[0], keys[1], keys[i], eventIndex); - if (covPreviousDatum == null) - covRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], keys[i], eventIndex); - else - covPreviousDatum.increment(1.0, isError); + + incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex); } } } diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/GenotypingEngine.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/GenotypingEngine.java index 8738def50..d91df82e2 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/GenotypingEngine.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/GenotypingEngine.java @@ -217,7 +217,7 @@ public class GenotypingEngine { } cleanUpSymbolicUnassembledEvents( haplotypes ); - if( activeAllelesToGenotype.isEmpty() && haplotypes.get(0).getSampleKeySet().size() >= 3 ) { // if not in GGA mode and have at least 3 samples try to create MNP and complex events by looking at LD structure + if( activeAllelesToGenotype.isEmpty() && haplotypes.get(0).getSampleKeySet().size() >= 10 ) { // if not in GGA mode and have at least 10 samples try to create MNP and complex events by looking at LD structure mergeConsecutiveEventsBasedOnLD( haplotypes, startPosKeySet, ref, refLoc ); } if( !activeAllelesToGenotype.isEmpty() ) { // we are in GGA mode! diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java index e1a9ef7c8..a0924623b 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/haplotypecaller/LikelihoodCalculationEngine.java @@ -140,7 +140,7 @@ public class LikelihoodCalculationEngine { } private static int computeFirstDifferingPosition( final byte[] b1, final byte[] b2 ) { - for( int iii = 0; iii < b1.length && iii < b2.length; iii++ ){ + for( int iii = 0; iii < b1.length && iii < b2.length; iii++ ) { if( b1[iii] != b2[iii] ) { return iii; } diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/bqsr/BQSRIntegrationTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/bqsr/BQSRIntegrationTest.java index 26ee78484..90d67d393 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/bqsr/BQSRIntegrationTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/bqsr/BQSRIntegrationTest.java @@ -73,11 +73,6 @@ public class BQSRIntegrationTest extends WalkerTest { params.getCommandLine(), Arrays.asList(params.md5)); executeTest("testBQSR-"+params.args, spec).getFirst(); - - WalkerTestSpec specNT2 = new WalkerTestSpec( - params.getCommandLine() + " -nt 2", - Arrays.asList(params.md5)); - executeTest("testBQSR-nt2-"+params.args, specNT2).getFirst(); } @Test diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerIntegrationTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerIntegrationTest.java index 1d4fad19e..71f03ea00 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerIntegrationTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/haplotypecaller/HaplotypeCallerIntegrationTest.java @@ -21,7 +21,7 @@ public class HaplotypeCallerIntegrationTest extends WalkerTest { @Test public void testHaplotypeCallerMultiSample() { - HCTest(CEUTRIO_BAM, "", "ee866a8694a6f6c77242041275350ab9"); + HCTest(CEUTRIO_BAM, "", "1d826f852ec1457efbfa7ee828c5783a"); } @Test @@ -42,7 +42,7 @@ public class HaplotypeCallerIntegrationTest extends WalkerTest { @Test public void testHaplotypeCallerMultiSampleComplex() { - HCTestComplexVariants(CEUTRIO_BAM, "", "30598abeeb0b0ae5816ffdbf0c4044fd"); + HCTestComplexVariants(privateTestDir + "AFR.complex.variants.bam", "", "966d338f423c86a390d685aa6336ec69"); } private void HCTestSymbolicVariants(String bam, String args, String md5) { @@ -94,6 +94,4 @@ public class HaplotypeCallerIntegrationTest extends WalkerTest { Arrays.asList("79af83432dc4a1768b3ebffffc4d2b8f")); executeTest("HC calling on a ReducedRead BAM", spec); } - - } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java index 2c7a12a36..245c8e3e8 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java @@ -109,7 +109,7 @@ import java.util.ArrayList; @ReadFilters({MappingQualityZeroFilter.class, MappingQualityUnavailableFilter.class}) // only look at covered loci, not every loci of the reference file @Requires({DataSource.READS, DataSource.REFERENCE}) // filter out all reads with zero or unavailable mapping quality @PartitionBy(PartitionType.LOCUS) // this walker requires both -I input.bam and -R reference.fasta -public class BaseRecalibrator extends LocusWalker implements TreeReducible, NanoSchedulable { +public class BaseRecalibrator extends LocusWalker { @ArgumentCollection private final RecalibrationArgumentCollection RAC = new RecalibrationArgumentCollection(); // all the command line arguments for BQSR and it's covariates @@ -277,11 +277,6 @@ public class BaseRecalibrator extends LocusWalker implements TreeRed return sum; } - public Long treeReduce(Long sum1, Long sum2) { - sum1 += sum2; - return sum1; - } - @Override public void onTraversalDone(Long result) { logger.info("Calculating quantized quality scores..."); diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java index 4fe9c5323..1f7fbbf42 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java @@ -55,7 +55,7 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP * @param refBase The reference base at this locus */ @Override - public synchronized void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) { + public void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) { final int offset = pileupElement.getOffset(); final ReadCovariates readCovariates = covariateKeySetFrom(pileupElement.getRead()); @@ -65,35 +65,21 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP final int[] keys = readCovariates.getKeySet(offset, EventType.BASE_SUBSTITUTION); final int eventIndex = EventType.BASE_SUBSTITUTION.index; - final NestedIntegerArray rgRecalTable = recalibrationTables.getReadGroupTable(); - final RecalDatum rgPreviousDatum = rgRecalTable.get(keys[0], eventIndex); - final RecalDatum rgThisDatum = createDatumObject(qual, isError); - if (rgPreviousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it - rgRecalTable.put(rgThisDatum, keys[0], eventIndex); - else - rgPreviousDatum.combine(rgThisDatum); + // TODO: should this really be combine rather than increment? + combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex); - final NestedIntegerArray qualRecalTable = recalibrationTables.getQualityScoreTable(); - final RecalDatum qualPreviousDatum = qualRecalTable.get(keys[0], keys[1], eventIndex); - if (qualPreviousDatum == null) - qualRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], eventIndex); - else - qualPreviousDatum.increment(isError); + incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex); for (int i = 2; i < covariates.length; i++) { if (keys[i] < 0) continue; - final NestedIntegerArray covRecalTable = recalibrationTables.getTable(i); - final RecalDatum covPreviousDatum = covRecalTable.get(keys[0], keys[1], keys[i], eventIndex); - if (covPreviousDatum == null) - covRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], keys[i], eventIndex); - else - covPreviousDatum.increment(isError); + + incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex); } } @Override - public synchronized void updateDataForRead( final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) { + public void updateDataForRead( final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) { throw new UnsupportedOperationException("Delocalized BQSR is not available in the GATK-lite version"); } @@ -121,4 +107,133 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP protected ReadCovariates covariateKeySetFrom(GATKSAMRecord read) { return (ReadCovariates) read.getTemporaryAttribute(BaseRecalibrator.COVARS_ATTRIBUTE); } + + /** + * Increments the RecalDatum at the specified position in the specified table, or put a new item there + * if there isn't already one. + * + * Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put() + * to return false if another thread inserts a new item at our position in the middle of our put operation. + * + * @param table the table that holds/will hold our item + * @param qual qual for this event + * @param isError error status for this event + * @param keys location in table of our item + */ + protected void incrementDatumOrPutIfNecessary( final NestedIntegerArray table, final byte qual, final boolean isError, final int... keys ) { + final RecalDatum existingDatum = table.get(keys); + + // There are three cases here: + // + // 1. The original get succeeded (existingDatum != null) because there already was an item in this position. + // In this case we can just increment the existing item. + // + // 2. The original get failed (existingDatum == null), and we successfully put a new item in this position + // and are done. + // + // 3. The original get failed (existingDatum == null), AND the put fails because another thread comes along + // in the interim and puts an item in this position. In this case we need to do another get after the + // failed put to get the item the other thread put here and increment it. + if ( existingDatum == null ) { + // No existing item, try to put a new one + if ( ! table.put(createDatumObject(qual, isError), keys) ) { + // Failed to put a new item because another thread came along and put an item here first. + // Get the newly-put item and increment it (item is guaranteed to exist at this point) + table.get(keys).increment(isError); + } + } + else { + // Easy case: already an item here, so increment it + existingDatum.increment(isError); + } + } + + /** + * Increments the RecalDatum at the specified position in the specified table, or put a new item there + * if there isn't already one. + * + * Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put() + * to return false if another thread inserts a new item at our position in the middle of our put operation. + * + * @param table the table that holds/will hold our item + * @param qual qual for this event + * @param isError error value for this event + * @param keys location in table of our item + */ + protected void incrementDatumOrPutIfNecessary( final NestedIntegerArray table, final byte qual, final double isError, final int... keys ) { + final RecalDatum existingDatum = table.get(keys); + + if ( existingDatum == null ) { + // No existing item, try to put a new one + if ( ! table.put(createDatumObject(qual, isError), keys) ) { + // Failed to put a new item because another thread came along and put an item here first. + // Get the newly-put item and increment it (item is guaranteed to exist at this point) + table.get(keys).increment(1.0, isError); + } + } + else { + // Easy case: already an item here, so increment it + existingDatum.increment(1.0, isError); + } + } + + /** + * Combines the RecalDatum at the specified position in the specified table with a new RecalDatum, or put a + * new item there if there isn't already one. + * + * Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put() + * to return false if another thread inserts a new item at our position in the middle of our put operation. + * + * @param table the table that holds/will hold our item + * @param qual qual for this event + * @param isError error status for this event + * @param keys location in table of our item + */ + protected void combineDatumOrPutIfNecessary( final NestedIntegerArray table, final byte qual, final boolean isError, final int... keys ) { + final RecalDatum existingDatum = table.get(keys); + final RecalDatum newDatum = createDatumObject(qual, isError); + + if ( existingDatum == null ) { + // No existing item, try to put a new one + if ( ! table.put(newDatum, keys) ) { + // Failed to put a new item because another thread came along and put an item here first. + // Get the newly-put item and combine it with our item (item is guaranteed to exist at this point) + table.get(keys).combine(newDatum); + } + } + else { + // Easy case: already an item here, so combine it with our item + existingDatum.combine(newDatum); + } + } + + /** + * Combines the RecalDatum at the specified position in the specified table with a new RecalDatum, or put a + * new item there if there isn't already one. + * + * Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put() + * to return false if another thread inserts a new item at our position in the middle of our put operation. + * + * @param table the table that holds/will hold our item + * @param qual qual for this event + * @param isError error value for this event + * @param keys location in table of our item + */ + protected void combineDatumOrPutIfNecessary( final NestedIntegerArray table, final byte qual, final double isError, final int... keys ) { + final RecalDatum existingDatum = table.get(keys); + final RecalDatum newDatum = createDatumObject(qual, isError); + + if ( existingDatum == null ) { + // No existing item, try to put a new one + if ( ! table.put(newDatum, keys) ) { + // Failed to put a new item because another thread came along and put an item here first. + // Get the newly-put item and combine it with our item (item is guaranteed to exist at this point) + table.get(keys).combine(newDatum); + } + } + else { + // Easy case: already an item here, so combine it with our item + existingDatum.combine(newDatum); + } + } } diff --git a/public/java/src/org/broadinstitute/sting/utils/collections/LoggingNestedIntegerArray.java b/public/java/src/org/broadinstitute/sting/utils/collections/LoggingNestedIntegerArray.java index 6fda0245b..63927ac84 100644 --- a/public/java/src/org/broadinstitute/sting/utils/collections/LoggingNestedIntegerArray.java +++ b/public/java/src/org/broadinstitute/sting/utils/collections/LoggingNestedIntegerArray.java @@ -99,9 +99,7 @@ public class LoggingNestedIntegerArray extends NestedIntegerArray { } @Override - public void put( final T value, final int... keys ) { - super.put(value, keys); - + public boolean put( final T value, final int... keys ) { StringBuilder logEntry = new StringBuilder(); logEntry.append(logEntryLabel); @@ -116,5 +114,7 @@ public class LoggingNestedIntegerArray extends NestedIntegerArray { // PrintStream methods all use synchronized blocks internally, so our logging is thread-safe log.println(logEntry.toString()); + + return super.put(value, keys); } } diff --git a/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java b/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java index 31d316555..c83de8ec7 100755 --- a/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java +++ b/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java @@ -50,6 +50,28 @@ public class NestedIntegerArray { this.dimensions = dimensions.clone(); data = new Object[dimensions[0]]; + prepopulateArray(data, 0); + } + + /** + * Recursively allocate the entire tree of arrays in all its dimensions. + * + * Doing this upfront uses more memory initially, but saves time over the course of the run + * and (crucially) avoids having to make threads wait while traversing the tree to check + * whether branches exist or not. + * + * @param subarray current node in the tree + * @param dimension current level in the tree + */ + private void prepopulateArray( Object[] subarray, int dimension ) { + if ( dimension >= numDimensions - 1 ) { + return; + } + + for ( int i = 0; i < subarray.length; i++ ) { + subarray[i] = new Object[dimensions[dimension + 1]]; + prepopulateArray((Object[])subarray[i], dimension + 1); + } } public T get(final int... keys) { @@ -59,14 +81,29 @@ public class NestedIntegerArray { for( int i = 0; i < numNestedDimensions; i++ ) { if ( keys[i] >= dimensions[i] ) return null; - myData = (Object[])myData[keys[i]]; - if ( myData == null ) - return null; + + myData = (Object[])myData[keys[i]]; // interior nodes in the tree will never be null, so we can safely traverse + // down to the leaves } + return (T)myData[keys[numNestedDimensions]]; } - public synchronized void put(final T value, final int... keys) { // WARNING! value comes before the keys! + /** + * Insert a value at the position specified by the given keys. + * + * This method is THREAD-SAFE despite not being synchronized, however the caller MUST + * check the return value to see if the put succeeded. This method RETURNS FALSE if + * the value could not be inserted because there already was a value present + * at the specified location. In this case the caller should do a get() to get + * the already-existing value and (potentially) update it. + * + * @param value value to insert + * @param keys keys specifying the location of the value in the tree + * @return true if the value was inserted, false if it could not be inserted because there was already + * a value at the specified position + */ + public boolean put(final T value, final int... keys) { // WARNING! value comes before the keys! if ( keys.length != numDimensions ) throw new ReviewedStingException("Exactly " + numDimensions + " keys should be passed to this NestedIntegerArray but " + keys.length + " were provided"); @@ -75,15 +112,26 @@ public class NestedIntegerArray { for ( int i = 0; i < numNestedDimensions; i++ ) { if ( keys[i] >= dimensions[i] ) throw new ReviewedStingException("Key " + keys[i] + " is too large for dimension " + i + " (max is " + (dimensions[i]-1) + ")"); - Object[] temp = (Object[])myData[keys[i]]; - if ( temp == null ) { - temp = new Object[dimensions[i+1]]; - myData[keys[i]] = temp; - } - myData = temp; + + myData = (Object[])myData[keys[i]]; // interior nodes in the tree will never be null, so we can safely traverse + // down to the leaves } - myData[keys[numNestedDimensions]] = value; + synchronized ( myData ) { // lock the bottom row while we examine and (potentially) update it + + // Insert the new value only if there still isn't any existing value in this position + if ( myData[keys[numNestedDimensions]] == null ) { + myData[keys[numNestedDimensions]] = value; + } + else { + // Already have a value for this leaf (perhaps another thread came along and inserted one + // while we traversed the tree), so return false to notify the caller that we didn't put + // the item + return false; + } + } + + return true; } public List getAllValues() { diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ReadGroupCovariate.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ReadGroupCovariate.java index 85568dac9..29c15adf7 100755 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ReadGroupCovariate.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ReadGroupCovariate.java @@ -66,7 +66,9 @@ public class ReadGroupCovariate implements RequiredCovariate { } @Override - public String formatKey(final int key) { + public synchronized String formatKey(final int key) { + // This method is synchronized so that we don't attempt to do a get() + // from the reverse lookup table while that table is being updated return readGroupReverseLookupTable.get(key); } @@ -76,17 +78,32 @@ public class ReadGroupCovariate implements RequiredCovariate { } private int keyForReadGroup(final String readGroupId) { - if (!readGroupLookupTable.containsKey(readGroupId)) { - readGroupLookupTable.put(readGroupId, nextId); - readGroupReverseLookupTable.put(nextId, readGroupId); - nextId++; + // Rather than synchronize this entire method (which would be VERY expensive for walkers like the BQSR), + // synchronize only the table updates. + + // Before entering the synchronized block, check to see if this read group is not in our tables. + // If it's not, either we will have to insert it, OR another thread will insert it first. + // This preliminary check avoids doing any synchronization most of the time. + if ( ! readGroupLookupTable.containsKey(readGroupId) ) { + + synchronized ( this ) { + + // Now we need to make sure the key is STILL not there, since another thread may have come along + // and inserted it while we were waiting to enter this synchronized block! + if ( ! readGroupLookupTable.containsKey(readGroupId) ) { + readGroupLookupTable.put(readGroupId, nextId); + readGroupReverseLookupTable.put(nextId, readGroupId); + nextId++; + } + } } return readGroupLookupTable.get(readGroupId); } @Override - public int maximumKeyValue() { + public synchronized int maximumKeyValue() { + // Synchronized so that we don't query table size while the tables are being updated return readGroupLookupTable.size() - 1; } diff --git a/public/java/src/org/broadinstitute/sting/utils/threading/ThreadLocalArray.java b/public/java/src/org/broadinstitute/sting/utils/threading/ThreadLocalArray.java new file mode 100644 index 000000000..cc50152ac --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/utils/threading/ThreadLocalArray.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2012, The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +package org.broadinstitute.sting.utils.threading; + +import java.lang.reflect.Array; + +/** + * ThreadLocal implementation for arrays + * + * Example usage: + * + * private ThreadLocal threadLocalByteArray = new ThreadLocalArray(length, byte.class); + * .... + * byte[] byteArray = threadLocalByteArray.get(); + * + * @param the type of the array itself (eg., int[], double[], etc.) + * + * @author David Roazen + */ +public class ThreadLocalArray extends ThreadLocal { + private int arraySize; + private Class arrayElementType; + + /** + * Create a new ThreadLocalArray + * + * @param arraySize desired length of the array + * @param arrayElementType type of the elements within the array (eg., Byte.class, Integer.class, etc.) + */ + public ThreadLocalArray( int arraySize, Class arrayElementType ) { + super(); + + this.arraySize = arraySize; + this.arrayElementType = arrayElementType; + } + + @Override + @SuppressWarnings("unchecked") + protected T initialValue() { + return (T)Array.newInstance(arrayElementType, arraySize); + } +}