From 0b37d44b0d3e41bde2078b5614a995cf7269a160 Mon Sep 17 00:00:00 2001 From: Eric Banks Date: Tue, 3 Jul 2012 13:05:11 -0400 Subject: [PATCH] Optimizations for the RecalDatum to make BQSR (Count Covariates) much faster. Needs some cleanup. --- .../sting/gatk/walkers/bqsr/BQSRGatherer.java | 18 +++-- .../sting/gatk/walkers/bqsr/Datum.java | 5 ++ .../sting/gatk/walkers/bqsr/RecalDatum.java | 70 +++++++++---------- .../walkers/bqsr/RecalibrationReport.java | 19 +++-- .../BaseRecalibrationUnitTest.java | 34 ++------- 5 files changed, 61 insertions(+), 85 deletions(-) diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BQSRGatherer.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BQSRGatherer.java index 01fa92b8c..122958ac2 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BQSRGatherer.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BQSRGatherer.java @@ -48,7 +48,7 @@ public class BQSRGatherer extends Gatherer { @Override public void gather(List inputs, File output) { RecalibrationReport generalReport = null; - PrintStream outputFile; + final PrintStream outputFile; try { outputFile = new PrintStream(output); } catch(FileNotFoundException e) { @@ -56,7 +56,7 @@ public class BQSRGatherer extends Gatherer { } for (File input : inputs) { - RecalibrationReport inputReport = new RecalibrationReport(input); + final RecalibrationReport inputReport = new RecalibrationReport(input); if (generalReport == null) generalReport = inputReport; else @@ -65,19 +65,17 @@ public class BQSRGatherer extends Gatherer { if (generalReport == null) throw new ReviewedStingException(EMPTY_INPUT_LIST); - generalReport.calculateEmpiricalAndQuantizedQualities(); + generalReport.calculateQuantizedQualities(); RecalibrationArgumentCollection RAC = generalReport.getRAC(); if (RAC.recalibrationReport != null && !RAC.NO_PLOTS) { - File recal_out = new File(output.getName() + ".original"); - RecalibrationReport originalReport = new RecalibrationReport(RAC.recalibrationReport); - // TODO -- fix me - //RecalDataManager.generateRecalibrationPlot(recal_out, originalReport.getKeysAndTablesMap(), generalReport.getKeysAndTablesMap(), RAC.KEEP_INTERMEDIATE_FILES); + final File recal_out = new File(output.getName() + ".original"); + final RecalibrationReport originalReport = new RecalibrationReport(RAC.recalibrationReport); + RecalDataManager.generateRecalibrationPlot(recal_out, originalReport.getRecalibrationTables(), generalReport.getRecalibrationTables(), generalReport.getCovariates(), RAC.KEEP_INTERMEDIATE_FILES); } else if (!RAC.NO_PLOTS) { - File recal_out = new File(output.getName() + ".recal"); - // TODO -- fix me - //RecalDataManager.generateRecalibrationPlot(recal_out, generalReport.getKeysAndTablesMap(), RAC.KEEP_INTERMEDIATE_FILES); + final File recal_out = new File(output.getName() + ".recal"); + RecalDataManager.generateRecalibrationPlot(recal_out, generalReport.getRecalibrationTables(), generalReport.getCovariates(), RAC.KEEP_INTERMEDIATE_FILES); } generalReport.output(outputFile); diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/Datum.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/Datum.java index 779500512..1ebf3941b 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/Datum.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/Datum.java @@ -71,6 +71,11 @@ public class Datum { numMismatches += incMismatches; } + synchronized void increment(final boolean isError) { + numObservations++; + numMismatches += isError ? 1:0; + } + //--------------------------------------------------------------------------------------------------------------- // // methods to derive empirical quality score diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalDatum.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalDatum.java index b26912c31..6dc5e70d1 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalDatum.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalDatum.java @@ -25,7 +25,10 @@ package org.broadinstitute.sting.gatk.walkers.bqsr; * OTHER DEALINGS IN THE SOFTWARE. */ +import com.google.java.contract.Ensures; +import com.google.java.contract.Requires; import org.broadinstitute.sting.utils.MathUtils; +import org.broadinstitute.sting.utils.QualityUtils; import java.util.Random; @@ -39,6 +42,8 @@ import java.util.Random; public class RecalDatum extends Datum { + private static final double UNINITIALIZED = -1.0; + private double estimatedQReported; // estimated reported quality score based on combined data's individual q-reporteds and number of observations private double empiricalQuality; // the empirical quality for datums that have been collapsed together (by read group and reported quality, for example) @@ -49,18 +54,10 @@ public class RecalDatum extends Datum { // //--------------------------------------------------------------------------------------------------------------- - public RecalDatum() { - numObservations = 0L; - numMismatches = 0L; - estimatedQReported = 0.0; - empiricalQuality = -1.0; - } - - public RecalDatum(final long _numObservations, final long _numMismatches, final double _estimatedQReported, final double _empiricalQuality) { + public RecalDatum(final long _numObservations, final long _numMismatches, final byte reportedQuality) { numObservations = _numObservations; numMismatches = _numMismatches; - estimatedQReported = _estimatedQReported; - empiricalQuality = _empiricalQuality; + estimatedQReported = QualityUtils.qualToErrorProb(reportedQuality); } public RecalDatum(final RecalDatum copy) { @@ -72,36 +69,39 @@ public class RecalDatum extends Datum { public void combine(final RecalDatum other) { final double sumErrors = this.calcExpectedErrors() + other.calcExpectedErrors(); - this.increment(other.numObservations, other.numMismatches); - this.estimatedQReported = -10 * Math.log10(sumErrors / this.numObservations); - this.empiricalQuality = -1.0; // reset the empirical quality calculation so we never have a wrongly calculated empirical quality stored + increment(other.numObservations, other.numMismatches); + estimatedQReported = -10 * Math.log10(sumErrors / this.numObservations); + empiricalQuality = UNINITIALIZED; } - public final void calcCombinedEmpiricalQuality() { - this.empiricalQuality = empiricalQualDouble(); // cache the value so we don't call log over and over again + @Override + public void increment(final boolean isError) { + super.increment(isError); + empiricalQuality = UNINITIALIZED; } - - public final void calcEstimatedReportedQuality() { - this.estimatedQReported = -10 * Math.log10(calcExpectedErrors() / numObservations); + + @Requires("empiricalQuality == UNINITIALIZED") + @Ensures("empiricalQuality != UNINITIALIZED") + protected final void calcEmpiricalQuality() { + empiricalQuality = empiricalQualDouble(); // cache the value so we don't call log over and over again + } + + public void setEstimatedQReported(final double estimatedQReported) { + this.estimatedQReported = estimatedQReported; } public final double getEstimatedQReported() { return estimatedQReported; } - public final double getEmpiricalQuality() { - if (empiricalQuality < 0) - calcCombinedEmpiricalQuality(); - return empiricalQuality; + public void setEmpiricalQuality(final double empiricalQuality) { + this.empiricalQuality = empiricalQuality; } - /** - * Makes a hard copy of the recal datum element - * - * @return a new recal datum object with the same contents of this datum. - */ - public RecalDatum copy() { - return new RecalDatum(numObservations, numMismatches, estimatedQReported, empiricalQuality); + public final double getEmpiricalQuality() { + if (empiricalQuality == UNINITIALIZED) + calcEmpiricalQuality(); + return empiricalQuality; } @Override @@ -122,13 +122,11 @@ public class RecalDatum extends Datum { } public static RecalDatum createRandomRecalDatum(int maxObservations, int maxErrors) { - Random random = new Random(); - int nObservations = random.nextInt(maxObservations); - int nErrors = random.nextInt(maxErrors); - Datum datum = new Datum(nObservations, nErrors); - double empiricalQuality = datum.empiricalQualDouble(); - double estimatedQReported = empiricalQuality + ((10 * random.nextDouble()) - 5); // empirical quality +/- 5. - return new RecalDatum(nObservations, nErrors, estimatedQReported, empiricalQuality); + final Random random = new Random(); + final int nObservations = random.nextInt(maxObservations); + final int nErrors = random.nextInt(maxErrors); + final int qual = random.nextInt(QualityUtils.MAX_QUAL_SCORE); + return new RecalDatum(nObservations, nErrors, (byte)qual); } /** diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationReport.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationReport.java index 05e24e98a..c49ff8b47 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationReport.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationReport.java @@ -57,7 +57,6 @@ public class RecalibrationReport { for (Covariate cov : requestedCovariates) cov.initialize(RAC); // initialize any covariate member variables using the shared argument collection - // TODO -- note that we might be able to save memory (esp. for sparse tables) by making one pass through the GATK report to see the maximum values for each covariate and using that in the constructor here recalibrationTables = new RecalibrationTables(requestedCovariates, countReadGroups(report.getTable(RecalDataManager.READGROUP_REPORT_TABLE_TITLE))); parseReadGroupTable(report.getTable(RecalDataManager.READGROUP_REPORT_TABLE_TITLE), recalibrationTables.getTable(RecalibrationTables.TableType.READ_GROUP_TABLE)); @@ -202,7 +201,10 @@ public class RecalibrationReport { (Double) reportTable.get(row, RecalDataManager.ESTIMATED_Q_REPORTED_COLUMN_NAME) : // we get it if we are in the read group table Byte.parseByte((String) reportTable.get(row, RecalDataManager.QUALITY_SCORE_COLUMN_NAME)); // or we use the reported quality if we are in any other table - return new RecalDatum(nObservations, nErrors, estimatedQReported, empiricalQuality); + final RecalDatum datum = new RecalDatum(nObservations, nErrors, (byte)1); + datum.setEstimatedQReported(estimatedQReported); + datum.setEmpiricalQuality(empiricalQuality); + return datum; } /** @@ -297,14 +299,7 @@ public class RecalibrationReport { * this functionality avoids recalculating the empirical qualities, estimated reported quality * and quantization of the quality scores during every call of combine(). Very useful for the BQSRGatherer. */ - public void calculateEmpiricalAndQuantizedQualities() { - for (RecalibrationTables.TableType type : RecalibrationTables.TableType.values()) { - final NestedIntegerArray table = recalibrationTables.getTable(type); - for (final Object value : table.getAllValues()) { - ((RecalDatum)value).calcCombinedEmpiricalQuality(); - } - } - + public void calculateQuantizedQualities() { quantizationInfo = new QuantizationInfo(recalibrationTables, RAC.QUANTIZING_LEVELS); } @@ -315,4 +310,8 @@ public class RecalibrationReport { public RecalibrationArgumentCollection getRAC() { return RAC; } + + public Covariate[] getCovariates() { + return requestedCovariates; + } } diff --git a/public/java/test/org/broadinstitute/sting/utils/recalibration/BaseRecalibrationUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/recalibration/BaseRecalibrationUnitTest.java index df4c351d6..74d9420b2 100644 --- a/public/java/test/org/broadinstitute/sting/utils/recalibration/BaseRecalibrationUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/recalibration/BaseRecalibrationUnitTest.java @@ -82,15 +82,15 @@ public class BaseRecalibrationUnitTest { final Object[] objKey = buildObjectKey(bitKeys); Random random = new Random(); - int nObservations = random.nextInt(10000); - int nErrors = random.nextInt(10); - double estimatedQReported = 30; + final int nObservations = random.nextInt(10000); + final int nErrors = random.nextInt(10); + final byte estimatedQReported = 30; double empiricalQuality = calcEmpiricalQual(nObservations, nErrors); org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum oldDatum = new org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum(nObservations, nErrors, estimatedQReported, empiricalQuality); dataManager.addToAllTables(objKey, oldDatum, QualityUtils.MIN_USABLE_Q_SCORE); - RecalDatum newDatum = new RecalDatum(nObservations, nErrors, estimatedQReported, empiricalQuality); + RecalDatum newDatum = new RecalDatum(nObservations, nErrors, estimatedQReported); rgTable.put(newDatum, bitKeys[0], EventType.BASE_SUBSTITUTION.index); qualTable.put(newDatum, bitKeys[0], bitKeys[1], EventType.BASE_SUBSTITUTION.index); @@ -100,7 +100,7 @@ public class BaseRecalibrationUnitTest { } } - dataManager.generateEmpiricalQualities(1, QualityUtils.MAX_RECALIBRATED_Q_SCORE); + dataManager.generateEmpiricalQualities(1, QualityUtils.MAX_RECALIBRATED_Q_SCORE); List quantizedQuals = new ArrayList(); List qualCounts = new ArrayList(); @@ -135,30 +135,6 @@ public class BaseRecalibrationUnitTest { return key; } - private static void printNestedHashMap(Map table, String output) { - for (Object key : table.keySet()) { - String ret; - if (output.isEmpty()) - ret = "" + key; - else - ret = output + "," + key; - - Object next = table.get(key); - if (next instanceof org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum) - System.out.println(ret + " => " + next); - else - printNestedHashMap((Map) next, "" + ret); - } - } - - private void updateCovariateWithKeySet(final Map recalTable, final Long hashKey, final RecalDatum datum) { - RecalDatum previousDatum = recalTable.get(hashKey); // using the list of covariate values as a key, pick out the RecalDatum from the data HashMap - if (previousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it - recalTable.put(hashKey, datum.copy()); - else - previousDatum.combine(datum); // add one to the number of observations and potentially one to the number of mismatches - } - /** * Implements a serial recalibration of the reads using the combinational table. * First, we perform a positional recalibration, and then a subsequent dinuc correction.