Initial commit of the delocalized BQSR written as a read walker.

This commit is contained in:
Ryan Poplin 2012-08-28 15:24:20 -04:00
parent e74c527d47
commit 18eca3544e
7 changed files with 83 additions and 21 deletions

View File

@ -34,17 +34,20 @@ import org.broadinstitute.sting.utils.recalibration.EventType;
import org.broadinstitute.sting.utils.recalibration.ReadCovariates;
import org.broadinstitute.sting.utils.recalibration.RecalDatum;
import org.broadinstitute.sting.utils.recalibration.RecalibrationTables;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine implements ProtectedPackageSource {
// optimizations: don't reallocate an array each time
private byte[] tempQualArray;
private boolean[] tempErrorArray;
private double[] tempFractionalErrorArray;
public void initialize(final Covariate[] covariates, final RecalibrationTables recalibrationTables) {
super.initialize(covariates, recalibrationTables);
tempQualArray = new byte[EventType.values().length];
tempErrorArray = new boolean[EventType.values().length];
tempFractionalErrorArray = new double[EventType.values().length];
}
/**
@ -56,6 +59,7 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp
* @param pileupElement The pileup element to update
* @param refBase The reference base at this locus
*/
@Override
public synchronized void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) {
final int offset = pileupElement.getOffset();
final ReadCovariates readCovariates = covariateKeySetFrom(pileupElement.getRead());
@ -100,4 +104,51 @@ public class AdvancedRecalibrationEngine extends StandardRecalibrationEngine imp
}
}
}
@Override
public synchronized void updateDataForRead(final GATKSAMRecord read, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) {
for( int offset = 0; offset < read.getReadBases().length; offset++ ) {
final ReadCovariates readCovariates = covariateKeySetFrom(read);
tempQualArray[EventType.BASE_SUBSTITUTION.index] = read.getBaseQualities()[offset];
tempFractionalErrorArray[EventType.BASE_SUBSTITUTION.index] = snpErrors[offset];
tempQualArray[EventType.BASE_INSERTION.index] = read.getBaseInsertionQualities()[offset];
tempFractionalErrorArray[EventType.BASE_INSERTION.index] = insertionErrors[offset];
tempQualArray[EventType.BASE_DELETION.index] = read.getBaseDeletionQualities()[offset];
tempFractionalErrorArray[EventType.BASE_DELETION.index] = deletionErrors[offset];
for (final EventType eventType : EventType.values()) {
final int[] keys = readCovariates.getKeySet(offset, eventType);
final int eventIndex = eventType.index;
final byte qual = tempQualArray[eventIndex];
final double isError = tempFractionalErrorArray[eventIndex];
final NestedIntegerArray<RecalDatum> rgRecalTable = recalibrationTables.getTable(RecalibrationTables.TableType.READ_GROUP_TABLE);
final RecalDatum rgPreviousDatum = rgRecalTable.get(keys[0], eventIndex);
final RecalDatum rgThisDatum = createDatumObject(qual, isError);
if (rgPreviousDatum == null) // key doesn't exist yet in the map so make a new bucket and add it
rgRecalTable.put(rgThisDatum, keys[0], eventIndex);
else
rgPreviousDatum.combine(rgThisDatum);
final NestedIntegerArray<RecalDatum> qualRecalTable = recalibrationTables.getTable(RecalibrationTables.TableType.QUALITY_SCORE_TABLE);
final RecalDatum qualPreviousDatum = qualRecalTable.get(keys[0], keys[1], eventIndex);
if (qualPreviousDatum == null)
qualRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], eventIndex);
else
qualPreviousDatum.increment(1.0, isError);
for (int i = 2; i < covariates.length; i++) {
if (keys[i] < 0)
continue;
final NestedIntegerArray<RecalDatum> covRecalTable = recalibrationTables.getTable(i);
final RecalDatum covPreviousDatum = covRecalTable.get(keys[0], keys[1], keys[i], eventIndex);
if (covPreviousDatum == null)
covRecalTable.put(createDatumObject(qual, isError), keys[0], keys[1], keys[i], eventIndex);
else
covPreviousDatum.increment(1.0, isError);
}
}
}
}
}

View File

