From eb22cd7222bf28facdd894e5065819c67a5b9d0d Mon Sep 17 00:00:00 2001 From: Mauricio Carneiro Date: Wed, 18 Apr 2012 23:02:10 -0400 Subject: [PATCH] Unit test to guarantee BQSR sequential calculation accuracy This test brings together the old and the new BQSR, building a recalibration table using the two separate frameworks and performing the recalibration calculation using the two different frameworks for 10,000+ bases and asserting that the calculations match in every case. --- .../sting/gatk/walkers/bqsr/RecalDatum.java | 2 +- .../walkers/recalibration/RecalDatum.java | 6 + .../recalibration/BaseRecalibration.java | 15 +- .../BaseRecalibrationUnitTest.java | 287 +++++++++++++++++- 4 files changed, 297 insertions(+), 13 deletions(-) diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalDatum.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalDatum.java index d232fde81..c71a00a3a 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalDatum.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/bqsr/RecalDatum.java @@ -102,7 +102,7 @@ public class RecalDatum extends Datum { @Override public String toString() { - return String.format("%d,%d,%d", numObservations, numMismatches, (byte) Math.floor(getEmpiricalQuality())); + return String.format("%d,%d,%d,%d", numObservations, numMismatches, (byte) Math.floor(getEmpiricalQuality()), (byte) Math.floor(getEstimatedQReported())); } diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/recalibration/RecalDatum.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/recalibration/RecalDatum.java index adc352b1b..aa9098549 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/recalibration/RecalDatum.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/recalibration/RecalDatum.java @@ -109,4 +109,10 @@ public class RecalDatum extends RecalDatumOptimized { private double qualToErrorProb( final double qual ) { return Math.pow(10.0, qual / -10.0); } + + + @Override + public String toString() { + return String.format("%d,%d,%d,%d", numObservations, numMismatches, (byte) Math.floor(getEmpiricalQuality()), (byte) Math.floor(getEstimatedQReported())); + } } diff --git a/public/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java b/public/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java index 2badca44c..70eb9426b 100644 --- a/public/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java +++ b/public/java/src/org/broadinstitute/sting/utils/recalibration/BaseRecalibration.java @@ -65,6 +65,19 @@ public class BaseRecalibration { quantizationInfo.quantizeQualityScores(quantizationLevels); } + /** + * This constructor only exists for testing purposes. + * + * @param quantizationInfo + * @param keysAndTablesMap + * @param requestedCovariates + */ + protected BaseRecalibration(QuantizationInfo quantizationInfo, LinkedHashMap> keysAndTablesMap, ArrayList requestedCovariates) { + this.quantizationInfo = quantizationInfo; + this.keysAndTablesMap = keysAndTablesMap; + this.requestedCovariates = requestedCovariates; + } + /** * Recalibrates the base qualities of a read * @@ -110,7 +123,7 @@ public class BaseRecalibration { * @param errorModel the event type * @return A recalibrated quality score as a byte */ - private byte performSequentialQualityCalculation(BitSet[] key, EventType errorModel) { + protected byte performSequentialQualityCalculation(BitSet[] key, EventType errorModel) { final String UNRECOGNIZED_REPORT_TABLE_EXCEPTION = "Unrecognized table. Did you add an extra required covariate? This is a hard check that needs propagate through the code"; final String TOO_MANY_KEYS_EXCEPTION = "There should only be one key for the RG collapsed table, something went wrong here"; diff --git a/public/java/test/org/broadinstitute/sting/utils/recalibration/BaseRecalibrationUnitTest.java b/public/java/test/org/broadinstitute/sting/utils/recalibration/BaseRecalibrationUnitTest.java index f8f1ead9b..4f0d39991 100644 --- a/public/java/test/org/broadinstitute/sting/utils/recalibration/BaseRecalibrationUnitTest.java +++ b/public/java/test/org/broadinstitute/sting/utils/recalibration/BaseRecalibrationUnitTest.java @@ -1,13 +1,17 @@ package org.broadinstitute.sting.utils.recalibration; -import net.sf.samtools.SAMReadGroupRecord; -import org.broadinstitute.sting.utils.NGSPlatform; +import org.broadinstitute.sting.gatk.walkers.bqsr.*; +import org.broadinstitute.sting.utils.QualityUtils; +import org.broadinstitute.sting.utils.Utils; +import org.broadinstitute.sting.utils.collections.NestedHashMap; import org.broadinstitute.sting.utils.sam.GATKSAMReadGroupRecord; import org.broadinstitute.sting.utils.sam.GATKSAMRecord; import org.broadinstitute.sting.utils.sam.ReadUtils; +import org.testng.Assert; +import org.testng.annotations.BeforeClass; import org.testng.annotations.Test; -import java.io.File; +import java.util.*; /** * Unit tests for on-the-fly recalibration. @@ -17,13 +21,274 @@ import java.io.File; */ public class BaseRecalibrationUnitTest { - @Test(enabled=false) - public void testReadingReport() { - File csv = new File("public/testdata/exampleGATKREPORT.grp"); - BaseRecalibration baseRecalibration = new BaseRecalibration(csv, -1); - GATKSAMRecord read = ReadUtils.createRandomRead(1000); - read.setReadGroup(new GATKSAMReadGroupRecord(new SAMReadGroupRecord("exampleBAM.bam.bam"), NGSPlatform.ILLUMINA)); - baseRecalibration.recalibrateRead(read); - System.out.println("Success"); + private org.broadinstitute.sting.gatk.walkers.recalibration.RecalDataManager dataManager; + private LinkedHashMap> keysAndTablesMap; + + private BQSRKeyManager rgKeyManager; + private BQSRKeyManager qsKeyManager; + private BQSRKeyManager cvKeyManager; + + private ReadGroupCovariate rgCovariate; + private QualityScoreCovariate qsCovariate; + private ContextCovariate cxCovariate; + private CycleCovariate cyCovariate; + + private GATKSAMRecord read = ReadUtils.createRandomRead(10000); + private BaseRecalibration baseRecalibration; + private ReadCovariates readCovariates; + + + @BeforeClass + public void init() { + GATKSAMReadGroupRecord rg = new GATKSAMReadGroupRecord("rg"); + rg.setPlatform("illumina"); + read.setReadGroup(rg); + + byte[] quals = new byte[read.getReadLength()]; + for (int i = 0; i < read.getReadLength(); i++) + quals[i] = 20; + read.setBaseQualities(quals); + + RecalibrationArgumentCollection RAC = new RecalibrationArgumentCollection(); + List requiredCovariates = new ArrayList(); + List optionalCovariates = new ArrayList(); + ArrayList requestedCovariates = new ArrayList(); + + dataManager = new org.broadinstitute.sting.gatk.walkers.recalibration.RecalDataManager(true, 4); + keysAndTablesMap = new LinkedHashMap>(); + + rgCovariate = new ReadGroupCovariate(); + rgCovariate.initialize(RAC); + requiredCovariates.add(rgCovariate); + rgKeyManager = new BQSRKeyManager(requiredCovariates, optionalCovariates); + keysAndTablesMap.put(rgKeyManager, new HashMap()); + + qsCovariate = new QualityScoreCovariate(); + qsCovariate.initialize(RAC); + requiredCovariates.add(qsCovariate); + qsKeyManager = new BQSRKeyManager(requiredCovariates, optionalCovariates); + keysAndTablesMap.put(qsKeyManager, new HashMap()); + + cxCovariate = new ContextCovariate(); + cxCovariate.initialize(RAC); + optionalCovariates.add(cxCovariate); + cyCovariate = new CycleCovariate(); + cyCovariate.initialize(RAC); + optionalCovariates.add(cyCovariate); + cvKeyManager = new BQSRKeyManager(requiredCovariates, optionalCovariates); + keysAndTablesMap.put(cvKeyManager, new HashMap()); + + + for (Covariate cov : requiredCovariates) + requestedCovariates.add(cov); + for (Covariate cov : optionalCovariates) + requestedCovariates.add(cov); + + readCovariates = RecalDataManager.computeCovariates(read, requestedCovariates); + + for (int i=0; i> mapEntry : keysAndTablesMap.entrySet()) { + List keys = mapEntry.getKey().bitSetsFromAllKeys(bitKeys, EventType.BASE_SUBSTITUTION); + for (BitSet key : keys) + updateCovariateWithKeySet(mapEntry.getValue(), key, newDatum); + } + } + dataManager.generateEmpiricalQualities(1, QualityUtils.MAX_QUAL_SCORE); + + List quantizedQuals = new ArrayList(); + List qualCounts = new ArrayList(); + for (byte i = 0; i <= QualityUtils.MAX_QUAL_SCORE; i++) { + quantizedQuals.add(i); + qualCounts.add(1L); + } + QuantizationInfo quantizationInfo = new QuantizationInfo(quantizedQuals, qualCounts); + quantizationInfo.noQuantization(); + baseRecalibration = new BaseRecalibration(quantizationInfo, keysAndTablesMap, requestedCovariates); + + } + + + @Test(enabled=true) + public void testGoldStandardComparison() { + debugTables(); + for (int i = 0; i < read.getReadLength(); i++) { + BitSet [] bitKey = readCovariates.getKeySet(i, EventType.BASE_SUBSTITUTION); + Object [] objKey = buildObjectKey(bitKey); + byte v2 = baseRecalibration.performSequentialQualityCalculation(bitKey, EventType.BASE_SUBSTITUTION); + byte v1 = goldStandardSequentialCalculation(objKey); + Assert.assertEquals(v2, v1); + } + } + + private Object[] buildObjectKey(BitSet[] bitKey) { + Object[] key = new Object[bitKey.length]; + key[0] = rgCovariate.keyFromBitSet(bitKey[0]); + key[1] = qsCovariate.keyFromBitSet(bitKey[1]); + key[2] = cxCovariate.keyFromBitSet(bitKey[2]); + key[3] = cyCovariate.keyFromBitSet(bitKey[3]); + return key; + } + + private void debugTables() { + System.out.println("\nV1 Table\n"); + System.out.println("ReadGroup Table:"); + NestedHashMap nestedTable = dataManager.getCollapsedTable(0); + printNestedHashMap(nestedTable.data, ""); + System.out.println("\nQualityScore Table:"); + nestedTable = dataManager.getCollapsedTable(1); + printNestedHashMap(nestedTable.data, ""); + System.out.println("\nCovariates Table:"); + nestedTable = dataManager.getCollapsedTable(2); + printNestedHashMap(nestedTable.data, ""); + nestedTable = dataManager.getCollapsedTable(3); + printNestedHashMap(nestedTable.data, ""); + + + int i = 0; + System.out.println("\nV2 Table\n"); + for (Map.Entry> mapEntry : keysAndTablesMap.entrySet()) { + BQSRKeyManager keyManager = mapEntry.getKey(); + Map table = mapEntry.getValue(); + switch(i++) { + case 0 : + System.out.println("ReadGroup Table:"); + break; + case 1 : + System.out.println("QualityScore Table:"); + break; + case 2 : + System.out.println("Covariates Table:"); + break; + } + for (Map.Entry entry : table.entrySet()) { + BitSet key = entry.getKey(); + RecalDatum datum = entry.getValue(); + List keySet = keyManager.keySetFrom(key); + System.out.println(String.format("%s => %s", Utils.join(",", keySet), datum)); + } + System.out.println(); + } + + + } + + private static void printNestedHashMap(Map table, String output) { + for (Object key : table.keySet()) { + String ret = ""; + if (output.isEmpty()) + ret = "" + key; + else + ret = output + "," + key; + + Object next = table.get(key); + if (next instanceof org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum) + System.out.println(ret + " => " + next); + else + printNestedHashMap((Map) next, "" + ret); + } + } + + private void updateCovariateWithKeySet(final Map recalTable, final BitSet hashKey, final RecalDatum datum) { + RecalDatum previousDatum = recalTable.get(hashKey); // using the list of covariate values as a key, pick out the RecalDatum from the data HashMap + if (previousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it + recalTable.put(hashKey, datum.copy()); + else + previousDatum.combine(datum); // add one to the number of observations and potentially one to the number of mismatches + } + + /** + * Implements a serial recalibration of the reads using the combinational table. + * First, we perform a positional recalibration, and then a subsequent dinuc correction. + * + * Given the full recalibration table, we perform the following preprocessing steps: + * + * - calculate the global quality score shift across all data [DeltaQ] + * - calculate for each of cycle and dinuc the shift of the quality scores relative to the global shift + * -- i.e., DeltaQ(dinuc) = Sum(pos) Sum(Qual) Qempirical(pos, qual, dinuc) - Qreported(pos, qual, dinuc) / Npos * Nqual + * - The final shift equation is: + * + * Qrecal = Qreported + DeltaQ + DeltaQ(pos) + DeltaQ(dinuc) + DeltaQ( ... any other covariate ... ) + * + * @param key The list of Comparables that were calculated from the covariates + * @return A recalibrated quality score as a byte + */ + private byte goldStandardSequentialCalculation(final Object... key) { + + final byte qualFromRead = (byte) Integer.parseInt(key[1].toString()); + final Object[] readGroupCollapsedKey = new Object[1]; + final Object[] qualityScoreCollapsedKey = new Object[2]; + final Object[] covariateCollapsedKey = new Object[3]; + + // The global quality shift (over the read group only) + readGroupCollapsedKey[0] = key[0]; + final org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum globalRecalDatum = ((org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum) dataManager.getCollapsedTable(0).get(readGroupCollapsedKey)); + double globalDeltaQ = 0.0; + if (globalRecalDatum != null) { + final double globalDeltaQEmpirical = globalRecalDatum.getEmpiricalQuality(); + final double aggregrateQReported = globalRecalDatum.getEstimatedQReported(); + globalDeltaQ = globalDeltaQEmpirical - aggregrateQReported; + } + + // The shift in quality between reported and empirical + qualityScoreCollapsedKey[0] = key[0]; + qualityScoreCollapsedKey[1] = key[1]; + final org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum qReportedRecalDatum = ((org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum) dataManager.getCollapsedTable(1).get(qualityScoreCollapsedKey)); + double deltaQReported = 0.0; + if (qReportedRecalDatum != null) { + final double deltaQReportedEmpirical = qReportedRecalDatum.getEmpiricalQuality(); + deltaQReported = deltaQReportedEmpirical - qualFromRead - globalDeltaQ; + } + + // The shift in quality due to each covariate by itself in turn + double deltaQCovariates = 0.0; + double deltaQCovariateEmpirical; + covariateCollapsedKey[0] = key[0]; + covariateCollapsedKey[1] = key[1]; + for (int iii = 2; iii < key.length; iii++) { + covariateCollapsedKey[2] = key[iii]; // The given covariate + final org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum covariateRecalDatum = ((org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum) dataManager.getCollapsedTable(iii).get(covariateCollapsedKey)); + if (covariateRecalDatum != null) { + deltaQCovariateEmpirical = covariateRecalDatum.getEmpiricalQuality(); + deltaQCovariates += (deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported)); + } + } + + final double newQuality = qualFromRead + globalDeltaQ + deltaQReported + deltaQCovariates; + return QualityUtils.boundQual((int) Math.round(newQuality), QualityUtils.MAX_QUAL_SCORE); + + // Verbose printouts used to validate with old recalibrator + //if(key.contains(null)) { + // System.out.println( key + String.format(" => %d + %.2f + %.2f + %.2f + %.2f = %d", + // qualFromRead, globalDeltaQ, deltaQReported, deltaQPos, deltaQDinuc, newQualityByte)); + //} + //else { + // System.out.println( String.format("%s %s %s %s => %d + %.2f + %.2f + %.2f + %.2f = %d", + // key.get(0).toString(), key.get(3).toString(), key.get(2).toString(), key.get(1).toString(), qualFromRead, globalDeltaQ, deltaQReported, deltaQPos, deltaQDinuc, newQualityByte) ); + //} + + //return newQualityByte; + } + + public static double calcEmpiricalQual(final int observations, final int errors) { + final int smoothing = 1; + final double doubleMismatches = (double) (errors + smoothing); + final double doubleObservations = (double) ( observations + smoothing ); + double empiricalQual = -10 * Math.log10(doubleMismatches / doubleObservations); + return Math.min(QualityUtils.MAX_QUAL_SCORE, empiricalQual); } }