From dc7636b923c7e0fae628e62304c5261c7206cd26 Mon Sep 17 00:00:00 2001 From: Eric Banks Date: Thu, 28 Jun 2012 02:29:35 -0400 Subject: [PATCH] Refactor the ContextCovariate to significantly reduce runtime --- .../gatk/walkers/bqsr/ContextCovariate.java | 99 ++++++++++++++++--- 1 file changed, 86 insertions(+), 13 deletions(-) diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/ContextCovariate.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/ContextCovariate.java index 365c816c7..7da3c372e 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/ContextCovariate.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/ContextCovariate.java @@ -32,6 +32,8 @@ import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import org.broadinstitute.sting.utils.exceptions.UserException; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; +import java.util.ArrayList; + /** * Created by IntelliJ IDEA. * User: rpoplin @@ -43,6 +45,16 @@ public class ContextCovariate implements StandardCovariate { private int mismatchesContextSize; private int indelsContextSize; + private int mismatchesKeyMask; + private int indelsKeyMask; + + private static final int LENGTH_BITS = 4; + private static final int LENGTH_MASK = 15; + + // temporary lists to use for creating context covariate keys + private final ArrayList mismatchKeys = new ArrayList(200); + private final ArrayList indelKeys = new ArrayList(200); + // the maximum context size (number of bases) permitted; we need to keep the leftmost base free so that values are // not negative and we reserve 4 more bits to represent the length of the context; it takes 2 bits to encode one base. static final private int MAX_DNA_CONTEXT = 13; @@ -62,6 +74,9 @@ public class ContextCovariate implements StandardCovariate { if (mismatchesContextSize <= 0 || indelsContextSize <= 0) throw new UserException(String.format("Context size must be positive, if you don't want to use the context covariate, just turn it off instead. Mismatches: %d Indels: %d", mismatchesContextSize, indelsContextSize)); + + mismatchesKeyMask = createMask(mismatchesContextSize); + indelsKeyMask = createMask(indelsContextSize); } @Override @@ -75,10 +90,15 @@ public class ContextCovariate implements StandardCovariate { if (negativeStrand) bases = BaseUtils.simpleReverseComplement(bases); - final int readLength = clippedRead.getReadLength(); + mismatchKeys.clear(); + indelKeys.clear(); + contextWith(bases, mismatchesContextSize, mismatchKeys, mismatchesKeyMask); + contextWith(bases, indelsContextSize, indelKeys, indelsKeyMask); + + final int readLength = bases.length; for (int i = 0; i < readLength; i++) { - final int indelKey = contextWith(bases, i, indelsContextSize); - values.addCovariate(contextWith(bases, i, mismatchesContextSize), indelKey, indelKey, (negativeStrand ? readLength - i - 1 : i)); + final int indelKey = indelKeys.get(i); + values.addCovariate(mismatchKeys.get(i), indelKey, indelKey, (negativeStrand ? readLength - i - 1 : i)); } } @@ -101,17 +121,72 @@ public class ContextCovariate implements StandardCovariate { return keyFromContext((String) value); } + private static int createMask(final int contextSize) { + int mask = 0; + // create 2*contextSize worth of bits + for (int i = 0; i < contextSize; i++) + mask = (mask << 2) | 3; + // shift 4 bits to mask out the bits used to encode the length + return mask << LENGTH_BITS; + } + /** * calculates the context of a base independent of the covariate mode (mismatch, insertion or deletion) * * @param bases the bases in the read to build the context from - * @param offset the position in the read to calculate the context for * @param contextSize context size to use building the context - * @return the key representing the context + * @param keys list to store the keys + * @param mask mask for pulling out just the context bits */ - private int contextWith(final byte[] bases, final int offset, final int contextSize) { - final int start = offset - contextSize + 1; - return (start >= 0) ? keyFromContext(bases, start, offset + 1) : -1; + private static void contextWith(final byte[] bases, final int contextSize, final ArrayList keys, final int mask) { + + // the first contextSize-1 bases will not have enough previous context + for (int i = 1; i < contextSize && i <= bases.length; i++) + keys.add(-1); + + if (bases.length < contextSize) + return; + + final int newBaseOffset = 2 * (contextSize - 1) + LENGTH_BITS; + + // get (and add) the key for the context starting at the first base + int currentKey = keyFromContext(bases, 0, contextSize); + keys.add(currentKey); + + // if the first key was -1 then there was an N in the context; figure out how many more consecutive contexts it affects + int currentNPenalty = 0; + if (currentKey == -1) { + currentKey = 0; + currentNPenalty = contextSize - 1; + int offset = newBaseOffset; + while (bases[currentNPenalty] != 'N') { + final int baseIndex = BaseUtils.simpleBaseToBaseIndex(bases[currentNPenalty]); + currentKey |= (baseIndex << offset); + offset -= 2; + currentNPenalty--; + } + } + + final int readLength = bases.length; + for (int currentIndex = contextSize; currentIndex < readLength; currentIndex++) { + final int baseIndex = BaseUtils.simpleBaseToBaseIndex(bases[currentIndex]); + if (baseIndex == -1) { // ignore non-ACGT bases + currentNPenalty = contextSize; + currentKey = 0; // reset the key + } else { + // push this base's contribution onto the key: shift everything 2 bits, mask out the non-context bits, and add the new base and the length in + currentKey = (currentKey >> 2) & mask; + currentKey |= (baseIndex << newBaseOffset); + currentKey |= contextSize; + } + + if (currentNPenalty == 0) { + keys.add(currentKey); + } else { + currentNPenalty--; + keys.add(-1); + } + } } public static int keyFromContext(final String dna) { @@ -126,9 +201,7 @@ public class ContextCovariate implements StandardCovariate { * @param end the end position in the array (exclusive) * @return the key representing the dna sequence */ - public static int keyFromContext(final byte[] dna, final int start, final int end) { - - // TODO -- bit fiddle to ge this all working in a single call to the method (mask out length, shift, OR length back in) + private static int keyFromContext(final byte[] dna, final int start, final int end) { int key = end - start; int bitOffset = 4; @@ -152,8 +225,8 @@ public class ContextCovariate implements StandardCovariate { if (key < 0) throw new ReviewedStingException("dna conversion cannot handle negative numbers. Possible overflow?"); - final int length = key & 15; // the first 4 bits represent the length (in bp) of the context - int mask = 48; // use the mask to pull out bases + final int length = key & LENGTH_MASK; // the first bits represent the length (in bp) of the context + int mask = 48; // use the mask to pull out bases int offset = 4; StringBuilder dna = new StringBuilder();