diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GaussianMixtureModel.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GaussianMixtureModel.java index eef9da84a..92b0d4df2 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GaussianMixtureModel.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GaussianMixtureModel.java @@ -47,7 +47,6 @@ package org.broadinstitute.sting.gatk.walkers.variantrecalibration; import Jama.Matrix; -import cern.jet.random.Normal; import org.apache.log4j.Logger; import org.broadinstitute.sting.gatk.GenomeAnalysisEngine; import org.broadinstitute.sting.utils.MathUtils; @@ -243,12 +242,10 @@ public class GaussianMixtureModel { public Double evaluateDatumInOneDimension( final VariantDatum datum, final int iii ) { if(datum.isNull[iii]) { return null; } - final Normal normal = new Normal(0.0, 1.0, null); final double[] pVarInGaussianLog10 = new double[gaussians.size()]; int gaussianIndex = 0; for( final MultivariateGaussian gaussian : gaussians ) { - normal.setState( gaussian.mu[iii], gaussian.sigma.get(iii, iii) ); - pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + Math.log10( normal.pdf( datum.annotations[iii] ) ); + pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + MathUtils.normalDistributionLog10(gaussian.mu[iii], gaussian.sigma.get(iii, iii), datum.annotations[iii]); } return MathUtils.log10sumLog10(pVarInGaussianLog10); // Sum(pi_k * p(v|n,k)) } diff --git a/public/java/src/org/broadinstitute/sting/utils/MathUtils.java b/public/java/src/org/broadinstitute/sting/utils/MathUtils.java index 38c131bc6..c8cf9d6a1 100644 --- a/public/java/src/org/broadinstitute/sting/utils/MathUtils.java +++ b/public/java/src/org/broadinstitute/sting/utils/MathUtils.java @@ -63,6 +63,8 @@ public class MathUtils { */ public final static double LOG10_P_OF_ZERO = -1000000.0; public final static double FAIR_BINOMIAL_PROB_LOG10_0_5 = Math.log10(0.5); + private final static double NATURAL_LOG_OF_TEN = Math.log(10.0); + private final static double SQUARE_ROOT_OF_TWO_TIMES_PI = Math.sqrt(2.0 * Math.PI); static { log10Cache = new double[LOG10_CACHE_SIZE]; @@ -301,12 +303,46 @@ public class MathUtils { return 1; } - public static double NormalDistribution(final double mean, final double sd, final double x) { - double a = 1.0 / (sd * Math.sqrt(2.0 * Math.PI)); - double b = Math.exp(-1.0 * (Math.pow(x - mean, 2.0) / (2.0 * sd * sd))); + /** + * Calculate f(x) = Normal(x | mu = mean, sigma = sd) + * @param mean the desired mean of the Normal distribution + * @param sd the desired standard deviation of the Normal distribution + * @param x the value to evaluate + * @return a well-formed double + */ + public static double normalDistribution(final double mean, final double sd, final double x) { + final double a = 1.0 / (sd * SQUARE_ROOT_OF_TWO_TIMES_PI); + final double b = Math.exp(-1.0 * (square(x - mean) / (2.0 * square(sd)))); return a * b; } + /** + * Calculate f(x) = log10 ( Normal(x | mu = mean, sigma = sd) ) + * @param mean the desired mean of the Normal distribution + * @param sd the desired standard deviation of the Normal distribution + * @param x the value to evaluate + * @return a well-formed double + */ + + public static double normalDistributionLog10(final double mean, final double sd, final double x) { + if( sd < 0 ) + throw new IllegalArgumentException("sd: Standard deviation of normal must be >0"); + if ( ! wellFormedDouble(mean) || ! wellFormedDouble(sd) || ! wellFormedDouble(x) ) + throw new IllegalArgumentException("mean, sd, or, x : Normal parameters must be well formatted (non-INF, non-NAN)"); + final double a = -1.0 * Math.log10(sd * SQUARE_ROOT_OF_TWO_TIMES_PI); + final double b = -1.0 * (square(x - mean) / (2.0 * square(sd))) / NATURAL_LOG_OF_TEN; + return a + b; + } + + /** + * Calculate f(x) = x^2 + * @param x the value to square + * @return x * x + */ + public static double square(final double x) { + return x * x; + } + /** * Calculates the log10 of the binomial coefficient. Designed to prevent * overflows even with very large numbers. diff --git a/public/java/src/org/broadinstitute/sting/utils/activeregion/BandPassActivityProfile.java b/public/java/src/org/broadinstitute/sting/utils/activeregion/BandPassActivityProfile.java index f2bc86dfc..f352bc332 100644 --- a/public/java/src/org/broadinstitute/sting/utils/activeregion/BandPassActivityProfile.java +++ b/public/java/src/org/broadinstitute/sting/utils/activeregion/BandPassActivityProfile.java @@ -31,7 +31,6 @@ import org.broadinstitute.sting.utils.GenomeLocParser; import org.broadinstitute.sting.utils.GenomeLocSortedSet; import org.broadinstitute.sting.utils.MathUtils; -import java.util.ArrayList; import java.util.Collection; import java.util.LinkedList; @@ -108,7 +107,7 @@ public class BandPassActivityProfile extends ActivityProfile { final int bandSize = 2 * filterSize + 1; final double[] kernel = new double[bandSize]; for( int iii = 0; iii < bandSize; iii++ ) { - kernel[iii] = MathUtils.NormalDistribution(filterSize, sigma, iii); + kernel[iii] = MathUtils.normalDistribution(filterSize, sigma, iii); } return MathUtils.normalizeFromRealSpace(kernel); } diff --git a/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java index 27af8ec68..e4c74a0ad 100644 --- a/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java @@ -25,6 +25,7 @@ package org.broadinstitute.sting.utils; +import cern.jet.random.Normal; import org.broadinstitute.sting.BaseTest; import org.testng.Assert; import org.testng.annotations.BeforeClass; @@ -398,4 +399,20 @@ public class MathUtilsUnitTest extends BaseTest { Assert.assertEquals(MathUtils.logDotProduct(new double[]{-5.0,-3.0,2.0}, new double[]{6.0,7.0,8.0}),10.0,1e-3); Assert.assertEquals(MathUtils.logDotProduct(new double[]{-5.0}, new double[]{6.0}),1.0,1e-3); } + + @Test + public void testNormalDistribution() { + final double requiredPrecision = 1E-10; + + final Normal n = new Normal(0.0, 1.0, null); + for( final double mu : new double[]{-5.0, -3.2, -1.5, 0.0, 1.2, 3.0, 5.8977} ) { + for( final double sigma : new double[]{1.2, 3.0, 5.8977} ) { + for( final double x : new double[]{-5.0, -3.2, -1.5, 0.0, 1.2, 3.0, 5.8977} ) { + n.setState(mu, sigma); + Assert.assertEquals(n.pdf(x), MathUtils.normalDistribution(mu, sigma, x), requiredPrecision); + Assert.assertEquals(Math.log10(n.pdf(x)), MathUtils.normalDistributionLog10(mu, sigma, x), requiredPrecision); + } + } + } + } } diff --git a/public/java/test/org/broadinstitute/sting/utils/activeregion/ActivityProfileUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/activeregion/ActivityProfileUnitTest.java index 9be250b8e..f208815f7 100644 --- a/public/java/test/org/broadinstitute/sting/utils/activeregion/ActivityProfileUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/activeregion/ActivityProfileUnitTest.java @@ -450,7 +450,7 @@ public class ActivityProfileUnitTest extends BaseTest { private double[] makeGaussian(final int mean, final int range, final double sigma) { final double[] gauss = new double[range]; for( int iii = 0; iii < range; iii++ ) { - gauss[iii] = MathUtils.NormalDistribution(mean, sigma, iii) + ActivityProfile.ACTIVE_PROB_THRESHOLD; + gauss[iii] = MathUtils.normalDistribution(mean, sigma, iii) + ActivityProfile.ACTIVE_PROB_THRESHOLD; } return gauss; }