Cleanup and more unit tests for RecalibrationTables in BQSR

-- Added unit tests for combining RecalibrationTables.  As a side effect now has serious tests for incrementDatumOrPutIfNecessary
-- Removed unnecessary enum.index system from RecalibrationTables.
-- Moved what were really static utility methods out of RecalibrationEngine and into RecalUtils.
This commit is contained in:
Mark DePristo 2013-01-04 17:13:31 -05:00
parent 9df30880cb
commit 69bf70c42e
6 changed files with 175 additions and 90 deletions

View File

@ -27,10 +27,7 @@ package org.broadinstitute.sting.gatk.walkers.bqsr;
import com.google.java.contract.Requires;
import org.broadinstitute.sting.utils.collections.NestedIntegerArray;
import org.broadinstitute.sting.utils.recalibration.EventType;
import org.broadinstitute.sting.utils.recalibration.ReadCovariates;
import org.broadinstitute.sting.utils.recalibration.RecalDatum;
import org.broadinstitute.sting.utils.recalibration.RecalibrationTables;
import org.broadinstitute.sting.utils.recalibration.*;
import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
@ -128,29 +125,19 @@ public class RecalibrationEngine {
final byte qual = recalInfo.getQual(eventType, offset);
final double isError = recalInfo.getErrorFraction(eventType, offset);
incrementDatumOrPutIfNecessary(qualityScoreTable, qual, isError, keys[0], keys[1], eventIndex);
RecalUtils.incrementDatumOrPutIfNecessary(qualityScoreTable, qual, isError, keys[0], keys[1], eventIndex);
for (int i = 2; i < covariates.length; i++) {
if (keys[i] < 0)
continue;
incrementDatumOrPutIfNecessary(tables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex);
RecalUtils.incrementDatumOrPutIfNecessary(tables.getTable(i), qual, isError, keys[0], keys[1], keys[i], eventIndex);
}
}
}
}
}
/**
* creates a datum object with one observation and one or zero error
*
* @param reportedQual the quality score reported by the instrument for this base
* @param isError whether or not the observation is an error
* @return a new RecalDatum object with the observation and the error
*/
protected RecalDatum createDatumObject(final byte reportedQual, final double isError) {
return new RecalDatum(1, isError, reportedQual);
}
/**
* Finalize, if appropriate, all derived data in recalibrationTables.
@ -226,36 +213,4 @@ public class RecalibrationEngine {
if ( ! finalized ) throw new IllegalStateException("Cannot get final recalibration tables until finalizeData() has been called");
return finalRecalibrationTables;
}
/**
* Increments the RecalDatum at the specified position in the specified table, or put a new item there
* if there isn't already one.
*
* Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put()
* to return false if another thread inserts a new item at our position in the middle of our put operation.
*
* @param table the table that holds/will hold our item
* @param qual qual for this event
* @param isError error value for this event
* @param keys location in table of our item
*/
protected void incrementDatumOrPutIfNecessary( final NestedIntegerArray<RecalDatum> table,
final byte qual,
final double isError,
final int... keys ) {
final RecalDatum existingDatum = table.get(keys);
if ( existingDatum == null ) {
// No existing item, try to put a new one
if ( ! table.put(createDatumObject(qual, isError), keys) ) {
// Failed to put a new item because another thread came along and put an item here first.
// Get the newly-put item and increment it (item is guaranteed to exist at this point)
table.get(keys).increment(1.0, isError);
}
}
else {
// Easy case: already an item here, so increment it
existingDatum.increment(1.0, isError);
}
}
}

View File

