diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java index b89f68e24..be8357679 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java @@ -28,26 +28,22 @@ package org.broadinstitute.sting.gatk.walkers.bqsr; import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; import org.broadinstitute.sting.utils.BaseUtils; import org.broadinstitute.sting.utils.classloader.ProtectedPackageSource; -import org.broadinstitute.sting.utils.collections.NestedIntegerArray; import org.broadinstitute.sting.utils.pileup.PileupElement; import org.broadinstitute.sting.utils.recalibration.EventType; import org.broadinstitute.sting.utils.recalibration.ReadCovariates; -import org.broadinstitute.sting.utils.recalibration.RecalDatum; import org.broadinstitute.sting.utils.recalibration.RecalibrationTables; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; +import org.broadinstitute.sting.utils.threading.ThreadLocalArray; public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine implements ProtectedPackageSource { - // optimizations: don't reallocate an array each time - private byte[] tempQualArray; - private boolean[] tempErrorArray; - private double[] tempFractionalErrorArray; + // optimization: only allocate temp arrays once per thread + private final ThreadLocal threadLocalTempQualArray = new ThreadLocalArray(EventType.values().length, byte.class); + private final ThreadLocal threadLocalTempErrorArray = new ThreadLocalArray(EventType.values().length, boolean.class); + private final ThreadLocal threadLocalTempFractionalErrorArray = new ThreadLocalArray(EventType.values().length, double.class); public void initialize(final Covariate[] covariates, final RecalibrationTables recalibrationTables) { super.initialize(covariates, recalibrationTables); - tempQualArray = new byte[EventType.values().length]; - tempErrorArray = new boolean[EventType.values().length]; - tempFractionalErrorArray = new double[EventType.values().length]; } /** @@ -60,10 +56,13 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp * @param refBase The reference base at this locus */ @Override - public synchronized void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) { + public void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) { final int offset = pileupElement.getOffset(); final ReadCovariates readCovariates = covariateKeySetFrom(pileupElement.getRead()); + byte[] tempQualArray = threadLocalTempQualArray.get(); + boolean[] tempErrorArray = threadLocalTempErrorArray.get(); + tempQualArray[EventType.BASE_SUBSTITUTION.index] = pileupElement.getQual(); tempErrorArray[EventType.BASE_SUBSTITUTION.index] = !BaseUtils.basesAreEqual(pileupElement.getBase(), refBase); tempQualArray[EventType.BASE_INSERTION.index] = pileupElement.getBaseInsertionQual(); @@ -77,40 +76,29 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp final byte qual = tempQualArray[eventIndex]; final boolean isError = tempErrorArray[eventIndex]; - final NestedIntegerArray rgRecalTable = recalibrationTables.getReadGroupTable(); - final RecalDatum rgPreviousDatum = rgRecalTable.get(keys[0], eventIndex); - final RecalDatum rgThisDatum = createDatumObject(qual, isError); - if (rgPreviousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it - rgRecalTable.put(rgThisDatum, keys[0], eventIndex); - else - rgPreviousDatum.combine(rgThisDatum); + // TODO: should this really be combine rather than increment? + combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex); - final NestedIntegerArray qualRecalTable = recalibrationTables.getQualityScoreTable(); - final RecalDatum qualPreviousDatum = qualRecalTable.get(keys[0], keys[1], eventIndex); - if (qualPreviousDatum == null) - qualRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], eventIndex); - else - qualPreviousDatum.increment(isError); + incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex); for (int i = 2; i < covariates.length; i++) { if (keys[i] < 0) continue; - final NestedIntegerArray covRecalTable = recalibrationTables.getTable(i); - final RecalDatum covPreviousDatum = covRecalTable.get(keys[0], keys[1], keys[i], eventIndex); - if (covPreviousDatum == null) - covRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], keys[i], eventIndex); - else - covPreviousDatum.increment(isError); + + incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex); } } } @Override - public synchronized void updateDataForRead(final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) { + public void updateDataForRead(final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) { for( int offset = 0; offset < read.getReadBases().length; offset++ ) { if( !skip[offset] ) { final ReadCovariates readCovariates = covariateKeySetFrom(read); + byte[] tempQualArray = threadLocalTempQualArray.get(); + double[] tempFractionalErrorArray = threadLocalTempFractionalErrorArray.get(); + tempQualArray[EventType.BASE_SUBSTITUTION.index] = read.getBaseQualities()[offset]; tempFractionalErrorArray[EventType.BASE_SUBSTITUTION.index] = snpErrors[offset]; tempQualArray[EventType.BASE_INSERTION.index] = read.getBaseInsertionQualities()[offset]; @@ -124,30 +112,15 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp final byte qual = tempQualArray[eventIndex]; final double isError = tempFractionalErrorArray[eventIndex]; - final NestedIntegerArray rgRecalTable = recalibrationTables.getReadGroupTable(); - final RecalDatum rgPreviousDatum = rgRecalTable.get(keys[0], eventIndex); - final RecalDatum rgThisDatum = createDatumObject(qual, isError); - if (rgPreviousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it - rgRecalTable.put(rgThisDatum, keys[0], eventIndex); - else - rgPreviousDatum.combine(rgThisDatum); + combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex); - final NestedIntegerArray qualRecalTable = recalibrationTables.getQualityScoreTable(); - final RecalDatum qualPreviousDatum = qualRecalTable.get(keys[0], keys[1], eventIndex); - if (qualPreviousDatum == null) - qualRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], eventIndex); - else - qualPreviousDatum.increment(1.0, isError); + incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex); for (int i = 2; i < covariates.length; i++) { if (keys[i] < 0) continue; - final NestedIntegerArray covRecalTable = recalibrationTables.getTable(i); - final RecalDatum covPreviousDatum = covRecalTable.get(keys[0], keys[1], keys[i], eventIndex); - if (covPreviousDatum == null) - covRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], keys[i], eventIndex); - else - covPreviousDatum.increment(1.0, isError); + + incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex); } } } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java index 4fe9c5323..1f7fbbf42 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java @@ -55,7 +55,7 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP * @param refBase The reference base at this locus */ @Override - public synchronized void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) { + public void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) { final int offset = pileupElement.getOffset(); final ReadCovariates readCovariates = covariateKeySetFrom(pileupElement.getRead()); @@ -65,35 +65,21 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP final int[] keys = readCovariates.getKeySet(offset, EventType.BASE_SUBSTITUTION); final int eventIndex = EventType.BASE_SUBSTITUTION.index; - final NestedIntegerArray rgRecalTable = recalibrationTables.getReadGroupTable(); - final RecalDatum rgPreviousDatum = rgRecalTable.get(keys[0], eventIndex); - final RecalDatum rgThisDatum = createDatumObject(qual, isError); - if (rgPreviousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it - rgRecalTable.put(rgThisDatum, keys[0], eventIndex); - else - rgPreviousDatum.combine(rgThisDatum); + // TODO: should this really be combine rather than increment? + combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex); - final NestedIntegerArray qualRecalTable = recalibrationTables.getQualityScoreTable(); - final RecalDatum qualPreviousDatum = qualRecalTable.get(keys[0], keys[1], eventIndex); - if (qualPreviousDatum == null) - qualRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], eventIndex); - else - qualPreviousDatum.increment(isError); + incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex); for (int i = 2; i < covariates.length; i++) { if (keys[i] < 0) continue; - final NestedIntegerArray covRecalTable = recalibrationTables.getTable(i); - final RecalDatum covPreviousDatum = covRecalTable.get(keys[0], keys[1], keys[i], eventIndex); - if (covPreviousDatum == null) - covRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], keys[i], eventIndex); - else - covPreviousDatum.increment(isError); + + incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex); } } @Override - public synchronized void updateDataForRead( final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) { + public void updateDataForRead( final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) { throw new UnsupportedOperationException("Delocalized BQSR is not available in the GATK-lite version"); } @@ -121,4 +107,133 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP protected ReadCovariates covariateKeySetFrom(GATKSAMRecord read) { return (ReadCovariates) read.getTemporaryAttribute(BaseRecalibrator.COVARS_ATTRIBUTE); } + + /** + * Increments the RecalDatum at the specified position in the specified table, or put a new item there + * if there isn't already one. + * + * Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put() + * to return false if another thread inserts a new item at our position in the middle of our put operation. + * + * @param table the table that holds/will hold our item + * @param qual qual for this event + * @param isError error status for this event + * @param keys location in table of our item + */ + protected void incrementDatumOrPutIfNecessary( final NestedIntegerArray table, final byte qual, final boolean isError, final int... keys ) { + final RecalDatum existingDatum = table.get(keys); + + // There are three cases here: + // + // 1. The original get succeeded (existingDatum != null) because there already was an item in this position. + // In this case we can just increment the existing item. + // + // 2. The original get failed (existingDatum == null), and we successfully put a new item in this position + // and are done. + // + // 3. The original get failed (existingDatum == null), AND the put fails because another thread comes along + // in the interim and puts an item in this position. In this case we need to do another get after the + // failed put to get the item the other thread put here and increment it. + if ( existingDatum == null ) { + // No existing item, try to put a new one + if ( ! table.put(createDatumObject(qual, isError), keys) ) { + // Failed to put a new item because another thread came along and put an item here first. + // Get the newly-put item and increment it (item is guaranteed to exist at this point) + table.get(keys).increment(isError); + } + } + else { + // Easy case: already an item here, so increment it + existingDatum.increment(isError); + } + } + + /** + * Increments the RecalDatum at the specified position in the specified table, or put a new item there + * if there isn't already one. + * + * Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put() + * to return false if another thread inserts a new item at our position in the middle of our put operation. + * + * @param table the table that holds/will hold our item + * @param qual qual for this event + * @param isError error value for this event + * @param keys location in table of our item + */ + protected void incrementDatumOrPutIfNecessary( final NestedIntegerArray table, final byte qual, final double isError, final int... keys ) { + final RecalDatum existingDatum = table.get(keys); + + if ( existingDatum == null ) { + // No existing item, try to put a new one + if ( ! table.put(createDatumObject(qual, isError), keys) ) { + // Failed to put a new item because another thread came along and put an item here first. + // Get the newly-put item and increment it (item is guaranteed to exist at this point) + table.get(keys).increment(1.0, isError); + } + } + else { + // Easy case: already an item here, so increment it + existingDatum.increment(1.0, isError); + } + } + + /** + * Combines the RecalDatum at the specified position in the specified table with a new RecalDatum, or put a + * new item there if there isn't already one. + * + * Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put() + * to return false if another thread inserts a new item at our position in the middle of our put operation. + * + * @param table the table that holds/will hold our item + * @param qual qual for this event + * @param isError error status for this event + * @param keys location in table of our item + */ + protected void combineDatumOrPutIfNecessary( final NestedIntegerArray table, final byte qual, final boolean isError, final int... keys ) { + final RecalDatum existingDatum = table.get(keys); + final RecalDatum newDatum = createDatumObject(qual, isError); + + if ( existingDatum == null ) { + // No existing item, try to put a new one + if ( ! table.put(newDatum, keys) ) { + // Failed to put a new item because another thread came along and put an item here first. + // Get the newly-put item and combine it with our item (item is guaranteed to exist at this point) + table.get(keys).combine(newDatum); + } + } + else { + // Easy case: already an item here, so combine it with our item + existingDatum.combine(newDatum); + } + } + + /** + * Combines the RecalDatum at the specified position in the specified table with a new RecalDatum, or put a + * new item there if there isn't already one. + * + * Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put() + * to return false if another thread inserts a new item at our position in the middle of our put operation. + * + * @param table the table that holds/will hold our item + * @param qual qual for this event + * @param isError error value for this event + * @param keys location in table of our item + */ + protected void combineDatumOrPutIfNecessary( final NestedIntegerArray table, final byte qual, final double isError, final int... keys ) { + final RecalDatum existingDatum = table.get(keys); + final RecalDatum newDatum = createDatumObject(qual, isError); + + if ( existingDatum == null ) { + // No existing item, try to put a new one + if ( ! table.put(newDatum, keys) ) { + // Failed to put a new item because another thread came along and put an item here first. + // Get the newly-put item and combine it with our item (item is guaranteed to exist at this point) + table.get(keys).combine(newDatum); + } + } + else { + // Easy case: already an item here, so combine it with our item + existingDatum.combine(newDatum); + } + } } diff --git a/public/java/src/org/broadinstitute/sting/utils/collections/LoggingNestedIntegerArray.java b/public/java/src/org/broadinstitute/sting/utils/collections/LoggingNestedIntegerArray.java index 6fda0245b..63927ac84 100644 --- a/public/java/src/org/broadinstitute/sting/utils/collections/LoggingNestedIntegerArray.java +++ b/public/java/src/org/broadinstitute/sting/utils/collections/LoggingNestedIntegerArray.java @@ -99,9 +99,7 @@ public class LoggingNestedIntegerArray extends NestedIntegerArray { } @Override - public void put( final T value, final int... keys ) { - super.put(value, keys); - + public boolean put( final T value, final int... keys ) { StringBuilder logEntry = new StringBuilder(); logEntry.append(logEntryLabel); @@ -116,5 +114,7 @@ public class LoggingNestedIntegerArray extends NestedIntegerArray { // PrintStream methods all use synchronized blocks internally, so our logging is thread-safe log.println(logEntry.toString()); + + return super.put(value, keys); } } diff --git a/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java b/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java index 31d316555..9f1eb546d 100755 --- a/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java +++ b/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java @@ -50,6 +50,28 @@ public class NestedIntegerArray { this.dimensions = dimensions.clone(); data = new Object[dimensions[0]]; + prepopulateArray(data, 0); + } + + /** + * Recursively allocate the entire tree of arrays in all its dimensions. + * + * Doing this upfront uses more memory initially, but saves time over the course of the run + * and (crucially) avoids having to make threads wait while traversing the tree to check + * whether branches exist or not. + * + * @param subarray current node in the tree + * @param dimension current level in the tree + */ + private void prepopulateArray( Object[] subarray, int dimension ) { + if ( dimension >= numDimensions - 1 ) { + return; + } + + for ( int i = 0; i < subarray.length; i++ ) { + subarray[i] = new Object[dimensions[dimension + 1]]; + prepopulateArray((Object[])subarray[i], dimension + 1); + } } public T get(final int... keys) { @@ -59,14 +81,15 @@ public class NestedIntegerArray { for( int i = 0; i < numNestedDimensions; i++ ) { if ( keys[i] >= dimensions[i] ) return null; - myData = (Object[])myData[keys[i]]; - if ( myData == null ) - return null; + + myData = (Object[])myData[keys[i]]; // interior nodes in the tree will never be null, so we can safely traverse + // down to the leaves } + return (T)myData[keys[numNestedDimensions]]; } - public synchronized void put(final T value, final int... keys) { // WARNING! value comes before the keys! + public boolean put(final T value, final int... keys) { // WARNING! value comes before the keys! if ( keys.length != numDimensions ) throw new ReviewedStingException("Exactly " + numDimensions + " keys should be passed to this NestedIntegerArray but " + keys.length + " were provided"); @@ -75,15 +98,26 @@ public class NestedIntegerArray { for ( int i = 0; i < numNestedDimensions; i++ ) { if ( keys[i] >= dimensions[i] ) throw new ReviewedStingException("Key " + keys[i] + " is too large for dimension " + i + " (max is " + (dimensions[i]-1) + ")"); - Object[] temp = (Object[])myData[keys[i]]; - if ( temp == null ) { - temp = new Object[dimensions[i+1]]; - myData[keys[i]] = temp; - } - myData = temp; + + myData = (Object[])myData[keys[i]]; // interior nodes in the tree will never be null, so we can safely traverse + // down to the leaves } - myData[keys[numNestedDimensions]] = value; + synchronized ( myData ) { // lock the bottom row while we examine and (potentially) update it + + // Insert the new value only if there still isn't any existing value in this position + if ( myData[keys[numNestedDimensions]] == null ) { + myData[keys[numNestedDimensions]] = value; + } + else { + // Already have a value for this leaf (perhaps another thread came along and inserted one + // while we traversed the tree), so return false to notify the caller that we didn't put + // the item + return false; + } + } + + return true; } public List getAllValues() { diff --git a/public/java/src/org/broadinstitute/sting/utils/threading/ThreadLocalArray.java b/public/java/src/org/broadinstitute/sting/utils/threading/ThreadLocalArray.java new file mode 100644 index 000000000..cc50152ac --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/utils/threading/ThreadLocalArray.java @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2012, The Broad Institute + * + * Permission is hereby granted, free of charge, to any person + * obtaining a copy of this software and associated documentation + * files (the "Software"), to deal in the Software without + * restriction, including without limitation the rights to use, + * copy, modify, merge, publish, distribute, sublicense, and/or sell + * copies of the Software, and to permit persons to whom the + * Software is furnished to do so, subject to the following + * conditions: + * + * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES + * OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT + * HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, + * WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING + * FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR + * OTHER DEALINGS IN THE SOFTWARE. + */ + +package org.broadinstitute.sting.utils.threading; + +import java.lang.reflect.Array; + +/** + * ThreadLocal implementation for arrays + * + * Example usage: + * + * private ThreadLocal threadLocalByteArray = new ThreadLocalArray(length, byte.class); + * .... + * byte[] byteArray = threadLocalByteArray.get(); + * + * @param the type of the array itself (eg., int[], double[], etc.) + * + * @author David Roazen + */ +public class ThreadLocalArray extends ThreadLocal { + private int arraySize; + private Class arrayElementType; + + /** + * Create a new ThreadLocalArray + * + * @param arraySize desired length of the array + * @param arrayElementType type of the elements within the array (eg., Byte.class, Integer.class, etc.) + */ + public ThreadLocalArray( int arraySize, Class arrayElementType ) { + super(); + + this.arraySize = arraySize; + this.arrayElementType = arrayElementType; + } + + @Override + @SuppressWarnings("unchecked") + protected T initialValue() { + return (T)Array.newInstance(arrayElementType, arraySize); + } +}