Refactor the ContextCovariate to significantly reduce runtime
This commit is contained in:
parent
1fafd9f6c8
commit
dc7636b923
|
|
@ -32,6 +32,8 @@ import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
|
|||
import org.broadinstitute.sting.utils.exceptions.UserException;
|
||||
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
|
||||
|
||||
import java.util.ArrayList;
|
||||
|
||||
/**
|
||||
* Created by IntelliJ IDEA.
|
||||
* User: rpoplin
|
||||
|
|
@ -43,6 +45,16 @@ public class ContextCovariate implements StandardCovariate {
|
|||
private int mismatchesContextSize;
|
||||
private int indelsContextSize;
|
||||
|
||||
private int mismatchesKeyMask;
|
||||
private int indelsKeyMask;
|
||||
|
||||
private static final int LENGTH_BITS = 4;
|
||||
private static final int LENGTH_MASK = 15;
|
||||
|
||||
// temporary lists to use for creating context covariate keys
|
||||
private final ArrayList<Integer> mismatchKeys = new ArrayList<Integer>(200);
|
||||
private final ArrayList<Integer> indelKeys = new ArrayList<Integer>(200);
|
||||
|
||||
// the maximum context size (number of bases) permitted; we need to keep the leftmost base free so that values are
|
||||
// not negative and we reserve 4 more bits to represent the length of the context; it takes 2 bits to encode one base.
|
||||
static final private int MAX_DNA_CONTEXT = 13;
|
||||
|
|
@ -62,6 +74,9 @@ public class ContextCovariate implements StandardCovariate {
|
|||
|
||||
if (mismatchesContextSize <= 0 || indelsContextSize <= 0)
|
||||
throw new UserException(String.format("Context size must be positive, if you don't want to use the context covariate, just turn it off instead. Mismatches: %d Indels: %d", mismatchesContextSize, indelsContextSize));
|
||||
|
||||
mismatchesKeyMask = createMask(mismatchesContextSize);
|
||||
indelsKeyMask = createMask(indelsContextSize);
|
||||
}
|
||||
|
||||
@Override
|
||||
|
|
@ -75,10 +90,15 @@ public class ContextCovariate implements StandardCovariate {
|
|||
if (negativeStrand)
|
||||
bases = BaseUtils.simpleReverseComplement(bases);
|
||||
|
||||
final int readLength = clippedRead.getReadLength();
|
||||
mismatchKeys.clear();
|
||||
indelKeys.clear();
|
||||
contextWith(bases, mismatchesContextSize, mismatchKeys, mismatchesKeyMask);
|
||||
contextWith(bases, indelsContextSize, indelKeys, indelsKeyMask);
|
||||
|
||||
final int readLength = bases.length;
|
||||
for (int i = 0; i < readLength; i++) {
|
||||
final int indelKey = contextWith(bases, i, indelsContextSize);
|
||||
values.addCovariate(contextWith(bases, i, mismatchesContextSize), indelKey, indelKey, (negativeStrand ? readLength - i - 1 : i));
|
||||
final int indelKey = indelKeys.get(i);
|
||||
values.addCovariate(mismatchKeys.get(i), indelKey, indelKey, (negativeStrand ? readLength - i - 1 : i));
|
||||
}
|
||||
}
|
||||
|
||||
|
|
@ -101,17 +121,72 @@ public class ContextCovariate implements StandardCovariate {
|
|||
return keyFromContext((String) value);
|
||||
}
|
||||
|
||||
private static int createMask(final int contextSize) {
|
||||
int mask = 0;
|
||||
// create 2*contextSize worth of bits
|
||||
for (int i = 0; i < contextSize; i++)
|
||||
mask = (mask << 2) | 3;
|
||||
// shift 4 bits to mask out the bits used to encode the length
|
||||
return mask << LENGTH_BITS;
|
||||
}
|
||||
|
||||
/**
|
||||
* calculates the context of a base independent of the covariate mode (mismatch, insertion or deletion)
|
||||
*
|
||||
* @param bases the bases in the read to build the context from
|
||||
* @param offset the position in the read to calculate the context for
|
||||
* @param contextSize context size to use building the context
|
||||
* @return the key representing the context
|
||||
* @param keys list to store the keys
|
||||
* @param mask mask for pulling out just the context bits
|
||||
*/
|
||||
private int contextWith(final byte[] bases, final int offset, final int contextSize) {
|
||||
final int start = offset - contextSize + 1;
|
||||
return (start >= 0) ? keyFromContext(bases, start, offset + 1) : -1;
|
||||
private static void contextWith(final byte[] bases, final int contextSize, final ArrayList<Integer> keys, final int mask) {
|
||||
|
||||
// the first contextSize-1 bases will not have enough previous context
|
||||
for (int i = 1; i < contextSize && i <= bases.length; i++)
|
||||
keys.add(-1);
|
||||
|
||||
if (bases.length < contextSize)
|
||||
return;
|
||||
|
||||
final int newBaseOffset = 2 * (contextSize - 1) + LENGTH_BITS;
|
||||
|
||||
// get (and add) the key for the context starting at the first base
|
||||
int currentKey = keyFromContext(bases, 0, contextSize);
|
||||
keys.add(currentKey);
|
||||
|
||||
// if the first key was -1 then there was an N in the context; figure out how many more consecutive contexts it affects
|
||||
int currentNPenalty = 0;
|
||||
if (currentKey == -1) {
|
||||
currentKey = 0;
|
||||
currentNPenalty = contextSize - 1;
|
||||
int offset = newBaseOffset;
|
||||
while (bases[currentNPenalty] != 'N') {
|
||||
final int baseIndex = BaseUtils.simpleBaseToBaseIndex(bases[currentNPenalty]);
|
||||
currentKey |= (baseIndex << offset);
|
||||
offset -= 2;
|
||||
currentNPenalty--;
|
||||
}
|
||||
}
|
||||
|
||||
final int readLength = bases.length;
|
||||
for (int currentIndex = contextSize; currentIndex < readLength; currentIndex++) {
|
||||
final int baseIndex = BaseUtils.simpleBaseToBaseIndex(bases[currentIndex]);
|
||||
if (baseIndex == -1) { // ignore non-ACGT bases
|
||||
currentNPenalty = contextSize;
|
||||
currentKey = 0; // reset the key
|
||||
} else {
|
||||
// push this base's contribution onto the key: shift everything 2 bits, mask out the non-context bits, and add the new base and the length in
|
||||
currentKey = (currentKey >> 2) & mask;
|
||||
currentKey |= (baseIndex << newBaseOffset);
|
||||
currentKey |= contextSize;
|
||||
}
|
||||
|
||||
if (currentNPenalty == 0) {
|
||||
keys.add(currentKey);
|
||||
} else {
|
||||
currentNPenalty--;
|
||||
keys.add(-1);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
public static int keyFromContext(final String dna) {
|
||||
|
|
@ -126,9 +201,7 @@ public class ContextCovariate implements StandardCovariate {
|
|||
* @param end the end position in the array (exclusive)
|
||||
* @return the key representing the dna sequence
|
||||
*/
|
||||
public static int keyFromContext(final byte[] dna, final int start, final int end) {
|
||||
|
||||
// TODO -- bit fiddle to ge this all working in a single call to the method (mask out length, shift, OR length back in)
|
||||
private static int keyFromContext(final byte[] dna, final int start, final int end) {
|
||||
|
||||
int key = end - start;
|
||||
int bitOffset = 4;
|
||||
|
|
@ -152,8 +225,8 @@ public class ContextCovariate implements StandardCovariate {
|
|||
if (key < 0)
|
||||
throw new ReviewedStingException("dna conversion cannot handle negative numbers. Possible overflow?");
|
||||
|
||||
final int length = key & 15; // the first 4 bits represent the length (in bp) of the context
|
||||
int mask = 48; // use the mask to pull out bases
|
||||
final int length = key & LENGTH_MASK; // the first bits represent the length (in bp) of the context
|
||||
int mask = 48; // use the mask to pull out bases
|
||||
int offset = 4;
|
||||
|
||||
StringBuilder dna = new StringBuilder();
|
||||
|
|
|
|||
Loading…
Reference in New Issue