diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java index 935467f1b..a9f9f95de 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/GaussianMixtureModel.java @@ -28,8 +28,9 @@ public class GaussianMixtureModel { private final ArrayList gaussians; private final double shrinkage; private final double dirichletParameter; + private final double degreesOfFreedom; private final double[] empiricalMu; // BUGBUG: move these to VariantData class - private final Matrix empiricalSigma; + private final Matrix empiricalSigma; // BUGBUG: move these to VariantData class public boolean isModelReadyForEvaluation; public GaussianMixtureModel( final int numGaussians, final int numAnnotations, @@ -42,6 +43,7 @@ public class GaussianMixtureModel { } this.shrinkage = shrinkage; this.dirichletParameter = dirichletParameter; + degreesOfFreedom = numAnnotations; empiricalMu = new double[numAnnotations]; empiricalSigma = new Matrix(numAnnotations, numAnnotations); isModelReadyForEvaluation = false; @@ -57,27 +59,49 @@ public class GaussianMixtureModel { for( int iii = 0; iii < annotationLines.size(); iii++ ) { gaussian.mu[iii] = Double.parseDouble(vals[2+iii]); //BUGBUG: recreated here to match the integration tests for( int jjj = 0; jjj < annotationLines.size(); jjj++ ) { - gaussian.sigma.set(iii, jjj, Double.parseDouble(vals[2+annotationLines.size()+(iii*annotationLines.size())+jjj]) * 1.3); // BUGBUG: VRAC backOff, or get rid of this completely!? + gaussian.sigma.set(iii, jjj, 1.3*Double.parseDouble(vals[2+annotationLines.size()+(iii*annotationLines.size())+jjj])); //BUGBUG: VRAC backoff } } gaussians.add( gaussian ); } - this.shrinkage = 0.0; // not used when evaluating data, BUGBUG: move this to VariantData class - this.dirichletParameter = 0.0; // not used when evaluating data + shrinkage = 0.0; // not used when evaluating data, BUGBUG: move this to VariantData class + dirichletParameter = 0.0; // not used when evaluating data + degreesOfFreedom = 0.0; // not used when evaluating data empiricalMu = null; // not used when evaluating data empiricalSigma = null; // not used when evaluating data isModelReadyForEvaluation = false; } public void cacheEmpiricalStats( final List data ) { + //final double[][] tmpSigmaVals = new double[empiricalMu.length][empiricalMu.length]; + for( int iii = 0; iii < empiricalMu.length; iii++ ) { + empiricalMu[iii] = 0.0; + //for( int jjj = iii; jjj < empiricalMu.length; jjj++ ) { + // tmpSigmaVals[iii][jjj] = 0.0; + //} + } + for( final VariantDatum datum : data ) { - for( int jjj = 0; jjj < empiricalMu.length; jjj++ ) { - empiricalMu[jjj] += datum.annotations[jjj] / ((double) data.size()); + for( int iii = 0; iii < empiricalMu.length; iii++ ) { + empiricalMu[iii] += datum.annotations[iii] / ((double) data.size()); } } - empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length)); // BUGBUG: why does the identity matrix work best here? - // is it because of a bug in the old implementation in which std>X variants were still counted in this calculation?? + + /* + for( final VariantDatum datum : data ) { + for( int iii = 0; iii < empiricalMu.length; iii++ ) { + for( int jjj = 0; jjj < empiricalMu.length; jjj++ ) { + tmpSigmaVals[iii][jjj] += (datum.annotations[iii]-empiricalMu[iii]) * (datum.annotations[jjj]-empiricalMu[jjj]); + } + } + } + + empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, new Matrix(tmpSigmaVals)); + empiricalSigma.timesEquals( 1.0 / ((double) data.size()) ); + empiricalSigma.timesEquals( 1.0 / (Math.pow(gaussians.size(), 2.0 / ((double) empiricalMu.length))) ); + */ + empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length)); } public void initializeRandomModel( final List data, final Random rand ) { @@ -94,7 +118,7 @@ public class GaussianMixtureModel { for( final MultivariateGaussian gaussian : gaussians ) { gaussian.pMixtureLog10 = Math.log10( 1.0 / ((double) gaussians.size()) ); gaussian.initializeRandomSigma( rand ); - gaussian.hyperParameter_a = gaussian.mu.length; + gaussian.hyperParameter_a = degreesOfFreedom; gaussian.hyperParameter_b = shrinkage; gaussian.hyperParameter_lambda = dirichletParameter; } @@ -164,7 +188,7 @@ public class GaussianMixtureModel { public void maximizationStep( final List data ) { for( final MultivariateGaussian gaussian : gaussians ) { - gaussian.maximizeGaussian( data, empiricalMu, empiricalSigma, shrinkage, dirichletParameter ); + gaussian.maximizeGaussian( data, empiricalMu, empiricalSigma, shrinkage, dirichletParameter, degreesOfFreedom ); } } @@ -186,9 +210,9 @@ public class GaussianMixtureModel { for(int jjj = 0; jjj < gaussian.mu.length; jjj++ ) { clusterFile.print(String.format(",%.8f", gaussian.mu[jjj])); } - for(int jjj = 0; jjj < gaussian.mu.length; jjj++ ) { - for(int ppp = 0; ppp < gaussian.mu.length; ppp++ ) { - clusterFile.print(String.format(",%.8f", (sigmaVals[jjj][ppp] / gaussian.hyperParameter_a) )); // BUGBUG: this is a bug which should be fixed after passing integration tests + for(int iii = 0; iii < gaussian.mu.length; iii++ ) { + for(int jjj = 0; jjj < gaussian.mu.length; jjj++ ) { + clusterFile.print(String.format(",%.8f", (sigmaVals[iii][jjj] / gaussian.hyperParameter_a))); } } clusterFile.println(); diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/MultivariateGaussian.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/MultivariateGaussian.java index 9a33bec5d..af507d351 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/MultivariateGaussian.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/MultivariateGaussian.java @@ -18,7 +18,6 @@ public class MultivariateGaussian { public double pMixtureLog10; final public double[] mu; final public Matrix sigma; - private double cachedDeterminant; public double hyperParameter_a; public double hyperParameter_b; public double hyperParameter_lambda; @@ -65,7 +64,6 @@ public class MultivariateGaussian { Matrix tmp = new Matrix( randSigma ); tmp = tmp.times(tmp.transpose()); sigma.setMatrix(0, mu.length - 1, 0, mu.length - 1, tmp); - cachedDeterminant = sigma.det(); } public double calculateDistanceFromMeanSquared( final VariantDatum datum ) { @@ -90,9 +88,7 @@ public class MultivariateGaussian { public void precomputeDenominatorForEvaluation() { cachedSigmaInverse = sigma.inverse(); - cachedDenomLog10 = -1.0 * ( Math.log10(Math.pow(2.0 * Math.PI, ((double) mu.length) / 2.0)) + Math.log10(Math.pow(sigma.det(), 0.5)) ); - //BUGBUG: This should be determinant of sigma inverse? - //BUGBUG: Denom --> constant factor log10 + cachedDenomLog10 = Math.log10(Math.pow(2.0 * Math.PI, -1.0 * ((double) mu.length) / 2.0)) + Math.log10(Math.pow(sigma.det(), -0.5)) ; } public void precomputeDenominatorForVariationalBayes( final double sumHyperParameterLambda ) { @@ -102,7 +98,7 @@ public class MultivariateGaussian { for(int jjj = 1; jjj < mu.length; jjj++) { sum += MathUtils.diGamma( (hyperParameter_a + 1.0 - jjj) / 2.0 ); } - sum -= Math.log( cachedDeterminant ); + sum -= Math.log( sigma.det() ); sum += Math.log(2.0) * mu.length; final double gamma = 0.5 * sum; final double pi = MathUtils.diGamma( hyperParameter_lambda ) - MathUtils.diGamma( sumHyperParameterLambda ); @@ -135,7 +131,7 @@ public class MultivariateGaussian { } public void maximizeGaussian( final List data, final double[] empiricalMu, final Matrix empiricalSigma, - final double SHRINKAGE, final double DIRICHLET_PARAMETER ) { + final double SHRINKAGE, final double DIRICHLET_PARAMETER, final double DEGREES_OF_FREEDOM ) { double sumProb = 0.0; Matrix wishart = new Matrix(mu.length, mu.length); zeroOutMu(); @@ -173,11 +169,10 @@ public class MultivariateGaussian { sigma.plusEquals( empiricalSigma ); sigma.plusEquals( wishart ); - cachedDeterminant = sigma.det(); pMixtureLog10 = sumProb; // will be normalized later by GaussianMixtureModel so no need to do it every iteration - hyperParameter_a = sumProb + mu.length; + hyperParameter_a = sumProb + DEGREES_OF_FREEDOM; hyperParameter_b = sumProb + SHRINKAGE; hyperParameter_lambda = sumProb + DIRICHLET_PARAMETER; diff --git a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java index 4f56bc030..415d88f4a 100755 --- a/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java +++ b/java/src/org/broadinstitute/sting/playground/gatk/walkers/variantrecalibration/VariantRecalibratorEngine.java @@ -40,7 +40,7 @@ public class VariantRecalibratorEngine { } public GaussianMixtureModel generateModel( final List data ) { - final GaussianMixtureModel model = new GaussianMixtureModel( 4, 3, 0.0001, 1000.0 ); //BUGBUG: VRAC.maxGaussians, VRAC.numAnnotations + final GaussianMixtureModel model = new GaussianMixtureModel( 4, 3, 0.0001, 1000.0 ); //BUGBUG: VRAC arguments variationalBayesExpectationMaximization( model, data ); return model; }