Refactoring of PairHMM to support reduced reads

This commit is contained in:
Mark DePristo 2011-09-26 13:28:56 -04:00
parent a6b65d6347
commit fa0efbc4ca
1 changed files with 151 additions and 144 deletions

View File

@ -28,6 +28,7 @@ package org.broadinstitute.sting.gatk.walkers.indels;
import net.sf.samtools.Cigar;
import net.sf.samtools.CigarElement;
import net.sf.samtools.CigarOperator;
import net.sf.samtools.SAMRecord;
import org.broadinstitute.sting.gatk.contexts.ReferenceContext;
import org.broadinstitute.sting.utils.Haplotype;
import org.broadinstitute.sting.utils.MathUtils;
@ -244,31 +245,31 @@ public class PairHMMIndelErrorModel {
/**
* For each covariate read in a value and parse it. Associate those values with the data itself (num observation and num mismatches)
*/
/*
private void addCSVData(final File file, final String line) {
final String[] vals = line.split(",");
/*
private void addCSVData(final File file, final String line) {
final String[] vals = line.split(",");
// Check if the data line is malformed, for example if the read group string contains a comma then it won't be parsed correctly
if( vals.length != requestedCovariates.size() + 3 ) { // +3 because of nObservations, nMismatch, and Qempirical
throw new UserException.MalformedFile(file, "Malformed input recalibration file. Found data line with too many fields: " + line +
" --Perhaps the read group string contains a comma and isn't being parsed correctly.");
// Check if the data line is malformed, for example if the read group string contains a comma then it won't be parsed correctly
if( vals.length != requestedCovariates.size() + 3 ) { // +3 because of nObservations, nMismatch, and Qempirical
throw new UserException.MalformedFile(file, "Malformed input recalibration file. Found data line with too many fields: " + line +
" --Perhaps the read group string contains a comma and isn't being parsed correctly.");
}
final Object[] key = new Object[requestedCovariates.size()];
Covariate cov;
int iii;
for( iii = 0; iii < requestedCovariates.size(); iii++ ) {
cov = requestedCovariates.get( iii );
key[iii] = cov.getValue( vals[iii] );
}
// Create a new datum using the number of observations, number of mismatches, and reported quality score
final RecalDatum datum = new RecalDatum( Long.parseLong( vals[iii] ), Long.parseLong( vals[iii + 1] ), Double.parseDouble( vals[1] ), 0.0 );
// Add that datum to all the collapsed tables which will be used in the sequential calculation
dataManager.addToAllTables( key, datum, PRESERVE_QSCORES_LESS_THAN );
}
final Object[] key = new Object[requestedCovariates.size()];
Covariate cov;
int iii;
for( iii = 0; iii < requestedCovariates.size(); iii++ ) {
cov = requestedCovariates.get( iii );
key[iii] = cov.getValue( vals[iii] );
}
// Create a new datum using the number of observations, number of mismatches, and reported quality score
final RecalDatum datum = new RecalDatum( Long.parseLong( vals[iii] ), Long.parseLong( vals[iii + 1] ), Double.parseDouble( vals[1] ), 0.0 );
// Add that datum to all the collapsed tables which will be used in the sequential calculation
dataManager.addToAllTables( key, datum, PRESERVE_QSCORES_LESS_THAN );
}
*/
*/
public PairHMMIndelErrorModel(double indelGOP, double indelGCP, boolean deb, boolean doCDP, boolean dovit) {
this(indelGOP, indelGCP, deb, doCDP);
this.doViterbi = dovit;
@ -588,7 +589,7 @@ public class PairHMMIndelErrorModel {
}
else {
c = currentGOP[jm1];
d = currentGCP[jm1];
d = currentGCP[jm1];
}
if (indI == X_METRIC_LENGTH-1)
c = d = END_GAP_COST;
@ -707,12 +708,12 @@ public class PairHMMIndelErrorModel {
}
}
public synchronized double[] computeReadHaplotypeLikelihoods(ReadBackedPileup pileup, LinkedHashMap<Allele,Haplotype> haplotypeMap,
ReferenceContext ref, int eventLength,
HashMap<PileupElement, LinkedHashMap<Allele,Double>> indelLikelihoodMap){
ReferenceContext ref, int eventLength,
HashMap<PileupElement, LinkedHashMap<Allele,Double>> indelLikelihoodMap){
int numHaplotypes = haplotypeMap.size();
double[][] haplotypeLikehoodMatrix = new double[numHaplotypes][numHaplotypes];
double readLikelihoods[][] = new double[pileup.getReads().size()][numHaplotypes];
final double readLikelihoods[][] = new double[pileup.size()][numHaplotypes];
final int readCounts[] = new int[pileup.size()];
int readIdx=0;
LinkedHashMap<Allele,double[]> gapOpenProbabilityMap = new LinkedHashMap<Allele,double[]>();
@ -751,6 +752,9 @@ public class PairHMMIndelErrorModel {
}
}
for (PileupElement p: pileup) {
// > 1 when the read is a consensus read representing multiple independent observations
final boolean isReduced = ReadUtils.isReducedRead(p.getRead());
readCounts[readIdx] = isReduced ? p.getReducedCount() : 1;
// check if we've already computed likelihoods for this pileup element (i.e. for this read at this location)
if (indelLikelihoodMap.containsKey(p)) {
@ -762,61 +766,65 @@ public class PairHMMIndelErrorModel {
}
else {
//System.out.format("%d %s\n",p.getRead().getAlignmentStart(), p.getRead().getClass().getName());
GATKSAMRecord read = ReadUtils.hardClipAdaptorSequence(p.getRead());
SAMRecord read = ReadUtils.hardClipAdaptorSequence(p.getRead());
if (read == null)
continue;
if ( isReduced ) {
read = ReadUtils.reducedReadWithReducedQuals(read);
}
if(ReadUtils.is454Read(read) && !getGapPenaltiesFromFile) {
continue;
}
double[] recalQuals = null;
/*
if (getGapPenaltiesFromFile) {
RecalDataManager.parseSAMRecord( read, RAC );
/*
if (getGapPenaltiesFromFile) {
RecalDataManager.parseSAMRecord( read, RAC );
recalQuals = new double[read.getReadLength()];
recalQuals = new double[read.getReadLength()];
//compute all covariate values for this read
final Comparable[][] covariateValues_offset_x_covar =
RecalDataManager.computeCovariates((GATKSAMRecord) read, requestedCovariates);
// For each base in the read
for( int offset = 0; offset < read.getReadLength(); offset++ ) {
//compute all covariate values for this read
final Comparable[][] covariateValues_offset_x_covar =
RecalDataManager.computeCovariates((GATKSAMRecord) read, requestedCovariates);
// For each base in the read
for( int offset = 0; offset < read.getReadLength(); offset++ ) {
final Object[] fullCovariateKey = covariateValues_offset_x_covar[offset];
final Object[] fullCovariateKey = covariateValues_offset_x_covar[offset];
Byte qualityScore = (Byte) qualityScoreByFullCovariateKey.get(fullCovariateKey);
if(qualityScore == null)
{
qualityScore = performSequentialQualityCalculation( fullCovariateKey );
qualityScoreByFullCovariateKey.put(qualityScore, fullCovariateKey);
}
Byte qualityScore = (Byte) qualityScoreByFullCovariateKey.get(fullCovariateKey);
if(qualityScore == null)
{
qualityScore = performSequentialQualityCalculation( fullCovariateKey );
qualityScoreByFullCovariateKey.put(qualityScore, fullCovariateKey);
}
recalQuals[offset] = -((double)qualityScore)/10.0;
}
recalQuals[offset] = -((double)qualityScore)/10.0;
}
// for each read/haplotype combination, compute likelihoods, ie -10*log10(Pr(R | Hi))
// = sum_j(-10*log10(Pr(R_j | Hi) since reads are assumed to be independent
if (DEBUG) {
System.out.format("\n\nStarting read:%s S:%d US:%d E:%d UE:%d C:%s\n",read.getReadName(),
read.getAlignmentStart(),
read.getUnclippedStart(), read.getAlignmentEnd(), read.getUnclippedEnd(),
read.getCigarString());
// for each read/haplotype combination, compute likelihoods, ie -10*log10(Pr(R | Hi))
// = sum_j(-10*log10(Pr(R_j | Hi) since reads are assumed to be independent
if (DEBUG) {
System.out.format("\n\nStarting read:%s S:%d US:%d E:%d UE:%d C:%s\n",read.getReadName(),
read.getAlignmentStart(),
read.getUnclippedStart(), read.getAlignmentEnd(), read.getUnclippedEnd(),
read.getCigarString());
byte[] bases = read.getReadBases();
for (int k = 0; k < recalQuals.length; k++) {
System.out.format("%c",bases[k]);
}
System.out.println();
byte[] bases = read.getReadBases();
for (int k = 0; k < recalQuals.length; k++) {
System.out.format("%c",bases[k]);
}
System.out.println();
for (int k = 0; k < recalQuals.length; k++) {
System.out.format("%.0f ",recalQuals[k]);
}
System.out.println();
}
} */
for (int k = 0; k < recalQuals.length; k++) {
System.out.format("%.0f ",recalQuals[k]);
}
System.out.println();
}
} */
// get bases of candidate haplotypes that overlap with reads
final int trailingBases = 3;
@ -971,7 +979,7 @@ public class PairHMMIndelErrorModel {
System.out.println(new String(haplotypeBases));
}
Double readLikelihood = 0.0;
double readLikelihood = 0.0;
if (useAffineGapModel) {
double[] currentContextGOP = null;
@ -979,14 +987,14 @@ public class PairHMMIndelErrorModel {
if (doContextDependentPenalties) {
if (getGapPenaltiesFromFile) {
readLikelihood = computeReadLikelihoodGivenHaplotypeAffineGaps(haplotypeBases, readBases, readQuals, recalCDP, null);
if (getGapPenaltiesFromFile) {
readLikelihood = computeReadLikelihoodGivenHaplotypeAffineGaps(haplotypeBases, readBases, readQuals, recalCDP, null);
} else {
currentContextGOP = Arrays.copyOfRange(gapOpenProbabilityMap.get(a), (int)indStart, (int)indStop);
currentContextGCP = Arrays.copyOfRange(gapContProbabilityMap.get(a), (int)indStart, (int)indStop);
readLikelihood = computeReadLikelihoodGivenHaplotypeAffineGaps(haplotypeBases, readBases, readQuals, currentContextGOP, currentContextGCP);
}
} else {
currentContextGOP = Arrays.copyOfRange(gapOpenProbabilityMap.get(a), (int)indStart, (int)indStop);
currentContextGCP = Arrays.copyOfRange(gapContProbabilityMap.get(a), (int)indStart, (int)indStop);
readLikelihood = computeReadLikelihoodGivenHaplotypeAffineGaps(haplotypeBases, readBases, readQuals, currentContextGOP, currentContextGCP);
}
}
}
@ -1004,7 +1012,7 @@ public class PairHMMIndelErrorModel {
if (DEBUG) {
System.out.println("\nLikelihood summary");
for (readIdx=0; readIdx < pileup.getReads().size(); readIdx++) {
for (readIdx=0; readIdx < pileup.size(); readIdx++) {
System.out.format("Read Index: %d ",readIdx);
for (int i=0; i < readLikelihoods[readIdx].length; i++)
System.out.format("L%d: %f ",i,readLikelihoods[readIdx][i]);
@ -1012,36 +1020,35 @@ public class PairHMMIndelErrorModel {
}
}
return getHaplotypeLikelihoods(numHaplotypes, readCounts, readLikelihoods);
}
private final static double[] getHaplotypeLikelihoods(final int numHaplotypes, final int readCounts[], final double readLikelihoods[][]) {
final double[][] haplotypeLikehoodMatrix = new double[numHaplotypes][numHaplotypes];
// todo: MAD 09/26/11 -- I'm almost certain this calculation can be simplied to just a single loop without the intermediate NxN matrix
for (int i=0; i < numHaplotypes; i++) {
for (int j=i; j < numHaplotypes; j++){
// combine likelihoods of haplotypeLikelihoods[i], haplotypeLikelihoods[j]
// L(Hi, Hj) = sum_reads ( Pr(R|Hi)/2 + Pr(R|Hj)/2)
//readLikelihoods[k][j] has log10(Pr(R_k) | H[j] )
for (readIdx=0; readIdx < pileup.getReads().size(); readIdx++) {
for (int readIdx = 0; readIdx < readLikelihoods.length; readIdx++) {
// Compute log10(10^x1/2 + 10^x2/2) = log10(10^x1+10^x2)-log10(2)
// First term is approximated by Jacobian log with table lookup.
if (Double.isInfinite(readLikelihoods[readIdx][i]) && Double.isInfinite(readLikelihoods[readIdx][j]))
continue;
haplotypeLikehoodMatrix[i][j] += ( MathUtils.softMax(readLikelihoods[readIdx][i],
readLikelihoods[readIdx][j]) + LOG_ONE_HALF);
final double li = readLikelihoods[readIdx][i];
final double lj = readLikelihoods[readIdx][j];
final int readCount = readCounts[readIdx];
haplotypeLikehoodMatrix[i][j] += readCount * (MathUtils.softMax(li, lj) + LOG_ONE_HALF);
}
}
}
return getHaplotypeLikelihoods(haplotypeLikehoodMatrix);
}
public static double[] getHaplotypeLikelihoods(double[][] haplotypeLikehoodMatrix) {
int hSize = haplotypeLikehoodMatrix.length;
double[] genotypeLikelihoods = new double[hSize*(hSize+1)/2];
final double[] genotypeLikelihoods = new double[numHaplotypes*(numHaplotypes+1)/2];
int k=0;
for (int j=0; j < hSize; j++) {
for (int j=0; j < numHaplotypes; j++) {
for (int i=0; i <= j; i++){
genotypeLikelihoods[k++] = haplotypeLikehoodMatrix[i][j];
}
@ -1066,63 +1073,63 @@ public class PairHMMIndelErrorModel {
* @param key The list of Comparables that were calculated from the covariates
* @return A recalibrated quality score as a byte
*/
/*
private byte performSequentialQualityCalculation( final Object... key ) {
/*
private byte performSequentialQualityCalculation( final Object... key ) {
final byte qualFromRead = (byte)Integer.parseInt(key[1].toString());
final Object[] readGroupCollapsedKey = new Object[1];
final Object[] qualityScoreCollapsedKey = new Object[2];
final Object[] covariateCollapsedKey = new Object[3];
final byte qualFromRead = (byte)Integer.parseInt(key[1].toString());
final Object[] readGroupCollapsedKey = new Object[1];
final Object[] qualityScoreCollapsedKey = new Object[2];
final Object[] covariateCollapsedKey = new Object[3];
// The global quality shift (over the read group only)
readGroupCollapsedKey[0] = key[0];
final RecalDatum globalRecalDatum = ((RecalDatum)dataManager.getCollapsedTable(0).get( readGroupCollapsedKey ));
double globalDeltaQ = 0.0;
if( globalRecalDatum != null ) {
final double globalDeltaQEmpirical = globalRecalDatum.getEmpiricalQuality();
final double aggregrateQReported = globalRecalDatum.getEstimatedQReported();
globalDeltaQ = globalDeltaQEmpirical - aggregrateQReported;
}
// The shift in quality between reported and empirical
qualityScoreCollapsedKey[0] = key[0];
qualityScoreCollapsedKey[1] = key[1];
final RecalDatum qReportedRecalDatum = ((RecalDatum)dataManager.getCollapsedTable(1).get( qualityScoreCollapsedKey ));
double deltaQReported = 0.0;
if( qReportedRecalDatum != null ) {
final double deltaQReportedEmpirical = qReportedRecalDatum.getEmpiricalQuality();
deltaQReported = deltaQReportedEmpirical - qualFromRead - globalDeltaQ;
}
// The shift in quality due to each covariate by itself in turn
double deltaQCovariates = 0.0;
double deltaQCovariateEmpirical;
covariateCollapsedKey[0] = key[0];
covariateCollapsedKey[1] = key[1];
for( int iii = 2; iii < key.length; iii++ ) {
covariateCollapsedKey[2] = key[iii]; // The given covariate
final RecalDatum covariateRecalDatum = ((RecalDatum)dataManager.getCollapsedTable(iii).get( covariateCollapsedKey ));
if( covariateRecalDatum != null ) {
deltaQCovariateEmpirical = covariateRecalDatum.getEmpiricalQuality();
deltaQCovariates += ( deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported) );
// The global quality shift (over the read group only)
readGroupCollapsedKey[0] = key[0];
final RecalDatum globalRecalDatum = ((RecalDatum)dataManager.getCollapsedTable(0).get( readGroupCollapsedKey ));
double globalDeltaQ = 0.0;
if( globalRecalDatum != null ) {
final double globalDeltaQEmpirical = globalRecalDatum.getEmpiricalQuality();
final double aggregrateQReported = globalRecalDatum.getEstimatedQReported();
globalDeltaQ = globalDeltaQEmpirical - aggregrateQReported;
}
// The shift in quality between reported and empirical
qualityScoreCollapsedKey[0] = key[0];
qualityScoreCollapsedKey[1] = key[1];
final RecalDatum qReportedRecalDatum = ((RecalDatum)dataManager.getCollapsedTable(1).get( qualityScoreCollapsedKey ));
double deltaQReported = 0.0;
if( qReportedRecalDatum != null ) {
final double deltaQReportedEmpirical = qReportedRecalDatum.getEmpiricalQuality();
deltaQReported = deltaQReportedEmpirical - qualFromRead - globalDeltaQ;
}
// The shift in quality due to each covariate by itself in turn
double deltaQCovariates = 0.0;
double deltaQCovariateEmpirical;
covariateCollapsedKey[0] = key[0];
covariateCollapsedKey[1] = key[1];
for( int iii = 2; iii < key.length; iii++ ) {
covariateCollapsedKey[2] = key[iii]; // The given covariate
final RecalDatum covariateRecalDatum = ((RecalDatum)dataManager.getCollapsedTable(iii).get( covariateCollapsedKey ));
if( covariateRecalDatum != null ) {
deltaQCovariateEmpirical = covariateRecalDatum.getEmpiricalQuality();
deltaQCovariates += ( deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported) );
}
}
final double newQuality = qualFromRead + globalDeltaQ + deltaQReported + deltaQCovariates;
return QualityUtils.boundQual( (int)Math.round(newQuality), (byte)MAX_QUALITY_SCORE );
// Verbose printouts used to validate with old recalibrator
//if(key.contains(null)) {
// System.out.println( key + String.format(" => %d + %.2f + %.2f + %.2f + %.2f = %d",
// qualFromRead, globalDeltaQ, deltaQReported, deltaQPos, deltaQDinuc, newQualityByte));
//}
//else {
// System.out.println( String.format("%s %s %s %s => %d + %.2f + %.2f + %.2f + %.2f = %d",
// key.get(0).toString(), key.get(3).toString(), key.get(2).toString(), key.get(1).toString(), qualFromRead, globalDeltaQ, deltaQReported, deltaQPos, deltaQDinuc, newQualityByte) );
//}
//return newQualityByte;
}
final double newQuality = qualFromRead + globalDeltaQ + deltaQReported + deltaQCovariates;
return QualityUtils.boundQual( (int)Math.round(newQuality), (byte)MAX_QUALITY_SCORE );
// Verbose printouts used to validate with old recalibrator
//if(key.contains(null)) {
// System.out.println( key + String.format(" => %d + %.2f + %.2f + %.2f + %.2f = %d",
// qualFromRead, globalDeltaQ, deltaQReported, deltaQPos, deltaQDinuc, newQualityByte));
//}
//else {
// System.out.println( String.format("%s %s %s %s => %d + %.2f + %.2f + %.2f + %.2f = %d",
// key.get(0).toString(), key.get(3).toString(), key.get(2).toString(), key.get(1).toString(), qualFromRead, globalDeltaQ, deltaQReported, deltaQPos, deltaQDinuc, newQualityByte) );
//}
//return newQualityByte;
}
*/
*/
}