By using the AC info field instead of parsing the genotypes we cut 78% off the runtime of VariantRecalibrator. There is a new argument to force the parsing of genotypes if necessary. Various other optimizations throughout.

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@4383 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
rpoplin 2010-09-29 18:56:50 +00:00
parent 2d1265771f
commit a6c7de95c8
3 changed files with 70 additions and 81 deletions

View File

@ -65,9 +65,10 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
private final double[][] mu; // The means for each cluster
private final Matrix[] sigma; // The covariance matrix for each cluster
private final Matrix[] sigmaInverse;
private final double[][][] sigmaInverse;
private double[] pClusterLog10;
private final double[] determinant;
private final double[] sqrtDeterminantLog10;
private final double stdThreshold;
private double singletonFPRate = -1; // Estimated FP rate for singleton calls. Used to estimate FP rate as a function of AC
@ -78,6 +79,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
private final double[] hyperParameter_b;
private final double[] hyperParameter_lambda;
private final double CONSTANT_GAUSSIAN_DENOM_LOG10;
private static final Pattern COMMENT_PATTERN = Pattern.compile("^##.*");
private static final Pattern ANNOTATION_PATTERN = Pattern.compile("^@!ANNOTATION.*");
private static final Pattern CLUSTER_PATTERN = Pattern.compile("^@!CLUSTER.*");
@ -93,10 +96,12 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
sigma = null;
sigmaInverse = null;
determinant = null;
sqrtDeterminantLog10 = null;
stdThreshold = 0;
hyperParameter_a = null;
hyperParameter_b = null;
hyperParameter_lambda = null;
CONSTANT_GAUSSIAN_DENOM_LOG10 = 0.0;
}
public VariantGaussianMixtureModel( final VariantDataManager _dataManager, final int _maxGaussians, final int _maxIterations,
@ -108,6 +113,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
mu = new double[maxGaussians][];
sigma = new Matrix[maxGaussians];
determinant = new double[maxGaussians];
sqrtDeterminantLog10 = null;
pClusterLog10 = new double[maxGaussians];
stdThreshold = _stdThreshold;
FORCE_INDEPENDENT_ANNOTATIONS = _forceIndependent;
@ -116,6 +122,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
hyperParameter_lambda = new double[maxGaussians];
sigmaInverse = null; // This field isn't used during GenerateVariantClusters pass
CONSTANT_GAUSSIAN_DENOM_LOG10 = Math.log10(Math.pow(2.0 * Math.PI, ((double)dataManager.numAnnotations) / 2.0));
SHRINKAGE = _shrinkage;
DIRICHLET_PARAMETER = _dirichlet;
}
@ -149,15 +156,17 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
hyperParameter_a = null;
hyperParameter_b = null;
hyperParameter_lambda = null;
determinant = null;
// BUGBUG: move this parsing out of the constructor
CONSTANT_GAUSSIAN_DENOM_LOG10 = Math.log10(Math.pow(2.0 * Math.PI, ((double)dataManager.numAnnotations) / 2.0));
maxGaussians = clusterLines.size();
mu = new double[maxGaussians][dataManager.numAnnotations];
final double sigmaVals[][][] = new double[maxGaussians][dataManager.numAnnotations][dataManager.numAnnotations];
sigma = new Matrix[maxGaussians];
sigmaInverse = new Matrix[maxGaussians];
sigmaInverse = new double[maxGaussians][dataManager.numAnnotations][dataManager.numAnnotations];
pClusterLog10 = new double[maxGaussians];
determinant = new double[maxGaussians];
sqrtDeterminantLog10 = new double[maxGaussians];
int kkk = 0;
for( final String line : clusterLines ) {
@ -171,8 +180,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
}
sigma[kkk] = new Matrix(sigmaVals[kkk]);
sigmaInverse[kkk] = sigma[kkk].inverse(); // Precompute all the inverses and determinants for use later
determinant[kkk] = sigma[kkk].det();
sigmaInverse[kkk] = sigma[kkk].inverse().getArray(); // Precompute all the inverses and determinants for use later
sqrtDeterminantLog10[kkk] = Math.log10(Math.pow(sigma[kkk].det(), 0.5)); // Precompute for use later
kkk++;
}
@ -381,7 +390,7 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
logger.info("Finished iteration " + ttt );
ttt++;
if( Math.abs(currentLikelihood - previousLikelihood) < MIN_PROB_CONVERGENCE) {
if( Math.abs(currentLikelihood - previousLikelihood) < MIN_PROB_CONVERGENCE ) {
logger.info("Convergence!");
break;
}
@ -452,59 +461,6 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
return sum;
}
public final void outputClusterReports( final String outputPrefix ) {
final double STD_STEP = 0.2;
final double MAX_STD = 4.0;
final double MIN_STD = -4.0;
final int NUM_BINS = (int)Math.floor((Math.abs(MIN_STD) + Math.abs(MAX_STD)) / STD_STEP);
final int numAnnotations = dataManager.numAnnotations;
int totalCountsKnown = 0;
int totalCountsNovel = 0;
final int counts[][][] = new int[numAnnotations][NUM_BINS][2];
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
for( int iii = 0; iii < NUM_BINS; iii++ ) {
counts[jjj][iii][0] = 0;
counts[jjj][iii][1] = 0;
}
}
for( final VariantDatum datum : dataManager.data ) {
final int isKnown = ( datum.isKnown ? 1 : 0 );
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
int histBin = (int)Math.round((datum.annotations[jjj]-MIN_STD) * (1.0 / STD_STEP));
if(histBin < 0) { histBin = 0; }
if(histBin > NUM_BINS-1) { histBin = NUM_BINS-1; }
if(histBin >= 0 && histBin <= NUM_BINS-1) {
counts[jjj][histBin][isKnown]++;
}
}
if( isKnown == 1 ) { totalCountsKnown++; }
else { totalCountsNovel++; }
}
int annIndex = 0;
for( final String annotation : dataManager.annotationKeys ) {
PrintStream outputFile;
File file = new File(outputPrefix + "." + annotation + ".dat");
try {
outputFile = new PrintStream( file );
} catch (FileNotFoundException e) {
throw new UserException.CouldNotCreateOutputFile( file, e );
}
outputFile.println("annotationValue,knownDist,novelDist");
for( int iii = 0; iii < NUM_BINS; iii++ ) {
final double annotationValue = (((double)iii * STD_STEP)+MIN_STD) * dataManager.varianceVector[annIndex] + dataManager.meanVector[annIndex];
outputFile.println( annotationValue + "," + ( ((double)counts[annIndex][iii][1])/((double)totalCountsKnown) ) +
"," + ( ((double)counts[annIndex][iii][0])/((double)totalCountsNovel) ));
}
annIndex++;
}
}
public final void outputOptimizationCurve( final VariantDatum[] data, final PrintStream outputReportDatFile, final PrintStream tranchesOutputFile,
final int desiredNumVariants, final Double[] FDRtranches, final double QUAL_STEP ) {
@ -721,9 +677,8 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
final int numAnnotations = annotations.length;
final double evalGaussianPDFLog10[] = new double[maxGaussians];
for( int kkk = 0; kkk < maxGaussians; kkk++ ) {
final double sigmaVals[][] = sigmaInverse[kkk].getArray();
final double sigmaVals[][] = sigmaInverse[kkk];
double sum = 0.0;
for( int jjj = 0; jjj < numAnnotations; jjj++ ) {
double value = 0.0;
@ -733,15 +688,14 @@ public final class VariantGaussianMixtureModel extends VariantOptimizationModel
final double mySigma = sigmaVals[ppp][jjj];
value += (myAnn - myMu) * mySigma;
}
double jNorm = annotations[jjj] - mu[kkk][jjj];
double prod = value * jNorm;
final double jNorm = annotations[jjj] - mu[kkk][jjj];
final double prod = value * jNorm;
sum += prod;
}
final double log10SqrtDet = Math.log10(Math.pow(determinant[kkk], 0.5));
final double denomLog10 = Math.log10(Math.pow(2.0 * Math.PI, ((double)numAnnotations) / 2.0)) + log10SqrtDet;
evalGaussianPDFLog10[kkk] = (( -0.5 * sum ) / Math.log(10.0)) - denomLog10;
double pVar1 = Math.pow(10.0, pClusterLog10[kkk] + evalGaussianPDFLog10[kkk]);
final double denomLog10 = CONSTANT_GAUSSIAN_DENOM_LOG10 + sqrtDeterminantLog10[kkk];
final double evalGaussianPDFLog10 = (( -0.5 * sum ) / Math.log(10.0)) - denomLog10;
final double pVar1 = Math.pow(10.0, pClusterLog10[kkk] + evalGaussianPDFLog10);
pVarInCluster[kkk] = pVar1;
}

