From 9be978e0068c760a83ed0dbb2a4dc58cc115e17e Mon Sep 17 00:00:00 2001 From: kiran Date: Tue, 7 Apr 2009 01:20:15 +0000 Subject: [PATCH] Intermediate commit (debugging info). git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@309 348d0f76-0448-11de-a6fe-93d51630548a --- .../fourbasecaller/BasecallingBaseModel.java | 129 +++++++++--------- .../fourbasecaller/BasecallingReadModel.java | 9 +- .../fourbasecaller/FourBaseRecaller.java | 37 ++++- 3 files changed, 104 insertions(+), 71 deletions(-) diff --git a/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingBaseModel.java b/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingBaseModel.java index 41cee7de4..8eb3918a1 100755 --- a/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingBaseModel.java +++ b/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingBaseModel.java @@ -21,39 +21,28 @@ import java.io.*; * @author Kiran Garimella */ public class BasecallingBaseModel { - private double[][] counts; - - private DoubleMatrix1D[][] runningChannelSums; - private DoubleMatrix2D[][] runningChannelProductSums; - - private boolean readyToCall = false; - private DoubleMatrix1D[][] means; - private DoubleMatrix2D[][] inverseCovariances; - private double[][] norms; + private double[] counts; + private DoubleMatrix1D[] sums; + private DoubleMatrix2D[] inverseCovariances; + private double[] norms; private Algebra alg; + private boolean readyToCall = false; + /** * Constructor for BasecallingBaseModel */ public BasecallingBaseModel() { - counts = new double[4][4]; + counts = new double[4]; - runningChannelSums = new DoubleMatrix1D[4][4]; - runningChannelProductSums = new DoubleMatrix2D[4][4]; + sums = new DoubleMatrix1D[4]; + inverseCovariances = new DoubleMatrix2D[4]; + norms = new double[4]; - means = new DoubleMatrix1D[4][4]; - inverseCovariances = new DoubleMatrix2D[4][4]; - norms = new double[4][4]; - - for (int basePrevIndex = 0; basePrevIndex < 4; basePrevIndex++) { - for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { - runningChannelSums[basePrevIndex][baseCurIndex] = (DoubleFactory1D.dense).make(4); - runningChannelProductSums[basePrevIndex][baseCurIndex] = (DoubleFactory2D.dense).make(4, 4); - - means[basePrevIndex][baseCurIndex] = (DoubleFactory1D.dense).make(4); - inverseCovariances[basePrevIndex][baseCurIndex] = (DoubleFactory2D.dense).make(4, 4); - } + for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { + sums[baseCurIndex] = (DoubleFactory1D.dense).make(4); + inverseCovariances[baseCurIndex] = (DoubleFactory2D.dense).make(4, 4); } alg = new Algebra(); @@ -62,61 +51,63 @@ public class BasecallingBaseModel { /** * Add a single training point to the model. * - * @param basePrev the previous cycle's base call (A, C, G, T, or * for the first cycle) * @param baseCur the current cycle's base call (A, C, G, T) * @param qualCur the quality score for the current cycle's base call * @param fourintensity the four intensities for the current cycle's base call */ - public void addTrainingPoint(char basePrev, char baseCur, byte qualCur, double[] fourintensity) { - int actualBasePrevIndex = baseToBaseIndex(basePrev); + public void addMeanPoint(char baseCur, byte qualCur, double[] fourintensity) { int actualBaseCurIndex = baseToBaseIndex(baseCur); double actualWeight = QualityUtils.qualToProb(qualCur); - double otherTheories = (basePrev == '*') ? 3.0 : 15.0; cern.jet.math.Functions F = cern.jet.math.Functions.functions; - for (int basePrevIndex = 0; basePrevIndex < 4; basePrevIndex++) { - for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { - // We want to upweight the correct theory as much as we can and spread the remainder out evenly between all other hypotheses. - double weight = (basePrevIndex == actualBasePrevIndex && baseCurIndex == actualBaseCurIndex) ? actualWeight : ((1.0 - actualWeight)/otherTheories); + for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { + // We want to upweight the correct theory as much as we can and spread the remainder out evenly between all other hypotheses. + double weight = (baseCurIndex == actualBaseCurIndex) ? actualWeight : ((1.0 - actualWeight)/3.0); - DoubleMatrix1D weightedChannelIntensities = (DoubleFactory1D.dense).make(fourintensity); - weightedChannelIntensities.assign(F.mult(weight)); + DoubleMatrix1D weightedChannelIntensities = (DoubleFactory1D.dense).make(fourintensity); + weightedChannelIntensities.assign(F.mult(weight)); - runningChannelSums[basePrevIndex][baseCurIndex].assign(weightedChannelIntensities, F.plus); - counts[basePrevIndex][baseCurIndex] += weight; - } + sums[baseCurIndex].assign(weightedChannelIntensities, F.plus); + counts[baseCurIndex] += weight; } - /* - if (basePrevIndex >= 0 && baseCurIndex >= 0) { - for (int channel = 0; channel < 4; channel++) { - double weight = QualityUtils.qualToProb(qualCur); - double channelIntensity = fourintensity[channel]; - - runningChannelSums[basePrevIndex][baseCurIndex].setQuick(channel, runningChannelSums[basePrevIndex][baseCurIndex].getQuick(channel) + weight*channelIntensity); - - for (int cochannel = 0; cochannel < 4; cochannel++) { - double cochannelIntensity = fourintensity[cochannel]; - runningChannelProductSums[basePrevIndex][baseCurIndex].setQuick(channel, cochannel, runningChannelProductSums[basePrevIndex][baseCurIndex].getQuick(channel, cochannel) + weight*channelIntensity*cochannelIntensity); - } - } - - counts[basePrevIndex][baseCurIndex]++; - } - */ - readyToCall = false; } + public void addCovariancePoint(char baseCur, byte qualCur, double[] fourintensity) { + int actualBaseCurIndex = baseToBaseIndex(baseCur); + double actualWeight = QualityUtils.qualToProb(qualCur); + + cern.jet.math.Functions F = cern.jet.math.Functions.functions; + + for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { + // We want to upweight the correct theory as much as we can and spread the remainder out evenly between all other hypotheses. + double weight = (baseCurIndex == actualBaseCurIndex) ? actualWeight : ((1.0 - actualWeight)/3.0); + + DoubleMatrix1D mean = sums[baseCurIndex].copy(); + mean.assign(F.div(counts[baseCurIndex])); + + DoubleMatrix1D sub = (DoubleFactory1D.dense).make(fourintensity); + sub.assign(mean, F.minus); + + DoubleMatrix2D cov = (DoubleFactory2D.dense).make(4, 4); + alg.multOuter(sub, sub, cov); + + cov.assign(F.mult(weight)); + inverseCovariances[baseCurIndex].assign(cov, F.plus); + } + } + /** * Precompute all the matrix inversions and determinants we'll need for computing the likelihood distributions. */ public void prepareToCallBases() { + /* for (int basePrevIndex = 0; basePrevIndex < 4; basePrevIndex++) { for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { for (int channel = 0; channel < 4; channel++) { - means[basePrevIndex][baseCurIndex].setQuick(channel, runningChannelSums[basePrevIndex][baseCurIndex].getQuick(channel)/counts[basePrevIndex][baseCurIndex]); + sums[baseCurIndex].setQuick(channel, runningChannelSums[basePrevIndex][baseCurIndex].getQuick(channel)/counts[basePrevIndex][baseCurIndex]); for (int cochannel = 0; cochannel < 4; cochannel++) { // Cov(Xi, Xj) = E(XiXj) - E(Xi)E(Xj) @@ -129,6 +120,7 @@ public class BasecallingBaseModel { norms[basePrevIndex][baseCurIndex] = Math.pow(alg.det(invcov), 0.5)/Math.pow(2.0*Math.PI, 2.0); } } + */ readyToCall = true; } @@ -149,6 +141,7 @@ public class BasecallingBaseModel { } double[][] probdist = new double[4][4]; + /* double probPrev = (cycle == 0) ? 1.0 : QualityUtils.qualToProb(qualPrev); int baseIndex = (cycle == 0) ? 0 : baseToBaseIndex(basePrev); @@ -156,7 +149,7 @@ public class BasecallingBaseModel { for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { double[] diff = new double[4]; for (int channel = 0; channel < 4; channel++) { - diff[channel] = fourintensity[channel] - means[basePrevIndex][baseCurIndex].getQuick(channel); + diff[channel] = fourintensity[channel] - sums[basePrevIndex][baseCurIndex].getQuick(channel); } DoubleMatrix1D sub = (DoubleFactory1D.dense).make(diff); @@ -166,6 +159,7 @@ public class BasecallingBaseModel { probdist[basePrevIndex][baseCurIndex] = (baseIndex == basePrevIndex ? probPrev : 1.0 - probPrev)*norms[basePrevIndex][baseCurIndex]*Math.exp(exparg); } } + */ return probdist; } @@ -174,15 +168,19 @@ public class BasecallingBaseModel { try { PrintWriter writer = new PrintWriter(outparam); - for (int basePrevIndex = 0; basePrevIndex < 4; basePrevIndex++) { - for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { - writer.print("mean_" + baseIndexToBase(basePrevIndex) + "" + baseIndexToBase(baseCurIndex) + " : [ "); - for (int channel = 0; channel < 4; channel++) { - writer.print(runningChannelSums[basePrevIndex][baseCurIndex].getQuick(channel)/counts[basePrevIndex][baseCurIndex]); - writer.print(" "); - } - writer.print("]\n"); + for (int baseCurIndex = 0; baseCurIndex < 4; baseCurIndex++) { + writer.print("mean_" + baseIndexToBase(baseCurIndex) + " : [ "); + for (int channel = 0; channel < 4; channel++) { + writer.print(sums[baseCurIndex].getQuick(channel)/counts[baseCurIndex]); + writer.print(" "); } + writer.print("] (" + counts[baseCurIndex] + ")\n"); + + DoubleMatrix2D cov = inverseCovariances[baseCurIndex].copy(); + cern.jet.math.Functions F = cern.jet.math.Functions.functions; + cov.assign(F.div(counts[baseCurIndex])); + + writer.println("cov_" + baseIndexToBase(baseCurIndex) + " : " + cov + "\n"); } writer.close(); @@ -193,6 +191,7 @@ public class BasecallingBaseModel { /** * Utility method for converting a base ([Aa*], [Cc], [Gg], [Tt]) to an index (0, 1, 2, 3); + * * @param base * @return 0, 1, 2, 3, or -1 if the base can't be understood. */ diff --git a/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingReadModel.java b/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingReadModel.java index 59e2c6506..0069c2e3b 100644 --- a/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingReadModel.java +++ b/java/src/org/broadinstitute/sting/playground/fourbasecaller/BasecallingReadModel.java @@ -32,13 +32,16 @@ public class BasecallingReadModel { * Add a single training point to the model. * * @param cycle the cycle for which this point should be added - * @param basePrev the previous base * @param baseCur the current base * @param qualCur the current base's quality * @param fourintensity the four intensities of the current base */ - public void addTrainingPoint(int cycle, char basePrev, char baseCur, byte qualCur, double[] fourintensity) { - basemodels[cycle].addTrainingPoint(basePrev, baseCur, qualCur, fourintensity); + public void addMeanPoint(int cycle, char baseCur, byte qualCur, double[] fourintensity) { + basemodels[cycle].addMeanPoint(baseCur, qualCur, fourintensity); + } + + public void addCovariancePoint(int cycle, char baseCur, byte qualCur, double[] fourintensity) { + basemodels[cycle].addCovariancePoint(baseCur, qualCur, fourintensity); } /** diff --git a/java/src/org/broadinstitute/sting/playground/fourbasecaller/FourBaseRecaller.java b/java/src/org/broadinstitute/sting/playground/fourbasecaller/FourBaseRecaller.java index 7169c78e4..da79727bf 100644 --- a/java/src/org/broadinstitute/sting/playground/fourbasecaller/FourBaseRecaller.java +++ b/java/src/org/broadinstitute/sting/playground/fourbasecaller/FourBaseRecaller.java @@ -48,7 +48,9 @@ public class FourBaseRecaller extends CommandLineProgram { BasecallingReadModel model = new BasecallingReadModel(bread.getFirstReadSequence().length()); int queryid; - // learn initial parameters + // learn mean parameters + //System.out.println("intensity int_a int_c int_g int_t base"); + queryid = 0; do { String bases = (END <= 1) ? bread.getFirstReadSequence() : bread.getSecondReadSequence(); @@ -56,22 +58,50 @@ public class FourBaseRecaller extends CommandLineProgram { double[][] intensities = bread.getIntensities(); for (int cycle = 0; cycle < bases.length(); cycle++) { - char basePrev = (cycle == 0) ? '*' : bases.charAt(cycle - 1); char baseCur = bases.charAt(cycle); byte qualCur = quals[cycle]; double[] fourintensity = intensities[cycle + cycle_offset]; - model.addTrainingPoint(cycle, basePrev, baseCur, qualCur, fourintensity); + /* + if (cycle == 0) { + System.out.println("intensity " + intensities[0][0] + " " + intensities[0][1] + " " + intensities[0][2] + " " + intensities[0][3] + " " + baseCur); + } + */ + + model.addMeanPoint(cycle, baseCur, qualCur, fourintensity); } queryid++; } while (queryid < TRAINING_LIMIT && bfp.hasNext() && (bread = bfp.next()) != null); + // learn covariance parameters + bfp = new BustardFileParser(DIR, LANE, isPaired, "FB"); + bread = bfp.next(); + + queryid = 0; + do { + String bases = (END <= 1) ? bread.getFirstReadSequence() : bread.getSecondReadSequence(); + byte[] quals = (END <= 1) ? bread.getFirstReadPhredBinaryQualities() : bread.getSecondReadPhredBinaryQualities(); + double[][] intensities = bread.getIntensities(); + + for (int cycle = 0; cycle < bases.length(); cycle++) { + char baseCur = bases.charAt(cycle); + byte qualCur = quals[cycle]; + double[] fourintensity = intensities[cycle + cycle_offset]; + + model.addCovariancePoint(cycle, baseCur, qualCur, fourintensity); + } + + queryid++; + } while (queryid < TRAINING_LIMIT && bfp.hasNext() && (bread = bfp.next()) != null); + + // write debugging info File debugout = new File(OUT.getParentFile().getPath() + "/model/"); debugout.mkdir(); model.write(debugout); // call bases + /* SAMFileHeader sfh = new SAMFileHeader(); SAMFileWriter sfw = new SAMFileWriterFactory().makeSAMOrBAMWriter(sfh, false, OUT); @@ -107,6 +137,7 @@ public class FourBaseRecaller extends CommandLineProgram { } while (queryid < CALLING_LIMIT && bfp.hasNext() && (bread = bfp.next()) != null); sfw.close(); + */ return 0; }