diff --git a/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingBaseModel.java b/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingBaseModel.java index 8eb3918a1..11c50b592 100755 --- a/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingBaseModel.java +++ b/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingBaseModel.java @@ -23,9 +23,13 @@ import java.io.*; public class BasecallingBaseModel { private double[] counts; private DoubleMatrix1D[] sums; + private DoubleMatrix2D[] unscaledCovarianceSums; + + private DoubleMatrix1D[] means; private DoubleMatrix2D[] inverseCovariances; private double[] norms; + private cern.jet.math.Functions F = cern.jet.math.Functions.functions; private Algebra alg; private boolean readyToCall = false; @@ -37,11 +41,17 @@ public class BasecallingBaseModel { counts = new double[4]; sums = new DoubleMatrix1D[4]; + unscaledCovarianceSums = new DoubleMatrix2D[4]; + + means = new DoubleMatrix1D[4]; inverseCovariances = new DoubleMatrix2D[4]; norms = new double[4]; for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { sums[baseCurIndex] = (DoubleFactory1D.dense).make(4); + unscaledCovarianceSums[baseCurIndex] = (DoubleFactory2D.dense).make(4, 4); + + means[baseCurIndex] = (DoubleFactory1D.dense).make(4); inverseCovariances[baseCurIndex] = (DoubleFactory2D.dense).make(4, 4); } @@ -49,7 +59,7 @@ public class BasecallingBaseModel { } /** - * Add a single training point to the model. + * Add a single training point to the model to estimate the means. * * @param baseCur the current cycle's base call (A, C, G, T) * @param qualCur the quality score for the current cycle's base call @@ -59,8 +69,6 @@ public class BasecallingBaseModel { int actualBaseCurIndex = baseToBaseIndex(baseCur); double actualWeight = QualityUtils.qualToProb(qualCur); - cern.jet.math.Functions F = cern.jet.math.Functions.functions; - for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { // We want to upweight the correct theory as much as we can and spread the remainder out evenly between all other hypotheses. double weight = (baseCurIndex == actualBaseCurIndex) ? actualWeight : ((1.0 - actualWeight)/3.0); @@ -75,12 +83,17 @@ public class BasecallingBaseModel { readyToCall = false; } + /** + * Add a single training point to the model to estimate the covariances. + * + * @param baseCur the current cycle's base call (A, C, G, T) + * @param qualCur the quality score for the current cycle's base call + * @param fourintensity the four intensities for the current cycle's base call + */ public void addCovariancePoint(char baseCur, byte qualCur, double[] fourintensity) { int actualBaseCurIndex = baseToBaseIndex(baseCur); double actualWeight = QualityUtils.qualToProb(qualCur); - cern.jet.math.Functions F = cern.jet.math.Functions.functions; - for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { // We want to upweight the correct theory as much as we can and spread the remainder out evenly between all other hypotheses. double weight = (baseCurIndex == actualBaseCurIndex) ? actualWeight : ((1.0 - actualWeight)/3.0); @@ -95,73 +108,58 @@ public class BasecallingBaseModel { alg.multOuter(sub, sub, cov); cov.assign(F.mult(weight)); - inverseCovariances[baseCurIndex].assign(cov, F.plus); + unscaledCovarianceSums[baseCurIndex].assign(cov, F.plus); } + + readyToCall = false; } /** * Precompute all the matrix inversions and determinants we'll need for computing the likelihood distributions. */ public void prepareToCallBases() { - /* - for (int basePrevIndex = 0; basePrevIndex < 4; basePrevIndex++) { - for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { - for (int channel = 0; channel < 4; channel++) { - sums[baseCurIndex].setQuick(channel, runningChannelSums[basePrevIndex][baseCurIndex].getQuick(channel)/counts[basePrevIndex][baseCurIndex]); - - for (int cochannel = 0; cochannel < 4; cochannel++) { - // Cov(Xi, Xj) = E(XiXj) - E(Xi)E(Xj) - inverseCovariances[basePrevIndex][baseCurIndex].setQuick(channel, cochannel, (runningChannelProductSums[basePrevIndex][baseCurIndex].getQuick(channel, cochannel)/counts[basePrevIndex][baseCurIndex]) - (runningChannelSums[basePrevIndex][baseCurIndex].getQuick(channel)/counts[basePrevIndex][baseCurIndex])*(runningChannelSums[basePrevIndex][baseCurIndex].getQuick(cochannel)/counts[basePrevIndex][baseCurIndex])); - } - } + for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { + means[baseCurIndex] = sums[baseCurIndex].copy(); + means[baseCurIndex].assign(F.div(counts[baseCurIndex])); - DoubleMatrix2D invcov = alg.inverse(inverseCovariances[basePrevIndex][baseCurIndex]); - inverseCovariances[basePrevIndex][baseCurIndex] = invcov; - norms[basePrevIndex][baseCurIndex] = Math.pow(alg.det(invcov), 0.5)/Math.pow(2.0*Math.PI, 2.0); - } + inverseCovariances[baseCurIndex] = unscaledCovarianceSums[baseCurIndex].copy(); + inverseCovariances[baseCurIndex].assign(F.div(counts[baseCurIndex])); + DoubleMatrix2D invcov = alg.inverse(inverseCovariances[baseCurIndex]); + inverseCovariances[baseCurIndex] = invcov; + + norms[baseCurIndex] = Math.pow(alg.det(invcov), 0.5)/Math.pow(2.0*Math.PI, 2.0); } - */ readyToCall = true; } /** - * Compute the likelihood matrix for a base (contextual priors included). + * Compute the likelihood matrix for a base * * @param cycle the cycle we're calling right now - * @param basePrev the previous cycle's base - * @param qualPrev the previous cycle's quality score * @param fourintensity the four intensities of the current cycle's base * @return a 4x4 matrix of likelihoods, where the row is the previous cycle base hypothesis and * the column is the current cycle base hypothesis */ - public double[][] computeLikelihoods(int cycle, char basePrev, byte qualPrev, double[] fourintensity) { + public double[] computeLikelihoods(int cycle, double[] fourintensity) { if (!readyToCall) { prepareToCallBases(); } - double[][] probdist = new double[4][4]; - /* - double probPrev = (cycle == 0) ? 1.0 : QualityUtils.qualToProb(qualPrev); - int baseIndex = (cycle == 0) ? 0 : baseToBaseIndex(basePrev); + double[] likedist = new double[4]; + for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { + double norm = norms[baseCurIndex]; - for (int basePrevIndex = 0; basePrevIndex < ((cycle == 0) ? 1 : 4); basePrevIndex++) { - for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { - double[] diff = new double[4]; - for (int channel = 0; channel < 4; channel++) { - diff[channel] = fourintensity[channel] - sums[basePrevIndex][baseCurIndex].getQuick(channel); - } - - DoubleMatrix1D sub = (DoubleFactory1D.dense).make(diff); - DoubleMatrix1D Ax = alg.mult(inverseCovariances[basePrevIndex][baseCurIndex], sub); + DoubleMatrix1D sub = (DoubleFactory1D.dense).make(fourintensity); + sub.assign(means[baseCurIndex], F.minus); - double exparg = -0.5*alg.mult(sub, Ax); - probdist[basePrevIndex][baseCurIndex] = (baseIndex == basePrevIndex ? probPrev : 1.0 - probPrev)*norms[basePrevIndex][baseCurIndex]*Math.exp(exparg); - } + DoubleMatrix1D Ax = alg.mult(inverseCovariances[baseCurIndex], sub); + double exparg = -0.5*alg.mult(sub, Ax); + + likedist[baseCurIndex] = norm*Math.exp(exparg); } - */ - return probdist; + return likedist; } public void write(File outparam) { @@ -176,8 +174,7 @@ public class BasecallingBaseModel { } writer.print("] (" + counts[baseCurIndex] + ")\n"); - DoubleMatrix2D cov = inverseCovariances[baseCurIndex].copy(); - cern.jet.math.Functions F = cern.jet.math.Functions.functions; + DoubleMatrix2D cov = unscaledCovarianceSums[baseCurIndex].copy(); cov.assign(F.div(counts[baseCurIndex])); writer.println("cov_" + baseIndexToBase(baseCurIndex) + " : " + cov + "\n"); diff --git a/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingReadModel.java b/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingReadModel.java index 0069c2e3b..c2cf83a67 100644 --- a/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingReadModel.java +++ b/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingReadModel.java @@ -48,45 +48,30 @@ public class BasecallingReadModel { * Compute the likelihood matrix for a given cycle. * * @param cycle the cycle number for the current base - * @param basePrev the previous cycle's base - * @param qualPrev the quality score for the previous cycle's base * @param fourintensity the four intensities for the current cycle's base * @return 4x4 matrix of likelihoods */ - public double[][] computeLikelihoods(int cycle, char basePrev, byte qualPrev, double[] fourintensity) { - return basemodels[cycle].computeLikelihoods(cycle, basePrev, qualPrev, fourintensity); + public double[] computeLikelihoods(int cycle, double[] fourintensity) { + return basemodels[cycle].computeLikelihoods(cycle, fourintensity); } /** * Compute the probability distribution for the base at a given cycle. - * Contextual components of the likelihood matrix are marginalized out. * * @param cycle the cycle number for the current base - * @param basePrev the previous cycle's base - * @param qualPrev the quality score for the previous cycle's base * @param fourintensity the four intensities for the current cycle's base * @return an instance of FourProb, which encodes a base hypothesis, its probability, * and the ranking among the other hypotheses */ - public FourProb computeProbabilities(int cycle, char basePrev, byte qualPrev, double[] fourintensity) { - double[][] likes = computeLikelihoods(cycle, basePrev, qualPrev, fourintensity); + public FourProb computeProbabilities(int cycle, double[] fourintensity) { + double[] likes = computeLikelihoods(cycle, fourintensity); - double[] probs = new double[4]; - int[] baseindices = { 0, 1, 2, 3 }; double total = 0; - for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { - for (int basePrevIndex = 0; basePrevIndex < 4; basePrevIndex++) { - probs[baseCurIndex] += likes[basePrevIndex][baseCurIndex]; - } - total += probs[baseCurIndex]; - } + for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { total += likes[baseCurIndex]; } + for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { likes[baseCurIndex] /= total; } - for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { - probs[baseCurIndex] /= total; - } - - return new FourProb(baseindices, probs); + return new FourProb(likes); } /** diff --git a/java/src/org/broadinstitute/sting/playground/fourbasecaller/FourBaseRecaller.java b/java/src/org/broadinstitute/sting/playground/fourbasecaller/FourBaseRecaller.java index da79727bf..5d777d712 100644 --- a/java/src/org/broadinstitute/sting/playground/fourbasecaller/FourBaseRecaller.java +++ b/java/src/org/broadinstitute/sting/playground/fourbasecaller/FourBaseRecaller.java @@ -1,6 +1,9 @@ package org.broadinstitute.sting.playground.fourbasecaller; import java.io.File; +import java.io.FileWriter; +import java.io.IOException; +import java.io.PrintWriter; import org.broadinstitute.sting.utils.cmdLine.CommandLineProgram; import org.broadinstitute.sting.utils.QualityUtils; @@ -38,6 +41,15 @@ public class FourBaseRecaller extends CommandLineProgram { protected int execute() { boolean isPaired = (END > 0); + // Set up debugging paths + File debugdir = new File(OUT.getPath() + ".debug/"); + debugdir.mkdir(); + PrintWriter debugout = null; + try { + debugout = new PrintWriter(debugdir.getPath() + "/debug.out"); + } catch (IOException e) { + } + BustardFileParser bfp; BustardReadData bread; @@ -49,7 +61,7 @@ public class FourBaseRecaller extends CommandLineProgram { int queryid; // learn mean parameters - //System.out.println("intensity int_a int_c int_g int_t base"); + if (debugout != null) { debugout.println("intensity int_a int_c int_g int_t base"); } queryid = 0; do { @@ -62,11 +74,9 @@ public class FourBaseRecaller extends CommandLineProgram { byte qualCur = quals[cycle]; double[] fourintensity = intensities[cycle + cycle_offset]; - /* - if (cycle == 0) { - System.out.println("intensity " + intensities[0][0] + " " + intensities[0][1] + " " + intensities[0][2] + " " + intensities[0][3] + " " + baseCur); + if (debugout != null && cycle == 0) { + debugout.println("intensity " + intensities[0][0] + " " + intensities[0][1] + " " + intensities[0][2] + " " + intensities[0][3] + " " + baseCur); } - */ model.addMeanPoint(cycle, baseCur, qualCur, fourintensity); } @@ -96,12 +106,9 @@ public class FourBaseRecaller extends CommandLineProgram { } while (queryid < TRAINING_LIMIT && bfp.hasNext() && (bread = bfp.next()) != null); // write debugging info - File debugout = new File(OUT.getParentFile().getPath() + "/model/"); - debugout.mkdir(); - model.write(debugout); + model.write(debugdir); // call bases - /* SAMFileHeader sfh = new SAMFileHeader(); SAMFileWriter sfw = new SAMFileWriterFactory().makeSAMOrBAMWriter(sfh, false, OUT); @@ -119,11 +126,13 @@ public class FourBaseRecaller extends CommandLineProgram { byte[] nextbestqual = new byte[bases.length()]; for (int cycle = 0; cycle < bases.length(); cycle++) { - char basePrev = (cycle == 0) ? '*' : (char) asciiseq[cycle - 1]; - byte qualPrev = (cycle == 0) ? 0 : bestqual[cycle - 1]; double[] fourintensity = intensities[cycle + cycle_offset]; - FourProb fp = model.computeProbabilities(cycle, basePrev, qualPrev, fourintensity); + FourProb fp = model.computeProbabilities(cycle, fourintensity); + + //if (cycle == 0) { + // System.out.println("result " + intensities[0][0] + " " + intensities[0][1] + " " + intensities[0][2] + " " + intensities[0][3] + " " + bases.charAt(0) + " " + fp.toString()); + //} asciiseq[cycle] = (byte) fp.baseAtRank(0); bestqual[cycle] = fp.qualAtRank(0); @@ -137,7 +146,6 @@ public class FourBaseRecaller extends CommandLineProgram { } while (queryid < CALLING_LIMIT && bfp.hasNext() && (bread = bfp.next()) != null); sfw.close(); - */ return 0; } diff --git a/java/src/org/broadinstitute/sting/playground/fourbasecaller/FourProb.java b/java/src/org/broadinstitute/sting/playground/fourbasecaller/FourProb.java index 065abefe4..64ff062b3 100755 --- a/java/src/org/broadinstitute/sting/playground/fourbasecaller/FourProb.java +++ b/java/src/org/broadinstitute/sting/playground/fourbasecaller/FourProb.java @@ -15,10 +15,11 @@ public class FourProb { /** * Constructor for FourProb. * - * @param baseIndices the unsorted base indices (A:0, C:1, G:2, T:3). Now that I think about it, this is stupid. - * @param baseProbs the unsorted base hypothesis probabilities. + * @param baseProbs the unsorted base hypothesis probabilities (in ACGT order). */ - public FourProb(int[] baseIndices, double[] baseProbs) { + public FourProb(double[] baseProbs) { + int[] baseIndices = {0, 1, 2, 3}; + Integer[] perm = Utils.SortPermutation(baseProbs); double[] ascendingBaseProbs = Utils.PermuteArray(baseProbs, perm); int[] ascendingBaseIndices = Utils.PermuteArray(baseIndices, perm);