BQSR: use more granular locking for concurrency control

-With this change, BQSR performance scales properly by thread rather
 than gaining nothing from additional threads.
-Benefits are seen when using either -nt (HierarchicalMicroScheduler) or -nct
 (NanoScheduler)
-Removes high-level locks in the recalibration engines and NestedIntegerArray
 in favor of maximally-granular locks on and around manipulation of the leaf
 nodes of the NestedIntegerArray.
-NestedIntegerArray now creates all interior nodes upfront rather than on
 the fly to avoid the need for locking during tree traversals. This uses
 more memory in the initial part of BQSR runs, but the BQSR would eventually
 converge to use this memory anyway over the course of a typical run.

IMPORTANT NOTE: This does not mean it's safe to run the old BaseRecalibrator
walker with multiple threads. The BaseRecalibrator walker is and will never be
thread-safe, as it's a LocusWalker that uses read attributes to track
state information. ONLY the newer DelocalizedBaseRecalibrator can be made
thread-safe (and will hopefully be made so in my subsequent commits). This
commit addresses performance, not correctness.
This commit is contained in:
David Roazen 2012-10-22 13:36:07 -04:00
parent a27ee26481
commit 991658acf4
5 changed files with 270 additions and 84 deletions

View File

@ -28,26 +28,22 @@ package org.broadinstitute.sting.gatk.walkers.bqsr;
import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
import org.broadinstitute.sting.utils.BaseUtils;
import org.broadinstitute.sting.utils.classloader.ProtectedPackageSource;
import org.broadinstitute.sting.utils.collections.NestedIntegerArray;
import org.broadinstitute.sting.utils.pileup.PileupElement;
import org.broadinstitute.sting.utils.recalibration.EventType;
import org.broadinstitute.sting.utils.recalibration.ReadCovariates;
import org.broadinstitute.sting.utils.recalibration.RecalDatum;
import org.broadinstitute.sting.utils.recalibration.RecalibrationTables;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import org.broadinstitute.sting.utils.threading.ThreadLocalArray;
public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine implements ProtectedPackageSource {
// optimizations: don't reallocate an array each time
private byte[] tempQualArray;
private boolean[] tempErrorArray;
private double[] tempFractionalErrorArray;
// optimization: only allocate temp arrays once per thread
private final ThreadLocal<byte[]> threadLocalTempQualArray = new ThreadLocalArray<byte[]>(EventType.values().length, byte.class);
private final ThreadLocal<boolean[]> threadLocalTempErrorArray = new ThreadLocalArray<boolean[]>(EventType.values().length, boolean.class);
private final ThreadLocal<double[]> threadLocalTempFractionalErrorArray = new ThreadLocalArray<double[]>(EventType.values().length, double.class);
public void initialize(final Covariate[] covariates, final RecalibrationTables recalibrationTables) {
super.initialize(covariates, recalibrationTables);
tempQualArray = new byte[EventType.values().length];
tempErrorArray = new boolean[EventType.values().length];
tempFractionalErrorArray = new double[EventType.values().length];
}
/**
@ -60,10 +56,13 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp
* @param refBase The reference base at this locus
*/
@Override
public synchronized void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) {
public void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) {
final int offset = pileupElement.getOffset();
final ReadCovariates readCovariates = covariateKeySetFrom(pileupElement.getRead());
byte[] tempQualArray = threadLocalTempQualArray.get();
boolean[] tempErrorArray = threadLocalTempErrorArray.get();
tempQualArray[EventType.BASE_SUBSTITUTION.index] = pileupElement.getQual();
tempErrorArray[EventType.BASE_SUBSTITUTION.index] = !BaseUtils.basesAreEqual(pileupElement.getBase(), refBase);
tempQualArray[EventType.BASE_INSERTION.index] = pileupElement.getBaseInsertionQual();
@ -77,40 +76,29 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp
final byte qual = tempQualArray[eventIndex];
final boolean isError = tempErrorArray[eventIndex];
final NestedIntegerArray<RecalDatum> rgRecalTable = recalibrationTables.getReadGroupTable();
final RecalDatum rgPreviousDatum = rgRecalTable.get(keys[0], eventIndex);
final RecalDatum rgThisDatum = createDatumObject(qual, isError);
if (rgPreviousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it
rgRecalTable.put(rgThisDatum, keys[0], eventIndex);
else
rgPreviousDatum.combine(rgThisDatum);
// TODO: should this really be combine rather than increment?
combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex);
final NestedIntegerArray<RecalDatum> qualRecalTable = recalibrationTables.getQualityScoreTable();
final RecalDatum qualPreviousDatum = qualRecalTable.get(keys[0], keys[1], eventIndex);
if (qualPreviousDatum == null)
qualRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], eventIndex);
else
qualPreviousDatum.increment(isError);
incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex);
for (int i = 2; i < covariates.length; i++) {
if (keys[i] < 0)
continue;
final NestedIntegerArray<RecalDatum> covRecalTable = recalibrationTables.getTable(i);
final RecalDatum covPreviousDatum = covRecalTable.get(keys[0], keys[1], keys[i], eventIndex);
if (covPreviousDatum == null)
covRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], keys[i], eventIndex);
else
covPreviousDatum.increment(isError);
incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex);
}
}
}
@Override
public synchronized void updateDataForRead(final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) {
public void updateDataForRead(final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) {
for( int offset = 0; offset < read.getReadBases().length; offset++ ) {
if( !skip[offset] ) {
final ReadCovariates readCovariates = covariateKeySetFrom(read);
byte[] tempQualArray = threadLocalTempQualArray.get();
double[] tempFractionalErrorArray = threadLocalTempFractionalErrorArray.get();
tempQualArray[EventType.BASE_SUBSTITUTION.index] = read.getBaseQualities()[offset];
tempFractionalErrorArray[EventType.BASE_SUBSTITUTION.index] = snpErrors[offset];
tempQualArray[EventType.BASE_INSERTION.index] = read.getBaseInsertionQualities()[offset];
@ -124,30 +112,15 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp
final byte qual = tempQualArray[eventIndex];
final double isError = tempFractionalErrorArray[eventIndex];
final NestedIntegerArray<RecalDatum> rgRecalTable = recalibrationTables.getReadGroupTable();
final RecalDatum rgPreviousDatum = rgRecalTable.get(keys[0], eventIndex);
final RecalDatum rgThisDatum = createDatumObject(qual, isError);
if (rgPreviousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it
rgRecalTable.put(rgThisDatum, keys[0], eventIndex);
else
rgPreviousDatum.combine(rgThisDatum);
combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex);
final NestedIntegerArray<RecalDatum> qualRecalTable = recalibrationTables.getQualityScoreTable();
final RecalDatum qualPreviousDatum = qualRecalTable.get(keys[0], keys[1], eventIndex);
if (qualPreviousDatum == null)
qualRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], eventIndex);
else
qualPreviousDatum.increment(1.0, isError);
incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex);
for (int i = 2; i < covariates.length; i++) {
if (keys[i] < 0)
continue;
final NestedIntegerArray<RecalDatum> covRecalTable = recalibrationTables.getTable(i);
final RecalDatum covPreviousDatum = covRecalTable.get(keys[0], keys[1], keys[i], eventIndex);
if (covPreviousDatum == null)
covRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], keys[i], eventIndex);
else
covPreviousDatum.increment(1.0, isError);
incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex);
}
}
}

