2009-04-03 06:07:47 +08:00
|
|
|
package org.broadinstitute.sting.playground.fourbasecaller;
|
|
|
|
|
|
|
|
|
|
import cern.colt.matrix.DoubleMatrix1D;
|
|
|
|
|
import cern.colt.matrix.DoubleFactory1D;
|
|
|
|
|
import cern.colt.matrix.DoubleMatrix2D;
|
|
|
|
|
import cern.colt.matrix.DoubleFactory2D;
|
|
|
|
|
import cern.colt.matrix.linalg.Algebra;
|
|
|
|
|
|
|
|
|
|
import org.broadinstitute.sting.utils.QualityUtils;
|
|
|
|
|
|
2009-04-07 06:00:58 +08:00
|
|
|
import java.io.*;
|
|
|
|
|
|
2009-04-04 03:19:17 +08:00
|
|
|
/**
|
|
|
|
|
* BasecallingBaseModel is a class that represents the statistical
|
|
|
|
|
* model for all bases at a given cycle. It allows for easy, one
|
|
|
|
|
* pass training via the addTrainingPoint() method. Once the model
|
|
|
|
|
* is trained, computeLikelihoods will return the probability matrix
|
|
|
|
|
* over previous cycle's base hypotheses and current cycle base
|
|
|
|
|
* hypotheses (contextual prior is included in these likelihoods).
|
|
|
|
|
*
|
|
|
|
|
* @author Kiran Garimella
|
|
|
|
|
*/
|
2009-04-03 06:07:47 +08:00
|
|
|
public class BasecallingBaseModel {
|
2009-04-07 09:20:15 +08:00
|
|
|
private double[] counts;
|
|
|
|
|
private DoubleMatrix1D[] sums;
|
|
|
|
|
private DoubleMatrix2D[] inverseCovariances;
|
|
|
|
|
private double[] norms;
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-07 09:20:15 +08:00
|
|
|
private Algebra alg;
|
2009-04-03 06:07:47 +08:00
|
|
|
|
|
|
|
|
private boolean readyToCall = false;
|
|
|
|
|
|
2009-04-04 03:19:17 +08:00
|
|
|
/**
|
|
|
|
|
* Constructor for BasecallingBaseModel
|
|
|
|
|
*/
|
2009-04-03 06:07:47 +08:00
|
|
|
public BasecallingBaseModel() {
|
2009-04-07 09:20:15 +08:00
|
|
|
counts = new double[4];
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-07 09:20:15 +08:00
|
|
|
sums = new DoubleMatrix1D[4];
|
|
|
|
|
inverseCovariances = new DoubleMatrix2D[4];
|
|
|
|
|
norms = new double[4];
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-07 09:20:15 +08:00
|
|
|
for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) {
|
|
|
|
|
sums[baseCurIndex] = (DoubleFactory1D.dense).make(4);
|
|
|
|
|
inverseCovariances[baseCurIndex] = (DoubleFactory2D.dense).make(4, 4);
|
2009-04-03 06:07:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
alg = new Algebra();
|
|
|
|
|
}
|
|
|
|
|
|
2009-04-04 03:19:17 +08:00
|
|
|
/**
|
|
|
|
|
* Add a single training point to the model.
|
2009-04-07 06:00:58 +08:00
|
|
|
*
|
2009-04-04 03:19:17 +08:00
|
|
|
* @param baseCur the current cycle's base call (A, C, G, T)
|
|
|
|
|
* @param qualCur the quality score for the current cycle's base call
|
|
|
|
|
* @param fourintensity the four intensities for the current cycle's base call
|
|
|
|
|
*/
|
2009-04-07 09:20:15 +08:00
|
|
|
public void addMeanPoint(char baseCur, byte qualCur, double[] fourintensity) {
|
2009-04-07 06:00:58 +08:00
|
|
|
int actualBaseCurIndex = baseToBaseIndex(baseCur);
|
|
|
|
|
double actualWeight = QualityUtils.qualToProb(qualCur);
|
|
|
|
|
|
|
|
|
|
cern.jet.math.Functions F = cern.jet.math.Functions.functions;
|
|
|
|
|
|
2009-04-07 09:20:15 +08:00
|
|
|
for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) {
|
|
|
|
|
// We want to upweight the correct theory as much as we can and spread the remainder out evenly between all other hypotheses.
|
|
|
|
|
double weight = (baseCurIndex == actualBaseCurIndex) ? actualWeight : ((1.0 - actualWeight)/3.0);
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-07 09:20:15 +08:00
|
|
|
DoubleMatrix1D weightedChannelIntensities = (DoubleFactory1D.dense).make(fourintensity);
|
|
|
|
|
weightedChannelIntensities.assign(F.mult(weight));
|
2009-04-07 06:00:58 +08:00
|
|
|
|
2009-04-07 09:20:15 +08:00
|
|
|
sums[baseCurIndex].assign(weightedChannelIntensities, F.plus);
|
|
|
|
|
counts[baseCurIndex] += weight;
|
2009-04-07 06:00:58 +08:00
|
|
|
}
|
|
|
|
|
|
2009-04-07 09:20:15 +08:00
|
|
|
readyToCall = false;
|
|
|
|
|
}
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-07 09:20:15 +08:00
|
|
|
public void addCovariancePoint(char baseCur, byte qualCur, double[] fourintensity) {
|
|
|
|
|
int actualBaseCurIndex = baseToBaseIndex(baseCur);
|
|
|
|
|
double actualWeight = QualityUtils.qualToProb(qualCur);
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-07 09:20:15 +08:00
|
|
|
cern.jet.math.Functions F = cern.jet.math.Functions.functions;
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-07 09:20:15 +08:00
|
|
|
for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) {
|
|
|
|
|
// We want to upweight the correct theory as much as we can and spread the remainder out evenly between all other hypotheses.
|
|
|
|
|
double weight = (baseCurIndex == actualBaseCurIndex) ? actualWeight : ((1.0 - actualWeight)/3.0);
|
2009-04-03 06:07:47 +08:00
|
|
|
|
2009-04-07 09:20:15 +08:00
|
|
|
DoubleMatrix1D mean = sums[baseCurIndex].copy();
|
|
|
|
|
mean.assign(F.div(counts[baseCurIndex]));
|
|
|
|
|
|
|
|
|
|
DoubleMatrix1D sub = (DoubleFactory1D.dense).make(fourintensity);
|
|
|
|
|
sub.assign(mean, F.minus);
|
|
|
|
|
|
|
|
|
|
DoubleMatrix2D cov = (DoubleFactory2D.dense).make(4, 4);
|
|
|
|
|
alg.multOuter(sub, sub, cov);
|
|
|
|
|
|
|
|
|
|
cov.assign(F.mult(weight));
|
|
|
|
|
inverseCovariances[baseCurIndex].assign(cov, F.plus);
|
|
|
|
|
}
|
2009-04-03 06:07:47 +08:00
|
|
|
}
|
|
|
|
|
|
2009-04-04 03:19:17 +08:00
|
|
|
/**
|
|
|
|
|
* Precompute all the matrix inversions and determinants we'll need for computing the likelihood distributions.
|
|
|
|
|
*/
|
2009-04-03 06:07:47 +08:00
|
|
|
public void prepareToCallBases() {
|
2009-04-07 09:20:15 +08:00
|
|
|
/*
|
2009-04-03 06:07:47 +08:00
|
|
|
for (int basePrevIndex = 0; basePrevIndex < 4; basePrevIndex++) {
|
|
|
|
|
for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) {
|
|
|
|
|
for (int channel = 0; channel < 4; channel++) {
|
2009-04-07 09:20:15 +08:00
|
|
|
sums[baseCurIndex].setQuick(channel, runningChannelSums[basePrevIndex][baseCurIndex].getQuick(channel)/counts[basePrevIndex][baseCurIndex]);
|
2009-04-03 06:07:47 +08:00
|
|
|
|
|
|
|
|
for (int cochannel = 0; cochannel < 4; cochannel++) {
|
|
|
|
|
// Cov(Xi, Xj) = E(XiXj) - E(Xi)E(Xj)
|
|
|
|
|
inverseCovariances[basePrevIndex][baseCurIndex].setQuick(channel, cochannel, (runningChannelProductSums[basePrevIndex][baseCurIndex].getQuick(channel, cochannel)/counts[basePrevIndex][baseCurIndex]) - (runningChannelSums[basePrevIndex][baseCurIndex].getQuick(channel)/counts[basePrevIndex][baseCurIndex])*(runningChannelSums[basePrevIndex][baseCurIndex].getQuick(cochannel)/counts[basePrevIndex][baseCurIndex]));
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DoubleMatrix2D invcov = alg.inverse(inverseCovariances[basePrevIndex][baseCurIndex]);
|
|
|
|
|
inverseCovariances[basePrevIndex][baseCurIndex] = invcov;
|
|
|
|
|
norms[basePrevIndex][baseCurIndex] = Math.pow(alg.det(invcov), 0.5)/Math.pow(2.0*Math.PI, 2.0);
|
|
|
|
|
}
|
|
|
|
|
}
|
2009-04-07 09:20:15 +08:00
|
|
|
*/
|
2009-04-03 06:07:47 +08:00
|
|
|
|
|
|
|
|
readyToCall = true;
|
|
|
|
|
}
|
|
|
|
|
|
2009-04-04 03:19:17 +08:00
|
|
|
/**
|
|
|
|
|
* Compute the likelihood matrix for a base (contextual priors included).
|
|
|
|
|
*
|
|
|
|
|
* @param cycle the cycle we're calling right now
|
|
|
|
|
* @param basePrev the previous cycle's base
|
|
|
|
|
* @param qualPrev the previous cycle's quality score
|
|
|
|
|
* @param fourintensity the four intensities of the current cycle's base
|
|
|
|
|
* @return a 4x4 matrix of likelihoods, where the row is the previous cycle base hypothesis and
|
|
|
|
|
* the column is the current cycle base hypothesis
|
|
|
|
|
*/
|
2009-04-03 06:07:47 +08:00
|
|
|
public double[][] computeLikelihoods(int cycle, char basePrev, byte qualPrev, double[] fourintensity) {
|
|
|
|
|
if (!readyToCall) {
|
|
|
|
|
prepareToCallBases();
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
double[][] probdist = new double[4][4];
|
2009-04-07 09:20:15 +08:00
|
|
|
/*
|
2009-04-03 23:47:47 +08:00
|
|
|
double probPrev = (cycle == 0) ? 1.0 : QualityUtils.qualToProb(qualPrev);
|
|
|
|
|
int baseIndex = (cycle == 0) ? 0 : baseToBaseIndex(basePrev);
|
2009-04-03 06:07:47 +08:00
|
|
|
|
|
|
|
|
for (int basePrevIndex = 0; basePrevIndex < ((cycle == 0) ? 1 : 4); basePrevIndex++) {
|
|
|
|
|
for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) {
|
|
|
|
|
double[] diff = new double[4];
|
|
|
|
|
for (int channel = 0; channel < 4; channel++) {
|
2009-04-07 09:20:15 +08:00
|
|
|
diff[channel] = fourintensity[channel] - sums[basePrevIndex][baseCurIndex].getQuick(channel);
|
2009-04-03 06:07:47 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
DoubleMatrix1D sub = (DoubleFactory1D.dense).make(diff);
|
|
|
|
|
DoubleMatrix1D Ax = alg.mult(inverseCovariances[basePrevIndex][baseCurIndex], sub);
|
|
|
|
|
|
|
|
|
|
double exparg = -0.5*alg.mult(sub, Ax);
|
2009-04-03 23:47:47 +08:00
|
|
|
probdist[basePrevIndex][baseCurIndex] = (baseIndex == basePrevIndex ? probPrev : 1.0 - probPrev)*norms[basePrevIndex][baseCurIndex]*Math.exp(exparg);
|
2009-04-03 06:07:47 +08:00
|
|
|
}
|
|
|
|
|
}
|
2009-04-07 09:20:15 +08:00
|
|
|
*/
|
2009-04-03 06:07:47 +08:00
|
|
|
|
|
|
|
|
return probdist;
|
|
|
|
|
}
|
|
|
|
|
|
2009-04-07 06:00:58 +08:00
|
|
|
public void write(File outparam) {
|
|
|
|
|
try {
|
|
|
|
|
PrintWriter writer = new PrintWriter(outparam);
|
|
|
|
|
|
2009-04-07 09:20:15 +08:00
|
|
|
for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) {
|
|
|
|
|
writer.print("mean_" + baseIndexToBase(baseCurIndex) + " : [ ");
|
|
|
|
|
for (int channel = 0; channel < 4; channel++) {
|
|
|
|
|
writer.print(sums[baseCurIndex].getQuick(channel)/counts[baseCurIndex]);
|
|
|
|
|
writer.print(" ");
|
2009-04-07 06:00:58 +08:00
|
|
|
}
|
2009-04-07 09:20:15 +08:00
|
|
|
writer.print("] (" + counts[baseCurIndex] + ")\n");
|
|
|
|
|
|
|
|
|
|
DoubleMatrix2D cov = inverseCovariances[baseCurIndex].copy();
|
|
|
|
|
cern.jet.math.Functions F = cern.jet.math.Functions.functions;
|
|
|
|
|
cov.assign(F.div(counts[baseCurIndex]));
|
|
|
|
|
|
|
|
|
|
writer.println("cov_" + baseIndexToBase(baseCurIndex) + " : " + cov + "\n");
|
2009-04-07 06:00:58 +08:00
|
|
|
}
|
|
|
|
|
|
|
|
|
|
writer.close();
|
|
|
|
|
} catch (IOException e) {
|
|
|
|
|
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
2009-04-04 03:19:17 +08:00
|
|
|
/**
|
|
|
|
|
* Utility method for converting a base ([Aa*], [Cc], [Gg], [Tt]) to an index (0, 1, 2, 3);
|
2009-04-07 09:20:15 +08:00
|
|
|
*
|
2009-04-04 03:19:17 +08:00
|
|
|
* @param base
|
|
|
|
|
* @return 0, 1, 2, 3, or -1 if the base can't be understood.
|
|
|
|
|
*/
|
2009-04-03 06:07:47 +08:00
|
|
|
private int baseToBaseIndex(char base) {
|
|
|
|
|
switch (base) {
|
|
|
|
|
case 'A':
|
|
|
|
|
case 'a':
|
|
|
|
|
case '*': return 0;
|
|
|
|
|
|
|
|
|
|
case 'C':
|
|
|
|
|
case 'c': return 1;
|
|
|
|
|
|
|
|
|
|
case 'G':
|
|
|
|
|
case 'g': return 2;
|
|
|
|
|
|
|
|
|
|
case 'T':
|
|
|
|
|
case 't': return 3;
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
return -1;
|
|
|
|
|
}
|
2009-04-07 06:00:58 +08:00
|
|
|
|
|
|
|
|
private char baseIndexToBase(int baseIndex) {
|
|
|
|
|
switch (baseIndex) {
|
|
|
|
|
case 0: return 'A';
|
|
|
|
|
case 1: return 'C';
|
|
|
|
|
case 2: return 'G';
|
|
|
|
|
case 3: return 'T';
|
|
|
|
|
default: return '.';
|
|
|
|
|
}
|
|
|
|
|
}
|
2009-04-03 06:07:47 +08:00
|
|
|
}
|