2009-05-13 04:24:18 +08:00
|
|
|
package org.broadinstitute.sting.secondarybase;
|
2009-04-03 06:07:47 +08:00
|
|
|
|
|
|
|
|
import cern.colt.matrix.DoubleFactory1D;
|
|
|
|
|
import cern.colt.matrix.DoubleFactory2D;
|
2009-05-15 02:58:43 +08:00
|
|
|
import cern.colt.matrix.DoubleMatrix1D;
|
|
|
|
|
import cern.colt.matrix.DoubleMatrix2D;
|
2009-04-03 06:07:47 +08:00
|
|
|
import cern.colt.matrix.linalg.Algebra;
|
2009-04-15 12:18:07 +08:00
|
|
|
import org.broadinstitute.sting.utils.BaseUtils;
|
2009-06-09 09:00:33 +08:00
|
|
|
import org.broadinstitute.sting.utils.MathUtils;
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-05-15 02:58:43 +08:00
|
|
|
import java.io.File;
|
|
|
|
|
import java.io.IOException;
|
|
|
|
|
import java.io.PrintWriter;
|
2009-04-07 06:00:58 +08:00
|
|
|
|
2009-04-04 03:19:17 +08:00
|
|
|
/**
|
|
|
|
|
* BasecallingBaseModel is a class that represents the statistical
|
|
|
|
|
* model for all bases at a given cycle. It allows for easy, one
|
|
|
|
|
* pass training via the addTrainingPoint() method. Once the model
|
|
|
|
|
* is trained, computeLikelihoods will return the probability matrix
|
|
|
|
|
* over previous cycle's base hypotheses and current cycle base
|
|
|
|
|
* hypotheses (contextual prior is included in these likelihoods).
|
|
|
|
|
*
|
|
|
|
|
* @author Kiran Garimella
|
|
|
|
|
*/
|
2009-04-03 06:07:47 +08:00
|
|
|
public class BasecallingBaseModel {
|
2009-04-15 12:18:07 +08:00
|
|
|
private double[][] counts;
|
|
|
|
|
private DoubleMatrix1D[][] sums;
|
|
|
|
|
private DoubleMatrix2D[][] unscaledCovarianceSums;
|
2009-04-07 10:18:13 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
private DoubleMatrix1D[][] means;
|
|
|
|
|
private DoubleMatrix2D[][] inverseCovariances;
|
|
|
|
|
private double[][] norms;
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-07 10:18:13 +08:00
|
|
|
private cern.jet.math.Functions F = cern.jet.math.Functions.functions;
|
2009-04-07 09:20:15 +08:00
|
|
|
private Algebra alg;
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
private boolean correctForContext = false;
|
|
|
|
|
private int numTheories = 1;
|
|
|
|
|
|
2009-04-03 06:07:47 +08:00
|
|
|
private boolean readyToCall = false;
|
2009-06-09 09:00:33 +08:00
|
|
|
private boolean bustedCycle = false;
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-04 03:19:17 +08:00
|
|
|
/**
|
2009-05-22 03:40:47 +08:00
|
|
|
* Constructor for BasecallingBaseModel.
|
|
|
|
|
*
|
|
|
|
|
* @param correctForContext should we attempt to correct for contextual sequence effects?
|
2009-04-04 03:19:17 +08:00
|
|
|
*/
|
2009-04-15 12:18:07 +08:00
|
|
|
public BasecallingBaseModel(boolean correctForContext) {
|
|
|
|
|
this.correctForContext = correctForContext;
|
|
|
|
|
this.numTheories = (correctForContext) ? 4 : 1;
|
|
|
|
|
|
|
|
|
|
counts = new double[this.numTheories][4];
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
sums = new DoubleMatrix1D[this.numTheories][4];
|
|
|
|
|
unscaledCovarianceSums = new DoubleMatrix2D[this.numTheories][4];
|
2009-04-07 10:18:13 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
means = new DoubleMatrix1D[this.numTheories][4];
|
|
|
|
|
inverseCovariances = new DoubleMatrix2D[this.numTheories][4];
|
|
|
|
|
norms = new double[this.numTheories][4];
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
for (int basePrevIndex = 0; basePrevIndex < this.numTheories; basePrevIndex++) {
|
|
|
|
|
for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) {
|
|
|
|
|
sums[basePrevIndex][baseCurIndex] = (DoubleFactory1D.dense).make(4);
|
|
|
|
|
unscaledCovarianceSums[basePrevIndex][baseCurIndex] = (DoubleFactory2D.dense).make(4, 4);
|
2009-04-07 10:18:13 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
means[basePrevIndex][baseCurIndex] = (DoubleFactory1D.dense).make(4);
|
|
|
|
|
inverseCovariances[basePrevIndex][baseCurIndex] = (DoubleFactory2D.dense).make(4, 4);
|
|
|
|
|
}
|
2009-04-03 06:07:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
alg = new Algebra();
|
|
|
|
|
}
|
|
|
|
|
|
2009-05-22 03:40:47 +08:00
|
|
|
|
2009-04-04 03:19:17 +08:00
|
|
|
/**
|
2009-04-07 10:18:13 +08:00
|
|
|
* Add a single training point to the model to estimate the means.
|
2009-05-22 03:40:47 +08:00
|
|
|
*
|
|
|
|
|
* @param probMatrix the matrix of probabilities for the base
|
|
|
|
|
* @param fourIntensity the four raw intensities for the base
|
2009-04-04 03:19:17 +08:00
|
|
|
*/
|
2009-05-13 03:47:41 +08:00
|
|
|
public void addMeanPoint(double[][] probMatrix, double[] fourIntensity) {
|
2009-04-15 12:18:07 +08:00
|
|
|
for (int basePrevIndex = 0; basePrevIndex < numTheories; basePrevIndex++) {
|
|
|
|
|
for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) {
|
2009-05-13 03:47:41 +08:00
|
|
|
double weight = probMatrix[basePrevIndex][baseCurIndex];
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-05-13 03:47:41 +08:00
|
|
|
DoubleMatrix1D weightedChannelIntensities = (DoubleFactory1D.dense).make(fourIntensity);
|
2009-04-15 12:18:07 +08:00
|
|
|
weightedChannelIntensities.assign(F.mult(weight));
|
2009-04-07 06:00:58 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
sums[basePrevIndex][baseCurIndex].assign(weightedChannelIntensities, F.plus);
|
|
|
|
|
counts[basePrevIndex][baseCurIndex] += weight;
|
|
|
|
|
}
|
2009-04-07 06:00:58 +08:00
|
|
|
}
|
|
|
|
|
|
2009-04-07 09:20:15 +08:00
|
|
|
readyToCall = false;
|
|
|
|
|
}
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-07 10:18:13 +08:00
|
|
|
/**
|
|
|
|
|
* Add a single training point to the model to estimate the covariances.
|
2009-05-22 03:40:47 +08:00
|
|
|
*
|
|
|
|
|
* @param probMatrix the matrix of probabilities for the base
|
|
|
|
|
* @param fourIntensity the four raw intensities for the base
|
2009-04-07 10:18:13 +08:00
|
|
|
*/
|
2009-05-13 03:47:41 +08:00
|
|
|
public void addCovariancePoint(double[][] probMatrix, double[] fourIntensity) {
|
2009-04-15 12:18:07 +08:00
|
|
|
for (int basePrevIndex = 0; basePrevIndex < numTheories; basePrevIndex++) {
|
|
|
|
|
for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) {
|
2009-05-13 03:47:41 +08:00
|
|
|
double weight = probMatrix[basePrevIndex][baseCurIndex];
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
DoubleMatrix1D mean = sums[basePrevIndex][baseCurIndex].copy();
|
|
|
|
|
mean.assign(F.div(counts[basePrevIndex][baseCurIndex]));
|
2009-04-07 09:20:15 +08:00
|
|
|
|
2009-05-13 03:47:41 +08:00
|
|
|
DoubleMatrix1D sub = (DoubleFactory1D.dense).make(fourIntensity);
|
2009-04-15 12:18:07 +08:00
|
|
|
sub.assign(mean, F.minus);
|
2009-04-07 09:20:15 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
DoubleMatrix2D cov = (DoubleFactory2D.dense).make(4, 4);
|
|
|
|
|
alg.multOuter(sub, sub, cov);
|
2009-04-07 09:20:15 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
cov.assign(F.mult(weight));
|
|
|
|
|
unscaledCovarianceSums[basePrevIndex][baseCurIndex].assign(cov, F.plus);
|
|
|
|
|
}
|
2009-04-07 09:20:15 +08:00
|
|
|
}
|
2009-04-07 10:18:13 +08:00
|
|
|
|
|
|
|
|
readyToCall = false;
|
2009-04-03 06:07:47 +08:00
|
|
|
}
|
|
|
|
|
|
2009-04-04 03:19:17 +08:00
|
|
|
/**
|
|
|
|
|
* Precompute all the matrix inversions and determinants we'll need for computing the likelihood distributions.
|
|
|
|
|
*/
|
2009-04-03 06:07:47 +08:00
|
|
|
public void prepareToCallBases() {
|
2009-04-15 12:18:07 +08:00
|
|
|
for (int basePrevIndex = 0; basePrevIndex < numTheories; basePrevIndex++) {
|
|
|
|
|
for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) {
|
|
|
|
|
means[basePrevIndex][baseCurIndex] = sums[basePrevIndex][baseCurIndex].copy();
|
|
|
|
|
means[basePrevIndex][baseCurIndex].assign(F.div(counts[basePrevIndex][baseCurIndex]));
|
2009-04-07 10:18:13 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
inverseCovariances[basePrevIndex][baseCurIndex] = unscaledCovarianceSums[basePrevIndex][baseCurIndex].copy();
|
|
|
|
|
inverseCovariances[basePrevIndex][baseCurIndex].assign(F.div(counts[basePrevIndex][baseCurIndex]));
|
2009-06-09 09:00:33 +08:00
|
|
|
|
|
|
|
|
if (MathUtils.compareDoubles(alg.det(inverseCovariances[basePrevIndex][baseCurIndex]), 0.0) == 0) {
|
|
|
|
|
bustedCycle = true;
|
|
|
|
|
readyToCall = true;
|
|
|
|
|
|
|
|
|
|
return;
|
|
|
|
|
}
|
|
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
DoubleMatrix2D invcov = alg.inverse(inverseCovariances[basePrevIndex][baseCurIndex]);
|
2009-06-09 09:00:33 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
inverseCovariances[basePrevIndex][baseCurIndex] = invcov;
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
norms[basePrevIndex][baseCurIndex] = Math.pow(alg.det(invcov), 0.5)/Math.pow(2.0*Math.PI, 2.0);
|
|
|
|
|
}
|
2009-04-03 06:07:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
readyToCall = true;
|
|
|
|
|
}
|
|
|
|
|
|
2009-04-04 03:19:17 +08:00
|
|
|
/**
|
2009-05-22 03:40:47 +08:00
|
|
|
* Compute the likelihood matrix for a base.
|
2009-04-04 03:19:17 +08:00
|
|
|
*
|
|
|
|
|
* @param cycle the cycle we're calling right now
|
|
|
|
|
* @param fourintensity the four intensities of the current cycle's base
|
|
|
|
|
* @return a 4x4 matrix of likelihoods, where the row is the previous cycle base hypothesis and
|
|
|
|
|
* the column is the current cycle base hypothesis
|
|
|
|
|
*/
|
2009-04-15 12:18:07 +08:00
|
|
|
public double[][] computeLikelihoods(int cycle, double[] fourintensity) {
|
2009-04-03 06:07:47 +08:00
|
|
|
if (!readyToCall) {
|
|
|
|
|
prepareToCallBases();
|
|
|
|
|
}
|
|
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
double[][] likedist = new double[numTheories][4];
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-06-09 09:00:33 +08:00
|
|
|
if (bustedCycle) {
|
|
|
|
|
likedist[0][0] = 1.0;
|
|
|
|
|
} else {
|
|
|
|
|
for (int basePrevIndex = 0; basePrevIndex < numTheories; basePrevIndex++) {
|
|
|
|
|
for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) {
|
|
|
|
|
double norm = norms[basePrevIndex][baseCurIndex];
|
|
|
|
|
|
|
|
|
|
DoubleMatrix1D sub = (DoubleFactory1D.dense).make(fourintensity);
|
|
|
|
|
sub.assign(means[basePrevIndex][baseCurIndex], F.minus);
|
2009-04-07 10:18:13 +08:00
|
|
|
|
2009-06-09 09:00:33 +08:00
|
|
|
DoubleMatrix1D Ax = alg.mult(inverseCovariances[basePrevIndex][baseCurIndex], sub);
|
|
|
|
|
double exparg = -0.5*alg.mult(sub, Ax);
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-06-09 09:00:33 +08:00
|
|
|
likedist[basePrevIndex][baseCurIndex] = norm*Math.exp(exparg);
|
|
|
|
|
}
|
2009-04-15 12:18:07 +08:00
|
|
|
}
|
2009-04-03 06:07:47 +08:00
|
|
|
}
|
|
|
|
|
|
2009-04-07 10:18:13 +08:00
|
|
|
return likedist;
|
2009-04-03 06:07:47 +08:00
|
|
|
}
|
|
|
|
|
|
2009-05-22 03:40:47 +08:00
|
|
|
/**
|
|
|
|
|
* Write the model parameters to disk.
|
|
|
|
|
*
|
|
|
|
|
* @param outparam the file in which the output parameters should be stored
|
|
|
|
|
*/
|
2009-04-07 06:00:58 +08:00
|
|
|
public void write(File outparam) {
|
|
|
|
|
try {
|
|
|
|
|
PrintWriter writer = new PrintWriter(outparam);
|
|
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
for (int basePrevIndex = 0; basePrevIndex < numTheories; basePrevIndex++) {
|
|
|
|
|
for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) {
|
|
|
|
|
writer.print("mean_" + BaseUtils.baseIndexToSimpleBase(baseCurIndex) + " = c(");
|
|
|
|
|
for (int channel = 0; channel < 4; channel++) {
|
|
|
|
|
writer.print(sums[basePrevIndex][baseCurIndex].getQuick(channel)/counts[basePrevIndex][baseCurIndex]);
|
2009-04-13 03:45:33 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
if (channel < 3) {
|
|
|
|
|
writer.print(", ");
|
|
|
|
|
}
|
2009-04-13 03:45:33 +08:00
|
|
|
}
|
2009-04-15 12:18:07 +08:00
|
|
|
writer.println(");");
|
2009-04-07 09:20:15 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
DoubleMatrix2D cov = unscaledCovarianceSums[basePrevIndex][baseCurIndex].copy();
|
|
|
|
|
cov.assign(F.div(counts[basePrevIndex][baseCurIndex]));
|
2009-04-07 09:20:15 +08:00
|
|
|
|
2009-04-15 12:18:07 +08:00
|
|
|
writer.print("cov_" + BaseUtils.baseIndexToSimpleBase(baseCurIndex) + " = matrix(c(");
|
|
|
|
|
for (int channel1 = 0; channel1 < 4; channel1++) {
|
|
|
|
|
for (int channel2 = 0; channel2 < 4; channel2++) {
|
|
|
|
|
writer.print(cov.get(channel2, channel1) + (channel1 == 3 && channel2 == 3 ? "" : ","));
|
|
|
|
|
}
|
2009-04-13 03:45:33 +08:00
|
|
|
}
|
2009-04-15 12:18:07 +08:00
|
|
|
writer.println("), nr=4, nc=4);\n");
|
2009-04-13 03:45:33 +08:00
|
|
|
}
|
2009-04-07 06:00:58 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
writer.close();
|
|
|
|
|
} catch (IOException e) {
|
|
|
|
|
}
|
|
|
|
|
}
|
2009-04-03 06:07:47 +08:00
|
|
|
}
|