Unit test to guarantee BQSR sequential calculation accuracy

This test brings together the old and the new BQSR, building a recalibration table using the two separate frameworks and performing the recalibration calculation using the two different frameworks for 10,000+ bases and asserting that the calculations match in every case.
This commit is contained in:
Mauricio Carneiro 2012-04-18 23:02:10 -04:00
parent 68d0211fa1
commit eb22cd7222
4 changed files with 297 additions and 13 deletions

View File

@ -102,7 +102,7 @@ public class RecalDatum extends Datum {
@Override
public String toString() {
return String.format("%d,%d,%d", numObservations, numMismatches, (byte) Math.floor(getEmpiricalQuality()));
return String.format("%d,%d,%d,%d", numObservations, numMismatches, (byte) Math.floor(getEmpiricalQuality()), (byte) Math.floor(getEstimatedQReported()));
}

View File

@ -109,4 +109,10 @@ public class RecalDatum extends RecalDatumOptimized {
private double qualToErrorProb( final double qual ) {
return Math.pow(10.0, qual / -10.0);
}
@Override
public String toString() {
return String.format("%d,%d,%d,%d", numObservations, numMismatches, (byte) Math.floor(getEmpiricalQuality()), (byte) Math.floor(getEstimatedQReported()));
}
}

View File

@ -65,6 +65,19 @@ public class BaseRecalibration {
quantizationInfo.quantizeQualityScores(quantizationLevels);
}
/**
* This constructor only exists for testing purposes.
*
* @param quantizationInfo
* @param keysAndTablesMap
* @param requestedCovariates
*/
protected BaseRecalibration(QuantizationInfo quantizationInfo, LinkedHashMap<BQSRKeyManager, Map<BitSet, RecalDatum>> keysAndTablesMap, ArrayList<Covariate> requestedCovariates) {
this.quantizationInfo = quantizationInfo;
this.keysAndTablesMap = keysAndTablesMap;
this.requestedCovariates = requestedCovariates;
}
/**
* Recalibrates the base qualities of a read
*
@ -110,7 +123,7 @@ public class BaseRecalibration {
* @param errorModel the event type
* @return A recalibrated quality score as a byte
*/
private byte performSequentialQualityCalculation(BitSet[] key, EventType errorModel) {
protected byte performSequentialQualityCalculation(BitSet[] key, EventType errorModel) {
final String UNRECOGNIZED_REPORT_TABLE_EXCEPTION = "Unrecognized table. Did you add an extra required covariate? This is a hard check that needs propagate through the code";
final String TOO_MANY_KEYS_EXCEPTION = "There should only be one key for the RG collapsed table, something went wrong here";

View File

@ -1,13 +1,17 @@
package org.broadinstitute.sting.utils.recalibration;
import net.sf.samtools.SAMReadGroupRecord;
import org.broadinstitute.sting.utils.NGSPlatform;
import org.broadinstitute.sting.gatk.walkers.bqsr.*;
import org.broadinstitute.sting.utils.QualityUtils;
import org.broadinstitute.sting.utils.Utils;
import org.broadinstitute.sting.utils.collections.NestedHashMap;
import org.broadinstitute.sting.utils.sam.GATKSAMReadGroupRecord;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
import org.broadinstitute.sting.utils.sam.ReadUtils;
import org.testng.Assert;
import org.testng.annotations.BeforeClass;
import org.testng.annotations.Test;
import java.io.File;
import java.util.*;
/**
* Unit tests for on-the-fly recalibration.
@ -17,13 +21,274 @@ import java.io.File;
*/
public class BaseRecalibrationUnitTest {
@Test(enabled=false)
public void testReadingReport() {
File csv = new File("public/testdata/exampleGATKREPORT.grp");
BaseRecalibration baseRecalibration = new BaseRecalibration(csv, -1);
GATKSAMRecord read = ReadUtils.createRandomRead(1000);
read.setReadGroup(new GATKSAMReadGroupRecord(new SAMReadGroupRecord("exampleBAM.bam.bam"), NGSPlatform.ILLUMINA));
baseRecalibration.recalibrateRead(read);
System.out.println("Success");
private org.broadinstitute.sting.gatk.walkers.recalibration.RecalDataManager dataManager;
private LinkedHashMap<BQSRKeyManager, Map<BitSet, RecalDatum>> keysAndTablesMap;
private BQSRKeyManager rgKeyManager;
private BQSRKeyManager qsKeyManager;
private BQSRKeyManager cvKeyManager;
private ReadGroupCovariate rgCovariate;
private QualityScoreCovariate qsCovariate;
private ContextCovariate cxCovariate;
private CycleCovariate cyCovariate;
private GATKSAMRecord read = ReadUtils.createRandomRead(10000);
private BaseRecalibration baseRecalibration;
private ReadCovariates readCovariates;
@BeforeClass
public void init() {
GATKSAMReadGroupRecord rg = new GATKSAMReadGroupRecord("rg");
rg.setPlatform("illumina");
read.setReadGroup(rg);
byte[] quals = new byte[read.getReadLength()];
for (int i = 0; i < read.getReadLength(); i++)
quals[i] = 20;
read.setBaseQualities(quals);
RecalibrationArgumentCollection RAC = new RecalibrationArgumentCollection();
List<Covariate> requiredCovariates = new ArrayList<Covariate>();
List<Covariate> optionalCovariates = new ArrayList<Covariate>();
ArrayList<Covariate> requestedCovariates = new ArrayList<Covariate>();
dataManager = new org.broadinstitute.sting.gatk.walkers.recalibration.RecalDataManager(true, 4);
keysAndTablesMap = new LinkedHashMap<BQSRKeyManager, Map<BitSet, RecalDatum>>();
rgCovariate = new ReadGroupCovariate();
rgCovariate.initialize(RAC);
requiredCovariates.add(rgCovariate);
rgKeyManager = new BQSRKeyManager(requiredCovariates, optionalCovariates);
keysAndTablesMap.put(rgKeyManager, new HashMap<BitSet, RecalDatum>());
qsCovariate = new QualityScoreCovariate();
qsCovariate.initialize(RAC);
requiredCovariates.add(qsCovariate);
qsKeyManager = new BQSRKeyManager(requiredCovariates, optionalCovariates);
keysAndTablesMap.put(qsKeyManager, new HashMap<BitSet, RecalDatum>());
cxCovariate = new ContextCovariate();
cxCovariate.initialize(RAC);
optionalCovariates.add(cxCovariate);
cyCovariate = new CycleCovariate();
cyCovariate.initialize(RAC);
optionalCovariates.add(cyCovariate);
cvKeyManager = new BQSRKeyManager(requiredCovariates, optionalCovariates);
keysAndTablesMap.put(cvKeyManager, new HashMap<BitSet, RecalDatum>());
for (Covariate cov : requiredCovariates)
requestedCovariates.add(cov);
for (Covariate cov : optionalCovariates)
requestedCovariates.add(cov);
readCovariates = RecalDataManager.computeCovariates(read, requestedCovariates);
for (int i=0; i<read.getReadLength(); i++) {
BitSet[] bitKeys = readCovariates.getMismatchesKeySet(i);
Object[] objKey = buildObjectKey(bitKeys);
Random random = new Random();
int nObservations = random.nextInt(10000);
int nErrors = random.nextInt(10);
double estimatedQReported = 30;
double empiricalQuality = calcEmpiricalQual(nObservations, nErrors);
org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum oldDatum = new org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum(nObservations, nErrors, estimatedQReported, empiricalQuality);
dataManager.addToAllTables(objKey, oldDatum, QualityUtils.MIN_USABLE_Q_SCORE);
RecalDatum newDatum = new RecalDatum(nObservations, nErrors, estimatedQReported, empiricalQuality);
for (Map.Entry<BQSRKeyManager, Map<BitSet, RecalDatum>> mapEntry : keysAndTablesMap.entrySet()) {
List<BitSet> keys = mapEntry.getKey().bitSetsFromAllKeys(bitKeys, EventType.BASE_SUBSTITUTION);
for (BitSet key : keys)
updateCovariateWithKeySet(mapEntry.getValue(), key, newDatum);
}
}
dataManager.generateEmpiricalQualities(1, QualityUtils.MAX_QUAL_SCORE);
List<Byte> quantizedQuals = new ArrayList<Byte>();
List<Long> qualCounts = new ArrayList<Long>();
for (byte i = 0; i <= QualityUtils.MAX_QUAL_SCORE; i++) {
quantizedQuals.add(i);
qualCounts.add(1L);
}
QuantizationInfo quantizationInfo = new QuantizationInfo(quantizedQuals, qualCounts);
quantizationInfo.noQuantization();
baseRecalibration = new BaseRecalibration(quantizationInfo, keysAndTablesMap, requestedCovariates);
}
@Test(enabled=true)
public void testGoldStandardComparison() {
debugTables();
for (int i = 0; i < read.getReadLength(); i++) {
BitSet [] bitKey = readCovariates.getKeySet(i, EventType.BASE_SUBSTITUTION);
Object [] objKey = buildObjectKey(bitKey);
byte v2 = baseRecalibration.performSequentialQualityCalculation(bitKey, EventType.BASE_SUBSTITUTION);
byte v1 = goldStandardSequentialCalculation(objKey);
Assert.assertEquals(v2, v1);
}
}
private Object[] buildObjectKey(BitSet[] bitKey) {
Object[] key = new Object[bitKey.length];
key[0] = rgCovariate.keyFromBitSet(bitKey[0]);
key[1] = qsCovariate.keyFromBitSet(bitKey[1]);
key[2] = cxCovariate.keyFromBitSet(bitKey[2]);
key[3] = cyCovariate.keyFromBitSet(bitKey[3]);
return key;
}
private void debugTables() {
System.out.println("\nV1 Table\n");
System.out.println("ReadGroup Table:");
NestedHashMap nestedTable = dataManager.getCollapsedTable(0);
printNestedHashMap(nestedTable.data, "");
System.out.println("\nQualityScore Table:");
nestedTable = dataManager.getCollapsedTable(1);
printNestedHashMap(nestedTable.data, "");
System.out.println("\nCovariates Table:");
nestedTable = dataManager.getCollapsedTable(2);
printNestedHashMap(nestedTable.data, "");
nestedTable = dataManager.getCollapsedTable(3);
printNestedHashMap(nestedTable.data, "");
int i = 0;
System.out.println("\nV2 Table\n");
for (Map.Entry<BQSRKeyManager, Map<BitSet, RecalDatum>> mapEntry : keysAndTablesMap.entrySet()) {
BQSRKeyManager keyManager = mapEntry.getKey();
Map<BitSet, RecalDatum> table = mapEntry.getValue();
switch(i++) {
case 0 :
System.out.println("ReadGroup Table:");
break;
case 1 :
System.out.println("QualityScore Table:");
break;
case 2 :
System.out.println("Covariates Table:");
break;
}
for (Map.Entry<BitSet, RecalDatum> entry : table.entrySet()) {
BitSet key = entry.getKey();
RecalDatum datum = entry.getValue();
List<Object> keySet = keyManager.keySetFrom(key);
System.out.println(String.format("%s => %s", Utils.join(",", keySet), datum));
}
System.out.println();
}
}
private static void printNestedHashMap(Map<Object,Object> table, String output) {
for (Object key : table.keySet()) {
String ret = "";
if (output.isEmpty())
ret = "" + key;
else
ret = output + "," + key;
Object next = table.get(key);
if (next instanceof org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum)
System.out.println(ret + " => " + next);
else
printNestedHashMap((Map<Object, Object>) next, "" + ret);
}
}
private void updateCovariateWithKeySet(final Map<BitSet, RecalDatum> recalTable, final BitSet hashKey, final RecalDatum datum) {
RecalDatum previousDatum = recalTable.get(hashKey); // using the list of covariate values as a key, pick out the RecalDatum from the data HashMap
if (previousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it
recalTable.put(hashKey, datum.copy());
else
previousDatum.combine(datum); // add one to the number of observations and potentially one to the number of mismatches
}
/**
* Implements a serial recalibration of the reads using the combinational table.
* First, we perform a positional recalibration, and then a subsequent dinuc correction.
*
* Given the full recalibration table, we perform the following preprocessing steps:
*
* - calculate the global quality score shift across all data [DeltaQ]
* - calculate for each of cycle and dinuc the shift of the quality scores relative to the global shift
* -- i.e., DeltaQ(dinuc) = Sum(pos) Sum(Qual) Qempirical(pos, qual, dinuc) - Qreported(pos, qual, dinuc) / Npos * Nqual
* - The final shift equation is:
*
* Qrecal = Qreported + DeltaQ + DeltaQ(pos) + DeltaQ(dinuc) + DeltaQ( ... any other covariate ... )
*
* @param key The list of Comparables that were calculated from the covariates
* @return A recalibrated quality score as a byte
*/
private byte goldStandardSequentialCalculation(final Object... key) {
final byte qualFromRead = (byte) Integer.parseInt(key[1].toString());
final Object[] readGroupCollapsedKey = new Object[1];
final Object[] qualityScoreCollapsedKey = new Object[2];
final Object[] covariateCollapsedKey = new Object[3];
// The global quality shift (over the read group only)
readGroupCollapsedKey[0] = key[0];
final org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum globalRecalDatum = ((org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum) dataManager.getCollapsedTable(0).get(readGroupCollapsedKey));
double globalDeltaQ = 0.0;
if (globalRecalDatum != null) {
final double globalDeltaQEmpirical = globalRecalDatum.getEmpiricalQuality();
final double aggregrateQReported = globalRecalDatum.getEstimatedQReported();
globalDeltaQ = globalDeltaQEmpirical - aggregrateQReported;
}
// The shift in quality between reported and empirical
qualityScoreCollapsedKey[0] = key[0];
qualityScoreCollapsedKey[1] = key[1];
final org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum qReportedRecalDatum = ((org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum) dataManager.getCollapsedTable(1).get(qualityScoreCollapsedKey));
double deltaQReported = 0.0;
if (qReportedRecalDatum != null) {
final double deltaQReportedEmpirical = qReportedRecalDatum.getEmpiricalQuality();
deltaQReported = deltaQReportedEmpirical - qualFromRead - globalDeltaQ;
}
// The shift in quality due to each covariate by itself in turn
double deltaQCovariates = 0.0;
double deltaQCovariateEmpirical;
covariateCollapsedKey[0] = key[0];
covariateCollapsedKey[1] = key[1];
for (int iii = 2; iii < key.length; iii++) {
covariateCollapsedKey[2] = key[iii]; // The given covariate
final org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum covariateRecalDatum = ((org.broadinstitute.sting.gatk.walkers.recalibration.RecalDatum) dataManager.getCollapsedTable(iii).get(covariateCollapsedKey));
if (covariateRecalDatum != null) {
deltaQCovariateEmpirical = covariateRecalDatum.getEmpiricalQuality();
deltaQCovariates += (deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported));
}
}
final double newQuality = qualFromRead + globalDeltaQ + deltaQReported + deltaQCovariates;
return QualityUtils.boundQual((int) Math.round(newQuality), QualityUtils.MAX_QUAL_SCORE);
// Verbose printouts used to validate with old recalibrator
//if(key.contains(null)) {
// System.out.println( key + String.format(" => %d + %.2f + %.2f + %.2f + %.2f = %d",
// qualFromRead, globalDeltaQ, deltaQReported, deltaQPos, deltaQDinuc, newQualityByte));
//}
//else {
// System.out.println( String.format("%s %s %s %s => %d + %.2f + %.2f + %.2f + %.2f = %d",
// key.get(0).toString(), key.get(3).toString(), key.get(2).toString(), key.get(1).toString(), qualFromRead, globalDeltaQ, deltaQReported, deltaQPos, deltaQDinuc, newQualityByte) );
//}
//return newQualityByte;
}
public static double calcEmpiricalQual(final int observations, final int errors) {
final int smoothing = 1;
final double doubleMismatches = (double) (errors + smoothing);
final double doubleObservations = (double) ( observations + smoothing );
double empiricalQual = -10 * Math.log10(doubleMismatches / doubleObservations);
return Math.min(QualityUtils.MAX_QUAL_SCORE, empiricalQual);
}
}