Round one of "easy" zero-effort optimizations to UG's indel caller. Mostly inline functions, avoid repeated computation and try to optimize SoftMaxPair() which is by far the bigest runtime hog. More to come...

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@5666 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
delangel 2011-04-20 18:57:34 +00:00
parent d8b8f857f3
commit 246d8190b5
2 changed files with 82 additions and 83 deletions

View File

@ -27,6 +27,7 @@ package org.broadinstitute.sting.gatk.walkers.indels;
import net.sf.samtools.Cigar;
import net.sf.samtools.CigarElement;
import net.sf.samtools.CigarOperator;
import net.sf.samtools.SAMRecord;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.utils.MathUtils;
@ -319,53 +320,7 @@ public class PairHMMIndelErrorModel {
for (int i = 1; i < refBytes.length; i++)
hrunArray[i] = hforward[i]+hreverse[i];
}
private Pair<Double,Double> getGapPenalties(final int indI, final int indJ, final int X_METRIC_LENGTH,
final int Y_METRIC_LENGTH, final int tableToUpdate, double[] currentContextGOP, double[] currentContextGCP) {
double c=0.0,d=0.0;
if (doContextDependentPenalties) {
switch (tableToUpdate) {
case MATCH_OFFSET:
break;
case X_OFFSET:
c = (indJ==Y_METRIC_LENGTH-1? END_GAP_COST: currentContextGOP[indJ-1]);
d = (indJ==Y_METRIC_LENGTH-1? END_GAP_COST: currentContextGCP[indJ-1]);
break;
case Y_OFFSET:
c = (indI==X_METRIC_LENGTH-1? END_GAP_COST: currentContextGOP[indJ-1]);
d = (indI==X_METRIC_LENGTH-1? END_GAP_COST: currentContextGCP[indJ-1]);
break;
default:
throw new StingException("BUG!! Invalid table offset");
}
}
else {
switch (tableToUpdate) {
case MATCH_OFFSET:
break;
case X_OFFSET:
c = (indJ==Y_METRIC_LENGTH-1? END_GAP_COST: logGapOpenProbability);
d = (indJ==Y_METRIC_LENGTH-1? END_GAP_COST: logGapContinuationProbability);
break;
case Y_OFFSET:
c = (indI==X_METRIC_LENGTH-1? END_GAP_COST: logGapOpenProbability);
d = (indI==X_METRIC_LENGTH-1? END_GAP_COST: logGapContinuationProbability);
break;
default:
throw new StingException("BUG!! Invalid table offset");
}
}
return new Pair<Double,Double>(Double.valueOf(c),Double.valueOf(d));
}
private double computeReadLikelihoodGivenHaplotypeAffineGaps(byte[] haplotypeBases, byte[] readBases, byte[] readQuals,
double[] currentGOP, double[] currentGCP) {
@ -381,6 +336,7 @@ public class PairHMMIndelErrorModel {
int[][] bestActionArrayX = new int[X_METRIC_LENGTH][Y_METRIC_LENGTH];
int[][] bestActionArrayY = new int[X_METRIC_LENGTH][Y_METRIC_LENGTH];
double c,d;
matchMetricArray[0][0]= END_GAP_COST;//Double.NEGATIVE_INFINITY;
for (int i=1; i < X_METRIC_LENGTH; i++) {
@ -402,11 +358,12 @@ public class PairHMMIndelErrorModel {
}
for (int indI=1; indI < X_METRIC_LENGTH; indI++) {
int im1 = indI-1;
for (int indJ=1; indJ < Y_METRIC_LENGTH; indJ++) {
byte x = readBases[indI-1];
byte y = haplotypeBases[indJ-1];
byte qual = readQuals[indI-1];
int jm1 = indJ-1;
byte x = readBases[im1];
byte y = haplotypeBases[jm1];
byte qual = readQuals[im1];
double bestMetric = 0.0;
int bestMetricIdx = 0;
@ -424,17 +381,19 @@ public class PairHMMIndelErrorModel {
double[] metrics = new double[3];
// update match array
metrics[MATCH_OFFSET] = matchMetricArray[indI-1][indJ-1] + pBaseRead;
metrics[X_OFFSET] = XMetricArray[indI-1][indJ-1] + pBaseRead;
metrics[Y_OFFSET] = YMetricArray[indI-1][indJ-1] + pBaseRead;
if (doViterbi) {
// update match array
metrics[MATCH_OFFSET] = matchMetricArray[im1][jm1] + pBaseRead;
metrics[X_OFFSET] = XMetricArray[im1][jm1] + pBaseRead;
metrics[Y_OFFSET] = YMetricArray[im1][jm1] + pBaseRead;
bestMetricIdx = MathUtils.maxElementIndex(metrics);
bestMetric = metrics[bestMetricIdx];
}
else
bestMetric = MathUtils.softMax(metrics);
bestMetric = MathUtils.softMax(matchMetricArray[im1][jm1] + pBaseRead, XMetricArray[im1][jm1] + pBaseRead,
YMetricArray[im1][jm1] + pBaseRead);
matchMetricArray[indI][indJ] = bestMetric;
bestActionArrayM[indI][indJ] = ACTIONS_M[bestMetricIdx];
@ -442,35 +401,48 @@ public class PairHMMIndelErrorModel {
// update X array
// State X(i,j): X(1:i) aligned to a gap in Y(1:j).
// When in last column of X, ie X(1:i) aligned to full Y, we don't want to penalize gaps
Pair<Double,Double> p = getGapPenalties(indI, indJ, X_METRIC_LENGTH, Y_METRIC_LENGTH, X_OFFSET, currentGOP, currentGCP);
metrics[MATCH_OFFSET] = matchMetricArray[indI-1][indJ] + p.first;
metrics[X_OFFSET] = XMetricArray[indI-1][indJ] + p.second;
metrics[Y_OFFSET] = Double.NEGATIVE_INFINITY; //YMetricArray[indI-1][indJ] + logGapOpenProbability;
//c = (indJ==Y_METRIC_LENGTH-1? END_GAP_COST: currentGOP[jm1]);
//d = (indJ==Y_METRIC_LENGTH-1? END_GAP_COST: currentGCP[jm1]);
c = currentGOP[jm1];
d = currentGCP[jm1];
if (indJ == Y_METRIC_LENGTH-1)
c = d = END_GAP_COST;
if (doViterbi) {
metrics[MATCH_OFFSET] = matchMetricArray[im1][indJ] + c;
metrics[X_OFFSET] = XMetricArray[im1][indJ] + d;
metrics[Y_OFFSET] = Double.NEGATIVE_INFINITY; //YMetricArray[indI-1][indJ] + logGapOpenProbability;
bestMetricIdx = MathUtils.maxElementIndex(metrics);
bestMetric = metrics[bestMetricIdx];
}
else
bestMetric = MathUtils.softMax(metrics);
bestMetric = MathUtils.softMax(matchMetricArray[im1][indJ] + c, XMetricArray[im1][indJ] + d);
XMetricArray[indI][indJ] = bestMetric;
bestActionArrayX[indI][indJ] = ACTIONS_X[bestMetricIdx];
// update Y array
p = getGapPenalties(indI, indJ, X_METRIC_LENGTH, Y_METRIC_LENGTH, Y_OFFSET, currentGOP, currentGCP);
//c = (indI==X_METRIC_LENGTH-1? END_GAP_COST: currentGOP[jm1]);
//d = (indI==X_METRIC_LENGTH-1? END_GAP_COST: currentGCP[jm1]);
c = currentGOP[jm1];
d = currentGCP[jm1];
if (indI == X_METRIC_LENGTH-1)
c = d = END_GAP_COST;
metrics[MATCH_OFFSET] = matchMetricArray[indI][indJ-1] + p.first;
metrics[X_OFFSET] = Double.NEGATIVE_INFINITY; //XMetricArray[indI][indJ-1] + logGapOpenProbability;
metrics[Y_OFFSET] = YMetricArray[indI][indJ-1] + p.second;
if (doViterbi) {
metrics[MATCH_OFFSET] = matchMetricArray[indI][jm1] + c;
metrics[X_OFFSET] = Double.NEGATIVE_INFINITY; //XMetricArray[indI][indJ-1] + logGapOpenProbability;
metrics[Y_OFFSET] = YMetricArray[indI][jm1] + d;
bestMetricIdx = MathUtils.maxElementIndex(metrics);
bestMetric = metrics[bestMetricIdx];
}
else
bestMetric = MathUtils.softMax(metrics);
bestMetric = MathUtils.softMax(matchMetricArray[indI][jm1] + c, YMetricArray[indI][jm1] + d);
YMetricArray[indI][indJ] = bestMetric;
bestActionArrayY[indI][indJ] = ACTIONS_Y[bestMetricIdx];
@ -612,6 +584,10 @@ public class PairHMMIndelErrorModel {
}
}
for (SAMRecord read : pileup.getReads()) {
/* SAMRecord read = ReadUtils.hardClipAdaptorSequence(read);
if (read == null)
continue;
*/
if(ReadUtils.is454Read(read)) {
continue;
}
@ -639,18 +615,37 @@ public class PairHMMIndelErrorModel {
// Conversely, if a read ends at [eventStart,eventStart+eventLength] we'll use all soft clipped bases in the end of the read.
long eventStartPos = ref.getLocus().getStart();
// default: discard soft-clipped bases
// compute total number of clipped bases (soft or hard clipped)
numStartSoftClippedBases = read.getAlignmentStart()- read.getUnclippedStart();
numEndSoftClippedBases = read.getUnclippedEnd()- read.getAlignmentEnd();
/*if (eventLength > 0) */
{
if ((read.getAlignmentStart()>=eventStartPos-eventLength && read.getAlignmentStart() <= eventStartPos+1) ||
(read.getAlignmentEnd() >= eventStartPos && read.getAlignmentEnd() <= eventStartPos + eventLength)) {
numStartSoftClippedBases = 0;
numEndSoftClippedBases = 0;
}
}
// check for hard clips (never consider these bases):
/* Cigar c = read.getCigar();
CigarElement first = c.getCigarElement(0);
CigarElement last = c.getCigarElement(c.numCigarElements()-1);
int numStartHardClippedBases = 0, numEndHardClippedBases = 0;
if (first.getOperator() == CigarOperator.H) {
numStartHardClippedBases = first.getLength();
}
if (last.getOperator() == CigarOperator.H) {
numEndHardClippedBases = last.getLength();
}
// correct for hard clips
numStartSoftClippedBases -= numStartHardClippedBases;
numEndSoftClippedBases -= numEndHardClippedBases;
readStart += numStartHardClippedBases;
readEnd -= numEndHardClippedBases;
*/
// remove soft clips if necessary
if ((read.getAlignmentStart()>=eventStartPos-eventLength && read.getAlignmentStart() <= eventStartPos+1) ||
(read.getAlignmentEnd() >= eventStartPos && read.getAlignmentEnd() <= eventStartPos + eventLength)) {
numStartSoftClippedBases = 0;
numEndSoftClippedBases = 0;
}
byte[] unclippedReadBases, unclippedReadQuals;

View File

@ -965,6 +965,7 @@ public class MathUtils {
public static final double[] jacobianLogTable;
public static final int JACOBIAN_LOG_TABLE_SIZE = 101;
public static final double JACOBIAN_LOG_TABLE_STEP = 0.1;
public static final double INV_JACOBIAN_LOG_TABLE_STEP = 1.0/JACOBIAN_LOG_TABLE_STEP;
public static final double MAX_JACOBIAN_TOLERANCE = 10.0;
private static final int MAXN = 10000;
@ -983,7 +984,7 @@ public class MathUtils {
}
}
static public double softMax(double[] vec) {
static public double softMax(final double[] vec) {
double acc = vec[0];
for (int k=1; k < vec.length; k++)
acc = softMax(acc,vec[k]);
@ -997,7 +998,7 @@ public class MathUtils {
return Math.max(a,x2);
}
static public double softMax(double x0, double x1, double x2) {
static public double softMax(final double x0, final double x1, final double x2) {
// compute naively log10(10^x[0] + 10^x[1]+...)
// return Math.log10(MathUtils.sumLog10(vec));
@ -1006,7 +1007,7 @@ public class MathUtils {
return softMax(a,x2);
}
static public double softMax(double x, double y) {
static public double softMax(final double x, final double y) {
if (Double.isInfinite(x))
return y;
@ -1024,20 +1025,23 @@ public class MathUtils {
// max(x,y) + log10(1+10^-abs(x-y))
// we compute the second term as a table lookup
// with integer quantization
double diff = Math.abs(x-y);
//double diff = Math.abs(x-y);
double diff = x-y;
double t1 =x;
if (y > x)
if (diff<0) { //
t1 = y;
// t has max(x,y)
diff= -diff;
}
// t has max(x,y), diff has abs(x-y)
// we have pre-stored correction for 0,0.1,0.2,... 10.0
int ind = (int)Math.round(diff/JACOBIAN_LOG_TABLE_STEP);
double t2 = jacobianLogTable[ind];
int ind = (int)Math.round(diff*INV_JACOBIAN_LOG_TABLE_STEP);
// gdebug+
//double z =Math.log10(1+Math.pow(10.0,-diff));
//System.out.format("x: %f, y:%f, app: %f, true: %f ind:%d\n",x,y,t2,z,ind);
//gdebug-
return t1+t2;
return t1+jacobianLogTable[ind];
// return Math.log10(Math.pow(10.0,x) + Math.pow(10.0,y));
}