RR optimization: profiling was showing that the BaseCounts class was a major bottleneck because the underlying implementation was a HashMap. Given that the map index was an indexable Enum anyways, it makes a lot more sense to implement as a native array. Knocks 30% off the runtime in bad regions.

This commit is contained in:
Eric Banks 2012-10-17 21:44:53 -04:00
parent 33df1afe0e
commit 20ffbcc86e
4 changed files with 89 additions and 111 deletions

View File

@ -1,8 +1,5 @@
package org.broadinstitute.sting.gatk.walkers.compression.reducereads; package org.broadinstitute.sting.gatk.walkers.compression.reducereads;
import java.util.HashMap;
import java.util.Map;
/** /**
* An object that keeps track of the base counts as well as the sum of the base, insertion and deletion qualities of each base. * An object that keeps track of the base counts as well as the sum of the base, insertion and deletion qualities of each base.
* *
@ -10,35 +7,31 @@ import java.util.Map;
* @since 6/15/12 * @since 6/15/12
*/ */
public class BaseAndQualsCounts extends BaseCounts { public class BaseAndQualsCounts extends BaseCounts {
private final Map<BaseIndex, Long> sumInsertionQuals; private final long[] sumInsertionQuals;
private final Map<BaseIndex, Long> sumDeletionQuals; private final long[] sumDeletionQuals;
public BaseAndQualsCounts() { public BaseAndQualsCounts() {
super(); super();
this.sumInsertionQuals = new HashMap<BaseIndex, Long>(); this.sumInsertionQuals = new long[BaseIndex.values().length];
this.sumDeletionQuals = new HashMap<BaseIndex, Long>(); this.sumDeletionQuals = new long[BaseIndex.values().length];
for (BaseIndex i : BaseIndex.values()) { for (final BaseIndex i : BaseIndex.values()) {
sumInsertionQuals.put(i, 0L); sumInsertionQuals[i.index] = 0L;
sumDeletionQuals.put(i, 0L); sumDeletionQuals[i.index] = 0L;
} }
} }
public void incr(final byte base, final byte baseQual, final byte insQual, final byte delQual) { public void incr(final byte base, final byte baseQual, final byte insQual, final byte delQual) {
super.incr(base, baseQual); final BaseIndex i = BaseIndex.byteToBase(base);
BaseIndex i = BaseIndex.byteToBase(base); super.incr(i, baseQual);
if (i != null) { // do not allow Ns sumInsertionQuals[i.index] += insQual;
sumInsertionQuals.put(i, sumInsertionQuals.get(i) + insQual); sumDeletionQuals[i.index] += delQual;
sumDeletionQuals.put(i, sumDeletionQuals.get(i) + delQual);
}
} }
public void decr(final byte base, final byte baseQual, final byte insQual, final byte delQual) { public void decr(final byte base, final byte baseQual, final byte insQual, final byte delQual) {
super.decr(base, baseQual); final BaseIndex i = BaseIndex.byteToBase(base);
BaseIndex i = BaseIndex.byteToBase(base); super.decr(i, baseQual);
if (i != null) { // do not allow Ns sumInsertionQuals[i.index] -= insQual;
sumInsertionQuals.put(i, sumInsertionQuals.get(i) - insQual); sumDeletionQuals[i.index] -= delQual;
sumDeletionQuals.put(i, sumDeletionQuals.get(i) - delQual);
}
} }
public byte averageInsertionQualsOfBase(final BaseIndex base) { public byte averageInsertionQualsOfBase(final BaseIndex base) {
@ -49,7 +42,7 @@ public class BaseAndQualsCounts extends BaseCounts {
return getGenericAverageQualOfBase(base, sumDeletionQuals); return getGenericAverageQualOfBase(base, sumDeletionQuals);
} }
private byte getGenericAverageQualOfBase(final BaseIndex base, final Map<BaseIndex, Long> sumQuals) { private byte getGenericAverageQualOfBase(final BaseIndex base, final long[] sumQuals) {
return (byte) (sumQuals.get(base) / getCount(base)); return (byte) (sumQuals[base.index] / countOfBase(base));
} }
} }

View File

@ -3,8 +3,6 @@ package org.broadinstitute.sting.gatk.walkers.compression.reducereads;
import com.google.java.contract.Ensures; import com.google.java.contract.Ensures;
import com.google.java.contract.Requires; import com.google.java.contract.Requires;
import java.util.EnumMap;
import java.util.Map;
/** /**
* An object to keep track of the number of occurrences of each base and it's quality. * An object to keep track of the number of occurrences of each base and it's quality.
@ -18,25 +16,25 @@ import java.util.Map;
public final static BaseIndex MAX_BASE_INDEX_WITH_NO_COUNTS = BaseIndex.N; public final static BaseIndex MAX_BASE_INDEX_WITH_NO_COUNTS = BaseIndex.N;
public final static byte MAX_BASE_WITH_NO_COUNTS = MAX_BASE_INDEX_WITH_NO_COUNTS.getByte(); public final static byte MAX_BASE_WITH_NO_COUNTS = MAX_BASE_INDEX_WITH_NO_COUNTS.getByte();
private final Map<BaseIndex, Integer> counts; // keeps track of the base counts private final int[] counts; // keeps track of the base counts
private final Map<BaseIndex, Long> sumQuals; // keeps track of the quals of each base private final long[] sumQuals; // keeps track of the quals of each base
private int totalCount = 0; // keeps track of total count since this is requested so often private int totalCount = 0; // keeps track of total count since this is requested so often
public BaseCounts() { public BaseCounts() {
counts = new EnumMap<BaseIndex, Integer>(BaseIndex.class); counts = new int[BaseIndex.values().length];
sumQuals = new EnumMap<BaseIndex, Long>(BaseIndex.class); sumQuals = new long[BaseIndex.values().length];
for (BaseIndex i : BaseIndex.values()) { for (final BaseIndex i : BaseIndex.values()) {
counts.put(i, 0); counts[i.index] = 0;
sumQuals.put(i, 0L); sumQuals[i.index] = 0L;
} }
} }
public static BaseCounts createWithCounts(int[] countsACGT) { public static BaseCounts createWithCounts(int[] countsACGT) {
BaseCounts baseCounts = new BaseCounts(); BaseCounts baseCounts = new BaseCounts();
baseCounts.counts.put(BaseIndex.A, countsACGT[0]); baseCounts.counts[BaseIndex.A.index] = countsACGT[0];
baseCounts.counts.put(BaseIndex.C, countsACGT[1]); baseCounts.counts[BaseIndex.C.index] = countsACGT[1];
baseCounts.counts.put(BaseIndex.G, countsACGT[2]); baseCounts.counts[BaseIndex.G.index] = countsACGT[2];
baseCounts.counts.put(BaseIndex.T, countsACGT[3]); baseCounts.counts[BaseIndex.T.index] = countsACGT[3];
baseCounts.totalCount = countsACGT[0] + countsACGT[1] + countsACGT[2] + countsACGT[3]; baseCounts.totalCount = countsACGT[0] + countsACGT[1] + countsACGT[2] + countsACGT[3];
return baseCounts; return baseCounts;
} }
@ -44,8 +42,8 @@ import java.util.Map;
@Requires("other != null") @Requires("other != null")
public void add(final BaseCounts other) { public void add(final BaseCounts other) {
for (final BaseIndex i : BaseIndex.values()) { for (final BaseIndex i : BaseIndex.values()) {
final int otherCount = other.counts.get(i); final int otherCount = other.counts[i.index];
counts.put(i, counts.get(i) + otherCount); counts[i.index] += otherCount;
totalCount += otherCount; totalCount += otherCount;
} }
} }
@ -53,8 +51,8 @@ import java.util.Map;
@Requires("other != null") @Requires("other != null")
public void sub(final BaseCounts other) { public void sub(final BaseCounts other) {
for (final BaseIndex i : BaseIndex.values()) { for (final BaseIndex i : BaseIndex.values()) {
final int otherCount = other.counts.get(i); final int otherCount = other.counts[i.index];
counts.put(i, counts.get(i) - otherCount); counts[i.index] -= otherCount;
totalCount -= otherCount; totalCount -= otherCount;
} }
} }
@ -62,49 +60,29 @@ import java.util.Map;
@Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) + 1") @Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) + 1")
public void incr(final byte base) { public void incr(final byte base) {
final BaseIndex i = BaseIndex.byteToBase(base); final BaseIndex i = BaseIndex.byteToBase(base);
if (i != null) { // no Ns counts[i.index]++;
counts.put(i, counts.get(i) + 1); totalCount++;
totalCount++;
}
} }
@Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) + 1") @Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) + 1")
public void incr(final byte base, final byte qual) { public void incr(final BaseIndex base, final byte qual) {
final BaseIndex i = BaseIndex.byteToBase(base); counts[base.index]++;
if (i != null) { // no Ns totalCount++;
counts.put(i, counts.get(i) + 1); sumQuals[base.index] += qual;
totalCount++;
sumQuals.put(i, sumQuals.get(i) + qual);
}
} }
@Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) - 1") @Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) - 1")
public void decr(final byte base) { public void decr(final byte base) {
final BaseIndex i = BaseIndex.byteToBase(base); final BaseIndex i = BaseIndex.byteToBase(base);
if (i != null) { // no Ns counts[i.index]--;
counts.put(i, counts.get(i) - 1); totalCount--;
totalCount--;
}
} }
@Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) - 1") @Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) - 1")
public void decr(final byte base, final byte qual) { public void decr(final BaseIndex base, final byte qual) {
final BaseIndex i = BaseIndex.byteToBase(base); counts[base.index]--;
if (i != null) { // no Ns totalCount--;
counts.put(i, counts.get(i) - 1); sumQuals[base.index] -= qual;
totalCount--;
sumQuals.put(i, sumQuals.get(i) - qual);
}
}
@Ensures("result >= 0")
public int getCount(final byte base) {
return getCount(BaseIndex.byteToBase(base));
}
@Ensures("result >= 0")
public int getCount(final BaseIndex base) {
return counts.get(base);
} }
@Ensures("result >= 0") @Ensures("result >= 0")
@ -114,27 +92,32 @@ import java.util.Map;
@Ensures("result >= 0") @Ensures("result >= 0")
public long getSumQuals(final BaseIndex base) { public long getSumQuals(final BaseIndex base) {
return sumQuals.get(base); return sumQuals[base.index];
} }
@Ensures("result >= 0") @Ensures("result >= 0")
public byte averageQuals(final byte base) { public byte averageQuals(final byte base) {
return (byte) (getSumQuals(base) / getCount(base)); return (byte) (getSumQuals(base) / countOfBase(base));
} }
@Ensures("result >= 0") @Ensures("result >= 0")
public byte averageQuals(final BaseIndex base) { public byte averageQuals(final BaseIndex base) {
return (byte) (getSumQuals(base) / getCount(base)); return (byte) (getSumQuals(base) / countOfBase(base));
}
@Ensures("result >= 0")
public int countOfBase(final byte base) {
return countOfBase(BaseIndex.byteToBase(base));
} }
@Ensures("result >= 0") @Ensures("result >= 0")
public int countOfBase(final BaseIndex base) { public int countOfBase(final BaseIndex base) {
return counts.get(base); return counts[base.index];
} }
@Ensures("result >= 0") @Ensures("result >= 0")
public long sumQualsOfBase(final BaseIndex base) { public long sumQualsOfBase(final BaseIndex base) {
return sumQuals.get(base); return sumQuals[base.index];
} }
@Ensures("result >= 0") @Ensures("result >= 0")
@ -151,7 +134,7 @@ import java.util.Map;
/** /**
* Given a base , it returns the proportional count of this base compared to all other bases * Given a base , it returns the proportional count of this base compared to all other bases
* *
* @param base * @param base base
* @return the proportion of this base over all other bases * @return the proportion of this base over all other bases
*/ */
@Ensures({"result >=0.0", "result<= 1.0"}) @Ensures({"result >=0.0", "result<= 1.0"})
@ -162,19 +145,19 @@ import java.util.Map;
/** /**
* Given a base , it returns the proportional count of this base compared to all other bases * Given a base , it returns the proportional count of this base compared to all other bases
* *
* @param baseIndex * @param baseIndex base
* @return the proportion of this base over all other bases * @return the proportion of this base over all other bases
*/ */
@Ensures({"result >=0.0", "result<= 1.0"}) @Ensures({"result >=0.0", "result<= 1.0"})
public double baseCountProportion(final BaseIndex baseIndex) { public double baseCountProportion(final BaseIndex baseIndex) {
return (totalCount == 0) ? 0.0 : (double)counts.get(baseIndex) / (double)totalCount; return (totalCount == 0) ? 0.0 : (double)counts[baseIndex.index] / (double)totalCount;
} }
@Ensures("result != null") @Ensures("result != null")
public String toString() { public String toString() {
StringBuilder b = new StringBuilder(); StringBuilder b = new StringBuilder();
for (Map.Entry<BaseIndex, Integer> elt : counts.entrySet()) { for (final BaseIndex i : BaseIndex.values()) {
b.append(elt.toString()).append("=").append(elt.getValue()).append(","); b.append(i.toString()).append("=").append(counts[i.index]).append(",");
} }
return b.toString(); return b.toString();
} }
@ -186,9 +169,9 @@ import java.util.Map;
@Ensures("result != null") @Ensures("result != null")
public BaseIndex baseIndexWithMostCounts() { public BaseIndex baseIndexWithMostCounts() {
BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS; BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS;
for (Map.Entry<BaseIndex, Integer> entry : counts.entrySet()) { for (final BaseIndex i : BaseIndex.values()) {
if (entry.getValue() > counts.get(maxI)) if (counts[i.index] > counts[maxI.index])
maxI = entry.getKey(); maxI = i;
} }
return maxI; return maxI;
} }
@ -196,17 +179,17 @@ import java.util.Map;
@Ensures("result != null") @Ensures("result != null")
public BaseIndex baseIndexWithMostCountsWithoutIndels() { public BaseIndex baseIndexWithMostCountsWithoutIndels() {
BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS; BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS;
for (Map.Entry<BaseIndex, Integer> entry : counts.entrySet()) { for (final BaseIndex i : BaseIndex.values()) {
if (entry.getKey().isNucleotide() && entry.getValue() > counts.get(maxI)) if (i.isNucleotide() && counts[i.index] > counts[maxI.index])
maxI = entry.getKey(); maxI = i;
} }
return maxI; return maxI;
} }
private boolean hasHigherCount(final BaseIndex targetIndex, final BaseIndex testIndex) { private boolean hasHigherCount(final BaseIndex targetIndex, final BaseIndex testIndex) {
final int targetCount = counts.get(targetIndex); final int targetCount = counts[targetIndex.index];
final int testCount = counts.get(testIndex); final int testCount = counts[testIndex.index];
return ( targetCount > testCount || (targetCount == testCount && sumQuals.get(targetIndex) > sumQuals.get(testIndex)) ); return ( targetCount > testCount || (targetCount == testCount && sumQuals[targetIndex.index] > sumQuals[testIndex.index]) );
} }
public byte baseWithMostProbability() { public byte baseWithMostProbability() {
@ -216,42 +199,42 @@ import java.util.Map;
@Ensures("result != null") @Ensures("result != null")
public BaseIndex baseIndexWithMostProbability() { public BaseIndex baseIndexWithMostProbability() {
BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS; BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS;
for (Map.Entry<BaseIndex, Long> entry : sumQuals.entrySet()) { for (final BaseIndex i : BaseIndex.values()) {
if (entry.getValue() > sumQuals.get(maxI)) if (sumQuals[i.index] > sumQuals[maxI.index])
maxI = entry.getKey(); maxI = i;
} }
return (sumQuals.get(maxI) > 0L ? maxI : baseIndexWithMostCounts()); return (sumQuals[maxI.index] > 0L ? maxI : baseIndexWithMostCounts());
} }
@Ensures("result != null") @Ensures("result != null")
public BaseIndex baseIndexWithMostProbabilityWithoutIndels() { public BaseIndex baseIndexWithMostProbabilityWithoutIndels() {
BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS; BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS;
for (Map.Entry<BaseIndex, Long> entry : sumQuals.entrySet()) { for (final BaseIndex i : BaseIndex.values()) {
if (entry.getKey().isNucleotide() && entry.getValue() > sumQuals.get(maxI)) if (i.isNucleotide() && sumQuals[i.index] > sumQuals[maxI.index])
maxI = entry.getKey(); maxI = i;
} }
return (sumQuals.get(maxI) > 0L ? maxI : baseIndexWithMostCountsWithoutIndels()); return (sumQuals[maxI.index] > 0L ? maxI : baseIndexWithMostCountsWithoutIndels());
} }
@Ensures("result >=0") @Ensures("result >=0")
public int totalCountWithoutIndels() { public int totalCountWithoutIndels() {
return totalCount - counts.get(BaseIndex.D) - counts.get(BaseIndex.I); return totalCount - counts[BaseIndex.D.index] - counts[BaseIndex.I.index];
} }
/** /**
* Calculates the proportional count of a base compared to all other bases except indels (I and D) * Calculates the proportional count of a base compared to all other bases except indels (I and D)
* *
* @param index * @param base base
* @return the proportion of this base over all other bases except indels * @return the proportion of this base over all other bases except indels
*/ */
@Requires("index.isNucleotide()") @Requires("index.isNucleotide()")
@Ensures({"result >=0.0", "result<= 1.0"}) @Ensures({"result >=0.0", "result<= 1.0"})
public double baseCountProportionWithoutIndels(final BaseIndex index) { public double baseCountProportionWithoutIndels(final BaseIndex base) {
final int total = totalCountWithoutIndels(); final int total = totalCountWithoutIndels();
return (total == 0) ? 0.0 : (double)counts.get(index) / (double)total; return (total == 0) ? 0.0 : (double)counts[base.index] / (double)total;
} }
public Object[] countsArray() { public int[] countsArray() {
return counts.values().toArray(); return counts.clone();
} }
} }

View File

@ -1,5 +1,7 @@
package org.broadinstitute.sting.gatk.walkers.compression.reducereads; package org.broadinstitute.sting.gatk.walkers.compression.reducereads;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
/** /**
* Simple byte / base index conversions * Simple byte / base index conversions
* *
@ -56,7 +58,7 @@ public enum BaseIndex {
case 'N': case 'N':
case 'n': case 'n':
return N; return N;
default: return null; default: throw new ReviewedStingException("Tried to create a byte index for an impossible base " + base);
} }
} }
@ -68,7 +70,7 @@ public enum BaseIndex {
* @return whether or not it is a nucleotide, given the definition above * @return whether or not it is a nucleotide, given the definition above
*/ */
public boolean isNucleotide() { public boolean isNucleotide() {
return this == A || this == C || this == G || this == T || this == N; return !isIndel();
} }
/** /**

View File

@ -213,11 +213,11 @@ public class HeaderElement {
if (totalCount == 0) if (totalCount == 0)
return 0; return 0;
Object[] countsArray = consensusBaseCounts.countsArray(); int[] countsArray = consensusBaseCounts.countsArray();
Arrays.sort(countsArray); Arrays.sort(countsArray);
for (int i = countsArray.length-1; i>=0; i--) { for (int i = countsArray.length-1; i>=0; i--) {
nHaplotypes++; nHaplotypes++;
runningCount += (Integer) countsArray[i]; runningCount += countsArray[i];
if (runningCount/totalCount > minVariantProportion) if (runningCount/totalCount > minVariantProportion)
break; break;
} }