From fa0efbc4ca9b304e6e67326912e65263a996c56a Mon Sep 17 00:00:00 2001 From: Mark DePristo Date: Mon, 26 Sep 2011 13:28:56 -0400 Subject: [PATCH] Refactoring of PairHMM to support reduced reads --- .../indels/PairHMMIndelErrorModel.java | 295 +++++++++--------- 1 file changed, 151 insertions(+), 144 deletions(-) diff --git a/public/java/src/org/broadinstitute/sting/gatk/walkers/indels/PairHMMIndelErrorModel.java b/public/java/src/org/broadinstitute/sting/gatk/walkers/indels/PairHMMIndelErrorModel.java index 31e9819ab..6e4db9303 100755 --- a/public/java/src/org/broadinstitute/sting/gatk/walkers/indels/PairHMMIndelErrorModel.java +++ b/public/java/src/org/broadinstitute/sting/gatk/walkers/indels/PairHMMIndelErrorModel.java @@ -28,6 +28,7 @@ package org.broadinstitute.sting.gatk.walkers.indels; import net.sf.samtools.Cigar; import net.sf.samtools.CigarElement; import net.sf.samtools.CigarOperator; +import net.sf.samtools.SAMRecord; import org.broadinstitute.sting.gatk.contexts.ReferenceContext; import org.broadinstitute.sting.utils.Haplotype; import org.broadinstitute.sting.utils.MathUtils; @@ -244,31 +245,31 @@ public class PairHMMIndelErrorModel { /** * For each covariate read in a value and parse it. Associate those values with the data itself (num observation and num mismatches) */ - /* - private void addCSVData(final File file, final String line) { - final String[] vals = line.split(","); + /* + private void addCSVData(final File file, final String line) { + final String[] vals = line.split(","); - // Check if the data line is malformed, for example if the read group string contains a comma then it won't be parsed correctly - if( vals.length != requestedCovariates.size() + 3 ) { // +3 because of nObservations, nMismatch, and Qempirical - throw new UserException.MalformedFile(file, "Malformed input recalibration file. Found data line with too many fields: " + line + - " --Perhaps the read group string contains a comma and isn't being parsed correctly."); + // Check if the data line is malformed, for example if the read group string contains a comma then it won't be parsed correctly + if( vals.length != requestedCovariates.size() + 3 ) { // +3 because of nObservations, nMismatch, and Qempirical + throw new UserException.MalformedFile(file, "Malformed input recalibration file. Found data line with too many fields: " + line + + " --Perhaps the read group string contains a comma and isn't being parsed correctly."); + } + + final Object[] key = new Object[requestedCovariates.size()]; + Covariate cov; + int iii; + for( iii = 0; iii < requestedCovariates.size(); iii++ ) { + cov = requestedCovariates.get( iii ); + key[iii] = cov.getValue( vals[iii] ); + } + + // Create a new datum using the number of observations, number of mismatches, and reported quality score + final RecalDatum datum = new RecalDatum( Long.parseLong( vals[iii] ), Long.parseLong( vals[iii + 1] ), Double.parseDouble( vals[1] ), 0.0 ); + // Add that datum to all the collapsed tables which will be used in the sequential calculation + dataManager.addToAllTables( key, datum, PRESERVE_QSCORES_LESS_THAN ); } - final Object[] key = new Object[requestedCovariates.size()]; - Covariate cov; - int iii; - for( iii = 0; iii < requestedCovariates.size(); iii++ ) { - cov = requestedCovariates.get( iii ); - key[iii] = cov.getValue( vals[iii] ); - } - - // Create a new datum using the number of observations, number of mismatches, and reported quality score - final RecalDatum datum = new RecalDatum( Long.parseLong( vals[iii] ), Long.parseLong( vals[iii + 1] ), Double.parseDouble( vals[1] ), 0.0 ); - // Add that datum to all the collapsed tables which will be used in the sequential calculation - dataManager.addToAllTables( key, datum, PRESERVE_QSCORES_LESS_THAN ); - } - -*/ + */ public PairHMMIndelErrorModel(double indelGOP, double indelGCP, boolean deb, boolean doCDP, boolean dovit) { this(indelGOP, indelGCP, deb, doCDP); this.doViterbi = dovit; @@ -588,7 +589,7 @@ public class PairHMMIndelErrorModel { } else { c = currentGOP[jm1]; - d = currentGCP[jm1]; + d = currentGCP[jm1]; } if (indI == X_METRIC_LENGTH-1) c = d = END_GAP_COST; @@ -707,12 +708,12 @@ public class PairHMMIndelErrorModel { } } public synchronized double[] computeReadHaplotypeLikelihoods(ReadBackedPileup pileup, LinkedHashMap haplotypeMap, - ReferenceContext ref, int eventLength, - HashMap> indelLikelihoodMap){ + ReferenceContext ref, int eventLength, + HashMap> indelLikelihoodMap){ int numHaplotypes = haplotypeMap.size(); - double[][] haplotypeLikehoodMatrix = new double[numHaplotypes][numHaplotypes]; - double readLikelihoods[][] = new double[pileup.getReads().size()][numHaplotypes]; + final double readLikelihoods[][] = new double[pileup.size()][numHaplotypes]; + final int readCounts[] = new int[pileup.size()]; int readIdx=0; LinkedHashMap gapOpenProbabilityMap = new LinkedHashMap(); @@ -751,6 +752,9 @@ public class PairHMMIndelErrorModel { } } for (PileupElement p: pileup) { + // > 1 when the read is a consensus read representing multiple independent observations + final boolean isReduced = ReadUtils.isReducedRead(p.getRead()); + readCounts[readIdx] = isReduced ? p.getReducedCount() : 1; // check if we've already computed likelihoods for this pileup element (i.e. for this read at this location) if (indelLikelihoodMap.containsKey(p)) { @@ -762,61 +766,65 @@ public class PairHMMIndelErrorModel { } else { //System.out.format("%d %s\n",p.getRead().getAlignmentStart(), p.getRead().getClass().getName()); - GATKSAMRecord read = ReadUtils.hardClipAdaptorSequence(p.getRead()); + SAMRecord read = ReadUtils.hardClipAdaptorSequence(p.getRead()); if (read == null) continue; + if ( isReduced ) { + read = ReadUtils.reducedReadWithReducedQuals(read); + } + if(ReadUtils.is454Read(read) && !getGapPenaltiesFromFile) { continue; } double[] recalQuals = null; - /* - if (getGapPenaltiesFromFile) { - RecalDataManager.parseSAMRecord( read, RAC ); + /* + if (getGapPenaltiesFromFile) { + RecalDataManager.parseSAMRecord( read, RAC ); - recalQuals = new double[read.getReadLength()]; + recalQuals = new double[read.getReadLength()]; - //compute all covariate values for this read - final Comparable[][] covariateValues_offset_x_covar = - RecalDataManager.computeCovariates((GATKSAMRecord) read, requestedCovariates); - // For each base in the read - for( int offset = 0; offset < read.getReadLength(); offset++ ) { + //compute all covariate values for this read + final Comparable[][] covariateValues_offset_x_covar = + RecalDataManager.computeCovariates((GATKSAMRecord) read, requestedCovariates); + // For each base in the read + for( int offset = 0; offset < read.getReadLength(); offset++ ) { - final Object[] fullCovariateKey = covariateValues_offset_x_covar[offset]; + final Object[] fullCovariateKey = covariateValues_offset_x_covar[offset]; - Byte qualityScore = (Byte) qualityScoreByFullCovariateKey.get(fullCovariateKey); - if(qualityScore == null) - { - qualityScore = performSequentialQualityCalculation( fullCovariateKey ); - qualityScoreByFullCovariateKey.put(qualityScore, fullCovariateKey); - } + Byte qualityScore = (Byte) qualityScoreByFullCovariateKey.get(fullCovariateKey); + if(qualityScore == null) + { + qualityScore = performSequentialQualityCalculation( fullCovariateKey ); + qualityScoreByFullCovariateKey.put(qualityScore, fullCovariateKey); + } - recalQuals[offset] = -((double)qualityScore)/10.0; - } + recalQuals[offset] = -((double)qualityScore)/10.0; + } - // for each read/haplotype combination, compute likelihoods, ie -10*log10(Pr(R | Hi)) - // = sum_j(-10*log10(Pr(R_j | Hi) since reads are assumed to be independent - if (DEBUG) { - System.out.format("\n\nStarting read:%s S:%d US:%d E:%d UE:%d C:%s\n",read.getReadName(), - read.getAlignmentStart(), - read.getUnclippedStart(), read.getAlignmentEnd(), read.getUnclippedEnd(), - read.getCigarString()); + // for each read/haplotype combination, compute likelihoods, ie -10*log10(Pr(R | Hi)) + // = sum_j(-10*log10(Pr(R_j | Hi) since reads are assumed to be independent + if (DEBUG) { + System.out.format("\n\nStarting read:%s S:%d US:%d E:%d UE:%d C:%s\n",read.getReadName(), + read.getAlignmentStart(), + read.getUnclippedStart(), read.getAlignmentEnd(), read.getUnclippedEnd(), + read.getCigarString()); - byte[] bases = read.getReadBases(); - for (int k = 0; k < recalQuals.length; k++) { - System.out.format("%c",bases[k]); - } - System.out.println(); + byte[] bases = read.getReadBases(); + for (int k = 0; k < recalQuals.length; k++) { + System.out.format("%c",bases[k]); + } + System.out.println(); - for (int k = 0; k < recalQuals.length; k++) { - System.out.format("%.0f ",recalQuals[k]); - } - System.out.println(); - } - } */ + for (int k = 0; k < recalQuals.length; k++) { + System.out.format("%.0f ",recalQuals[k]); + } + System.out.println(); + } + } */ // get bases of candidate haplotypes that overlap with reads final int trailingBases = 3; @@ -971,7 +979,7 @@ public class PairHMMIndelErrorModel { System.out.println(new String(haplotypeBases)); } - Double readLikelihood = 0.0; + double readLikelihood = 0.0; if (useAffineGapModel) { double[] currentContextGOP = null; @@ -979,14 +987,14 @@ public class PairHMMIndelErrorModel { if (doContextDependentPenalties) { - if (getGapPenaltiesFromFile) { - readLikelihood = computeReadLikelihoodGivenHaplotypeAffineGaps(haplotypeBases, readBases, readQuals, recalCDP, null); + if (getGapPenaltiesFromFile) { + readLikelihood = computeReadLikelihoodGivenHaplotypeAffineGaps(haplotypeBases, readBases, readQuals, recalCDP, null); - } else { - currentContextGOP = Arrays.copyOfRange(gapOpenProbabilityMap.get(a), (int)indStart, (int)indStop); - currentContextGCP = Arrays.copyOfRange(gapContProbabilityMap.get(a), (int)indStart, (int)indStop); - readLikelihood = computeReadLikelihoodGivenHaplotypeAffineGaps(haplotypeBases, readBases, readQuals, currentContextGOP, currentContextGCP); - } + } else { + currentContextGOP = Arrays.copyOfRange(gapOpenProbabilityMap.get(a), (int)indStart, (int)indStop); + currentContextGCP = Arrays.copyOfRange(gapContProbabilityMap.get(a), (int)indStart, (int)indStop); + readLikelihood = computeReadLikelihoodGivenHaplotypeAffineGaps(haplotypeBases, readBases, readQuals, currentContextGOP, currentContextGCP); + } } } @@ -1004,7 +1012,7 @@ public class PairHMMIndelErrorModel { if (DEBUG) { System.out.println("\nLikelihood summary"); - for (readIdx=0; readIdx < pileup.getReads().size(); readIdx++) { + for (readIdx=0; readIdx < pileup.size(); readIdx++) { System.out.format("Read Index: %d ",readIdx); for (int i=0; i < readLikelihoods[readIdx].length; i++) System.out.format("L%d: %f ",i,readLikelihoods[readIdx][i]); @@ -1012,36 +1020,35 @@ public class PairHMMIndelErrorModel { } } + + return getHaplotypeLikelihoods(numHaplotypes, readCounts, readLikelihoods); + } + + private final static double[] getHaplotypeLikelihoods(final int numHaplotypes, final int readCounts[], final double readLikelihoods[][]) { + final double[][] haplotypeLikehoodMatrix = new double[numHaplotypes][numHaplotypes]; + + // todo: MAD 09/26/11 -- I'm almost certain this calculation can be simplied to just a single loop without the intermediate NxN matrix for (int i=0; i < numHaplotypes; i++) { for (int j=i; j < numHaplotypes; j++){ // combine likelihoods of haplotypeLikelihoods[i], haplotypeLikelihoods[j] // L(Hi, Hj) = sum_reads ( Pr(R|Hi)/2 + Pr(R|Hj)/2) //readLikelihoods[k][j] has log10(Pr(R_k) | H[j] ) - for (readIdx=0; readIdx < pileup.getReads().size(); readIdx++) { - + for (int readIdx = 0; readIdx < readLikelihoods.length; readIdx++) { // Compute log10(10^x1/2 + 10^x2/2) = log10(10^x1+10^x2)-log10(2) // First term is approximated by Jacobian log with table lookup. if (Double.isInfinite(readLikelihoods[readIdx][i]) && Double.isInfinite(readLikelihoods[readIdx][j])) continue; - haplotypeLikehoodMatrix[i][j] += ( MathUtils.softMax(readLikelihoods[readIdx][i], - readLikelihoods[readIdx][j]) + LOG_ONE_HALF); - + final double li = readLikelihoods[readIdx][i]; + final double lj = readLikelihoods[readIdx][j]; + final int readCount = readCounts[readIdx]; + haplotypeLikehoodMatrix[i][j] += readCount * (MathUtils.softMax(li, lj) + LOG_ONE_HALF); } - - } } - return getHaplotypeLikelihoods(haplotypeLikehoodMatrix); - - } - - public static double[] getHaplotypeLikelihoods(double[][] haplotypeLikehoodMatrix) { - int hSize = haplotypeLikehoodMatrix.length; - double[] genotypeLikelihoods = new double[hSize*(hSize+1)/2]; - + final double[] genotypeLikelihoods = new double[numHaplotypes*(numHaplotypes+1)/2]; int k=0; - for (int j=0; j < hSize; j++) { + for (int j=0; j < numHaplotypes; j++) { for (int i=0; i <= j; i++){ genotypeLikelihoods[k++] = haplotypeLikehoodMatrix[i][j]; } @@ -1066,63 +1073,63 @@ public class PairHMMIndelErrorModel { * @param key The list of Comparables that were calculated from the covariates * @return A recalibrated quality score as a byte */ - /* - private byte performSequentialQualityCalculation( final Object... key ) { + /* + private byte performSequentialQualityCalculation( final Object... key ) { - final byte qualFromRead = (byte)Integer.parseInt(key[1].toString()); - final Object[] readGroupCollapsedKey = new Object[1]; - final Object[] qualityScoreCollapsedKey = new Object[2]; - final Object[] covariateCollapsedKey = new Object[3]; + final byte qualFromRead = (byte)Integer.parseInt(key[1].toString()); + final Object[] readGroupCollapsedKey = new Object[1]; + final Object[] qualityScoreCollapsedKey = new Object[2]; + final Object[] covariateCollapsedKey = new Object[3]; - // The global quality shift (over the read group only) - readGroupCollapsedKey[0] = key[0]; - final RecalDatum globalRecalDatum = ((RecalDatum)dataManager.getCollapsedTable(0).get( readGroupCollapsedKey )); - double globalDeltaQ = 0.0; - if( globalRecalDatum != null ) { - final double globalDeltaQEmpirical = globalRecalDatum.getEmpiricalQuality(); - final double aggregrateQReported = globalRecalDatum.getEstimatedQReported(); - globalDeltaQ = globalDeltaQEmpirical - aggregrateQReported; - } - - // The shift in quality between reported and empirical - qualityScoreCollapsedKey[0] = key[0]; - qualityScoreCollapsedKey[1] = key[1]; - final RecalDatum qReportedRecalDatum = ((RecalDatum)dataManager.getCollapsedTable(1).get( qualityScoreCollapsedKey )); - double deltaQReported = 0.0; - if( qReportedRecalDatum != null ) { - final double deltaQReportedEmpirical = qReportedRecalDatum.getEmpiricalQuality(); - deltaQReported = deltaQReportedEmpirical - qualFromRead - globalDeltaQ; - } - - // The shift in quality due to each covariate by itself in turn - double deltaQCovariates = 0.0; - double deltaQCovariateEmpirical; - covariateCollapsedKey[0] = key[0]; - covariateCollapsedKey[1] = key[1]; - for( int iii = 2; iii < key.length; iii++ ) { - covariateCollapsedKey[2] = key[iii]; // The given covariate - final RecalDatum covariateRecalDatum = ((RecalDatum)dataManager.getCollapsedTable(iii).get( covariateCollapsedKey )); - if( covariateRecalDatum != null ) { - deltaQCovariateEmpirical = covariateRecalDatum.getEmpiricalQuality(); - deltaQCovariates += ( deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported) ); + // The global quality shift (over the read group only) + readGroupCollapsedKey[0] = key[0]; + final RecalDatum globalRecalDatum = ((RecalDatum)dataManager.getCollapsedTable(0).get( readGroupCollapsedKey )); + double globalDeltaQ = 0.0; + if( globalRecalDatum != null ) { + final double globalDeltaQEmpirical = globalRecalDatum.getEmpiricalQuality(); + final double aggregrateQReported = globalRecalDatum.getEstimatedQReported(); + globalDeltaQ = globalDeltaQEmpirical - aggregrateQReported; } + + // The shift in quality between reported and empirical + qualityScoreCollapsedKey[0] = key[0]; + qualityScoreCollapsedKey[1] = key[1]; + final RecalDatum qReportedRecalDatum = ((RecalDatum)dataManager.getCollapsedTable(1).get( qualityScoreCollapsedKey )); + double deltaQReported = 0.0; + if( qReportedRecalDatum != null ) { + final double deltaQReportedEmpirical = qReportedRecalDatum.getEmpiricalQuality(); + deltaQReported = deltaQReportedEmpirical - qualFromRead - globalDeltaQ; + } + + // The shift in quality due to each covariate by itself in turn + double deltaQCovariates = 0.0; + double deltaQCovariateEmpirical; + covariateCollapsedKey[0] = key[0]; + covariateCollapsedKey[1] = key[1]; + for( int iii = 2; iii < key.length; iii++ ) { + covariateCollapsedKey[2] = key[iii]; // The given covariate + final RecalDatum covariateRecalDatum = ((RecalDatum)dataManager.getCollapsedTable(iii).get( covariateCollapsedKey )); + if( covariateRecalDatum != null ) { + deltaQCovariateEmpirical = covariateRecalDatum.getEmpiricalQuality(); + deltaQCovariates += ( deltaQCovariateEmpirical - qualFromRead - (globalDeltaQ + deltaQReported) ); + } + } + + final double newQuality = qualFromRead + globalDeltaQ + deltaQReported + deltaQCovariates; + return QualityUtils.boundQual( (int)Math.round(newQuality), (byte)MAX_QUALITY_SCORE ); + + // Verbose printouts used to validate with old recalibrator + //if(key.contains(null)) { + // System.out.println( key + String.format(" => %d + %.2f + %.2f + %.2f + %.2f = %d", + // qualFromRead, globalDeltaQ, deltaQReported, deltaQPos, deltaQDinuc, newQualityByte)); + //} + //else { + // System.out.println( String.format("%s %s %s %s => %d + %.2f + %.2f + %.2f + %.2f = %d", + // key.get(0).toString(), key.get(3).toString(), key.get(2).toString(), key.get(1).toString(), qualFromRead, globalDeltaQ, deltaQReported, deltaQPos, deltaQDinuc, newQualityByte) ); + //} + + //return newQualityByte; + } - - final double newQuality = qualFromRead + globalDeltaQ + deltaQReported + deltaQCovariates; - return QualityUtils.boundQual( (int)Math.round(newQuality), (byte)MAX_QUALITY_SCORE ); - - // Verbose printouts used to validate with old recalibrator - //if(key.contains(null)) { - // System.out.println( key + String.format(" => %d + %.2f + %.2f + %.2f + %.2f = %d", - // qualFromRead, globalDeltaQ, deltaQReported, deltaQPos, deltaQDinuc, newQualityByte)); - //} - //else { - // System.out.println( String.format("%s %s %s %s => %d + %.2f + %.2f + %.2f + %.2f = %d", - // key.get(0).toString(), key.get(3).toString(), key.get(2).toString(), key.get(1).toString(), qualFromRead, globalDeltaQ, deltaQReported, deltaQPos, deltaQDinuc, newQualityByte) ); - //} - - //return newQualityByte; - - } -*/ + */ }