Refactor the ContextCovariate to significantly reduce runtime

This commit is contained in:
Eric Banks 2012-06-28 02:29:35 -04:00
parent 1fafd9f6c8
commit dc7636b923
1 changed files with 86 additions and 13 deletions

View File

@ -32,6 +32,8 @@ import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
import org.broadinstitute.sting.utils.exceptions.UserException;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import java.util.ArrayList;
/**
* Created by IntelliJ IDEA.
* User: rpoplin
@ -43,6 +45,16 @@ public class ContextCovariate implements StandardCovariate {
private int mismatchesContextSize;
private int indelsContextSize;
private int mismatchesKeyMask;
private int indelsKeyMask;
private static final int LENGTH_BITS = 4;
private static final int LENGTH_MASK = 15;
// temporary lists to use for creating context covariate keys
private final ArrayList<Integer> mismatchKeys = new ArrayList<Integer>(200);
private final ArrayList<Integer> indelKeys = new ArrayList<Integer>(200);
// the maximum context size (number of bases) permitted; we need to keep the leftmost base free so that values are
// not negative and we reserve 4 more bits to represent the length of the context; it takes 2 bits to encode one base.
static final private int MAX_DNA_CONTEXT = 13;
@ -62,6 +74,9 @@ public class ContextCovariate implements StandardCovariate {
if (mismatchesContextSize <= 0 || indelsContextSize <= 0)
throw new UserException(String.format("Context size must be positive, if you don't want to use the context covariate, just turn it off instead. Mismatches: %d Indels: %d", mismatchesContextSize, indelsContextSize));
mismatchesKeyMask = createMask(mismatchesContextSize);
indelsKeyMask = createMask(indelsContextSize);
}
@Override
@ -75,10 +90,15 @@ public class ContextCovariate implements StandardCovariate {
if (negativeStrand)
bases = BaseUtils.simpleReverseComplement(bases);
final int readLength = clippedRead.getReadLength();
mismatchKeys.clear();
indelKeys.clear();
contextWith(bases, mismatchesContextSize, mismatchKeys, mismatchesKeyMask);
contextWith(bases, indelsContextSize, indelKeys, indelsKeyMask);
final int readLength = bases.length;
for (int i = 0; i < readLength; i++) {
final int indelKey = contextWith(bases, i, indelsContextSize);
values.addCovariate(contextWith(bases, i, mismatchesContextSize), indelKey, indelKey, (negativeStrand ? readLength - i - 1 : i));
final int indelKey = indelKeys.get(i);
values.addCovariate(mismatchKeys.get(i), indelKey, indelKey, (negativeStrand ? readLength - i - 1 : i));
}
}
@ -101,17 +121,72 @@ public class ContextCovariate implements StandardCovariate {
return keyFromContext((String) value);
}
private static int createMask(final int contextSize) {
int mask = 0;
// create 2*contextSize worth of bits
for (int i = 0; i < contextSize; i++)
mask = (mask << 2) | 3;
// shift 4 bits to mask out the bits used to encode the length
return mask << LENGTH_BITS;
}
/**
* calculates the context of a base independent of the covariate mode (mismatch, insertion or deletion)
*
* @param bases the bases in the read to build the context from
* @param offset the position in the read to calculate the context for
* @param contextSize context size to use building the context
* @return the key representing the context
* @param keys list to store the keys
* @param mask mask for pulling out just the context bits
*/
private int contextWith(final byte[] bases, final int offset, final int contextSize) {
final int start = offset - contextSize + 1;
return (start >= 0) ? keyFromContext(bases, start, offset + 1) : -1;
private static void contextWith(final byte[] bases, final int contextSize, final ArrayList<Integer> keys, final int mask) {
// the first contextSize-1 bases will not have enough previous context
for (int i = 1; i < contextSize && i <= bases.length; i++)
keys.add(-1);
if (bases.length < contextSize)
return;
final int newBaseOffset = 2 * (contextSize - 1) + LENGTH_BITS;
// get (and add) the key for the context starting at the first base
int currentKey = keyFromContext(bases, 0, contextSize);
keys.add(currentKey);
// if the first key was -1 then there was an N in the context; figure out how many more consecutive contexts it affects
int currentNPenalty = 0;
if (currentKey == -1) {
currentKey = 0;
currentNPenalty = contextSize - 1;
int offset = newBaseOffset;
while (bases[currentNPenalty] != 'N') {
final int baseIndex = BaseUtils.simpleBaseToBaseIndex(bases[currentNPenalty]);
currentKey |= (baseIndex << offset);
offset -= 2;
currentNPenalty--;
}
}
final int readLength = bases.length;
for (int currentIndex = contextSize; currentIndex < readLength; currentIndex++) {
final int baseIndex = BaseUtils.simpleBaseToBaseIndex(bases[currentIndex]);
if (baseIndex == -1) { // ignore non-ACGT bases
currentNPenalty = contextSize;
currentKey = 0; // reset the key
} else {
// push this base's contribution onto the key: shift everything 2 bits, mask out the non-context bits, and add the new base and the length in
currentKey = (currentKey >> 2) & mask;
currentKey |= (baseIndex << newBaseOffset);
currentKey |= contextSize;
}
if (currentNPenalty == 0) {
keys.add(currentKey);
} else {
currentNPenalty--;
keys.add(-1);
}
}
}
public static int keyFromContext(final String dna) {
@ -126,9 +201,7 @@ public class ContextCovariate implements StandardCovariate {
* @param end the end position in the array (exclusive)
* @return the key representing the dna sequence
*/
public static int keyFromContext(final byte[] dna, final int start, final int end) {
// TODO -- bit fiddle to ge this all working in a single call to the method (mask out length, shift, OR length back in)
private static int keyFromContext(final byte[] dna, final int start, final int end) {
int key = end - start;
int bitOffset = 4;
@ -152,8 +225,8 @@ public class ContextCovariate implements StandardCovariate {
if (key < 0)
throw new ReviewedStingException("dna conversion cannot handle negative numbers. Possible overflow?");
final int length = key & 15; // the first 4 bits represent the length (in bp) of the context
int mask = 48; // use the mask to pull out bases
final int length = key & LENGTH_MASK; // the first bits represent the length (in bp) of the context
int mask = 48; // use the mask to pull out bases
int offset = 4;
StringBuilder dna = new StringBuilder();