From 93640b382ebe79d269ada3c006530d8a3f49330f Mon Sep 17 00:00:00 2001 From: Mark DePristo Date: Mon, 30 Jul 2012 08:31:38 -0400 Subject: [PATCH] Preliminary version of adaptive context covariate algorithm -- Works according to visual inspection of output tree --- .../utils/recalibration/RecalDatumNode.java | 213 ++++++++++++++++++ .../utils/recalibration/RecalDatumTree.java | 76 ------- 2 files changed, 213 insertions(+), 76 deletions(-) create mode 100644 public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumNode.java delete mode 100644 public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumTree.java diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumNode.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumNode.java new file mode 100644 index 000000000..62ea67d7c --- /dev/null +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumNode.java @@ -0,0 +1,213 @@ +package org.broadinstitute.sting.utils.recalibration; + +import com.google.java.contract.Ensures; +import com.google.java.contract.Requires; +import org.apache.log4j.Logger; +import org.broadinstitute.sting.utils.collections.Pair; + +import java.util.HashSet; +import java.util.Set; + +/** + * A tree of recal datum, where each contains a set of sub datum representing sub-states of the higher level one + * + * @author Mark DePristo + * @since 07/27/12 + */ +public class RecalDatumNode { + protected static Logger logger = Logger.getLogger(RecalDatumNode.class); + private final static double UNINITIALIZED = -1.0; + private final T recalDatum; + private double fixedPenalty = UNINITIALIZED; + private final Set> subnodes; + + public RecalDatumNode(final T recalDatum) { + this(recalDatum, new HashSet>()); + } + + @Override + public String toString() { + return recalDatum.toString(); + } + + public RecalDatumNode(final T recalDatum, final Set> subnodes) { + this(recalDatum, UNINITIALIZED, subnodes); + } + + protected RecalDatumNode(final T recalDatum, final double fixedPenalty) { + this(recalDatum, fixedPenalty, new HashSet>()); + } + + protected RecalDatumNode(final T recalDatum, final double fixedPenalty, final Set> subnodes) { + this.recalDatum = recalDatum; + this.fixedPenalty = fixedPenalty; + this.subnodes = new HashSet>(subnodes); + } + + public T getRecalDatum() { + return recalDatum; + } + + public Set> getSubnodes() { + return subnodes; + } + + public double getPenalty() { + if ( fixedPenalty != UNINITIALIZED ) + return fixedPenalty; + else + return calcPenalty(recalDatum.getEmpiricalErrorRate()); + } + + public double calcAndSetFixedPenalty(final boolean doEntireTree) { + fixedPenalty = calcPenalty(recalDatum.getEmpiricalErrorRate()); + if ( doEntireTree ) + for ( final RecalDatumNode sub : subnodes ) + sub.calcAndSetFixedPenalty(doEntireTree); + return fixedPenalty; + } + + public void addSubnode(final RecalDatumNode sub) { + subnodes.add(sub); + } + + public boolean isLeaf() { + return subnodes.isEmpty(); + } + + public int getNumBranches() { + return subnodes.size(); + } + + public double getMinNodePenalty() { + if ( isLeaf() ) + return Double.MAX_VALUE; + else { + double minPenalty = getPenalty(); + for ( final RecalDatumNode sub : subnodes ) + minPenalty = Math.min(minPenalty, sub.getMinNodePenalty()); + return minPenalty; + } + } + + public int maxDepth() { + int subMax = 0; + for ( final RecalDatumNode sub : subnodes ) + subMax = Math.max(subMax, sub.maxDepth()); + return subMax + 1; + } + + public int size() { + int size = 1; + for ( final RecalDatumNode sub : subnodes ) + size += sub.size(); + return size; + } + + /** + * Calculate the penalty of this interval, given the overall error rate for the interval + * + * If the globalErrorRate is e, this value is: + * + * sum_i |log10(e_i) - log10(e)| * nObservations_i + * + * each the index i applies to all leaves of the tree accessible from this interval + * (found recursively from subnodes as necessary) + * + * @param globalErrorRate overall error rate in real space against which we calculate the penalty + * @return the cost of approximating the bins in this interval with the globalErrorRate + */ + @Requires("globalErrorRate >= 0.0") + @Ensures("result >= 0.0") + private double calcPenalty(final double globalErrorRate) { + if ( globalErrorRate == 0.0 ) // there were no observations, so there's no penalty + return 0.0; + + if ( isLeaf() ) { + // this is leave node + return (Math.abs(Math.log10(recalDatum.getEmpiricalErrorRate()) - Math.log10(globalErrorRate))) * recalDatum.getNumObservations(); + // TODO -- how we can generalize this calculation? +// if ( this.qEnd <= minInterestingQual ) +// // It's free to merge up quality scores below the smallest interesting one +// return 0; +// else { +// return (Math.abs(Math.log10(getEmpiricalErrorRate()) - Math.log10(globalErrorRate))) * getNumObservations(); +// } + } else { + double sum = 0; + for ( final RecalDatumNode hrd : subnodes) + sum += hrd.calcPenalty(globalErrorRate); + return sum; + } + } + + public RecalDatumNode pruneToDepth(final int maxDepth) { + if ( maxDepth < 1 ) + throw new IllegalArgumentException("maxDepth < 1"); + else { + final Set> subPruned = new HashSet>(getNumBranches()); + if ( maxDepth > 1 ) + for ( final RecalDatumNode sub : subnodes ) + subPruned.add(sub.pruneToDepth(maxDepth - 1)); + return new RecalDatumNode(getRecalDatum(), fixedPenalty, subPruned); + } + } + + public RecalDatumNode pruneByPenalty(final int maxElements) { + RecalDatumNode root = this; + + while ( root.size() > maxElements ) { + // remove the lowest penalty element, and continue + root = root.removeLowestPenaltyNode(); + } + + // our size is below the target, so we are good, return + return root; + } + + /** + * Find the lowest penalty node in the tree, and return a tree without it + * + * Note this excludes the current (root) node + * + * @return + */ + private RecalDatumNode removeLowestPenaltyNode() { + final RecalDatumNode oneRemoved = removeFirstNodeWithPenalty(getMinNodePenalty()).getFirst(); + if ( oneRemoved == null ) + throw new IllegalStateException("Removed our root node, wow, didn't expect that"); + return oneRemoved; + } + + private Pair, Boolean> removeFirstNodeWithPenalty(final double penaltyToRemove) { + if ( getPenalty() == penaltyToRemove ) { + logger.info("Removing " + this + " with penalty " + penaltyToRemove); + if ( isLeaf() ) + throw new IllegalStateException("Trying to remove a leaf node from the tree! " + this + " " + penaltyToRemove); + // node is the thing we are going to remove, but without any subnodes + final RecalDatumNode node = new RecalDatumNode(getRecalDatum(), fixedPenalty); + return new Pair, Boolean>(node, true); + } else { + // did we remove something in a sub branch? + boolean removedSomething = false; + + // our sub nodes with the penalty node removed + final Set> sub = new HashSet>(getNumBranches()); + + for ( final RecalDatumNode sub1 : subnodes ) { + if ( removedSomething ) { + // already removed something, just add sub1 back to sub + sub.add(sub1); + } else { + // haven't removed anything yet, so try + final Pair, Boolean> maybeRemoved = sub1.removeFirstNodeWithPenalty(penaltyToRemove); + removedSomething = maybeRemoved.getSecond(); + sub.add(maybeRemoved.getFirst()); + } + } + + final RecalDatumNode node = new RecalDatumNode(getRecalDatum(), fixedPenalty, sub); + return new Pair, Boolean>(node, removedSomething); + } + } +} diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumTree.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumTree.java deleted file mode 100644 index 210ea53bf..000000000 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatumTree.java +++ /dev/null @@ -1,76 +0,0 @@ -package org.broadinstitute.sting.utils.recalibration; - -import com.google.java.contract.Ensures; -import com.google.java.contract.Requires; - -import java.util.Collections; -import java.util.HashSet; -import java.util.Set; - -/** - * A tree of recal datum, where each contains a set of sub datum representing sub-states of the higher level one - * - * @author Mark DePristo - * @since 07/27/12 - */ -public class RecalDatumTree extends RecalDatum { - final Set subnodes; - - protected RecalDatumTree(final long nObservations, final long nErrors, final byte reportedQual) { - this(nObservations, nErrors, reportedQual, new HashSet()); - } - - public RecalDatumTree(final long nObservations, final long nErrors, final byte reportedQual, final Set subnodes) { - super(nObservations, nErrors, reportedQual); - this.subnodes = new HashSet(subnodes); - } - - public double getPenalty() { - return calcPenalty(getEmpiricalErrorRate()); - } - - public void addSubnode(final RecalDatumTree sub) { - subnodes.add(sub); - } - - public boolean isLeaf() { - return subnodes.isEmpty(); - } - - /** - * Calculate the penalty of this interval, given the overall error rate for the interval - * - * If the globalErrorRate is e, this value is: - * - * sum_i |log10(e_i) - log10(e)| * nObservations_i - * - * each the index i applies to all leaves of the tree accessible from this interval - * (found recursively from subnodes as necessary) - * - * @param globalErrorRate overall error rate in real space against which we calculate the penalty - * @return the cost of approximating the bins in this interval with the globalErrorRate - */ - @Requires("globalErrorRate >= 0.0") - @Ensures("result >= 0.0") - private double calcPenalty(final double globalErrorRate) { - if ( globalErrorRate == 0.0 ) // there were no observations, so there's no penalty - return 0.0; - - if ( isLeaf() ) { - // this is leave node - return (Math.abs(Math.log10(getEmpiricalErrorRate()) - Math.log10(globalErrorRate))) * getNumObservations(); - // TODO -- how we can generalize this calculation? -// if ( this.qEnd <= minInterestingQual ) -// // It's free to merge up quality scores below the smallest interesting one -// return 0; -// else { -// return (Math.abs(Math.log10(getEmpiricalErrorRate()) - Math.log10(globalErrorRate))) * getNumObservations(); -// } - } else { - double sum = 0; - for ( final RecalDatumTree hrd : subnodes) - sum += hrd.calcPenalty(globalErrorRate); - return sum; - } - } -}