By using the AC info field instead of parsing the genotypes we cut 78% off the runtime of VariantRecalibrator. There is a new argument to force the parsing of genotypes if necessary. Various other optimizations throughout.

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@4383 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
rpoplin 2010-09-29 18:56:50 +00:00
parent 2d1265771f
commit a6c7de95c8
3 changed files with 70 additions and 81 deletions

View File

@ -65,9 +65,10 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
private final double[][] mu; // The means for each cluster private final double[][] mu; // The means for each cluster
private final Matrix[] sigma; // The covariance matrix for each cluster private final Matrix[] sigma; // The covariance matrix for each cluster
private final Matrix[] sigmaInverse; private final double[][][] sigmaInverse;
private double[] pClusterLog10; private double[] pClusterLog10;
private final double[] determinant; private final double[] determinant;
private final double[] sqrtDeterminantLog10;
private final double stdThreshold; private final double stdThreshold;
private double singletonFPRate = -1; // Estimated FP rate for singleton calls. Used to estimate FP rate as a function of AC private double singletonFPRate = -1; // Estimated FP rate for singleton calls. Used to estimate FP rate as a function of AC
@ -78,6 +79,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
private final double[] hyperParameter_b; private final double[] hyperParameter_b;
private final double[] hyperParameter_lambda; private final double[] hyperParameter_lambda;
private final double CONSTANT_GAUSSIAN_DENOM_LOG10;
private static final Pattern COMMENT_PATTERN = Pattern.compile("^##.*"); private static final Pattern COMMENT_PATTERN = Pattern.compile("^##.*");
private static final Pattern ANNOTATION_PATTERN = Pattern.compile("^@!ANNOTATION.*"); private static final Pattern ANNOTATION_PATTERN = Pattern.compile("^@!ANNOTATION.*");
private static final Pattern CLUSTER_PATTERN = Pattern.compile("^@!CLUSTER.*"); private static final Pattern CLUSTER_PATTERN = Pattern.compile("^@!CLUSTER.*");
@ -93,10 +96,12 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
sigma = null; sigma = null;
sigmaInverse = null; sigmaInverse = null;
determinant = null; determinant = null;
sqrtDeterminantLog10 = null;
stdThreshold = 0; stdThreshold = 0;
hyperParameter_a = null; hyperParameter_a = null;
hyperParameter_b = null; hyperParameter_b = null;
hyperParameter_lambda = null; hyperParameter_lambda = null;
CONSTANT_GAUSSIAN_DENOM_LOG10 = 0.0;
} }
public VariantGaussianMixtureModel( final VariantDataManager _dataManager, final int _maxGaussians, final int _maxIterations, public VariantGaussianMixtureModel( final VariantDataManager _dataManager, final int _maxGaussians, final int _maxIterations,
@ -108,6 +113,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
mu = new double[maxGaussians][]; mu = new double[maxGaussians][];
sigma = new Matrix[maxGaussians]; sigma = new Matrix[maxGaussians];
determinant = new double[maxGaussians]; determinant = new double[maxGaussians];
sqrtDeterminantLog10 = null;
pClusterLog10 = new double[maxGaussians]; pClusterLog10 = new double[maxGaussians];
stdThreshold = _stdThreshold; stdThreshold = _stdThreshold;
FORCE_INDEPENDENT_ANNOTATIONS = _forceIndependent; FORCE_INDEPENDENT_ANNOTATIONS = _forceIndependent;
@ -116,6 +122,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
hyperParameter_lambda = new double[maxGaussians]; hyperParameter_lambda = new double[maxGaussians];
sigmaInverse = null; // This field isn't used during GenerateVariantClusters pass sigmaInverse = null; // This field isn't used during GenerateVariantClusters pass
CONSTANT_GAUSSIAN_DENOM_LOG10 = Math.log10(Math.pow(2.0 * Math.PI, ((double)dataManager.numAnnotations) / 2.0));
SHRINKAGE = _shrinkage; SHRINKAGE = _shrinkage;
DIRICHLET_PARAMETER = _dirichlet; DIRICHLET_PARAMETER = _dirichlet;
} }
@ -149,15 +156,17 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
hyperParameter_a = null; hyperParameter_a = null;
hyperParameter_b = null; hyperParameter_b = null;
hyperParameter_lambda = null; hyperParameter_lambda = null;
determinant = null;
// BUGBUG: move this parsing out of the constructor // BUGBUG: move this parsing out of the constructor
CONSTANT_GAUSSIAN_DENOM_LOG10 = Math.log10(Math.pow(2.0 * Math.PI, ((double)dataManager.numAnnotations) / 2.0));
maxGaussians = clusterLines.size(); maxGaussians = clusterLines.size();
mu = new double[maxGaussians][dataManager.numAnnotations]; mu = new double[maxGaussians][dataManager.numAnnotations];
final double sigmaVals[][][] = new double[maxGaussians][dataManager.numAnnotations][dataManager.numAnnotations]; final double sigmaVals[][][] = new double[maxGaussians][dataManager.numAnnotations][dataManager.numAnnotations];
sigma = new Matrix[maxGaussians]; sigma = new Matrix[maxGaussians];
sigmaInverse = new Matrix[maxGaussians]; sigmaInverse = new double[maxGaussians][dataManager.numAnnotations][dataManager.numAnnotations];
pClusterLog10 = new double[maxGaussians]; pClusterLog10 = new double[maxGaussians];
determinant = new double[maxGaussians]; sqrtDeterminantLog10 = new double[maxGaussians];
int kkk = 0; int kkk = 0;
for( final String line : clusterLines ) { for( final String line : clusterLines ) {
@ -171,8 +180,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
} }
sigma[kkk] = new Matrix(sigmaVals[kkk]); sigma[kkk] = new Matrix(sigmaVals[kkk]);
sigmaInverse[kkk] = sigma[kkk].inverse(); // Precompute all the inverses and determinants for use later sigmaInverse[kkk] = sigma[kkk].inverse().getArray(); // Precompute all the inverses and determinants for use later
determinant[kkk] = sigma[kkk].det(); sqrtDeterminantLog10[kkk] = Math.log10(Math.pow(sigma[kkk].det(), 0.5)); // Precompute for use later
kkk++; kkk++;
} }
@ -381,7 +390,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
logger.info("Finished iteration " + ttt ); logger.info("Finished iteration " + ttt );
ttt++; ttt++;
if( Math.abs(currentLikelihood - previousLikelihood) < MIN_PROB_CONVERGENCE) { if( Math.abs(currentLikelihood - previousLikelihood) < MIN_PROB_CONVERGENCE ) {
logger.info("Convergence!"); logger.info("Convergence!");
break; break;
} }
@ -452,59 +461,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
return sum; return sum;
} }
public final void outputClusterReports( final String outputPrefix ) {
final double STD_STEP = 0.2;
final double MAX_STD = 4.0;
final double MIN_STD = -4.0;
final int NUM_BINS = (int)Math.floor((Math.abs(MIN_STD) + Math.abs(MAX_STD)) / STD_STEP);
final int numAnnotations = dataManager.numAnnotations;
int totalCountsKnown = 0;
int totalCountsNovel = 0;
final int counts[][][] = new int[numAnnotations][NUM_BINS][2];
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
for( int iii = 0; iii < NUM_BINS; iii++ ) {
counts[jjj][iii][0] = 0;
counts[jjj][iii][1] = 0;
}
}
for( final VariantDatum datum : dataManager.data ) {
final int isKnown = ( datum.isKnown ? 1 : 0 );
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
int histBin = (int)Math.round((datum.annotations[jjj]-MIN_STD) * (1.0 / STD_STEP));
if(histBin < 0) { histBin = 0; }
if(histBin > NUM_BINS-1) { histBin = NUM_BINS-1; }
if(histBin >= 0 && histBin <= NUM_BINS-1) {
counts[jjj][histBin][isKnown]++;
}
}
if( isKnown == 1 ) { totalCountsKnown++; }
else { totalCountsNovel++; }
}
int annIndex = 0;
for( final String annotation : dataManager.annotationKeys ) {
PrintStream outputFile;
File file = new File(outputPrefix + "." + annotation + ".dat");
try {
outputFile = new PrintStream( file );
} catch (FileNotFoundException e) {
throw new UserException.CouldNotCreateOutputFile( file, e );
}
outputFile.println("annotationValue,knownDist,novelDist");
for( int iii = 0; iii < NUM_BINS; iii++ ) {
final double annotationValue = (((double)iii * STD_STEP)+MIN_STD) * dataManager.varianceVector[annIndex] + dataManager.meanVector[annIndex];
outputFile.println( annotationValue + "," + ( ((double)counts[annIndex][iii][1])/((double)totalCountsKnown) ) +
"," + ( ((double)counts[annIndex][iii][0])/((double)totalCountsNovel) ));
}
annIndex++;
}
}
public final void outputOptimizationCurve( final VariantDatum[] data, final PrintStream outputReportDatFile, final PrintStream tranchesOutputFile, public final void outputOptimizationCurve( final VariantDatum[] data, final PrintStream outputReportDatFile, final PrintStream tranchesOutputFile,
final int desiredNumVariants, final Double[] FDRtranches, final double QUAL_STEP ) { final int desiredNumVariants, final Double[] FDRtranches, final double QUAL_STEP ) {
@ -721,9 +677,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
final int numAnnotations = annotations.length; final int numAnnotations = annotations.length;
final double evalGaussianPDFLog10[] = new double[maxGaussians];
for( int kkk = 0; kkk < maxGaussians; kkk++ ) { for( int kkk = 0; kkk < maxGaussians; kkk++ ) {
final double sigmaVals[][] = sigmaInverse[kkk].getArray(); final double sigmaVals[][] = sigmaInverse[kkk];
double sum = 0.0; double sum = 0.0;
for( int jjj = 0; jjj < numAnnotations; jjj++ ) { for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
double value = 0.0; double value = 0.0;
@ -733,15 +688,14 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
final double mySigma = sigmaVals[ppp][jjj]; final double mySigma = sigmaVals[ppp][jjj];
value += (myAnn - myMu) * mySigma; value += (myAnn - myMu) * mySigma;
} }
double jNorm = annotations[jjj] - mu[kkk][jjj]; final double jNorm = annotations[jjj] - mu[kkk][jjj];
double prod = value * jNorm; final double prod = value * jNorm;
sum += prod; sum += prod;
} }
final double log10SqrtDet = Math.log10(Math.pow(determinant[kkk], 0.5)); final double denomLog10 = CONSTANT_GAUSSIAN_DENOM_LOG10 + sqrtDeterminantLog10[kkk];
final double denomLog10 = Math.log10(Math.pow(2.0 * Math.PI, ((double)numAnnotations) / 2.0)) + log10SqrtDet; final double evalGaussianPDFLog10 = (( -0.5 * sum ) / Math.log(10.0)) - denomLog10;
evalGaussianPDFLog10[kkk] = (( -0.5 * sum ) / Math.log(10.0)) - denomLog10; final double pVar1 = Math.pow(10.0, pClusterLog10[kkk] + evalGaussianPDFLog10);
double pVar1 = Math.pow(10.0, pClusterLog10[kkk] + evalGaussianPDFLog10[kkk]);
pVarInCluster[kkk] = pVar1; pVarInCluster[kkk] = pVar1;
} }

View File

@ -38,6 +38,7 @@ import org.broadinstitute.sting.gatk.refdata.utils.helpers.DbSNPHelper;
import org.broadinstitute.sting.gatk.walkers.RodWalker; import org.broadinstitute.sting.gatk.walkers.RodWalker;
import org.broadinstitute.sting.utils.*; import org.broadinstitute.sting.utils.*;
import org.broadinstitute.sting.utils.collections.ExpandingArrayList; import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
import org.broadinstitute.sting.utils.collections.NestedHashMap;
import org.broadinstitute.sting.utils.exceptions.UserException; import org.broadinstitute.sting.utils.exceptions.UserException;
import org.broadinstitute.sting.utils.vcf.VCFUtils; import org.broadinstitute.sting.utils.vcf.VCFUtils;
@ -70,7 +71,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
private PrintStream TRANCHES_FILE; private PrintStream TRANCHES_FILE;
@Output(fullName="report_dat_file", shortName="reportDatFile", doc="The output report .dat file used with Rscript to create the optimization curve PDF file", required=true) @Output(fullName="report_dat_file", shortName="reportDatFile", doc="The output report .dat file used with Rscript to create the optimization curve PDF file", required=true)
private File REPORT_DAT_FILE; private File REPORT_DAT_FILE;
@Output(doc="File to which recalibrated variants should be written",required=true) @Output(doc="File to which recalibrated variants should be written", required=true)
private VCFWriter vcfWriter = null; private VCFWriter vcfWriter = null;
///////////////////////////// /////////////////////////////
@ -106,6 +107,8 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
private double SINGLETON_FP_RATE = 0.5; private double SINGLETON_FP_RATE = 0.5;
@Argument(fullName="quality_scale_factor", shortName="qScale", doc="Multiply all final quality scores by this value. Needed to normalize the quality scores.", required=false) @Argument(fullName="quality_scale_factor", shortName="qScale", doc="Multiply all final quality scores by this value. Needed to normalize the quality scores.", required=false)
private double QUALITY_SCALE_FACTOR = 100.0; private double QUALITY_SCALE_FACTOR = 100.0;
@Argument(fullName="dontTrustACField", shortName="dontTrustACField", doc="If specified the VR will not use the AC field and will instead always parse the genotypes to figure out how many variant chromosomes there are at a given site.", required=false)
private boolean NEVER_TRUST_AC_FIELD = false;
///////////////////////////// /////////////////////////////
// Debug Arguments // Debug Arguments
@ -121,6 +124,8 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
private VariantGaussianMixtureModel theModel = null; private VariantGaussianMixtureModel theModel = null;
private Set<String> ignoreInputFilterSet = null; private Set<String> ignoreInputFilterSet = null;
private Set<String> inputNames = new HashSet<String>(); private Set<String> inputNames = new HashSet<String>();
private NestedHashMap priorCache = new NestedHashMap();
private boolean trustACField = false;
//--------------------------------------------------------------------------------------------------------------- //---------------------------------------------------------------------------------------------------------------
// //
@ -228,26 +233,56 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
} else if( dbsnp != null ) { } else if( dbsnp != null ) {
knownPrior_qScore = PRIOR_DBSNP; knownPrior_qScore = PRIOR_DBSNP;
} }
final double knownPrior = QualityUtils.qualToProb(knownPrior_qScore);
final int alleleCount = vc.getChromosomeCount(vc.getAlternateAllele(0)); // BUGBUG: assumes file has genotypes. Also, what to do about tri-allelic sites?
final double acPrior = theModel.getAlleleCountPrior( alleleCount ); // If we can trust the AC field then use it instead of parsing all the genotypes. Results in a substantial speed up.
final double totalPrior = 1.0 - ((1.0 - acPrior) * (1.0 - knownPrior)); final int alleleCount = (trustACField ? Integer.parseInt(vc.getAttribute("AC",-1).toString()) : vc.getChromosomeCount(vc.getAlternateAllele(0)));
if( !trustACField && !NEVER_TRUST_AC_FIELD ) {
if( !vc.getAttributes().containsKey("AC") ) {
NEVER_TRUST_AC_FIELD = true;
} else {
if( alleleCount == Integer.parseInt( vc.getAttribute("AC").toString()) ) {
// The AC field is correct at this record so we trust it forever
trustACField = true;
} else { // We found a record in which the AC field wasn't correct but we are trying to trust it
throw new UserException.BadInput("AC info field doesn't match the variant chromosome count so we can't trust it! Please run with --dontTrustACField which will force the walker to parse the genotypes for each record, drastically increasing the runtime." +
"First observed at " + vc.getChr() + ":" + vc.getStart());
}
}
}
if( trustACField && alleleCount == -1 ) {
throw new UserException.BadInput("AC info field doesn't exist for all records (although it does for some) so we can't trust it! Please run with --dontTrustACField which will force the walker to parse the genotypes for each record, drastically increasing the runtime." +
"First observed at " + vc.getChr() + ":" + vc.getStart());
}
if( MathUtils.compareDoubles(totalPrior, 1.0, 1E-8) == 0 || MathUtils.compareDoubles(totalPrior, 0.0, 1E-8) == 0 ) { final Object[] priorKey = new Object[2];
throw new UserException.CommandLineException("Something is wrong with the prior that was entered by the user: Prior = " + totalPrior); // TODO - fix this up later priorKey[0] = alleleCount;
priorKey[1] = knownPrior_qScore;
Double priorLodFactor = (Double)priorCache.get( priorKey );
// If this prior factor hasn't been calculated before, do so now
if(priorLodFactor == null) {
final double knownPrior = QualityUtils.qualToProb(knownPrior_qScore);
final double acPrior = theModel.getAlleleCountPrior( alleleCount );
final double totalPrior = 1.0 - ((1.0 - acPrior) * (1.0 - knownPrior));
if( MathUtils.compareDoubles(totalPrior, 1.0, 1E-8) == 0 || MathUtils.compareDoubles(totalPrior, 0.0, 1E-8) == 0 ) {
throw new UserException.CommandLineException("Something is wrong with the prior that was entered by the user: Prior = " + totalPrior); // TODO - fix this up later
}
priorLodFactor = Math.log10(totalPrior) - Math.log10(1.0 - totalPrior) - Math.log10(1.0);
priorCache.put( priorLodFactor, false, priorKey );
} }
final double pVar = theModel.evaluateVariant( vc ); final double pVar = theModel.evaluateVariant( vc );
final double lod = priorLodFactor + Math.log10(pVar);
final double lod = (Math.log10(totalPrior) + Math.log10(pVar)) - ((Math.log10(1.0 - totalPrior)) + Math.log10(1.0));
variantDatum.qual = Math.abs( QUALITY_SCALE_FACTOR * QualityUtils.lodToPhredScaleErrorRate(lod) ); variantDatum.qual = Math.abs( QUALITY_SCALE_FACTOR * QualityUtils.lodToPhredScaleErrorRate(lod) );
mapList.add( variantDatum ); mapList.add( variantDatum );
final Map<String, Object> attrs = new HashMap<String, Object>(vc.getAttributes());
Map<String, Object> attrs = new HashMap<String, Object>(vc.getAttributes()); attrs.put("OQ", String.format("%.2f", vc.getPhredScaledQual()));
attrs.put("OQ", String.format("%.2f", ((Double)vc.getPhredScaledQual())));
attrs.put("LOD", String.format("%.4f", lod)); attrs.put("LOD", String.format("%.4f", lod));
VariantContext newVC = VariantContext.modifyPErrorFiltersAndAttributes(vc, variantDatum.qual / 10.0, new HashSet<String>(), attrs); VariantContext newVC = VariantContext.modifyPErrorFiltersAndAttributes(vc, variantDatum.qual / 10.0, new HashSet<String>(), attrs);

