BWA-FastAlign/ksw_align_avx512.c

362 lines
15 KiB
C
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "ksw_align_avx.h"
#undef SIMD_WIDTH8
#undef SIMD_WIDTH16
#define SIMD_WIDTH8 64
#define SIMD_WIDTH16 32
// 默认的非ACGT对应的罚分
static const int8_t w_ambig = -1;
#define MAIN_SAM_CODE8_OPT(s1, s2, h00, h11, e11, f11, f21, max512, sft512) \
{ \
__m512i sbt11, xor11, or11; \
xor11 = _mm512_xor_si512(s1, s2); \
sbt11 = _mm512_shuffle_epi8(permSft512, xor11); \
__mmask64 cmpq = _mm512_cmpeq_epu8_mask(s2, five512); \
sbt11 = _mm512_mask_blend_epi8(cmpq, sbt11, sft512); \
or11 = _mm512_or_si512(s1, s2); \
__mmask64 cmp = _mm512_movepi8_mask(or11); \
__m512i m11 = _mm512_adds_epu8(h00, sbt11); \
m11 = _mm512_mask_blend_epi8(cmp, m11, zero512); \
m11 = _mm512_subs_epu8(m11, sft512); \
h11 = _mm512_max_epu8(m11, e11); \
h11 = _mm512_max_epu8(h11, f11); \
__mmask64 cmp0 = _mm512_cmpgt_epu8_mask(h11, imax512); \
imax512 = _mm512_max_epu8(imax512, h11); \
iqe512 = _mm512_mask_blend_epi8(cmp0, iqe512, l512); \
__m512i gapE512 = _mm512_subs_epu8(h11, oe_ins512); \
e11 = _mm512_subs_epu8(e11, e_ins512); \
e11 = _mm512_max_epu8(gapE512, e11); \
__m512i gapD512 = _mm512_subs_epu8(h11, oe_del512); \
f21 = _mm512_subs_epu8(f11, e_del512); \
f21 = _mm512_max_epu8(gapD512, f21); \
}
#define MAIN_SAM_CODE16_OPT(s1, s2, h00, h11, e11, f11, f21, max512) \
{ \
__m512i sbt11, xor11, or11; \
xor11 = _mm512_xor_si512(s1, s2); \
sbt11 = _mm512_permutexvar_epi16(xor11, perm512); \
__m512i m11 = _mm512_add_epi16(h00, sbt11); \
or11 = _mm512_or_si512(s1, s2); \
__mmask64 cmp = _mm512_movepi8_mask(or11); \
m11 = _mm512_mask_blend_epi8(cmp, m11, zero512); \
h11 = _mm512_max_epi16(m11, e11); \
h11 = _mm512_max_epi16(h11, f11); \
h11 = _mm512_max_epi16(h11, zero512); \
__mmask32 cmp0 = _mm512_cmpgt_epi16_mask(h11, imax512); \
imax512 = _mm512_max_epi16(imax512, h11); \
iqe512 = _mm512_mask_blend_epi16(cmp0, iqe512, l512); \
__m512i gapE512 = _mm512_sub_epi16(h11, oe_ins512); \
e11 = _mm512_sub_epi16(e11, e_ins512); \
e11 = _mm512_max_epi16(gapE512, e11); \
__m512i gapD512 = _mm512_sub_epi16(h11, oe_del512); \
f21 = _mm512_sub_epi16(f11, e_del512); \
f21 = _mm512_max_epi16(gapD512, f21); \
}
void ksw_align_avx512_u8(int8_t w_match, // match分数正数
int8_t w_mismatch, // 错配罚分,负数
int8_t o_ins, // 开始一个insert罚分正数
int8_t e_ins, // 延续一个insert罚分正数
int8_t o_del, // 开始一个delete罚分正数
int8_t e_del, // 延续一个delete罚分正数
msw_buf_t* cache, // 计算用到的一些数据
uint8_t* seq1SoA, // ref序列已经pack好了
uint8_t* seq2SoA, // seq序列
int16_t nrow, // 最长的行数对应ref长度
int16_t ncol, // 最长的列数对应seq长度
int* xtras, // 每个seq对应一个xtra
int* rlenA, // ref真实长度
kswr_avx_t* alns, // 存放结果
int phase) { // 正向阶段0反向阶段1
int g_qmax = max_(w_match, w_mismatch);
g_qmax = max_(g_qmax, w_ambig);
uint8_t minsc[SIMD_WIDTH8] __attribute__((aligned(64))) = {0}; // min score ?
uint8_t endsc[SIMD_WIDTH8] __attribute__((aligned(64))) = {0}; // ending position score ?
__m512i zero512 = _mm512_setzero_si512();
__m512i one512 = _mm512_set1_epi8(1);
int8_t temp[SIMD_WIDTH8] __attribute((aligned(64))) = {0}; // 应该是用来根据比较结果赋分值的
uint8_t shift = 127, mdiff = 0;
mdiff = max_(w_match, (int8_t)w_mismatch);
mdiff = max_(mdiff, (int8_t)w_ambig);
shift = min_(w_match, (int8_t)w_mismatch);
shift = min_((int8_t)shift, w_ambig);
shift = 256 - (uint8_t)shift;
mdiff += shift;
temp[0] = w_match; // states: 1. matches
temp[1] = temp[2] = temp[3] = w_mismatch; // 2. mis-matches
temp[4] = temp[5] = temp[6] = temp[7] = w_ambig; // 3. beyond boundary
temp[8] = temp[9] = temp[10] = temp[11] = w_ambig; // 4. 0 - sse2 region
temp[12] = w_ambig; // 5. ambig
for (int i = 0; i < 16; i++) // for shuffle_epi8
temp[i] += shift;
int pos = 0;
for (int i = 16; i < SIMD_WIDTH8; i++) {
temp[i] = temp[pos++];
if (pos % 16 == 0)
pos = 0;
}
__m512i permSft512 = _mm512_load_si512(temp);
__m512i sft512 = _mm512_set1_epi8(shift);
__m512i cmax512 = _mm512_set1_epi8(255);
// __m512i minsc512, endsc512;
__mmask64 minsc_msk_a = 0x0000, endsc_msk_a = 0x0000;
int val = 0;
for (int i = 0; i < SIMD_WIDTH8; i++) {
int xtra = xtras[i];
val = (xtra & KSW_XSUBO) ? xtra & 0xffff : 0x10000;
if (val <= 255) {
minsc[i] = val;
minsc_msk_a |= (0x1L << i);
}
// msc_mask;
val = (xtra & KSW_XSTOP) ? xtra & 0xffff : 0x10000;
if (val <= 255) {
endsc[i] = val;
endsc_msk_a |= (0x1L << i);
}
}
__m512i minsc512 = _mm512_load_si512((__m512i*)minsc);
__m512i endsc512 = _mm512_load_si512((__m512i*)endsc);
__m512i e_del512 = _mm512_set1_epi8(e_del);
__m512i oe_del512 = _mm512_set1_epi8(o_del + e_del);
__m512i e_ins512 = _mm512_set1_epi8(e_ins);
__m512i oe_ins512 = _mm512_set1_epi8(o_ins + e_ins);
__m512i five512 = _mm512_set1_epi8(DUMMY5); // ambig mapping element
__m512i gmax512 = zero512; // exit1 = zero512;
__m512i te512 = _mm512_set1_epi16(-1); // changed to -1
__m512i te512_ = _mm512_set1_epi16(-1); // changed to -1
__mmask64 exit0 = 0xFFFFFFFFFFFFFFFF;
// 计算过程用到的一些数据用cache预先开辟的空间
uint8_t* H0 = cache->H0;
uint8_t* H1 = cache->H1;
uint8_t* Hmax = cache->Hmax;
uint8_t* F = cache->F;
uint8_t* rowMax = cache->rowMax;
for (int i = 0; i <= ncol; i++) {
_mm512_store_si512((__m512*)(H0 + i * SIMD_WIDTH8), zero512);
_mm512_store_si512((__m512*)(Hmax + i * SIMD_WIDTH8), zero512);
_mm512_store_si512((__m512*)(F + i * SIMD_WIDTH8), zero512);
}
#if 1
__m512i max512 = zero512, imax512, pimax512 = zero512;
__mmask64 mask512 = 0x0000;
__mmask64 minsc_msk = 0x0000;
__m512i qe512 = _mm512_set1_epi8(0);
_mm512_store_si512((__m512i*)(H0), zero512);
_mm512_store_si512((__m512i*)(H1), zero512);
#endif
#if 1
int i, limit = nrow;
for (i = 0; i < nrow; i++) {
__m512i e11 = zero512;
__m512i h00, h11, s1;
__m512i i512 = _mm512_set1_epi16(i);
int j;
s1 = _mm512_load_si512((__m512i*)(seq1SoA + (i + 0) * SIMD_WIDTH8));
imax512 = zero512;
__m512i iqe512 = _mm512_set1_epi8(-1);
__m512i l512 = zero512;
for (j = 0; j < ncol; j++) {
__m512i f11, s2, f21;
h00 = _mm512_load_si512((__m512i*)(H0 + j * SIMD_WIDTH8)); // check for col "0"
s2 = _mm512_load_si512((__m512i*)(seq2SoA + (j)*SIMD_WIDTH8));
f11 = _mm512_load_si512((__m512i*)(F + (j + 1) * SIMD_WIDTH8));
MAIN_SAM_CODE8_OPT(s1, s2, h00, h11, e11, f11, f21, max512, sft512);
_mm512_store_si512((__m512i*)(H1 + (j + 1) * SIMD_WIDTH8), h11); // check for col "0"
_mm512_store_si512((__m512i*)(F + (j + 1) * SIMD_WIDTH8), f21);
l512 = _mm512_add_epi8(l512, one512);
}
// Block I从第二行开始需要和前一行比较来计算max score
if (i > 0) {
__mmask64 msk64 = _mm512_cmpgt_epu8_mask(imax512, pimax512);
msk64 |= mask512;
pimax512 = _mm512_mask_blend_epi8(msk64, pimax512, zero512);
pimax512 = _mm512_mask_blend_epi8(minsc_msk, zero512, pimax512);
pimax512 = _mm512_mask_blend_epi8(exit0, zero512, pimax512);
_mm512_store_si512((__m512i*)(rowMax + (i - 1) * SIMD_WIDTH8), pimax512);
mask512 = ~msk64;
}
pimax512 = imax512;
minsc_msk = _mm512_cmpge_epu8_mask(imax512, minsc512);
minsc_msk &= minsc_msk_a;
// Block II: gmax, te
__mmask64 cmp0 = _mm512_cmpgt_epu8_mask(imax512, gmax512);
cmp0 &= exit0;
gmax512 = _mm512_mask_blend_epi8(cmp0, gmax512, imax512);
te512 = _mm512_mask_blend_epi16((__mmask32)cmp0, te512, i512);
te512_ = _mm512_mask_blend_epi16((__mmask32)(cmp0 >> SIMD_WIDTH16), te512_, i512);
qe512 = _mm512_mask_blend_epi8(cmp0, qe512, iqe512);
cmp0 = _mm512_cmpge_epu8_mask(gmax512, endsc512);
cmp0 &= endsc_msk_a;
__m512i left512 = _mm512_adds_epu8(gmax512, sft512);
__mmask64 cmp2 = _mm512_cmpge_epu8_mask(left512, cmax512);
exit0 = (~(cmp0 | cmp2)) & exit0;
if (exit0 == 0) {
limit = i++;
break;
}
uint8_t* S = H1;
H1 = H0;
H0 = S;
i512 = _mm512_add_epi16(i512, one512);
} // for nrow
pimax512 = _mm512_mask_blend_epi8(mask512, pimax512, zero512);
pimax512 = _mm512_mask_blend_epi8(minsc_msk, zero512, pimax512);
pimax512 = _mm512_mask_blend_epi8(exit0, zero512, pimax512);
_mm512_store_si512((__m512i*)(rowMax + (i - 1) * SIMD_WIDTH8), pimax512);
/******************* DP loop over *****************************/
/**************** Partial output setting **********************/
uint8_t score[SIMD_WIDTH8] __attribute((aligned(64)));
int16_t te1[SIMD_WIDTH8] __attribute((aligned(64)));
uint8_t qe[SIMD_WIDTH8] __attribute((aligned(64)));
int16_t low[SIMD_WIDTH8] __attribute((aligned(64)));
int16_t high[SIMD_WIDTH8] __attribute((aligned(64)));
_mm512_store_si512((__m512i*)score, gmax512);
_mm512_store_si512((__m512i*)te1, te512);
_mm512_store_si512((__m512i*)(te1 + SIMD_WIDTH16), te512_);
_mm512_store_si512((__m512i*)qe, qe512);
int live = 0;
for (int l = 0; l < SIMD_WIDTH8; l++) {
int16_t* te;
if (i < SIMD_WIDTH16)
te = te1;
else
te = te1;
if (phase) { // 第二阶段,反向比对
if (alns[l].score == score[l]) {
alns[l].tb = alns[l].te - te[l];
alns[l].qb = alns[l].qe - qe[l];
}
} else { // 第一阶段,正向比对
alns[l].score = score[l] + shift < 255 ? score[l] : 255;
alns[l].te = te[l];
alns[l].qe = qe[l];
if (alns[l].score != 255) {
qe[l] = 1;
live++;
} else
qe[l] = 0;
}
}
if (phase)
return;
if (live == 0)
return;
/*************** Score2 and te2 *******************/
int qmax = g_qmax;
int maxl = 0, minh = nrow;
for (int i = 0; i < SIMD_WIDTH8; i++) {
int val = (score[i] + qmax - 1) / qmax;
int16_t* te = te1;
low[i] = te[i] - val;
high[i] = te[i] + val;
if (qe[i]) {
maxl = maxl < low[i] ? low[i] : maxl;
minh = minh > high[i] ? high[i] : minh;
}
}
max512 = zero512;
te512 = _mm512_set1_epi16(-1);
te512_ = _mm512_set1_epi16(-1);
__m512i low512 = _mm512_load_si512((__m512i*)low); // make it int16
__m512i high512 = _mm512_load_si512((__m512i*)high); // int16
__m512i low512_ = _mm512_load_si512((__m512i*)(low + SIMD_WIDTH16)); // make it int16
__m512i high512_ = _mm512_load_si512((__m512i*)(high + SIMD_WIDTH16)); // int16
__m512i rmax512;
for (int i = 0; i < maxl; i++) {
__m512i i512 = _mm512_set1_epi16(i);
rmax512 = _mm512_load_si512((__m512i*)(rowMax + i * SIMD_WIDTH8));
__mmask64 mask11 = _mm512_cmpgt_epi16_mask(low512, i512);
__mmask64 mask12 = _mm512_cmpgt_epi16_mask(low512_, i512);
__mmask64 mask2 = _mm512_cmpgt_epu8_mask(rmax512, max512);
__mmask64 mask1 = mask11 | (mask12 << SIMD_WIDTH16);
mask2 &= mask1;
max512 = _mm512_mask_blend_epi8(mask2, max512, rmax512);
te512 = _mm512_mask_blend_epi16(mask2, te512, i512);
te512_ = _mm512_mask_blend_epi16(mask2 >> SIMD_WIDTH16, te512_, i512);
}
// Added new block -- due to bug
int16_t rlen[SIMD_WIDTH8] __attribute((aligned(64)));
for (int i = 0; i < SIMD_WIDTH8; i++) rlen[i] = rlenA[i];
__m512i rlen512 = _mm512_load_si512(rlen);
__m512i rlen512_ = _mm512_load_si512(rlen + SIMD_WIDTH16);
for (int i = minh + 1; i < limit; i++) {
__m512i i512 = _mm512_set1_epi16(i);
rmax512 = _mm512_load_si512((__m512i*)(rowMax + i * SIMD_WIDTH8));
__mmask64 mask11 = _mm512_cmpgt_epi16_mask(i512, high512);
__mmask64 mask12 = _mm512_cmpgt_epi16_mask(i512, high512_);
__mmask64 mask2 = _mm512_cmpgt_epu8_mask(rmax512, max512);
__mmask64 mask1 = mask11 | (mask12 << SIMD_WIDTH16);
mask2 &= mask1;
// new, bug
__mmask64 mask11_ = _mm512_cmpgt_epi16_mask(rlen512, i512);
__mmask64 mask12_ = _mm512_cmpgt_epi16_mask(rlen512_, i512);
__mmask64 mask1_ = mask11_ | (mask12_ << SIMD_WIDTH16);
mask2 &= mask1_;
max512 = _mm512_mask_blend_epi8(mask2, max512, rmax512);
te512 = _mm512_mask_blend_epi16(mask2, te512, i512);
te512_ = _mm512_mask_blend_epi16(mask2 >> SIMD_WIDTH16, te512_, i512);
}
int16_t temp4[SIMD_WIDTH8] __attribute((aligned(64)));
_mm512_store_si512((__m512i*)temp, max512);
_mm512_store_si512((__m512i*)temp4, te512);
_mm512_store_si512((__m512i*)(temp4 + SIMD_WIDTH16), te512_);
for (int i = 0; i < SIMD_WIDTH8; i++) {
int16_t* te2;
te2 = temp4;
if (qe[i]) {
alns[i].score2 = (temp[i] == 0 ? (int)-1 : (uint8_t)temp[i]);
alns[i].te2 = te2[i];
} else {
alns[i].score2 = -1;
alns[i].te2 = -1;
}
}
#endif
}