Working version of adaptive context calculations

-- Uses chi2 test for independences to determine if subcontext is worth representing.   Give excellent visual results
-- Writes out analysis output file producing excellent results in R
-- Trivial reformatting of MathUtils
This commit is contained in:
Mark DePristo 2012-07-30 15:44:33 -04:00
parent 93640b382e
commit 0c4e729e13
1 changed files with 112 additions and 16 deletions

View File

@ -2,7 +2,9 @@ package org.broadinstitute.sting.utils.recalibration;
import com.google.java.contract.Ensures; import com.google.java.contract.Ensures;
import com.google.java.contract.Requires; import com.google.java.contract.Requires;
import org.apache.commons.math.stat.inference.ChiSquareTestImpl;
import org.apache.log4j.Logger; import org.apache.log4j.Logger;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.collections.Pair; import org.broadinstitute.sting.utils.collections.Pair;
import java.util.HashSet; import java.util.HashSet;
@ -15,8 +17,9 @@ import java.util.Set;
* @since 07/27/12 * @since 07/27/12
*/ */
public class RecalDatumNode<T extends RecalDatum> { public class RecalDatumNode<T extends RecalDatum> {
private final static boolean USE_CHI2 = true;
protected static Logger logger = Logger.getLogger(RecalDatumNode.class); protected static Logger logger = Logger.getLogger(RecalDatumNode.class);
private final static double UNINITIALIZED = -1.0; private final static double UNINITIALIZED = Double.NEGATIVE_INFINITY;
private final T recalDatum; private final T recalDatum;
private double fixedPenalty = UNINITIALIZED; private double fixedPenalty = UNINITIALIZED;
private final Set<RecalDatumNode<T>> subnodes; private final Set<RecalDatumNode<T>> subnodes;
@ -56,11 +59,11 @@ public class RecalDatumNode<T extends RecalDatum> {
if ( fixedPenalty != UNINITIALIZED ) if ( fixedPenalty != UNINITIALIZED )
return fixedPenalty; return fixedPenalty;
else else
return calcPenalty(recalDatum.getEmpiricalErrorRate()); return calcPenalty();
} }
public double calcAndSetFixedPenalty(final boolean doEntireTree) { public double calcAndSetFixedPenalty(final boolean doEntireTree) {
fixedPenalty = calcPenalty(recalDatum.getEmpiricalErrorRate()); fixedPenalty = calcPenalty();
if ( doEntireTree ) if ( doEntireTree )
for ( final RecalDatumNode<T> sub : subnodes ) for ( final RecalDatumNode<T> sub : subnodes )
sub.calcAndSetFixedPenalty(doEntireTree); sub.calcAndSetFixedPenalty(doEntireTree);
@ -79,14 +82,23 @@ public class RecalDatumNode<T extends RecalDatum> {
return subnodes.size(); return subnodes.size();
} }
public double getMinNodePenalty() { /**
* Total penalty is the sum of leaf node penalties
*
* This algorithm assumes that penalties have been fixed before pruning, as leaf nodes by
* definition have 0 penalty unless they represent a pruned tree with underlying -- but now
* pruned -- subtrees
*
* @return
*/
public double totalPenalty() {
if ( isLeaf() ) if ( isLeaf() )
return Double.MAX_VALUE; return getPenalty();
else { else {
double minPenalty = getPenalty(); double sum = 0.0;
for ( final RecalDatumNode<T> sub : subnodes ) for ( final RecalDatumNode<T> sub : subnodes )
minPenalty = Math.min(minPenalty, sub.getMinNodePenalty()); sum += sub.totalPenalty();
return minPenalty; return sum;
} }
} }
@ -97,6 +109,17 @@ public class RecalDatumNode<T extends RecalDatum> {
return subMax + 1; return subMax + 1;
} }
public int minDepth() {
if ( isLeaf() )
return 1;
else {
int subMin = Integer.MAX_VALUE;
for ( final RecalDatumNode<T> sub : subnodes )
subMin = Math.min(subMin, sub.minDepth());
return subMin + 1;
}
}
public int size() { public int size() {
int size = 1; int size = 1;
for ( final RecalDatumNode<T> sub : subnodes ) for ( final RecalDatumNode<T> sub : subnodes )
@ -104,6 +127,58 @@ public class RecalDatumNode<T extends RecalDatum> {
return size; return size;
} }
public int numLeaves() {
if ( isLeaf() )
return 1;
else {
int size = 0;
for ( final RecalDatumNode<T> sub : subnodes )
size += sub.numLeaves();
return size;
}
}
private double calcPenalty() {
if ( USE_CHI2 )
return calcPenaltyChi2();
else
return calcPenaltyLog10(getRecalDatum().getEmpiricalErrorRate());
}
private double calcPenaltyChi2() {
if ( isLeaf() )
return 0.0;
else {
final long[][] counts = new long[subnodes.size()][2];
int i = 0;
for ( RecalDatumNode<T> subnode : subnodes ) {
counts[i][0] = subnode.getRecalDatum().getNumMismatches();
counts[i][1] = subnode.getRecalDatum().getNumObservations();
i++;
}
final double chi2 = new ChiSquareTestImpl().chiSquare(counts);
// StringBuilder x = new StringBuilder();
// StringBuilder y = new StringBuilder();
// for ( int k = 0; k < counts.length; k++) {
// if ( k != 0 ) {
// x.append(", ");
// y.append(", ");
// }
// x.append(counts[k][0]);
// y.append(counts[k][1]);
// }
// logger.info("x = c(" + x.toString() + ")");
// logger.info("y = c(" + y.toString() + ")");
// logger.info("chi2 = " + chi2);
return chi2;
//return Math.log10(chi2);
}
}
/** /**
* Calculate the penalty of this interval, given the overall error rate for the interval * Calculate the penalty of this interval, given the overall error rate for the interval
* *
@ -119,7 +194,7 @@ public class RecalDatumNode<T extends RecalDatum> {
*/ */
@Requires("globalErrorRate >= 0.0") @Requires("globalErrorRate >= 0.0")
@Ensures("result >= 0.0") @Ensures("result >= 0.0")
private double calcPenalty(final double globalErrorRate) { private double calcPenaltyLog10(final double globalErrorRate) {
if ( globalErrorRate == 0.0 ) // there were no observations, so there's no penalty if ( globalErrorRate == 0.0 ) // there were no observations, so there's no penalty
return 0.0; return 0.0;
@ -136,7 +211,7 @@ public class RecalDatumNode<T extends RecalDatum> {
} else { } else {
double sum = 0; double sum = 0;
for ( final RecalDatumNode<T> hrd : subnodes) for ( final RecalDatumNode<T> hrd : subnodes)
sum += hrd.calcPenalty(globalErrorRate); sum += hrd.calcPenaltyLog10(globalErrorRate);
return sum; return sum;
} }
} }
@ -173,17 +248,38 @@ public class RecalDatumNode<T extends RecalDatum> {
* @return * @return
*/ */
private RecalDatumNode<T> removeLowestPenaltyNode() { private RecalDatumNode<T> removeLowestPenaltyNode() {
final RecalDatumNode<T> oneRemoved = removeFirstNodeWithPenalty(getMinNodePenalty()).getFirst(); final Pair<RecalDatumNode<T>, Double> nodeToRemove = getMinPenaltyNode();
logger.info("Removing " + nodeToRemove.getFirst() + " with penalty " + nodeToRemove.getSecond());
final Pair<RecalDatumNode<T>, Boolean> result = removeNode(nodeToRemove.getFirst());
if ( ! result.getSecond() )
throw new IllegalStateException("Never removed any node!");
final RecalDatumNode<T> oneRemoved = result.getFirst();
if ( oneRemoved == null ) if ( oneRemoved == null )
throw new IllegalStateException("Removed our root node, wow, didn't expect that"); throw new IllegalStateException("Removed our root node, wow, didn't expect that");
return oneRemoved; return oneRemoved;
} }
private Pair<RecalDatumNode<T>, Boolean> removeFirstNodeWithPenalty(final double penaltyToRemove) { private Pair<RecalDatumNode<T>, Double> getMinPenaltyNode() {
if ( getPenalty() == penaltyToRemove ) { final double myValue = isLeaf() ? Double.MAX_VALUE : getPenalty();
logger.info("Removing " + this + " with penalty " + penaltyToRemove); Pair<RecalDatumNode<T>, Double> maxNode = new Pair<RecalDatumNode<T>, Double>(this, myValue);
for ( final RecalDatumNode<T> sub : subnodes ) {
final Pair<RecalDatumNode<T>, Double> subFind = sub.getMinPenaltyNode();
if ( subFind.getSecond() < maxNode.getSecond() ) {
maxNode = subFind;
}
}
return maxNode;
}
private Pair<RecalDatumNode<T>, Boolean> removeNode(final RecalDatumNode<T> nodeToRemove) {
if ( this == nodeToRemove ) {
if ( isLeaf() ) if ( isLeaf() )
throw new IllegalStateException("Trying to remove a leaf node from the tree! " + this + " " + penaltyToRemove); throw new IllegalStateException("Trying to remove a leaf node from the tree! " + this + " " + nodeToRemove);
// node is the thing we are going to remove, but without any subnodes // node is the thing we are going to remove, but without any subnodes
final RecalDatumNode<T> node = new RecalDatumNode<T>(getRecalDatum(), fixedPenalty); final RecalDatumNode<T> node = new RecalDatumNode<T>(getRecalDatum(), fixedPenalty);
return new Pair<RecalDatumNode<T>, Boolean>(node, true); return new Pair<RecalDatumNode<T>, Boolean>(node, true);
@ -200,7 +296,7 @@ public class RecalDatumNode<T extends RecalDatum> {
sub.add(sub1); sub.add(sub1);
} else { } else {
// haven't removed anything yet, so try // haven't removed anything yet, so try
final Pair<RecalDatumNode<T>, Boolean> maybeRemoved = sub1.removeFirstNodeWithPenalty(penaltyToRemove); final Pair<RecalDatumNode<T>, Boolean> maybeRemoved = sub1.removeNode(nodeToRemove);
removedSomething = maybeRemoved.getSecond(); removedSomething = maybeRemoved.getSecond();
sub.add(maybeRemoved.getFirst()); sub.add(maybeRemoved.getFirst());
} }