From b57054c63cf95bf639486374e8e99c44a90ada58 Mon Sep 17 00:00:00 2001 From: Ryan Poplin Date: Mon, 30 Sep 2013 11:14:35 -0400 Subject: [PATCH] Various VQSR optimizations in both runtime and accuracy. -- For very large whole genome datasets with over 2M variants overlapping the training data randomly downsample the training set that gets used to build the Gaussian mixture model. -- Annotations are ordered by the difference in means between known and novel instead of by their standard deviation. -- Removed the training set quality score threshold. -- Now uses 2 gaussians by default for the negative model. -- Num bad argument has been removed and the cutoffs are now chosen by the model itself by looking at the LOD scores. -- Model plots are now generated much faster. -- Stricter threshold for determining model convergence. -- All VQSR integration tests change because of these changes to the model. -- Add test for downsampling of training data. --- .../ApplyRecalibration.java | 112 +++++++++----- .../GaussianMixtureModel.java | 2 +- .../variantrecalibration/TrancheManager.java | 6 +- .../VariantDataManager.java | 139 ++++++++---------- .../VariantRecalibrator.java | 62 ++++---- ...VariantRecalibratorArgumentCollection.java | 39 +++-- .../VariantRecalibratorEngine.java | 2 +- .../ApplyRecalibrationUnitTest.java | 68 +++++++++ .../VariantDataManagerUnitTest.java | 82 ++++++++++- ...ntRecalibrationWalkersIntegrationTest.java | 26 ++-- 10 files changed, 364 insertions(+), 174 deletions(-) create mode 100644 protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/ApplyRecalibrationUnitTest.java diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/ApplyRecalibration.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/ApplyRecalibration.java index 314efe2a2..4b5237087 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/ApplyRecalibration.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/ApplyRecalibration.java @@ -46,10 +46,8 @@ package org.broadinstitute.sting.gatk.walkers.variantrecalibration; -import org.broadinstitute.sting.commandline.Argument; -import org.broadinstitute.sting.commandline.Input; -import org.broadinstitute.sting.commandline.Output; -import org.broadinstitute.sting.commandline.RodBinding; +import org.apache.commons.math.util.MathUtils; +import org.broadinstitute.sting.commandline.*; import org.broadinstitute.sting.gatk.CommandLineGATK; import org.broadinstitute.sting.gatk.contexts.AlignmentContext; import org.broadinstitute.sting.gatk.contexts.ReferenceContext; @@ -112,6 +110,9 @@ import java.util.*; @PartitionBy(PartitionType.LOCUS) public class ApplyRecalibration extends RodWalker implements TreeReducible { + public static final String LOW_VQSLOD_FILTER_NAME = "LOW_VQSLOD"; + private final double DEFAULT_VQSLOD_CUTOFF = 0.0; + ///////////////////////////// // Inputs ///////////////////////////// @@ -122,7 +123,7 @@ public class ApplyRecalibration extends RodWalker implements T public List> input; @Input(fullName="recal_file", shortName="recalFile", doc="The input recal file used by ApplyRecalibration", required=true) protected RodBinding recal; - @Input(fullName="tranches_file", shortName="tranchesFile", doc="The input tranches file describing where to cut the data", required=true) + @Input(fullName="tranches_file", shortName="tranchesFile", doc="The input tranches file describing where to cut the data", required=false) protected File TRANCHES_FILE; ///////////////////////////// @@ -134,8 +135,13 @@ public class ApplyRecalibration extends RodWalker implements T ///////////////////////////// // Command Line Arguments ///////////////////////////// + @Advanced @Argument(fullName="ts_filter_level", shortName="ts_filter_level", doc="The truth sensitivity level at which to start filtering", required=false) - protected double TS_FILTER_LEVEL = 99.0; + protected Double TS_FILTER_LEVEL = null; + @Advanced + @Argument(fullName="lodCutoff", shortName="lodCutoff", doc="The VQSLOD score below which to start filtering", required=false) + protected Double VQSLOD_CUTOFF = null; + /** * For this to work properly, the -ignoreFilter argument should also be applied to the VariantRecalibration command. */ @@ -160,13 +166,15 @@ public class ApplyRecalibration extends RodWalker implements T //--------------------------------------------------------------------------------------------------------------- public void initialize() { - for ( final Tranche t : Tranche.readTranches(TRANCHES_FILE) ) { - if ( t.ts >= TS_FILTER_LEVEL ) { - tranches.add(t); + if( TS_FILTER_LEVEL != null ) { + for ( final Tranche t : Tranche.readTranches(TRANCHES_FILE) ) { + if ( t.ts >= TS_FILTER_LEVEL ) { + tranches.add(t); + } + logger.info(String.format("Read tranche " + t)); } - logger.info(String.format("Read tranche " + t)); + Collections.reverse(tranches); // this algorithm wants the tranches ordered from best (lowest truth sensitivity) to worst (highest truth sensitivity) } - Collections.reverse(tranches); // this algorithm wants the tranches ordered from best (lowest truth sensitivity) to worst (highest truth sensitivity) for( final RodBinding rod : input ) { inputNames.add( rod.getName() ); @@ -183,19 +191,32 @@ public class ApplyRecalibration extends RodWalker implements T final TreeSet samples = new TreeSet<>(); samples.addAll(SampleUtils.getUniqueSamplesFromRods(getToolkit(), inputNames)); - if( tranches.size() >= 2 ) { - for( int iii = 0; iii < tranches.size() - 1; iii++ ) { - final Tranche t = tranches.get(iii); - hInfo.add(new VCFFilterHeaderLine(t.name, String.format("Truth sensitivity tranche level for " + t.model.toString() + " model at VQS Lod: " + t.minVQSLod + " <= x < " + tranches.get(iii+1).minVQSLod))); + if( TS_FILTER_LEVEL != null ) { + // if the user specifies both ts_filter_level and lodCutoff then throw a user error + if( VQSLOD_CUTOFF != null ) { + throw new UserException("Arguments --ts_filter_level and --lodCutoff are mutually exclusive. Please only specify one option."); } - } - if( tranches.size() >= 1 ) { - hInfo.add(new VCFFilterHeaderLine(tranches.get(0).name + "+", String.format("Truth sensitivity tranche level for " + tranches.get(0).model.toString() + " model at VQS Lod < " + tranches.get(0).minVQSLod))); - } else { - throw new UserException("No tranches were found in the file or were above the truth sensitivity filter level " + TS_FILTER_LEVEL); - } - logger.info("Keeping all variants in tranche " + tranches.get(tranches.size()-1)); + if( tranches.size() >= 2 ) { + for( int iii = 0; iii < tranches.size() - 1; iii++ ) { + final Tranche t = tranches.get(iii); + hInfo.add(new VCFFilterHeaderLine(t.name, String.format("Truth sensitivity tranche level for " + t.model.toString() + " model at VQS Lod: " + t.minVQSLod + " <= x < " + tranches.get(iii+1).minVQSLod))); + } + } + if( tranches.size() >= 1 ) { + hInfo.add(new VCFFilterHeaderLine(tranches.get(0).name + "+", String.format("Truth sensitivity tranche level for " + tranches.get(0).model.toString() + " model at VQS Lod < " + tranches.get(0).minVQSLod))); + } else { + throw new UserException("No tranches were found in the file or were above the truth sensitivity filter level " + TS_FILTER_LEVEL); + } + + logger.info("Keeping all variants in tranche " + tranches.get(tranches.size()-1)); + } else { + if( VQSLOD_CUTOFF == null ) { + VQSLOD_CUTOFF = DEFAULT_VQSLOD_CUTOFF; + } + hInfo.add(new VCFFilterHeaderLine(LOW_VQSLOD_FILTER_NAME, "VQSLOD < " + VQSLOD_CUTOFF)); + logger.info("Keeping all variants with VQSLOD >= " + VQSLOD_CUTOFF); + } final VCFHeader vcfHeader = new VCFHeader(hInfo, samples); vcfWriter.writeHeader(vcfHeader); @@ -245,7 +266,6 @@ public class ApplyRecalibration extends RodWalker implements T } VariantContextBuilder builder = new VariantContextBuilder(vc); - String filterString = null; // Annotate the new record with its VQSLOD and the worst performing annotation builder.attribute(VariantRecalibrator.VQS_LOD_KEY, lod); @@ -255,21 +275,7 @@ public class ApplyRecalibration extends RodWalker implements T if ( recalDatum.hasAttribute(VariantRecalibrator.NEGATIVE_LABEL_KEY)) builder.attribute(VariantRecalibrator.NEGATIVE_LABEL_KEY, true); - for( int i = tranches.size() - 1; i >= 0; i-- ) { - final Tranche tranche = tranches.get(i); - if( lod >= tranche.minVQSLod ) { - if( i == tranches.size() - 1 ) { - filterString = VCFConstants.PASSES_FILTERS_v4; - } else { - filterString = tranche.name; - } - break; - } - } - - if( filterString == null ) { - filterString = tranches.get(0).name+"+"; - } + final String filterString = generateFilterString(lod); if( filterString.equals(VCFConstants.PASSES_FILTERS_v4) ) { builder.passFilters(); @@ -289,6 +295,36 @@ public class ApplyRecalibration extends RodWalker implements T return 1; // This value isn't used for anything } + /** + * Generate the VCF filter string for this record based on the provided lod score + * @param lod non-null double + * @return the String to use as the VCF filter field + */ + protected String generateFilterString( final double lod ) { + String filterString = null; + if( TS_FILTER_LEVEL != null ) { + for( int i = tranches.size() - 1; i >= 0; i-- ) { + final Tranche tranche = tranches.get(i); + if( lod >= tranche.minVQSLod ) { + if( i == tranches.size() - 1 ) { + filterString = VCFConstants.PASSES_FILTERS_v4; + } else { + filterString = tranche.name; + } + break; + } + } + + if( filterString == null ) { + filterString = tranches.get(0).name+"+"; + } + } else { + filterString = ( lod < VQSLOD_CUTOFF ? LOW_VQSLOD_FILTER_NAME : VCFConstants.PASSES_FILTERS_v4 ); + } + + return filterString; + } + private static VariantContext getMatchingRecalVC(final VariantContext target, final List recalVCs) { for( final VariantContext recalVC : recalVCs ) { if ( target.getEnd() == recalVC.getEnd() ) { diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GaussianMixtureModel.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GaussianMixtureModel.java index ed26bc17a..9e36e5dbe 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GaussianMixtureModel.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/GaussianMixtureModel.java @@ -267,7 +267,7 @@ public class GaussianMixtureModel { public double evaluateDatumMarginalized( final VariantDatum datum ) { int numRandomDraws = 0; double sumPVarInGaussian = 0.0; - final int numIterPerMissingAnnotation = 10; // Trade off here between speed of computation and accuracy of the marginalization + final int numIterPerMissingAnnotation = 20; // Trade off here between speed of computation and accuracy of the marginalization final double[] pVarInGaussianLog10 = new double[gaussians.size()]; // for each dimension for( int iii = 0; iii < datum.annotations.length; iii++ ) { diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/TrancheManager.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/TrancheManager.java index 30377b63e..ab6b4adda 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/TrancheManager.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/TrancheManager.java @@ -160,11 +160,11 @@ public class TrancheManager { } } - public static List findTranches( final ArrayList data, final double[] tranches, final SelectionMetric metric, final VariantRecalibratorArgumentCollection.Mode model ) { + public static List findTranches( final List data, final double[] tranches, final SelectionMetric metric, final VariantRecalibratorArgumentCollection.Mode model ) { return findTranches( data, tranches, metric, model, null ); } - public static List findTranches( final ArrayList data, final double[] trancheThresholds, final SelectionMetric metric, final VariantRecalibratorArgumentCollection.Mode model, final File debugFile ) { + public static List findTranches( final List data, final double[] trancheThresholds, final SelectionMetric metric, final VariantRecalibratorArgumentCollection.Mode model, final File debugFile ) { logger.info(String.format("Finding %d tranches for %d variants", trancheThresholds.length, data.size())); Collections.sort( data, new VariantDatum.VariantDatumLODComparator() ); @@ -172,7 +172,7 @@ public class TrancheManager { if ( debugFile != null) { writeTranchesDebuggingInfo(debugFile, data, metric); } - List tranches = new ArrayList(); + List tranches = new ArrayList<>(); for ( double trancheThreshold : trancheThresholds ) { Tranche t = findTranche(data, metric, trancheThreshold, model); diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java index 65b1c2322..ac4654f73 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantDataManager.java @@ -71,7 +71,7 @@ import java.util.*; */ public class VariantDataManager { - private ExpandingArrayList data; + private List data; private double[] meanVector; private double[] varianceVector; // this is really the standard deviation public List annotationKeys; @@ -88,30 +88,30 @@ public class VariantDataManager { trainingSets = new ArrayList<>(); } - public void setData( final ExpandingArrayList data ) { + public void setData( final List data ) { this.data = data; } - public ExpandingArrayList getData() { + public List getData() { return data; } public void normalizeData() { boolean foundZeroVarianceAnnotation = false; for( int iii = 0; iii < meanVector.length; iii++ ) { - final double theMean = mean(iii); - final double theSTD = standardDeviation(theMean, iii); + final double theMean = mean(iii, true); + final double theSTD = standardDeviation(theMean, iii, true); logger.info( annotationKeys.get(iii) + String.format(": \t mean = %.2f\t standard deviation = %.2f", theMean, theSTD) ); if( Double.isNaN(theMean) ) { throw new UserException.BadInput("Values for " + annotationKeys.get(iii) + " annotation not detected for ANY training variant in the input callset. VariantAnnotator may be used to add these annotations. See " + HelpConstants.forumPost("discussion/49/using-variant-annotator")); } - foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-6); + foundZeroVarianceAnnotation = foundZeroVarianceAnnotation || (theSTD < 1E-5); meanVector[iii] = theMean; varianceVector[iii] = theSTD; for( final VariantDatum datum : data ) { // Transform each data point via: (x - mean) / standard deviation - datum.annotations[iii] = ( datum.isNull[iii] ? GenomeAnalysisEngine.getRandomGenerator().nextGaussian() : ( datum.annotations[iii] - theMean ) / theSTD ); + datum.annotations[iii] = ( datum.isNull[iii] ? 0.1 * GenomeAnalysisEngine.getRandomGenerator().nextGaussian() : ( datum.annotations[iii] - theMean ) / theSTD ); } } if( foundZeroVarianceAnnotation ) { @@ -129,7 +129,7 @@ public class VariantDataManager { // re-order the data by increasing standard deviation so that the results don't depend on the order things were specified on the command line // standard deviation over the training points is used as a simple proxy for information content, perhaps there is a better thing to use here - final List theOrder = calculateSortOrder(varianceVector); + final List theOrder = calculateSortOrder(meanVector); annotationKeys = reorderList(annotationKeys, theOrder); varianceVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(varianceVector), theOrder)); meanVector = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(meanVector), theOrder)); @@ -137,40 +137,41 @@ public class VariantDataManager { datum.annotations = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.annotations), theOrder)); datum.isNull = ArrayUtils.toPrimitive(reorderArray(ArrayUtils.toObject(datum.isNull), theOrder)); } + logger.info("Annotations are now ordered by their information content: " + annotationKeys.toString()); } /** * Get a list of indices which give the ascending sort order of the data array - * @param data the data to consider + * @param inputVector the data to consider * @return a non-null list of integers with length matching the length of the input array */ - protected List calculateSortOrder(final double[] data) { - final List theOrder = new ArrayList<>(data.length); - final List sortedData = new ArrayList<>(data.length); + protected List calculateSortOrder(final double[] inputVector) { + final List theOrder = new ArrayList<>(inputVector.length); + final List toBeSorted = new ArrayList<>(inputVector.length); int count = 0; - for( final double d : data ) { - sortedData.add(new MyStandardDeviation(d, count++)); + for( int iii = 0; iii < inputVector.length; iii++ ) { + toBeSorted.add(new MyDoubleForSorting(-1.0 * Math.abs(inputVector[iii] - mean(iii, false)), count++)); } - Collections.sort(sortedData); // sort the data in ascending order - for( final MyStandardDeviation d : sortedData ) { + Collections.sort(toBeSorted); + for( final MyDoubleForSorting d : toBeSorted ) { theOrder.add(d.originalIndex); // read off the sort order by looking at the index field } return theOrder; } - // small private class to assist in reading off the new ordering of the standard deviation array - private class MyStandardDeviation implements Comparable { - final Double standardDeviation; + // small private class to assist in reading off the new ordering of the annotation array + private class MyDoubleForSorting implements Comparable { + final Double myData; final int originalIndex; - public MyStandardDeviation( final double standardDeviation, final int originalIndex ) { - this.standardDeviation = standardDeviation; + public MyDoubleForSorting(final double myData, final int originalIndex) { + this.myData = myData; this.originalIndex = originalIndex; } @Override - public int compareTo(final MyStandardDeviation other) { - return standardDeviation.compareTo(other.standardDeviation); + public int compareTo(final MyDoubleForSorting other) { + return myData.compareTo(other.myData); } } @@ -233,92 +234,77 @@ public class VariantDataManager { return false; } - public ExpandingArrayList getTrainingData() { - final ExpandingArrayList trainingData = new ExpandingArrayList<>(); + public List getTrainingData() { + final List trainingData = new ExpandingArrayList<>(); for( final VariantDatum datum : data ) { - if( datum.atTrainingSite && !datum.failingSTDThreshold && datum.originalQual > VRAC.QUAL_THRESHOLD ) { + if( datum.atTrainingSite && !datum.failingSTDThreshold ) { trainingData.add( datum ); } } logger.info( "Training with " + trainingData.size() + " variants after standard deviation thresholding." ); - if( trainingData.size() < VRAC.NUM_BAD_VARIANTS) { + if( trainingData.size() < VRAC.MIN_NUM_BAD_VARIANTS ) { logger.warn( "WARNING: Training with very few variant sites! Please check the model reporting PDF to ensure the quality of the model is reliable." ); + } else if( trainingData.size() > VRAC.MAX_NUM_TRAINING_DATA ) { + logger.warn( "WARNING: Very large training set detected. Downsampling to " + VRAC.MAX_NUM_TRAINING_DATA + " training variants." ); + Collections.shuffle(trainingData); + return trainingData.subList(0, VRAC.MAX_NUM_TRAINING_DATA); } return trainingData; } - public ExpandingArrayList selectWorstVariants( final int minimumNumber ) { - // The return value is the list of training variants - final ExpandingArrayList trainingData = new ExpandingArrayList<>(); + public List selectWorstVariants() { + final List trainingData = new ExpandingArrayList<>(); - // First add to the training list all sites overlapping any bad sites training tracks for( final VariantDatum datum : data ) { - if( datum.atAntiTrainingSite && !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) { - trainingData.add( datum ); - } - } - final int numBadSitesAdded = trainingData.size(); - logger.info( "Found " + numBadSitesAdded + " variants overlapping bad sites training tracks." ); - - // Next sort the variants by the LOD coming from the positive model and add to the list the bottom X percent of variants - Collections.sort( data, new VariantDatum.VariantDatumLODComparator() ); - final int numToAdd = minimumNumber - trainingData.size(); - if( numToAdd > data.size() ) { - throw new UserException.BadInput( "Error during negative model training. Minimum number of variants to use in training is larger than the whole call set. You can try lowering the --numBadVariants argument but this is unsafe." ); - } - int index = 0, numAdded = 0; - while( numAdded < numToAdd && index < data.size() ) { - final VariantDatum datum = data.get(index++); - if( datum != null && !datum.atAntiTrainingSite && !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) ) { + if( datum != null && !datum.failingSTDThreshold && !Double.isInfinite(datum.lod) && datum.lod < VRAC.BAD_LOD_CUTOFF ) { datum.atAntiTrainingSite = true; trainingData.add( datum ); - numAdded++; } } - logger.info( "Additionally training with worst " + numToAdd + " scoring variants --> " + (trainingData.size() - numBadSitesAdded) + " variants with LOD <= " + String.format("%.4f", data.get(index).lod) + "." ); + + logger.info( "Training with worst " + trainingData.size() + " scoring variants --> variants with LOD <= " + String.format("%.4f", VRAC.BAD_LOD_CUTOFF) + "." ); + return trainingData; } - public ExpandingArrayList getRandomDataForPlotting( int numToAdd ) { - numToAdd = Math.min(numToAdd, data.size()); - final ExpandingArrayList returnData = new ExpandingArrayList<>(); - // add numToAdd non-anti training sites to plot - for( int iii = 0; iii < numToAdd; iii++) { - final VariantDatum datum = data.get(GenomeAnalysisEngine.getRandomGenerator().nextInt(data.size())); - if( ! datum.atAntiTrainingSite && !datum.failingSTDThreshold ) { - returnData.add(datum); - } - } + public List getEvaluationData() { + final List evaluationData = new ExpandingArrayList<>(); - final int MAX_ANTI_TRAINING_SITES = 10000; - int nAntiTrainingAdded = 0; - // Add all anti-training sites to visual for( final VariantDatum datum : data ) { - if ( nAntiTrainingAdded > MAX_ANTI_TRAINING_SITES ) - break; - else if ( datum.atAntiTrainingSite ) { - returnData.add(datum); - nAntiTrainingAdded++; + if( datum != null && !datum.failingSTDThreshold && !datum.atTrainingSite && !datum.atAntiTrainingSite ) { + evaluationData.add( datum ); } } + return evaluationData; + } + + public List getRandomDataForPlotting( final int numToAdd, final List trainingData, final List antiTrainingData, final List evaluationData ) { + final List returnData = new ExpandingArrayList<>(); + Collections.shuffle(trainingData); + Collections.shuffle(antiTrainingData); + Collections.shuffle(evaluationData); + returnData.addAll(trainingData.subList(0, Math.min(numToAdd, trainingData.size()))); + returnData.addAll(antiTrainingData.subList(0, Math.min(numToAdd, antiTrainingData.size()))); + returnData.addAll(evaluationData.subList(0, Math.min(numToAdd, evaluationData.size()))); + Collections.shuffle(returnData); return returnData; } - private double mean( final int index ) { + protected double mean( final int index, final boolean trainingData ) { double sum = 0.0; int numNonNull = 0; for( final VariantDatum datum : data ) { - if( datum.atTrainingSite && !datum.isNull[index] ) { sum += datum.annotations[index]; numNonNull++; } + if( (trainingData == datum.atTrainingSite) && !datum.isNull[index] ) { sum += datum.annotations[index]; numNonNull++; } } return sum / ((double) numNonNull); } - private double standardDeviation( final double mean, final int index ) { + protected double standardDeviation( final double mean, final int index, final boolean trainingData ) { double sum = 0.0; int numNonNull = 0; for( final VariantDatum datum : data ) { - if( datum.atTrainingSite && !datum.isNull[index] ) { sum += ((datum.annotations[index] - mean)*(datum.annotations[index] - mean)); numNonNull++; } + if( (trainingData == datum.atTrainingSite) && !datum.isNull[index] ) { sum += ((datum.annotations[index] - mean)*(datum.annotations[index] - mean)); numNonNull++; } } return Math.sqrt( sum / ((double) numNonNull) ); } @@ -343,12 +329,9 @@ public class VariantDataManager { try { value = vc.getAttributeAsDouble( annotationKey, Double.NaN ); if( Double.isInfinite(value) ) { value = Double.NaN; } - if( jitter && annotationKey.equalsIgnoreCase("HRUN") ) { // Integer valued annotations must be jittered a bit to work in this GMM - value += -0.25 + 0.5 * GenomeAnalysisEngine.getRandomGenerator().nextDouble(); - } - - if( jitter && annotationKey.equalsIgnoreCase("HaplotypeScore") && MathUtils.compareDoubles(value, 0.0, 0.0001) == 0 ) { value = -0.2 + 0.4*GenomeAnalysisEngine.getRandomGenerator().nextDouble(); } - if( jitter && annotationKey.equalsIgnoreCase("FS") && MathUtils.compareDoubles(value, 0.0, 0.001) == 0 ) { value = -0.2 + 0.4*GenomeAnalysisEngine.getRandomGenerator().nextDouble(); } + if( jitter && annotationKey.equalsIgnoreCase("HaplotypeScore") && MathUtils.compareDoubles(value, 0.0, 0.01) == 0 ) { value += 0.01 * GenomeAnalysisEngine.getRandomGenerator().nextGaussian(); } + if( jitter && annotationKey.equalsIgnoreCase("FS") && MathUtils.compareDoubles(value, 0.0, 0.01) == 0 ) { value += 0.01 * GenomeAnalysisEngine.getRandomGenerator().nextGaussian(); } + if( jitter && annotationKey.equalsIgnoreCase("InbreedingCoeff") && MathUtils.compareDoubles(value, 0.0, 0.01) == 0 ) { value += 0.01 * GenomeAnalysisEngine.getRandomGenerator().nextGaussian(); } } catch( Exception e ) { value = Double.NaN; // The VQSR works with missing data by marginalizing over the missing dimension when evaluating the Gaussian mixture model } diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java index c3f575022..1c32b852b 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrator.java @@ -48,6 +48,7 @@ package org.broadinstitute.sting.gatk.walkers.variantrecalibration; import org.broadinstitute.sting.commandline.*; import org.broadinstitute.sting.gatk.CommandLineGATK; +import org.broadinstitute.sting.gatk.GenomeAnalysisEngine; import org.broadinstitute.sting.gatk.contexts.AlignmentContext; import org.broadinstitute.sting.gatk.contexts.ReferenceContext; import org.broadinstitute.sting.gatk.refdata.RefMetaDataTracker; @@ -142,7 +143,7 @@ public class VariantRecalibrator extends RodWalker replicate = new ArrayList<>(); ///////////////////////////// // Debug Arguments @@ -223,7 +227,7 @@ public class VariantRecalibrator extends RodWalker ignoreInputFilterSet = new TreeSet(); + private final Set ignoreInputFilterSet = new TreeSet<>(); private final VariantRecalibratorEngine engine = new VariantRecalibratorEngine( VRAC ); //--------------------------------------------------------------------------------------------------------------- @@ -232,8 +236,9 @@ public class VariantRecalibrator extends RodWalker(Arrays.asList(USE_ANNOTATIONS)), VRAC ); + dataManager = new VariantDataManager( new ArrayList<>(Arrays.asList(USE_ANNOTATIONS)), VRAC ); if (RSCRIPT_FILE != null && !RScriptExecutor.RSCRIPT_EXISTS) Utils.warnUser(logger, String.format( @@ -262,9 +267,13 @@ public class VariantRecalibrator extends RodWalker hInfo = new HashSet(); + final Set hInfo = new HashSet<>(); ApplyRecalibration.addVQSRStandardHeaderLines(hInfo); recalWriter.writeHeader( new VCFHeader(hInfo) ); + + for( int iii = 0; iii < REPLICATE * 2; iii++ ) { + replicate.add(GenomeAnalysisEngine.getRandomGenerator().nextDouble()); + } } //--------------------------------------------------------------------------------------------------------------- @@ -273,8 +282,9 @@ public class VariantRecalibrator extends RodWalker map( final RefMetaDataTracker tracker, final ReferenceContext ref, final AlignmentContext context ) { - final ExpandingArrayList mapList = new ExpandingArrayList(); + final ExpandingArrayList mapList = new ExpandingArrayList<>(); if( tracker == null ) { // For some reason RodWalkers get map calls with null trackers return mapList; @@ -294,7 +304,7 @@ public class VariantRecalibrator extends RodWalker reduceInit() { - return new ExpandingArrayList(); + return new ExpandingArrayList<>(); } + @Override public ExpandingArrayList reduce( final ExpandingArrayList mapValue, final ExpandingArrayList reduceSum ) { reduceSum.addAll( mapValue ); return reduceSum; } + @Override public ExpandingArrayList treeReduce( final ExpandingArrayList lhs, final ExpandingArrayList rhs ) { rhs.addAll( lhs ); return rhs; @@ -331,21 +344,23 @@ public class VariantRecalibrator extends RodWalker reduceSum ) { dataManager.setData( reduceSum ); dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation // Generate the positive model using the training data and evaluate each variant - final GaussianMixtureModel goodModel = engine.generateModel( dataManager.getTrainingData(), VRAC.MAX_GAUSSIANS ); + final List positiveTrainingData = dataManager.getTrainingData(); + final GaussianMixtureModel goodModel = engine.generateModel( positiveTrainingData, VRAC.MAX_GAUSSIANS ); engine.evaluateData( dataManager.getData(), goodModel, false ); // Generate the negative model using the worst performing data and evaluate each variant contrastively - final ExpandingArrayList negativeTrainingData = dataManager.selectWorstVariants( VRAC.NUM_BAD_VARIANTS ); + final List negativeTrainingData = dataManager.selectWorstVariants(); final GaussianMixtureModel badModel = engine.generateModel( negativeTrainingData, Math.min(VRAC.MAX_GAUSSIANS_FOR_NEGATIVE_MODEL, VRAC.MAX_GAUSSIANS)); engine.evaluateData( dataManager.getData(), badModel, true ); if( badModel.failedToConverge || goodModel.failedToConverge ) { - throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --numBadVariants 3000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).") ); + throw new UserException("NaN LOD value assigned. Clustering with this few variants and these annotations is unsafe. Please consider " + (badModel.failedToConverge ? "raising the number of variants used to train the negative model (via --minNumBadVariants 5000, for example)." : "lowering the maximum number of Gaussians allowed for use in the model (via --maxGaussians 4, for example).") ); } engine.calculateWorstPerformingAnnotation( dataManager.getData(), goodModel, badModel ); @@ -356,19 +371,11 @@ public class VariantRecalibrator extends RodWalker tranches = TrancheManager.findTranches( dataManager.getData(), TS_TRANCHES, metric, VRAC.MODE ); tranchesStream.print(Tranche.tranchesString( tranches )); - // Find the filtering lodCutoff for display on the model PDFs. Red variants are those which were below the cutoff and filtered out of the final callset. - double lodCutoff = 0.0; - for( final Tranche tranche : tranches ) { - if( MathUtils.compareDoubles(tranche.ts, TS_FILTER_LEVEL, 0.0001) == 0 ) { - lodCutoff = tranche.minVQSLod; - } - } - logger.info( "Writing out recalibration table..." ); dataManager.writeOutRecalibrationTable( recalWriter ); if( RSCRIPT_FILE != null ) { logger.info( "Writing out visualization Rscript file..."); - createVisualizationScript( dataManager.getRandomDataForPlotting( 6000 ), goodModel, badModel, lodCutoff, dataManager.getAnnotationKeys().toArray(new String[USE_ANNOTATIONS.length]) ); + createVisualizationScript( dataManager.getRandomDataForPlotting( 1000, positiveTrainingData, negativeTrainingData, dataManager.getEvaluationData() ), goodModel, badModel, 0.0, dataManager.getAnnotationKeys().toArray(new String[USE_ANNOTATIONS.length]) ); } if(VRAC.MODE == VariantRecalibratorArgumentCollection.Mode.INDEL) { @@ -385,7 +392,7 @@ public class VariantRecalibrator extends RodWalker randomData, final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, final double lodCutoff, final String[] annotationKeys ) { + private void createVisualizationScript( final List randomData, final GaussianMixtureModel goodModel, final GaussianMixtureModel badModel, final double lodCutoff, final String[] annotationKeys ) { PrintStream stream; try { stream = new PrintStream(RSCRIPT_FILE); @@ -409,7 +416,7 @@ public class VariantRecalibrator extends RodWalker fakeData = new ExpandingArrayList(); + final List fakeData = new ExpandingArrayList<>(); double minAnn1 = 100.0, maxAnn1 = -100.0, minAnn2 = 100.0, maxAnn2 = -100.0; for( final VariantDatum datum : randomData ) { minAnn1 = Math.min(minAnn1, datum.annotations[iii]); @@ -418,8 +425,9 @@ public class VariantRecalibrator extends RodWalker(), new VariantRecalibratorArgumentCollection()); - final List order = vdm.calculateSortOrder(data); - Assert.assertArrayEquals(new int[]{2,3,0,1,4,6,5}, ArrayUtils.toPrimitive(order.toArray(new Integer[order.size()]))); + final double passingQual = 400.0; + final VariantRecalibratorArgumentCollection VRAC = new VariantRecalibratorArgumentCollection(); + + VariantDataManager vdm = new VariantDataManager(new ArrayList(), VRAC); + + final List theData = new ArrayList<>(); + final VariantDatum datum1 = new VariantDatum(); + datum1.atTrainingSite = true; + datum1.failingSTDThreshold = false; + datum1.originalQual = passingQual; + datum1.annotations = new double[]{0.0,-10.0,10.0}; + datum1.isNull = new boolean[]{false, false, false}; + theData.add(datum1); + + final VariantDatum datum2 = new VariantDatum(); + datum2.atTrainingSite = true; + datum2.failingSTDThreshold = false; + datum2.originalQual = passingQual; + datum2.annotations = new double[]{0.0,-9.0,15.0}; + datum2.isNull = new boolean[]{false, false, false}; + theData.add(datum2); + + final VariantDatum datum3 = new VariantDatum(); + datum3.atTrainingSite = false; + datum3.failingSTDThreshold = false; + datum3.originalQual = passingQual; + datum3.annotations = new double[]{0.0,1.0,999.0}; + datum3.isNull = new boolean[]{false, false, false}; + theData.add(datum3); + + final VariantDatum datum4 = new VariantDatum(); + datum4.atTrainingSite = false; + datum4.failingSTDThreshold = false; + datum4.originalQual = passingQual; + datum4.annotations = new double[]{0.015,2.0,1001.11}; + datum4.isNull = new boolean[]{false, false, false}; + theData.add(datum4); + + vdm.setData(theData); + + final double[] meanVector = new double[3]; + for( int iii = 0; iii < meanVector.length; iii++ ) { + meanVector[iii] = vdm.mean(iii, true); + } + final List order = vdm.calculateSortOrder(meanVector); + Assert.assertArrayEquals(new int[]{2,1,0}, ArrayUtils.toPrimitive(order.toArray(new Integer[order.size()]))); + } + + @Test + public final void testDownSamplingTrainingData() { + final int MAX_NUM_TRAINING_DATA = 5000; + final double passingQual = 400.0; + final VariantRecalibratorArgumentCollection VRAC = new VariantRecalibratorArgumentCollection(); + VRAC.MAX_NUM_TRAINING_DATA = MAX_NUM_TRAINING_DATA; + + VariantDataManager vdm = new VariantDataManager(new ArrayList(), VRAC); + final List theData = new ArrayList<>(); + for( int iii = 0; iii < MAX_NUM_TRAINING_DATA * 10; iii++) { + final VariantDatum datum = new VariantDatum(); + datum.atTrainingSite = true; + datum.failingSTDThreshold = false; + datum.originalQual = passingQual; + theData.add(datum); + } + + for( int iii = 0; iii < MAX_NUM_TRAINING_DATA * 2; iii++) { + final VariantDatum datum = new VariantDatum(); + datum.atTrainingSite = false; + datum.failingSTDThreshold = false; + datum.originalQual = passingQual; + theData.add(datum); + } + + vdm.setData(theData); + final List trainingData = vdm.getTrainingData(); + + Assert.assertTrue( trainingData.size() == MAX_NUM_TRAINING_DATA ); } } diff --git a/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java b/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java index aab4c8d3d..f3e57b48a 100644 --- a/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java +++ b/protected/java/test/org/broadinstitute/sting/gatk/walkers/variantrecalibration/VariantRecalibrationWalkersIntegrationTest.java @@ -79,9 +79,9 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { } VRTest lowPass = new VRTest(validationDataLocation + "phase1.projectConsensus.chr20.raw.snps.vcf", - "0f4ceeeb8e4a3c89f8591d5e531d8410", // tranches - "c979a102669498ef40dde47ca4133c42", // recal file - "8f60fd849537610b653b321869e94641"); // cut VCF + "6f029dc7d16e63e19c006613cd0a5cff", // tranches + "73c7897441622c9b37376eb4f071c560", // recal file + "11a28df79b92229bd317ac49a3ed0fa1"); // cut VCF @DataProvider(name = "VRTest") public Object[][] createData1() { @@ -126,9 +126,9 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { } VRTest bcfTest = new VRTest(privateTestDir + "vqsr.bcf_test.snps.unfiltered.bcf", - "6539e025997579cd0c7da12219cbc572", // tranches - "778e61f81ab3d468b75f684bef0478e5", // recal file - "21e96b0bb47e2976f53f11181f920e51"); // cut VCF + "3ad7f55fb3b072f373cbce0b32b66df4", // tranches + "e747c08131d58d9a4800720f6ca80e0c", // recal file + "e5808af3af0f2611ba5a3d172ab2557b"); // cut VCF @DataProvider(name = "VRBCFTest") public Object[][] createVRBCFTest() { @@ -178,15 +178,15 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { VRTest indelUnfiltered = new VRTest( validationDataLocation + "combined.phase1.chr20.raw.indels.unfiltered.sites.vcf", // all FILTERs as . - "8906fdae8beca712f5ff2808d35ef02d", // tranches - "07ffea25e04f6ef53079bccb30bd6a7b", // recal file - "8b3ef71cad71e8eb48a856a27ae4f8d5"); // cut VCF + "9a331328370889168a7aa3a625f73620", // tranches + "2cbbd146d68c40200b782e0226f71976", // recal file + "64dd98a5ab80cf5fd9a36eb66b38268e"); // cut VCF VRTest indelFiltered = new VRTest( validationDataLocation + "combined.phase1.chr20.raw.indels.filtered.sites.vcf", // all FILTERs as PASS - "8906fdae8beca712f5ff2808d35ef02d", // tranches - "07ffea25e04f6ef53079bccb30bd6a7b", // recal file - "3d69b280370cdd9611695e4893591306"); // cut VCF + "9a331328370889168a7aa3a625f73620", // tranches + "2cbbd146d68c40200b782e0226f71976", // recal file + "c0ec662001e829f5779a9d13b1d77d80"); // cut VCF @DataProvider(name = "VRIndelTest") public Object[][] createTestVariantRecalibratorIndel() { @@ -242,7 +242,7 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest { " -o %s" + " -tranchesFile " + privateTestDir + "VQSR.mixedTest.tranches" + " -recalFile " + privateTestDir + "VQSR.mixedTest.recal", - Arrays.asList("20c23643a78c5b95abd1526fdab8960d")); + Arrays.asList("03a0ed00af6aac76d39e569f90594a02")); executeTest("testApplyRecalibrationSnpAndIndelTogether", spec); }