View File

@ -55,7 +55,7 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
* @param refBase The reference base at this locus
*/
@Override
public synchronized void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) {
public void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) {
final int offset = pileupElement.getOffset();
final ReadCovariates readCovariates = covariateKeySetFrom(pileupElement.getRead());
@ -65,35 +65,21 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
final int[] keys = readCovariates.getKeySet(offset, EventType.BASE_SUBSTITUTION);
final int eventIndex = EventType.BASE_SUBSTITUTION.index;
final NestedIntegerArray<RecalDatum> rgRecalTable = recalibrationTables.getReadGroupTable();
final RecalDatum rgPreviousDatum = rgRecalTable.get(keys[0], eventIndex);
final RecalDatum rgThisDatum = createDatumObject(qual, isError);
if (rgPreviousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it
rgRecalTable.put(rgThisDatum, keys[0], eventIndex);
else
rgPreviousDatum.combine(rgThisDatum);
// TODO: should this really be combine rather than increment?
combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex);
final NestedIntegerArray<RecalDatum> qualRecalTable = recalibrationTables.getQualityScoreTable();
final RecalDatum qualPreviousDatum = qualRecalTable.get(keys[0], keys[1], eventIndex);
if (qualPreviousDatum == null)
qualRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], eventIndex);
else
qualPreviousDatum.increment(isError);
incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex);
for (int i = 2; i < covariates.length; i++) {
if (keys[i] < 0)
continue;
final NestedIntegerArray<RecalDatum> covRecalTable = recalibrationTables.getTable(i);
final RecalDatum covPreviousDatum = covRecalTable.get(keys[0], keys[1], keys[i], eventIndex);
if (covPreviousDatum == null)
covRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], keys[i], eventIndex);
else
covPreviousDatum.increment(isError);
incrementDatumOrPutIfNecessary(recalibrationTables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex);
}
}
@Override
public synchronized void updateDataForRead( final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) {
public void updateDataForRead( final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) {
throw new UnsupportedOperationException("Delocalized BQSR is not available in the GATK-lite version");
}
@ -121,4 +107,133 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
protected ReadCovariates covariateKeySetFrom(GATKSAMRecord read) {
return (ReadCovariates) read.getTemporaryAttribute(BaseRecalibrator.COVARS_ATTRIBUTE);
}
/**
* Increments the RecalDatum at the specified position in the specified table, or put a new item there
* if there isn't already one.
*
* Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put()
* to return false if another thread inserts a new item at our position in the middle of our put operation.
*
* @param table the table that holds/will hold our item
* @param qual qual for this event
* @param isError error status for this event
* @param keys location in table of our item
*/
protected void incrementDatumOrPutIfNecessary( final NestedIntegerArray<RecalDatum> table, final byte qual, final boolean isError, final int... keys ) {
final RecalDatum existingDatum = table.get(keys);
// There are three cases here:
//
// 1. The original get succeeded (existingDatum != null) because there already was an item in this position.
// In this case we can just increment the existing item.
//
// 2. The original get failed (existingDatum == null), and we successfully put a new item in this position
// and are done.
//
// 3. The original get failed (existingDatum == null), AND the put fails because another thread comes along
// in the interim and puts an item in this position. In this case we need to do another get after the
// failed put to get the item the other thread put here and increment it.
if ( existingDatum == null ) {
// No existing item, try to put a new one
if ( ! table.put(createDatumObject(qual, isError), keys) ) {
// Failed to put a new item because another thread came along and put an item here first.
// Get the newly-put item and increment it (item is guaranteed to exist at this point)
table.get(keys).increment(isError);
}
}
else {
// Easy case: already an item here, so increment it
existingDatum.increment(isError);
}
}
/**
* Increments the RecalDatum at the specified position in the specified table, or put a new item there
* if there isn't already one.
*
* Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put()
* to return false if another thread inserts a new item at our position in the middle of our put operation.
*
* @param table the table that holds/will hold our item
* @param qual qual for this event
* @param isError error value for this event
* @param keys location in table of our item
*/
protected void incrementDatumOrPutIfNecessary( final NestedIntegerArray<RecalDatum> table, final byte qual, final double isError, final int... keys ) {
final RecalDatum existingDatum = table.get(keys);
if ( existingDatum == null ) {
// No existing item, try to put a new one
if ( ! table.put(createDatumObject(qual, isError), keys) ) {
// Failed to put a new item because another thread came along and put an item here first.
// Get the newly-put item and increment it (item is guaranteed to exist at this point)
table.get(keys).increment(1.0, isError);
}
}
else {
// Easy case: already an item here, so increment it
existingDatum.increment(1.0, isError);
}
}
/**
* Combines the RecalDatum at the specified position in the specified table with a new RecalDatum, or put a
* new item there if there isn't already one.
*
* Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put()
* to return false if another thread inserts a new item at our position in the middle of our put operation.
*
* @param table the table that holds/will hold our item
* @param qual qual for this event
* @param isError error status for this event
* @param keys location in table of our item
*/
protected void combineDatumOrPutIfNecessary( final NestedIntegerArray<RecalDatum> table, final byte qual, final boolean isError, final int... keys ) {
final RecalDatum existingDatum = table.get(keys);
final RecalDatum newDatum = createDatumObject(qual, isError);
if ( existingDatum == null ) {
// No existing item, try to put a new one
if ( ! table.put(newDatum, keys) ) {
// Failed to put a new item because another thread came along and put an item here first.
// Get the newly-put item and combine it with our item (item is guaranteed to exist at this point)
table.get(keys).combine(newDatum);
}
}
else {
// Easy case: already an item here, so combine it with our item
existingDatum.combine(newDatum);
}
}
/**
* Combines the RecalDatum at the specified position in the specified table with a new RecalDatum, or put a
* new item there if there isn't already one.
*
* Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put()
* to return false if another thread inserts a new item at our position in the middle of our put operation.
*
* @param table the table that holds/will hold our item
* @param qual qual for this event
* @param isError error value for this event
* @param keys location in table of our item
*/
protected void combineDatumOrPutIfNecessary( final NestedIntegerArray<RecalDatum> table, final byte qual, final double isError, final int... keys ) {
final RecalDatum existingDatum = table.get(keys);
final RecalDatum newDatum = createDatumObject(qual, isError);
if ( existingDatum == null ) {
// No existing item, try to put a new one
if ( ! table.put(newDatum, keys) ) {
// Failed to put a new item because another thread came along and put an item here first.
// Get the newly-put item and combine it with our item (item is guaranteed to exist at this point)
table.get(keys).combine(newDatum);
}
}
else {
// Easy case: already an item here, so combine it with our item
existingDatum.combine(newDatum);
}
}
}

View File

@ -99,9 +99,7 @@ public class LoggingNestedIntegerArray<T> extends NestedIntegerArray<T> {
}
@Override
public void put( final T value, final int... keys ) {
super.put(value, keys);
public boolean put( final T value, final int... keys ) {
StringBuilder logEntry = new StringBuilder();
logEntry.append(logEntryLabel);
@ -116,5 +114,7 @@ public class LoggingNestedIntegerArray<T> extends NestedIntegerArray<T> {
// PrintStream methods all use synchronized blocks internally, so our logging is thread-safe
log.println(logEntry.toString());
return super.put(value, keys);
}
}

View File

@ -50,6 +50,28 @@ public class NestedIntegerArray<T> {
this.dimensions = dimensions.clone();
data = new Object[dimensions[0]];
prepopulateArray(data, 0);
}
/**
* Recursively allocate the entire tree of arrays in all its dimensions.
*
* Doing this upfront uses more memory initially, but saves time over the course of the run
* and (crucially) avoids having to make threads wait while traversing the tree to check
* whether branches exist or not.
*
* @param subarray current node in the tree
* @param dimension current level in the tree
*/
private void prepopulateArray( Object[] subarray, int dimension ) {
if ( dimension >= numDimensions - 1 ) {
return;
}
for ( int i = 0; i < subarray.length; i++ ) {
subarray[i] = new Object[dimensions[dimension + 1]];
prepopulateArray((Object[])subarray[i], dimension + 1);
}
}
public T get(final int... keys) {
@ -59,14 +81,15 @@ public class NestedIntegerArray<T> {
for( int i = 0; i < numNestedDimensions; i++ ) {
if ( keys[i] >= dimensions[i] )
return null;
myData = (Object[])myData[keys[i]];
if ( myData == null )
return null;
myData = (Object[])myData[keys[i]]; // interior nodes in the tree will never be null, so we can safely traverse
// down to the leaves
}
return (T)myData[keys[numNestedDimensions]];
}
public synchronized void put(final T value, final int... keys) { // WARNING! value comes before the keys!
public boolean put(final T value, final int... keys) { // WARNING! value comes before the keys!
if ( keys.length != numDimensions )
throw new ReviewedStingException("Exactly " + numDimensions + " keys should be passed to this NestedIntegerArray but " + keys.length + " were provided");
@ -75,15 +98,26 @@ public class NestedIntegerArray<T> {
for ( int i = 0; i < numNestedDimensions; i++ ) {
if ( keys[i] >= dimensions[i] )
throw new ReviewedStingException("Key " + keys[i] + " is too large for dimension " + i + " (max is " + (dimensions[i]-1) + ")");
Object[] temp = (Object[])myData[keys[i]];
if ( temp == null ) {
temp = new Object[dimensions[i+1]];
myData[keys[i]] = temp;
}
myData = temp;
myData = (Object[])myData[keys[i]]; // interior nodes in the tree will never be null, so we can safely traverse
// down to the leaves
}
myData[keys[numNestedDimensions]] = value;
synchronized ( myData ) { // lock the bottom row while we examine and (potentially) update it
// Insert the new value only if there still isn't any existing value in this position
if ( myData[keys[numNestedDimensions]] == null ) {
myData[keys[numNestedDimensions]] = value;
}
else {
// Already have a value for this leaf (perhaps another thread came along and inserted one
// while we traversed the tree), so return false to notify the caller that we didn't put
// the item
return false;
}
}
return true;
}
public List<T> getAllValues() {

View File

@ -0,0 +1,64 @@
/*
* Copyright (c) 2012, The Broad Institute
*
* Permission is hereby granted, free of charge, to any person
* obtaining a copy of this software and associated documentation
* files (the "Software"), to deal in the Software without
* restriction, including without limitation the rights to use,
* copy, modify, merge, publish, distribute, sublicense, and/or sell
* copies of the Software, and to permit persons to whom the
* Software is furnished to do so, subject to the following
* conditions:
*
* The above copyright notice and this permission notice shall be
* included in all copies or substantial portions of the Software.
* THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND,
* EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES
* OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND
* NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT
* HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY,
* WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING
* FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR
* OTHER DEALINGS IN THE SOFTWARE.
*/
package org.broadinstitute.sting.utils.threading;
import java.lang.reflect.Array;
/**
* ThreadLocal implementation for arrays
*
* Example usage:
*
* private ThreadLocal<byte[]> threadLocalByteArray = new ThreadLocalArray<byte[]>(length, byte.class);
* ....
* byte[] byteArray = threadLocalByteArray.get();
*
* @param <T> the type of the array itself (eg., int[], double[], etc.)
*
* @author David Roazen
*/
public class ThreadLocalArray<T> extends ThreadLocal<T> {
private int arraySize;
private Class arrayElementType;
/**
* Create a new ThreadLocalArray
*
* @param arraySize desired length of the array
* @param arrayElementType type of the elements within the array (eg., Byte.class, Integer.class, etc.)
*/
public ThreadLocalArray( int arraySize, Class arrayElementType ) {
super();
this.arraySize = arraySize;
this.arrayElementType = arrayElementType;
}
@Override
@SuppressWarnings("unchecked")
protected T initialValue() {
return (T)Array.newInstance(arrayElementType, arraySize);
}
}