Preliminary version of adaptive context covariate algorithm
-- Works according to visual inspection of output tree
This commit is contained in:
parent
315d25409f
commit
93640b382e
|
|
@ -0,0 +1,213 @@
|
||||||
|
package org.broadinstitute.sting.utils.recalibration;
|
||||||
|
|
||||||
|
import com.google.java.contract.Ensures;
|
||||||
|
import com.google.java.contract.Requires;
|
||||||
|
import org.apache.log4j.Logger;
|
||||||
|
import org.broadinstitute.sting.utils.collections.Pair;
|
||||||
|
|
||||||
|
import java.util.HashSet;
|
||||||
|
import java.util.Set;
|
||||||
|
|
||||||
|
/**
|
||||||
|
* A tree of recal datum, where each contains a set of sub datum representing sub-states of the higher level one
|
||||||
|
*
|
||||||
|
* @author Mark DePristo
|
||||||
|
* @since 07/27/12
|
||||||
|
*/
|
||||||
|
public class RecalDatumNode<T extends RecalDatum> {
|
||||||
|
protected static Logger logger = Logger.getLogger(RecalDatumNode.class);
|
||||||
|
private final static double UNINITIALIZED = -1.0;
|
||||||
|
private final T recalDatum;
|
||||||
|
private double fixedPenalty = UNINITIALIZED;
|
||||||
|
private final Set<RecalDatumNode<T>> subnodes;
|
||||||
|
|
||||||
|
public RecalDatumNode(final T recalDatum) {
|
||||||
|
this(recalDatum, new HashSet<RecalDatumNode<T>>());
|
||||||
|
}
|
||||||
|
|
||||||
|
@Override
|
||||||
|
public String toString() {
|
||||||
|
return recalDatum.toString();
|
||||||
|
}
|
||||||
|
|
||||||
|
public RecalDatumNode(final T recalDatum, final Set<RecalDatumNode<T>> subnodes) {
|
||||||
|
this(recalDatum, UNINITIALIZED, subnodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
protected RecalDatumNode(final T recalDatum, final double fixedPenalty) {
|
||||||
|
this(recalDatum, fixedPenalty, new HashSet<RecalDatumNode<T>>());
|
||||||
|
}
|
||||||
|
|
||||||
|
protected RecalDatumNode(final T recalDatum, final double fixedPenalty, final Set<RecalDatumNode<T>> subnodes) {
|
||||||
|
this.recalDatum = recalDatum;
|
||||||
|
this.fixedPenalty = fixedPenalty;
|
||||||
|
this.subnodes = new HashSet<RecalDatumNode<T>>(subnodes);
|
||||||
|
}
|
||||||
|
|
||||||
|
public T getRecalDatum() {
|
||||||
|
return recalDatum;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Set<RecalDatumNode<T>> getSubnodes() {
|
||||||
|
return subnodes;
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getPenalty() {
|
||||||
|
if ( fixedPenalty != UNINITIALIZED )
|
||||||
|
return fixedPenalty;
|
||||||
|
else
|
||||||
|
return calcPenalty(recalDatum.getEmpiricalErrorRate());
|
||||||
|
}
|
||||||
|
|
||||||
|
public double calcAndSetFixedPenalty(final boolean doEntireTree) {
|
||||||
|
fixedPenalty = calcPenalty(recalDatum.getEmpiricalErrorRate());
|
||||||
|
if ( doEntireTree )
|
||||||
|
for ( final RecalDatumNode<T> sub : subnodes )
|
||||||
|
sub.calcAndSetFixedPenalty(doEntireTree);
|
||||||
|
return fixedPenalty;
|
||||||
|
}
|
||||||
|
|
||||||
|
public void addSubnode(final RecalDatumNode<T> sub) {
|
||||||
|
subnodes.add(sub);
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean isLeaf() {
|
||||||
|
return subnodes.isEmpty();
|
||||||
|
}
|
||||||
|
|
||||||
|
public int getNumBranches() {
|
||||||
|
return subnodes.size();
|
||||||
|
}
|
||||||
|
|
||||||
|
public double getMinNodePenalty() {
|
||||||
|
if ( isLeaf() )
|
||||||
|
return Double.MAX_VALUE;
|
||||||
|
else {
|
||||||
|
double minPenalty = getPenalty();
|
||||||
|
for ( final RecalDatumNode<T> sub : subnodes )
|
||||||
|
minPenalty = Math.min(minPenalty, sub.getMinNodePenalty());
|
||||||
|
return minPenalty;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public int maxDepth() {
|
||||||
|
int subMax = 0;
|
||||||
|
for ( final RecalDatumNode<T> sub : subnodes )
|
||||||
|
subMax = Math.max(subMax, sub.maxDepth());
|
||||||
|
return subMax + 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
public int size() {
|
||||||
|
int size = 1;
|
||||||
|
for ( final RecalDatumNode<T> sub : subnodes )
|
||||||
|
size += sub.size();
|
||||||
|
return size;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Calculate the penalty of this interval, given the overall error rate for the interval
|
||||||
|
*
|
||||||
|
* If the globalErrorRate is e, this value is:
|
||||||
|
*
|
||||||
|
* sum_i |log10(e_i) - log10(e)| * nObservations_i
|
||||||
|
*
|
||||||
|
* each the index i applies to all leaves of the tree accessible from this interval
|
||||||
|
* (found recursively from subnodes as necessary)
|
||||||
|
*
|
||||||
|
* @param globalErrorRate overall error rate in real space against which we calculate the penalty
|
||||||
|
* @return the cost of approximating the bins in this interval with the globalErrorRate
|
||||||
|
*/
|
||||||
|
@Requires("globalErrorRate >= 0.0")
|
||||||
|
@Ensures("result >= 0.0")
|
||||||
|
private double calcPenalty(final double globalErrorRate) {
|
||||||
|
if ( globalErrorRate == 0.0 ) // there were no observations, so there's no penalty
|
||||||
|
return 0.0;
|
||||||
|
|
||||||
|
if ( isLeaf() ) {
|
||||||
|
// this is leave node
|
||||||
|
return (Math.abs(Math.log10(recalDatum.getEmpiricalErrorRate()) - Math.log10(globalErrorRate))) * recalDatum.getNumObservations();
|
||||||
|
// TODO -- how we can generalize this calculation?
|
||||||
|
// if ( this.qEnd <= minInterestingQual )
|
||||||
|
// // It's free to merge up quality scores below the smallest interesting one
|
||||||
|
// return 0;
|
||||||
|
// else {
|
||||||
|
// return (Math.abs(Math.log10(getEmpiricalErrorRate()) - Math.log10(globalErrorRate))) * getNumObservations();
|
||||||
|
// }
|
||||||
|
} else {
|
||||||
|
double sum = 0;
|
||||||
|
for ( final RecalDatumNode<T> hrd : subnodes)
|
||||||
|
sum += hrd.calcPenalty(globalErrorRate);
|
||||||
|
return sum;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public RecalDatumNode<T> pruneToDepth(final int maxDepth) {
|
||||||
|
if ( maxDepth < 1 )
|
||||||
|
throw new IllegalArgumentException("maxDepth < 1");
|
||||||
|
else {
|
||||||
|
final Set<RecalDatumNode<T>> subPruned = new HashSet<RecalDatumNode<T>>(getNumBranches());
|
||||||
|
if ( maxDepth > 1 )
|
||||||
|
for ( final RecalDatumNode<T> sub : subnodes )
|
||||||
|
subPruned.add(sub.pruneToDepth(maxDepth - 1));
|
||||||
|
return new RecalDatumNode<T>(getRecalDatum(), fixedPenalty, subPruned);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public RecalDatumNode<T> pruneByPenalty(final int maxElements) {
|
||||||
|
RecalDatumNode<T> root = this;
|
||||||
|
|
||||||
|
while ( root.size() > maxElements ) {
|
||||||
|
// remove the lowest penalty element, and continue
|
||||||
|
root = root.removeLowestPenaltyNode();
|
||||||
|
}
|
||||||
|
|
||||||
|
// our size is below the target, so we are good, return
|
||||||
|
return root;
|
||||||
|
}
|
||||||
|
|
||||||
|
/**
|
||||||
|
* Find the lowest penalty node in the tree, and return a tree without it
|
||||||
|
*
|
||||||
|
* Note this excludes the current (root) node
|
||||||
|
*
|
||||||
|
* @return
|
||||||
|
*/
|
||||||
|
private RecalDatumNode<T> removeLowestPenaltyNode() {
|
||||||
|
final RecalDatumNode<T> oneRemoved = removeFirstNodeWithPenalty(getMinNodePenalty()).getFirst();
|
||||||
|
if ( oneRemoved == null )
|
||||||
|
throw new IllegalStateException("Removed our root node, wow, didn't expect that");
|
||||||
|
return oneRemoved;
|
||||||
|
}
|
||||||
|
|
||||||
|
private Pair<RecalDatumNode<T>, Boolean> removeFirstNodeWithPenalty(final double penaltyToRemove) {
|
||||||
|
if ( getPenalty() == penaltyToRemove ) {
|
||||||
|
logger.info("Removing " + this + " with penalty " + penaltyToRemove);
|
||||||
|
if ( isLeaf() )
|
||||||
|
throw new IllegalStateException("Trying to remove a leaf node from the tree! " + this + " " + penaltyToRemove);
|
||||||
|
// node is the thing we are going to remove, but without any subnodes
|
||||||
|
final RecalDatumNode<T> node = new RecalDatumNode<T>(getRecalDatum(), fixedPenalty);
|
||||||
|
return new Pair<RecalDatumNode<T>, Boolean>(node, true);
|
||||||
|
} else {
|
||||||
|
// did we remove something in a sub branch?
|
||||||
|
boolean removedSomething = false;
|
||||||
|
|
||||||
|
// our sub nodes with the penalty node removed
|
||||||
|
final Set<RecalDatumNode<T>> sub = new HashSet<RecalDatumNode<T>>(getNumBranches());
|
||||||
|
|
||||||
|
for ( final RecalDatumNode<T> sub1 : subnodes ) {
|
||||||
|
if ( removedSomething ) {
|
||||||
|
// already removed something, just add sub1 back to sub
|
||||||
|
sub.add(sub1);
|
||||||
|
} else {
|
||||||
|
// haven't removed anything yet, so try
|
||||||
|
final Pair<RecalDatumNode<T>, Boolean> maybeRemoved = sub1.removeFirstNodeWithPenalty(penaltyToRemove);
|
||||||
|
removedSomething = maybeRemoved.getSecond();
|
||||||
|
sub.add(maybeRemoved.getFirst());
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
final RecalDatumNode<T> node = new RecalDatumNode<T>(getRecalDatum(), fixedPenalty, sub);
|
||||||
|
return new Pair<RecalDatumNode<T>, Boolean>(node, removedSomething);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -1,76 +0,0 @@
|
||||||
package org.broadinstitute.sting.utils.recalibration;
|
|
||||||
|
|
||||||
import com.google.java.contract.Ensures;
|
|
||||||
import com.google.java.contract.Requires;
|
|
||||||
|
|
||||||
import java.util.Collections;
|
|
||||||
import java.util.HashSet;
|
|
||||||
import java.util.Set;
|
|
||||||
|
|
||||||
/**
|
|
||||||
* A tree of recal datum, where each contains a set of sub datum representing sub-states of the higher level one
|
|
||||||
*
|
|
||||||
* @author Mark DePristo
|
|
||||||
* @since 07/27/12
|
|
||||||
*/
|
|
||||||
public class RecalDatumTree extends RecalDatum {
|
|
||||||
final Set<RecalDatumTree> subnodes;
|
|
||||||
|
|
||||||
protected RecalDatumTree(final long nObservations, final long nErrors, final byte reportedQual) {
|
|
||||||
this(nObservations, nErrors, reportedQual, new HashSet<RecalDatumTree>());
|
|
||||||
}
|
|
||||||
|
|
||||||
public RecalDatumTree(final long nObservations, final long nErrors, final byte reportedQual, final Set<RecalDatumTree> subnodes) {
|
|
||||||
super(nObservations, nErrors, reportedQual);
|
|
||||||
this.subnodes = new HashSet<RecalDatumTree>(subnodes);
|
|
||||||
}
|
|
||||||
|
|
||||||
public double getPenalty() {
|
|
||||||
return calcPenalty(getEmpiricalErrorRate());
|
|
||||||
}
|
|
||||||
|
|
||||||
public void addSubnode(final RecalDatumTree sub) {
|
|
||||||
subnodes.add(sub);
|
|
||||||
}
|
|
||||||
|
|
||||||
public boolean isLeaf() {
|
|
||||||
return subnodes.isEmpty();
|
|
||||||
}
|
|
||||||
|
|
||||||
/**
|
|
||||||
* Calculate the penalty of this interval, given the overall error rate for the interval
|
|
||||||
*
|
|
||||||
* If the globalErrorRate is e, this value is:
|
|
||||||
*
|
|
||||||
* sum_i |log10(e_i) - log10(e)| * nObservations_i
|
|
||||||
*
|
|
||||||
* each the index i applies to all leaves of the tree accessible from this interval
|
|
||||||
* (found recursively from subnodes as necessary)
|
|
||||||
*
|
|
||||||
* @param globalErrorRate overall error rate in real space against which we calculate the penalty
|
|
||||||
* @return the cost of approximating the bins in this interval with the globalErrorRate
|
|
||||||
*/
|
|
||||||
@Requires("globalErrorRate >= 0.0")
|
|
||||||
@Ensures("result >= 0.0")
|
|
||||||
private double calcPenalty(final double globalErrorRate) {
|
|
||||||
if ( globalErrorRate == 0.0 ) // there were no observations, so there's no penalty
|
|
||||||
return 0.0;
|
|
||||||
|
|
||||||
if ( isLeaf() ) {
|
|
||||||
// this is leave node
|
|
||||||
return (Math.abs(Math.log10(getEmpiricalErrorRate()) - Math.log10(globalErrorRate))) * getNumObservations();
|
|
||||||
// TODO -- how we can generalize this calculation?
|
|
||||||
// if ( this.qEnd <= minInterestingQual )
|
|
||||||
// // It's free to merge up quality scores below the smallest interesting one
|
|
||||||
// return 0;
|
|
||||||
// else {
|
|
||||||
// return (Math.abs(Math.log10(getEmpiricalErrorRate()) - Math.log10(globalErrorRate))) * getNumObservations();
|
|
||||||
// }
|
|
||||||
} else {
|
|
||||||
double sum = 0;
|
|
||||||
for ( final RecalDatumTree hrd : subnodes)
|
|
||||||
sum += hrd.calcPenalty(globalErrorRate);
|
|
||||||
return sum;
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
Loading…
Reference in New Issue