View File

@ -46,7 +46,7 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
public void testVariantRecalibrator() { public void testVariantRecalibrator() {
HashMap<String, List<String>> e = new HashMap<String, List<String>>(); HashMap<String, List<String>> e = new HashMap<String, List<String>>();
e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf",
Arrays.asList("0c6b5085a678b6312ab4bc8ce7b4eee4", "038c31c5bb46a4df89b8ee69ec740812","7d42bbdfb69fdfb18cbda13a63d92602")); // Each test checks the md5 of three output files Arrays.asList("e94b02016e6f7936999f02979b801c30", "038c31c5bb46a4df89b8ee69ec740812","7d42bbdfb69fdfb18cbda13a63d92602")); // Each test checks the md5 of three output files
e.put( validationDataLocation + "lowpass.N3.chr1.raw.vcf", e.put( validationDataLocation + "lowpass.N3.chr1.raw.vcf",
Arrays.asList("bbdffb7fa611f4ae80e919cdf86b9bc6", "661360e85392af9c97e386399871854a","371e5a70a4006420737c5ab259e0e23e")); // Each test checks the md5 of three output files Arrays.asList("bbdffb7fa611f4ae80e919cdf86b9bc6", "661360e85392af9c97e386399871854a","371e5a70a4006420737c5ab259e0e23e")); // Each test checks the md5 of three output files
@ -84,7 +84,7 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
@Test @Test
public void testApplyVariantCuts() { public void testApplyVariantCuts() {
HashMap<String, String> e = new HashMap<String, String>(); HashMap<String, String> e = new HashMap<String, String>();
e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "7429ed494131eb1dab5a9169cf65d1f0" ); e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "e06aa6b734cc3c881d95cf5ee9315664" );
e.put( validationDataLocation + "lowpass.N3.chr1.raw.vcf", "ad8661cba3b04a7977c97a541fd8a668" ); e.put( validationDataLocation + "lowpass.N3.chr1.raw.vcf", "ad8661cba3b04a7977c97a541fd8a668" );
for ( Map.Entry<String, String> entry : e.entrySet() ) { for ( Map.Entry<String, String> entry : e.entrySet() ) {