Working version of adaptive context calculations
-- Uses chi2 test for independences to determine if subcontext is worth representing. Give excellent visual results -- Writes out analysis output file producing excellent results in R -- Trivial reformatting of MathUtils
This commit is contained in:
parent
93640b382e
commit
0c4e729e13
|
|
@ -2,7 +2,9 @@ package org.broadinstitute.sting.utils.recalibration;
|
||||||
|
|
||||||
import com.google.java.contract.Ensures;
|
import com.google.java.contract.Ensures;
|
||||||
import com.google.java.contract.Requires;
|
import com.google.java.contract.Requires;
|
||||||
|
import org.apache.commons.math.stat.inference.ChiSquareTestImpl;
|
||||||
import org.apache.log4j.Logger;
|
import org.apache.log4j.Logger;
|
||||||
|
import org.broadinstitute.sting.utils.MathUtils;
|
||||||
import org.broadinstitute.sting.utils.collections.Pair;
|
import org.broadinstitute.sting.utils.collections.Pair;
|
||||||
|
|
||||||
import java.util.HashSet;
|
import java.util.HashSet;
|
||||||
|
|
@ -15,8 +17,9 @@ import java.util.Set;
|
||||||
* @since 07/27/12
|
* @since 07/27/12
|
||||||
*/
|
*/
|
||||||
public class RecalDatumNode<T extends RecalDatum> {
|
public class RecalDatumNode<T extends RecalDatum> {
|
||||||
|
private final static boolean USE_CHI2 = true;
|
||||||
protected static Logger logger = Logger.getLogger(RecalDatumNode.class);
|
protected static Logger logger = Logger.getLogger(RecalDatumNode.class);
|
||||||
private final static double UNINITIALIZED = -1.0;
|
private final static double UNINITIALIZED = Double.NEGATIVE_INFINITY;
|
||||||
private final T recalDatum;
|
private final T recalDatum;
|
||||||
private double fixedPenalty = UNINITIALIZED;
|
private double fixedPenalty = UNINITIALIZED;
|
||||||
private final Set<RecalDatumNode<T>> subnodes;
|
private final Set<RecalDatumNode<T>> subnodes;
|
||||||
|
|
@ -56,11 +59,11 @@ public class RecalDatumNode<T extends RecalDatum> {
|
||||||
if ( fixedPenalty != UNINITIALIZED )
|
if ( fixedPenalty != UNINITIALIZED )
|
||||||
return fixedPenalty;
|
return fixedPenalty;
|
||||||
else
|
else
|
||||||
return calcPenalty(recalDatum.getEmpiricalErrorRate());
|
return calcPenalty();
|
||||||
}
|
}
|
||||||
|
|
||||||
public double calcAndSetFixedPenalty(final boolean doEntireTree) {
|
public double calcAndSetFixedPenalty(final boolean doEntireTree) {
|
||||||
fixedPenalty = calcPenalty(recalDatum.getEmpiricalErrorRate());
|
fixedPenalty = calcPenalty();
|
||||||
if ( doEntireTree )
|
if ( doEntireTree )
|
||||||
for ( final RecalDatumNode<T> sub : subnodes )
|
for ( final RecalDatumNode<T> sub : subnodes )
|
||||||
sub.calcAndSetFixedPenalty(doEntireTree);
|
sub.calcAndSetFixedPenalty(doEntireTree);
|
||||||
|
|
@ -79,14 +82,23 @@ public class RecalDatumNode<T extends RecalDatum> {
|
||||||
return subnodes.size();
|
return subnodes.size();
|
||||||
}
|
}
|
||||||
|
|
||||||
public double getMinNodePenalty() {
|
/**
|
||||||
|
* Total penalty is the sum of leaf node penalties
|
||||||
|
*
|
||||||
|
* This algorithm assumes that penalties have been fixed before pruning, as leaf nodes by
|
||||||
|
* definition have 0 penalty unless they represent a pruned tree with underlying -- but now
|
||||||
|
* pruned -- subtrees
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
public double totalPenalty() {
|
||||||
if ( isLeaf() )
|
if ( isLeaf() )
|
||||||
return Double.MAX_VALUE;
|
return getPenalty();
|
||||||
else {
|
else {
|
||||||
double minPenalty = getPenalty();
|
double sum = 0.0;
|
||||||
for ( final RecalDatumNode<T> sub : subnodes )
|
for ( final RecalDatumNode<T> sub : subnodes )
|
||||||
minPenalty = Math.min(minPenalty, sub.getMinNodePenalty());
|
sum += sub.totalPenalty();
|
||||||
return minPenalty;
|
return sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -97,6 +109,17 @@ public class RecalDatumNode<T extends RecalDatum> {
|
||||||
return subMax + 1;
|
return subMax + 1;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public int minDepth() {
|
||||||
|
if ( isLeaf() )
|
||||||
|
return 1;
|
||||||
|
else {
|
||||||
|
int subMin = Integer.MAX_VALUE;
|
||||||
|
for ( final RecalDatumNode<T> sub : subnodes )
|
||||||
|
subMin = Math.min(subMin, sub.minDepth());
|
||||||
|
return subMin + 1;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public int size() {
|
public int size() {
|
||||||
int size = 1;
|
int size = 1;
|
||||||
for ( final RecalDatumNode<T> sub : subnodes )
|
for ( final RecalDatumNode<T> sub : subnodes )
|
||||||
|
|
@ -104,6 +127,58 @@ public class RecalDatumNode<T extends RecalDatum> {
|
||||||
return size;
|
return size;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
public int numLeaves() {
|
||||||
|
if ( isLeaf() )
|
||||||
|
return 1;
|
||||||
|
else {
|
||||||
|
int size = 0;
|
||||||
|
for ( final RecalDatumNode<T> sub : subnodes )
|
||||||
|
size += sub.numLeaves();
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
private double calcPenalty() {
|
||||||
|
if ( USE_CHI2 )
|
||||||
|
return calcPenaltyChi2();
|
||||||
|
else
|
||||||
|
return calcPenaltyLog10(getRecalDatum().getEmpiricalErrorRate());
|
||||||
|
}
|
||||||
|
|
||||||
|
private double calcPenaltyChi2() {
|
||||||
|
if ( isLeaf() )
|
||||||
|
return 0.0;
|
||||||
|
else {
|
||||||
|
final long[][] counts = new long[subnodes.size()][2];
|
||||||
|
|
||||||
|
int i = 0;
|
||||||
|
for ( RecalDatumNode<T> subnode : subnodes ) {
|
||||||
|
counts[i][0] = subnode.getRecalDatum().getNumMismatches();
|
||||||
|
counts[i][1] = subnode.getRecalDatum().getNumObservations();
|
||||||
|
i++;
|
||||||
|
}
|
||||||
|
|
||||||
|
final double chi2 = new ChiSquareTestImpl().chiSquare(counts);
|
||||||
|
|
||||||
|
// StringBuilder x = new StringBuilder();
|
||||||
|
// StringBuilder y = new StringBuilder();
|
||||||
|
// for ( int k = 0; k < counts.length; k++) {
|
||||||
|
// if ( k != 0 ) {
|
||||||
|
// x.append(", ");
|
||||||
|
// y.append(", ");
|
||||||
|
// }
|
||||||
|
// x.append(counts[k][0]);
|
||||||
|
// y.append(counts[k][1]);
|
||||||
|
// }
|
||||||
|
// logger.info("x = c(" + x.toString() + ")");
|
||||||
|
// logger.info("y = c(" + y.toString() + ")");
|
||||||
|
// logger.info("chi2 = " + chi2);
|
||||||
|
|
||||||
|
return chi2;
|
||||||
|
//return Math.log10(chi2);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/**
|
/**
|
||||||
* Calculate the penalty of this interval, given the overall error rate for the interval
|
* Calculate the penalty of this interval, given the overall error rate for the interval
|
||||||
*
|
*
|
||||||
|
|
@ -119,7 +194,7 @@ public class RecalDatumNode<T extends RecalDatum> {
|
||||||
*/
|
*/
|
||||||
@Requires("globalErrorRate >= 0.0")
|
@Requires("globalErrorRate >= 0.0")
|
||||||
@Ensures("result >= 0.0")
|
@Ensures("result >= 0.0")
|
||||||
private double calcPenalty(final double globalErrorRate) {
|
private double calcPenaltyLog10(final double globalErrorRate) {
|
||||||
if ( globalErrorRate == 0.0 ) // there were no observations, so there's no penalty
|
if ( globalErrorRate == 0.0 ) // there were no observations, so there's no penalty
|
||||||
return 0.0;
|
return 0.0;
|
||||||
|
|
||||||
|
|
@ -136,7 +211,7 @@ public class RecalDatumNode<T extends RecalDatum> {
|
||||||
} else {
|
} else {
|
||||||
double sum = 0;
|
double sum = 0;
|
||||||
for ( final RecalDatumNode<T> hrd : subnodes)
|
for ( final RecalDatumNode<T> hrd : subnodes)
|
||||||
sum += hrd.calcPenalty(globalErrorRate);
|
sum += hrd.calcPenaltyLog10(globalErrorRate);
|
||||||
return sum;
|
return sum;
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -173,17 +248,38 @@ public class RecalDatumNode<T extends RecalDatum> {
|
||||||
* @return
|
* @return
|
||||||
*/
|
*/
|
||||||
private RecalDatumNode<T> removeLowestPenaltyNode() {
|
private RecalDatumNode<T> removeLowestPenaltyNode() {
|
||||||
final RecalDatumNode<T> oneRemoved = removeFirstNodeWithPenalty(getMinNodePenalty()).getFirst();
|
final Pair<RecalDatumNode<T>, Double> nodeToRemove = getMinPenaltyNode();
|
||||||
|
logger.info("Removing " + nodeToRemove.getFirst() + " with penalty " + nodeToRemove.getSecond());
|
||||||
|
|
||||||
|
final Pair<RecalDatumNode<T>, Boolean> result = removeNode(nodeToRemove.getFirst());
|
||||||
|
|
||||||
|
if ( ! result.getSecond() )
|
||||||
|
throw new IllegalStateException("Never removed any node!");
|
||||||
|
|
||||||
|
final RecalDatumNode<T> oneRemoved = result.getFirst();
|
||||||
if ( oneRemoved == null )
|
if ( oneRemoved == null )
|
||||||
throw new IllegalStateException("Removed our root node, wow, didn't expect that");
|
throw new IllegalStateException("Removed our root node, wow, didn't expect that");
|
||||||
return oneRemoved;
|
return oneRemoved;
|
||||||
}
|
}
|
||||||
|
|
||||||
private Pair<RecalDatumNode<T>, Boolean> removeFirstNodeWithPenalty(final double penaltyToRemove) {
|
private Pair<RecalDatumNode<T>, Double> getMinPenaltyNode() {
|
||||||
if ( getPenalty() == penaltyToRemove ) {
|
final double myValue = isLeaf() ? Double.MAX_VALUE : getPenalty();
|
||||||
logger.info("Removing " + this + " with penalty " + penaltyToRemove);
|
Pair<RecalDatumNode<T>, Double> maxNode = new Pair<RecalDatumNode<T>, Double>(this, myValue);
|
||||||
|
|
||||||
|
for ( final RecalDatumNode<T> sub : subnodes ) {
|
||||||
|
final Pair<RecalDatumNode<T>, Double> subFind = sub.getMinPenaltyNode();
|
||||||
|
if ( subFind.getSecond() < maxNode.getSecond() ) {
|
||||||
|
maxNode = subFind;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return maxNode;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Pair<RecalDatumNode<T>, Boolean> removeNode(final RecalDatumNode<T> nodeToRemove) {
|
||||||
|
if ( this == nodeToRemove ) {
|
||||||
if ( isLeaf() )
|
if ( isLeaf() )
|
||||||
throw new IllegalStateException("Trying to remove a leaf node from the tree! " + this + " " + penaltyToRemove);
|
throw new IllegalStateException("Trying to remove a leaf node from the tree! " + this + " " + nodeToRemove);
|
||||||
// node is the thing we are going to remove, but without any subnodes
|
// node is the thing we are going to remove, but without any subnodes
|
||||||
final RecalDatumNode<T> node = new RecalDatumNode<T>(getRecalDatum(), fixedPenalty);
|
final RecalDatumNode<T> node = new RecalDatumNode<T>(getRecalDatum(), fixedPenalty);
|
||||||
return new Pair<RecalDatumNode<T>, Boolean>(node, true);
|
return new Pair<RecalDatumNode<T>, Boolean>(node, true);
|
||||||
|
|
@ -200,7 +296,7 @@ public class RecalDatumNode<T extends RecalDatum> {
|
||||||
sub.add(sub1);
|
sub.add(sub1);
|
||||||
} else {
|
} else {
|
||||||
// haven't removed anything yet, so try
|
// haven't removed anything yet, so try
|
||||||
final Pair<RecalDatumNode<T>, Boolean> maybeRemoved = sub1.removeFirstNodeWithPenalty(penaltyToRemove);
|
final Pair<RecalDatumNode<T>, Boolean> maybeRemoved = sub1.removeNode(nodeToRemove);
|
||||||
removedSomething = maybeRemoved.getSecond();
|
removedSomething = maybeRemoved.getSecond();
|
||||||
sub.add(maybeRemoved.getFirst());
|
sub.add(maybeRemoved.getFirst());
|
||||||
}
|
}
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue