From 5ec25797b3c9c9e0abc1b4d318dbbbee59d365b7 Mon Sep 17 00:00:00 2001 From: Mark DePristo Date: Wed, 12 Dec 2012 13:54:14 -0500 Subject: [PATCH] Optimizations for BaseRecalibrator -- No longer computes at each update the overall read group table. Now computes this derived table only at the end of the computation, using the ByQual table as input. Reduces BQSR runtime by 1/3 in my test --- .../bqsr/AdvancedRecalibrationEngine.java | 6 +- .../gatk/walkers/bqsr/BaseRecalibrator.java | 2 + .../walkers/bqsr/RecalibrationEngine.java | 5 +- .../bqsr/StandardRecalibrationEngine.java | 64 +++++++++---------- .../utils/collections/NestedIntegerArray.java | 14 ++-- .../sting/utils/recalibration/RecalDatum.java | 26 ++------ .../covariates/ReadGroupCovariate.java | 12 +++- 7 files changed, 57 insertions(+), 72 deletions(-) diff --git a/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java b/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java index d0bcd0eb3..ffd681056 100644 --- a/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java +++ b/protected/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/AdvancedRecalibrationEngine.java @@ -25,13 +25,11 @@ package org.broadinstitute.sting.gatk.walkers.bqsr; * OTHER DEALINGS IN THE SOFTWARE. */ -import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; -import org.broadinstitute.sting.utils.BaseUtils; import org.broadinstitute.sting.utils.classloader.ProtectedPackageSource; -import org.broadinstitute.sting.utils.pileup.PileupElement; import org.broadinstitute.sting.utils.recalibration.EventType; import org.broadinstitute.sting.utils.recalibration.ReadCovariates; import org.broadinstitute.sting.utils.recalibration.RecalibrationTables; +import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; import org.broadinstitute.sting.utils.threading.ThreadLocalArray; @@ -67,8 +65,6 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp final byte qual = tempQualArray[eventIndex]; final double isError = tempFractionalErrorArray[eventIndex]; - combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex); - incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex); for (int i = 2; i < covariates.length; i++) { diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java index 7ce98cf1d..2410aefce 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/BaseRecalibrator.java @@ -452,6 +452,8 @@ public class BaseRecalibrator extends ReadWalker implements NanoSche @Override public void onTraversalDone(Long result) { + recalibrationEngine.finalizeData(); + logger.info("Calculating quantized quality scores..."); quantizeQualityScores(); diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationEngine.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationEngine.java index 962d62d5e..35375eb1d 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationEngine.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalibrationEngine.java @@ -1,8 +1,7 @@ package org.broadinstitute.sting.gatk.walkers.bqsr; -import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; -import org.broadinstitute.sting.utils.pileup.PileupElement; import org.broadinstitute.sting.utils.recalibration.RecalibrationTables; +import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; /* @@ -34,4 +33,6 @@ public interface RecalibrationEngine { public void initialize(final Covariate[] covariates, final RecalibrationTables recalibrationTables); public void updateDataForRead(final GATKSAMRecord read, final boolean[] skip, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors); + + public void finalizeData(); } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java index 6031aa955..1e166dfd0 100644 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/StandardRecalibrationEngine.java @@ -25,15 +25,13 @@ package org.broadinstitute.sting.gatk.walkers.bqsr; * OTHER DEALINGS IN THE SOFTWARE. */ -import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; -import org.broadinstitute.sting.utils.BaseUtils; import org.broadinstitute.sting.utils.classloader.PublicPackageSource; import org.broadinstitute.sting.utils.collections.NestedIntegerArray; -import org.broadinstitute.sting.utils.pileup.PileupElement; import org.broadinstitute.sting.utils.recalibration.EventType; import org.broadinstitute.sting.utils.recalibration.ReadCovariates; import org.broadinstitute.sting.utils.recalibration.RecalDatum; import org.broadinstitute.sting.utils.recalibration.RecalibrationTables; +import org.broadinstitute.sting.utils.recalibration.covariates.Covariate; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; public class StandardRecalibrationEngine implements RecalibrationEngine, PublicPackageSource { @@ -58,8 +56,6 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP final int[] keys = readCovariates.getKeySet(offset, EventType.BASE_SUBSTITUTION); final int eventIndex = EventType.BASE_SUBSTITUTION.index; - combineDatumOrPutIfNecessary(recalibrationTables.getReadGroupTable(), qual, isError, keys[0], eventIndex); - incrementDatumOrPutIfNecessary(recalibrationTables.getQualityScoreTable(), qual, isError, keys[0], keys[1], eventIndex); for (int i = 2; i < covariates.length; i++) { @@ -93,6 +89,34 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP return (ReadCovariates) read.getTemporaryAttribute(BaseRecalibrator.COVARS_ATTRIBUTE); } + /** + * Create derived recalibration data tables + * + * Assumes that all of the principal tables (by quality score) have been completely updated, + * and walks over this data to create summary data tables like by read group table. + */ + @Override + public void finalizeData() { + final NestedIntegerArray byReadGroupTable = recalibrationTables.getReadGroupTable(); + final NestedIntegerArray byQualTable = recalibrationTables.getQualityScoreTable(); + + // iterate over all values in the qual table + for ( NestedIntegerArray.Leaf leaf : byQualTable.getAllLeaves() ) { + final int rgKey = leaf.keys[0]; + final int eventIndex = leaf.keys[2]; + final RecalDatum rgDatum = byReadGroupTable.get(rgKey, eventIndex); + final RecalDatum qualDatum = leaf.value; + + if ( rgDatum == null ) { + // create a copy of qualDatum, and initialize byReadGroup table with it + byReadGroupTable.put(new RecalDatum(qualDatum), rgKey, eventIndex); + } else { + // combine the qual datum with the existing datum in the byReadGroup table + rgDatum.combine(qualDatum); + } + } + } + /** * Increments the RecalDatum at the specified position in the specified table, or put a new item there * if there isn't already one. @@ -121,34 +145,4 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP existingDatum.increment(1.0, isError); } } - - /** - * Combines the RecalDatum at the specified position in the specified table with a new RecalDatum, or put a - * new item there if there isn't already one. - * - * Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put() - * to return false if another thread inserts a new item at our position in the middle of our put operation. - * - * @param table the table that holds/will hold our item - * @param qual qual for this event - * @param isError error value for this event - * @param keys location in table of our item - */ - protected void combineDatumOrPutIfNecessary( final NestedIntegerArray table, final byte qual, final double isError, final int... keys ) { - final RecalDatum existingDatum = table.get(keys); - final RecalDatum newDatum = createDatumObject(qual, isError); - - if ( existingDatum == null ) { - // No existing item, try to put a new one - if ( ! table.put(newDatum, keys) ) { - // Failed to put a new item because another thread came along and put an item here first. - // Get the newly-put item and combine it with our item (item is guaranteed to exist at this point) - table.get(keys).combine(newDatum); - } - } - else { - // Easy case: already an item here, so combine it with our item - existingDatum.combine(newDatum); - } - } } diff --git a/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java b/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java index 050ed52ac..2e45eabe1 100755 --- a/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java +++ b/public/java/src/org/broadinstitute/sting/utils/collections/NestedIntegerArray.java @@ -174,23 +174,23 @@ public class NestedIntegerArray { } } - public static class Leaf { + public static class Leaf { public final int[] keys; - public final Object value; + public final T value; - public Leaf(final int[] keys, final Object value) { + public Leaf(final int[] keys, final T value) { this.keys = keys; this.value = value; } } - public List getAllLeaves() { - final List result = new ArrayList(); + public List> getAllLeaves() { + final List> result = new ArrayList>(); fillAllLeaves(data, new int[0], result); return result; } - private void fillAllLeaves(final Object[] array, final int[] path, final List result) { + private void fillAllLeaves(final Object[] array, final int[] path, final List> result) { for ( int key = 0; key < array.length; key++ ) { final Object value = array[key]; if ( value == null ) @@ -199,7 +199,7 @@ public class NestedIntegerArray { if ( value instanceof Object[] ) { fillAllLeaves((Object[]) value, newPath, result); } else { - result.add(new Leaf(newPath, value)); + result.add(new Leaf(newPath, (T)value)); } } } diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java index 207988749..31eb40d24 100755 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/RecalDatum.java @@ -191,9 +191,9 @@ public class RecalDatum { return (byte)(Math.round(getEmpiricalQuality())); } - //--------------------------------------------------------------------------------------------------------------- + //--------------------------------------------------------------------------------------------------------------- // - // increment methods + // toString methods // //--------------------------------------------------------------------------------------------------------------- @@ -206,22 +206,6 @@ public class RecalDatum { return String.format("%s,%.2f,%.2f", toString(), getEstimatedQReported(), getEmpiricalQuality() - getEstimatedQReported()); } -// /** -// * We don't compare the estimated quality reported because it may be different when read from -// * report tables. -// * -// * @param o the other recal datum -// * @return true if the two recal datums have the same number of observations, errors and empirical quality. -// */ -// @Override -// public boolean equals(Object o) { -// if (!(o instanceof RecalDatum)) -// return false; -// RecalDatum other = (RecalDatum) o; -// return super.equals(o) && -// MathUtils.compareDoubles(this.empiricalQuality, other.empiricalQuality, 0.001) == 0; -// } - //--------------------------------------------------------------------------------------------------------------- // // increment methods @@ -270,9 +254,7 @@ public class RecalDatum { @Ensures({"numObservations == old(numObservations) + 1", "numMismatches >= old(numMismatches)"}) public synchronized void increment(final boolean isError) { - incrementNumObservations(1); - if ( isError ) - incrementNumMismatches(1); + increment(1, isError ? 1 : 0.0); } // ------------------------------------------------------------------------------------- @@ -286,7 +268,7 @@ public class RecalDatum { */ @Requires("empiricalQuality == UNINITIALIZED") @Ensures("empiricalQuality != UNINITIALIZED") - private synchronized final void calcEmpiricalQuality() { + private synchronized void calcEmpiricalQuality() { final double empiricalQual = -10 * Math.log10(getEmpiricalErrorRate()); empiricalQuality = Math.min(empiricalQual, (double) QualityUtils.MAX_RECALIBRATED_Q_SCORE); } diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ReadGroupCovariate.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ReadGroupCovariate.java index 29c15adf7..47f11312a 100755 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ReadGroupCovariate.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/covariates/ReadGroupCovariate.java @@ -1,11 +1,13 @@ package org.broadinstitute.sting.utils.recalibration.covariates; -import org.broadinstitute.sting.utils.recalibration.ReadCovariates; import org.broadinstitute.sting.gatk.walkers.bqsr.RecalibrationArgumentCollection; +import org.broadinstitute.sting.utils.recalibration.ReadCovariates; import org.broadinstitute.sting.utils.sam.GATKSAMReadGroupRecord; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; import java.util.HashMap; +import java.util.Map; +import java.util.Set; /* * Copyright (c) 2009 The Broad Institute @@ -77,6 +79,14 @@ public class ReadGroupCovariate implements RequiredCovariate { return keyForReadGroup((String) value); } + /** + * Get the mapping from read group names to integer key values for all read groups in this covariate + * @return a set of mappings from read group names -> integer key values + */ + public Set> getKeyMap() { + return readGroupLookupTable.entrySet(); + } + private int keyForReadGroup(final String readGroupId) { // Rather than synchronize this entire method (which would be VERY expensive for walkers like the BQSR), // synchronize only the table updates.