@ -269,9 +269,9 @@ public class RecalUtils {
final ArrayList<Pair<String, String>> columnNames = new ArrayList<Pair<String, String>>(); // initialize the array to hold the column names
columnNames.add(new Pair<String, String>(covariateNameMap.get(requestedCovariates[0]), "%s")); // save the required covariate name so we can reference it in the future
if (tableIndex != RecalibrationTables.TableType.READ_GROUP_TABLE.index) {
if (tableIndex != RecalibrationTables.TableType.READ_GROUP_TABLE.ordinal()) {
columnNames.add(new Pair<String, String>(covariateNameMap.get(requestedCovariates[1]), "%s")); // save the required covariate name so we can reference it in the future
if (tableIndex >= RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.index) {
if (tableIndex >= RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.ordinal()) {
columnNames.add(covariateValue);
columnNames.add(covariateName);
}
@ -279,13 +279,13 @@ public class RecalUtils {
columnNames.add(eventType); // the order of these column names is important here
columnNames.add(empiricalQuality);
if (tableIndex == RecalibrationTables.TableType.READ_GROUP_TABLE.index)
if (tableIndex == RecalibrationTables.TableType.READ_GROUP_TABLE.ordinal())
columnNames.add(estimatedQReported); // only the read group table needs the estimated Q reported
columnNames.add(nObservations);
columnNames.add(nErrors);
final GATKReportTable reportTable;
if (tableIndex <= RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.index) {
if (tableIndex <= RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.ordinal()) {
if(sortByCols) {
reportTable = new GATKReportTable("RecalTable" + reportTableIndex++, "", columnNames.size(), GATKReportTable.TableSortingWay.SORT_BY_COLUMN);
} else {
@ -295,7 +295,7 @@ public class RecalUtils {
reportTable.addColumn(columnName.getFirst(), columnName.getSecond());
rowIndex = 0; // reset the row index since we're starting with a new table
} else {
reportTable = result.get(RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.index);
reportTable = result.get(RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.ordinal());
}
final NestedIntegerArray<RecalDatum> table = recalibrationTables.getTable(tableIndex);
@ -306,9 +306,9 @@ public class RecalUtils {
int columnIndex = 0;
int keyIndex = 0;
reportTable.set(rowIndex, columnNames.get(columnIndex++).getFirst(), requestedCovariates[0].formatKey(keys[keyIndex++]));
if (tableIndex != RecalibrationTables.TableType.READ_GROUP_TABLE.index) {
if (tableIndex != RecalibrationTables.TableType.READ_GROUP_TABLE.ordinal()) {
reportTable.set(rowIndex, columnNames.get(columnIndex++).getFirst(), requestedCovariates[1].formatKey(keys[keyIndex++]));
if (tableIndex >= RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.index) {
if (tableIndex >= RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.ordinal()) {
final Covariate covariate = requestedCovariates[tableIndex];
reportTable.set(rowIndex, columnNames.get(columnIndex++).getFirst(), covariate.formatKey(keys[keyIndex++]));
@ -320,7 +320,7 @@ public class RecalUtils {
reportTable.set(rowIndex, columnNames.get(columnIndex++).getFirst(), event.toString());
reportTable.set(rowIndex, columnNames.get(columnIndex++).getFirst(), datum.getEmpiricalQuality());
if (tableIndex == RecalibrationTables.TableType.READ_GROUP_TABLE.index)
if (tableIndex == RecalibrationTables.TableType.READ_GROUP_TABLE.ordinal())
reportTable.set(rowIndex, columnNames.get(columnIndex++).getFirst(), datum.getEstimatedQReported()); // we only add the estimated Q reported in the RG table
reportTable.set(rowIndex, columnNames.get(columnIndex++).getFirst(), datum.getNumObservations());
reportTable.set(rowIndex, columnNames.get(columnIndex).getFirst(), datum.getNumMismatches());
@ -414,7 +414,7 @@ public class RecalUtils {
}
// add the optional covariates to the delta table
for (int i = RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.index; i < requestedCovariates.length; i++) {
for (int i = RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.ordinal(); i < requestedCovariates.length; i++) {
final NestedIntegerArray<RecalDatum> covTable = recalibrationTables.getTable(i);
for (final NestedIntegerArray.Leaf leaf : covTable.getAllLeaves()) {
final int[] covs = new int[4];
@ -458,9 +458,9 @@ public class RecalUtils {
private static List<Object> generateValuesFromKeys(final List<Object> keys, final Covariate[] covariates, final Map<Covariate, String> covariateNameMap) {
final List<Object> values = new ArrayList<Object>(4);
values.add(covariates[RecalibrationTables.TableType.READ_GROUP_TABLE.index].formatKey((Integer)keys.get(0)));
values.add(covariates[RecalibrationTables.TableType.READ_GROUP_TABLE.ordinal()].formatKey((Integer)keys.get(0)));
final int covariateIndex = (Integer)keys.get(1);
final Covariate covariate = covariateIndex == covariates.length ? covariates[RecalibrationTables.TableType.QUALITY_SCORE_TABLE.index] : covariates[covariateIndex];
final Covariate covariate = covariateIndex == covariates.length ? covariates[RecalibrationTables.TableType.QUALITY_SCORE_TABLE.ordinal()] : covariates[covariateIndex];
final int covariateKey = (Integer)keys.get(2);
values.add(covariate.formatKey(covariateKey));
values.add(covariateNameMap.get(covariate));
@ -793,4 +793,48 @@ public class RecalUtils {
myDatum.combine(row.value);
}
}
/**
* Increments the RecalDatum at the specified position in the specified table, or put a new item there
* if there isn't already one.
*
* Does this in a thread-safe way WITHOUT being synchronized: relies on the behavior of NestedIntegerArray.put()
* to return false if another thread inserts a new item at our position in the middle of our put operation.
*
* @param table the table that holds/will hold our item
* @param qual qual for this event
* @param isError error value for this event
* @param keys location in table of our item
*/
public static void incrementDatumOrPutIfNecessary( final NestedIntegerArray<RecalDatum> table,
final byte qual,
final double isError,
final int... keys ) {
final RecalDatum existingDatum = table.get(keys);
if ( existingDatum == null ) {
// No existing item, try to put a new one
if ( ! table.put(createDatumObject(qual, isError), keys) ) {
// Failed to put a new item because another thread came along and put an item here first.
// Get the newly-put item and increment it (item is guaranteed to exist at this point)
table.get(keys).increment(1.0, isError);
}
}
else {
// Easy case: already an item here, so increment it
existingDatum.increment(1.0, isError);
}
}
/**
* creates a datum object with one observation and one or zero error
*
* @param reportedQual the quality score reported by the instrument for this base
* @param isError whether or not the observation is an error
* @return a new RecalDatum object with the observation and the error
*/
private static RecalDatum createDatumObject(final byte reportedQual, final double isError) {
return new RecalDatum(1, isError, reportedQual);
}
}

View File

@ -139,12 +139,12 @@ public class RecalibrationReport {
final String covName = (String)reportTable.get(i, RecalUtils.COVARIATE_NAME_COLUMN_NAME);
final int covIndex = optionalCovariateIndexes.get(covName);
final Object covValue = reportTable.get(i, RecalUtils.COVARIATE_VALUE_COLUMN_NAME);
tempCOVarray[2] = requestedCovariates[RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.index + covIndex].keyFromValue(covValue);
tempCOVarray[2] = requestedCovariates[RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.ordinal() + covIndex].keyFromValue(covValue);
final EventType event = EventType.eventFrom((String)reportTable.get(i, RecalUtils.EVENT_TYPE_COLUMN_NAME));
tempCOVarray[3] = event.ordinal();
recalibrationTables.getTable(RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.index + covIndex).put(getRecalDatum(reportTable, i, false), tempCOVarray);
recalibrationTables.getTable(RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.ordinal() + covIndex).put(getRecalDatum(reportTable, i, false), tempCOVarray);
}
}

View File

@ -42,15 +42,9 @@ import java.util.ArrayList;
public final class RecalibrationTables {
public enum TableType {
READ_GROUP_TABLE(0),
QUALITY_SCORE_TABLE(1),
OPTIONAL_COVARIATE_TABLES_START(2);
public final int index;
private TableType(final int index) {
this.index = index;
}
READ_GROUP_TABLE,
QUALITY_SCORE_TABLE,
OPTIONAL_COVARIATE_TABLES_START;
}
private final ArrayList<NestedIntegerArray<RecalDatum>> tables;
@ -60,7 +54,7 @@ public final class RecalibrationTables {
private final PrintStream log;
public RecalibrationTables(final Covariate[] covariates) {
this(covariates, covariates[TableType.READ_GROUP_TABLE.index].maximumKeyValue() + 1, null);
this(covariates, covariates[TableType.READ_GROUP_TABLE.ordinal()].maximumKeyValue() + 1, null);
}
public RecalibrationTables(final Covariate[] covariates, final int numReadGroups) {
@ -72,31 +66,31 @@ public final class RecalibrationTables {
for ( int i = 0; i < covariates.length; i++ )
tables.add(i, null); // initialize so we can set below
qualDimension = covariates[TableType.QUALITY_SCORE_TABLE.index].maximumKeyValue() + 1;
qualDimension = covariates[TableType.QUALITY_SCORE_TABLE.ordinal()].maximumKeyValue() + 1;
this.numReadGroups = numReadGroups;
this.log = log;
tables.set(TableType.READ_GROUP_TABLE.index,
tables.set(TableType.READ_GROUP_TABLE.ordinal(),
log == null ? new NestedIntegerArray<RecalDatum>(numReadGroups, eventDimension) :
new LoggingNestedIntegerArray<RecalDatum>(log, "READ_GROUP_TABLE", numReadGroups, eventDimension));
tables.set(TableType.QUALITY_SCORE_TABLE.index, makeQualityScoreTable());
tables.set(TableType.QUALITY_SCORE_TABLE.ordinal(), makeQualityScoreTable());
for (int i = TableType.OPTIONAL_COVARIATE_TABLES_START.index; i < covariates.length; i++)
for (int i = TableType.OPTIONAL_COVARIATE_TABLES_START.ordinal(); i < covariates.length; i++)
tables.set(i,
log == null ? new NestedIntegerArray<RecalDatum>(numReadGroups, qualDimension, covariates[i].maximumKeyValue()+1, eventDimension) :
new LoggingNestedIntegerArray<RecalDatum>(log, String.format("OPTIONAL_COVARIATE_TABLE_%d", i - TableType.OPTIONAL_COVARIATE_TABLES_START.index + 1),
new LoggingNestedIntegerArray<RecalDatum>(log, String.format("OPTIONAL_COVARIATE_TABLE_%d", i - TableType.OPTIONAL_COVARIATE_TABLES_START.ordinal() + 1),
numReadGroups, qualDimension, covariates[i].maximumKeyValue()+1, eventDimension));
}
@Ensures("result != null")
public NestedIntegerArray<RecalDatum> getReadGroupTable() {
return getTable(TableType.READ_GROUP_TABLE.index);
return getTable(TableType.READ_GROUP_TABLE.ordinal());
}
@Ensures("result != null")
public NestedIntegerArray<RecalDatum> getQualityScoreTable() {
return getTable(TableType.QUALITY_SCORE_TABLE.index);
return getTable(TableType.QUALITY_SCORE_TABLE.ordinal());
}
@Ensures("result != null")

View File

@ -94,8 +94,8 @@ public class RecalibrationReportUnitTest {
qualTable.put(createRandomRecalDatum(randomMax, 10), covariates[0], covariates[1], errorMode.ordinal());
nKeys += 2;
for (int j = 0; j < optionalCovariates.size(); j++) {
final NestedIntegerArray<RecalDatum> covTable = recalibrationTables.getTable(RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.index + j);
final int covValue = covariates[RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.index + j];
final NestedIntegerArray<RecalDatum> covTable = recalibrationTables.getTable(RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.ordinal() + j);
final int covValue = covariates[RecalibrationTables.TableType.OPTIONAL_COVARIATE_TABLES_START.ordinal() + j];
if ( covValue >= 0 ) {
covTable.put(createRandomRecalDatum(randomMax, 10), covariates[0], covariates[1], covValue, errorMode.ordinal());
nKeys++;

View File

@ -29,15 +29,46 @@ import org.broadinstitute.sting.BaseTest;
import org.broadinstitute.sting.utils.collections.NestedIntegerArray;
import org.broadinstitute.sting.utils.recalibration.covariates.*;
import org.testng.Assert;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import java.util.Arrays;
import java.util.List;
public final class RecalibrationTablesUnitTest extends BaseTest {
private RecalibrationTables tables;
private Covariate[] covariates;
private int numReadGroups = 6;
final byte qualByte = 1;
final List<Integer> combineStates = Arrays.asList(0, 1, 2);
@BeforeMethod
private void makeTables() {
covariates = RecalibrationTestUtils.makeInitializedStandardCovariates();
tables = new RecalibrationTables(covariates, numReadGroups);
fillTable(tables);
}
private void fillTable(final RecalibrationTables tables) {
for ( int iterations = 0; iterations < 10; iterations++ ) {
for ( final EventType et : EventType.values() ) {
for ( final int rg : combineStates) {
final double error = rg % 2 == 0 ? 1 : 0;
RecalUtils.incrementDatumOrPutIfNecessary(tables.getReadGroupTable(), qualByte, error, rg, et.ordinal());
for ( final int qual : combineStates) {
RecalUtils.incrementDatumOrPutIfNecessary(tables.getQualityScoreTable(), qualByte, error, rg, qual, et.ordinal());
for ( final int cycle : combineStates)
RecalUtils.incrementDatumOrPutIfNecessary(tables.getTable(2), qualByte, error, rg, qual, cycle, et.ordinal());
for ( final int context : combineStates)
RecalUtils.incrementDatumOrPutIfNecessary(tables.getTable(3), qualByte, error, rg, qual, context, et.ordinal());
}
}
}
}
}
@Test
public void basicTest() {
final Covariate[] covariates = RecalibrationTestUtils.makeInitializedStandardCovariates();
final int numReadGroups = 6;
final RecalibrationTables tables = new RecalibrationTables(covariates, numReadGroups);
final Covariate qualCov = covariates[1];
final Covariate cycleCov = covariates[2];
final Covariate contextCov = covariates[3];
@ -45,11 +76,11 @@ public final class RecalibrationTablesUnitTest extends BaseTest {
Assert.assertEquals(tables.numTables(), covariates.length);
Assert.assertNotNull(tables.getReadGroupTable());
Assert.assertEquals(tables.getReadGroupTable(), tables.getTable(RecalibrationTables.TableType.READ_GROUP_TABLE.index));
Assert.assertEquals(tables.getReadGroupTable(), tables.getTable(RecalibrationTables.TableType.READ_GROUP_TABLE.ordinal()));
testDimensions(tables.getReadGroupTable(), numReadGroups);
Assert.assertNotNull(tables.getQualityScoreTable());
Assert.assertEquals(tables.getQualityScoreTable(), tables.getTable(RecalibrationTables.TableType.QUALITY_SCORE_TABLE.index));
Assert.assertEquals(tables.getQualityScoreTable(), tables.getTable(RecalibrationTables.TableType.QUALITY_SCORE_TABLE.ordinal()));
testDimensions(tables.getQualityScoreTable(), numReadGroups, qualCov.maximumKeyValue() + 1);
Assert.assertNotNull(tables.getTable(2));
@ -72,13 +103,74 @@ public final class RecalibrationTablesUnitTest extends BaseTest {
@Test
public void basicMakeQualityScoreTable() {
final Covariate[] covariates = RecalibrationTestUtils.makeInitializedStandardCovariates();
final int numReadGroups = 6;
final RecalibrationTables tables = new RecalibrationTables(covariates, numReadGroups);
final Covariate qualCov = covariates[1];
final NestedIntegerArray<RecalDatum> copy = tables.makeQualityScoreTable();
testDimensions(copy, numReadGroups, qualCov.maximumKeyValue()+1);
Assert.assertEquals(copy.getAllValues().size(), 0);
}
@Test
public void testCombine1() {
final RecalibrationTables merged = new RecalibrationTables(covariates, numReadGroups);
fillTable(merged);
merged.combine(tables);
for ( int i = 0; i < tables.numTables(); i++ ) {
NestedIntegerArray<RecalDatum> table = tables.getTable(i);
NestedIntegerArray<RecalDatum> mergedTable = merged.getTable(i);
Assert.assertEquals(table.getAllLeaves().size(), mergedTable.getAllLeaves().size());
for ( final NestedIntegerArray.Leaf<RecalDatum> leaf : table.getAllLeaves() ) {
final RecalDatum mergedValue = mergedTable.get(leaf.keys);
Assert.assertNotNull(mergedValue);
Assert.assertEquals(mergedValue.getNumObservations(), leaf.value.getNumObservations() * 2);
Assert.assertEquals(mergedValue.getNumMismatches(), leaf.value.getNumMismatches() * 2);
}
}
}
@Test
public void testCombineEmptyOther() {
final RecalibrationTables merged = new RecalibrationTables(covariates, numReadGroups);
merged.combine(tables);
for ( int i = 0; i < tables.numTables(); i++ ) {
NestedIntegerArray<RecalDatum> table = tables.getTable(i);
NestedIntegerArray<RecalDatum> mergedTable = merged.getTable(i);
Assert.assertEquals(table.getAllLeaves().size(), mergedTable.getAllLeaves().size());
for ( final NestedIntegerArray.Leaf<RecalDatum> leaf : table.getAllLeaves() ) {
final RecalDatum mergedValue = mergedTable.get(leaf.keys);
Assert.assertNotNull(mergedValue);
Assert.assertEquals(mergedValue.getNumObservations(), leaf.value.getNumObservations());
Assert.assertEquals(mergedValue.getNumMismatches(), leaf.value.getNumMismatches());
}
}
}
@Test
public void testCombinePartial() {
final RecalibrationTables merged = new RecalibrationTables(covariates, numReadGroups);
for ( final int rg : combineStates) {
RecalUtils.incrementDatumOrPutIfNecessary(merged.getTable(3), qualByte, 1, rg, 0, 0, 0);
}
merged.combine(tables);
for ( int i = 0; i < tables.numTables(); i++ ) {
NestedIntegerArray<RecalDatum> table = tables.getTable(i);
NestedIntegerArray<RecalDatum> mergedTable = merged.getTable(i);
Assert.assertEquals(table.getAllLeaves().size(), mergedTable.getAllLeaves().size());
for ( final NestedIntegerArray.Leaf<RecalDatum> leaf : table.getAllLeaves() ) {
final RecalDatum mergedValue = mergedTable.get(leaf.keys);
Assert.assertNotNull(mergedValue);
final int delta = i == 3 && leaf.keys[1] == 0 && leaf.keys[2] == 0 && leaf.keys[3] == 0 ? 1 : 0;
Assert.assertEquals(mergedValue.getNumObservations(), leaf.value.getNumObservations() + delta);
Assert.assertEquals(mergedValue.getNumMismatches(), leaf.value.getNumMismatches() + delta);
}
}
}
}