#include "ksw_align_avx.h" #undef SIMD_WIDTH8 #undef SIMD_WIDTH16 #define SIMD_WIDTH8 64 #define SIMD_WIDTH16 32 // 默认的非ACGT对应的罚分 static const int8_t w_ambig = -1; #define MAIN_SAM_CODE8_OPT(s1, s2, h00, h11, e11, f11, f21, max512, sft512) \ { \ __m512i sbt11, xor11, or11; \ xor11 = _mm512_xor_si512(s1, s2); \ sbt11 = _mm512_shuffle_epi8(permSft512, xor11); \ __mmask64 cmpq = _mm512_cmpeq_epu8_mask(s2, five512); \ sbt11 = _mm512_mask_blend_epi8(cmpq, sbt11, sft512); \ or11 = _mm512_or_si512(s1, s2); \ __mmask64 cmp = _mm512_movepi8_mask(or11); \ __m512i m11 = _mm512_adds_epu8(h00, sbt11); \ m11 = _mm512_mask_blend_epi8(cmp, m11, zero512); \ m11 = _mm512_subs_epu8(m11, sft512); \ h11 = _mm512_max_epu8(m11, e11); \ h11 = _mm512_max_epu8(h11, f11); \ __mmask64 cmp0 = _mm512_cmpgt_epu8_mask(h11, imax512); \ imax512 = _mm512_max_epu8(imax512, h11); \ iqe512 = _mm512_mask_blend_epi8(cmp0, iqe512, l512); \ __m512i gapE512 = _mm512_subs_epu8(h11, oe_ins512); \ e11 = _mm512_subs_epu8(e11, e_ins512); \ e11 = _mm512_max_epu8(gapE512, e11); \ __m512i gapD512 = _mm512_subs_epu8(h11, oe_del512); \ f21 = _mm512_subs_epu8(f11, e_del512); \ f21 = _mm512_max_epu8(gapD512, f21); \ } #define MAIN_SAM_CODE16_OPT(s1, s2, h00, h11, e11, f11, f21, max512) \ { \ __m512i sbt11, xor11, or11; \ xor11 = _mm512_xor_si512(s1, s2); \ sbt11 = _mm512_permutexvar_epi16(xor11, perm512); \ __m512i m11 = _mm512_add_epi16(h00, sbt11); \ or11 = _mm512_or_si512(s1, s2); \ __mmask64 cmp = _mm512_movepi8_mask(or11); \ m11 = _mm512_mask_blend_epi8(cmp, m11, zero512); \ h11 = _mm512_max_epi16(m11, e11); \ h11 = _mm512_max_epi16(h11, f11); \ h11 = _mm512_max_epi16(h11, zero512); \ __mmask32 cmp0 = _mm512_cmpgt_epi16_mask(h11, imax512); \ imax512 = _mm512_max_epi16(imax512, h11); \ iqe512 = _mm512_mask_blend_epi16(cmp0, iqe512, l512); \ __m512i gapE512 = _mm512_sub_epi16(h11, oe_ins512); \ e11 = _mm512_sub_epi16(e11, e_ins512); \ e11 = _mm512_max_epi16(gapE512, e11); \ __m512i gapD512 = _mm512_sub_epi16(h11, oe_del512); \ f21 = _mm512_sub_epi16(f11, e_del512); \ f21 = _mm512_max_epi16(gapD512, f21); \ } void ksw_align_avx512_u8(int8_t w_match, // match分数,正数 int8_t w_mismatch, // 错配罚分,负数 int8_t o_ins, // 开始一个insert罚分,正数 int8_t e_ins, // 延续一个insert罚分,正数 int8_t o_del, // 开始一个delete罚分,正数 int8_t e_del, // 延续一个delete罚分,正数 matesw_buf_t* cache, // 计算用到的一些数据 uint8_t* seq1SoA, // ref序列,已经pack好了 uint8_t* seq2SoA, // seq序列 int16_t nrow, // 最长的行数,对应ref长度 int16_t ncol, // 最长的列数,对应seq长度 int* xtras, // 每个seq对应一个xtra int* rlenA, // ref真实长度 kswr_avx_t* alns, // 存放结果 int phase) { // 正向阶段0,反向阶段1 int g_qmax = max_(w_match, w_mismatch); g_qmax = max_(g_qmax, w_ambig); uint8_t minsc[SIMD_WIDTH8] __attribute__((aligned(64))) = {0}; // min score ? uint8_t endsc[SIMD_WIDTH8] __attribute__((aligned(64))) = {0}; // ending position score ? __m512i zero512 = _mm512_setzero_si512(); __m512i one512 = _mm512_set1_epi8(1); int8_t temp[SIMD_WIDTH8] __attribute((aligned(64))) = {0}; // 应该是用来根据比较结果赋分值的 uint8_t shift = 127, mdiff = 0; mdiff = max_(w_match, (int8_t)w_mismatch); mdiff = max_(mdiff, (int8_t)w_ambig); shift = min_(w_match, (int8_t)w_mismatch); shift = min_((int8_t)shift, w_ambig); shift = 256 - (uint8_t)shift; mdiff += shift; temp[0] = w_match; // states: 1. matches temp[1] = temp[2] = temp[3] = w_mismatch; // 2. mis-matches temp[4] = temp[5] = temp[6] = temp[7] = w_ambig; // 3. beyond boundary temp[8] = temp[9] = temp[10] = temp[11] = w_ambig; // 4. 0 - sse2 region temp[12] = w_ambig; // 5. ambig for (int i = 0; i < 16; i++) // for shuffle_epi8 temp[i] += shift; int pos = 0; for (int i = 16; i < SIMD_WIDTH8; i++) { temp[i] = temp[pos++]; if (pos % 16 == 0) pos = 0; } __m512i permSft512 = _mm512_load_si512(temp); __m512i sft512 = _mm512_set1_epi8(shift); __m512i cmax512 = _mm512_set1_epi8(255); // __m512i minsc512, endsc512; __mmask64 minsc_msk_a = 0x0000, endsc_msk_a = 0x0000; int val = 0; for (int i = 0; i < SIMD_WIDTH8; i++) { int xtra = xtras[i]; val = (xtra & KSW_XSUBO) ? xtra & 0xffff : 0x10000; if (val <= 255) { minsc[i] = val; minsc_msk_a |= (0x1L << i); } // msc_mask; val = (xtra & KSW_XSTOP) ? xtra & 0xffff : 0x10000; if (val <= 255) { endsc[i] = val; endsc_msk_a |= (0x1L << i); } } __m512i minsc512 = _mm512_load_si512((__m512i*)minsc); __m512i endsc512 = _mm512_load_si512((__m512i*)endsc); __m512i e_del512 = _mm512_set1_epi8(e_del); __m512i oe_del512 = _mm512_set1_epi8(o_del + e_del); __m512i e_ins512 = _mm512_set1_epi8(e_ins); __m512i oe_ins512 = _mm512_set1_epi8(o_ins + e_ins); __m512i five512 = _mm512_set1_epi8(DUMMY5); // ambig mapping element __m512i gmax512 = zero512; // exit1 = zero512; __m512i te512 = _mm512_set1_epi16(-1); // changed to -1 __m512i te512_ = _mm512_set1_epi16(-1); // changed to -1 __mmask64 exit0 = 0xFFFFFFFFFFFFFFFF; // 计算过程用到的一些数据,用cache预先开辟的空间 uint8_t* H0 = cache->H0; uint8_t* H1 = cache->H1; uint8_t* Hmax = cache->Hmax; uint8_t* F = cache->F; uint8_t* rowMax = cache->rowMax; for (int i = 0; i <= ncol; i++) { _mm512_store_si512((__m512*)(H0 + i * SIMD_WIDTH8), zero512); _mm512_store_si512((__m512*)(Hmax + i * SIMD_WIDTH8), zero512); _mm512_store_si512((__m512*)(F + i * SIMD_WIDTH8), zero512); } #if 1 __m512i max512 = zero512, imax512, pimax512 = zero512; __mmask64 mask512 = 0x0000; __mmask64 minsc_msk = 0x0000; __m512i qe512 = _mm512_set1_epi8(0); _mm512_store_si512((__m512i*)(H0), zero512); _mm512_store_si512((__m512i*)(H1), zero512); #endif #if 1 int i, limit = nrow; for (i = 0; i < nrow; i++) { __m512i e11 = zero512; __m512i h00, h11, s1; __m512i i512 = _mm512_set1_epi16(i); int j; s1 = _mm512_load_si512((__m512i*)(seq1SoA + (i + 0) * SIMD_WIDTH8)); imax512 = zero512; __m512i iqe512 = _mm512_set1_epi8(-1); __m512i l512 = zero512; for (j = 0; j < ncol; j++) { __m512i f11, s2, f21; h00 = _mm512_load_si512((__m512i*)(H0 + j * SIMD_WIDTH8)); // check for col "0" s2 = _mm512_load_si512((__m512i*)(seq2SoA + (j)*SIMD_WIDTH8)); f11 = _mm512_load_si512((__m512i*)(F + (j + 1) * SIMD_WIDTH8)); MAIN_SAM_CODE8_OPT(s1, s2, h00, h11, e11, f11, f21, max512, sft512); _mm512_store_si512((__m512i*)(H1 + (j + 1) * SIMD_WIDTH8), h11); // check for col "0" _mm512_store_si512((__m512i*)(F + (j + 1) * SIMD_WIDTH8), f21); l512 = _mm512_add_epi8(l512, one512); } // Block I,从第二行开始,需要和前一行比较,来计算max score if (i > 0) { __mmask64 msk64 = _mm512_cmpgt_epu8_mask(imax512, pimax512); msk64 |= mask512; pimax512 = _mm512_mask_blend_epi8(msk64, pimax512, zero512); pimax512 = _mm512_mask_blend_epi8(minsc_msk, zero512, pimax512); pimax512 = _mm512_mask_blend_epi8(exit0, zero512, pimax512); _mm512_store_si512((__m512i*)(rowMax + (i - 1) * SIMD_WIDTH8), pimax512); mask512 = ~msk64; } pimax512 = imax512; minsc_msk = _mm512_cmpge_epu8_mask(imax512, minsc512); minsc_msk &= minsc_msk_a; // Block II: gmax, te __mmask64 cmp0 = _mm512_cmpgt_epu8_mask(imax512, gmax512); cmp0 &= exit0; gmax512 = _mm512_mask_blend_epi8(cmp0, gmax512, imax512); te512 = _mm512_mask_blend_epi16((__mmask32)cmp0, te512, i512); te512_ = _mm512_mask_blend_epi16((__mmask32)(cmp0 >> SIMD_WIDTH16), te512_, i512); qe512 = _mm512_mask_blend_epi8(cmp0, qe512, iqe512); cmp0 = _mm512_cmpge_epu8_mask(gmax512, endsc512); cmp0 &= endsc_msk_a; __m512i left512 = _mm512_adds_epu8(gmax512, sft512); __mmask64 cmp2 = _mm512_cmpge_epu8_mask(left512, cmax512); exit0 = (~(cmp0 | cmp2)) & exit0; if (exit0 == 0) { limit = i++; break; } uint8_t* S = H1; H1 = H0; H0 = S; i512 = _mm512_add_epi16(i512, one512); } // for nrow pimax512 = _mm512_mask_blend_epi8(mask512, pimax512, zero512); pimax512 = _mm512_mask_blend_epi8(minsc_msk, zero512, pimax512); pimax512 = _mm512_mask_blend_epi8(exit0, zero512, pimax512); _mm512_store_si512((__m512i*)(rowMax + (i - 1) * SIMD_WIDTH8), pimax512); /******************* DP loop over *****************************/ /**************** Partial output setting **********************/ uint8_t score[SIMD_WIDTH8] __attribute((aligned(64))); int16_t te1[SIMD_WIDTH8] __attribute((aligned(64))); uint8_t qe[SIMD_WIDTH8] __attribute((aligned(64))); int16_t low[SIMD_WIDTH8] __attribute((aligned(64))); int16_t high[SIMD_WIDTH8] __attribute((aligned(64))); _mm512_store_si512((__m512i*)score, gmax512); _mm512_store_si512((__m512i*)te1, te512); _mm512_store_si512((__m512i*)(te1 + SIMD_WIDTH16), te512_); _mm512_store_si512((__m512i*)qe, qe512); int live = 0; for (int l = 0; l < SIMD_WIDTH8; l++) { int16_t* te; if (i < SIMD_WIDTH16) te = te1; else te = te1; if (phase) { // 第二阶段,反向比对 if (alns[l].score == score[l]) { alns[l].tb = alns[l].te - te[l]; alns[l].qb = alns[l].qe - qe[l]; } } else { // 第一阶段,正向比对 alns[l].score = score[l] + shift < 255 ? score[l] : 255; alns[l].te = te[l]; alns[l].qe = qe[l]; if (alns[l].score != 255) { qe[l] = 1; live++; } else qe[l] = 0; } } if (phase) return; if (live == 0) return; /*************** Score2 and te2 *******************/ int qmax = g_qmax; int maxl = 0, minh = nrow; for (int i = 0; i < SIMD_WIDTH8; i++) { int val = (score[i] + qmax - 1) / qmax; int16_t* te = te1; low[i] = te[i] - val; high[i] = te[i] + val; if (qe[i]) { maxl = maxl < low[i] ? low[i] : maxl; minh = minh > high[i] ? high[i] : minh; } } max512 = zero512; te512 = _mm512_set1_epi16(-1); te512_ = _mm512_set1_epi16(-1); __m512i low512 = _mm512_load_si512((__m512i*)low); // make it int16 __m512i high512 = _mm512_load_si512((__m512i*)high); // int16 __m512i low512_ = _mm512_load_si512((__m512i*)(low + SIMD_WIDTH16)); // make it int16 __m512i high512_ = _mm512_load_si512((__m512i*)(high + SIMD_WIDTH16)); // int16 __m512i rmax512; for (int i = 0; i < maxl; i++) { __m512i i512 = _mm512_set1_epi16(i); rmax512 = _mm512_load_si512((__m512i*)(rowMax + i * SIMD_WIDTH8)); __mmask64 mask11 = _mm512_cmpgt_epi16_mask(low512, i512); __mmask64 mask12 = _mm512_cmpgt_epi16_mask(low512_, i512); __mmask64 mask2 = _mm512_cmpgt_epu8_mask(rmax512, max512); __mmask64 mask1 = mask11 | (mask12 << SIMD_WIDTH16); mask2 &= mask1; max512 = _mm512_mask_blend_epi8(mask2, max512, rmax512); te512 = _mm512_mask_blend_epi16(mask2, te512, i512); te512_ = _mm512_mask_blend_epi16(mask2 >> SIMD_WIDTH16, te512_, i512); } // Added new block -- due to bug int16_t rlen[SIMD_WIDTH8] __attribute((aligned(64))); for (int i = 0; i < SIMD_WIDTH8; i++) rlen[i] = rlenA[i]; __m512i rlen512 = _mm512_load_si512(rlen); __m512i rlen512_ = _mm512_load_si512(rlen + SIMD_WIDTH16); for (int i = minh + 1; i < limit; i++) { __m512i i512 = _mm512_set1_epi16(i); rmax512 = _mm512_load_si512((__m512i*)(rowMax + i * SIMD_WIDTH8)); __mmask64 mask11 = _mm512_cmpgt_epi16_mask(i512, high512); __mmask64 mask12 = _mm512_cmpgt_epi16_mask(i512, high512_); __mmask64 mask2 = _mm512_cmpgt_epu8_mask(rmax512, max512); __mmask64 mask1 = mask11 | (mask12 << SIMD_WIDTH16); mask2 &= mask1; // new, bug __mmask64 mask11_ = _mm512_cmpgt_epi16_mask(rlen512, i512); __mmask64 mask12_ = _mm512_cmpgt_epi16_mask(rlen512_, i512); __mmask64 mask1_ = mask11_ | (mask12_ << SIMD_WIDTH16); mask2 &= mask1_; max512 = _mm512_mask_blend_epi8(mask2, max512, rmax512); te512 = _mm512_mask_blend_epi16(mask2, te512, i512); te512_ = _mm512_mask_blend_epi16(mask2 >> SIMD_WIDTH16, te512_, i512); } int16_t temp4[SIMD_WIDTH8] __attribute((aligned(64))); _mm512_store_si512((__m512i*)temp, max512); _mm512_store_si512((__m512i*)temp4, te512); _mm512_store_si512((__m512i*)(temp4 + SIMD_WIDTH16), te512_); for (int i = 0; i < SIMD_WIDTH8; i++) { int16_t* te2; te2 = temp4; if (qe[i]) { alns[i].score2 = (temp[i] == 0 ? (int)-1 : (uint8_t)temp[i]); alns[i].te2 = te2[i]; } else { alns[i].score2 = -1; alns[i].te2 = -1; } } #endif }