diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/Tranche.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/Tranche.java index 2caf27e8a..76b6dd22b 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/Tranche.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/Tranche.java @@ -47,17 +47,20 @@ import java.util.*; */ public class Tranche implements Comparable { - private static final int CURRENT_VERSION = 2; + private static final int CURRENT_VERSION = 3; public double fdr, minVQSLod, targetTiTv, knownTiTv, novelTiTv; public int numKnown,numNovel; public String name; - public Tranche(double fdr, double targetTiTv, double minVQSLod, int numKnown, double knownTiTv, int numNovel, double novelTiTv) { - this(fdr, targetTiTv, minVQSLod, numKnown, knownTiTv, numNovel, novelTiTv, "anonymous"); + int accessibleTruthSites = 0; + int callsAtTruthSites = 0; + + public Tranche(double fdr, double targetTiTv, double minVQSLod, int numKnown, double knownTiTv, int numNovel, double novelTiTv, int accessibleTruthSites, int callsAtTruthSites) { + this(fdr, targetTiTv, minVQSLod, numKnown, knownTiTv, numNovel, novelTiTv, accessibleTruthSites, callsAtTruthSites, "anonymous"); } - public Tranche(double fdr, double targetTiTv, double minVQSLod, int numKnown, double knownTiTv, int numNovel, double novelTiTv, String name) { + public Tranche(double fdr, double targetTiTv, double minVQSLod, int numKnown, double knownTiTv, int numNovel, double novelTiTv, int accessibleTruthSites, int callsAtTruthSites, String name ) { this.fdr = fdr; this.targetTiTv = targetTiTv; this.minVQSLod = minVQSLod; @@ -67,7 +70,10 @@ public class Tranche implements Comparable { this.numKnown = numKnown; this.name = name; - if ( fdr <= 0.0 ) + this.accessibleTruthSites = accessibleTruthSites; + this.callsAtTruthSites = callsAtTruthSites; + + if ( fdr < 0.0 ) throw new UserException("Target FDR is unreasonable " + fdr); if ( targetTiTv < 0.5 || targetTiTv > 10 ) @@ -80,13 +86,17 @@ public class Tranche implements Comparable { throw new ReviewedStingException("BUG -- name cannot be null"); } + private double getTruthSensitivity() { + return accessibleTruthSites > 0 ? callsAtTruthSites / (1.0*accessibleTruthSites) : 0.0; + } + public int compareTo(Tranche other) { return Double.compare(this.fdr, other.fdr); } public String toString() { - return String.format("Tranche fdr=%.2f minVQSLod=%.4f known=(%d @ %.2f) novel=(%d @ %.2f) name=%s]", - fdr, minVQSLod, numKnown, knownTiTv, numNovel, novelTiTv, name); + return String.format("Tranche fdr=%.2f minVQSLod=%.4f known=(%d @ %.2f) novel=(%d @ %.2f) truthSites(%d accessible, %d called), name=%s]", + fdr, minVQSLod, numKnown, knownTiTv, numNovel, novelTiTv, accessibleTruthSites, callsAtTruthSites, name); } /** @@ -105,13 +115,13 @@ public class Tranche implements Comparable { stream.println("# Variant quality score tranches file"); stream.println("# Version number " + CURRENT_VERSION); - stream.println("FDRtranche,targetTiTv,numKnown,numNovel,knownTiTv,novelTiTv,minVQSLod,filterName"); + stream.println("FDRtranche,targetTiTv,numKnown,numNovel,knownTiTv,novelTiTv,minVQSLod,filterName,accessibleTruthSites,callsAtTruthSites,truthSensitivity"); Tranche prev = null; for ( Tranche t : tranches ) { - stream.printf("%.2f,%.2f,%d,%d,%.4f,%.4f,%.4f,FDRtranche%.2fto%.2f%n", + stream.printf("%.2f,%.2f,%d,%d,%.4f,%.4f,%.4f,FDRtranche%.2fto%.2f,%d,%d,%.4f%n", t.fdr,t.targetTiTv,t.numKnown,t.numNovel,t.knownTiTv,t.novelTiTv, t.minVQSLod, - (prev == null ? 0.0 : prev.fdr), t.fdr); + (prev == null ? 0.0 : prev.fdr), t.fdr, t.accessibleTruthSites, t.callsAtTruthSites, t.getTruthSensitivity()); prev = t; } @@ -161,7 +171,7 @@ public class Tranche implements Comparable { if ( header.length == 5 ) // old style tranches file, throw an error throw new UserException.MalformedFile(f, "Unfortuanately, your tranches file is from a previous version of this tool and cannot be used with the latest code. Please rerun VariantRecalibrator"); - if ( header.length != 8 ) + if ( header.length != 8 && header.length != 11 ) throw new UserException.MalformedFile(f, "Expected 8 elements in header line " + line); } else { if ( header.length != vals.length ) @@ -176,6 +186,8 @@ public class Tranche implements Comparable { getDouble(bindings,"knownTiTv", false), getInteger(bindings,"numNovel", true), getDouble(bindings,"novelTiTv", true), + getInteger(bindings,"accessibleTruthSites", false), + getInteger(bindings,"callsAtTruthSites", false), bindings.get("filterName"))); } } diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java index f49cd4b7a..57aa50f4c 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java @@ -38,6 +38,7 @@ public class VariantDatum implements Comparable { public boolean isKnown; public double lod; public double weight; + public boolean atTruthSite; public int compareTo(VariantDatum other) { return Double.compare(this.lod, other.lod); diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java index 3254e4184..3375fbbdf 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java @@ -460,26 +460,114 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel // Code to determine FDR tranches for VariantDatum[] // // --------------------------------------------------------------------------------------------------------- - public final static List findTranches( final VariantDatum[] data, final double[] FDRtranches, double targetTiTv ) { - return findTranches( data, FDRtranches, targetTiTv, null ); + public static abstract class SelectionMetric { + String name = null; + + public SelectionMetric(String name) { + this.name = name; + } + + public String getName() { return name; } + + public abstract double getThreshold(double tranche); + public abstract double getTarget(); + public abstract void calculateRunningMetric(List data); + public abstract double getRunningMetric(int i); + public abstract int datumValue(VariantDatum d); } - public final static List findTranches( final VariantDatum[] data, final double[] FDRtranches, double targetTiTv, File debugFile ) { - logger.info(String.format("Finding tranches for %d variants with %d FDRs and a target TiTv of %.2f", data.length, FDRtranches.length, targetTiTv)); + public static class NovelTiTvMetric extends SelectionMetric { + double[] runningTiTv; + double targetTiTv = 0; + + public NovelTiTvMetric(double target) { + super("NovelTiTv"); + targetTiTv = target; // compute the desired TiTv + } + + public double getThreshold(double tranche) { + return fdrToTiTv(tranche, targetTiTv); + } + + public double getTarget() { return targetTiTv; } + + public void calculateRunningMetric(List data) { + int ti = 0, tv = 0; + runningTiTv = new double[data.size()]; + + for ( int i = data.size() - 1; i >= 0; i-- ) { + VariantDatum datum = data.get(i); + if ( ! datum.isKnown ) { + if ( datum.isTransition ) { ti++; } else { tv++; } + runningTiTv[i] = ti / Math.max(1.0 * tv, 1.0); + } + } + } + + public double getRunningMetric(int i) { + return runningTiTv[i]; + } + + public int datumValue(VariantDatum d) { + return d.isTransition ? 1 : 0; + } + } + + public static class TruthSensitivityMetric extends SelectionMetric { + double[] runningSensitivity; + double targetTiTv = 0; + int nTrueSites = 0; + + public TruthSensitivityMetric(int nTrueSites) { + super("TruthSensitivity"); + this.nTrueSites = nTrueSites; + } + + public double getThreshold(double tranche) { + return tranche/100; // tranche of 1 => 99% sensivity target + } + + public double getTarget() { return 1.0; } + + public void calculateRunningMetric(List data) { + int nCalledAtTruth = 0; + runningSensitivity = new double[data.size()]; + + for ( int i = data.size() - 1; i >= 0; i-- ) { + VariantDatum datum = data.get(i); + nCalledAtTruth += datum.atTruthSite ? 1 : 0; + runningSensitivity[i] = 1 - nCalledAtTruth / (1.0 * nTrueSites); + } + } + + public double getRunningMetric(int i) { + return runningSensitivity[i]; + } + + public int datumValue(VariantDatum d) { + return d.atTruthSite ? 1 : 0; + } + } + + public final static List findTranches( final VariantDatum[] data, final double[] tranches, SelectionMetric metric ) { + return findTranches( data, tranches, metric, null ); + } + + public final static List findTranches( final VariantDatum[] data, final double[] trancheThresholds, SelectionMetric metric, File debugFile ) { + logger.info(String.format("Finding %d tranches for %d variants", trancheThresholds.length, data.length)); List tranchesData = sortVariantsbyLod(data); - double[] runningTiTv = calculateRunningTiTv(tranchesData); + metric.calculateRunningMetric(tranchesData); - if ( debugFile != null) { writeTranchesDebuggingInfo(debugFile, tranchesData, runningTiTv); } + if ( debugFile != null) { writeTranchesDebuggingInfo(debugFile, tranchesData, metric); } List tranches = new ArrayList(); - for ( double fdr : FDRtranches ) { - Tranche t = findTranche(tranchesData, runningTiTv, fdr, targetTiTv); - // todo -- should abort early when t's qual is 0 -- that's the lowest we'll get to + for ( double trancheThreshold : trancheThresholds ) { + Tranche t = findTranche(tranchesData, metric, trancheThreshold); if ( t == null ) { if ( tranches.size() == 0 ) - throw new UserException("Couldn't find any tranche containing variants with a TiTv > target of " + targetTiTv); + throw new UserException(String.format("Couldn't find any tranche containing variants with a %s > %.2f", metric.getName(), metric.getThreshold(trancheThreshold))); break; } @@ -489,13 +577,15 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel return tranches; } - private static void writeTranchesDebuggingInfo(File f, List tranchesData, double[] runningTiTv ) { + private static void writeTranchesDebuggingInfo(File f, List tranchesData, SelectionMetric metric ) { try { PrintStream out = new PrintStream(f); - out.println("Qual isTransition runningTiTv"); - for ( int i = 0; i < runningTiTv.length; i++ ) { + out.println("Qual metricValue runningValue"); + for ( int i = 0; i < tranchesData.size(); i++ ) { VariantDatum d = tranchesData.get(i); - out.printf("%.4f %d %.4f%n", d.lod, d.isTransition ? 1 : 0, runningTiTv[i]); + int score = metric.datumValue(d); + double runningValue = metric.getRunningMetric(i); + out.printf("%.4f %d %.4f%n", d.lod, score, runningValue); } } catch (FileNotFoundException e) { throw new UserException.CouldNotCreateOutputFile(f, e); @@ -508,44 +598,46 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel return sorted; } - private static double[] calculateRunningTiTv(List data) { - int ti = 0, tv = 0; - double[] run = new double[data.size()]; + public final static Tranche findTranche( final List data, final SelectionMetric metric, final double trancheThreshold ) { + logger.info(String.format(" Tranche threshold %.2f => selection metric threshold %.3f", trancheThreshold, metric.getThreshold(trancheThreshold))); - for ( int i = data.size() - 1; i >= 0; i-- ) { - VariantDatum datum = data.get(i); - if ( ! datum.isKnown ) { - if ( datum.isTransition ) { ti++; } else { tv++; } - run[i] = ti / Math.max(1.0 * tv, 1.0); - } - } - - return run; - } - - public final static Tranche findTranche( final List data, double[] runningTiTv, final double desiredFDR, double targetTiTv ) { - if ( data.size() != runningTiTv.length ) throw new ReviewedStingException("BUG: data and running titv are of different sizes"); - - final double titvThreshold = fdrToTiTv(desiredFDR, targetTiTv); // compute the desired TiTv - - for ( int i = 0; i < runningTiTv.length; i++ ) { - if ( runningTiTv[i] >= titvThreshold ) { + double metricThreshold = metric.getThreshold(trancheThreshold); + int n = data.size(); + for ( int i = 0; i < n; i++ ) { + if ( metric.getRunningMetric(i) >= metricThreshold ) { // we've found the largest group of variants with Ti/Tv >= our target titv - Tranche t = trancheOfVariants(data, i, desiredFDR, targetTiTv); - logger.info(String.format(" Found tranche for %.3f FDR = %.3f TiTv threshold starting with variant %d; running TiTv is %.2f ", - desiredFDR, titvThreshold, i, runningTiTv[i])); - logger.info(String.format(" Trance is %s", t)); + Tranche t = trancheOfVariants(data, i, trancheThreshold, metric.getTarget()); + logger.info(String.format(" Found tranche for %.3f: %.3f threshold starting with variant %d; running score is %.3f ", + trancheThreshold, metricThreshold, i, metric.getRunningMetric(i))); + logger.info(String.format(" Tranche is %s", t)); return t; } } // we get here when there's no subset of variants with Ti/Tv >= threshold, in which case we should return null - logger.info(String.format(" **Couldn't find tranche for %.3f FDR = %.3f TiTv threshold; last running TiTv was %.2f ", - desiredFDR, titvThreshold, runningTiTv[runningTiTv.length-1])); + //logger.info(String.format(" **Couldn't find tranche for %.3f FDR = %.3f TiTv threshold; last running TiTv was %.2f ", + // desiredFDR, titvThreshold, runningTiTv[runningTiTv.length-1])); + + +// for ( int i = 0; i < runningTiTv.length; i++ ) { +// if ( runningTiTv[i] >= titvThreshold ) { +// // we've found the largest group of variants with Ti/Tv >= our target titv +// Tranche t = trancheOfVariants(data, i, desiredFDR, targetTiTv); +// logger.info(String.format(" Found tranche for %.3f FDR = %.3f TiTv threshold starting with variant %d; running TiTv is %.2f ", +// desiredFDR, titvThreshold, i, metric.getRunningMetric(i))); +// logger.info(String.format(" Tranche is %s", t)); +// return t; +// } +// } +// +// // we get here when there's no subset of variants with Ti/Tv >= threshold, in which case we should return null +// logger.info(String.format(" **Couldn't find tranche for %.3f FDR = %.3f TiTv threshold; last running TiTv was %.2f ", +// desiredFDR, titvThreshold, runningTiTv[runningTiTv.length-1])); +// return null; } - public final static Tranche trancheOfVariants( final List data, int minI, double fdr, double targetTiTv ) { + public final static Tranche trancheOfVariants( final List data, int minI, double fdr, double target ) { int numKnown = 0, numNovel = 0, knownTi = 0, knownTv = 0, novelTi = 0, novelTv = 0; double minLod = data.get(minI).lod; @@ -568,13 +660,34 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel double knownTiTv = knownTi / Math.max(1.0 * knownTv, 1.0); double novelTiTv = novelTi / Math.max(1.0 * novelTv, 1.0); - return new Tranche(fdr, targetTiTv, minLod, numKnown, knownTiTv, numNovel, novelTiTv); + int accessibleTruthSites = countCallsAtTruth(data, Double.NEGATIVE_INFINITY); + int nCallsAtTruth = countCallsAtTruth(data, minLod); + + + return new Tranche(fdr, target, minLod, numKnown, knownTiTv, numNovel, novelTiTv, accessibleTruthSites, nCallsAtTruth); } public final static double fdrToTiTv(double desiredFDR, double targetTiTv) { return (1.0 - desiredFDR / 100.0) * (targetTiTv - 0.5) + 0.5; } + public static int countCallsAtTruth(final VariantDatum[] data, double minLOD ) { + int n = 0; + for ( VariantDatum d : data) { n += d.atTruthSite && d.lod >= minLOD ? 1 : 0; } + return n; + } + + public static int countCallsAtTruth(final List data, double minLOD ) { + int n = 0; + for ( VariantDatum d : data) { n += d.atTruthSite && d.lod >= minLOD ? 1 : 0; } + return n; + } + + // --------------------------------------------------------------------------------------------------------- + // + // Code to evaluate vectors given an already trained model + // + // --------------------------------------------------------------------------------------------------------- private double evaluateGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int startCluster, final int stopCluster ) { final int numAnnotations = data[0].annotations.length; diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java index 3748ed7d2..dcd82c7de 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java @@ -94,7 +94,7 @@ public class VariantRecalibrator extends RodWalker vcsTruth = tracker.getVariantContexts(ref, "truth", null, context.getLocation(), false, true); + final VariantContext vcTruth = ( vcsTruth.size() != 0 ? vcsTruth.iterator().next() : null ); + nTruthSites += vcTruth != null && vcTruth.isVariant() ? 1 : 0; + for( final VariantContext vc : tracker.getVariantContexts(ref, inputNames, null, context.getLocation(), false, false) ) { if( vc != null && vc.isSNP() ) { if( !vc.isFiltered() || IGNORE_ALL_INPUT_FILTERS || (ignoreInputFilterSet != null && ignoreInputFilterSet.containsAll(vc.getFilters())) ) { @@ -293,6 +308,9 @@ public class VariantRecalibrator extends RodWalker attrs = new HashMap(vc.getAttributes()); attrs.put(VariantRecalibrator.VQS_LOD_KEY, String.format("%.4f", lod)); @@ -327,11 +345,22 @@ public class VariantRecalibrator extends RodWalker reduceSum ) { - final VariantDataManager dataManager = new VariantDataManager( reduceSum, theModel.dataManager.annotationKeys ); reduceSum.clear(); // Don't need this ever again, clean up some memory - List tranches = VariantGaussianMixtureModel.findTranches( dataManager.data, FDR_TRANCHES, TARGET_TITV, DEBUG_FILE ); + // deal with truth information + int nCallsAtTruth = VariantGaussianMixtureModel.countCallsAtTruth(dataManager.data, Double.NEGATIVE_INFINITY); + logger.info(String.format("Truth set size is %d, raw calls at these sites %d, maximum sensitivity of %.2f", + nTruthSites, nCallsAtTruth, (100.0*nCallsAtTruth / Math.max(nTruthSites, nCallsAtTruth)))); + + VariantGaussianMixtureModel.SelectionMetric metric = null; + switch ( SELECTION_METRIC_TYPE ) { + case NOVEL_TITV: metric = new VariantGaussianMixtureModel.NovelTiTvMetric(TARGET_TITV); break; + case TRUTH_SENSITIVITY: metric = new VariantGaussianMixtureModel.TruthSensitivityMetric(nCallsAtTruth); break; + default: throw new ReviewedStingException("BUG: unexpected SelectionMetricType " + SELECTION_METRIC_TYPE); + } + + List tranches = VariantGaussianMixtureModel.findTranches( dataManager.data, FDR_TRANCHES, metric, DEBUG_FILE ); tranchesStream.print(Tranche.tranchesString(tranches)); // Execute Rscript command to plot the optimization curve diff --git a/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModelUnitTest.java b/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModelUnitTest.java index 50d1484f9..037ab6c98 100644 --- a/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModelUnitTest.java +++ b/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModelUnitTest.java @@ -122,10 +122,15 @@ public final class VariantGaussianMixtureModelUnitTest extends BaseTest { } } + private static final List findMyTranches(List vd, double[] fdrs, double targetTiTv) { + VariantGaussianMixtureModel.SelectionMetric metric = new VariantGaussianMixtureModel.NovelTiTvMetric(targetTiTv); + return VariantGaussianMixtureModel.findTranches(vd.toArray(new VariantDatum[0]), fdrs, metric); + } + @Test public final void testFindTranches1() { List vd = readData(); - List tranches = VariantGaussianMixtureModel.findTranches(vd.toArray(new VariantDatum[0]), FDRS, TARGET_TITV); + List tranches = findMyTranches(vd, FDRS, TARGET_TITV); System.out.printf(Tranche.tranchesString(tranches)); assertTranchesAreTheSame(read(EXPECTED_TRANCHES_NEW), tranches, true, false); } @@ -133,12 +138,12 @@ public final class VariantGaussianMixtureModelUnitTest extends BaseTest { @Test(expectedExceptions = {UserException.class}) public final void testBadFDR() { List vd = readData(); - VariantGaussianMixtureModel.findTranches(vd.toArray(new VariantDatum[0]), new double[]{-1}, TARGET_TITV); + List tranches = findMyTranches(vd, new double[]{-1}, TARGET_TITV); } @Test(expectedExceptions = {UserException.class}) public final void testBadTargetTiTv() { List vd = readData(); - VariantGaussianMixtureModel.findTranches(vd.toArray(new VariantDatum[0]), FDRS, 0.1); + List tranches = findMyTranches(vd, FDRS, 0.1); } }