362 lines
15 KiB
C
362 lines
15 KiB
C
#include "ksw_align_avx.h"
|
||
|
||
#undef SIMD_WIDTH8
|
||
#undef SIMD_WIDTH16
|
||
#define SIMD_WIDTH8 64
|
||
#define SIMD_WIDTH16 32
|
||
|
||
// 默认的非ACGT对应的罚分
|
||
static const int8_t w_ambig = -1;
|
||
|
||
#define MAIN_SAM_CODE8_OPT(s1, s2, h00, h11, e11, f11, f21, max512, sft512) \
|
||
{ \
|
||
__m512i sbt11, xor11, or11; \
|
||
xor11 = _mm512_xor_si512(s1, s2); \
|
||
sbt11 = _mm512_shuffle_epi8(permSft512, xor11); \
|
||
__mmask64 cmpq = _mm512_cmpeq_epu8_mask(s2, five512); \
|
||
sbt11 = _mm512_mask_blend_epi8(cmpq, sbt11, sft512); \
|
||
or11 = _mm512_or_si512(s1, s2); \
|
||
__mmask64 cmp = _mm512_movepi8_mask(or11); \
|
||
__m512i m11 = _mm512_adds_epu8(h00, sbt11); \
|
||
m11 = _mm512_mask_blend_epi8(cmp, m11, zero512); \
|
||
m11 = _mm512_subs_epu8(m11, sft512); \
|
||
h11 = _mm512_max_epu8(m11, e11); \
|
||
h11 = _mm512_max_epu8(h11, f11); \
|
||
__mmask64 cmp0 = _mm512_cmpgt_epu8_mask(h11, imax512); \
|
||
imax512 = _mm512_max_epu8(imax512, h11); \
|
||
iqe512 = _mm512_mask_blend_epi8(cmp0, iqe512, l512); \
|
||
__m512i gapE512 = _mm512_subs_epu8(h11, oe_ins512); \
|
||
e11 = _mm512_subs_epu8(e11, e_ins512); \
|
||
e11 = _mm512_max_epu8(gapE512, e11); \
|
||
__m512i gapD512 = _mm512_subs_epu8(h11, oe_del512); \
|
||
f21 = _mm512_subs_epu8(f11, e_del512); \
|
||
f21 = _mm512_max_epu8(gapD512, f21); \
|
||
}
|
||
|
||
#define MAIN_SAM_CODE16_OPT(s1, s2, h00, h11, e11, f11, f21, max512) \
|
||
{ \
|
||
__m512i sbt11, xor11, or11; \
|
||
xor11 = _mm512_xor_si512(s1, s2); \
|
||
sbt11 = _mm512_permutexvar_epi16(xor11, perm512); \
|
||
__m512i m11 = _mm512_add_epi16(h00, sbt11); \
|
||
or11 = _mm512_or_si512(s1, s2); \
|
||
__mmask64 cmp = _mm512_movepi8_mask(or11); \
|
||
m11 = _mm512_mask_blend_epi8(cmp, m11, zero512); \
|
||
h11 = _mm512_max_epi16(m11, e11); \
|
||
h11 = _mm512_max_epi16(h11, f11); \
|
||
h11 = _mm512_max_epi16(h11, zero512); \
|
||
__mmask32 cmp0 = _mm512_cmpgt_epi16_mask(h11, imax512); \
|
||
imax512 = _mm512_max_epi16(imax512, h11); \
|
||
iqe512 = _mm512_mask_blend_epi16(cmp0, iqe512, l512); \
|
||
__m512i gapE512 = _mm512_sub_epi16(h11, oe_ins512); \
|
||
e11 = _mm512_sub_epi16(e11, e_ins512); \
|
||
e11 = _mm512_max_epi16(gapE512, e11); \
|
||
__m512i gapD512 = _mm512_sub_epi16(h11, oe_del512); \
|
||
f21 = _mm512_sub_epi16(f11, e_del512); \
|
||
f21 = _mm512_max_epi16(gapD512, f21); \
|
||
}
|
||
|
||
void ksw_align_avx512_u8(int8_t w_match, // match分数,正数
|
||
int8_t w_mismatch, // 错配罚分,负数
|
||
int8_t o_ins, // 开始一个insert罚分,正数
|
||
int8_t e_ins, // 延续一个insert罚分,正数
|
||
int8_t o_del, // 开始一个delete罚分,正数
|
||
int8_t e_del, // 延续一个delete罚分,正数
|
||
msw_buf_t* cache, // 计算用到的一些数据
|
||
uint8_t* seq1SoA, // ref序列,已经pack好了
|
||
uint8_t* seq2SoA, // seq序列
|
||
int16_t nrow, // 最长的行数,对应ref长度
|
||
int16_t ncol, // 最长的列数,对应seq长度
|
||
int* xtras, // 每个seq对应一个xtra
|
||
int* rlenA, // ref真实长度
|
||
kswr_avx_t* alns, // 存放结果
|
||
int phase) { // 正向阶段0,反向阶段1
|
||
|
||
int g_qmax = max_(w_match, w_mismatch);
|
||
g_qmax = max_(g_qmax, w_ambig);
|
||
|
||
uint8_t minsc[SIMD_WIDTH8] __attribute__((aligned(64))) = {0}; // min score ?
|
||
uint8_t endsc[SIMD_WIDTH8] __attribute__((aligned(64))) = {0}; // ending position score ?
|
||
|
||
__m512i zero512 = _mm512_setzero_si512();
|
||
__m512i one512 = _mm512_set1_epi8(1);
|
||
|
||
int8_t temp[SIMD_WIDTH8] __attribute((aligned(64))) = {0}; // 应该是用来根据比较结果赋分值的
|
||
|
||
uint8_t shift = 127, mdiff = 0;
|
||
mdiff = max_(w_match, (int8_t)w_mismatch);
|
||
mdiff = max_(mdiff, (int8_t)w_ambig);
|
||
shift = min_(w_match, (int8_t)w_mismatch);
|
||
shift = min_((int8_t)shift, w_ambig);
|
||
|
||
shift = 256 - (uint8_t)shift;
|
||
mdiff += shift;
|
||
|
||
temp[0] = w_match; // states: 1. matches
|
||
temp[1] = temp[2] = temp[3] = w_mismatch; // 2. mis-matches
|
||
temp[4] = temp[5] = temp[6] = temp[7] = w_ambig; // 3. beyond boundary
|
||
temp[8] = temp[9] = temp[10] = temp[11] = w_ambig; // 4. 0 - sse2 region
|
||
temp[12] = w_ambig; // 5. ambig
|
||
|
||
for (int i = 0; i < 16; i++) // for shuffle_epi8
|
||
temp[i] += shift;
|
||
|
||
int pos = 0;
|
||
for (int i = 16; i < SIMD_WIDTH8; i++) {
|
||
temp[i] = temp[pos++];
|
||
if (pos % 16 == 0)
|
||
pos = 0;
|
||
}
|
||
|
||
__m512i permSft512 = _mm512_load_si512(temp);
|
||
__m512i sft512 = _mm512_set1_epi8(shift);
|
||
__m512i cmax512 = _mm512_set1_epi8(255);
|
||
|
||
// __m512i minsc512, endsc512;
|
||
__mmask64 minsc_msk_a = 0x0000, endsc_msk_a = 0x0000;
|
||
int val = 0;
|
||
for (int i = 0; i < SIMD_WIDTH8; i++) {
|
||
int xtra = xtras[i];
|
||
val = (xtra & KSW_XSUBO) ? xtra & 0xffff : 0x10000;
|
||
if (val <= 255) {
|
||
minsc[i] = val;
|
||
minsc_msk_a |= (0x1L << i);
|
||
}
|
||
// msc_mask;
|
||
val = (xtra & KSW_XSTOP) ? xtra & 0xffff : 0x10000;
|
||
if (val <= 255) {
|
||
endsc[i] = val;
|
||
endsc_msk_a |= (0x1L << i);
|
||
}
|
||
}
|
||
|
||
__m512i minsc512 = _mm512_load_si512((__m512i*)minsc);
|
||
__m512i endsc512 = _mm512_load_si512((__m512i*)endsc);
|
||
|
||
__m512i e_del512 = _mm512_set1_epi8(e_del);
|
||
__m512i oe_del512 = _mm512_set1_epi8(o_del + e_del);
|
||
__m512i e_ins512 = _mm512_set1_epi8(e_ins);
|
||
__m512i oe_ins512 = _mm512_set1_epi8(o_ins + e_ins);
|
||
__m512i five512 = _mm512_set1_epi8(DUMMY5); // ambig mapping element
|
||
__m512i gmax512 = zero512; // exit1 = zero512;
|
||
__m512i te512 = _mm512_set1_epi16(-1); // changed to -1
|
||
__m512i te512_ = _mm512_set1_epi16(-1); // changed to -1
|
||
|
||
__mmask64 exit0 = 0xFFFFFFFFFFFFFFFF;
|
||
|
||
// 计算过程用到的一些数据,用cache预先开辟的空间
|
||
uint8_t* H0 = cache->H0;
|
||
uint8_t* H1 = cache->H1;
|
||
uint8_t* Hmax = cache->Hmax;
|
||
uint8_t* F = cache->F;
|
||
uint8_t* rowMax = cache->rowMax;
|
||
|
||
for (int i = 0; i <= ncol; i++) {
|
||
_mm512_store_si512((__m512*)(H0 + i * SIMD_WIDTH8), zero512);
|
||
_mm512_store_si512((__m512*)(Hmax + i * SIMD_WIDTH8), zero512);
|
||
_mm512_store_si512((__m512*)(F + i * SIMD_WIDTH8), zero512);
|
||
}
|
||
#if 1
|
||
__m512i max512 = zero512, imax512, pimax512 = zero512;
|
||
__mmask64 mask512 = 0x0000;
|
||
__mmask64 minsc_msk = 0x0000;
|
||
|
||
__m512i qe512 = _mm512_set1_epi8(0);
|
||
_mm512_store_si512((__m512i*)(H0), zero512);
|
||
_mm512_store_si512((__m512i*)(H1), zero512);
|
||
#endif
|
||
#if 1
|
||
int i, limit = nrow;
|
||
for (i = 0; i < nrow; i++) {
|
||
__m512i e11 = zero512;
|
||
__m512i h00, h11, s1;
|
||
__m512i i512 = _mm512_set1_epi16(i);
|
||
int j;
|
||
|
||
s1 = _mm512_load_si512((__m512i*)(seq1SoA + (i + 0) * SIMD_WIDTH8));
|
||
imax512 = zero512;
|
||
__m512i iqe512 = _mm512_set1_epi8(-1);
|
||
|
||
__m512i l512 = zero512;
|
||
for (j = 0; j < ncol; j++) {
|
||
__m512i f11, s2, f21;
|
||
h00 = _mm512_load_si512((__m512i*)(H0 + j * SIMD_WIDTH8)); // check for col "0"
|
||
s2 = _mm512_load_si512((__m512i*)(seq2SoA + (j)*SIMD_WIDTH8));
|
||
f11 = _mm512_load_si512((__m512i*)(F + (j + 1) * SIMD_WIDTH8));
|
||
|
||
MAIN_SAM_CODE8_OPT(s1, s2, h00, h11, e11, f11, f21, max512, sft512);
|
||
|
||
_mm512_store_si512((__m512i*)(H1 + (j + 1) * SIMD_WIDTH8), h11); // check for col "0"
|
||
_mm512_store_si512((__m512i*)(F + (j + 1) * SIMD_WIDTH8), f21);
|
||
l512 = _mm512_add_epi8(l512, one512);
|
||
}
|
||
|
||
// Block I,从第二行开始,需要和前一行比较,来计算max score
|
||
if (i > 0) {
|
||
__mmask64 msk64 = _mm512_cmpgt_epu8_mask(imax512, pimax512);
|
||
msk64 |= mask512;
|
||
pimax512 = _mm512_mask_blend_epi8(msk64, pimax512, zero512);
|
||
pimax512 = _mm512_mask_blend_epi8(minsc_msk, zero512, pimax512);
|
||
pimax512 = _mm512_mask_blend_epi8(exit0, zero512, pimax512);
|
||
|
||
_mm512_store_si512((__m512i*)(rowMax + (i - 1) * SIMD_WIDTH8), pimax512);
|
||
mask512 = ~msk64;
|
||
}
|
||
pimax512 = imax512;
|
||
minsc_msk = _mm512_cmpge_epu8_mask(imax512, minsc512);
|
||
minsc_msk &= minsc_msk_a;
|
||
|
||
// Block II: gmax, te
|
||
__mmask64 cmp0 = _mm512_cmpgt_epu8_mask(imax512, gmax512);
|
||
cmp0 &= exit0;
|
||
gmax512 = _mm512_mask_blend_epi8(cmp0, gmax512, imax512);
|
||
te512 = _mm512_mask_blend_epi16((__mmask32)cmp0, te512, i512);
|
||
te512_ = _mm512_mask_blend_epi16((__mmask32)(cmp0 >> SIMD_WIDTH16), te512_, i512);
|
||
qe512 = _mm512_mask_blend_epi8(cmp0, qe512, iqe512);
|
||
|
||
cmp0 = _mm512_cmpge_epu8_mask(gmax512, endsc512);
|
||
cmp0 &= endsc_msk_a;
|
||
|
||
__m512i left512 = _mm512_adds_epu8(gmax512, sft512);
|
||
__mmask64 cmp2 = _mm512_cmpge_epu8_mask(left512, cmax512);
|
||
|
||
exit0 = (~(cmp0 | cmp2)) & exit0;
|
||
if (exit0 == 0) {
|
||
limit = i++;
|
||
break;
|
||
}
|
||
|
||
uint8_t* S = H1;
|
||
H1 = H0;
|
||
H0 = S;
|
||
i512 = _mm512_add_epi16(i512, one512);
|
||
} // for nrow
|
||
|
||
pimax512 = _mm512_mask_blend_epi8(mask512, pimax512, zero512);
|
||
pimax512 = _mm512_mask_blend_epi8(minsc_msk, zero512, pimax512);
|
||
pimax512 = _mm512_mask_blend_epi8(exit0, zero512, pimax512);
|
||
_mm512_store_si512((__m512i*)(rowMax + (i - 1) * SIMD_WIDTH8), pimax512);
|
||
|
||
/******************* DP loop over *****************************/
|
||
/**************** Partial output setting **********************/
|
||
uint8_t score[SIMD_WIDTH8] __attribute((aligned(64)));
|
||
int16_t te1[SIMD_WIDTH8] __attribute((aligned(64)));
|
||
uint8_t qe[SIMD_WIDTH8] __attribute((aligned(64)));
|
||
int16_t low[SIMD_WIDTH8] __attribute((aligned(64)));
|
||
int16_t high[SIMD_WIDTH8] __attribute((aligned(64)));
|
||
|
||
_mm512_store_si512((__m512i*)score, gmax512);
|
||
_mm512_store_si512((__m512i*)te1, te512);
|
||
_mm512_store_si512((__m512i*)(te1 + SIMD_WIDTH16), te512_);
|
||
_mm512_store_si512((__m512i*)qe, qe512);
|
||
|
||
int live = 0;
|
||
|
||
for (int l = 0; l < SIMD_WIDTH8; l++) {
|
||
int16_t* te;
|
||
if (i < SIMD_WIDTH16)
|
||
te = te1;
|
||
else
|
||
te = te1;
|
||
if (phase) { // 第二阶段,反向比对
|
||
if (alns[l].score == score[l]) {
|
||
alns[l].tb = alns[l].te - te[l];
|
||
alns[l].qb = alns[l].qe - qe[l];
|
||
}
|
||
} else { // 第一阶段,正向比对
|
||
alns[l].score = score[l] + shift < 255 ? score[l] : 255;
|
||
alns[l].te = te[l];
|
||
alns[l].qe = qe[l];
|
||
if (alns[l].score != 255) {
|
||
qe[l] = 1;
|
||
live++;
|
||
} else
|
||
qe[l] = 0;
|
||
}
|
||
}
|
||
|
||
if (phase)
|
||
return;
|
||
|
||
if (live == 0)
|
||
return;
|
||
|
||
/*************** Score2 and te2 *******************/
|
||
int qmax = g_qmax;
|
||
int maxl = 0, minh = nrow;
|
||
for (int i = 0; i < SIMD_WIDTH8; i++) {
|
||
int val = (score[i] + qmax - 1) / qmax;
|
||
int16_t* te = te1;
|
||
low[i] = te[i] - val;
|
||
high[i] = te[i] + val;
|
||
if (qe[i]) {
|
||
maxl = maxl < low[i] ? low[i] : maxl;
|
||
minh = minh > high[i] ? high[i] : minh;
|
||
}
|
||
}
|
||
|
||
max512 = zero512;
|
||
te512 = _mm512_set1_epi16(-1);
|
||
te512_ = _mm512_set1_epi16(-1);
|
||
__m512i low512 = _mm512_load_si512((__m512i*)low); // make it int16
|
||
__m512i high512 = _mm512_load_si512((__m512i*)high); // int16
|
||
__m512i low512_ = _mm512_load_si512((__m512i*)(low + SIMD_WIDTH16)); // make it int16
|
||
__m512i high512_ = _mm512_load_si512((__m512i*)(high + SIMD_WIDTH16)); // int16
|
||
|
||
__m512i rmax512;
|
||
for (int i = 0; i < maxl; i++) {
|
||
__m512i i512 = _mm512_set1_epi16(i);
|
||
rmax512 = _mm512_load_si512((__m512i*)(rowMax + i * SIMD_WIDTH8));
|
||
|
||
__mmask64 mask11 = _mm512_cmpgt_epi16_mask(low512, i512);
|
||
__mmask64 mask12 = _mm512_cmpgt_epi16_mask(low512_, i512);
|
||
__mmask64 mask2 = _mm512_cmpgt_epu8_mask(rmax512, max512);
|
||
__mmask64 mask1 = mask11 | (mask12 << SIMD_WIDTH16);
|
||
mask2 &= mask1;
|
||
max512 = _mm512_mask_blend_epi8(mask2, max512, rmax512);
|
||
te512 = _mm512_mask_blend_epi16(mask2, te512, i512);
|
||
te512_ = _mm512_mask_blend_epi16(mask2 >> SIMD_WIDTH16, te512_, i512);
|
||
}
|
||
|
||
// Added new block -- due to bug
|
||
int16_t rlen[SIMD_WIDTH8] __attribute((aligned(64)));
|
||
for (int i = 0; i < SIMD_WIDTH8; i++) rlen[i] = rlenA[i];
|
||
__m512i rlen512 = _mm512_load_si512(rlen);
|
||
__m512i rlen512_ = _mm512_load_si512(rlen + SIMD_WIDTH16);
|
||
|
||
for (int i = minh + 1; i < limit; i++) {
|
||
__m512i i512 = _mm512_set1_epi16(i);
|
||
rmax512 = _mm512_load_si512((__m512i*)(rowMax + i * SIMD_WIDTH8));
|
||
__mmask64 mask11 = _mm512_cmpgt_epi16_mask(i512, high512);
|
||
__mmask64 mask12 = _mm512_cmpgt_epi16_mask(i512, high512_);
|
||
__mmask64 mask2 = _mm512_cmpgt_epu8_mask(rmax512, max512);
|
||
__mmask64 mask1 = mask11 | (mask12 << SIMD_WIDTH16);
|
||
mask2 &= mask1;
|
||
// new, bug
|
||
__mmask64 mask11_ = _mm512_cmpgt_epi16_mask(rlen512, i512);
|
||
__mmask64 mask12_ = _mm512_cmpgt_epi16_mask(rlen512_, i512);
|
||
__mmask64 mask1_ = mask11_ | (mask12_ << SIMD_WIDTH16);
|
||
mask2 &= mask1_;
|
||
max512 = _mm512_mask_blend_epi8(mask2, max512, rmax512);
|
||
te512 = _mm512_mask_blend_epi16(mask2, te512, i512);
|
||
te512_ = _mm512_mask_blend_epi16(mask2 >> SIMD_WIDTH16, te512_, i512);
|
||
}
|
||
|
||
int16_t temp4[SIMD_WIDTH8] __attribute((aligned(64)));
|
||
_mm512_store_si512((__m512i*)temp, max512);
|
||
_mm512_store_si512((__m512i*)temp4, te512);
|
||
_mm512_store_si512((__m512i*)(temp4 + SIMD_WIDTH16), te512_);
|
||
|
||
for (int i = 0; i < SIMD_WIDTH8; i++) {
|
||
int16_t* te2;
|
||
te2 = temp4;
|
||
if (qe[i]) {
|
||
alns[i].score2 = (temp[i] == 0 ? (int)-1 : (uint8_t)temp[i]);
|
||
alns[i].te2 = te2[i];
|
||
} else {
|
||
alns[i].score2 = -1;
|
||
alns[i].te2 = -1;
|
||
}
|
||
}
|
||
#endif
|
||
} |