misc cleanup in VQSR

git-svn-id: file:///humgen/gsa-scr1/gsa-engineering/svn_contents/trunk@5732 348d0f76-0448-11de-a6fe-93d51630548a
This commit is contained in:
rpoplin 2011-05-03 18:00:22 +00:00
parent f3bd11a02e
commit 6323fb8673
5 changed files with 14 additions and 23 deletions

View File

@ -210,7 +210,7 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
public void onTraversalDone( final ExpandingArrayList<VariantDatum> reduceSum ) { public void onTraversalDone( final ExpandingArrayList<VariantDatum> reduceSum ) {
dataManager.setData( reduceSum ); dataManager.setData( reduceSum );
dataManager.normalizeData(); dataManager.normalizeData(); // Each data point is now (x - mean) / standard deviation
final GaussianMixtureModel goodModel = engine.generateModel( dataManager.getTrainingData() ); final GaussianMixtureModel goodModel = engine.generateModel( dataManager.getTrainingData() );
engine.evaluateData( dataManager.getData(), goodModel, false ); engine.evaluateData( dataManager.getData(), goodModel, false );
final GaussianMixtureModel badModel = engine.generateModel( dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS ) ); final GaussianMixtureModel badModel = engine.generateModel( dataManager.selectWorstVariants( VRAC.PERCENT_BAD_VARIANTS ) );
@ -250,12 +250,10 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
createArrangeFunction( stream ); createArrangeFunction( stream );
stream.println("pdf(\"" + RSCRIPT_FILE + ".pdf\")"); stream.println("pdf(\"" + RSCRIPT_FILE + ".pdf\")"); // Unfortunately this is a huge pdf file, BUGBUG: need to work on reducing the file size
for(int iii = 0; iii < USE_ANNOTATIONS.length; iii++) { for(int iii = 0; iii < USE_ANNOTATIONS.length; iii++) {
for( int jjj = iii + 1; jjj < USE_ANNOTATIONS.length; jjj++) { for( int jjj = iii + 1; jjj < USE_ANNOTATIONS.length; jjj++) {
//stream.println("png(\"" + RSCRIPT_FILE + "." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj] + ".png\", type=\"cairo\", width = 960, height = 960)");
//stream.println("pdf(\"" + RSCRIPT_FILE + "." + USE_ANNOTATIONS[iii] + "." + USE_ANNOTATIONS[jjj] + ".pdf\")");
logger.info( "Building " + USE_ANNOTATIONS[iii] + " x " + USE_ANNOTATIONS[jjj] + " plot..."); logger.info( "Building " + USE_ANNOTATIONS[iii] + " x " + USE_ANNOTATIONS[jjj] + " plot...");
final ExpandingArrayList<VariantDatum> fakeData = new ExpandingArrayList<VariantDatum>(); final ExpandingArrayList<VariantDatum> fakeData = new ExpandingArrayList<VariantDatum>();
@ -266,6 +264,7 @@ public class ContrastiveRecalibrator extends RodWalker<ExpandingArrayList<Varian
minAnn2 = Math.min(minAnn2, datum.annotations[jjj]); minAnn2 = Math.min(minAnn2, datum.annotations[jjj]);
maxAnn2 = Math.max(maxAnn2, datum.annotations[jjj]); maxAnn2 = Math.max(maxAnn2, datum.annotations[jjj]);
} }
// Create a fake set of data which spans the full extent of these two annotation dimensions in order to calculate the model PDF projected to 2D
for(double ann1 = minAnn1; ann1 <= maxAnn1; ann1+=0.1) { for(double ann1 = minAnn1; ann1 <= maxAnn1; ann1+=0.1) {
for(double ann2 = minAnn2; ann2 <= maxAnn2; ann2+=0.1) { for(double ann2 = minAnn2; ann2 <= maxAnn2; ann2+=0.1) {
final VariantDatum datum = new VariantDatum(); final VariantDatum datum = new VariantDatum();

View File

@ -42,9 +42,6 @@ public class GaussianMixtureModel {
empiricalMu = new double[numAnnotations]; empiricalMu = new double[numAnnotations];
empiricalSigma = new Matrix(numAnnotations, numAnnotations); empiricalSigma = new Matrix(numAnnotations, numAnnotations);
isModelReadyForEvaluation = false; isModelReadyForEvaluation = false;
}
public void cacheEmpiricalStats() {
Arrays.fill(empiricalMu, 0.0); Arrays.fill(empiricalMu, 0.0);
empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length).times(200.0).inverse()); empiricalSigma.setMatrix(0, empiricalMu.length - 1, 0, empiricalMu.length - 1, Matrix.identity(empiricalMu.length, empiricalMu.length).times(200.0).inverse());
} }
@ -75,6 +72,7 @@ public class GaussianMixtureModel {
int ttt = 0; int ttt = 0;
while( ttt++ < numIterations ) { while( ttt++ < numIterations ) {
// Estep: assign each variant to the nearest cluster
for( final VariantDatum datum : data ) { for( final VariantDatum datum : data ) {
double minDistance = Double.MAX_VALUE; double minDistance = Double.MAX_VALUE;
MultivariateGaussian minGaussian = null; MultivariateGaussian minGaussian = null;
@ -89,6 +87,7 @@ public class GaussianMixtureModel {
datum.assignment = minGaussian; datum.assignment = minGaussian;
} }
// Mstep: update gaussian means based on assigned variants
for( final MultivariateGaussian gaussian : gaussians ) { for( final MultivariateGaussian gaussian : gaussians ) {
gaussian.zeroOutMu(); gaussian.zeroOutMu();
int numAssigned = 0; int numAssigned = 0;
@ -188,7 +187,7 @@ public class GaussianMixtureModel {
for( final MultivariateGaussian gaussian : gaussians ) { for( final MultivariateGaussian gaussian : gaussians ) {
pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10( datum ); pVarInGaussianLog10[gaussianIndex++] = gaussian.pMixtureLog10 + gaussian.evaluateDatumLog10( datum );
} }
return MathUtils.log10sumLog10(pVarInGaussianLog10); return MathUtils.log10sumLog10(pVarInGaussianLog10); // Sum(pi_k * p(v|n,k))
} }
public double evaluateDatumMarginalized( final VariantDatum datum ) { public double evaluateDatumMarginalized( final VariantDatum datum ) {
@ -197,6 +196,7 @@ public class GaussianMixtureModel {
int numIter = 10; int numIter = 10;
final double[] pVarInGaussianLog10 = new double[gaussians.size()]; final double[] pVarInGaussianLog10 = new double[gaussians.size()];
for( int iii = 0; iii < datum.annotations.length; iii++ ) { for( int iii = 0; iii < datum.annotations.length; iii++ ) {
// marginalize over the missing dimension by drawing X random values for the missing annotation and averaging the lod
if( datum.isNull[iii] ) { if( datum.isNull[iii] ) {
for( int ttt = 0; ttt < numIter; ttt++ ) { for( int ttt = 0; ttt < numIter; ttt++ ) {
datum.annotations[iii] = Normal.staticNextDouble(0.0, 1.0); datum.annotations[iii] = Normal.staticNextDouble(0.0, 1.0);

View File

@ -94,6 +94,7 @@ public class MultivariateGaussian {
} }
public void precomputeDenominatorForVariationalBayes( final double sumHyperParameterLambda ) { public void precomputeDenominatorForVariationalBayes( final double sumHyperParameterLambda ) {
// Variational Bayes calculations from Bishop
cachedSigmaInverse = sigma.inverse(); cachedSigmaInverse = sigma.inverse();
cachedSigmaInverse.timesEquals( hyperParameter_a ); cachedSigmaInverse.timesEquals( hyperParameter_a );
double sum = 0.0; double sum = 0.0;

View File

@ -73,7 +73,7 @@ public class VariantDataManager {
throw new UserException.BadInput( "Found annotations with zero variance. They must be excluded before proceeding." ); throw new UserException.BadInput( "Found annotations with zero variance. They must be excluded before proceeding." );
} }
// trim data by standard deviation threshold and place into two sets: data and failingData // trim data by standard deviation threshold and mark failing data for exclusion later
for( final VariantDatum datum : data ) { for( final VariantDatum datum : data ) {
boolean remove = false; boolean remove = false;
for( final double val : datum.annotations ) { for( final double val : datum.annotations ) {

View File

@ -29,7 +29,6 @@ public class VariantRecalibratorEngine {
public VariantRecalibratorEngine( final VariantRecalibratorArgumentCollection VRAC ) { public VariantRecalibratorEngine( final VariantRecalibratorArgumentCollection VRAC ) {
this.VRAC = VRAC; this.VRAC = VRAC;
initialize( this.VRAC );
} }
public GaussianMixtureModel generateModel( final List<VariantDatum> data ) { public GaussianMixtureModel generateModel( final List<VariantDatum> data ) {
@ -50,36 +49,28 @@ public class VariantRecalibratorEngine {
} }
} }
/////////////////////////////
// Private Methods used for initialization
/////////////////////////////
private void initialize( final VariantRecalibratorArgumentCollection VRAC ) {
}
///////////////////////////// /////////////////////////////
// Private Methods used for generating a GaussianMixtureModel // Private Methods used for generating a GaussianMixtureModel
///////////////////////////// /////////////////////////////
private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List<VariantDatum> data ) { private void variationalBayesExpectationMaximization( final GaussianMixtureModel model, final List<VariantDatum> data ) {
model.cacheEmpiricalStats();
model.initializeRandomModel( data, VRAC.NUM_KMEANS_ITERATIONS ); model.initializeRandomModel( data, VRAC.NUM_KMEANS_ITERATIONS );
// The VBEM loop // The VBEM loop
model.normalizePMixtureLog10(); model.normalizePMixtureLog10();
model.expectationStep( data ); model.expectationStep( data );
double currentLikelihood; double currentChangeInMixtureCoefficients;
int iteration = 0; int iteration = 0;
logger.info("Finished iteration " + iteration ); logger.info("Finished iteration " + iteration );
while( iteration < VRAC.MAX_ITERATIONS ) { while( iteration < VRAC.MAX_ITERATIONS ) {
iteration++; iteration++;
model.maximizationStep( data ); model.maximizationStep( data );
currentLikelihood = model.normalizePMixtureLog10(); currentChangeInMixtureCoefficients = model.normalizePMixtureLog10();
model.expectationStep( data ); model.expectationStep(data);
logger.info("Current change in mixture coefficients = " + String.format("%.5f", currentLikelihood)); logger.info("Current change in mixture coefficients = " + String.format("%.5f", currentChangeInMixtureCoefficients));
logger.info("Finished iteration " + iteration ); logger.info("Finished iteration " + iteration );
if( iteration > 2 && currentLikelihood < MIN_PROB_CONVERGENCE ) { if( iteration > 2 && currentChangeInMixtureCoefficients < MIN_PROB_CONVERGENCE ) {
logger.info("Convergence!"); logger.info("Convergence!");
break; break;
} }