Optimizations for the RecalDatum to make BQSR (Count Covariates) much faster. Needs some cleanup.

This commit is contained in:
Eric Banks 2012-07-03 13:05:11 -04:00
parent 031322ff00
commit 0b37d44b0d
5 changed files with 61 additions and 85 deletions

View File

@ -48,7 +48,7 @@ public class BQSRGatherer extends Gatherer {
@Override
public void gather(List<File> inputs, File output) {
RecalibrationReport generalReport = null;
PrintStream outputFile;
final PrintStream outputFile;
try {
outputFile = new PrintStream(output);
} catch(FileNotFoundException e) {
@ -56,7 +56,7 @@ public class BQSRGatherer extends Gatherer {
}
for (File input : inputs) {
RecalibrationReport inputReport = new RecalibrationReport(input);
final RecalibrationReport inputReport = new RecalibrationReport(input);
if (generalReport == null)
generalReport = inputReport;
else
@ -65,19 +65,17 @@ public class BQSRGatherer extends Gatherer {
if (generalReport == null)
throw new ReviewedStingException(EMPTY_INPUT_LIST);
generalReport.calculateEmpiricalAndQuantizedQualities();
generalReport.calculateQuantizedQualities();
RecalibrationArgumentCollection RAC = generalReport.getRAC();
if (RAC.recalibrationReport != null && !RAC.NO_PLOTS) {
File recal_out = new File(output.getName() + ".original");
RecalibrationReport originalReport = new RecalibrationReport(RAC.recalibrationReport);
// TODO -- fix me
//RecalDataManager.generateRecalibrationPlot(recal_out, originalReport.getKeysAndTablesMap(), generalReport.getKeysAndTablesMap(), RAC.KEEP_INTERMEDIATE_FILES);
final File recal_out = new File(output.getName() + ".original");
final RecalibrationReport originalReport = new RecalibrationReport(RAC.recalibrationReport);
RecalDataManager.generateRecalibrationPlot(recal_out, originalReport.getRecalibrationTables(), generalReport.getRecalibrationTables(), generalReport.getCovariates(), RAC.KEEP_INTERMEDIATE_FILES);
}
else if (!RAC.NO_PLOTS) {
File recal_out = new File(output.getName() + ".recal");
// TODO -- fix me
//RecalDataManager.generateRecalibrationPlot(recal_out, generalReport.getKeysAndTablesMap(), RAC.KEEP_INTERMEDIATE_FILES);
final File recal_out = new File(output.getName() + ".recal");
RecalDataManager.generateRecalibrationPlot(recal_out, generalReport.getRecalibrationTables(), generalReport.getCovariates(), RAC.KEEP_INTERMEDIATE_FILES);
}
generalReport.output(outputFile);

View File

@ -71,6 +71,11 @@ public class Datum {
numMismatches += incMismatches;
}
synchronized void increment(final boolean isError) {
numObservations++;
numMismatches += isError ? 1:0;
}
//---------------------------------------------------------------------------------------------------------------
//
// methods to derive empirical quality score

View File

@ -25,7 +25,10 @@ package org.broadinstitute.sting.gatk.walkers.bqsr;
* OTHER DEALINGS IN THE SOFTWARE.
*/
import com.google.java.contract.Ensures;
import com.google.java.contract.Requires;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.QualityUtils;
import java.util.Random;
@ -39,6 +42,8 @@ import java.util.Random;
public class RecalDatum extends Datum {
private static final double UNINITIALIZED = -1.0;
private double estimatedQReported; // estimated reported quality score based on combined data's individual q-reporteds and number of observations
private double empiricalQuality; // the empirical quality for datums that have been collapsed together (by read group and reported quality, for example)
@ -49,18 +54,10 @@ public class RecalDatum extends Datum {
//
//---------------------------------------------------------------------------------------------------------------
public RecalDatum() {
numObservations = 0L;
numMismatches = 0L;
estimatedQReported = 0.0;
empiricalQuality = -1.0;
}
public RecalDatum(final long _numObservations, final long _numMismatches, final double _estimatedQReported, final double _empiricalQuality) {
public RecalDatum(final long _numObservations, final long _numMismatches, final byte reportedQuality) {
numObservations = _numObservations;
numMismatches = _numMismatches;
estimatedQReported = _estimatedQReported;
empiricalQuality = _empiricalQuality;
estimatedQReported = QualityUtils.qualToErrorProb(reportedQuality);
}
public RecalDatum(final RecalDatum copy) {
@ -72,36 +69,39 @@ public class RecalDatum extends Datum {
public void combine(final RecalDatum other) {
final double sumErrors = this.calcExpectedErrors() + other.calcExpectedErrors();
this.increment(other.numObservations, other.numMismatches);
this.estimatedQReported = -10 * Math.log10(sumErrors / this.numObservations);
this.empiricalQuality = -1.0; // reset the empirical quality calculation so we never have a wrongly calculated empirical quality stored
increment(other.numObservations, other.numMismatches);
estimatedQReported = -10 * Math.log10(sumErrors / this.numObservations);
empiricalQuality = UNINITIALIZED;
}
public final void calcCombinedEmpiricalQuality() {
this.empiricalQuality = empiricalQualDouble(); // cache the value so we don't call log over and over again
@Override
public void increment(final boolean isError) {
super.increment(isError);
empiricalQuality = UNINITIALIZED;
}
public final void calcEstimatedReportedQuality() {
this.estimatedQReported = -10 * Math.log10(calcExpectedErrors() / numObservations);
@Requires("empiricalQuality == UNINITIALIZED")
@Ensures("empiricalQuality != UNINITIALIZED")
protected final void calcEmpiricalQuality() {
empiricalQuality = empiricalQualDouble(); // cache the value so we don't call log over and over again
}
public void setEstimatedQReported(final double estimatedQReported) {
this.estimatedQReported = estimatedQReported;
}
public final double getEstimatedQReported() {
return estimatedQReported;
}
public final double getEmpiricalQuality() {
if (empiricalQuality < 0)
calcCombinedEmpiricalQuality();
return empiricalQuality;
public void setEmpiricalQuality(final double empiricalQuality) {
this.empiricalQuality = empiricalQuality;
}
/**
* Makes a hard copy of the recal datum element
*
* @return a new recal datum object with the same contents of this datum.
*/
public RecalDatum copy() {
return new RecalDatum(numObservations, numMismatches, estimatedQReported, empiricalQuality);
public final double getEmpiricalQuality() {
if (empiricalQuality == UNINITIALIZED)
calcEmpiricalQuality();
return empiricalQuality;
}
@Override
@ -122,13 +122,11 @@ public class RecalDatum extends Datum {
}
public static RecalDatum createRandomRecalDatum(int maxObservations, int maxErrors) {
Random random = new Random();
int nObservations = random.nextInt(maxObservations);
int nErrors = random.nextInt(maxErrors);
Datum datum = new Datum(nObservations, nErrors);
double empiricalQuality = datum.empiricalQualDouble();
double estimatedQReported = empiricalQuality + ((10 * random.nextDouble()) - 5); // empirical quality +/- 5.
return new RecalDatum(nObservations, nErrors, estimatedQReported, empiricalQuality);
final Random random = new Random();
final int nObservations = random.nextInt(maxObservations);
final int nErrors = random.nextInt(maxErrors);
final int qual = random.nextInt(QualityUtils.MAX_QUAL_SCORE);
return new RecalDatum(nObservations, nErrors, (byte)qual);
}
/**

View File

@ -57,7 +57,6 @@ public class RecalibrationReport {
for (Covariate cov : requestedCovariates)
cov.initialize(RAC); // initialize any covariate member variables using the shared argument collection
// TODO -- note that we might be able to save memory (esp. for sparse tables) by making one pass through the GATK report to see the maximum values for each covariate and using that in the constructor here
recalibrationTables = new RecalibrationTables(requestedCovariates, countReadGroups(report.getTable(RecalDataManager.READGROUP_REPORT_TABLE_TITLE)));
parseReadGroupTable(report.getTable(RecalDataManager.READGROUP_REPORT_TABLE_TITLE), recalibrationTables.getTable(RecalibrationTables.TableType.READ_GROUP_TABLE));
@ -202,7 +201,10 @@ public class RecalibrationReport {
(Double) reportTable.get(row, RecalDataManager.ESTIMATED_Q_REPORTED_COLUMN_NAME) : // we get it if we are in the read group table
Byte.parseByte((String) reportTable.get(row, RecalDataManager.QUALITY_SCORE_COLUMN_NAME)); // or we use the reported quality if we are in any other table
return new RecalDatum(nObservations, nErrors, estimatedQReported, empiricalQuality);
final RecalDatum datum = new RecalDatum(nObservations, nErrors, (byte)1);
datum.setEstimatedQReported(estimatedQReported);
datum.setEmpiricalQuality(empiricalQuality);
return datum;
}
/**
@ -297,14 +299,7 @@ public class RecalibrationReport {
* this functionality avoids recalculating the empirical qualities, estimated reported quality
* and quantization of the quality scores during every call of combine(). Very useful for the BQSRGatherer.
*/
public void calculateEmpiricalAndQuantizedQualities() {
for (RecalibrationTables.TableType type : RecalibrationTables.TableType.values()) {
final NestedIntegerArray table = recalibrationTables.getTable(type);
for (final Object value : table.getAllValues()) {
((RecalDatum)value).calcCombinedEmpiricalQuality();
}
}
public void calculateQuantizedQualities() {
quantizationInfo = new QuantizationInfo(recalibrationTables, RAC.QUANTIZING_LEVELS);
}
@ -315,4 +310,8 @@ public class RecalibrationReport {
public RecalibrationArgumentCollection getRAC() {
return RAC;
}
public Covariate[] getCovariates() {
return requestedCovariates;
}
}

View File

@ -82,15 +82,15 @@ public class BaseRecalibrationUnitTest {
final Object[] objKey = buildObjectKey(bitKeys);
Random random = new Random();
int nObservations = random.nextInt(10000);
int nErrors = random.nextInt(10);
double estimatedQReported = 30;
final int nObservations = random.nextInt(10000);
final int nErrors = random.nextInt(10);
final byte estimatedQReported = 30;
double empiricalQuality = calcEmpiricalQual(nObservations, nErrors);
org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum oldDatum = new org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum(nObservations, nErrors, estimatedQReported, empiricalQuality);
dataManager.addToAllTables(objKey, oldDatum, QualityUtils.MIN_USABLE_Q_SCORE);
RecalDatum newDatum = new RecalDatum(nObservations, nErrors, estimatedQReported, empiricalQuality);
RecalDatum newDatum = new RecalDatum(nObservations, nErrors, estimatedQReported);
rgTable.put(newDatum, bitKeys[0], EventType.BASE_SUBSTITUTION.index);
qualTable.put(newDatum, bitKeys[0], bitKeys[1], EventType.BASE_SUBSTITUTION.index);
@ -100,7 +100,7 @@ public class BaseRecalibrationUnitTest {
}
}
dataManager.generateEmpiricalQualities(1, QualityUtils.MAX_RECALIBRATED_Q_SCORE);
dataManager.generateEmpiricalQualities(1, QualityUtils.MAX_RECALIBRATED_Q_SCORE);
List<Byte> quantizedQuals = new ArrayList<Byte>();
List<Long> qualCounts = new ArrayList<Long>();
@ -135,30 +135,6 @@ public class BaseRecalibrationUnitTest {
return key;
}
private static void printNestedHashMap(Map table, String output) {
for (Object key : table.keySet()) {
String ret;
if (output.isEmpty())
ret = "" + key;
else
ret = output + "," + key;
Object next = table.get(key);
if (next instanceof org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum)
System.out.println(ret + " => " + next);
else
printNestedHashMap((Map) next, "" + ret);
}
}
private void updateCovariateWithKeySet(final Map<Long, RecalDatum> recalTable, final Long hashKey, final RecalDatum datum) {
RecalDatum previousDatum = recalTable.get(hashKey); // using the list of covariate values as a key, pick out the RecalDatum from the data HashMap
if (previousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it
recalTable.put(hashKey, datum.copy());
else
previousDatum.combine(datum); // add one to the number of observations and potentially one to the number of mismatches
}
/**
* Implements a serial recalibration of the reads using the combinational table.
* First, we perform a positional recalibration, and then a subsequent dinuc correction.