From abd6ce1c77f9b3872b0d9dc425354a6921233b86 Mon Sep 17 00:00:00 2001 From: depristo Date: Sat, 11 Dec 2010 23:08:13 +0000 Subject: [PATCH] A TiTv-free approach for cutting variants! Apparently much better than previous approach, and will work for indels and SV will truly minor modifications to the code. Will discuss with methods group on Monday. git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@4822 348d0f76-0448-11de-a6fe-93d51630548a --- .../walkers/variantrecalibration/Tranche.java | 34 ++- .../variantrecalibration/VariantDatum.java | 1 + .../VariantGaussianMixtureModel.java | 199 ++++++++++++++---- .../VariantRecalibrator.java | 35 ++- .../VariantGaussianMixtureModelUnitTest.java | 11 +- 5 files changed, 220 insertions(+), 60 deletions(-) diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/Tranche.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/Tranche.java index 2caf27e8a..76b6dd22b 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/Tranche.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/Tranche.java @@ -47,17 +47,20 @@ import java.util.*; */ public class Tranche implements Comparable { - private static final int CURRENT_VERSION = 2; + private static final int CURRENT_VERSION = 3; public double fdr, minVQSLod, targetTiTv, knownTiTv, novelTiTv; public int numKnown,numNovel; public String name; - public Tranche(double fdr, double targetTiTv, double minVQSLod, int numKnown, double knownTiTv, int numNovel, double novelTiTv) { - this(fdr, targetTiTv, minVQSLod, numKnown, knownTiTv, numNovel, novelTiTv, "anonymous"); + int accessibleTruthSites = 0; + int callsAtTruthSites = 0; + + public Tranche(double fdr, double targetTiTv, double minVQSLod, int numKnown, double knownTiTv, int numNovel, double novelTiTv, int accessibleTruthSites, int callsAtTruthSites) { + this(fdr, targetTiTv, minVQSLod, numKnown, knownTiTv, numNovel, novelTiTv, accessibleTruthSites, callsAtTruthSites, "anonymous"); } - public Tranche(double fdr, double targetTiTv, double minVQSLod, int numKnown, double knownTiTv, int numNovel, double novelTiTv, String name) { + public Tranche(double fdr, double targetTiTv, double minVQSLod, int numKnown, double knownTiTv, int numNovel, double novelTiTv, int accessibleTruthSites, int callsAtTruthSites, String name ) { this.fdr = fdr; this.targetTiTv = targetTiTv; this.minVQSLod = minVQSLod; @@ -67,7 +70,10 @@ public class Tranche implements Comparable { this.numKnown = numKnown; this.name = name; - if ( fdr <= 0.0 ) + this.accessibleTruthSites = accessibleTruthSites; + this.callsAtTruthSites = callsAtTruthSites; + + if ( fdr < 0.0 ) throw new UserException("Target FDR is unreasonable " + fdr); if ( targetTiTv < 0.5 || targetTiTv > 10 ) @@ -80,13 +86,17 @@ public class Tranche implements Comparable { throw new ReviewedStingException("BUG -- name cannot be null"); } + private double getTruthSensitivity() { + return accessibleTruthSites > 0 ? callsAtTruthSites / (1.0*accessibleTruthSites) : 0.0; + } + public int compareTo(Tranche other) { return Double.compare(this.fdr, other.fdr); } public String toString() { - return String.format("Tranche fdr=%.2f minVQSLod=%.4f known=(%d @ %.2f) novel=(%d @ %.2f) name=%s]", - fdr, minVQSLod, numKnown, knownTiTv, numNovel, novelTiTv, name); + return String.format("Tranche fdr=%.2f minVQSLod=%.4f known=(%d @ %.2f) novel=(%d @ %.2f) truthSites(%d accessible, %d called), name=%s]", + fdr, minVQSLod, numKnown, knownTiTv, numNovel, novelTiTv, accessibleTruthSites, callsAtTruthSites, name); } /** @@ -105,13 +115,13 @@ public class Tranche implements Comparable { stream.println("# Variant quality score tranches file"); stream.println("# Version number " + CURRENT_VERSION); - stream.println("FDRtranche,targetTiTv,numKnown,numNovel,knownTiTv,novelTiTv,minVQSLod,filterName"); + stream.println("FDRtranche,targetTiTv,numKnown,numNovel,knownTiTv,novelTiTv,minVQSLod,filterName,accessibleTruthSites,callsAtTruthSites,truthSensitivity"); Tranche prev = null; for ( Tranche t : tranches ) { - stream.printf("%.2f,%.2f,%d,%d,%.4f,%.4f,%.4f,FDRtranche%.2fto%.2f%n", + stream.printf("%.2f,%.2f,%d,%d,%.4f,%.4f,%.4f,FDRtranche%.2fto%.2f,%d,%d,%.4f%n", t.fdr,t.targetTiTv,t.numKnown,t.numNovel,t.knownTiTv,t.novelTiTv, t.minVQSLod, - (prev == null ? 0.0 : prev.fdr), t.fdr); + (prev == null ? 0.0 : prev.fdr), t.fdr, t.accessibleTruthSites, t.callsAtTruthSites, t.getTruthSensitivity()); prev = t; } @@ -161,7 +171,7 @@ public class Tranche implements Comparable { if ( header.length == 5 ) // old style tranches file, throw an error throw new UserException.MalformedFile(f, "Unfortuanately, your tranches file is from a previous version of this tool and cannot be used with the latest code. Please rerun VariantRecalibrator"); - if ( header.length != 8 ) + if ( header.length != 8 && header.length != 11 ) throw new UserException.MalformedFile(f, "Expected 8 elements in header line " + line); } else { if ( header.length != vals.length ) @@ -176,6 +186,8 @@ public class Tranche implements Comparable { getDouble(bindings,"knownTiTv", false), getInteger(bindings,"numNovel", true), getDouble(bindings,"novelTiTv", true), + getInteger(bindings,"accessibleTruthSites", false), + getInteger(bindings,"callsAtTruthSites", false), bindings.get("filterName"))); } } diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java index f49cd4b7a..57aa50f4c 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDatum.java @@ -38,6 +38,7 @@ public class VariantDatum implements Comparable { public boolean isKnown; public double lod; public double weight; + public boolean atTruthSite; public int compareTo(VariantDatum other) { return Double.compare(this.lod, other.lod); diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java index 3254e4184..3375fbbdf 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModel.java @@ -460,26 +460,114 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel // Code to determine FDR tranches for VariantDatum[] // // --------------------------------------------------------------------------------------------------------- - public final static List findTranches( final VariantDatum[] data, final double[] FDRtranches, double targetTiTv ) { - return findTranches( data, FDRtranches, targetTiTv, null ); + public static abstract class SelectionMetric { + String name = null; + + public SelectionMetric(String name) { + this.name = name; + } + + public String getName() { return name; } + + public abstract double getThreshold(double tranche); + public abstract double getTarget(); + public abstract void calculateRunningMetric(List data); + public abstract double getRunningMetric(int i); + public abstract int datumValue(VariantDatum d); } - public final static List findTranches( final VariantDatum[] data, final double[] FDRtranches, double targetTiTv, File debugFile ) { - logger.info(String.format("Finding tranches for %d variants with %d FDRs and a target TiTv of %.2f", data.length, FDRtranches.length, targetTiTv)); + public static class NovelTiTvMetric extends SelectionMetric { + double[] runningTiTv; + double targetTiTv = 0; + + public NovelTiTvMetric(double target) { + super("NovelTiTv"); + targetTiTv = target; // compute the desired TiTv + } + + public double getThreshold(double tranche) { + return fdrToTiTv(tranche, targetTiTv); + } + + public double getTarget() { return targetTiTv; } + + public void calculateRunningMetric(List data) { + int ti = 0, tv = 0; + runningTiTv = new double[data.size()]; + + for ( int i = data.size() - 1; i >= 0; i-- ) { + VariantDatum datum = data.get(i); + if ( ! datum.isKnown ) { + if ( datum.isTransition ) { ti++; } else { tv++; } + runningTiTv[i] = ti / Math.max(1.0 * tv, 1.0); + } + } + } + + public double getRunningMetric(int i) { + return runningTiTv[i]; + } + + public int datumValue(VariantDatum d) { + return d.isTransition ? 1 : 0; + } + } + + public static class TruthSensitivityMetric extends SelectionMetric { + double[] runningSensitivity; + double targetTiTv = 0; + int nTrueSites = 0; + + public TruthSensitivityMetric(int nTrueSites) { + super("TruthSensitivity"); + this.nTrueSites = nTrueSites; + } + + public double getThreshold(double tranche) { + return tranche/100; // tranche of 1 => 99% sensivity target + } + + public double getTarget() { return 1.0; } + + public void calculateRunningMetric(List data) { + int nCalledAtTruth = 0; + runningSensitivity = new double[data.size()]; + + for ( int i = data.size() - 1; i >= 0; i-- ) { + VariantDatum datum = data.get(i); + nCalledAtTruth += datum.atTruthSite ? 1 : 0; + runningSensitivity[i] = 1 - nCalledAtTruth / (1.0 * nTrueSites); + } + } + + public double getRunningMetric(int i) { + return runningSensitivity[i]; + } + + public int datumValue(VariantDatum d) { + return d.atTruthSite ? 1 : 0; + } + } + + public final static List findTranches( final VariantDatum[] data, final double[] tranches, SelectionMetric metric ) { + return findTranches( data, tranches, metric, null ); + } + + public final static List findTranches( final VariantDatum[] data, final double[] trancheThresholds, SelectionMetric metric, File debugFile ) { + logger.info(String.format("Finding %d tranches for %d variants", trancheThresholds.length, data.length)); List tranchesData = sortVariantsbyLod(data); - double[] runningTiTv = calculateRunningTiTv(tranchesData); + metric.calculateRunningMetric(tranchesData); - if ( debugFile != null) { writeTranchesDebuggingInfo(debugFile, tranchesData, runningTiTv); } + if ( debugFile != null) { writeTranchesDebuggingInfo(debugFile, tranchesData, metric); } List tranches = new ArrayList(); - for ( double fdr : FDRtranches ) { - Tranche t = findTranche(tranchesData, runningTiTv, fdr, targetTiTv); - // todo -- should abort early when t's qual is 0 -- that's the lowest we'll get to + for ( double trancheThreshold : trancheThresholds ) { + Tranche t = findTranche(tranchesData, metric, trancheThreshold); if ( t == null ) { if ( tranches.size() == 0 ) - throw new UserException("Couldn't find any tranche containing variants with a TiTv > target of " + targetTiTv); + throw new UserException(String.format("Couldn't find any tranche containing variants with a %s > %.2f", metric.getName(), metric.getThreshold(trancheThreshold))); break; } @@ -489,13 +577,15 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel return tranches; } - private static void writeTranchesDebuggingInfo(File f, List tranchesData, double[] runningTiTv ) { + private static void writeTranchesDebuggingInfo(File f, List tranchesData, SelectionMetric metric ) { try { PrintStream out = new PrintStream(f); - out.println("Qual isTransition runningTiTv"); - for ( int i = 0; i < runningTiTv.length; i++ ) { + out.println("Qual metricValue runningValue"); + for ( int i = 0; i < tranchesData.size(); i++ ) { VariantDatum d = tranchesData.get(i); - out.printf("%.4f %d %.4f%n", d.lod, d.isTransition ? 1 : 0, runningTiTv[i]); + int score = metric.datumValue(d); + double runningValue = metric.getRunningMetric(i); + out.printf("%.4f %d %.4f%n", d.lod, score, runningValue); } } catch (FileNotFoundException e) { throw new UserException.CouldNotCreateOutputFile(f, e); @@ -508,44 +598,46 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel return sorted; } - private static double[] calculateRunningTiTv(List data) { - int ti = 0, tv = 0; - double[] run = new double[data.size()]; + public final static Tranche findTranche( final List data, final SelectionMetric metric, final double trancheThreshold ) { + logger.info(String.format(" Tranche threshold %.2f => selection metric threshold %.3f", trancheThreshold, metric.getThreshold(trancheThreshold))); - for ( int i = data.size() - 1; i >= 0; i-- ) { - VariantDatum datum = data.get(i); - if ( ! datum.isKnown ) { - if ( datum.isTransition ) { ti++; } else { tv++; } - run[i] = ti / Math.max(1.0 * tv, 1.0); - } - } - - return run; - } - - public final static Tranche findTranche( final List data, double[] runningTiTv, final double desiredFDR, double targetTiTv ) { - if ( data.size() != runningTiTv.length ) throw new ReviewedStingException("BUG: data and running titv are of different sizes"); - - final double titvThreshold = fdrToTiTv(desiredFDR, targetTiTv); // compute the desired TiTv - - for ( int i = 0; i < runningTiTv.length; i++ ) { - if ( runningTiTv[i] >= titvThreshold ) { + double metricThreshold = metric.getThreshold(trancheThreshold); + int n = data.size(); + for ( int i = 0; i < n; i++ ) { + if ( metric.getRunningMetric(i) >= metricThreshold ) { // we've found the largest group of variants with Ti/Tv >= our target titv - Tranche t = trancheOfVariants(data, i, desiredFDR, targetTiTv); - logger.info(String.format(" Found tranche for %.3f FDR = %.3f TiTv threshold starting with variant %d; running TiTv is %.2f ", - desiredFDR, titvThreshold, i, runningTiTv[i])); - logger.info(String.format(" Trance is %s", t)); + Tranche t = trancheOfVariants(data, i, trancheThreshold, metric.getTarget()); + logger.info(String.format(" Found tranche for %.3f: %.3f threshold starting with variant %d; running score is %.3f ", + trancheThreshold, metricThreshold, i, metric.getRunningMetric(i))); + logger.info(String.format(" Tranche is %s", t)); return t; } } // we get here when there's no subset of variants with Ti/Tv >= threshold, in which case we should return null - logger.info(String.format(" **Couldn't find tranche for %.3f FDR = %.3f TiTv threshold; last running TiTv was %.2f ", - desiredFDR, titvThreshold, runningTiTv[runningTiTv.length-1])); + //logger.info(String.format(" **Couldn't find tranche for %.3f FDR = %.3f TiTv threshold; last running TiTv was %.2f ", + // desiredFDR, titvThreshold, runningTiTv[runningTiTv.length-1])); + + +// for ( int i = 0; i < runningTiTv.length; i++ ) { +// if ( runningTiTv[i] >= titvThreshold ) { +// // we've found the largest group of variants with Ti/Tv >= our target titv +// Tranche t = trancheOfVariants(data, i, desiredFDR, targetTiTv); +// logger.info(String.format(" Found tranche for %.3f FDR = %.3f TiTv threshold starting with variant %d; running TiTv is %.2f ", +// desiredFDR, titvThreshold, i, metric.getRunningMetric(i))); +// logger.info(String.format(" Tranche is %s", t)); +// return t; +// } +// } +// +// // we get here when there's no subset of variants with Ti/Tv >= threshold, in which case we should return null +// logger.info(String.format(" **Couldn't find tranche for %.3f FDR = %.3f TiTv threshold; last running TiTv was %.2f ", +// desiredFDR, titvThreshold, runningTiTv[runningTiTv.length-1])); +// return null; } - public final static Tranche trancheOfVariants( final List data, int minI, double fdr, double targetTiTv ) { + public final static Tranche trancheOfVariants( final List data, int minI, double fdr, double target ) { int numKnown = 0, numNovel = 0, knownTi = 0, knownTv = 0, novelTi = 0, novelTv = 0; double minLod = data.get(minI).lod; @@ -568,13 +660,34 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel double knownTiTv = knownTi / Math.max(1.0 * knownTv, 1.0); double novelTiTv = novelTi / Math.max(1.0 * novelTv, 1.0); - return new Tranche(fdr, targetTiTv, minLod, numKnown, knownTiTv, numNovel, novelTiTv); + int accessibleTruthSites = countCallsAtTruth(data, Double.NEGATIVE_INFINITY); + int nCallsAtTruth = countCallsAtTruth(data, minLod); + + + return new Tranche(fdr, target, minLod, numKnown, knownTiTv, numNovel, novelTiTv, accessibleTruthSites, nCallsAtTruth); } public final static double fdrToTiTv(double desiredFDR, double targetTiTv) { return (1.0 - desiredFDR / 100.0) * (targetTiTv - 0.5) + 0.5; } + public static int countCallsAtTruth(final VariantDatum[] data, double minLOD ) { + int n = 0; + for ( VariantDatum d : data) { n += d.atTruthSite && d.lod >= minLOD ? 1 : 0; } + return n; + } + + public static int countCallsAtTruth(final List data, double minLOD ) { + int n = 0; + for ( VariantDatum d : data) { n += d.atTruthSite && d.lod >= minLOD ? 1 : 0; } + return n; + } + + // --------------------------------------------------------------------------------------------------------- + // + // Code to evaluate vectors given an already trained model + // + // --------------------------------------------------------------------------------------------------------- private double evaluateGaussians( final VariantDatum[] data, final double[][] pVarInCluster, final int startCluster, final int stopCluster ) { final int numAnnotations = data[0].annotations.length; diff --git a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java index 3748ed7d2..dcd82c7de 100755 --- a/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java +++ b/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java @@ -94,7 +94,7 @@ public class VariantRecalibrator extends RodWalker vcsTruth = tracker.getVariantContexts(ref, "truth", null, context.getLocation(), false, true); + final VariantContext vcTruth = ( vcsTruth.size() != 0 ? vcsTruth.iterator().next() : null ); + nTruthSites += vcTruth != null && vcTruth.isVariant() ? 1 : 0; + for( final VariantContext vc : tracker.getVariantContexts(ref, inputNames, null, context.getLocation(), false, false) ) { if( vc != null && vc.isSNP() ) { if( !vc.isFiltered() || IGNORE_ALL_INPUT_FILTERS || (ignoreInputFilterSet != null && ignoreInputFilterSet.containsAll(vc.getFilters())) ) { @@ -293,6 +308,9 @@ public class VariantRecalibrator extends RodWalker attrs = new HashMap(vc.getAttributes()); attrs.put(VariantRecalibrator.VQS_LOD_KEY, String.format("%.4f", lod)); @@ -327,11 +345,22 @@ public class VariantRecalibrator extends RodWalker reduceSum ) { - final VariantDataManager dataManager = new VariantDataManager( reduceSum, theModel.dataManager.annotationKeys ); reduceSum.clear(); // Don't need this ever again, clean up some memory - List tranches = VariantGaussianMixtureModel.findTranches( dataManager.data, FDR_TRANCHES, TARGET_TITV, DEBUG_FILE ); + // deal with truth information + int nCallsAtTruth = VariantGaussianMixtureModel.countCallsAtTruth(dataManager.data, Double.NEGATIVE_INFINITY); + logger.info(String.format("Truth set size is %d, raw calls at these sites %d, maximum sensitivity of %.2f", + nTruthSites, nCallsAtTruth, (100.0*nCallsAtTruth / Math.max(nTruthSites, nCallsAtTruth)))); + + VariantGaussianMixtureModel.SelectionMetric metric = null; + switch ( SELECTION_METRIC_TYPE ) { + case NOVEL_TITV: metric = new VariantGaussianMixtureModel.NovelTiTvMetric(TARGET_TITV); break; + case TRUTH_SENSITIVITY: metric = new VariantGaussianMixtureModel.TruthSensitivityMetric(nCallsAtTruth); break; + default: throw new ReviewedStingException("BUG: unexpected SelectionMetricType " + SELECTION_METRIC_TYPE); + } + + List tranches = VariantGaussianMixtureModel.findTranches( dataManager.data, FDR_TRANCHES, metric, DEBUG_FILE ); tranchesStream.print(Tranche.tranchesString(tranches)); // Execute Rscript command to plot the optimization curve diff --git a/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModelUnitTest.java b/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModelUnitTest.java index 50d1484f9..037ab6c98 100644 --- a/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModelUnitTest.java +++ b/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantGaussianMixtureModelUnitTest.java @@ -122,10 +122,15 @@ public final class VariantGaussianMixtureModelUnitTest extends BaseTest { } } + private static final List findMyTranches(List vd, double[] fdrs, double targetTiTv) { + VariantGaussianMixtureModel.SelectionMetric metric = new VariantGaussianMixtureModel.NovelTiTvMetric(targetTiTv); + return VariantGaussianMixtureModel.findTranches(vd.toArray(new VariantDatum[0]), fdrs, metric); + } + @Test public final void testFindTranches1() { List vd = readData(); - List tranches = VariantGaussianMixtureModel.findTranches(vd.toArray(new VariantDatum[0]), FDRS, TARGET_TITV); + List tranches = findMyTranches(vd, FDRS, TARGET_TITV); System.out.printf(Tranche.tranchesString(tranches)); assertTranchesAreTheSame(read(EXPECTED_TRANCHES_NEW), tranches, true, false); } @@ -133,12 +138,12 @@ public final class VariantGaussianMixtureModelUnitTest extends BaseTest { @Test(expectedExceptions = {UserException.class}) public final void testBadFDR() { List vd = readData(); - VariantGaussianMixtureModel.findTranches(vd.toArray(new VariantDatum[0]), new double[]{-1}, TARGET_TITV); + List tranches = findMyTranches(vd, new double[]{-1}, TARGET_TITV); } @Test(expectedExceptions = {UserException.class}) public final void testBadTargetTiTv() { List vd = readData(); - VariantGaussianMixtureModel.findTranches(vd.toArray(new VariantDatum[0]), FDRS, 0.1); + List tranches = findMyTranches(vd, FDRS, 0.1); } }