RR optimization: profiling was showing that the BaseCounts class was a major bottleneck because the underlying implementation was a HashMap. Given that the map index was an indexable Enum anyways, it makes a lot more sense to implement as a native array. Knocks 30% off the runtime in bad regions.

This commit is contained in:
Eric Banks 2012-10-17 21:44:53 -04:00
parent 33df1afe0e
commit 20ffbcc86e
4 changed files with 89 additions and 111 deletions

View File

@ -1,8 +1,5 @@
package org.broadinstitute.sting.gatk.walkers.compression.reducereads;
import java.util.HashMap;
import java.util.Map;
/**
* An object that keeps track of the base counts as well as the sum of the base, insertion and deletion qualities of each base.
*
@ -10,35 +7,31 @@ import java.util.Map;
* @since 6/15/12
*/
public class BaseAndQualsCounts extends BaseCounts {
private final Map<BaseIndex, Long> sumInsertionQuals;
private final Map<BaseIndex, Long> sumDeletionQuals;
private final long[] sumInsertionQuals;
private final long[] sumDeletionQuals;
public BaseAndQualsCounts() {
super();
this.sumInsertionQuals = new HashMap<BaseIndex, Long>();
this.sumDeletionQuals = new HashMap<BaseIndex, Long>();
for (BaseIndex i : BaseIndex.values()) {
sumInsertionQuals.put(i, 0L);
sumDeletionQuals.put(i, 0L);
this.sumInsertionQuals = new long[BaseIndex.values().length];
this.sumDeletionQuals = new long[BaseIndex.values().length];
for (final BaseIndex i : BaseIndex.values()) {
sumInsertionQuals[i.index] = 0L;
sumDeletionQuals[i.index] = 0L;
}
}
public void incr(final byte base, final byte baseQual, final byte insQual, final byte delQual) {
super.incr(base, baseQual);
BaseIndex i = BaseIndex.byteToBase(base);
if (i != null) { // do not allow Ns
sumInsertionQuals.put(i, sumInsertionQuals.get(i) + insQual);
sumDeletionQuals.put(i, sumDeletionQuals.get(i) + delQual);
}
final BaseIndex i = BaseIndex.byteToBase(base);
super.incr(i, baseQual);
sumInsertionQuals[i.index] += insQual;
sumDeletionQuals[i.index] += delQual;
}
public void decr(final byte base, final byte baseQual, final byte insQual, final byte delQual) {
super.decr(base, baseQual);
BaseIndex i = BaseIndex.byteToBase(base);
if (i != null) { // do not allow Ns
sumInsertionQuals.put(i, sumInsertionQuals.get(i) - insQual);
sumDeletionQuals.put(i, sumDeletionQuals.get(i) - delQual);
}
final BaseIndex i = BaseIndex.byteToBase(base);
super.decr(i, baseQual);
sumInsertionQuals[i.index] -= insQual;
sumDeletionQuals[i.index] -= delQual;
}
public byte averageInsertionQualsOfBase(final BaseIndex base) {
@ -49,7 +42,7 @@ public class BaseAndQualsCounts extends BaseCounts {
return getGenericAverageQualOfBase(base, sumDeletionQuals);
}
private byte getGenericAverageQualOfBase(final BaseIndex base, final Map<BaseIndex, Long> sumQuals) {
return (byte) (sumQuals.get(base) / getCount(base));
private byte getGenericAverageQualOfBase(final BaseIndex base, final long[] sumQuals) {
return (byte) (sumQuals[base.index] / countOfBase(base));
}
}

View File

@ -3,8 +3,6 @@ package org.broadinstitute.sting.gatk.walkers.compression.reducereads;
import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import java.util.EnumMap;
import java.util.Map;
/**
* An object to keep track of the number of occurrences of each base and it's quality.
@ -18,25 +16,25 @@ import java.util.Map;
public final static BaseIndex MAX_BASE_INDEX_WITH_NO_COUNTS = BaseIndex.N;
public final static byte MAX_BASE_WITH_NO_COUNTS = MAX_BASE_INDEX_WITH_NO_COUNTS.getByte();
private final Map<BaseIndex, Integer> counts; // keeps track of the base counts
private final Map<BaseIndex, Long> sumQuals; // keeps track of the quals of each base
private int totalCount = 0; // keeps track of total count since this is requested so often
private final int[] counts; // keeps track of the base counts
private final long[] sumQuals; // keeps track of the quals of each base
private int totalCount = 0; // keeps track of total count since this is requested so often
public BaseCounts() {
counts = new EnumMap<BaseIndex, Integer>(BaseIndex.class);
sumQuals = new EnumMap<BaseIndex, Long>(BaseIndex.class);
for (BaseIndex i : BaseIndex.values()) {
counts.put(i, 0);
sumQuals.put(i, 0L);
counts = new int[BaseIndex.values().length];
sumQuals = new long[BaseIndex.values().length];
for (final BaseIndex i : BaseIndex.values()) {
counts[i.index] = 0;
sumQuals[i.index] = 0L;
}
}
public static BaseCounts createWithCounts(int[] countsACGT) {
BaseCounts baseCounts = new BaseCounts();
baseCounts.counts.put(BaseIndex.A, countsACGT[0]);
baseCounts.counts.put(BaseIndex.C, countsACGT[1]);
baseCounts.counts.put(BaseIndex.G, countsACGT[2]);
baseCounts.counts.put(BaseIndex.T, countsACGT[3]);
baseCounts.counts[BaseIndex.A.index] = countsACGT[0];
baseCounts.counts[BaseIndex.C.index] = countsACGT[1];
baseCounts.counts[BaseIndex.G.index] = countsACGT[2];
baseCounts.counts[BaseIndex.T.index] = countsACGT[3];
baseCounts.totalCount = countsACGT[0] + countsACGT[1] + countsACGT[2] + countsACGT[3];
return baseCounts;
}
@ -44,8 +42,8 @@ import java.util.Map;
@Requires("other != null")
public void add(final BaseCounts other) {
for (final BaseIndex i : BaseIndex.values()) {
final int otherCount = other.counts.get(i);
counts.put(i, counts.get(i) + otherCount);
final int otherCount = other.counts[i.index];
counts[i.index] += otherCount;
totalCount += otherCount;
}
}
@ -53,8 +51,8 @@ import java.util.Map;
@Requires("other != null")
public void sub(final BaseCounts other) {
for (final BaseIndex i : BaseIndex.values()) {
final int otherCount = other.counts.get(i);
counts.put(i, counts.get(i) - otherCount);
final int otherCount = other.counts[i.index];
counts[i.index] -= otherCount;
totalCount -= otherCount;
}
}
@ -62,49 +60,29 @@ import java.util.Map;
@Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) + 1")
public void incr(final byte base) {
final BaseIndex i = BaseIndex.byteToBase(base);
if (i != null) { // no Ns
counts.put(i, counts.get(i) + 1);
totalCount++;
}
counts[i.index]++;
totalCount++;
}
@Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) + 1")
public void incr(final byte base, final byte qual) {
final BaseIndex i = BaseIndex.byteToBase(base);
if (i != null) { // no Ns
counts.put(i, counts.get(i) + 1);
totalCount++;
sumQuals.put(i, sumQuals.get(i) + qual);
}
public void incr(final BaseIndex base, final byte qual) {
counts[base.index]++;
totalCount++;
sumQuals[base.index] += qual;
}
@Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) - 1")
public void decr(final byte base) {
final BaseIndex i = BaseIndex.byteToBase(base);
if (i != null) { // no Ns
counts.put(i, counts.get(i) - 1);
totalCount--;
}
counts[i.index]--;
totalCount--;
}
@Ensures("totalCount() == old(totalCount()) || totalCount() == old(totalCount()) - 1")
public void decr(final byte base, final byte qual) {
final BaseIndex i = BaseIndex.byteToBase(base);
if (i != null) { // no Ns
counts.put(i, counts.get(i) - 1);
totalCount--;
sumQuals.put(i, sumQuals.get(i) - qual);
}
}
@Ensures("result >= 0")
public int getCount(final byte base) {
return getCount(BaseIndex.byteToBase(base));
}
@Ensures("result >= 0")
public int getCount(final BaseIndex base) {
return counts.get(base);
public void decr(final BaseIndex base, final byte qual) {
counts[base.index]--;
totalCount--;
sumQuals[base.index] -= qual;
}
@Ensures("result >= 0")
@ -114,27 +92,32 @@ import java.util.Map;
@Ensures("result >= 0")
public long getSumQuals(final BaseIndex base) {
return sumQuals.get(base);
return sumQuals[base.index];
}
@Ensures("result >= 0")
public byte averageQuals(final byte base) {
return (byte) (getSumQuals(base) / getCount(base));
return (byte) (getSumQuals(base) / countOfBase(base));
}
@Ensures("result >= 0")
public byte averageQuals(final BaseIndex base) {
return (byte) (getSumQuals(base) / getCount(base));
return (byte) (getSumQuals(base) / countOfBase(base));
}
@Ensures("result >= 0")
public int countOfBase(final byte base) {
return countOfBase(BaseIndex.byteToBase(base));
}
@Ensures("result >= 0")
public int countOfBase(final BaseIndex base) {
return counts.get(base);
return counts[base.index];
}
@Ensures("result >= 0")
public long sumQualsOfBase(final BaseIndex base) {
return sumQuals.get(base);
return sumQuals[base.index];
}
@Ensures("result >= 0")
@ -151,7 +134,7 @@ import java.util.Map;
/**
* Given a base , it returns the proportional count of this base compared to all other bases
*
* @param base
* @param base base
* @return the proportion of this base over all other bases
*/
@Ensures({"result >=0.0", "result<= 1.0"})
@ -162,19 +145,19 @@ import java.util.Map;
/**
* Given a base , it returns the proportional count of this base compared to all other bases
*
* @param baseIndex
* @param baseIndex base
* @return the proportion of this base over all other bases
*/
@Ensures({"result >=0.0", "result<= 1.0"})
public double baseCountProportion(final BaseIndex baseIndex) {
return (totalCount == 0) ? 0.0 : (double)counts.get(baseIndex) / (double)totalCount;
return (totalCount == 0) ? 0.0 : (double)counts[baseIndex.index] / (double)totalCount;
}
@Ensures("result != null")
public String toString() {
StringBuilder b = new StringBuilder();
for (Map.Entry<BaseIndex, Integer> elt : counts.entrySet()) {
b.append(elt.toString()).append("=").append(elt.getValue()).append(",");
for (final BaseIndex i : BaseIndex.values()) {
b.append(i.toString()).append("=").append(counts[i.index]).append(",");
}
return b.toString();
}
@ -186,9 +169,9 @@ import java.util.Map;
@Ensures("result != null")
public BaseIndex baseIndexWithMostCounts() {
BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS;
for (Map.Entry<BaseIndex, Integer> entry : counts.entrySet()) {
if (entry.getValue() > counts.get(maxI))
maxI = entry.getKey();
for (final BaseIndex i : BaseIndex.values()) {
if (counts[i.index] > counts[maxI.index])
maxI = i;
}
return maxI;
}
@ -196,17 +179,17 @@ import java.util.Map;
@Ensures("result != null")
public BaseIndex baseIndexWithMostCountsWithoutIndels() {
BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS;
for (Map.Entry<BaseIndex, Integer> entry : counts.entrySet()) {
if (entry.getKey().isNucleotide() && entry.getValue() > counts.get(maxI))
maxI = entry.getKey();
for (final BaseIndex i : BaseIndex.values()) {
if (i.isNucleotide() && counts[i.index] > counts[maxI.index])
maxI = i;
}
return maxI;
}
private boolean hasHigherCount(final BaseIndex targetIndex, final BaseIndex testIndex) {
final int targetCount = counts.get(targetIndex);
final int testCount = counts.get(testIndex);
return ( targetCount > testCount || (targetCount == testCount && sumQuals.get(targetIndex) > sumQuals.get(testIndex)) );
final int targetCount = counts[targetIndex.index];
final int testCount = counts[testIndex.index];
return ( targetCount > testCount || (targetCount == testCount && sumQuals[targetIndex.index] > sumQuals[testIndex.index]) );
}
public byte baseWithMostProbability() {
@ -216,42 +199,42 @@ import java.util.Map;
@Ensures("result != null")
public BaseIndex baseIndexWithMostProbability() {
BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS;
for (Map.Entry<BaseIndex, Long> entry : sumQuals.entrySet()) {
if (entry.getValue() > sumQuals.get(maxI))
maxI = entry.getKey();
for (final BaseIndex i : BaseIndex.values()) {
if (sumQuals[i.index] > sumQuals[maxI.index])
maxI = i;
}
return (sumQuals.get(maxI) > 0L ? maxI : baseIndexWithMostCounts());
return (sumQuals[maxI.index] > 0L ? maxI : baseIndexWithMostCounts());
}
@Ensures("result != null")
public BaseIndex baseIndexWithMostProbabilityWithoutIndels() {
BaseIndex maxI = MAX_BASE_INDEX_WITH_NO_COUNTS;
for (Map.Entry<BaseIndex, Long> entry : sumQuals.entrySet()) {
if (entry.getKey().isNucleotide() && entry.getValue() > sumQuals.get(maxI))
maxI = entry.getKey();
for (final BaseIndex i : BaseIndex.values()) {
if (i.isNucleotide() && sumQuals[i.index] > sumQuals[maxI.index])
maxI = i;
}
return (sumQuals.get(maxI) > 0L ? maxI : baseIndexWithMostCountsWithoutIndels());
return (sumQuals[maxI.index] > 0L ? maxI : baseIndexWithMostCountsWithoutIndels());
}
@Ensures("result >=0")
public int totalCountWithoutIndels() {
return totalCount - counts.get(BaseIndex.D) - counts.get(BaseIndex.I);
return totalCount - counts[BaseIndex.D.index] - counts[BaseIndex.I.index];
}
/**
* Calculates the proportional count of a base compared to all other bases except indels (I and D)
*
* @param index
* @param base base
* @return the proportion of this base over all other bases except indels
*/
@Requires("index.isNucleotide()")
@Ensures({"result >=0.0", "result<= 1.0"})
public double baseCountProportionWithoutIndels(final BaseIndex index) {
public double baseCountProportionWithoutIndels(final BaseIndex base) {
final int total = totalCountWithoutIndels();
return (total == 0) ? 0.0 : (double)counts.get(index) / (double)total;
return (total == 0) ? 0.0 : (double)counts[base.index] / (double)total;
}
public Object[] countsArray() {
return counts.values().toArray();
public int[] countsArray() {
return counts.clone();
}
}

View File

@ -1,5 +1,7 @@
package org.broadinstitute.sting.gatk.walkers.compression.reducereads;
import org.broadinstitute.sting.utils.exceptions.ReviewedStingException;
/**
* Simple byte / base index conversions
*
@ -56,7 +58,7 @@ public enum BaseIndex {
case 'N':
case 'n':
return N;
default: return null;
default: throw new ReviewedStingException("Tried to create a byte index for an impossible base " + base);
}
}
@ -68,7 +70,7 @@ public enum BaseIndex {
* @return whether or not it is a nucleotide, given the definition above
*/
public boolean isNucleotide() {
return this == A || this == C || this == G || this == T || this == N;
return !isIndel();
}
/**

View File

@ -213,11 +213,11 @@ public class HeaderElement {
if (totalCount == 0)
return 0;
Object[] countsArray = consensusBaseCounts.countsArray();
int[] countsArray = consensusBaseCounts.countsArray();
Arrays.sort(countsArray);
for (int i = countsArray.length-1; i>=0; i--) {
nHaplotypes++;
runningCount += (Integer) countsArray[i];
runningCount += countsArray[i];
if (runningCount/totalCount > minVariantProportion)
break;
}