Working version of adaptive context calculations

-- Uses chi2 test for independences to determine if subcontext is worth representing.   Give excellent visual results
-- Writes out analysis output file producing excellent results in R
-- Trivial reformatting of MathUtils
This commit is contained in:
Mark DePristo 2012-07-30 15:44:33 -04:00
parent 93640b382e
commit 0c4e729e13
1 changed files with 112 additions and 16 deletions

View File

@ -2,7 +2,9 @@ package org.broadinstitute.sting.utils.recalibration;
import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import org.apache.commons.math.stat.inference.ChiSquareTestImpl;
import org.apache.log4j.Logger;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.collections.Pair;
import java.util.HashSet;
@ -15,8 +17,9 @@ import java.util.Set;
* @since 07/27/12
*/
public class RecalDatumNode<T extends RecalDatum> {
private final static boolean USE_CHI2 = true;
protected static Logger logger = Logger.getLogger(RecalDatumNode.class);
private final static double UNINITIALIZED = -1.0;
private final static double UNINITIALIZED = Double.NEGATIVE_INFINITY;
private final T recalDatum;
private double fixedPenalty = UNINITIALIZED;
private final Set<RecalDatumNode<T>> subnodes;
@ -56,11 +59,11 @@ public class RecalDatumNode<T extends RecalDatum> {
if ( fixedPenalty != UNINITIALIZED )
return fixedPenalty;
else
return calcPenalty(recalDatum.getEmpiricalErrorRate());
return calcPenalty();
}
public double calcAndSetFixedPenalty(final boolean doEntireTree) {
fixedPenalty = calcPenalty(recalDatum.getEmpiricalErrorRate());
fixedPenalty = calcPenalty();
if ( doEntireTree )
for ( final RecalDatumNode<T> sub : subnodes )
sub.calcAndSetFixedPenalty(doEntireTree);
@ -79,14 +82,23 @@ public class RecalDatumNode<T extends RecalDatum> {
return subnodes.size();
}
public double getMinNodePenalty() {
/**
* Total penalty is the sum of leaf node penalties
*
* This algorithm assumes that penalties have been fixed before pruning, as leaf nodes by
* definition have 0 penalty unless they represent a pruned tree with underlying -- but now
* pruned -- subtrees
*
* @return
*/
public double totalPenalty() {
if ( isLeaf() )
return Double.MAX_VALUE;
return getPenalty();
else {
double minPenalty = getPenalty();
double sum = 0.0;
for ( final RecalDatumNode<T> sub : subnodes )
minPenalty = Math.min(minPenalty, sub.getMinNodePenalty());
return minPenalty;
sum += sub.totalPenalty();
return sum;
}
}
@ -97,6 +109,17 @@ public class RecalDatumNode<T extends RecalDatum> {
return subMax + 1;
}
public int minDepth() {
if ( isLeaf() )
return 1;
else {
int subMin = Integer.MAX_VALUE;
for ( final RecalDatumNode<T> sub : subnodes )
subMin = Math.min(subMin, sub.minDepth());
return subMin + 1;
}
}
public int size() {
int size = 1;
for ( final RecalDatumNode<T> sub : subnodes )
@ -104,6 +127,58 @@ public class RecalDatumNode<T extends RecalDatum> {
return size;
}
public int numLeaves() {
if ( isLeaf() )
return 1;
else {
int size = 0;
for ( final RecalDatumNode<T> sub : subnodes )
size += sub.numLeaves();
return size;
}
}
private double calcPenalty() {
if ( USE_CHI2 )
return calcPenaltyChi2();
else
return calcPenaltyLog10(getRecalDatum().getEmpiricalErrorRate());
}
private double calcPenaltyChi2() {
if ( isLeaf() )
return 0.0;
else {
final long[][] counts = new long[subnodes.size()][2];
int i = 0;
for ( RecalDatumNode<T> subnode : subnodes ) {
counts[i][0] = subnode.getRecalDatum().getNumMismatches();
counts[i][1] = subnode.getRecalDatum().getNumObservations();
i++;
}
final double chi2 = new ChiSquareTestImpl().chiSquare(counts);
// StringBuilder x = new StringBuilder();
// StringBuilder y = new StringBuilder();
// for ( int k = 0; k < counts.length; k++) {
// if ( k != 0 ) {
// x.append(", ");
// y.append(", ");
// }
// x.append(counts[k][0]);
// y.append(counts[k][1]);
// }
// logger.info("x = c(" + x.toString() + ")");
// logger.info("y = c(" + y.toString() + ")");
// logger.info("chi2 = " + chi2);
return chi2;
//return Math.log10(chi2);
}
}
/**
* Calculate the penalty of this interval, given the overall error rate for the interval
*
@ -119,7 +194,7 @@ public class RecalDatumNode<T extends RecalDatum> {
*/
@Requires("globalErrorRate >= 0.0")
@Ensures("result >= 0.0")
private double calcPenalty(final double globalErrorRate) {
private double calcPenaltyLog10(final double globalErrorRate) {
if ( globalErrorRate == 0.0 ) // there were no observations, so there's no penalty
return 0.0;
@ -136,7 +211,7 @@ public class RecalDatumNode<T extends RecalDatum> {
} else {
double sum = 0;
for ( final RecalDatumNode<T> hrd : subnodes)
sum += hrd.calcPenalty(globalErrorRate);
sum += hrd.calcPenaltyLog10(globalErrorRate);
return sum;
}
}
@ -173,17 +248,38 @@ public class RecalDatumNode<T extends RecalDatum> {
* @return
*/
private RecalDatumNode<T> removeLowestPenaltyNode() {
final RecalDatumNode<T> oneRemoved = removeFirstNodeWithPenalty(getMinNodePenalty()).getFirst();
final Pair<RecalDatumNode<T>, Double> nodeToRemove = getMinPenaltyNode();
logger.info("Removing " + nodeToRemove.getFirst() + " with penalty " + nodeToRemove.getSecond());
final Pair<RecalDatumNode<T>, Boolean> result = removeNode(nodeToRemove.getFirst());
if ( ! result.getSecond() )
throw new IllegalStateException("Never removed any node!");
final RecalDatumNode<T> oneRemoved = result.getFirst();
if ( oneRemoved == null )
throw new IllegalStateException("Removed our root node, wow, didn't expect that");
return oneRemoved;
}
private Pair<RecalDatumNode<T>, Boolean> removeFirstNodeWithPenalty(final double penaltyToRemove) {
if ( getPenalty() == penaltyToRemove ) {
logger.info("Removing " + this + " with penalty " + penaltyToRemove);
private Pair<RecalDatumNode<T>, Double> getMinPenaltyNode() {
final double myValue = isLeaf() ? Double.MAX_VALUE : getPenalty();
Pair<RecalDatumNode<T>, Double> maxNode = new Pair<RecalDatumNode<T>, Double>(this, myValue);
for ( final RecalDatumNode<T> sub : subnodes ) {
final Pair<RecalDatumNode<T>, Double> subFind = sub.getMinPenaltyNode();
if ( subFind.getSecond() < maxNode.getSecond() ) {
maxNode = subFind;
}
}
return maxNode;
}
private Pair<RecalDatumNode<T>, Boolean> removeNode(final RecalDatumNode<T> nodeToRemove) {
if ( this == nodeToRemove ) {
if ( isLeaf() )
throw new IllegalStateException("Trying to remove a leaf node from the tree! " + this + " " + penaltyToRemove);
throw new IllegalStateException("Trying to remove a leaf node from the tree! " + this + " " + nodeToRemove);
// node is the thing we are going to remove, but without any subnodes
final RecalDatumNode<T> node = new RecalDatumNode<T>(getRecalDatum(), fixedPenalty);
return new Pair<RecalDatumNode<T>, Boolean>(node, true);
@ -200,7 +296,7 @@ public class RecalDatumNode<T extends RecalDatum> {
sub.add(sub1);
} else {
// haven't removed anything yet, so try
final Pair<RecalDatumNode<T>, Boolean> maybeRemoved = sub1.removeFirstNodeWithPenalty(penaltyToRemove);
final Pair<RecalDatumNode<T>, Boolean> maybeRemoved = sub1.removeNode(nodeToRemove);
removedSomething = maybeRemoved.getSecond();
sub.add(maybeRemoved.getFirst());
}