diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantutils/ConcordanceMetricsUnitTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantutils/ConcordanceMetricsUnitTest.java index 2e31f6725..bca912d63 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantutils/ConcordanceMetricsUnitTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantutils/ConcordanceMetricsUnitTest.java @@ -203,6 +203,27 @@ public class ConcordanceMetricsUnitTest extends BaseTest { Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.EVAL_SUPERSET_TRUTH.ordinal()],1); Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.ALLELES_DO_NOT_MATCH.ordinal()],0); Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.ALLELES_MATCH.ordinal()],0); + + // now flip them around + + eval = data.getSecond(); + truth = data.getFirst(); + codec = new VCFCodec(); + evalHeader = (VCFHeader)codec.readHeader(new AsciiLineReader(new PositionalBufferedStream(new StringBufferInputStream(TEST_1_HEADER)))); + compHeader = (VCFHeader)codec.readHeader(new AsciiLineReader(new PositionalBufferedStream(new StringBufferInputStream(TEST_1_HEADER)))); + metrics = new ConcordanceMetrics(evalHeader,compHeader); + metrics.update(eval,truth); + Assert.assertEquals(eval.getGenotype("test1_sample2").getType().ordinal(), 2); + Assert.assertEquals(truth.getGenotype("test1_sample2").getType().ordinal(),2); + Assert.assertEquals(metrics.getGenotypeConcordance("test1_sample2").getnMismatchingAlt(),1); + Assert.assertEquals(metrics.getGenotypeConcordance("test1_sample2").getTable()[1][2],0); + Assert.assertEquals(metrics.getGenotypeConcordance("test1_sample3").getTable()[1][2],0); + Assert.assertEquals(metrics.getGenotypeConcordance("test1_sample3").getTable()[3][2],1); + Assert.assertEquals(metrics.getOverallGenotypeConcordance().getTable()[1][1],1); + Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.EVAL_SUPERSET_TRUTH.ordinal()],0); + Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.EVAL_SUBSET_TRUTH.ordinal()],1); + Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.ALLELES_DO_NOT_MATCH.ordinal()],0); + Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.ALLELES_MATCH.ordinal()],0); } private Pair getData3() { diff --git a/public/java/src/org/broadinstitute/sting/gatk/examples/GATKPaperGenotyper.java b/public/java/src/org/broadinstitute/sting/gatk/examples/GATKPaperGenotyper.java index 7b56852d3..07ec088cf 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/examples/GATKPaperGenotyper.java +++ b/public/java/src/org/broadinstitute/sting/gatk/examples/GATKPaperGenotyper.java @@ -40,7 +40,8 @@ import org.broadinstitute.sting.utils.help.HelpConstants; import org.broadinstitute.sting.utils.pileup.ReadBackedPileup; import java.io.PrintStream; - +import java.util.Arrays; +import java.util.Comparator; /** * A simple Bayesian genotyper, that outputs a text based call format. Intended to be used only as an @@ -95,7 +96,7 @@ public class GATKPaperGenotyper extends LocusWalker implements Tre likelihoods[genotype.ordinal()] += Math.log10(p / genotype.toString().length()); } - Integer sortedList[] = MathUtils.sortPermutation(likelihoods); + Integer sortedList[] = sortPermutation(likelihoods); // create call using the best genotype (GENOTYPE.values()[sortedList[9]].toString()) // and calculate the LOD score from best - next best (9 and 8 in the sorted list, since the best likelihoods are closest to zero) @@ -110,6 +111,29 @@ public class GATKPaperGenotyper extends LocusWalker implements Tre return 0; } + private static Integer[] sortPermutation(final double[] A) { + class comparator implements Comparator { + public int compare(Integer a, Integer b) { + if (A[a.intValue()] < A[b.intValue()]) { + return -1; + } + if (A[a.intValue()] == A[b.intValue()]) { + return 0; + } + if (A[a.intValue()] > A[b.intValue()]) { + return 1; + } + return 0; + } + } + Integer[] permutation = new Integer[A.length]; + for (int i = 0; i < A.length; i++) { + permutation[i] = i; + } + Arrays.sort(permutation, new comparator()); + return permutation; + } + /** * Takes reference base, and three priors for hom-ref, het, hom-var, and fills in the priors vector * appropriately. diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/variantutils/ConcordanceMetrics.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/variantutils/ConcordanceMetrics.java index efb84edef..b3b4857b6 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/variantutils/ConcordanceMetrics.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/variantutils/ConcordanceMetrics.java @@ -102,15 +102,16 @@ public class ConcordanceMetrics { public void update(VariantContext eval, VariantContext truth) { overallSiteConcordance.update(eval,truth); Set alleleTruth = new HashSet(8); - alleleTruth.add(truth.getReference().getBaseString()); + String truthRef = truth.getReference().getBaseString(); + alleleTruth.add(truthRef); for ( Allele a : truth.getAlternateAlleles() ) { alleleTruth.add(a.getBaseString()); } for ( String sample : perSampleGenotypeConcordance.keySet() ) { Genotype evalGenotype = eval.getGenotype(sample); Genotype truthGenotype = truth.getGenotype(sample); - perSampleGenotypeConcordance.get(sample).update(evalGenotype,truthGenotype,alleleTruth); - overallGenotypeConcordance.update(evalGenotype,truthGenotype,alleleTruth); + perSampleGenotypeConcordance.get(sample).update(evalGenotype,truthGenotype,alleleTruth,truthRef); + overallGenotypeConcordance.update(evalGenotype,truthGenotype,alleleTruth,truthRef); } } @@ -170,10 +171,14 @@ public class ConcordanceMetrics { } @Requires({"eval!=null","truth != null","truthAlleles != null"}) - public void update(Genotype eval, Genotype truth, Set truthAlleles) { - // this is slow but correct + public void update(Genotype eval, Genotype truth, Set truthAlleles, String truthRef) { + // this is slow but correct. + + // NOTE: a reference call in "truth" is a special case, the eval can match *any* of the truth alleles + // that is, if the reference base is C, and a sample is C/C in truth, A/C, A/A, T/C, T/T will + // all match, so long as A and T are alleles in the truth callset. boolean matchingAlt = true; - if ( eval.isCalled() && truth.isCalled() ) { + if ( eval.isCalled() && truth.isCalled() && truth.isHomRef() ) { // by default, no-calls "match" between alleles, so if // one or both sites are no-call or unavailable, the alt alleles match // otherwise, check explicitly: if the eval has an allele that's not ref, no-call, or present in truth @@ -181,6 +186,17 @@ public class ConcordanceMetrics { for ( Allele evalAllele : eval.getAlleles() ) { matchingAlt &= truthAlleles.contains(evalAllele.getBaseString()); } + } else if ( eval.isCalled() && truth.isCalled() ) { + // otherwise, the eval genotype has to match either the alleles in the truth genotype, or the truth reference allele + // todo -- this can be sped up by caching the truth allele sets + Set genoAlleles = new HashSet(3); + genoAlleles.add(truthRef); + for ( Allele truthGenoAl : truth.getAlleles() ) { + genoAlleles.add(truthGenoAl.getBaseString()); + } + for ( Allele evalAllele : eval.getAlleles() ) { + matchingAlt &= genoAlleles.contains(evalAllele.getBaseString()); + } } if ( matchingAlt ) { diff --git a/public/java/src/org/broadinstitute/sting/utils/MathUtils.java b/public/java/src/org/broadinstitute/sting/utils/MathUtils.java index 2459c1d36..ebbc3945f 100644 --- a/public/java/src/org/broadinstitute/sting/utils/MathUtils.java +++ b/public/java/src/org/broadinstitute/sting/utils/MathUtils.java @@ -88,14 +88,14 @@ public class MathUtils { * @param max upper bound of the range * @return a random int >= min and <= max */ - public static int randomIntegerInRange( int min, int max ) { + public static int randomIntegerInRange( final int min, final int max ) { return GenomeAnalysisEngine.getRandomGenerator().nextInt(max - min + 1) + min; } // A fast implementation of the Math.round() method. This method does not perform // under/overflow checking, so this shouldn't be used in the general case (but is fine // if one is already make those checks before calling in to the rounding). - public static int fastRound(double d) { + public static int fastRound(final double d) { return (d > 0.0) ? (int) (d + 0.5d) : (int) (d - 0.5d); } @@ -123,7 +123,7 @@ public class MathUtils { return approxSum; } - public static double approximateLog10SumLog10(double a, double b, double c) { + public static double approximateLog10SumLog10(final double a, final double b, final double c) { return approximateLog10SumLog10(a, approximateLog10SumLog10(b, c)); } @@ -152,97 +152,53 @@ public class MathUtils { return big + MathUtils.jacobianLogTable[ind]; } - public static double sum(Collection numbers) { - return sum(numbers, false); - } - - public static double sum(Collection numbers, boolean ignoreNan) { - double sum = 0; - for (Number n : numbers) { - if (!ignoreNan || !Double.isNaN(n.doubleValue())) { - sum += n.doubleValue(); - } - } - - return sum; - } - - public static int nonNanSize(Collection numbers) { - int size = 0; - for (Number n : numbers) { - size += Double.isNaN(n.doubleValue()) ? 0 : 1; - } - - return size; - } - - public static double average(Collection x) { - return sum(x) / x.size(); - } - - public static double average(Collection numbers, boolean ignoreNan) { - if (ignoreNan) { - return sum(numbers, true) / nonNanSize(numbers); - } - else { - return sum(numbers, false) / nonNanSize(numbers); - } - } - - public static double variance(Collection numbers, Number mean, boolean ignoreNan) { - double mn = mean.doubleValue(); - double var = 0; - for (Number n : numbers) { - var += (!ignoreNan || !Double.isNaN(n.doubleValue())) ? (n.doubleValue() - mn) * (n.doubleValue() - mn) : 0; - } - if (ignoreNan) { - return var / (nonNanSize(numbers) - 1); - } - return var / (numbers.size() - 1); - } - - public static double variance(Collection numbers, Number mean) { - return variance(numbers, mean, false); - } - - public static double variance(Collection numbers, boolean ignoreNan) { - return variance(numbers, average(numbers, ignoreNan), ignoreNan); - } - - public static double variance(Collection numbers) { - return variance(numbers, average(numbers, false), false); - } - - public static double sum(double[] values) { + public static double sum(final double[] values) { double s = 0.0; for (double v : values) s += v; return s; } - public static long sum(int[] x) { + public static long sum(final int[] x) { long total = 0; for (int v : x) total += v; return total; } - public static int sum(byte[] x) { + public static int sum(final byte[] x) { int total = 0; for (byte v : x) total += (int)v; return total; } - /** - * Calculates the log10 cumulative sum of an array with log10 probabilities - * - * @param log10p the array with log10 probabilities - * @param upTo index in the array to calculate the cumsum up to - * @return the log10 of the cumulative sum - */ - public static double log10CumulativeSumLog10(double[] log10p, int upTo) { - return log10sumLog10(log10p, 0, upTo); + public static double percentage(int x, int base) { + return (base > 0 ? ((double) x / (double) base) * 100.0 : 0); + } + + public static double ratio(final int num, final int denom) { + if ( denom > 0 ) { + return ((double) num)/denom; + } else { + if ( num == 0 && denom == 0) { + return 0.0; + } else { + throw new ReviewedStingException(String.format("The denominator of a ratio cannot be zero or less than zero: %d/%d",num,denom)); + } + } + } + + public static double ratio(final long num, final long denom) { + if ( denom > 0L ) { + return ((double) num)/denom; + } else { + if ( num == 0L && denom == 0L ) { + return 0.0; + } else { + throw new ReviewedStingException(String.format("The denominator of a ratio cannot be zero or less than zero: %d/%d",num,denom)); + } + } } /** @@ -251,18 +207,18 @@ public class MathUtils { * @param prRealSpace * @return */ - public static double[] toLog10(double[] prRealSpace) { + public static double[] toLog10(final double[] prRealSpace) { double[] log10s = new double[prRealSpace.length]; for (int i = 0; i < prRealSpace.length; i++) log10s[i] = Math.log10(prRealSpace[i]); return log10s; } - public static double log10sumLog10(double[] log10p, int start) { + public static double log10sumLog10(final double[] log10p, final int start) { return log10sumLog10(log10p, start, log10p.length); } - public static double log10sumLog10(double[] log10p, int start, int finish) { + public static double log10sumLog10(final double[] log10p,final int start,final int finish) { double sum = 0.0; double maxValue = arrayMax(log10p, finish); @@ -276,56 +232,42 @@ public class MathUtils { return Math.log10(sum) + maxValue; } - public static double sumDoubles(List values) { - double s = 0.0; - for (double v : values) - s += v; - return s; - } - - public static int sumIntegers(List values) { - int s = 0; - for (int v : values) - s += v; - return s; - } - - public static double sumLog10(double[] log10values) { + public static double sumLog10(final double[] log10values) { return Math.pow(10.0, log10sumLog10(log10values)); // double s = 0.0; // for ( double v : log10values) s += Math.pow(10.0, v); // return s; } - public static double log10sumLog10(double[] log10values) { + public static double log10sumLog10(final double[] log10values) { return log10sumLog10(log10values, 0); } - public static boolean wellFormedDouble(double val) { + public static boolean wellFormedDouble(final double val) { return !Double.isInfinite(val) && !Double.isNaN(val); } - public static double bound(double value, double minBoundary, double maxBoundary) { + public static double bound(final double value, final double minBoundary, final double maxBoundary) { return Math.max(Math.min(value, maxBoundary), minBoundary); } - public static boolean isBounded(double val, double lower, double upper) { + public static boolean isBounded(final double val, final double lower, final double upper) { return val >= lower && val <= upper; } - public static boolean isPositive(double val) { + public static boolean isPositive(final double val) { return !isNegativeOrZero(val); } - public static boolean isPositiveOrZero(double val) { + public static boolean isPositiveOrZero(final double val) { return isBounded(val, 0.0, Double.POSITIVE_INFINITY); } - public static boolean isNegativeOrZero(double val) { + public static boolean isNegativeOrZero(final double val) { return isBounded(val, Double.NEGATIVE_INFINITY, 0.0); } - public static boolean isNegative(double val) { + public static boolean isNegative(final double val) { return !isPositiveOrZero(val); } @@ -336,7 +278,7 @@ public class MathUtils { * @param b the second double value * @return -1 if a is greater than b, 0 if a is equal to be within 1e-6, 1 if b is greater than a. */ - public static byte compareDoubles(double a, double b) { + public static byte compareDoubles(final double a, final double b) { return compareDoubles(a, b, 1e-6); } @@ -348,7 +290,7 @@ public class MathUtils { * @param epsilon the precision within which two double values will be considered equal * @return -1 if a is greater than b, 0 if a is equal to be within epsilon, 1 if b is greater than a. */ - public static byte compareDoubles(double a, double b, double epsilon) { + public static byte compareDoubles(final double a, final double b, final double epsilon) { if (Math.abs(a - b) < epsilon) { return 0; } @@ -358,42 +300,13 @@ public class MathUtils { return 1; } - /** - * Compares float values for equality (within 1e-6), or inequality. - * - * @param a the first float value - * @param b the second float value - * @return -1 if a is greater than b, 0 if a is equal to be within 1e-6, 1 if b is greater than a. - */ - public static byte compareFloats(float a, float b) { - return compareFloats(a, b, 1e-6f); - } - - /** - * Compares float values for equality (within epsilon), or inequality. - * - * @param a the first float value - * @param b the second float value - * @param epsilon the precision within which two float values will be considered equal - * @return -1 if a is greater than b, 0 if a is equal to be within epsilon, 1 if b is greater than a. - */ - public static byte compareFloats(float a, float b, float epsilon) { - if (Math.abs(a - b) < epsilon) { - return 0; - } - if (a > b) { - return -1; - } - return 1; - } - - public static double NormalDistribution(double mean, double sd, double x) { + public static double NormalDistribution(final double mean, final double sd, final double x) { double a = 1.0 / (sd * Math.sqrt(2.0 * Math.PI)); double b = Math.exp(-1.0 * (Math.pow(x - mean, 2.0) / (2.0 * sd * sd))); return a * b; } - public static double binomialCoefficient(int n, int k) { + public static double binomialCoefficient(final int n, final int k) { return Math.pow(10, log10BinomialCoefficient(n, k)); } @@ -409,7 +322,7 @@ public class MathUtils { * @param p probability of success * @return the binomial probability of the specified configuration. Computes values down to about 1e-237. */ - public static double binomialProbability(int n, int k, double p) { + public static double binomialProbability(final int n, final int k, final double p) { return Math.pow(10, log10BinomialProbability(n, k, Math.log10(p))); } @@ -422,7 +335,7 @@ public class MathUtils { * @param probHit - probability of a successful hit * @return - returns the cumulative probability */ - public static double binomialCumulativeProbability(int start, int end, int total, double probHit) { + public static double binomialCumulativeProbability(final int start, final int end, final int total, final double probHit) { double cumProb = 0.0; double prevProb; BigDecimal probCache = BigDecimal.ZERO; @@ -454,7 +367,7 @@ public class MathUtils { * @param k an int[] of counts, where each element represents the number of times a certain outcome was observed * @return the multinomial of the specified configuration. */ - public static double multinomialCoefficient(int[] k) { + public static double multinomialCoefficient(final int[] k) { int n = 0; for (int xi : k) { n += xi; @@ -477,7 +390,7 @@ public class MathUtils { * @param p a double[] of probabilities, where each element represents the probability a given outcome can occur * @return the multinomial probability of the specified configuration. */ - public static double multinomialProbability(int[] k, double[] p) { + public static double multinomialProbability(final int[] k, final double[] p) { if (p.length != k.length) throw new UserException.BadArgumentValue("p and k", "Array of log10 probabilities must have the same size as the array of number of sucesses: " + p.length + ", " + k.length); @@ -496,7 +409,7 @@ public class MathUtils { * @param x an byte[] of numbers * @return the RMS of the specified numbers. */ - public static double rms(byte[] x) { + public static double rms(final byte[] x) { if (x.length == 0) return 0.0; @@ -513,7 +426,7 @@ public class MathUtils { * @param x an int[] of numbers * @return the RMS of the specified numbers. */ - public static double rms(int[] x) { + public static double rms(final int[] x) { if (x.length == 0) return 0.0; @@ -530,7 +443,7 @@ public class MathUtils { * @param x a double[] of numbers * @return the RMS of the specified numbers. */ - public static double rms(Double[] x) { + public static double rms(final Double[] x) { if (x.length == 0) return 0.0; @@ -541,7 +454,7 @@ public class MathUtils { return Math.sqrt(rms); } - public static double rms(Collection l) { + public static double rms(final Collection l) { if (l.size() == 0) return 0.0; @@ -560,7 +473,7 @@ public class MathUtils { return dist; } - public static double round(double num, int digits) { + public static double round(final double num, final int digits) { double result = num * Math.pow(10.0, (double) digits); result = Math.round(result); result = result / Math.pow(10.0, (double) digits); @@ -574,7 +487,7 @@ public class MathUtils { * @param takeLog10OfOutput if true, the output will be transformed back into log10 units * @return a newly allocated array corresponding the normalized values in array, maybe log10 transformed */ - public static double[] normalizeFromLog10(double[] array, boolean takeLog10OfOutput) { + public static double[] normalizeFromLog10(final double[] array, final boolean takeLog10OfOutput) { return normalizeFromLog10(array, takeLog10OfOutput, false); } @@ -587,7 +500,7 @@ public class MathUtils { * * @return */ - public static double[] normalizeFromLog10(double[] array, boolean takeLog10OfOutput, boolean keepInLogSpace) { + public static double[] normalizeFromLog10(final double[] array, final boolean takeLog10OfOutput, final boolean keepInLogSpace) { // for precision purposes, we need to add (or really subtract, since they're // all negative) the largest value; also, we need to convert to normal-space. double maxValue = arrayMax(array); @@ -630,7 +543,7 @@ public class MathUtils { * @param array the array to be normalized * @return a newly allocated array corresponding the normalized values in array */ - public static double[] normalizeFromLog10(double[] array) { + public static double[] normalizeFromLog10(final double[] array) { return normalizeFromLog10(array, false); } @@ -683,7 +596,7 @@ public class MathUtils { return maxElementIndex(array, array.length); } - public static int maxElementIndex(final int[] array, int endIndex) { + public static int maxElementIndex(final int[] array, final int endIndex) { if (array == null || array.length == 0) throw new IllegalArgumentException("Array cannot be null!"); @@ -696,7 +609,7 @@ public class MathUtils { return maxI; } - public static int maxElementIndex(final byte[] array, int endIndex) { + public static int maxElementIndex(final byte[] array, final int endIndex) { if (array == null || array.length == 0) throw new IllegalArgumentException("Array cannot be null!"); @@ -722,19 +635,19 @@ public class MathUtils { return array[maxElementIndex(array, endIndex)]; } - public static double arrayMin(double[] array) { + public static double arrayMin(final double[] array) { return array[minElementIndex(array)]; } - public static int arrayMin(int[] array) { + public static int arrayMin(final int[] array) { return array[minElementIndex(array)]; } - public static byte arrayMin(byte[] array) { + public static byte arrayMin(final byte[] array) { return array[minElementIndex(array)]; } - public static int minElementIndex(double[] array) { + public static int minElementIndex(final double[] array) { if (array == null || array.length == 0) throw new IllegalArgumentException("Array cannot be null!"); @@ -747,7 +660,7 @@ public class MathUtils { return minI; } - public static int minElementIndex(byte[] array) { + public static int minElementIndex(final byte[] array) { if (array == null || array.length == 0) throw new IllegalArgumentException("Array cannot be null!"); @@ -760,7 +673,7 @@ public class MathUtils { return minI; } - public static int minElementIndex(int[] array) { + public static int minElementIndex(final int[] array) { if (array == null || array.length == 0) throw new IllegalArgumentException("Array cannot be null!"); @@ -773,7 +686,7 @@ public class MathUtils { return minI; } - public static int arrayMaxInt(List array) { + public static int arrayMaxInt(final List array) { if (array == null) throw new IllegalArgumentException("Array cannot be null!"); if (array.size() == 0) @@ -785,19 +698,15 @@ public class MathUtils { return m; } - public static double arrayMaxDouble(List array) { - if (array == null) - throw new IllegalArgumentException("Array cannot be null!"); - if (array.size() == 0) - throw new IllegalArgumentException("Array size cannot be 0!"); - - double m = array.get(0); - for (double e : array) - m = Math.max(m, e); - return m; + public static int sum(final List list ) { + int sum = 0; + for ( Integer i : list ) { + sum += i; + } + return sum; } - public static double average(List vals, int maxI) { + public static double average(final List vals, final int maxI) { long sum = 0L; int i = 0; @@ -814,201 +723,11 @@ public class MathUtils { return (1.0 * sum) / i; } - public static double averageDouble(List vals, int maxI) { - double sum = 0.0; - - int i = 0; - for (double x : vals) { - if (i > maxI) - break; - sum += x; - i++; - } - return (1.0 * sum) / i; - } - - public static double average(List vals) { + public static double average(final List vals) { return average(vals, vals.size()); } - public static double average(int[] x) { - int sum = 0; - for (int v : x) - sum += v; - return (double) sum / x.length; - } - - public static byte average(byte[] vals) { - int sum = 0; - for (byte v : vals) { - sum += v; - } - return (byte) (sum / vals.length); - } - - public static double averageDouble(List vals) { - return averageDouble(vals, vals.size()); - } - - // Java Generics can't do primitive types, so I had to do this the simplistic way - - public static Integer[] sortPermutation(final int[] A) { - class comparator implements Comparator { - public int compare(Integer a, Integer b) { - if (A[a.intValue()] < A[b.intValue()]) { - return -1; - } - if (A[a.intValue()] == A[b.intValue()]) { - return 0; - } - if (A[a.intValue()] > A[b.intValue()]) { - return 1; - } - return 0; - } - } - Integer[] permutation = new Integer[A.length]; - for (int i = 0; i < A.length; i++) { - permutation[i] = i; - } - Arrays.sort(permutation, new comparator()); - return permutation; - } - - public static Integer[] sortPermutation(final double[] A) { - class comparator implements Comparator { - public int compare(Integer a, Integer b) { - if (A[a.intValue()] < A[b.intValue()]) { - return -1; - } - if (A[a.intValue()] == A[b.intValue()]) { - return 0; - } - if (A[a.intValue()] > A[b.intValue()]) { - return 1; - } - return 0; - } - } - Integer[] permutation = new Integer[A.length]; - for (int i = 0; i < A.length; i++) { - permutation[i] = i; - } - Arrays.sort(permutation, new comparator()); - return permutation; - } - - public static Integer[] sortPermutation(List A) { - final Object[] data = A.toArray(); - - class comparator implements Comparator { - public int compare(Integer a, Integer b) { - return ((T) data[a]).compareTo(data[b]); - } - } - Integer[] permutation = new Integer[A.size()]; - for (int i = 0; i < A.size(); i++) { - permutation[i] = i; - } - Arrays.sort(permutation, new comparator()); - return permutation; - } - - public static int[] permuteArray(int[] array, Integer[] permutation) { - int[] output = new int[array.length]; - for (int i = 0; i < output.length; i++) { - output[i] = array[permutation[i]]; - } - return output; - } - - public static double[] permuteArray(double[] array, Integer[] permutation) { - double[] output = new double[array.length]; - for (int i = 0; i < output.length; i++) { - output[i] = array[permutation[i]]; - } - return output; - } - - public static Object[] permuteArray(Object[] array, Integer[] permutation) { - Object[] output = new Object[array.length]; - for (int i = 0; i < output.length; i++) { - output[i] = array[permutation[i]]; - } - return output; - } - - public static String[] permuteArray(String[] array, Integer[] permutation) { - String[] output = new String[array.length]; - for (int i = 0; i < output.length; i++) { - output[i] = array[permutation[i]]; - } - return output; - } - - public static List permuteList(List list, Integer[] permutation) { - List output = new ArrayList(); - for (int i = 0; i < permutation.length; i++) { - output.add(list.get(permutation[i])); - } - return output; - } - - /** - * Draw N random elements from list. - */ - public static List randomSubset(List list, int N) { - if (list.size() <= N) { - return list; - } - - int idx[] = new int[list.size()]; - for (int i = 0; i < list.size(); i++) { - idx[i] = GenomeAnalysisEngine.getRandomGenerator().nextInt(); - } - - Integer[] perm = sortPermutation(idx); - - List ans = new ArrayList(); - for (int i = 0; i < N; i++) { - ans.add(list.get(perm[i])); - } - - return ans; - } - - /** - * Draw N random elements from an array. - * - * @param array your objects - * @param n number of elements to select at random from the list - * @return a new list with the N randomly chosen elements from list - */ - @Requires({"array != null", "n>=0"}) - @Ensures({"result != null", "result.length == Math.min(n, array.length)"}) - public static Object[] randomSubset(final Object[] array, final int n) { - if (array.length <= n) - return array.clone(); - - Object[] shuffledArray = arrayShuffle(array); - Object[] result = new Object[n]; - System.arraycopy(shuffledArray, 0, result, 0, n); - return result; - } - - public static double percentage(double x, double base) { - return (base > 0 ? (x / base) * 100.0 : 0); - } - - public static double percentage(int x, int base) { - return (base > 0 ? ((double) x / (double) base) * 100.0 : 0); - } - - public static double percentage(long x, long base) { - return (base > 0 ? ((double) x / (double) base) * 100.0 : 0); - } - - public static int countOccurrences(char c, String s) { + public static int countOccurrences(final char c, final String s) { int count = 0; for (int i = 0; i < s.length(); i++) { count += s.charAt(i) == c ? 1 : 0; @@ -1036,27 +755,6 @@ public class MathUtils { return count; } - /** - * Returns the top (larger) N elements of the array. Naive n^2 implementation (Selection Sort). - * Better than sorting if N (number of elements to return) is small - * - * @param array the array - * @param n number of top elements to return - * @return the n larger elements of the array - */ - public static Collection getNMaxElements(double[] array, int n) { - ArrayList maxN = new ArrayList(n); - double lastMax = Double.MAX_VALUE; - for (int i = 0; i < n; i++) { - double max = Double.MIN_VALUE; - for (double x : array) { - max = Math.min(lastMax, Math.max(x, max)); - } - maxN.add(max); - lastMax = max; - } - return maxN; - } /** * Returns n random indices drawn with replacement from the range 0..(k-1) @@ -1065,7 +763,7 @@ public class MathUtils { * @param k the number of random indices to draw (with replacement) * @return a list of k random indices ranging from 0 to (n-1) with possible duplicates */ - static public ArrayList sampleIndicesWithReplacement(int n, int k) { + static public ArrayList sampleIndicesWithReplacement(final int n, final int k) { ArrayList chosen_balls = new ArrayList(k); for (int i = 0; i < k; i++) { @@ -1084,7 +782,7 @@ public class MathUtils { * @param k the number of random indices to draw (without replacement) * @return a list of k random indices ranging from 0 to (n-1) without duplicates */ - static public ArrayList sampleIndicesWithoutReplacement(int n, int k) { + static public ArrayList sampleIndicesWithoutReplacement(final int n, final int k) { ArrayList chosen_balls = new ArrayList(k); for (int i = 0; i < n; i++) { @@ -1105,7 +803,7 @@ public class MathUtils { * @param the template type of the ArrayList * @return a new ArrayList consisting of the elements at the specified indices */ - static public ArrayList sliceListByIndices(List indices, List list) { + static public ArrayList sliceListByIndices(final List indices, final List list) { ArrayList subset = new ArrayList(); for (int i : indices) { @@ -1115,35 +813,6 @@ public class MathUtils { return subset; } - public static Comparable orderStatisticSearch(int orderStat, List list) { - // this finds the order statistic of the list (kth largest element) - // the list is assumed *not* to be sorted - - final Comparable x = list.get(orderStat); - ArrayList lessThanX = new ArrayList(); - ArrayList equalToX = new ArrayList(); - ArrayList greaterThanX = new ArrayList(); - - for (Comparable y : list) { - if (x.compareTo(y) > 0) { - lessThanX.add(y); - } - else if (x.compareTo(y) < 0) { - greaterThanX.add(y); - } - else - equalToX.add(y); - } - - if (lessThanX.size() > orderStat) - return orderStatisticSearch(orderStat, lessThanX); - else if (lessThanX.size() + equalToX.size() >= orderStat) - return orderStat; - else - return orderStatisticSearch(orderStat - lessThanX.size() - equalToX.size(), greaterThanX); - - } - /** * Given two log-probability vectors, compute log of vector product of them: * in Matlab notation, return log10(10.*x'*10.^y) @@ -1151,7 +820,7 @@ public class MathUtils { * @param y vector 2 * @return a double representing log (dotProd(10.^x,10.^y) */ - public static double logDotProduct(double [] x, double[] y) { + public static double logDotProduct(final double [] x, final double[] y) { if (x.length != y.length) throw new ReviewedStingException("BUG: Vectors of different lengths"); @@ -1165,57 +834,6 @@ public class MathUtils { - } - public static Object getMedian(List list) { - return orderStatisticSearch((int) Math.ceil(list.size() / 2), list); - } - - public static byte getQScoreOrderStatistic(List reads, List offsets, int k) { - // version of the order statistic calculator for SAMRecord/Integer lists, where the - // list index maps to a q-score only through the offset index - // returns the kth-largest q-score. - - if (reads.size() == 0) { - return 0; - } - - ArrayList lessThanQReads = new ArrayList(); - ArrayList equalToQReads = new ArrayList(); - ArrayList greaterThanQReads = new ArrayList(); - ArrayList lessThanQOffsets = new ArrayList(); - ArrayList greaterThanQOffsets = new ArrayList(); - - final byte qk = reads.get(k).getBaseQualities()[offsets.get(k)]; - - for (int iter = 0; iter < reads.size(); iter++) { - SAMRecord read = reads.get(iter); - int offset = offsets.get(iter); - byte quality = read.getBaseQualities()[offset]; - - if (quality < qk) { - lessThanQReads.add(read); - lessThanQOffsets.add(offset); - } - else if (quality > qk) { - greaterThanQReads.add(read); - greaterThanQOffsets.add(offset); - } - else { - equalToQReads.add(reads.get(iter)); - } - } - - if (lessThanQReads.size() > k) - return getQScoreOrderStatistic(lessThanQReads, lessThanQOffsets, k); - else if (equalToQReads.size() + lessThanQReads.size() >= k) - return qk; - else - return getQScoreOrderStatistic(greaterThanQReads, greaterThanQOffsets, k - lessThanQReads.size() - equalToQReads.size()); - - } - - public static byte getQScoreMedian(List reads, List offsets) { - return getQScoreOrderStatistic(reads, offsets, (int) Math.floor(reads.size() / 2.)); } /** @@ -1336,29 +954,6 @@ public class MathUtils { // // useful common utility routines // - public static double rate(long n, long d) { - return n / (1.0 * Math.max(d, 1)); - } - - public static double rate(int n, int d) { - return n / (1.0 * Math.max(d, 1)); - } - - public static long inverseRate(long n, long d) { - return n == 0 ? 0 : d / Math.max(n, 1); - } - - public static long inverseRate(int n, int d) { - return n == 0 ? 0 : d / Math.max(n, 1); - } - - public static double ratio(int num, int denom) { - return ((double) num) / (Math.max(denom, 1)); - } - - public static double ratio(long num, long denom) { - return ((double) num) / (Math.max(denom, 1)); - } static public double max(double x0, double x1, double x2) { double a = Math.max(x0, x1); @@ -1371,8 +966,8 @@ public class MathUtils { * @param ln log(x) * @return log10(x) */ - public static double lnToLog10(double ln) { - return ln * Math.log10(Math.exp(1)); + public static double lnToLog10(final double ln) { + return ln * Math.log10(Math.E); } /** @@ -1384,7 +979,7 @@ public class MathUtils { * Efficient rounding functions to simplify the log gamma function calculation * double to long with 32 bit shift */ - private static final int HI(double x) { + private static final int HI(final double x) { return (int) (Double.doubleToLongBits(x) >> 32); } @@ -1392,7 +987,7 @@ public class MathUtils { * Efficient rounding functions to simplify the log gamma function calculation * double to long without shift */ - private static final int LO(double x) { + private static final int LO(final double x) { return (int) Double.doubleToLongBits(x); } @@ -1400,7 +995,7 @@ public class MathUtils { * Most efficent implementation of the lnGamma (FDLIBM) * Use via the log10Gamma wrapper method. */ - private static double lnGamma(double x) { + private static double lnGamma(final double x) { double t, y, z, p, p1, p2, p3, q, r, w; int i; @@ -1521,7 +1116,7 @@ public class MathUtils { * @param x the x parameter * @return the log10 of the gamma function at x. */ - public static double log10Gamma(double x) { + public static double log10Gamma(final double x) { return lnToLog10(lnGamma(x)); } @@ -1533,11 +1128,11 @@ public class MathUtils { * @param k number of successes * @return the log10 of the binomial coefficient */ - public static double log10BinomialCoefficient(int n, int k) { + public static double log10BinomialCoefficient(final int n, final int k) { return log10Factorial(n) - log10Factorial(k) - log10Factorial(n - k); } - public static double log10BinomialProbability(int n, int k, double log10p) { + public static double log10BinomialProbability(final int n, final int k, final double log10p) { double log10OneMinusP = Math.log10(1 - Math.pow(10, log10p)); return log10BinomialCoefficient(n, k) + log10p * k + log10OneMinusP * (n - k); } @@ -1550,10 +1145,10 @@ public class MathUtils { * @param k array of any size with the number of successes for each grouping (k1, k2, k3, ..., km) * @return */ - public static double log10MultinomialCoefficient(int n, int[] k) { + public static double log10MultinomialCoefficient(final int n, final int[] k) { double denominator = 0.0; for (int x : k) { - denominator += log10Factorial(x ); + denominator += log10Factorial(x); } return log10Factorial(n) - denominator; } @@ -1567,7 +1162,7 @@ public class MathUtils { * @param log10p array of log10 probabilities * @return */ - public static double log10MultinomialProbability(int n, int[] k, double[] log10p) { + public static double log10MultinomialProbability(final int n, final int[] k, final double[] log10p) { if (log10p.length != k.length) throw new UserException.BadArgumentValue("p and k", "Array of log10 probabilities must have the same size as the array of number of sucesses: " + log10p.length + ", " + k.length); double log10Prod = 0.0; @@ -1577,12 +1172,12 @@ public class MathUtils { return log10MultinomialCoefficient(n, k) + log10Prod; } - public static double factorial(int x) { + public static double factorial(final int x) { // avoid rounding errors caused by fact that 10^log(x) might be slightly lower than x and flooring may produce 1 less than real value return (double)Math.round(Math.pow(10, log10Factorial(x))); } - public static double log10Factorial(int x) { + public static double log10Factorial(final int x) { if (x >= log10FactorialCache.length || x < 0) return log10Gamma(x + 1); else @@ -1598,57 +1193,20 @@ public class MathUtils { */ @Requires("a.length == b.length") @Ensures("result.length == a.length") - public static int[] addArrays(int[] a, int[] b) { + public static int[] addArrays(final int[] a, final int[] b) { int[] c = new int[a.length]; for (int i = 0; i < a.length; i++) c[i] = a[i] + b[i]; return c; } - /** - * Quick implementation of the Knuth-shuffle algorithm to generate a random - * permutation of the given array. - * - * @param array the original array - * @return a new array with the elements shuffled - */ - public static Object[] arrayShuffle(Object[] array) { - int n = array.length; - Object[] shuffled = array.clone(); - for (int i = 0; i < n; i++) { - int j = i + GenomeAnalysisEngine.getRandomGenerator().nextInt(n - i); - Object tmp = shuffled[i]; - shuffled[i] = shuffled[j]; - shuffled[j] = tmp; - } - return shuffled; - } - - /** - * Vector operations - * - * @param v1 first numerical array - * @param v2 second numerical array - * @return a new array with the elements added - */ - public static Double[] vectorSum(E v1[], E v2[]) { - if (v1.length != v2.length) - throw new UserException("BUG: vectors v1, v2 of different size in vectorSum()"); - - Double[] result = new Double[v1.length]; - for (int k = 0; k < v1.length; k++) - result[k] = v1[k].doubleValue() + v2[k].doubleValue(); - - return result; - } - /** Same routine, unboxed types for efficiency * * @param x First vector * @param y Second vector * @return Vector of same length as x and y so that z[k] = x[k]+y[k] */ - public static double[] vectorSum(double[]x, double[] y) { + public static double[] vectorSum(final double[]x, final double[] y) { if (x.length != y.length) throw new ReviewedStingException("BUG: Lengths of x and y must be the same"); @@ -1665,24 +1223,7 @@ public class MathUtils { * @param y Second vector * @return Vector of same length as x and y so that z[k] = x[k]-y[k] */ - public static double[] vectorDiff(double[]x, double[] y) { - if (x.length != y.length) - throw new ReviewedStingException("BUG: Lengths of x and y must be the same"); - - double[] result = new double[x.length]; - for (int k=0; k Double[] scalarTimesVector(E a, E[] v1) { - - Double result[] = new Double[v1.length]; - for (int k = 0; k < v1.length; k++) - result[k] = a.doubleValue() * v1[k].doubleValue(); - - return result; - } - - public static Double dotProduct(E[] v1, E[] v2) { - if (v1.length != v2.length) - throw new UserException("BUG: vectors v1, v2 of different size in vectorSum()"); - - Double result = 0.0; - for (int k = 0; k < v1.length; k++) - result += v1[k].doubleValue() * v2[k].doubleValue(); - - return result; - } - - public static double dotProduct(double[] v1, double[] v2) { - if (v1.length != v2.length) - throw new UserException("BUG: vectors v1, v2 of different size in vectorSum()"); - - double result = 0.0; - for (int k = 0; k < v1.length; k++) - result += v1[k] * v2[k]; - - return result; - } - - public static double[] vectorLog10(double v1[]) { - double result[] = new double[v1.length]; - for (int k = 0; k < v1.length; k++) - result[k] = Math.log10(v1[k]); - - return result; - - } - - // todo - silly overloading, just because Java can't unbox/box arrays of primitive types, and we can't do generics with primitive types! - public static Double[] vectorLog10(Double v1[]) { - Double result[] = new Double[v1.length]; - for (int k = 0; k < v1.length; k++) - result[k] = Math.log10(v1[k]); - - return result; - - } - /** * Returns a series of integer values between start and stop, inclusive, * expontentially distributed between the two. That is, if there are @@ -1796,4 +1287,18 @@ public class MathUtils { return Double.isInfinite(d) || d > 0.0 ? 0.0 : d; } } + + /** + * Draw N random elements from list + * @param list - the list from which to draw randomly + * @param N - the number of elements to draw + */ + public static List randomSubset(final List list, final int N) { + if (list.size() <= N) { + return list; + } + + return sliceListByIndices(sampleIndicesWithoutReplacement(list.size(),N),list); + } + } diff --git a/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java index 2c57e8b33..2560bcd11 100644 --- a/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/MathUtilsUnitTest.java @@ -150,6 +150,21 @@ public class MathUtilsUnitTest extends BaseTest { @Test public void testLog10BinomialCoefficient() { logger.warn("Executing testLog10BinomialCoefficient"); + // note that we can test the binomial coefficient calculation indirectly via Newton's identity + // (1+z)^m = sum (m choose k)z^k + double[] z_vals = new double[]{0.999,0.9,0.8,0.5,0.2,0.01,0.0001}; + int[] exponent = new int[]{5,15,25,50,100}; + for ( double z : z_vals ) { + double logz = Math.log10(z); + for ( int exp : exponent ) { + double expected_log = exp*Math.log10(1+z); + double[] newtonArray_log = new double[1+exp]; + for ( int k = 0 ; k <= exp; k++ ) { + newtonArray_log[k] = MathUtils.log10BinomialCoefficient(exp,k)+k*logz; + } + Assert.assertEquals(MathUtils.log10sumLog10(newtonArray_log),expected_log,1e-6); + } + } Assert.assertEquals(MathUtils.log10BinomialCoefficient(4, 2), 0.7781513, 1e-6); Assert.assertEquals(MathUtils.log10BinomialCoefficient(10, 3), 2.079181, 1e-6); @@ -172,36 +187,19 @@ public class MathUtilsUnitTest extends BaseTest { Assert.assertEquals(MathUtils.log10Factorial(12), 8.680337, 1e-6); Assert.assertEquals(MathUtils.log10Factorial(200), 374.8969, 1e-3); Assert.assertEquals(MathUtils.log10Factorial(12342), 45138.26, 1e-1); - } - - @Test(enabled = true) - public void testRandomSubset() { - Integer[] x = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - Assert.assertEquals(MathUtils.randomSubset(x, 0).length, 0); - Assert.assertEquals(MathUtils.randomSubset(x, 1).length, 1); - Assert.assertEquals(MathUtils.randomSubset(x, 2).length, 2); - Assert.assertEquals(MathUtils.randomSubset(x, 3).length, 3); - Assert.assertEquals(MathUtils.randomSubset(x, 4).length, 4); - Assert.assertEquals(MathUtils.randomSubset(x, 5).length, 5); - Assert.assertEquals(MathUtils.randomSubset(x, 6).length, 6); - Assert.assertEquals(MathUtils.randomSubset(x, 7).length, 7); - Assert.assertEquals(MathUtils.randomSubset(x, 8).length, 8); - Assert.assertEquals(MathUtils.randomSubset(x, 9).length, 9); - Assert.assertEquals(MathUtils.randomSubset(x, 10).length, 10); - Assert.assertEquals(MathUtils.randomSubset(x, 11).length, 10); - - for (int i = 0; i < 25; i++) - Assert.assertTrue(hasUniqueElements(MathUtils.randomSubset(x, 5))); - - } - - @Test(enabled = true) - public void testArrayShuffle() { - Integer[] x = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}; - for (int i = 0; i < 25; i++) { - Object[] t = MathUtils.arrayShuffle(x); - Assert.assertTrue(hasUniqueElements(t)); - Assert.assertTrue(hasAllElements(x, t)); + double log10factorial_small = 0; + double log10factorial_middle = 374.8969; + double log10factorial_large = 45138.26; + int small_start = 1; + int med_start = 200; + int large_start = 12342; + for ( int i = 1; i < 1000; i++ ) { + log10factorial_small += Math.log10(i+small_start); + log10factorial_middle += Math.log10(i+med_start); + log10factorial_large += Math.log10(i+large_start); + Assert.assertEquals(MathUtils.log10Factorial(small_start+i),log10factorial_small,1e-6); + Assert.assertEquals(MathUtils.log10Factorial(med_start+i),log10factorial_middle,1e-3); + Assert.assertEquals(MathUtils.log10Factorial(large_start+i),log10factorial_large,1e-1); } } @@ -286,17 +284,29 @@ public class MathUtilsUnitTest extends BaseTest { Assert.assertEquals(MathUtils.approximateLog10SumLog10(-29.1, -27.6, -26.2), Math.log10(Math.pow(10.0, -29.1) + Math.pow(10.0, -27.6) + Math.pow(10.0, -26.2)), requiredPrecision); Assert.assertEquals(MathUtils.approximateLog10SumLog10(-0.12345, -0.23456, -0.34567), Math.log10(Math.pow(10.0, -0.12345) + Math.pow(10.0, -0.23456) + Math.pow(10.0, -0.34567)), requiredPrecision); Assert.assertEquals(MathUtils.approximateLog10SumLog10(-15.7654, -17.0101, -17.9341), Math.log10(Math.pow(10.0, -15.7654) + Math.pow(10.0, -17.0101) + Math.pow(10.0, -17.9341)), requiredPrecision); - } - @Test - public void testNormalizeFromLog10() { - Assert.assertTrue(compareDoubleArrays(MathUtils.normalizeFromLog10(new double[] {0.0, 0.0, -1.0, -1.1, -7.8}, false, true), new double[] {0.0, 0.0, -1.0, -1.1, -7.8})); - Assert.assertTrue(compareDoubleArrays(MathUtils.normalizeFromLog10(new double[] {-1.0, -1.0, -1.0, -1.1, -7.8}, false, true), new double[] {0.0, 0.0, 0.0, -0.1, -6.8})); - Assert.assertTrue(compareDoubleArrays(MathUtils.normalizeFromLog10(new double[] {-10.0, -7.8, -10.5, -1.1, -10.0}, false, true), new double[] {-8.9, -6.7, -9.4, 0.0, -8.9})); - - Assert.assertTrue(compareDoubleArrays(MathUtils.normalizeFromLog10(new double[] {-1.0, -1.0, -1.0, -1.0}), new double[] {0.25, 0.25, 0.25, 0.25})); - Assert.assertTrue(compareDoubleArrays(MathUtils.normalizeFromLog10(new double[] {-1.0, -3.0, -1.0, -1.0}), new double[] {0.1 * 1.0 / 0.301, 0.001 * 1.0 / 0.301, 0.1 * 1.0 / 0.301, 0.1 * 1.0 / 0.301})); - Assert.assertTrue(compareDoubleArrays(MathUtils.normalizeFromLog10(new double[] {-1.0, -3.0, -1.0, -2.0}), new double[] {0.1 * 1.0 / 0.211, 0.001 * 1.0 / 0.211, 0.1 * 1.0 / 0.211, 0.01 * 1.0 / 0.211})); + // magnitude of the sum doesn't matter, so we can combinatorially test this via partitions of unity + double[] mult_partitionFactor = new double[]{0.999,0.98,0.95,0.90,0.8,0.5,0.3,0.1,0.05,0.001}; + int[] n_partitions = new int[] {2,4,8,16,32,64,128,256,512,1028}; + for ( double alpha : mult_partitionFactor ) { + double log_alpha = Math.log10(alpha); + double log_oneMinusAlpha = Math.log10(1-alpha); + for ( int npart : n_partitions ) { + double[] multiplicative = new double[npart]; + double[] equal = new double[npart]; + double remaining_log = 0.0; // realspace = 1 + for ( int i = 0 ; i < npart-1; i++ ) { + equal[i] = -Math.log10(npart); + double piece = remaining_log + log_alpha; // take a*remaining, leaving remaining-a*remaining = (1-a)*remaining + multiplicative[i] = piece; + remaining_log = remaining_log + log_oneMinusAlpha; + } + equal[npart-1] = -Math.log10(npart); + multiplicative[npart-1] = remaining_log; + Assert.assertEquals(MathUtils.approximateLog10SumLog10(equal),0.0,requiredPrecision,String.format("Did not sum to one: k=%d equal partitions.",npart)); + Assert.assertEquals(MathUtils.approximateLog10SumLog10(multiplicative),0.0,requiredPrecision, String.format("Did not sum to one: k=%d multiplicative partitions with alpha=%f",npart,alpha)); + } + } } @Test @@ -342,12 +352,29 @@ public class MathUtilsUnitTest extends BaseTest { Assert.assertEquals(MathUtils.log10sumLog10(new double[] {-29.1, -27.6, -26.2}), Math.log10(Math.pow(10.0, -29.1) + Math.pow(10.0, -27.6) + Math.pow(10.0, -26.2)), requiredPrecision); Assert.assertEquals(MathUtils.log10sumLog10(new double[] {-0.12345, -0.23456, -0.34567}), Math.log10(Math.pow(10.0, -0.12345) + Math.pow(10.0, -0.23456) + Math.pow(10.0, -0.34567)), requiredPrecision); Assert.assertEquals(MathUtils.log10sumLog10(new double[] {-15.7654, -17.0101, -17.9341}), Math.log10(Math.pow(10.0, -15.7654) + Math.pow(10.0, -17.0101) + Math.pow(10.0, -17.9341)), requiredPrecision); - } - @Test - public void testDotProduct() { - Assert.assertEquals(MathUtils.dotProduct(new Double[]{-5.0,-3.0,2.0}, new Double[]{6.0,7.0,8.0}),-35.0,1e-3); - Assert.assertEquals(MathUtils.dotProduct(new Double[]{-5.0}, new Double[]{6.0}),-30.0,1e-3); + // magnitude of the sum doesn't matter, so we can combinatorially test this via partitions of unity + double[] mult_partitionFactor = new double[]{0.999,0.98,0.95,0.90,0.8,0.5,0.3,0.1,0.05,0.001}; + int[] n_partitions = new int[] {2,4,8,16,32,64,128,256,512,1028}; + for ( double alpha : mult_partitionFactor ) { + double log_alpha = Math.log10(alpha); + double log_oneMinusAlpha = Math.log10(1-alpha); + for ( int npart : n_partitions ) { + double[] multiplicative = new double[npart]; + double[] equal = new double[npart]; + double remaining_log = 0.0; // realspace = 1 + for ( int i = 0 ; i < npart-1; i++ ) { + equal[i] = -Math.log10(npart); + double piece = remaining_log + log_alpha; // take a*remaining, leaving remaining-a*remaining = (1-a)*remaining + multiplicative[i] = piece; + remaining_log = remaining_log + log_oneMinusAlpha; + } + equal[npart-1] = -Math.log10(npart); + multiplicative[npart-1] = remaining_log; + Assert.assertEquals(MathUtils.log10sumLog10(equal),0.0,requiredPrecision); + Assert.assertEquals(MathUtils.log10sumLog10(multiplicative),0.0,requiredPrecision,String.format("Did not sum to one: nPartitions=%d, alpha=%f",npart,alpha)); + } + } } @Test @@ -355,19 +382,4 @@ public class MathUtilsUnitTest extends BaseTest { Assert.assertEquals(MathUtils.logDotProduct(new double[]{-5.0,-3.0,2.0}, new double[]{6.0,7.0,8.0}),10.0,1e-3); Assert.assertEquals(MathUtils.logDotProduct(new double[]{-5.0}, new double[]{6.0}),1.0,1e-3); } - - /** - * Private function used by testNormalizeFromLog10() - */ - private boolean compareDoubleArrays(double[] b1, double[] b2) { - if (b1.length != b2.length) { - return false; // sanity check - } - - for (int i = 0; i < b1.length; i++) { - if (MathUtils.compareDoubles(b1[i], b2[i]) != 0) - return false; - } - return true; - } } diff --git a/public/java/test/org/broadinstitute/sting/utils/activeregion/BandPassActivityProfileUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/activeregion/BandPassActivityProfileUnitTest.java index d5231c30b..2470364c4 100644 --- a/public/java/test/org/broadinstitute/sting/utils/activeregion/BandPassActivityProfileUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/activeregion/BandPassActivityProfileUnitTest.java @@ -120,12 +120,21 @@ public class BandPassActivityProfileUnitTest extends BaseTest { for( int iii = 0; iii < activeProbArray.length; iii++ ) { final double[] kernel = ArrayUtils.subarray(GaussianKernel, Math.max(profile.getFilteredSize() - iii, 0), Math.min(GaussianKernel.length, profile.getFilteredSize() + activeProbArray.length - iii)); final double[] activeProbSubArray = ArrayUtils.subarray(activeProbArray, Math.max(0,iii - profile.getFilteredSize()), Math.min(activeProbArray.length,iii + profile.getFilteredSize() + 1)); - bandPassProbArray[iii] = MathUtils.dotProduct(activeProbSubArray, kernel); + bandPassProbArray[iii] = dotProduct(activeProbSubArray, kernel); } return bandPassProbArray; } + public static double dotProduct(double[] v1, double[] v2) { + Assert.assertEquals(v1.length,v2.length,"Array lengths do not mach in dotProduct"); + double result = 0.0; + for (int k = 0; k < v1.length; k++) + result += v1[k] * v2[k]; + + return result; + } + @DataProvider(name = "BandPassComposition") public Object[][] makeBandPassComposition() { final List tests = new LinkedList();