From 20ffbcc86e98794ff40d97dff6a3d9b9859bbc15 Mon Sep 17 00:00:00 2001 From: Eric Banks Date: Wed, 17 Oct 2012 21:44:53 -0400 Subject: [PATCH] RR optimization: profiling was showing that the BaseCounts class was a major bottleneck because the underlying implementation was a HashMap. Given that the map index was an indexable Enum anyways, it makes a lot more sense to implement as a native array. Knocks 30% off the runtime in bad regions. --- .../reducereads/BaseAndQualsCounts.java | 41 ++--- .../compression/reducereads/BaseCounts.java | 149 ++++++++---------- .../compression/reducereads/BaseIndex.java | 6 +- .../reducereads/HeaderElement.java | 4 +- 4 files changed, 89 insertions(+), 111 deletions(-) diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/BaseAndQualsCounts.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/BaseAndQualsCounts.java index d5afc5722..654e0af09 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/BaseAndQualsCounts.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/BaseAndQualsCounts.java @@ -1,8 +1,5 @@ package org.broadinstitute.sting.gatk.walkers.compression.reducereads; -import java.util.HashMap; -import java.util.Map; - /** * An object that keeps track of the base counts as well as the sum of the base, insertion and deletion qualities of each base. * @@ -10,35 +7,31 @@ import java.util.Map; * @since 6/15/12 */ public class BaseAndQualsCounts extends BaseCounts { - private final Map sumInsertionQuals; - private final Map sumDeletionQuals; + private final long[] sumInsertionQuals; + private final long[] sumDeletionQuals; public BaseAndQualsCounts() { super(); - this.sumInsertionQuals = new HashMap(); - this.sumDeletionQuals = new HashMap(); - for (BaseIndex i : BaseIndex.values()) { - sumInsertionQuals.put(i, 0L); - sumDeletionQuals.put(i, 0L); + this.sumInsertionQuals = new long[BaseIndex.values().length]; + this.sumDeletionQuals = new long[BaseIndex.values().length]; + for (final BaseIndex i : BaseIndex.values()) { + sumInsertionQuals[i.index] = 0L; + sumDeletionQuals[i.index] = 0L; } } public void incr(final byte base, final byte baseQual, final byte insQual, final byte delQual) { - super.incr(base, baseQual); - BaseIndex i = BaseIndex.byteToBase(base); - if (i != null) { // do not allow Ns - sumInsertionQuals.put(i, sumInsertionQuals.get(i) + insQual); - sumDeletionQuals.put(i, sumDeletionQuals.get(i) + delQual); - } + final BaseIndex i = BaseIndex.byteToBase(base); + super.incr(i, baseQual); + sumInsertionQuals[i.index] += insQual; + sumDeletionQuals[i.index] += delQual; } public void decr(final byte base, final byte baseQual, final byte insQual, final byte delQual) { - super.decr(base, baseQual); - BaseIndex i = BaseIndex.byteToBase(base); - if (i != null) { // do not allow Ns - sumInsertionQuals.put(i, sumInsertionQuals.get(i) - insQual); - sumDeletionQuals.put(i, sumDeletionQuals.get(i) - delQual); - } + final BaseIndex i = BaseIndex.byteToBase(base); + super.decr(i, baseQual); + sumInsertionQuals[i.index] -= insQual; + sumDeletionQuals[i.index] -= delQual; } public byte averageInsertionQualsOfBase(final BaseIndex base) { @@ -49,7 +42,7 @@ public class BaseAndQualsCounts extends BaseCounts { return getGenericAverageQualOfBase(base, sumDeletionQuals); } - private byte getGenericAverageQualOfBase(final BaseIndex base, final Map sumQuals) { - return (byte) (sumQuals.get(base) / getCount(base)); + private byte getGenericAverageQualOfBase(final BaseIndex base, final long[] sumQuals) { + return (byte) (sumQuals[base.index] / countOfBase(base)); } } diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/BaseCounts.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/BaseCounts.java index fb76ef291..3a3905710 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/BaseCounts.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/BaseCounts.java @@ -3,8 +3,6 @@ package org.broadinstitute.sting.gatk.walkers.compression.reducereads; import com.google.java.contract.Ensures; import com.google.java.contract.Requires; -import java.util.EnumMap; -import java.util.Map; /** * An object to keep track of the number of occurrences of each base and it's quality. @@ -18,25 +16,25 @@ import java.util.Map; public final static BaseIndex MAX_BASE_INDEX_WITH_NO_COUNTS = BaseIndex.N; public final static byte MAX_BASE_WITH_NO_COUNTS = MAX_BASE_INDEX_WITH_NO_COUNTS.getByte(); - private final Map counts; // keeps track of the base counts - private final Map sumQuals; // keeps track of the quals of each base - private int totalCount = 0; // keeps track of total count since this is requested so often + private final int[] counts; // keeps track of the base counts + private final long[] sumQuals; // keeps track of the quals of each base + private int totalCount = 0; // keeps track of total count since this is requested so often public BaseCounts() { - counts = new EnumMap(BaseIndex.class); - sumQuals = new EnumMap(BaseIndex.class); - for (BaseIndex i : BaseIndex.values()) { - counts.put(i, 0); - sumQuals.put(i, 0L); + counts = new int[BaseIndex.values().length]; + sumQuals = new long[BaseIndex.values().length]; + for (final BaseIndex i : BaseIndex.values()) { + counts[i.index] = 0; + sumQuals[i.index] = 0L; } } public static BaseCounts createWithCounts(int[] countsACGT) { BaseCounts baseCounts = new BaseCounts(); - baseCounts.counts.put(BaseIndex.A, countsACGT[0]); - baseCounts.counts.put(BaseIndex.C, countsACGT[1]); - baseCounts.counts.put(BaseIndex.G, countsACGT[2]); - baseCounts.counts.put(BaseIndex.T, countsACGT[3]); + baseCounts.counts[BaseIndex.A.index] = countsACGT[0]; + baseCounts.counts[BaseIndex.C.index] = countsACGT[1]; + baseCounts.counts[BaseIndex.G.index] = countsACGT[2]; + baseCounts.counts[BaseIndex.T.index] = countsACGT[3]; baseCounts.totalCount = countsACGT[0] + countsACGT[1] + countsACGT[2] + countsACGT[3]; return baseCounts; } @@ -44,8 +42,8 @@ import java.util.Map; @Requires("other != null") public void add(final BaseCounts other) { for (final BaseIndex i : BaseIndex.values()) { - final int otherCount = other.counts.get(i); - counts.put(i, counts.get(i) + otherCount); + final int otherCount = other.counts[i.index]; + counts[i.index] += otherCount; totalCount += otherCount; } } @@ -53,8 +51,8 @@ import java.util.Map; @Requires("other != null") public void sub(final BaseCounts other) { for (final BaseIndex i : BaseIndex.values()) { - final int otherCount = other.counts.get(i); - counts.put(i, counts.get(i) - otherCount); + final int otherCount = other.counts[i.index]; + counts[i.index] -= otherCount; totalCount -= otherCount; } } @@ -62,49 +60,29 @@ import java.util.Map; @Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) + 1") public void incr(final byte base) { final BaseIndex i = BaseIndex.byteToBase(base); - if (i != null) { // no Ns - counts.put(i, counts.get(i) + 1); - totalCount++; - } + counts[i.index]++; + totalCount++; } @Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) + 1") - public void incr(final byte base, final byte qual) { - final BaseIndex i = BaseIndex.byteToBase(base); - if (i != null) { // no Ns - counts.put(i, counts.get(i) + 1); - totalCount++; - sumQuals.put(i, sumQuals.get(i) + qual); - } + public void incr(final BaseIndex base, final byte qual) { + counts[base.index]++; + totalCount++; + sumQuals[base.index] += qual; } @Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) - 1") public void decr(final byte base) { final BaseIndex i = BaseIndex.byteToBase(base); - if (i != null) { // no Ns - counts.put(i, counts.get(i) - 1); - totalCount--; - } + counts[i.index]--; + totalCount--; } @Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) - 1") - public void decr(final byte base, final byte qual) { - final BaseIndex i = BaseIndex.byteToBase(base); - if (i != null) { // no Ns - counts.put(i, counts.get(i) - 1); - totalCount--; - sumQuals.put(i, sumQuals.get(i) - qual); - } - } - - @Ensures("result >= 0") - public int getCount(final byte base) { - return getCount(BaseIndex.byteToBase(base)); - } - - @Ensures("result >= 0") - public int getCount(final BaseIndex base) { - return counts.get(base); + public void decr(final BaseIndex base, final byte qual) { + counts[base.index]--; + totalCount--; + sumQuals[base.index] -= qual; } @Ensures("result >= 0") @@ -114,27 +92,32 @@ import java.util.Map; @Ensures("result >= 0") public long getSumQuals(final BaseIndex base) { - return sumQuals.get(base); + return sumQuals[base.index]; } @Ensures("result >= 0") public byte averageQuals(final byte base) { - return (byte) (getSumQuals(base) / getCount(base)); + return (byte) (getSumQuals(base) / countOfBase(base)); } @Ensures("result >= 0") public byte averageQuals(final BaseIndex base) { - return (byte) (getSumQuals(base) / getCount(base)); + return (byte) (getSumQuals(base) / countOfBase(base)); + } + + @Ensures("result >= 0") + public int countOfBase(final byte base) { + return countOfBase(BaseIndex.byteToBase(base)); } @Ensures("result >= 0") public int countOfBase(final BaseIndex base) { - return counts.get(base); + return counts[base.index]; } @Ensures("result >= 0") public long sumQualsOfBase(final BaseIndex base) { - return sumQuals.get(base); + return sumQuals[base.index]; } @Ensures("result >= 0") @@ -151,7 +134,7 @@ import java.util.Map; /** * Given a base , it returns the proportional count of this base compared to all other bases * - * @param base + * @param base base * @return the proportion of this base over all other bases */ @Ensures({"result >=0.0", "result<= 1.0"}) @@ -162,19 +145,19 @@ import java.util.Map; /** * Given a base , it returns the proportional count of this base compared to all other bases * - * @param baseIndex + * @param baseIndex base * @return the proportion of this base over all other bases */ @Ensures({"result >=0.0", "result<= 1.0"}) public double baseCountProportion(final BaseIndex baseIndex) { - return (totalCount == 0) ? 0.0 : (double)counts.get(baseIndex) / (double)totalCount; + return (totalCount == 0) ? 0.0 : (double)counts[baseIndex.index] / (double)totalCount; } @Ensures("result != null") public String toString() { StringBuilder b = new StringBuilder(); - for (Map.Entry elt : counts.entrySet()) { - b.append(elt.toString()).append("=").append(elt.getValue()).append(","); + for (final BaseIndex i : BaseIndex.values()) { + b.append(i.toString()).append("=").append(counts[i.index]).append(","); } return b.toString(); } @@ -186,9 +169,9 @@ import java.util.Map; @Ensures("result != null") public BaseIndex baseIndexWithMostCounts() { BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS; - for (Map.Entry entry : counts.entrySet()) { - if (entry.getValue() > counts.get(maxI)) - maxI = entry.getKey(); + for (final BaseIndex i : BaseIndex.values()) { + if (counts[i.index] > counts[maxI.index]) + maxI = i; } return maxI; } @@ -196,17 +179,17 @@ import java.util.Map; @Ensures("result != null") public BaseIndex baseIndexWithMostCountsWithoutIndels() { BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS; - for (Map.Entry entry : counts.entrySet()) { - if (entry.getKey().isNucleotide() && entry.getValue() > counts.get(maxI)) - maxI = entry.getKey(); + for (final BaseIndex i : BaseIndex.values()) { + if (i.isNucleotide() && counts[i.index] > counts[maxI.index]) + maxI = i; } return maxI; } private boolean hasHigherCount(final BaseIndex targetIndex, final BaseIndex testIndex) { - final int targetCount = counts.get(targetIndex); - final int testCount = counts.get(testIndex); - return ( targetCount > testCount || (targetCount == testCount && sumQuals.get(targetIndex) > sumQuals.get(testIndex)) ); + final int targetCount = counts[targetIndex.index]; + final int testCount = counts[testIndex.index]; + return ( targetCount > testCount || (targetCount == testCount && sumQuals[targetIndex.index] > sumQuals[testIndex.index]) ); } public byte baseWithMostProbability() { @@ -216,42 +199,42 @@ import java.util.Map; @Ensures("result != null") public BaseIndex baseIndexWithMostProbability() { BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS; - for (Map.Entry entry : sumQuals.entrySet()) { - if (entry.getValue() > sumQuals.get(maxI)) - maxI = entry.getKey(); + for (final BaseIndex i : BaseIndex.values()) { + if (sumQuals[i.index] > sumQuals[maxI.index]) + maxI = i; } - return (sumQuals.get(maxI) > 0L ? maxI : baseIndexWithMostCounts()); + return (sumQuals[maxI.index] > 0L ? maxI : baseIndexWithMostCounts()); } @Ensures("result != null") public BaseIndex baseIndexWithMostProbabilityWithoutIndels() { BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS; - for (Map.Entry entry : sumQuals.entrySet()) { - if (entry.getKey().isNucleotide() && entry.getValue() > sumQuals.get(maxI)) - maxI = entry.getKey(); + for (final BaseIndex i : BaseIndex.values()) { + if (i.isNucleotide() && sumQuals[i.index] > sumQuals[maxI.index]) + maxI = i; } - return (sumQuals.get(maxI) > 0L ? maxI : baseIndexWithMostCountsWithoutIndels()); + return (sumQuals[maxI.index] > 0L ? maxI : baseIndexWithMostCountsWithoutIndels()); } @Ensures("result >=0") public int totalCountWithoutIndels() { - return totalCount - counts.get(BaseIndex.D) - counts.get(BaseIndex.I); + return totalCount - counts[BaseIndex.D.index] - counts[BaseIndex.I.index]; } /** * Calculates the proportional count of a base compared to all other bases except indels (I and D) * - * @param index + * @param base base * @return the proportion of this base over all other bases except indels */ @Requires("index.isNucleotide()") @Ensures({"result >=0.0", "result<= 1.0"}) - public double baseCountProportionWithoutIndels(final BaseIndex index) { + public double baseCountProportionWithoutIndels(final BaseIndex base) { final int total = totalCountWithoutIndels(); - return (total == 0) ? 0.0 : (double)counts.get(index) / (double)total; + return (total == 0) ? 0.0 : (double)counts[base.index] / (double)total; } - public Object[] countsArray() { - return counts.values().toArray(); + public int[] countsArray() { + return counts.clone(); } } diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/BaseIndex.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/BaseIndex.java index a64db5874..02f867bcb 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/BaseIndex.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/BaseIndex.java @@ -1,5 +1,7 @@ package org.broadinstitute.sting.gatk.walkers.compression.reducereads; +import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; + /** * Simple byte / base index conversions * @@ -56,7 +58,7 @@ public enum BaseIndex { case 'N': case 'n': return N; - default: return null; + default: throw new ReviewedStingException("Tried to create a byte index for an impossible base " + base); } } @@ -68,7 +70,7 @@ public enum BaseIndex { * @return whether or not it is a nucleotide, given the definition above */ public boolean isNucleotide() { - return this == A || this == C || this == G || this == T || this == N; + return !isIndel(); } /** diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/HeaderElement.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/HeaderElement.java index 272512bdb..3097c2ee9 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/HeaderElement.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/compression/reducereads/HeaderElement.java @@ -213,11 +213,11 @@ public class HeaderElement { if (totalCount == 0) return 0; - Object[] countsArray = consensusBaseCounts.countsArray(); + int[] countsArray = consensusBaseCounts.countsArray(); Arrays.sort(countsArray); for (int i = countsArray.length-1; i>=0; i--) { nHaplotypes++; - runningCount += (Integer) countsArray[i]; + runningCount += countsArray[i]; if (runningCount/totalCount > minVariantProportion) break; }