Improvement to RecalDatum and VisualizeContextTree

-- Reorganize functions in RecalDatum so that error rate can be computed indepentently.  Added unit tests.  Removed equals() method, which is a buggy without it's associated implementation for hashcode
-- New class RecalDatumTree based on QualIntervals that inherits from RecalDatum but includes the concept of sub data
-- VisualizeContextTree now uses RecalDatumTree and can trivially compute the penalty function for merging nodes, which it displays in the graph
This commit is contained in:
Mark DePristo 2012-07-29 13:13:13 -04:00
parent 57b45bfb1e
commit 315d25409f
2 changed files with 133 additions and 28 deletions

View File

@ -136,6 +136,14 @@ public class RecalDatum {
this.estimatedQReported = estimatedQReported; this.estimatedQReported = estimatedQReported;
} }
public static RecalDatum createRandomRecalDatum(int maxObservations, int maxErrors) {
final Random random = new Random();
final int nObservations = random.nextInt(maxObservations);
final int nErrors = random.nextInt(maxErrors);
final int qual = random.nextInt(QualityUtils.MAX_QUAL_SCORE);
return new RecalDatum(nObservations, nErrors, (byte)qual);
}
public final double getEstimatedQReported() { public final double getEstimatedQReported() {
return estimatedQReported; return estimatedQReported;
} }
@ -143,6 +151,29 @@ public class RecalDatum {
return (byte)(int)(Math.round(getEstimatedQReported())); return (byte)(int)(Math.round(getEstimatedQReported()));
} }
//---------------------------------------------------------------------------------------------------------------
//
// Empirical quality score -- derived from the num mismatches and observations
//
//---------------------------------------------------------------------------------------------------------------
/**
* Returns the error rate (in real space) of this interval, or 0 if there are no obserations
* @return the empirical error rate ~= N errors / N obs
*/
@Ensures("result >= 0.0")
public double getEmpiricalErrorRate() {
if ( numObservations == 0 )
return 0.0;
else {
// cache the value so we don't call log over and over again
final double doubleMismatches = (double) (numMismatches + SMOOTHING_CONSTANT);
// smoothing is one error and one non-error observation, for example
final double doubleObservations = (double) (numObservations + SMOOTHING_CONSTANT + SMOOTHING_CONSTANT);
return doubleMismatches / doubleObservations;
}
}
public synchronized void setEmpiricalQuality(final double empiricalQuality) { public synchronized void setEmpiricalQuality(final double empiricalQuality) {
if ( empiricalQuality < 0 ) throw new IllegalArgumentException("empiricalQuality < 0"); if ( empiricalQuality < 0 ) throw new IllegalArgumentException("empiricalQuality < 0");
if ( Double.isInfinite(empiricalQuality) ) throw new IllegalArgumentException("empiricalQuality is infinite"); if ( Double.isInfinite(empiricalQuality) ) throw new IllegalArgumentException("empiricalQuality is infinite");
@ -157,6 +188,16 @@ public class RecalDatum {
return empiricalQuality; return empiricalQuality;
} }
public final byte getEmpiricalQualityAsByte() {
return (byte)(Math.round(getEmpiricalQuality()));
}
//---------------------------------------------------------------------------------------------------------------
//
// increment methods
//
//---------------------------------------------------------------------------------------------------------------
@Override @Override
public String toString() { public String toString() {
return String.format("%d,%d,%d", getNumObservations(), getNumMismatches(), (byte) Math.floor(getEmpiricalQuality())); return String.format("%d,%d,%d", getNumObservations(), getNumMismatches(), (byte) Math.floor(getEmpiricalQuality()));
@ -166,29 +207,21 @@ public class RecalDatum {
return String.format("%s,%d,%.2f", toString(), (byte) Math.floor(getEstimatedQReported()), getEmpiricalQuality() - getEstimatedQReported()); return String.format("%s,%d,%.2f", toString(), (byte) Math.floor(getEstimatedQReported()), getEmpiricalQuality() - getEstimatedQReported());
} }
public static RecalDatum createRandomRecalDatum(int maxObservations, int maxErrors) { // /**
final Random random = new Random(); // * We don't compare the estimated quality reported because it may be different when read from
final int nObservations = random.nextInt(maxObservations); // * report tables.
final int nErrors = random.nextInt(maxErrors); // *
final int qual = random.nextInt(QualityUtils.MAX_QUAL_SCORE); // * @param o the other recal datum
return new RecalDatum(nObservations, nErrors, (byte)qual); // * @return true if the two recal datums have the same number of observations, errors and empirical quality.
} // */
// @Override
/** // public boolean equals(Object o) {
* We don't compare the estimated quality reported because it may be different when read from // if (!(o instanceof RecalDatum))
* report tables. // return false;
* // RecalDatum other = (RecalDatum) o;
* @param o the other recal datum // return super.equals(o) &&
* @return true if the two recal datums have the same number of observations, errors and empirical quality. // MathUtils.compareDoubles(this.empiricalQuality, other.empiricalQuality, 0.001) == 0;
*/ // }
@Override
public boolean equals(Object o) {
if (!(o instanceof RecalDatum))
return false;
RecalDatum other = (RecalDatum) o;
return super.equals(o) &&
MathUtils.compareDoubles(this.empiricalQuality, other.empiricalQuality, 0.001) == 0;
}
//--------------------------------------------------------------------------------------------------------------- //---------------------------------------------------------------------------------------------------------------
// //
@ -255,11 +288,7 @@ public class RecalDatum {
@Requires("empiricalQuality == UNINITIALIZED") @Requires("empiricalQuality == UNINITIALIZED")
@Ensures("empiricalQuality != UNINITIALIZED") @Ensures("empiricalQuality != UNINITIALIZED")
private synchronized final void calcEmpiricalQuality() { private synchronized final void calcEmpiricalQuality() {
// cache the value so we don't call log over and over again final double empiricalQual = -10 * Math.log10(getEmpiricalErrorRate());
final double doubleMismatches = (double) (numMismatches + SMOOTHING_CONSTANT);
// smoothing is one error and one non-error observation, for example
final double doubleObservations = (double) (numObservations + SMOOTHING_CONSTANT + SMOOTHING_CONSTANT);
final double empiricalQual = -10 * Math.log10(doubleMismatches / doubleObservations);
empiricalQuality = Math.min(empiricalQual, (double) QualityUtils.MAX_RECALIBRATED_Q_SCORE); empiricalQuality = Math.min(empiricalQual, (double) QualityUtils.MAX_RECALIBRATED_Q_SCORE);
} }

View File

@ -0,0 +1,76 @@
package org.broadinstitute.sting.utils.recalibration;
import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import java.util.Collections;
import java.util.HashSet;
import java.util.Set;
/**
* A tree of recal datum, where each contains a set of sub datum representing sub-states of the higher level one
*
* @author Mark DePristo
* @since 07/27/12
*/
public class RecalDatumTree extends RecalDatum {
final Set<RecalDatumTree> subnodes;
protected RecalDatumTree(final long nObservations, final long nErrors, final byte reportedQual) {
this(nObservations, nErrors, reportedQual, new HashSet<RecalDatumTree>());
}
public RecalDatumTree(final long nObservations, final long nErrors, final byte reportedQual, final Set<RecalDatumTree> subnodes) {
super(nObservations, nErrors, reportedQual);
this.subnodes = new HashSet<RecalDatumTree>(subnodes);
}
public double getPenalty() {
return calcPenalty(getEmpiricalErrorRate());
}
public void addSubnode(final RecalDatumTree sub) {
subnodes.add(sub);
}
public boolean isLeaf() {
return subnodes.isEmpty();
}
/**
* Calculate the penalty of this interval, given the overall error rate for the interval
*
* If the globalErrorRate is e, this value is:
*
* sum_i |log10(e_i) - log10(e)| * nObservations_i
*
* each the index i applies to all leaves of the tree accessible from this interval
* (found recursively from subnodes as necessary)
*
* @param globalErrorRate overall error rate in real space against which we calculate the penalty
* @return the cost of approximating the bins in this interval with the globalErrorRate
*/
@Requires("globalErrorRate >= 0.0")
@Ensures("result >= 0.0")
private double calcPenalty(final double globalErrorRate) {
if ( globalErrorRate == 0.0 ) // there were no observations, so there's no penalty
return 0.0;
if ( isLeaf() ) {
// this is leave node
return (Math.abs(Math.log10(getEmpiricalErrorRate()) - Math.log10(globalErrorRate))) * getNumObservations();
// TODO -- how we can generalize this calculation?
// if ( this.qEnd <= minInterestingQual )
// // It's free to merge up quality scores below the smallest interesting one
// return 0;
// else {
// return (Math.abs(Math.log10(getEmpiricalErrorRate()) - Math.log10(globalErrorRate))) * getNumObservations();
// }
} else {
double sum = 0;
for ( final RecalDatumTree hrd : subnodes)
sum += hrd.calcPenalty(globalErrorRate);
return sum;
}
}
}