package org.broadinstitute.sting.utils.baq; import net.sf.samtools.SAMRecord; import net.sf.samtools.CigarElement; import net.sf.samtools.CigarOperator; import net.sf.picard.reference.IndexedFastaSequenceFile; import net.sf.picard.reference.ReferenceSequence; import org.broadinstitute.sting.utils.exceptions.ReviewedStingException; import org.broadinstitute.sting.utils.BaseUtils; import java.util.Arrays; /* The topology of the profile HMM: /\ /\ /\ /\ I[1] I[k-1] I[k] I[L] ^ \ \ ^ \ ^ \ \ ^ | \ \ | \ | \ \ | M[0] M[1] -> ... -> M[k-1] -> M[k] -> ... -> M[L] M[L+1] \ \/ \/ \/ / \ /\ /\ /\ / -> D[k-1] -> D[k] -> M[0] points to every {M,I}[k] and every {M,I}[k] points M[L+1]. On input, _ref is the reference sequence and _query is the query sequence. Both are sequences of 0/1/2/3/4 where 4 stands for an ambiguous residue. iqual is the base quality. c sets the gap open probability, gap extension probability and band width. On output, state and q are arrays of length l_query. The higher 30 bits give the reference position the query base is matched to and the lower two bits can be 0 (an alignment match) or 1 (an insertion). q[i] gives the phred scaled posterior probability of state[i] being wrong. */ public class BAQ { private final static boolean DEBUG = false; public enum CalculationMode { NONE, // don't apply a BAQ at all, the default CALCULATE_AS_NECESSARY, // do HMM BAQ calculation on the fly, as necessary, if there's no tag RECALCULATE // do HMM BAQ calculation on the fly, regardless of whether there's a tag present } /** these are features that only the walker can override */ public enum QualityMode { ADD_TAG, // calculate the BAQ, but write it into the reads as the BAQ tag, leaving QUAL field alone OVERWRITE_QUALS, // overwrite the quality field directly DONT_MODIFY // do the BAQ, but don't modify the quality scores themselves, just return them in the function. } public enum ApplicationTime { FORBIDDEN, // Walker does not tolerate BAQ input ON_INPUT, // apply the BAQ calculation to the incoming reads, the default ON_OUTPUT, // apply the BAQ calculation to outgoing read streams HANDLED_IN_WALKER // the walker will deal with the BAQ calculation status itself } public static final String BAQ_TAG = "BQ"; private static double[] qual2prob = new double[256]; static { for (int i = 0; i < 256; ++i) qual2prob[i] = Math.pow(10, -i/10.); } public static double DEFAULT_GOP = 1e-3; // todo -- make me final private public double cd = -1; // gap open probility [1e-3] private double ce = 0.1; // gap extension probability [0.1] private int cb = 7; // band width [7] public byte getMinBaseQual() { return minBaseQual; } /** * Any bases with Q < MIN_BASE_QUAL are raised up to this base quality */ private byte minBaseQual = 4; public double getGapOpenProb() { return cd; } public double getGapExtensionProb() { return ce; } public int getBandWidth() { return cb; } /** * Use defaults for everything */ public BAQ() { cd = DEFAULT_GOP; initializeCachedData(); } /** * Create a new HmmGlocal object with specified parameters * * @param d gap open prob. * @param e gap extension prob. * @param b band width * @param minBaseQual All bases with Q < minBaseQual are up'd to this value */ public BAQ(final double d, final double e, final int b, final byte minBaseQual) { cd = d; ce = e; cb = b; this.minBaseQual = minBaseQual; initializeCachedData(); } private final static double EM = 0.33333333333; private final static double EI = 0.25; private double[][][] EPSILONS = new double[256][256][64]; private void initializeCachedData() { for ( int i = 0; i < 256; i++ ) for ( int j = 0; j < 256; j++ ) for ( int q = 0; q < 64; q++ ) { double qual = qual2prob[q < minBaseQual ? minBaseQual : q]; EPSILONS[i][j][q] = 1.0; } for ( char b1 : "ACGTacgt".toCharArray() ) { for ( char b2 : "ACGTacgt".toCharArray() ) { for ( int q = 0; q < 64; q++ ) { double qual = qual2prob[q < minBaseQual ? minBaseQual : q]; double e = Character.toLowerCase(b1) == Character.toLowerCase(b2) ? 1 - qual : qual * EM; EPSILONS[(byte)b1][(byte)b2][q] = e; } } } } private double calcEpsilon( byte ref, byte read, byte qualB ) { return EPSILONS[ref][read][qualB]; } // private double calcEpsilon( byte ref, byte read, byte qualB ) { // double qual = qual2prob[qualB < minBaseQual ? minBaseQual : qualB]; // return (ref > 3 || read > 3)? 1. : ref == read ? 1. - qual : qual * EM; // } // #################################################################################################### // // NOTE -- THIS CODE IS SYNCHRONIZED WITH CODE IN THE SAMTOOLS REPOSITORY. CHANGES TO THIS CODE SHOULD BE // NOTE -- PUSHED BACK TO HENG LI // // Note that _ref and _query are in the special 0-4 encoding [see above for docs] // // #################################################################################################### public int hmm_glocal(final byte[] ref, final byte[] query, int qstart, int l_query, final byte[] _iqual, int[] state, byte[] q) { if ( ref == null ) throw new ReviewedStingException("BUG: ref sequence is null"); if ( query == null ) throw new ReviewedStingException("BUG: query sequence is null"); if ( _iqual == null ) throw new ReviewedStingException("BUG: query quality vector is null"); if ( query.length != _iqual.length ) throw new ReviewedStingException("BUG: read sequence length != qual length"); //if ( q != null && q.length != state.length ) throw new ReviewedStingException("BUG: BAQ quality length != read sequence length"); //if ( state != null && state.length != l_query ) throw new ReviewedStingException("BUG: state length != read sequence length"); int i, k; /*** initialization ***/ // change coordinates int l_ref = ref.length; //int l_query = query.length; // set band width int bw2, bw = l_ref > l_query? l_ref : l_query; if (bw > cb) bw = cb; if (bw < Math.abs(l_ref - l_query)) bw = Math.abs(l_ref - l_query); bw2 = bw * 2 + 1; // allocate the forward and backward matrices f[][] and b[][] and the scaling array s[] double[][] f = new double[l_query+1][bw2*3 + 6]; double[][] b = new double[l_query+1][bw2*3 + 6]; double[] s = new double[l_query+2]; // initialize transition probabilities double sM, sI, bM, bI; sM = sI = 1. / (2 * l_query + 2); bM = (1 - cd) / l_query; bI = cd / l_query; // (bM+bI)*l_query==1 double[] m = new double[9]; m[0*3+0] = (1 - cd - cd) * (1 - sM); m[0*3+1] = m[0*3+2] = cd * (1 - sM); m[1*3+0] = (1 - ce) * (1 - sI); m[1*3+1] = ce * (1 - sI); m[1*3+2] = 0.; m[2*3+0] = 1 - ce; m[2*3+1] = 0.; m[2*3+2] = ce; /*** forward ***/ // f[0] f[0][set_u(bw, 0, 0)] = s[0] = 1.; { // f[1] double[] fi = f[1]; double sum; int beg = 1, end = l_ref < bw + 1? l_ref : bw + 1, _beg, _end; for (k = beg, sum = 0.; k <= end; ++k) { int u; double e = calcEpsilon(ref[k-1], query[qstart], _iqual[qstart]); u = set_u(bw, 1, k); fi[u+0] = e * bM; fi[u+1] = EI * bI; sum += fi[u] + fi[u+1]; } // rescale s[1] = sum; _beg = set_u(bw, 1, beg); _end = set_u(bw, 1, end); _end += 2; for (k = _beg; k <= _end; ++k) fi[k] /= sum; } // f[2..l_query] for (i = 2; i <= l_query; ++i) { double[] fi = f[i], fi1 = f[i-1]; double sum; int beg = 1, end = l_ref, x, _beg, _end; byte qyi = query[qstart+i-1]; x = i - bw; beg = beg > x? beg : x; // band start x = i + bw; end = end < x? end : x; // band end for (k = beg, sum = 0.; k <= end; ++k) { int u, v11, v01, v10; double e = calcEpsilon(ref[k-1], qyi, _iqual[qstart+i-1]); u = set_u(bw, i, k); v11 = set_u(bw, i-1, k-1); v10 = set_u(bw, i-1, k); v01 = set_u(bw, i, k-1); fi[u+0] = e * (m[0] * fi1[v11+0] + m[3] * fi1[v11+1] + m[6] * fi1[v11+2]); fi[u+1] = EI * (m[1] * fi1[v10+0] + m[4] * fi1[v10+1]); fi[u+2] = m[2] * fi[v01+0] + m[8] * fi[v01+2]; sum += fi[u] + fi[u+1] + fi[u+2]; //System.out.println("("+i+","+k+";"+u+"): "+fi[u]+","+fi[u+1]+","+fi[u+2]); } // rescale s[i] = sum; _beg = set_u(bw, i, beg); _end = set_u(bw, i, end); _end += 2; for (k = _beg, sum = 1./sum; k <= _end; ++k) fi[k] *= sum; } { // f[l_query+1] double sum; for (k = 1, sum = 0.; k <= l_ref; ++k) { int u = set_u(bw, l_query, k); if (u < 3 || u >= bw2*3+3) continue; sum += f[l_query][u+0] * sM + f[l_query][u+1] * sI; } s[l_query+1] = sum; // the last scaling factor } /*** backward ***/ // b[l_query] (b[l_query+1][0]=1 and thus \tilde{b}[][]=1/s[l_query+1]; this is where s[l_query+1] comes from) for (k = 1; k <= l_ref; ++k) { int u = set_u(bw, l_query, k); double[] bi = b[l_query]; if (u < 3 || u >= bw2*3+3) continue; bi[u+0] = sM / s[l_query] / s[l_query+1]; bi[u+1] = sI / s[l_query] / s[l_query+1]; } // b[l_query-1..1] for (i = l_query - 1; i >= 1; --i) { int beg = 1, end = l_ref, x, _beg, _end; double[] bi = b[i], bi1 = b[i+1]; double y = (i > 1)? 1. : 0.; byte qyi1 = query[qstart+i]; x = i - bw; beg = beg > x? beg : x; x = i + bw; end = end < x? end : x; for (k = end; k >= beg; --k) { int u, v11, v01, v10; u = set_u(bw, i, k); v11 = set_u(bw, i+1, k+1); v10 = set_u(bw, i+1, k); v01 = set_u(bw, i, k+1); double e = (k >= l_ref? 0 : calcEpsilon(ref[k], qyi1, _iqual[qstart+i])) * bi1[v11]; bi[u+0] = e * m[0] + EI * m[1] * bi1[v10+1] + m[2] * bi[v01+2]; // bi1[v11] has been foled into e. bi[u+1] = e * m[3] + EI * m[4] * bi1[v10+1]; bi[u+2] = (e * m[6] + m[8] * bi[v01+2]) * y; } // rescale _beg = set_u(bw, i, beg); _end = set_u(bw, i, end); _end += 2; for (k = _beg, y = 1./s[i]; k <= _end; ++k) bi[k] *= y; } // TODO -- this appears to be a null operation overall. For debugging only? double pb; { // b[0] int beg = 1, end = l_ref < bw + 1? l_ref : bw + 1; double sum = 0.; for (k = end; k >= beg; --k) { int u = set_u(bw, 1, k); double e = calcEpsilon(ref[k-1], query[qstart], _iqual[qstart]); if (u < 3 || u >= bw2*3+3) continue; sum += e * b[1][u+0] * bM + EI * b[1][u+1] * bI; } pb = b[0][set_u(bw, 0, 0)] = sum / s[0]; // if everything works as is expected, pb == 1.0 } /*** MAP ***/ for (i = 1; i <= l_query; ++i) { double sum = 0., max = 0.; double[] fi = f[i], bi = b[i]; int beg = 1, end = l_ref, x, max_k = -1; x = i - bw; beg = beg > x? beg : x; x = i + bw; end = end < x? end : x; for (k = beg; k <= end; ++k) { int u = set_u(bw, i, k); double z; sum += (z = fi[u+0] * bi[u+0]); if (z > max) { max = z; max_k = (k-1)<<2 | 0; } sum += (z = fi[u+1] * bi[u+1]); if (z > max) { max = z; max_k = (k-1)<<2 | 1; } } max /= sum; sum *= s[i]; // if everything works as is expected, sum == 1.0 if (state != null) state[qstart+i-1] = max_k; if (q != null) { k = (int)(-4.343 * Math.log(1. - max) + .499); q[qstart+i-1] = (byte)(k > 100? 99 : (k < minBaseQual ? minBaseQual : k)); } //System.out.println("("+pb+","+sum+")"+" ("+(i-1)+","+(max_k>>2)+","+(max_k&3)+","+max+")"); } return 0; } // --------------------------------------------------------------------------------------------------------------- // // Helper routines // // --------------------------------------------------------------------------------------------------------------- /** decode the bit encoded state array values */ public static boolean stateIsIndel(int state) { return (state & 3) != 0; } /** decode the bit encoded state array values */ public static int stateAlignedPosition(int state) { return state >> 2; } /** * helper routine for hmm_glocal * * @param b * @param i * @param k * @return */ private static int set_u(final int b, final int i, final int k) { int x = i - b; x = x > 0 ? x : 0; return (k + 1 - x) * 3; } // --------------------------------------------------------------------------------------------------------------- // // Actually working with the BAQ tag now // // --------------------------------------------------------------------------------------------------------------- /** * Get the BAQ attribute from the tag in read. Returns null if no BAQ tag is present. * @param read * @return */ public static byte[] getBAQTag(SAMRecord read) { String s = read.getStringAttribute(BAQ_TAG); return s != null ? s.getBytes() : null; } public static String encodeBQTag(SAMRecord read, byte[] baq) { // Offset to base alignment quality (BAQ), of the same length as the read sequence. // At the i-th read base, BAQi = Qi - (BQi - 64) where Qi is the i-th base quality. // so BQi = Qi - BAQi + 64 byte[] bqTag = new byte[baq.length]; for ( int i = 0; i < bqTag.length; i++) bqTag[i] = (byte)(((int)read.getBaseQualities()[i] + 64) - baq[i]); return new String(bqTag); } public static void addBAQTag(SAMRecord read, byte[] baq) { read.setAttribute(BAQ_TAG, encodeBQTag(read, baq)); } /** * Returns true if the read has a BAQ tag, or false otherwise * @param read * @return */ public static boolean hasBAQTag(SAMRecord read) { return read.getStringAttribute(BAQ_TAG) != null; } /** * Returns a new qual array for read that includes the BAQ adjusted. Does not support on-the-fly BAQ calculation * * @param read the SAMRecord to operate on * @param overwriteOriginalQuals If true, we replace the original qualities scores in the read with their BAQ'd version * @param useRawQualsIfNoBAQTag If useRawQualsIfNoBAQTag is true, then if there's no BAQ annotation we just use the raw quality scores. Throws IllegalStateException is false and no BAQ tag is present * @return */ public static byte[] calcBAQFromTag(SAMRecord read, boolean overwriteOriginalQuals, boolean useRawQualsIfNoBAQTag) { byte[] rawQuals = read.getBaseQualities(); byte[] newQuals = rawQuals; byte[] baq = getBAQTag(read); if ( baq != null ) { // Offset to base alignment quality (BAQ), of the same length as the read sequence. // At the i-th read base, BAQi = Qi - (BQi - 64) where Qi is the i-th base quality. newQuals = overwriteOriginalQuals ? rawQuals : new byte[rawQuals.length]; for ( int i = 0; i < rawQuals.length; i++) { int val = rawQuals[i] - (baq[i] - 64); newQuals[i] = val < 0 ? 0 : (byte)val; } } else if ( ! useRawQualsIfNoBAQTag ) { throw new IllegalStateException("Required BAQ tag to be present, but none was on read " + read.getReadName()); } return newQuals; } public static class BAQCalculationResult { public byte[] refBases, rawQuals, readBases, bq; public int[] state; public BAQCalculationResult(SAMRecord read, byte[] ref) { this(read.getBaseQualities(), read.getReadBases(), ref); } public BAQCalculationResult(byte[] bases, byte[] quals, byte[] ref) { // prepares data for calculation rawQuals = quals; readBases = bases; // now actually prepare the data structures, and fire up the hmm bq = new byte[rawQuals.length]; state = new int[rawQuals.length]; this.refBases = ref; } } private static int getFirstInsertionOffset(SAMRecord read) { CigarElement e = read.getCigar().getCigarElement(0); if ( e.getOperator() == CigarOperator.I ) return e.getLength(); else return 0; } private static int getLastInsertionOffset(SAMRecord read) { CigarElement e = read.getCigar().getCigarElement(read.getCigarLength()-1); if ( e.getOperator() == CigarOperator.I ) return e.getLength(); else return 0; } public BAQCalculationResult calcBAQFromHMM(SAMRecord read, IndexedFastaSequenceFile refReader) { // start is alignment start - band width / 2 - size of first I element, if there is one. Stop is similar int offset = getBandWidth() / 2; long start = Math.max(read.getAlignmentStart() - offset - getFirstInsertionOffset(read), 0); long stop = read.getAlignmentEnd() + offset + getLastInsertionOffset(read); if ( stop > refReader.getSequenceDictionary().getSequence(read.getReferenceName()).getSequenceLength() ) { return null; } else { // now that we have the start and stop, get the reference sequence covering it ReferenceSequence refSeq = refReader.getSubsequenceAt(read.getReferenceName(), start, stop); return calcBAQFromHMM(read, refSeq.getBases(), (int)(start - read.getAlignmentStart())); } } public BAQCalculationResult calcBAQFromHMM(byte[] ref, byte[] query, byte[] quals, int queryStart, int queryEnd ) { // note -- assumes ref is offset from the *CLIPPED* start BAQCalculationResult baqResult = new BAQCalculationResult(query, quals, ref); int queryLen = queryEnd - queryStart; hmm_glocal(baqResult.refBases, baqResult.readBases, queryStart, queryLen, baqResult.rawQuals, baqResult.state, baqResult.bq); return baqResult; } // we need to bad ref by at least the bandwidth / 2 on either side public BAQCalculationResult calcBAQFromHMM(SAMRecord read, byte[] ref, int refOffset) { int queryStart = (int)(read.getAlignmentStart() - read.getUnclippedStart()); int queryEnd = (int)(read.getReadLength() - (read.getUnclippedEnd() - read.getAlignmentEnd())); BAQCalculationResult baqResult = calcBAQFromHMM(ref, read.getReadBases(), read.getBaseQualities(), queryStart, queryEnd); // cap quals int readI = 0, refI = 0; for ( CigarElement elt : read.getCigar().getCigarElements() ) { int l = elt.getLength(); switch (elt.getOperator()) { case N: // cannot handle these return null; case H : case P : // ignore pads and hard clips break; case I : case S : // todo -- is it really the case that we want to treat I and S the same? for ( int i = readI; i < readI + l; i++ ) baqResult.bq[i] = baqResult.rawQuals[i]; readI += l; break; case D : refI += l; break; case M : for (int i = readI; i < readI + l; i++) { int expectedPos = refI - refOffset + (i - readI); baqResult.bq[i] = capBaseByBAQ( baqResult.rawQuals[i], baqResult.bq[i], baqResult.state[i], expectedPos ); } readI += l; refI += l; break; default: throw new ReviewedStingException("BUG: Unexpected CIGAR element " + elt + " in read " + read.getReadName()); } } return baqResult; } public byte capBaseByBAQ( byte oq, byte bq, int state, int expectedPos ) { byte b; boolean isIndel = stateIsIndel(state); int pos = stateAlignedPosition(state); if ( isIndel || pos != expectedPos ) // we are an indel or we don't align to our best current position b = minBaseQual; // just take b = minBaseQuality else b = bq < oq ? bq : oq; return b; } /** * Modifies read in place so that the base quality scores are capped by the BAQ calculation. Uses the BAQ * tag if present already and alwaysRecalculate is false, otherwise fires up the HMM and does the BAQ on the fly * using the refReader to obtain the reference bases as needed. * * @param read * @param refReader * @param calculationType * @return BQ qualities for use, in case qmode is DONT_MODIFY */ public byte[] baqRead(SAMRecord read, IndexedFastaSequenceFile refReader, CalculationMode calculationType, QualityMode qmode ) { if ( DEBUG ) System.out.printf("BAQ %s read %s%n", calculationType, read.getReadName()); byte[] BAQQuals = read.getBaseQualities(); // in general we are overwriting quals, so just get a pointer to them if ( calculationType == CalculationMode.NONE ) { // we don't want to do anything ; // just fall though } else if ( excludeReadFromBAQ(read) ) { ; // just fall through } else { if ( calculationType == CalculationMode.RECALCULATE || ! hasBAQTag(read) ) { if ( DEBUG ) System.out.printf(" Calculating BAQ on the fly%n"); BAQCalculationResult hmmResult = calcBAQFromHMM(read, refReader); if ( hmmResult != null ) { switch ( qmode ) { case ADD_TAG: addBAQTag(read, hmmResult.bq); break; case OVERWRITE_QUALS: System.arraycopy(hmmResult.bq, 0, read.getBaseQualities(), 0, hmmResult.bq.length); break; case DONT_MODIFY: BAQQuals = hmmResult.bq; break; default: throw new ReviewedStingException("BUG: unexpected qmode " + qmode); } } } else if ( qmode == QualityMode.OVERWRITE_QUALS ) { // only makes sense if we are overwriting quals if ( DEBUG ) System.out.printf(" Taking BAQ from tag%n"); // this overwrites the original qualities calcBAQFromTag(read, true, false); } } return BAQQuals; } /** * Returns true if we don't think this read is eligable for the BAQ calculation. Examples include non-PF reads, * duplicates, or unmapped reads. Used by baqRead to determine if a read should fall through the calculation. * * @param read * @return */ public boolean excludeReadFromBAQ(SAMRecord read) { // keeping mapped reads, regardless of pairing status, or primary alignment status. return read.getReadUnmappedFlag() || read.getReadFailsVendorQualityCheckFlag() || read.getDuplicateReadFlag(); } }