diff --git a/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingBaseModel.java b/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingBaseModel.java new file mode 100755 index 000000000..6debe4ed5 --- /dev/null +++ b/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingBaseModel.java @@ -0,0 +1,134 @@ +package org.broadinstitute.sting.playground.fourbasecaller; + +import cern.colt.matrix.DoubleMatrix1D; +import cern.colt.matrix.DoubleFactory1D; +import cern.colt.matrix.DoubleMatrix2D; +import cern.colt.matrix.DoubleFactory2D; +import cern.colt.matrix.linalg.Algebra; + +import org.broadinstitute.sting.utils.QualityUtils; + +public class BasecallingBaseModel { + private double[][] counts; + + private DoubleMatrix1D[][] runningChannelSums; + private DoubleMatrix2D[][] runningChannelProductSums; + + private boolean readyToCall = false; + private DoubleMatrix1D[][] means; + private DoubleMatrix2D[][] inverseCovariances; + private double[][] norms; + + Algebra alg; + + public BasecallingBaseModel() { + counts = new double[4][4]; + + runningChannelSums = new DoubleMatrix1D[4][4]; + runningChannelProductSums = new DoubleMatrix2D[4][4]; + + means = new DoubleMatrix1D[4][4]; + inverseCovariances = new DoubleMatrix2D[4][4]; + norms = new double[4][4]; + + for (int basePrevIndex = 0; basePrevIndex < 4; basePrevIndex++) { + for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { + runningChannelSums[basePrevIndex][baseCurIndex] = (DoubleFactory1D.dense).make(4); + runningChannelProductSums[basePrevIndex][baseCurIndex] = (DoubleFactory2D.dense).make(4, 4); + + means[basePrevIndex][baseCurIndex] = (DoubleFactory1D.dense).make(4); + inverseCovariances[basePrevIndex][baseCurIndex] = (DoubleFactory2D.dense).make(4, 4); + } + } + + alg = new Algebra(); + } + + public void addTrainingPoint(char basePrev, char baseCur, byte qualCur, double[] fourintensity) { + int basePrevIndex = baseToBaseIndex(basePrev); + int baseCurIndex = baseToBaseIndex(baseCur); + + if (basePrevIndex >= 0 && baseCurIndex >= 0) { + for (int channel = 0; channel < 4; channel++) { + double weight = QualityUtils.qualToProb(qualCur); + double channelIntensity = fourintensity[channel]; + + runningChannelSums[basePrevIndex][baseCurIndex].setQuick(channel, runningChannelSums[basePrevIndex][baseCurIndex].getQuick(channel) + weight*channelIntensity); + + for (int cochannel = 0; cochannel < 4; cochannel++) { + double cochannelIntensity = fourintensity[cochannel]; + runningChannelProductSums[basePrevIndex][baseCurIndex].setQuick(channel, cochannel, runningChannelProductSums[basePrevIndex][baseCurIndex].getQuick(channel, cochannel) + weight*channelIntensity*cochannelIntensity); + } + } + + counts[basePrevIndex][baseCurIndex]++; + } + + readyToCall = false; + } + + public void prepareToCallBases() { + for (int basePrevIndex = 0; basePrevIndex < 4; basePrevIndex++) { + for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { + for (int channel = 0; channel < 4; channel++) { + means[basePrevIndex][baseCurIndex].setQuick(channel, runningChannelSums[basePrevIndex][baseCurIndex].getQuick(channel)/counts[basePrevIndex][baseCurIndex]); + + for (int cochannel = 0; cochannel < 4; cochannel++) { + // Cov(Xi, Xj) = E(XiXj) - E(Xi)E(Xj) + inverseCovariances[basePrevIndex][baseCurIndex].setQuick(channel, cochannel, (runningChannelProductSums[basePrevIndex][baseCurIndex].getQuick(channel, cochannel)/counts[basePrevIndex][baseCurIndex]) - (runningChannelSums[basePrevIndex][baseCurIndex].getQuick(channel)/counts[basePrevIndex][baseCurIndex])*(runningChannelSums[basePrevIndex][baseCurIndex].getQuick(cochannel)/counts[basePrevIndex][baseCurIndex])); + } + } + + DoubleMatrix2D invcov = alg.inverse(inverseCovariances[basePrevIndex][baseCurIndex]); + inverseCovariances[basePrevIndex][baseCurIndex] = invcov; + norms[basePrevIndex][baseCurIndex] = Math.pow(alg.det(invcov), 0.5)/Math.pow(2.0*Math.PI, 2.0); + } + } + + readyToCall = true; + } + + public double[][] computeLikelihoods(int cycle, char basePrev, byte qualPrev, double[] fourintensity) { + if (!readyToCall) { + prepareToCallBases(); + } + + double[][] probdist = new double[4][4]; + + for (int basePrevIndex = 0; basePrevIndex < ((cycle == 0) ? 1 : 4); basePrevIndex++) { + for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { + double[] diff = new double[4]; + for (int channel = 0; channel < 4; channel++) { + diff[channel] = fourintensity[channel] - means[basePrevIndex][baseCurIndex].getQuick(channel); + } + + DoubleMatrix1D sub = (DoubleFactory1D.dense).make(diff); + DoubleMatrix1D Ax = alg.mult(inverseCovariances[basePrevIndex][baseCurIndex], sub); + + double exparg = -0.5*alg.mult(sub, Ax); + probdist[basePrevIndex][baseCurIndex] = norms[basePrevIndex][baseCurIndex]*Math.exp(exparg); + } + } + + return probdist; + } + + private int baseToBaseIndex(char base) { + switch (base) { + case 'A': + case 'a': + case '*': return 0; + + case 'C': + case 'c': return 1; + + case 'G': + case 'g': return 2; + + case 'T': + case 't': return 3; + } + + return -1; + } +}