diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumNode.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumNode.java index 62ea67d7c..3af91be16 100644 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumNode.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumNode.java @@ -2,7 +2,9 @@ package org.broadinstitute.sting.utils.recalibration; import com.google.java.contract.Ensures; import com.google.java.contract.Requires; +import org.apache.commons.math.stat.inference.ChiSquareTestImpl; import org.apache.log4j.Logger; +import org.broadinstitute.sting.utils.MathUtils; import org.broadinstitute.sting.utils.collections.Pair; import java.util.HashSet; @@ -15,8 +17,9 @@ import java.util.Set; * @since 07/27/12 */ public class RecalDatumNode { + private final static boolean USE_CHI2 = true; protected static Logger logger = Logger.getLogger(RecalDatumNode.class); - private final static double UNINITIALIZED = -1.0; + private final static double UNINITIALIZED = Double.NEGATIVE_INFINITY; private final T recalDatum; private double fixedPenalty = UNINITIALIZED; private final Set> subnodes; @@ -56,11 +59,11 @@ public class RecalDatumNode { if ( fixedPenalty != UNINITIALIZED ) return fixedPenalty; else - return calcPenalty(recalDatum.getEmpiricalErrorRate()); + return calcPenalty(); } public double calcAndSetFixedPenalty(final boolean doEntireTree) { - fixedPenalty = calcPenalty(recalDatum.getEmpiricalErrorRate()); + fixedPenalty = calcPenalty(); if ( doEntireTree ) for ( final RecalDatumNode sub : subnodes ) sub.calcAndSetFixedPenalty(doEntireTree); @@ -79,14 +82,23 @@ public class RecalDatumNode { return subnodes.size(); } - public double getMinNodePenalty() { + /** + * Total penalty is the sum of leaf node penalties + * + * This algorithm assumes that penalties have been fixed before pruning, as leaf nodes by + * definition have 0 penalty unless they represent a pruned tree with underlying -- but now + * pruned -- subtrees + * + * @return + */ + public double totalPenalty() { if ( isLeaf() ) - return Double.MAX_VALUE; + return getPenalty(); else { - double minPenalty = getPenalty(); + double sum = 0.0; for ( final RecalDatumNode sub : subnodes ) - minPenalty = Math.min(minPenalty, sub.getMinNodePenalty()); - return minPenalty; + sum += sub.totalPenalty(); + return sum; } } @@ -97,6 +109,17 @@ public class RecalDatumNode { return subMax + 1; } + public int minDepth() { + if ( isLeaf() ) + return 1; + else { + int subMin = Integer.MAX_VALUE; + for ( final RecalDatumNode sub : subnodes ) + subMin = Math.min(subMin, sub.minDepth()); + return subMin + 1; + } + } + public int size() { int size = 1; for ( final RecalDatumNode sub : subnodes ) @@ -104,6 +127,58 @@ public class RecalDatumNode { return size; } + public int numLeaves() { + if ( isLeaf() ) + return 1; + else { + int size = 0; + for ( final RecalDatumNode sub : subnodes ) + size += sub.numLeaves(); + return size; + } + } + + private double calcPenalty() { + if ( USE_CHI2 ) + return calcPenaltyChi2(); + else + return calcPenaltyLog10(getRecalDatum().getEmpiricalErrorRate()); + } + + private double calcPenaltyChi2() { + if ( isLeaf() ) + return 0.0; + else { + final long[][] counts = new long[subnodes.size()][2]; + + int i = 0; + for ( RecalDatumNode subnode : subnodes ) { + counts[i][0] = subnode.getRecalDatum().getNumMismatches(); + counts[i][1] = subnode.getRecalDatum().getNumObservations(); + i++; + } + + final double chi2 = new ChiSquareTestImpl().chiSquare(counts); + +// StringBuilder x = new StringBuilder(); +// StringBuilder y = new StringBuilder(); +// for ( int k = 0; k < counts.length; k++) { +// if ( k != 0 ) { +// x.append(", "); +// y.append(", "); +// } +// x.append(counts[k][0]); +// y.append(counts[k][1]); +// } +// logger.info("x = c(" + x.toString() + ")"); +// logger.info("y = c(" + y.toString() + ")"); +// logger.info("chi2 = " + chi2); + + return chi2; + //return Math.log10(chi2); + } + } + /** * Calculate the penalty of this interval, given the overall error rate for the interval * @@ -119,7 +194,7 @@ public class RecalDatumNode { */ @Requires("globalErrorRate >= 0.0") @Ensures("result >= 0.0") - private double calcPenalty(final double globalErrorRate) { + private double calcPenaltyLog10(final double globalErrorRate) { if ( globalErrorRate == 0.0 ) // there were no observations, so there's no penalty return 0.0; @@ -136,7 +211,7 @@ public class RecalDatumNode { } else { double sum = 0; for ( final RecalDatumNode hrd : subnodes) - sum += hrd.calcPenalty(globalErrorRate); + sum += hrd.calcPenaltyLog10(globalErrorRate); return sum; } } @@ -173,17 +248,38 @@ public class RecalDatumNode { * @return */ private RecalDatumNode removeLowestPenaltyNode() { - final RecalDatumNode oneRemoved = removeFirstNodeWithPenalty(getMinNodePenalty()).getFirst(); + final Pair, Double> nodeToRemove = getMinPenaltyNode(); + logger.info("Removing " + nodeToRemove.getFirst() + " with penalty " + nodeToRemove.getSecond()); + + final Pair, Boolean> result = removeNode(nodeToRemove.getFirst()); + + if ( ! result.getSecond() ) + throw new IllegalStateException("Never removed any node!"); + + final RecalDatumNode oneRemoved = result.getFirst(); if ( oneRemoved == null ) throw new IllegalStateException("Removed our root node, wow, didn't expect that"); return oneRemoved; } - private Pair, Boolean> removeFirstNodeWithPenalty(final double penaltyToRemove) { - if ( getPenalty() == penaltyToRemove ) { - logger.info("Removing " + this + " with penalty " + penaltyToRemove); + private Pair, Double> getMinPenaltyNode() { + final double myValue = isLeaf() ? Double.MAX_VALUE : getPenalty(); + Pair, Double> maxNode = new Pair, Double>(this, myValue); + + for ( final RecalDatumNode sub : subnodes ) { + final Pair, Double> subFind = sub.getMinPenaltyNode(); + if ( subFind.getSecond() < maxNode.getSecond() ) { + maxNode = subFind; + } + } + + return maxNode; + } + + private Pair, Boolean> removeNode(final RecalDatumNode nodeToRemove) { + if ( this == nodeToRemove ) { if ( isLeaf() ) - throw new IllegalStateException("Trying to remove a leaf node from the tree! " + this + " " + penaltyToRemove); + throw new IllegalStateException("Trying to remove a leaf node from the tree! " + this + " " + nodeToRemove); // node is the thing we are going to remove, but without any subnodes final RecalDatumNode node = new RecalDatumNode(getRecalDatum(), fixedPenalty); return new Pair, Boolean>(node, true); @@ -200,7 +296,7 @@ public class RecalDatumNode { sub.add(sub1); } else { // haven't removed anything yet, so try - final Pair, Boolean> maybeRemoved = sub1.removeFirstNodeWithPenalty(penaltyToRemove); + final Pair, Boolean> maybeRemoved = sub1.removeNode(nodeToRemove); removedSomething = maybeRemoved.getSecond(); sub.add(maybeRemoved.getFirst()); }