View File

@ -38,6 +38,7 @@ import org.broadinstitute.sting.gatk.refdata.utils.helpers.DbSNPHelper;
import org.broadinstitute.sting.gatk.walkers.RodWalker;
import org.broadinstitute.sting.utils.*;
import org.broadinstitute.sting.utils.collections.ExpandingArrayList;
import org.broadinstitute.sting.utils.collections.NestedHashMap;
import org.broadinstitute.sting.utils.exceptions.UserException;
import org.broadinstitute.sting.utils.vcf.VCFUtils;
@ -70,7 +71,7 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
private PrintStream TRANCHES_FILE;
@Output(fullName="report_dat_file", shortName="reportDatFile", doc="The output report .dat file used with Rscript to create the optimization curve PDF file", required=true)
private File REPORT_DAT_FILE;
@Output(doc="File to which recalibrated variants should be written",required=true)
@Output(doc="File to which recalibrated variants should be written", required=true)
private VCFWriter vcfWriter = null;
/////////////////////////////
@ -106,6 +107,8 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
private double SINGLETON_FP_RATE = 0.5;
@Argument(fullName="quality_scale_factor", shortName="qScale", doc="Multiply all final quality scores by this value. Needed to normalize the quality scores.", required=false)
private double QUALITY_SCALE_FACTOR = 100.0;
@Argument(fullName="dontTrustACField", shortName="dontTrustACField", doc="If specified the VR will not use the AC field and will instead always parse the genotypes to figure out how many variant chromosomes there are at a given site.", required=false)
private boolean NEVER_TRUST_AC_FIELD = false;
/////////////////////////////
// Debug Arguments
@ -121,6 +124,8 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
private VariantGaussianMixtureModel theModel = null;
private Set<String> ignoreInputFilterSet = null;
private Set<String> inputNames = new HashSet<String>();
private NestedHashMap priorCache = new NestedHashMap();
private boolean trustACField = false;
//---------------------------------------------------------------------------------------------------------------
//
@ -228,26 +233,56 @@ public class VariantRecalibrator extends RodWalker<ExpandingArrayList<VariantDat
} else if( dbsnp != null ) {
knownPrior_qScore = PRIOR_DBSNP;
}
final double knownPrior = QualityUtils.qualToProb(knownPrior_qScore);
final int alleleCount = vc.getChromosomeCount(vc.getAlternateAllele(0)); // BUGBUG: assumes file has genotypes. Also, what to do about tri-allelic sites?
final double acPrior = theModel.getAlleleCountPrior( alleleCount );
final double totalPrior = 1.0 - ((1.0 - acPrior) * (1.0 - knownPrior));
// If we can trust the AC field then use it instead of parsing all the genotypes. Results in a substantial speed up.
final int alleleCount = (trustACField ? Integer.parseInt(vc.getAttribute("AC",-1).toString()) : vc.getChromosomeCount(vc.getAlternateAllele(0)));
if( !trustACField && !NEVER_TRUST_AC_FIELD ) {
if( !vc.getAttributes().containsKey("AC") ) {
NEVER_TRUST_AC_FIELD = true;
} else {
if( alleleCount == Integer.parseInt( vc.getAttribute("AC").toString()) ) {
// The AC field is correct at this record so we trust it forever
trustACField = true;
} else { // We found a record in which the AC field wasn't correct but we are trying to trust it
throw new UserException.BadInput("AC info field doesn't match the variant chromosome count so we can't trust it! Please run with --dontTrustACField which will force the walker to parse the genotypes for each record, drastically increasing the runtime." +
"First observed at " + vc.getChr() + ":" + vc.getStart());
}
}
}
if( trustACField && alleleCount == -1 ) {
throw new UserException.BadInput("AC info field doesn't exist for all records (although it does for some) so we can't trust it! Please run with --dontTrustACField which will force the walker to parse the genotypes for each record, drastically increasing the runtime." +
"First observed at " + vc.getChr() + ":" + vc.getStart());
}
if( MathUtils.compareDoubles(totalPrior, 1.0, 1E-8) == 0 || MathUtils.compareDoubles(totalPrior, 0.0, 1E-8) == 0 ) {
throw new UserException.CommandLineException("Something is wrong with the prior that was entered by the user: Prior = " + totalPrior); // TODO - fix this up later
final Object[] priorKey = new Object[2];
priorKey[0] = alleleCount;
priorKey[1] = knownPrior_qScore;
Double priorLodFactor = (Double)priorCache.get( priorKey );
// If this prior factor hasn't been calculated before, do so now
if(priorLodFactor == null) {
final double knownPrior = QualityUtils.qualToProb(knownPrior_qScore);
final double acPrior = theModel.getAlleleCountPrior( alleleCount );
final double totalPrior = 1.0 - ((1.0 - acPrior) * (1.0 - knownPrior));
if( MathUtils.compareDoubles(totalPrior, 1.0, 1E-8) == 0 || MathUtils.compareDoubles(totalPrior, 0.0, 1E-8) == 0 ) {
throw new UserException.CommandLineException("Something is wrong with the prior that was entered by the user: Prior = " + totalPrior); // TODO - fix this up later
}
priorLodFactor = Math.log10(totalPrior) - Math.log10(1.0 - totalPrior) - Math.log10(1.0);
priorCache.put( priorLodFactor, false, priorKey );
}
final double pVar = theModel.evaluateVariant( vc );
final double lod = (Math.log10(totalPrior) + Math.log10(pVar)) - ((Math.log10(1.0 - totalPrior)) + Math.log10(1.0));
final double lod = priorLodFactor + Math.log10(pVar);
variantDatum.qual = Math.abs( QUALITY_SCALE_FACTOR * QualityUtils.lodToPhredScaleErrorRate(lod) );
mapList.add( variantDatum );
Map<String, Object> attrs = new HashMap<String, Object>(vc.getAttributes());
attrs.put("OQ", String.format("%.2f", ((Double)vc.getPhredScaledQual())));
final Map<String, Object> attrs = new HashMap<String, Object>(vc.getAttributes());
attrs.put("OQ", String.format("%.2f", vc.getPhredScaledQual()));
attrs.put("LOD", String.format("%.4f", lod));
VariantContext newVC = VariantContext.modifyPErrorFiltersAndAttributes(vc, variantDatum.qual / 10.0, new HashSet<String>(), attrs);

View File

@ -46,7 +46,7 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
public void testVariantRecalibrator() {
HashMap<String, List<String>> e = new HashMap<String, List<String>>();
e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf",
Arrays.asList("0c6b5085a678b6312ab4bc8ce7b4eee4", "038c31c5bb46a4df89b8ee69ec740812","7d42bbdfb69fdfb18cbda13a63d92602")); // Each test checks the md5 of three output files
Arrays.asList("e94b02016e6f7936999f02979b801c30", "038c31c5bb46a4df89b8ee69ec740812","7d42bbdfb69fdfb18cbda13a63d92602")); // Each test checks the md5 of three output files
e.put( validationDataLocation + "lowpass.N3.chr1.raw.vcf",
Arrays.asList("bbdffb7fa611f4ae80e919cdf86b9bc6", "661360e85392af9c97e386399871854a","371e5a70a4006420737c5ab259e0e23e")); // Each test checks the md5 of three output files
@ -84,7 +84,7 @@ public class VariantRecalibrationWalkersIntegrationTest extends WalkerTest {
@Test
public void testApplyVariantCuts() {
HashMap<String, String> e = new HashMap<String, String>();
e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "7429ed494131eb1dab5a9169cf65d1f0" );
e.put( validationDataLocation + "yri.trio.gatk_glftrio.intersection.annotated.filtered.chr1.vcf", "e06aa6b734cc3c881d95cf5ee9315664" );
e.put( validationDataLocation + "lowpass.N3.chr1.raw.vcf", "ad8661cba3b04a7977c97a541fd8a668" );
for ( Map.Entry<String, String> entry : e.entrySet() ) {