@ -3,6 +3,7 @@ package org.broadinstitute.sting.gatk.walkers.bqsr;
import org.broadinstitute.sting.utils.recalibration.covariates.Covariate;
import org.broadinstitute.sting.utils.pileup.PileupElement;
import org.broadinstitute.sting.utils.recalibration.RecalibrationTables;
import org.broadinstitute.sting.utils.sam.GATKSAMRecord;
/*
* Copyright (c) 2009 The Broad Institute
@ -34,4 +35,5 @@ public interface RecalibrationEngine {
public void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase);
public void updateDataForRead(final GATKSAMRecord read, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors);
}

View File

@ -54,6 +54,7 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
* @param pileupElement The pileup element to update
* @param refBase The reference base at this locus
*/
@Override
public synchronized void updateDataForPileupElement(final PileupElement pileupElement, final byte refBase) {
final int offset = pileupElement.getOffset();
final ReadCovariates readCovariates = covariateKeySetFrom(pileupElement.getRead());
@ -91,6 +92,11 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
}
}
@Override
public synchronized void updateDataForRead( final GATKSAMRecord read, final double[] snpErrors, final double[] insertionErrors, final double[] deletionErrors ) {
throw new UnsupportedOperationException("Delocalized BQSR is not available in the GATK-lite version");
}
/**
* creates a datum object with one observation and one or zero error
*
@ -102,6 +108,10 @@ public class StandardRecalibrationEngine implements RecalibrationEngine, PublicP
return new RecalDatum(1, isError ? 1:0, reportedQual);
}
protected RecalDatum createDatumObject(final byte reportedQual, final double isError) {
return new RecalDatum(1, isError, reportedQual);
}
/**
* Get the covariate key set from a read
*

View File

@ -39,7 +39,7 @@ public class QuantizationInfo {
for (final RecalDatum value : qualTable.getAllValues()) {
final RecalDatum datum = value;
final int empiricalQual = MathUtils.fastRound(datum.getEmpiricalQuality()); // convert the empirical quality to an integer ( it is already capped by MAX_QUAL )
qualHistogram[empiricalQual] += datum.getNumObservations(); // add the number of observations for every key
qualHistogram[empiricalQual] += (long) datum.getNumObservations(); // add the number of observations for every key
}
empiricalQualCounts = Arrays.asList(qualHistogram); // histogram with the number of observations of the empirical qualities
quantizeQualityScores(quantizationLevels);

View File

@ -28,7 +28,6 @@ package org.broadinstitute.sting.utils.recalibration;
import com.google.java.contract.Ensures;
import com.google.java.contract.Invariant;
import com.google.java.contract.Requires;
import org.broadinstitute.sting.utils.MathUtils;
import org.broadinstitute.sting.utils.QualityUtils;
import java.util.Random;
@ -68,12 +67,12 @@ public class RecalDatum {
/**
* number of bases seen in total
*/
private long numObservations;
private double numObservations;
/**
* number of bases seen that didn't match the reference
*/
private long numMismatches;
private double numMismatches;
/**
* used when calculating empirical qualities to avoid division by zero
@ -93,7 +92,7 @@ public class RecalDatum {
* @param _numMismatches
* @param reportedQuality
*/
public RecalDatum(final long _numObservations, final long _numMismatches, final byte reportedQuality) {
public RecalDatum(final double _numObservations, final double _numMismatches, final byte reportedQuality) {
if ( _numObservations < 0 ) throw new IllegalArgumentException("numObservations < 0");
if ( _numMismatches < 0 ) throw new IllegalArgumentException("numMismatches < 0");
if ( reportedQuality < 0 ) throw new IllegalArgumentException("reportedQuality < 0");
@ -167,9 +166,9 @@ public class RecalDatum {
return 0.0;
else {
// cache the value so we don't call log over and over again
final double doubleMismatches = (double) (numMismatches + SMOOTHING_CONSTANT);
final double doubleMismatches = numMismatches + SMOOTHING_CONSTANT;
// smoothing is one error and one non-error observation, for example
final double doubleObservations = (double) (numObservations + SMOOTHING_CONSTANT + SMOOTHING_CONSTANT);
final double doubleObservations = numObservations + SMOOTHING_CONSTANT + SMOOTHING_CONSTANT;
return doubleMismatches / doubleObservations;
}
}
@ -200,7 +199,7 @@ public class RecalDatum {
@Override
public String toString() {
return String.format("%d,%d,%d", getNumObservations(), getNumMismatches(), (byte) Math.floor(getEmpiricalQuality()));
return String.format("%d,%d,%d", Math.round(getNumObservations()), Math.round(getNumMismatches()), (byte) Math.floor(getEmpiricalQuality()));
}
public String stringForCSV() {
@ -229,42 +228,42 @@ public class RecalDatum {
//
//---------------------------------------------------------------------------------------------------------------
public long getNumObservations() {
public double getNumObservations() {
return numObservations;
}
public synchronized void setNumObservations(final long numObservations) {
public synchronized void setNumObservations(final double numObservations) {
if ( numObservations < 0 ) throw new IllegalArgumentException("numObservations < 0");
this.numObservations = numObservations;
empiricalQuality = UNINITIALIZED;
}
public long getNumMismatches() {
public double getNumMismatches() {
return numMismatches;
}
@Requires({"numMismatches >= 0"})
public synchronized void setNumMismatches(final long numMismatches) {
public synchronized void setNumMismatches(final double numMismatches) {
if ( numMismatches < 0 ) throw new IllegalArgumentException("numMismatches < 0");
this.numMismatches = numMismatches;
empiricalQuality = UNINITIALIZED;
}
@Requires({"by >= 0"})
public synchronized void incrementNumObservations(final long by) {
public synchronized void incrementNumObservations(final double by) {
numObservations += by;
empiricalQuality = UNINITIALIZED;
}
@Requires({"by >= 0"})
public synchronized void incrementNumMismatches(final long by) {
public synchronized void incrementNumMismatches(final double by) {
numMismatches += by;
empiricalQuality = UNINITIALIZED;
}
@Requires({"incObservations >= 0", "incMismatches >= 0"})
@Ensures({"numObservations == old(numObservations) + incObservations", "numMismatches == old(numMismatches) + incMismatches"})
public synchronized void increment(final long incObservations, final long incMismatches) {
public synchronized void increment(final double incObservations, final double incMismatches) {
incrementNumObservations(incObservations);
incrementNumMismatches(incMismatches);
}
@ -300,6 +299,6 @@ public class RecalDatum {
*/
@Ensures("result >= 0.0")
private double calcExpectedErrors() {
return (double) getNumObservations() * QualityUtils.qualToErrorProb(estimatedQReported);
return getNumObservations() * QualityUtils.qualToErrorProb(estimatedQReported);
}
}

View File

@ -263,14 +263,14 @@ public class RecalDatumNode<T extends RecalDatum> {
int i = 0;
for ( final RecalDatumNode<T> subnode : subnodes ) {
// use the yates correction to help avoid all zeros => NaN
counts[i][0] = subnode.getRecalDatum().getNumMismatches() + 1;
counts[i][1] = subnode.getRecalDatum().getNumObservations() + 2;
counts[i][0] = Math.round(subnode.getRecalDatum().getNumMismatches()) + 1L;
counts[i][1] = Math.round(subnode.getRecalDatum().getNumObservations()) + 2L;
i++;
}
try {
final double chi2PValue = new ChiSquareTestImpl().chiSquareTest(counts);
final double penalty = -10 * Math.log10(Math.max(chi2PValue, SMALLEST_CHI2_PVALUE));
final double penalty = -10.0 * Math.log10(Math.max(chi2PValue, SMALLEST_CHI2_PVALUE));
// make sure things are reasonable and fail early if not
if (Double.isInfinite(penalty) || Double.isNaN(penalty))

View File

@ -317,8 +317,8 @@ public class RecalUtils {
reportTable.set(rowIndex, columnNames.get(columnIndex++).getFirst(), datum.getEmpiricalQuality());
if (tableIndex == RecalibrationTables.TableType.READ_GROUP_TABLE.index)
reportTable.set(rowIndex, columnNames.get(columnIndex++).getFirst(), datum.getEstimatedQReported()); // we only add the estimated Q reported in the RG table
reportTable.set(rowIndex, columnNames.get(columnIndex++).getFirst(), datum.getNumObservations());
reportTable.set(rowIndex, columnNames.get(columnIndex).getFirst(), datum.getNumMismatches());
reportTable.set(rowIndex, columnNames.get(columnIndex++).getFirst(), Math.round(datum.getNumObservations()));
reportTable.set(rowIndex, columnNames.get(columnIndex).getFirst(), Math.round(datum.getNumMismatches()));
rowIndex++;
}