hyb-align/ksw_align_avx512.c

362 lines
15 KiB
C
Raw Permalink Normal View History

#include "ksw_align_avx.h"
#undef SIMD_WIDTH8
#undef SIMD_WIDTH16
#define SIMD_WIDTH8 64
#define SIMD_WIDTH16 32
// 默认的非ACGT对应的罚分
static const int8_t w_ambig = -1;
#define MAIN_SAM_CODE8_OPT(s1, s2, h00, h11, e11, f11, f21, max512, sft512) \
{ \
__m512i sbt11, xor11, or11; \
xor11 = _mm512_xor_si512(s1, s2); \
sbt11 = _mm512_shuffle_epi8(permSft512, xor11); \
__mmask64 cmpq = _mm512_cmpeq_epu8_mask(s2, five512); \
sbt11 = _mm512_mask_blend_epi8(cmpq, sbt11, sft512); \
or11 = _mm512_or_si512(s1, s2); \
__mmask64 cmp = _mm512_movepi8_mask(or11); \
__m512i m11 = _mm512_adds_epu8(h00, sbt11); \
m11 = _mm512_mask_blend_epi8(cmp, m11, zero512); \
m11 = _mm512_subs_epu8(m11, sft512); \
h11 = _mm512_max_epu8(m11, e11); \
h11 = _mm512_max_epu8(h11, f11); \
__mmask64 cmp0 = _mm512_cmpgt_epu8_mask(h11, imax512); \
imax512 = _mm512_max_epu8(imax512, h11); \
iqe512 = _mm512_mask_blend_epi8(cmp0, iqe512, l512); \
__m512i gapE512 = _mm512_subs_epu8(h11, oe_ins512); \
e11 = _mm512_subs_epu8(e11, e_ins512); \
e11 = _mm512_max_epu8(gapE512, e11); \
__m512i gapD512 = _mm512_subs_epu8(h11, oe_del512); \
f21 = _mm512_subs_epu8(f11, e_del512); \
f21 = _mm512_max_epu8(gapD512, f21); \
}
#define MAIN_SAM_CODE16_OPT(s1, s2, h00, h11, e11, f11, f21, max512) \
{ \
__m512i sbt11, xor11, or11; \
xor11 = _mm512_xor_si512(s1, s2); \
sbt11 = _mm512_permutexvar_epi16(xor11, perm512); \
__m512i m11 = _mm512_add_epi16(h00, sbt11); \
or11 = _mm512_or_si512(s1, s2); \
__mmask64 cmp = _mm512_movepi8_mask(or11); \
m11 = _mm512_mask_blend_epi8(cmp, m11, zero512); \
h11 = _mm512_max_epi16(m11, e11); \
h11 = _mm512_max_epi16(h11, f11); \
h11 = _mm512_max_epi16(h11, zero512); \
__mmask32 cmp0 = _mm512_cmpgt_epi16_mask(h11, imax512); \
imax512 = _mm512_max_epi16(imax512, h11); \
iqe512 = _mm512_mask_blend_epi16(cmp0, iqe512, l512); \
__m512i gapE512 = _mm512_sub_epi16(h11, oe_ins512); \
e11 = _mm512_sub_epi16(e11, e_ins512); \
e11 = _mm512_max_epi16(gapE512, e11); \
__m512i gapD512 = _mm512_sub_epi16(h11, oe_del512); \
f21 = _mm512_sub_epi16(f11, e_del512); \
f21 = _mm512_max_epi16(gapD512, f21); \
}
void ksw_align_avx512_u8(int8_t w_match, // match分数正数
int8_t w_mismatch, // 错配罚分,负数
int8_t o_ins, // 开始一个insert罚分正数
int8_t e_ins, // 延续一个insert罚分正数
int8_t o_del, // 开始一个delete罚分正数
int8_t e_del, // 延续一个delete罚分正数
msw_buf_t* cache, // 计算用到的一些数据
uint8_t* seq1SoA, // ref序列已经pack好了
uint8_t* seq2SoA, // seq序列
int16_t nrow, // 最长的行数对应ref长度
int16_t ncol, // 最长的列数对应seq长度
int* xtras, // 每个seq对应一个xtra
int* rlenA, // ref真实长度
kswr_avx_t* alns, // 存放结果
int phase) { // 正向阶段0反向阶段1
int g_qmax = max_(w_match, w_mismatch);
g_qmax = max_(g_qmax, w_ambig);
uint8_t minsc[SIMD_WIDTH8] __attribute__((aligned(64))) = {0}; // min score ?
uint8_t endsc[SIMD_WIDTH8] __attribute__((aligned(64))) = {0}; // ending position score ?
__m512i zero512 = _mm512_setzero_si512();
__m512i one512 = _mm512_set1_epi8(1);
int8_t temp[SIMD_WIDTH8] __attribute((aligned(64))) = {0}; // 应该是用来根据比较结果赋分值的
uint8_t shift = 127, mdiff = 0;
mdiff = max_(w_match, (int8_t)w_mismatch);
mdiff = max_(mdiff, (int8_t)w_ambig);
shift = min_(w_match, (int8_t)w_mismatch);
shift = min_((int8_t)shift, w_ambig);
shift = 256 - (uint8_t)shift;
mdiff += shift;
temp[0] = w_match; // states: 1. matches
temp[1] = temp[2] = temp[3] = w_mismatch; // 2. mis-matches
temp[4] = temp[5] = temp[6] = temp[7] = w_ambig; // 3. beyond boundary
temp[8] = temp[9] = temp[10] = temp[11] = w_ambig; // 4. 0 - sse2 region
temp[12] = w_ambig; // 5. ambig
for (int i = 0; i < 16; i++) // for shuffle_epi8
temp[i] += shift;
int pos = 0;
for (int i = 16; i < SIMD_WIDTH8; i++) {
temp[i] = temp[pos++];
if (pos % 16 == 0)
pos = 0;
}
__m512i permSft512 = _mm512_load_si512(temp);
__m512i sft512 = _mm512_set1_epi8(shift);
__m512i cmax512 = _mm512_set1_epi8(255);
// __m512i minsc512, endsc512;
__mmask64 minsc_msk_a = 0x0000, endsc_msk_a = 0x0000;
int val = 0;
for (int i = 0; i < SIMD_WIDTH8; i++) {
int xtra = xtras[i];
val = (xtra & KSW_XSUBO) ? xtra & 0xffff : 0x10000;
if (val <= 255) {
minsc[i] = val;
minsc_msk_a |= (0x1L << i);
}
// msc_mask;
val = (xtra & KSW_XSTOP) ? xtra & 0xffff : 0x10000;
if (val <= 255) {
endsc[i] = val;
endsc_msk_a |= (0x1L << i);
}
}
__m512i minsc512 = _mm512_load_si512((__m512i*)minsc);
__m512i endsc512 = _mm512_load_si512((__m512i*)endsc);
__m512i e_del512 = _mm512_set1_epi8(e_del);
__m512i oe_del512 = _mm512_set1_epi8(o_del + e_del);
__m512i e_ins512 = _mm512_set1_epi8(e_ins);
__m512i oe_ins512 = _mm512_set1_epi8(o_ins + e_ins);
__m512i five512 = _mm512_set1_epi8(DUMMY5); // ambig mapping element
__m512i gmax512 = zero512; // exit1 = zero512;
__m512i te512 = _mm512_set1_epi16(-1); // changed to -1
__m512i te512_ = _mm512_set1_epi16(-1); // changed to -1
__mmask64 exit0 = 0xFFFFFFFFFFFFFFFF;
// 计算过程用到的一些数据用cache预先开辟的空间
uint8_t* H0 = cache->H0;
uint8_t* H1 = cache->H1;
uint8_t* Hmax = cache->Hmax;
uint8_t* F = cache->F;
uint8_t* rowMax = cache->rowMax;
for (int i = 0; i <= ncol; i++) {
_mm512_store_si512((__m512*)(H0 + i * SIMD_WIDTH8), zero512);
_mm512_store_si512((__m512*)(Hmax + i * SIMD_WIDTH8), zero512);
_mm512_store_si512((__m512*)(F + i * SIMD_WIDTH8), zero512);
}
#if 1
__m512i max512 = zero512, imax512, pimax512 = zero512;
__mmask64 mask512 = 0x0000;
__mmask64 minsc_msk = 0x0000;
__m512i qe512 = _mm512_set1_epi8(0);
_mm512_store_si512((__m512i*)(H0), zero512);
_mm512_store_si512((__m512i*)(H1), zero512);
#endif
#if 1
int i, limit = nrow;
for (i = 0; i < nrow; i++) {
__m512i e11 = zero512;
__m512i h00, h11, s1;
__m512i i512 = _mm512_set1_epi16(i);
int j;
s1 = _mm512_load_si512((__m512i*)(seq1SoA + (i + 0) * SIMD_WIDTH8));
imax512 = zero512;
__m512i iqe512 = _mm512_set1_epi8(-1);
__m512i l512 = zero512;
for (j = 0; j < ncol; j++) {
__m512i f11, s2, f21;
h00 = _mm512_load_si512((__m512i*)(H0 + j * SIMD_WIDTH8)); // check for col "0"
s2 = _mm512_load_si512((__m512i*)(seq2SoA + (j)*SIMD_WIDTH8));
f11 = _mm512_load_si512((__m512i*)(F + (j + 1) * SIMD_WIDTH8));
MAIN_SAM_CODE8_OPT(s1, s2, h00, h11, e11, f11, f21, max512, sft512);
_mm512_store_si512((__m512i*)(H1 + (j + 1) * SIMD_WIDTH8), h11); // check for col "0"
_mm512_store_si512((__m512i*)(F + (j + 1) * SIMD_WIDTH8), f21);
l512 = _mm512_add_epi8(l512, one512);
}
// Block I从第二行开始需要和前一行比较来计算max score
if (i > 0) {
__mmask64 msk64 = _mm512_cmpgt_epu8_mask(imax512, pimax512);
msk64 |= mask512;
pimax512 = _mm512_mask_blend_epi8(msk64, pimax512, zero512);
pimax512 = _mm512_mask_blend_epi8(minsc_msk, zero512, pimax512);
pimax512 = _mm512_mask_blend_epi8(exit0, zero512, pimax512);
_mm512_store_si512((__m512i*)(rowMax + (i - 1) * SIMD_WIDTH8), pimax512);
mask512 = ~msk64;
}
pimax512 = imax512;
minsc_msk = _mm512_cmpge_epu8_mask(imax512, minsc512);
minsc_msk &= minsc_msk_a;
// Block II: gmax, te
__mmask64 cmp0 = _mm512_cmpgt_epu8_mask(imax512, gmax512);
cmp0 &= exit0;
gmax512 = _mm512_mask_blend_epi8(cmp0, gmax512, imax512);
te512 = _mm512_mask_blend_epi16((__mmask32)cmp0, te512, i512);
te512_ = _mm512_mask_blend_epi16((__mmask32)(cmp0 >> SIMD_WIDTH16), te512_, i512);
qe512 = _mm512_mask_blend_epi8(cmp0, qe512, iqe512);
cmp0 = _mm512_cmpge_epu8_mask(gmax512, endsc512);
cmp0 &= endsc_msk_a;
__m512i left512 = _mm512_adds_epu8(gmax512, sft512);
__mmask64 cmp2 = _mm512_cmpge_epu8_mask(left512, cmax512);
exit0 = (~(cmp0 | cmp2)) & exit0;
if (exit0 == 0) {
limit = i++;
break;
}
uint8_t* S = H1;
H1 = H0;
H0 = S;
i512 = _mm512_add_epi16(i512, one512);
} // for nrow
pimax512 = _mm512_mask_blend_epi8(mask512, pimax512, zero512);
pimax512 = _mm512_mask_blend_epi8(minsc_msk, zero512, pimax512);
pimax512 = _mm512_mask_blend_epi8(exit0, zero512, pimax512);
_mm512_store_si512((__m512i*)(rowMax + (i - 1) * SIMD_WIDTH8), pimax512);
/******************* DP loop over *****************************/
/**************** Partial output setting **********************/
uint8_t score[SIMD_WIDTH8] __attribute((aligned(64)));
int16_t te1[SIMD_WIDTH8] __attribute((aligned(64)));
uint8_t qe[SIMD_WIDTH8] __attribute((aligned(64)));
int16_t low[SIMD_WIDTH8] __attribute((aligned(64)));
int16_t high[SIMD_WIDTH8] __attribute((aligned(64)));
_mm512_store_si512((__m512i*)score, gmax512);
_mm512_store_si512((__m512i*)te1, te512);
_mm512_store_si512((__m512i*)(te1 + SIMD_WIDTH16), te512_);
_mm512_store_si512((__m512i*)qe, qe512);
int live = 0;
for (int l = 0; l < SIMD_WIDTH8; l++) {
int16_t* te;
if (i < SIMD_WIDTH16)
te = te1;
else
te = te1;
if (phase) { // 第二阶段,反向比对
if (alns[l].score == score[l]) {
alns[l].tb = alns[l].te - te[l];
alns[l].qb = alns[l].qe - qe[l];
}
} else { // 第一阶段,正向比对
alns[l].score = score[l] + shift < 255 ? score[l] : 255;
alns[l].te = te[l];
alns[l].qe = qe[l];
if (alns[l].score != 255) {
qe[l] = 1;
live++;
} else
qe[l] = 0;
}
}
if (phase)
return;
if (live == 0)
return;
/*************** Score2 and te2 *******************/
int qmax = g_qmax;
int maxl = 0, minh = nrow;
for (int i = 0; i < SIMD_WIDTH8; i++) {
int val = (score[i] + qmax - 1) / qmax;
int16_t* te = te1;
low[i] = te[i] - val;
high[i] = te[i] + val;
if (qe[i]) {
maxl = maxl < low[i] ? low[i] : maxl;
minh = minh > high[i] ? high[i] : minh;
}
}
max512 = zero512;
te512 = _mm512_set1_epi16(-1);
te512_ = _mm512_set1_epi16(-1);
__m512i low512 = _mm512_load_si512((__m512i*)low); // make it int16
__m512i high512 = _mm512_load_si512((__m512i*)high); // int16
__m512i low512_ = _mm512_load_si512((__m512i*)(low + SIMD_WIDTH16)); // make it int16
__m512i high512_ = _mm512_load_si512((__m512i*)(high + SIMD_WIDTH16)); // int16
__m512i rmax512;
for (int i = 0; i < maxl; i++) {
__m512i i512 = _mm512_set1_epi16(i);
rmax512 = _mm512_load_si512((__m512i*)(rowMax + i * SIMD_WIDTH8));
__mmask64 mask11 = _mm512_cmpgt_epi16_mask(low512, i512);
__mmask64 mask12 = _mm512_cmpgt_epi16_mask(low512_, i512);
__mmask64 mask2 = _mm512_cmpgt_epu8_mask(rmax512, max512);
__mmask64 mask1 = mask11 | (mask12 << SIMD_WIDTH16);
mask2 &= mask1;
max512 = _mm512_mask_blend_epi8(mask2, max512, rmax512);
te512 = _mm512_mask_blend_epi16(mask2, te512, i512);
te512_ = _mm512_mask_blend_epi16(mask2 >> SIMD_WIDTH16, te512_, i512);
}
// Added new block -- due to bug
int16_t rlen[SIMD_WIDTH8] __attribute((aligned(64)));
for (int i = 0; i < SIMD_WIDTH8; i++) rlen[i] = rlenA[i];
__m512i rlen512 = _mm512_load_si512(rlen);
__m512i rlen512_ = _mm512_load_si512(rlen + SIMD_WIDTH16);
for (int i = minh + 1; i < limit; i++) {
__m512i i512 = _mm512_set1_epi16(i);
rmax512 = _mm512_load_si512((__m512i*)(rowMax + i * SIMD_WIDTH8));
__mmask64 mask11 = _mm512_cmpgt_epi16_mask(i512, high512);
__mmask64 mask12 = _mm512_cmpgt_epi16_mask(i512, high512_);
__mmask64 mask2 = _mm512_cmpgt_epu8_mask(rmax512, max512);
__mmask64 mask1 = mask11 | (mask12 << SIMD_WIDTH16);
mask2 &= mask1;
// new, bug
__mmask64 mask11_ = _mm512_cmpgt_epi16_mask(rlen512, i512);
__mmask64 mask12_ = _mm512_cmpgt_epi16_mask(rlen512_, i512);
__mmask64 mask1_ = mask11_ | (mask12_ << SIMD_WIDTH16);
mask2 &= mask1_;
max512 = _mm512_mask_blend_epi8(mask2, max512, rmax512);
te512 = _mm512_mask_blend_epi16(mask2, te512, i512);
te512_ = _mm512_mask_blend_epi16(mask2 >> SIMD_WIDTH16, te512_, i512);
}
int16_t temp4[SIMD_WIDTH8] __attribute((aligned(64)));
_mm512_store_si512((__m512i*)temp, max512);
_mm512_store_si512((__m512i*)temp4, te512);
_mm512_store_si512((__m512i*)(temp4 + SIMD_WIDTH16), te512_);
for (int i = 0; i < SIMD_WIDTH8; i++) {
int16_t* te2;
te2 = temp4;
if (qe[i]) {
alns[i].score2 = (temp[i] == 0 ? (int)-1 : (uint8_t)temp[i]);
alns[i].te2 = te2[i];
} else {
alns[i].score2 = -1;
alns[i].te2 = -1;
}
}
#endif
}