Merge pull request #130 from broadinstitute/chartl_genotype_concordance_and_mathutils

This commit is contained in:
MauricioCarneiro 2013-03-29 05:24:59 -07:00
commit b818a4f219
6 changed files with 263 additions and 676 deletions

View File

@ -203,6 +203,27 @@ public class ConcordanceMetricsUnitTest extends BaseTest {
Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.EVAL_SUPERSET_TRUTH.ordinal()],1);
Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.ALLELES_DO_NOT_MATCH.ordinal()],0);
Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.ALLELES_MATCH.ordinal()],0);
// now flip them around
eval = data.getSecond();
truth = data.getFirst();
codec = new VCFCodec();
evalHeader = (VCFHeader)codec.readHeader(new AsciiLineReader(new PositionalBufferedStream(new StringBufferInputStream(TEST_1_HEADER))));
compHeader = (VCFHeader)codec.readHeader(new AsciiLineReader(new PositionalBufferedStream(new StringBufferInputStream(TEST_1_HEADER))));
metrics = new ConcordanceMetrics(evalHeader,compHeader);
metrics.update(eval,truth);
Assert.assertEquals(eval.getGenotype("test1_sample2").getType().ordinal(), 2);
Assert.assertEquals(truth.getGenotype("test1_sample2").getType().ordinal(),2);
Assert.assertEquals(metrics.getGenotypeConcordance("test1_sample2").getnMismatchingAlt(),1);
Assert.assertEquals(metrics.getGenotypeConcordance("test1_sample2").getTable()[1][2],0);
Assert.assertEquals(metrics.getGenotypeConcordance("test1_sample3").getTable()[1][2],0);
Assert.assertEquals(metrics.getGenotypeConcordance("test1_sample3").getTable()[3][2],1);
Assert.assertEquals(metrics.getOverallGenotypeConcordance().getTable()[1][1],1);
Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.EVAL_SUPERSET_TRUTH.ordinal()],0);
Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.EVAL_SUBSET_TRUTH.ordinal()],1);
Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.ALLELES_DO_NOT_MATCH.ordinal()],0);
Assert.assertEquals(metrics.getOverallSiteConcordance().getSiteConcordance()[ConcordanceMetrics.SiteConcordanceType.ALLELES_MATCH.ordinal()],0);
}
private Pair<VariantContext,VariantContext> getData3() {

View File

@ -40,7 +40,8 @@ import org.broadinstitute.sting.utils.help.HelpConstants;
import org.broadinstitute.sting.utils.pileup.ReadBackedPileup;
import java.io.PrintStream;
import java.util.Arrays;
import java.util.Comparator;
/**
* A simple Bayesian genotyper, that outputs a text based call format. Intended to be used only as an
@ -95,7 +96,7 @@ public class GATKPaperGenotyper extends LocusWalker<Integer,Long> implements Tre
likelihoods[genotype.ordinal()] += Math.log10(p / genotype.toString().length());
}
Integer sortedList[] = MathUtils.sortPermutation(likelihoods);
Integer sortedList[] = sortPermutation(likelihoods);
// create call using the best genotype (GENOTYPE.values()[sortedList[9]].toString())
// and calculate the LOD score from best - next best (9 and 8 in the sorted list, since the best likelihoods are closest to zero)
@ -110,6 +111,29 @@ public class GATKPaperGenotyper extends LocusWalker<Integer,Long> implements Tre
return 0;
}
private static Integer[] sortPermutation(final double[] A) {
class comparator implements Comparator<Integer> {
public int compare(Integer a, Integer b) {
if (A[a.intValue()] < A[b.intValue()]) {
return -1;
}
if (A[a.intValue()] == A[b.intValue()]) {
return 0;
}
if (A[a.intValue()] > A[b.intValue()]) {
return 1;
}
return 0;
}
}
Integer[] permutation = new Integer[A.length];
for (int i = 0; i < A.length; i++) {
permutation[i] = i;
}
Arrays.sort(permutation, new comparator());
return permutation;
}
/**
* Takes reference base, and three priors for hom-ref, het, hom-var, and fills in the priors vector
* appropriately.

View File

@ -102,15 +102,16 @@ public class ConcordanceMetrics {
public void update(VariantContext eval, VariantContext truth) {
overallSiteConcordance.update(eval,truth);
Set<String> alleleTruth = new HashSet<String>(8);
alleleTruth.add(truth.getReference().getBaseString());
String truthRef = truth.getReference().getBaseString();
alleleTruth.add(truthRef);
for ( Allele a : truth.getAlternateAlleles() ) {
alleleTruth.add(a.getBaseString());
}
for ( String sample : perSampleGenotypeConcordance.keySet() ) {
Genotype evalGenotype = eval.getGenotype(sample);
Genotype truthGenotype = truth.getGenotype(sample);
perSampleGenotypeConcordance.get(sample).update(evalGenotype,truthGenotype,alleleTruth);
overallGenotypeConcordance.update(evalGenotype,truthGenotype,alleleTruth);
perSampleGenotypeConcordance.get(sample).update(evalGenotype,truthGenotype,alleleTruth,truthRef);
overallGenotypeConcordance.update(evalGenotype,truthGenotype,alleleTruth,truthRef);
}
}
@ -170,10 +171,14 @@ public class ConcordanceMetrics {
}
@Requires({"eval!=null","truth != null","truthAlleles != null"})
public void update(Genotype eval, Genotype truth, Set<String> truthAlleles) {
// this is slow but correct
public void update(Genotype eval, Genotype truth, Set<String> truthAlleles, String truthRef) {
// this is slow but correct.
// NOTE: a reference call in "truth" is a special case, the eval can match *any* of the truth alleles
// that is, if the reference base is C, and a sample is C/C in truth, A/C, A/A, T/C, T/T will
// all match, so long as A and T are alleles in the truth callset.
boolean matchingAlt = true;
if ( eval.isCalled() && truth.isCalled() ) {
if ( eval.isCalled() && truth.isCalled() && truth.isHomRef() ) {
// by default, no-calls "match" between alleles, so if
// one or both sites are no-call or unavailable, the alt alleles match
// otherwise, check explicitly: if the eval has an allele that's not ref, no-call, or present in truth
@ -181,6 +186,17 @@ public class ConcordanceMetrics {
for ( Allele evalAllele : eval.getAlleles() ) {
matchingAlt &= truthAlleles.contains(evalAllele.getBaseString());
}
} else if ( eval.isCalled() && truth.isCalled() ) {
// otherwise, the eval genotype has to match either the alleles in the truth genotype, or the truth reference allele
// todo -- this can be sped up by caching the truth allele sets
Set<String> genoAlleles = new HashSet<String>(3);
genoAlleles.add(truthRef);
for ( Allele truthGenoAl : truth.getAlleles() ) {
genoAlleles.add(truthGenoAl.getBaseString());
}
for ( Allele evalAllele : eval.getAlleles() ) {
matchingAlt &= genoAlleles.contains(evalAllele.getBaseString());
}
}
if ( matchingAlt ) {

File diff suppressed because it is too large Load Diff

View File

@ -150,6 +150,21 @@ public class MathUtilsUnitTest extends BaseTest {
@Test
public void testLog10BinomialCoefficient() {
logger.warn("Executing testLog10BinomialCoefficient");
// note that we can test the binomial coefficient calculation indirectly via Newton's identity
// (1+z)^m = sum (m choose k)z^k
double[] z_vals = new double[]{0.999,0.9,0.8,0.5,0.2,0.01,0.0001};
int[] exponent = new int[]{5,15,25,50,100};
for ( double z : z_vals ) {
double logz = Math.log10(z);
for ( int exp : exponent ) {
double expected_log = exp*Math.log10(1+z);
double[] newtonArray_log = new double[1+exp];
for ( int k = 0 ; k <= exp; k++ ) {
newtonArray_log[k] = MathUtils.log10BinomialCoefficient(exp,k)+k*logz;
}
Assert.assertEquals(MathUtils.log10sumLog10(newtonArray_log),expected_log,1e-6);
}
}
Assert.assertEquals(MathUtils.log10BinomialCoefficient(4, 2), 0.7781513, 1e-6);
Assert.assertEquals(MathUtils.log10BinomialCoefficient(10, 3), 2.079181, 1e-6);
@ -172,36 +187,19 @@ public class MathUtilsUnitTest extends BaseTest {
Assert.assertEquals(MathUtils.log10Factorial(12), 8.680337, 1e-6);
Assert.assertEquals(MathUtils.log10Factorial(200), 374.8969, 1e-3);
Assert.assertEquals(MathUtils.log10Factorial(12342), 45138.26, 1e-1);
}
@Test(enabled = true)
public void testRandomSubset() {
Integer[] x = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
Assert.assertEquals(MathUtils.randomSubset(x, 0).length, 0);
Assert.assertEquals(MathUtils.randomSubset(x, 1).length, 1);
Assert.assertEquals(MathUtils.randomSubset(x, 2).length, 2);
Assert.assertEquals(MathUtils.randomSubset(x, 3).length, 3);
Assert.assertEquals(MathUtils.randomSubset(x, 4).length, 4);
Assert.assertEquals(MathUtils.randomSubset(x, 5).length, 5);
Assert.assertEquals(MathUtils.randomSubset(x, 6).length, 6);
Assert.assertEquals(MathUtils.randomSubset(x, 7).length, 7);
Assert.assertEquals(MathUtils.randomSubset(x, 8).length, 8);
Assert.assertEquals(MathUtils.randomSubset(x, 9).length, 9);
Assert.assertEquals(MathUtils.randomSubset(x, 10).length, 10);
Assert.assertEquals(MathUtils.randomSubset(x, 11).length, 10);
for (int i = 0; i < 25; i++)
Assert.assertTrue(hasUniqueElements(MathUtils.randomSubset(x, 5)));
}
@Test(enabled = true)
public void testArrayShuffle() {
Integer[] x = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10};
for (int i = 0; i < 25; i++) {
Object[] t = MathUtils.arrayShuffle(x);
Assert.assertTrue(hasUniqueElements(t));
Assert.assertTrue(hasAllElements(x, t));
double log10factorial_small = 0;
double log10factorial_middle = 374.8969;
double log10factorial_large = 45138.26;
int small_start = 1;
int med_start = 200;
int large_start = 12342;
for ( int i = 1; i < 1000; i++ ) {
log10factorial_small += Math.log10(i+small_start);
log10factorial_middle += Math.log10(i+med_start);
log10factorial_large += Math.log10(i+large_start);
Assert.assertEquals(MathUtils.log10Factorial(small_start+i),log10factorial_small,1e-6);
Assert.assertEquals(MathUtils.log10Factorial(med_start+i),log10factorial_middle,1e-3);
Assert.assertEquals(MathUtils.log10Factorial(large_start+i),log10factorial_large,1e-1);
}
}
@ -286,17 +284,29 @@ public class MathUtilsUnitTest extends BaseTest {
Assert.assertEquals(MathUtils.approximateLog10SumLog10(-29.1, -27.6, -26.2), Math.log10(Math.pow(10.0, -29.1) + Math.pow(10.0, -27.6) + Math.pow(10.0, -26.2)), requiredPrecision);
Assert.assertEquals(MathUtils.approximateLog10SumLog10(-0.12345, -0.23456, -0.34567), Math.log10(Math.pow(10.0, -0.12345) + Math.pow(10.0, -0.23456) + Math.pow(10.0, -0.34567)), requiredPrecision);
Assert.assertEquals(MathUtils.approximateLog10SumLog10(-15.7654, -17.0101, -17.9341), Math.log10(Math.pow(10.0, -15.7654) + Math.pow(10.0, -17.0101) + Math.pow(10.0, -17.9341)), requiredPrecision);
}
@Test
public void testNormalizeFromLog10() {
Assert.assertTrue(compareDoubleArrays(MathUtils.normalizeFromLog10(new double[] {0.0, 0.0, -1.0, -1.1, -7.8}, false, true), new double[] {0.0, 0.0, -1.0, -1.1, -7.8}));
Assert.assertTrue(compareDoubleArrays(MathUtils.normalizeFromLog10(new double[] {-1.0, -1.0, -1.0, -1.1, -7.8}, false, true), new double[] {0.0, 0.0, 0.0, -0.1, -6.8}));
Assert.assertTrue(compareDoubleArrays(MathUtils.normalizeFromLog10(new double[] {-10.0, -7.8, -10.5, -1.1, -10.0}, false, true), new double[] {-8.9, -6.7, -9.4, 0.0, -8.9}));
Assert.assertTrue(compareDoubleArrays(MathUtils.normalizeFromLog10(new double[] {-1.0, -1.0, -1.0, -1.0}), new double[] {0.25, 0.25, 0.25, 0.25}));
Assert.assertTrue(compareDoubleArrays(MathUtils.normalizeFromLog10(new double[] {-1.0, -3.0, -1.0, -1.0}), new double[] {0.1 * 1.0 / 0.301, 0.001 * 1.0 / 0.301, 0.1 * 1.0 / 0.301, 0.1 * 1.0 / 0.301}));
Assert.assertTrue(compareDoubleArrays(MathUtils.normalizeFromLog10(new double[] {-1.0, -3.0, -1.0, -2.0}), new double[] {0.1 * 1.0 / 0.211, 0.001 * 1.0 / 0.211, 0.1 * 1.0 / 0.211, 0.01 * 1.0 / 0.211}));
// magnitude of the sum doesn't matter, so we can combinatorially test this via partitions of unity
double[] mult_partitionFactor = new double[]{0.999,0.98,0.95,0.90,0.8,0.5,0.3,0.1,0.05,0.001};
int[] n_partitions = new int[] {2,4,8,16,32,64,128,256,512,1028};
for ( double alpha : mult_partitionFactor ) {
double log_alpha = Math.log10(alpha);
double log_oneMinusAlpha = Math.log10(1-alpha);
for ( int npart : n_partitions ) {
double[] multiplicative = new double[npart];
double[] equal = new double[npart];
double remaining_log = 0.0; // realspace = 1
for ( int i = 0 ; i < npart-1; i++ ) {
equal[i] = -Math.log10(npart);
double piece = remaining_log + log_alpha; // take a*remaining, leaving remaining-a*remaining = (1-a)*remaining
multiplicative[i] = piece;
remaining_log = remaining_log + log_oneMinusAlpha;
}
equal[npart-1] = -Math.log10(npart);
multiplicative[npart-1] = remaining_log;
Assert.assertEquals(MathUtils.approximateLog10SumLog10(equal),0.0,requiredPrecision,String.format("Did not sum to one: k=%d equal partitions.",npart));
Assert.assertEquals(MathUtils.approximateLog10SumLog10(multiplicative),0.0,requiredPrecision, String.format("Did not sum to one: k=%d multiplicative partitions with alpha=%f",npart,alpha));
}
}
}
@Test
@ -342,12 +352,29 @@ public class MathUtilsUnitTest extends BaseTest {
Assert.assertEquals(MathUtils.log10sumLog10(new double[] {-29.1, -27.6, -26.2}), Math.log10(Math.pow(10.0, -29.1) + Math.pow(10.0, -27.6) + Math.pow(10.0, -26.2)), requiredPrecision);
Assert.assertEquals(MathUtils.log10sumLog10(new double[] {-0.12345, -0.23456, -0.34567}), Math.log10(Math.pow(10.0, -0.12345) + Math.pow(10.0, -0.23456) + Math.pow(10.0, -0.34567)), requiredPrecision);
Assert.assertEquals(MathUtils.log10sumLog10(new double[] {-15.7654, -17.0101, -17.9341}), Math.log10(Math.pow(10.0, -15.7654) + Math.pow(10.0, -17.0101) + Math.pow(10.0, -17.9341)), requiredPrecision);
}
@Test
public void testDotProduct() {
Assert.assertEquals(MathUtils.dotProduct(new Double[]{-5.0,-3.0,2.0}, new Double[]{6.0,7.0,8.0}),-35.0,1e-3);
Assert.assertEquals(MathUtils.dotProduct(new Double[]{-5.0}, new Double[]{6.0}),-30.0,1e-3);
// magnitude of the sum doesn't matter, so we can combinatorially test this via partitions of unity
double[] mult_partitionFactor = new double[]{0.999,0.98,0.95,0.90,0.8,0.5,0.3,0.1,0.05,0.001};
int[] n_partitions = new int[] {2,4,8,16,32,64,128,256,512,1028};
for ( double alpha : mult_partitionFactor ) {
double log_alpha = Math.log10(alpha);
double log_oneMinusAlpha = Math.log10(1-alpha);
for ( int npart : n_partitions ) {
double[] multiplicative = new double[npart];
double[] equal = new double[npart];
double remaining_log = 0.0; // realspace = 1
for ( int i = 0 ; i < npart-1; i++ ) {
equal[i] = -Math.log10(npart);
double piece = remaining_log + log_alpha; // take a*remaining, leaving remaining-a*remaining = (1-a)*remaining
multiplicative[i] = piece;
remaining_log = remaining_log + log_oneMinusAlpha;
}
equal[npart-1] = -Math.log10(npart);
multiplicative[npart-1] = remaining_log;
Assert.assertEquals(MathUtils.log10sumLog10(equal),0.0,requiredPrecision);
Assert.assertEquals(MathUtils.log10sumLog10(multiplicative),0.0,requiredPrecision,String.format("Did not sum to one: nPartitions=%d, alpha=%f",npart,alpha));
}
}
}
@Test
@ -355,19 +382,4 @@ public class MathUtilsUnitTest extends BaseTest {
Assert.assertEquals(MathUtils.logDotProduct(new double[]{-5.0,-3.0,2.0}, new double[]{6.0,7.0,8.0}),10.0,1e-3);
Assert.assertEquals(MathUtils.logDotProduct(new double[]{-5.0}, new double[]{6.0}),1.0,1e-3);
}
/**
* Private function used by testNormalizeFromLog10()
*/
private boolean compareDoubleArrays(double[] b1, double[] b2) {
if (b1.length != b2.length) {
return false; // sanity check
}
for (int i = 0; i < b1.length; i++) {
if (MathUtils.compareDoubles(b1[i], b2[i]) != 0)
return false;
}
return true;
}
}

View File

@ -120,12 +120,21 @@ public class BandPassActivityProfileUnitTest extends BaseTest {
for( int iii = 0; iii < activeProbArray.length; iii++ ) {
final double[] kernel = ArrayUtils.subarray(GaussianKernel, Math.max(profile.getFilteredSize() - iii, 0), Math.min(GaussianKernel.length, profile.getFilteredSize() + activeProbArray.length - iii));
final double[] activeProbSubArray = ArrayUtils.subarray(activeProbArray, Math.max(0,iii - profile.getFilteredSize()), Math.min(activeProbArray.length,iii + profile.getFilteredSize() + 1));
bandPassProbArray[iii] = MathUtils.dotProduct(activeProbSubArray, kernel);
bandPassProbArray[iii] = dotProduct(activeProbSubArray, kernel);
}
return bandPassProbArray;
}
public static double dotProduct(double[] v1, double[] v2) {
Assert.assertEquals(v1.length,v2.length,"Array lengths do not mach in dotProduct");
double result = 0.0;
for (int k = 0; k < v1.length; k++)
result += v1[k] * v2[k];
return result;
}
@DataProvider(name = "BandPassComposition")
public Object[][] makeBandPassComposition() {
final List<Object[]> tests = new LinkedList<Object[]>();