添加了内存对齐,性能提升5%,解决了对齐内存引入的bug

This commit is contained in:
zzh 2023-08-25 14:47:30 +08:00
parent b58cc1f574
commit b95e622e7a
11 changed files with 1281 additions and 198 deletions

View File

@ -1,6 +1,11 @@
{ {
"files.associations": { "files.associations": {
"functional": "c", "functional": "c",
"random": "c" "random": "c",
"__locale": "c",
"vector": "c",
"__bit_reference": "c",
"__split_buffer": "c",
"string": "c"
} }
} }

View File

@ -2,7 +2,7 @@ CC= gcc
#CFLAGS= -g -Wall -Wno-unused-function -mavx2 #CFLAGS= -g -Wall -Wno-unused-function -mavx2
CFLAGS= -Wall -Wno-unused-function -O2 -mavx2 CFLAGS= -Wall -Wno-unused-function -O2 -mavx2
DFLAGS= -DSHOW_PERF DFLAGS= -DSHOW_PERF
OBJS= ksw_normal.o ksw_avx2.o ksw_cuda.o ksw_avx2_u8.o bsw_avx2.o OBJS= ksw_normal.o ksw_avx2.o ksw_cuda.o ksw_avx2_u8.o bsw_avx2.o ksw_avx2_aligned.o thread_mem.o ksw_avx2_u8_aligned.o
PROG= sw_perf PROG= sw_perf
PROG2= sw_perf_discrete PROG2= sw_perf_discrete
INCLUDES= INCLUDES=

View File

@ -20,6 +20,9 @@
#define MIN(x, y) ((x) < (y) ? (x) : (y)) #define MIN(x, y) ((x) < (y) ? (x) : (y))
#define SIMD_WIDTH 16 #define SIMD_WIDTH 16
#define AMBIGUOUS_BASE_CODE 4
#define AMBIGUOUS_BASE_SCORE -1
/* 去掉多余计算的值 */ /* 去掉多余计算的值 */
static const uint16_t h_vec_int_mask[SIMD_WIDTH][SIMD_WIDTH] = { static const uint16_t h_vec_int_mask[SIMD_WIDTH][SIMD_WIDTH] = {
{0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, {0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
@ -45,7 +48,7 @@ static const uint16_t h_vec_int_mask[SIMD_WIDTH][SIMD_WIDTH] = {
#define SIMD_INIT \ #define SIMD_INIT \
int oe_del = o_del + e_del, oe_ins = o_ins + e_ins; \ int oe_del = o_del + e_del, oe_ins = o_ins + e_ins; \
__m256i zero_vec; \ __m256i zero_vec; \
__m256i max_vec, last_max_vec = _mm256_set1_epi16(h0); \ __m256i max_vec, last_max_vec = _mm256_set1_epi16(init_score); \
__m256i oe_del_vec; \ __m256i oe_del_vec; \
__m256i oe_ins_vec; \ __m256i oe_ins_vec; \
__m256i e_del_vec; \ __m256i e_del_vec; \
@ -56,10 +59,10 @@ static const uint16_t h_vec_int_mask[SIMD_WIDTH][SIMD_WIDTH] = {
oe_ins_vec = _mm256_set1_epi16(-oe_ins); \ oe_ins_vec = _mm256_set1_epi16(-oe_ins); \
e_del_vec = _mm256_set1_epi16(-e_del); \ e_del_vec = _mm256_set1_epi16(-e_del); \
e_ins_vec = _mm256_set1_epi16(-e_ins); \ e_ins_vec = _mm256_set1_epi16(-e_ins); \
__m256i match_sc_vec = _mm256_set1_epi16(a); \ __m256i match_sc_vec = _mm256_set1_epi16(base_match_score); \
__m256i mis_sc_vec = _mm256_set1_epi16(-b); \ __m256i mis_sc_vec = _mm256_set1_epi16(-base_mis_score); \
__m256i amb_sc_vec = _mm256_set1_epi16(-1); \ __m256i amb_sc_vec = _mm256_set1_epi16(AMBIGUOUS_BASE_SCORE); \
__m256i amb_vec = _mm256_set1_epi16(4); \ __m256i amb_vec = _mm256_set1_epi16(AMBIGUOUS_BASE_CODE); \
for (i = 0; i < SIMD_WIDTH; ++i) \ for (i = 0; i < SIMD_WIDTH; ++i) \
h_vec_mask[i] = _mm256_loadu_si256((__m256i *)(&h_vec_int_mask[i])); h_vec_mask[i] = _mm256_loadu_si256((__m256i *)(&h_vec_int_mask[i]));
@ -71,11 +74,11 @@ static const uint16_t h_vec_int_mask[SIMD_WIDTH][SIMD_WIDTH] = {
*/ */
// load向量化数据 // load向量化数据
#define SIMD_LOAD \ #define SIMD_LOAD \
__m256i m1 = _mm256_loadu_si256((__m256i *)(&mA1[j])); \ __m256i m1 = _mm256_loadu_si256((__m256i *)(&cur_match_arr[j])); \
__m256i e1 = _mm256_loadu_si256((__m256i *)(&eA1[j])); \ __m256i e1 = _mm256_loadu_si256((__m256i *)(&cur_del_arr[j])); \
__m256i m1j1 = _mm256_loadu_si256((__m256i *)(&mA1[j - 1])); \ __m256i m1j1 = _mm256_loadu_si256((__m256i *)(&cur_match_arr[j - 1])); \
__m256i f1j1 = _mm256_loadu_si256((__m256i *)(&fA1[j - 1])); \ __m256i f1j1 = _mm256_loadu_si256((__m256i *)(&cur_ins_arr[j - 1])); \
__m256i h0j1 = _mm256_loadu_si256((__m256i *)(&hA0[j - 1])); \ __m256i h0j1 = _mm256_loadu_si256((__m256i *)(&last_max_arr[j - 1])); \
__m256i qs_vec = _mm256_loadu_si256((__m256i *)(&seq[j - 1])); \ __m256i qs_vec = _mm256_loadu_si256((__m256i *)(&seq[j - 1])); \
__m256i ts_vec = _mm256_loadu_si256((__m256i *)(&ref[i])); __m256i ts_vec = _mm256_loadu_si256((__m256i *)(&ref[i]));
@ -128,10 +131,10 @@ static const uint16_t h_vec_int_mask[SIMD_WIDTH][SIMD_WIDTH] = {
// } // }
#define SIMD_STORE \ #define SIMD_STORE \
max_vec = _mm256_max_epu8(max_vec, hn_vec); \ max_vec = _mm256_max_epu8(max_vec, hn_vec); \
_mm256_storeu_si256((__m256i *)&eA2[j], en_vec); \ _mm256_storeu_si256((__m256i *)&next_del_arr[j], en_vec); \
_mm256_storeu_si256((__m256i *)&fA2[j], fn_vec); \ _mm256_storeu_si256((__m256i *)&next_ins_arr[j], fn_vec); \
_mm256_storeu_si256((__m256i *)&mA2[j], mn_vec); \ _mm256_storeu_si256((__m256i *)&next_match_arr[j], mn_vec); \
_mm256_storeu_si256((__m256i *)&hA2[j], hn_vec); _mm256_storeu_si256((__m256i *)&next_max_arr[j], hn_vec);
// 去除多余的部分 // 去除多余的部分
#define SIMD_REMOVE_EXTRA \ #define SIMD_REMOVE_EXTRA \
@ -156,7 +159,7 @@ static const uint16_t h_vec_int_mask[SIMD_WIDTH][SIMD_WIDTH] = {
m = maxVal[0]; \ m = maxVal[0]; \
for (j = beg, i = iend; j <= end; j += SIMD_WIDTH, i -= SIMD_WIDTH) \ for (j = beg, i = iend; j <= end; j += SIMD_WIDTH, i -= SIMD_WIDTH) \
{ \ { \
__m256i h2_vec = _mm256_loadu_si256((__m256i *)(&hA2[j])); \ __m256i h2_vec = _mm256_loadu_si256((__m256i *)(&next_max_arr[j])); \
__m256i vcmp = _mm256_cmpeq_epi16(h2_vec, max_vec); \ __m256i vcmp = _mm256_cmpeq_epi16(h2_vec, max_vec); \
uint32_t mask = _mm256_movemask_epi8(vcmp); \ uint32_t mask = _mm256_movemask_epi8(vcmp); \
if (mask > 0) \ if (mask > 0) \
@ -168,7 +171,7 @@ static const uint16_t h_vec_int_mask[SIMD_WIDTH][SIMD_WIDTH] = {
{ \ { \
if (seq[mj + 1] == ref[mi + 1 + SIMD_WIDTH]) \ if (seq[mj + 1] == ref[mi + 1 + SIMD_WIDTH]) \
{ \ { \
m += a; \ m += base_match_score; \
} \ } \
else \ else \
{ \ { \
@ -182,19 +185,19 @@ static const uint16_t h_vec_int_mask[SIMD_WIDTH][SIMD_WIDTH] = {
// 每轮迭代后,交换数组 // 每轮迭代后,交换数组
#define SWAP_DATA_POINTER \ #define SWAP_DATA_POINTER \
int16_t *tmp = hA0; \ int16_t *tmp = last_max_arr; \
hA0 = hA1; \ last_max_arr = cur_max_arr; \
hA1 = hA2; \ cur_max_arr = next_max_arr; \
hA2 = tmp; \ next_max_arr = tmp; \
tmp = eA1; \ tmp = cur_del_arr; \
eA1 = eA2; \ cur_del_arr = next_del_arr; \
eA2 = tmp; \ next_del_arr = tmp; \
tmp = fA1; \ tmp = cur_ins_arr; \
fA1 = fA2; \ cur_ins_arr = next_ins_arr; \
fA2 = tmp; \ next_ins_arr = tmp; \
tmp = mA1; \ tmp = cur_match_arr; \
mA1 = mA2; \ cur_match_arr = next_match_arr; \
mA2 = tmp; next_match_arr = tmp;
uint8_t mem[102400]; uint8_t mem[102400];
@ -202,26 +205,28 @@ int bsw_avx2(int qlen, // query length 待匹配段碱基的query
const uint8_t *query, // read碱基序列 const uint8_t *query, // read碱基序列
int tlen, // target length reference的长度 int tlen, // target length reference的长度
const uint8_t *target, // reference序列 const uint8_t *target, // reference序列
int is_left, // 是不是向左扩展 int extend_left, // 是不是向左扩展
int m, // 碱基种类 (5)
const int8_t *mat, // 每个位置的query和target的匹配得分 m*m
int o_del, // deletion 错配开始的惩罚系数 int o_del, // deletion 错配开始的惩罚系数
int e_del, // deletion extension的惩罚系数 int e_del, // deletion extension的惩罚系数
int o_ins, // insertion 错配开始的惩罚系数 int o_ins, // insertion 错配开始的惩罚系数
int e_ins, // insertion extension的惩罚系数SIMD_BTYES int e_ins, // insertion extension的惩罚系数SIMD_BTYES
int a, // 碱基match时的分数 int base_match_score, // 碱基match时的分数
int b, // 碱基mismatch时的惩罚分数正数 int base_mis_score, // 碱基mismatch时的惩罚分数正数
int w, // 提前剪枝系数w =100 匹配位置和beg的最大距离 int window_size, // 提前剪枝系数w =100 匹配位置和beg的最大距离
int end_bonus, // 如果query比对到了最后一个字符额外奖励分值 int end_bonus, // 如果query比对到了最后一个字符额外奖励分值
int zdrop, // 没匹配上的太多max-(m+ins or del score)),退出后续匹配 int init_score, // 该seed的初始得分完全匹配query的碱基数
int h0, // 该seed的初始得分完全匹配query的碱基数
int *_qle, // 匹配得到全局最大得分的碱基在query的位置 int *_qle, // 匹配得到全局最大得分的碱基在query的位置
int *_tle, // 匹配得到全局最大得分的碱基在reference的位置 int *_tle, // 匹配得到全局最大得分的碱基在reference的位置
int *_gtle, // query全部匹配上的target的长度 int *_gtle, // query全部匹配上的target的长度
int *_gscore, // query的端到端匹配得分 int *_gscore, // query的端到端匹配得分
int *_max_off) // 取得最大得分时在query和reference上位置差的 最大值 int *_max_off) // 取得最大得分时在query和reference上位置差的 最大值
{ {
int16_t *mA, *hA, *eA, *fA, *mA1, *mA2, *hA0, *hA1, *eA1, *fA1, *hA2, *eA2, *fA2; // hA0保存上上个col的H其他的保存上个H E F M return 0;
int16_t *mA, *eA, *hA, *fA,
*cur_match_arr, *next_match_arr,
*last_max_arr, *cur_max_arr, *next_max_arr,
*cur_del_arr, *next_del_arr,
*cur_ins_arr, *next_ins_arr; // hA0保存上上个col的H其他的保存上个H E F M
int16_t *seq, *ref; int16_t *seq, *ref;
// uint8_t *mem; // uint8_t *mem;
int16_t *qtmem, *vmem; int16_t *qtmem, *vmem;
@ -236,14 +241,14 @@ int bsw_avx2(int qlen, // query length 待匹配段碱基的query
SIMD_INIT; // 初始化simd用的数据 SIMD_INIT; // 初始化simd用的数据
assert(h0 > 0); assert(init_score > 0);
// allocate memory // allocate memory
// mem = malloc(mem_size); // mem = malloc(mem_size);
qtmem = (int16_t *)&mem[0]; qtmem = (int16_t *)&mem[0];
seq = &qtmem[0]; seq = &qtmem[0];
ref = &qtmem[seq_size]; ref = &qtmem[seq_size];
if (is_left) if (extend_left)
{ {
for (i = 0; i < qlen; ++i) for (i = 0; i < qlen; ++i)
seq[i] = query[qlen - 1 - i]; seq[i] = query[qlen - 1 - i];
@ -268,43 +273,41 @@ int bsw_avx2(int qlen, // query length 待匹配段碱基的query
eA = &vmem[col_size * 5]; eA = &vmem[col_size * 5];
fA = &vmem[col_size * 7]; fA = &vmem[col_size * 7];
hA0 = &hA[0]; last_max_arr = &hA[0];
hA1 = &hA[col_size]; cur_max_arr = &hA[col_size];
hA2 = &hA1[col_size]; next_max_arr = &cur_max_arr[col_size];
mA1 = &mA[0]; cur_match_arr = &mA[0];
mA2 = &mA[col_size]; next_match_arr = &mA[col_size];
eA1 = &eA[0]; cur_del_arr = &eA[0];
eA2 = &eA[col_size]; next_del_arr = &eA[col_size];
fA1 = &fA[0]; cur_ins_arr = &fA[0];
fA2 = &fA[col_size]; next_ins_arr = &fA[col_size];
// adjust $w if it is too large // adjust $window_size if it is too large
k = m * m;
// get the max score // get the max score
for (i = 0, max = 0; i < k; ++i) max = base_match_score;
max = max > mat[i] ? max : mat[i];
max_ins = (int)((double)(qlen * max + end_bonus - o_ins) / e_ins + 1.); max_ins = (int)((double)(qlen * max + end_bonus - o_ins) / e_ins + 1.);
max_ins = max_ins > 1 ? max_ins : 1; max_ins = max_ins > 1 ? max_ins : 1;
w = w < max_ins ? w : max_ins; window_size = window_size < max_ins ? window_size : max_ins;
max_del = (int)((double)(qlen * max + end_bonus - o_del) / e_del + 1.); max_del = (int)((double)(qlen * max + end_bonus - o_del) / e_del + 1.);
max_del = max_del > 1 ? max_del : 1; max_del = max_del > 1 ? max_del : 1;
w = w < max_del ? w : max_del; // TODO: is this necessary? window_size = window_size < max_del ? window_size : max_del; // TODO: is this necessary?
if (tlen < qlen) if (tlen < qlen)
w = MIN(tlen - 1, w); window_size = MIN(tlen - 1, window_size);
// DP loop // DP loop
max = h0, max_i = max_j = -1; max = init_score, max_i = max_j = -1;
max_ie = -1, gscore = -1; max_ie = -1, gscore = -1;
; ;
max_off = 0; max_off = 0;
beg = 1; beg = 1;
end = qlen; end = qlen;
// init h0 // init init_score
hA0[0] = h0; // 左上角 last_max_arr[0] = init_score; // 左上角
if (qlen == 0 || tlen == 0) if (qlen == 0 || tlen == 0)
Dloop = 0; // 防止意外情况 Dloop = 0; // 防止意外情况
if (w >= qlen) if (window_size >= qlen)
{ {
max_ie = 0; max_ie = 0;
gscore = 0; gscore = 0;
@ -318,16 +321,19 @@ int bsw_avx2(int qlen, // query length 待匹配段碱基的query
// 边界条件一定要注意! tlen 大于,等于,小于 qlen时的情况 // 边界条件一定要注意! tlen 大于,等于,小于 qlen时的情况
if (D > tlen) if (D > tlen)
{ {
span = MIN(Dloop - D, w); // 计算的窗口,或者说范围 span = MIN(Dloop - D, window_size); // 计算的窗口,或者说范围
beg1 = MAX(D - tlen + 1, ((D - w) / 2) + 1); beg1 = MAX(D - tlen + 1, ((D - window_size) / 2) + 1);
} }
else else
{ {
span = MIN(D - 1, w); span = MIN(D - 1, window_size);
beg1 = MAX(1, ((D - w) / 2) + 1); beg1 = MAX(1, ((D - window_size) / 2) + 1);
} }
end1 = MIN(qlen, beg1 + span); end1 = MIN(qlen, beg1 + span);
// beg = 1;
// end = qlen;
if (beg < beg1) if (beg < beg1)
beg = beg1; beg = beg1;
if (end > end1) if (end > end1)
@ -347,17 +353,17 @@ int bsw_avx2(int qlen, // query length 待匹配段碱基的query
// 左边界 处理f (insert) // 左边界 处理f (insert)
if (iStart == 0) if (iStart == 0)
{ {
hA1[end] = MAX(0, h0 - (o_ins + e_ins * end)); cur_max_arr[end] = MAX(0, init_score - (o_ins + e_ins * end));
} }
// 上边界 // 上边界
if (beg == 1) if (beg == 1)
{ {
hA1[0] = MAX(0, h0 - (o_del + e_del * iend)); cur_max_arr[0] = MAX(0, init_score - (o_del + e_del * iend));
} }
else else
{ {
hA1[beg - 1] = 0; cur_max_arr[beg - 1] = 0;
eA1[beg - 1] = 0; cur_del_arr[beg - 1] = 0;
} }
for (j = beg, i = iend; j <= end + 1 - SIMD_WIDTH; j += SIMD_WIDTH, i -= SIMD_WIDTH) for (j = beg, i = iend; j <= end + 1 - SIMD_WIDTH; j += SIMD_WIDTH, i -= SIMD_WIDTH)
@ -393,8 +399,8 @@ int bsw_avx2(int qlen, // query length 待匹配段碱基的query
if (j == qlen + 1) if (j == qlen + 1)
{ {
max_ie = gscore > hA2[qlen] ? max_ie : iStart; max_ie = gscore > next_max_arr[qlen] ? max_ie : iStart;
gscore = gscore > hA2[qlen] ? gscore : hA2[qlen]; gscore = gscore > next_max_arr[qlen] ? gscore : next_max_arr[qlen];
} }
// if (m == 0 && m_last == 0) // if (m == 0 && m_last == 0)
// break; // 一定要注意,斜对角遍历和按列遍历的不同点 // break; // 一定要注意,斜对角遍历和按列遍历的不同点
@ -403,19 +409,6 @@ int bsw_avx2(int qlen, // query length 待匹配段碱基的query
max = m, max_i = mi, max_j = mj; max = m, max_i = mi, max_j = mj;
max_off = max_off > abs(mj - mi) ? max_off : abs(mj - mi); max_off = max_off > abs(mj - mi) ? max_off : abs(mj - mi);
} }
else if (0) //(zdrop > 0)
{
if (mi - max_i > mj - max_j)
{
if (max - m - ((mi - max_i) - (mj - max_j)) * e_del > zdrop)
break;
}
else
{
if (max - m - ((mj - max_j) - (mi - max_i)) * e_ins > zdrop)
break;
}
}
// 调整计算的边界 // 调整计算的边界
@ -423,8 +416,8 @@ int bsw_avx2(int qlen, // query length 待匹配段碱基的query
/* for (j = beg; j <= end; j += SIMD_WIDTH) /* for (j = beg; j <= end; j += SIMD_WIDTH)
{ {
__m256i h1 = _mm256_loadu_si256((__m256i *)(&hA1[j - 1])); __m256i h1 = _mm256_loadu_si256((__m256i *)(&cur_max_arr[j - 1]));
__m256i h2 = _mm256_loadu_si256((__m256i *)(&hA2[j])); __m256i h2 = _mm256_loadu_si256((__m256i *)(&next_max_arr[j]));
__m256i orvec = _mm256_or_si256(h1, h2); __m256i orvec = _mm256_or_si256(h1, h2);
__m256i vcmp = _mm256_cmpgt_epi16(orvec, zero_vec); __m256i vcmp = _mm256_cmpgt_epi16(orvec, zero_vec);
uint32_t mask = _mm256_movemask_epi8(vcmp); uint32_t mask = _mm256_movemask_epi8(vcmp);
@ -452,8 +445,8 @@ int bsw_avx2(int qlen, // query length 待匹配段碱基的query
// int pos = 0; // int pos = 0;
// for (j = beg; j <= end; j += SIMD_WIDTH) // for (j = beg; j <= end; j += SIMD_WIDTH)
//{ //{
// __m256i h1 = _mm256_loadu_si256((__m256i *)(&hA1[j - 1])); // __m256i h1 = _mm256_loadu_si256((__m256i *)(&cur_max_arr[j - 1]));
// __m256i h2 = _mm256_loadu_si256((__m256i *)(&hA2[j])); // __m256i h2 = _mm256_loadu_si256((__m256i *)(&next_max_arr[j]));
// __m256i orvec = _mm256_or_si256(h1, h2); // __m256i orvec = _mm256_or_si256(h1, h2);
// int *val = (int *)&orvec; // int *val = (int *)&orvec;
// for (i = 0; i < SIMD_WIDTH; ++i) // for (i = 0; i < SIMD_WIDTH; ++i)
@ -466,7 +459,7 @@ int bsw_avx2(int qlen, // query length 待匹配段碱基的query
// beg = j; // beg = j;
for (j = beg; LIKELY(j <= end); ++j) for (j = beg; LIKELY(j <= end); ++j)
{ {
int has_val = hA1[j - 1] | hA2[j]; int has_val = cur_max_arr[j - 1] | next_max_arr[j];
if (has_val) if (has_val)
{ {
break; break;
@ -474,23 +467,23 @@ int bsw_avx2(int qlen, // query length 待匹配段碱基的query
} }
beg = j; beg = j;
hA2[end + 1] = 0; next_max_arr[end + 1] = 0;
for (j = end + 1; LIKELY(j >= beg); --j) for (j = end + 1; LIKELY(j >= beg); --j)
{ {
int has_val = hA1[j - 1] | hA2[j]; int has_val = cur_max_arr[j - 1] | next_max_arr[j];
if (has_val) if (has_val)
{ {
break; break;
} }
// else // else
// hA0[j - 1] = 0; // last_max_arr[j - 1] = 0;
} }
end = j + 1 <= qlen ? j + 1 : qlen; end = j + 1 <= qlen ? j + 1 : qlen;
/* for (j = end + 1; j >= beg; j -= SIMD_WIDTH) // 没有考虑beg附近且长度小于SIMD_WIDTH的数据 /* for (j = end + 1; j >= beg; j -= SIMD_WIDTH) // 没有考虑beg附近且长度小于SIMD_WIDTH的数据
{ {
__m256i h1 = _mm256_loadu_si256((__m256i *)(&hA1[j - 1])); __m256i h1 = _mm256_loadu_si256((__m256i *)(&cur_max_arr[j - 1]));
__m256i h2 = _mm256_loadu_si256((__m256i *)(&hA2[j])); __m256i h2 = _mm256_loadu_si256((__m256i *)(&next_max_arr[j]));
__m256i orvec = _mm256_or_si256(h1, h2); __m256i orvec = _mm256_or_si256(h1, h2);
__m256i vcmp = _mm256_cmpgt_epi16(orvec, zero_vec); __m256i vcmp = _mm256_cmpgt_epi16(orvec, zero_vec);
uint32_t mask = _mm256_movemask_epi8(vcmp); uint32_t mask = _mm256_movemask_epi8(vcmp);
@ -503,7 +496,7 @@ int bsw_avx2(int qlen, // query length 待匹配段碱基的query
} }
else else
{ {
_mm256_storeu_si256((__m256i *)&hA0[j - 1], zero_vec); _mm256_storeu_si256((__m256i *)&last_max_arr[j - 1], zero_vec);
} }
} }
*/ */

View File

@ -307,6 +307,8 @@ int ksw_avx2(int qlen, // query length 待匹配段碱基的query
if (beg > end) if (beg > end)
break; // 不用计算了直接跳出否则hA2没有被赋值里边是上一轮hA0的值会出bug break; // 不用计算了直接跳出否则hA2没有被赋值里边是上一轮hA0的值会出bug
beg = 1;
end = qlen;
iend = D - (beg - 1); // ref开始计算的位置倒序 iend = D - (beg - 1); // ref开始计算的位置倒序
span = end - beg; span = end - beg;
iStart = iend - span - 1; // 0开始的ref索引位置 iStart = iend - span - 1; // 0开始的ref索引位置
@ -375,7 +377,7 @@ int ksw_avx2(int qlen, // query length 待匹配段碱基的query
max = m, max_i = mi, max_j = mj; max = m, max_i = mi, max_j = mj;
max_off = max_off > abs(mj - mi) ? max_off : abs(mj - mi); max_off = max_off > abs(mj - mi) ? max_off : abs(mj - mi);
} }
else if (zdrop > 0) else if (0) //(zdrop > 0)
{ {
if (mi - max_i > mj - max_j) if (mi - max_i > mj - max_j)
{ {
@ -390,7 +392,7 @@ int ksw_avx2(int qlen, // query length 待匹配段碱基的query
} }
// 调整计算的边界 // 调整计算的边界
for (j = beg; LIKELY(j <= end); ++j) /*for (j = beg; LIKELY(j <= end); ++j)
{ {
int has_val = hA1[j - 1] | hA2[j]; int has_val = hA1[j - 1] | hA2[j];
if (has_val) if (has_val)
@ -408,6 +410,7 @@ int ksw_avx2(int qlen, // query length 待匹配段碱基的query
end = j + 1 <= qlen ? j + 1 : qlen; end = j + 1 <= qlen ? j + 1 : qlen;
// beg = 0; // beg = 0;
// end = qlen; // uncomment this line for debugging // end = qlen; // uncomment this line for debugging
*/
m_last = m; m_last = m;
// swap m, h, e, f // swap m, h, e, f
SWAP_DATA_POINTER; SWAP_DATA_POINTER;

450
ksw_avx2_aligned.c 100644
View File

@ -0,0 +1,450 @@
#include <stdlib.h>
#include <stdint.h>
#include <assert.h>
#include <stdio.h>
#include <immintrin.h>
#include <emmintrin.h>
#include "thread_mem.h"
#ifdef __GNUC__
#define LIKELY(x) __builtin_expect((x), 1)
#define UNLIKELY(x) __builtin_expect((x), 0)
#else
#define LIKELY(x) (x)
#define UNLIKELY(x) (x)
#endif
#undef MAX
#undef MIN
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define SIMD_WIDTH 16
#define BASE_BYTES 2
#define SCORE_BYTES 2
#define BOUNDARY_SCORE_NUM 2
#define TMP_SCORE_ARRAY_NUM 9
#define MEM_ALIGN_BYTES 32
#define ALIGN_SHIFT_BITS 5
#define SIMD_BYTES 32
#define AMBIGUOUS_BASE_CODE 4
#define AMBIGUOUS_BASE_SCORE -1
// 32字节对齐256位
#define align_mem(x) (((x) + 31) >> 5 << 5)
#define align_number(x) align_mem(x)
/* 去掉多余计算的值 */
static const uint16_t h_vec_int_mask[SIMD_WIDTH][SIMD_WIDTH] = {
{0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0},
{0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0},
{0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0},
{0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0},
{0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0},
{0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0},
{0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0},
{0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0},
{0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff}};
// #define permute_mask _MM_SHUFFLE(0, 1, 2, 3)
#define permute_mask 27
// 初始化变量
#define SIMD_INIT \
int oe_del = o_del + e_del, oe_ins = o_ins + e_ins; \
__m256i zero_vec; \
__m256i max_vec, last_max_vec = _mm256_set1_epi16(init_score); \
__m256i oe_del_vec; \
__m256i oe_ins_vec; \
__m256i e_del_vec; \
__m256i e_ins_vec; \
__m256i h_vec_mask[SIMD_WIDTH]; \
zero_vec = _mm256_setzero_si256(); \
oe_del_vec = _mm256_set1_epi16(-oe_del); \
oe_ins_vec = _mm256_set1_epi16(-oe_ins); \
e_del_vec = _mm256_set1_epi16(-e_del); \
e_ins_vec = _mm256_set1_epi16(-e_ins); \
__m256i match_sc_vec = _mm256_set1_epi16(base_match_score); \
__m256i mis_sc_vec = _mm256_set1_epi16(-base_mis_score); \
__m256i amb_sc_vec = _mm256_set1_epi16(AMBIGUOUS_BASE_SCORE); \
__m256i amb_vec = _mm256_set1_epi16(AMBIGUOUS_BASE_CODE); \
for (i = 0; i < SIMD_WIDTH; ++i) \
h_vec_mask[i] = _mm256_loadu_si256((__m256i *)(&h_vec_int_mask[i]));
/*
* e ref
* f seq
* m
* h
*/
// load向量化数据
#define SIMD_LOAD \
__m256i m1 = _mm256_loadu_si256((__m256i *)(&cur_match_arr[j])); \
__m256i e1 = _mm256_loadu_si256((__m256i *)(&cur_del_arr[j])); \
__m256i m1j1 = _mm256_loadu_si256((__m256i *)(&cur_match_arr[j - 1])); \
__m256i f1j1 = _mm256_loadu_si256((__m256i *)(&cur_ins_arr[j - 1])); \
__m256i h0j1 = _mm256_loadu_si256((__m256i *)(&last_max_arr[j - 1])); \
__m256i qs_vec = _mm256_loadu_si256((__m256i *)(&read_seq[j])); \
__m256i ts_vec = _mm256_loadu_si256((__m256i *)(&ref_seq[i]));
// 比对ref和seq的序列计算罚分
#define SIMD_CMP_SEQ \
/* 将待比对的target序列逆序排列 */ \
ts_vec = _mm256_permute4x64_epi64(ts_vec, permute_mask); \
ts_vec = _mm256_shufflelo_epi16(ts_vec, permute_mask); \
ts_vec = _mm256_shufflehi_epi16(ts_vec, permute_mask); \
__m256i match_mask_vec = _mm256_cmpeq_epi16(qs_vec, ts_vec); /* 比对query和target字符序列 */ \
__m256i mis_score_vec = _mm256_andnot_si256(match_mask_vec, mis_sc_vec); /* 未匹配上的位置赋值mismatch分数 */ \
__m256i score_vec = _mm256_and_si256(match_sc_vec, match_mask_vec); /* 匹配上的位置赋值match分数 */ \
score_vec = _mm256_or_si256(score_vec, mis_score_vec); \
/* 计算模棱两可的字符N)的位置的分数 */ \
__m256i q_amb_mask_vec = _mm256_cmpeq_epi16(qs_vec, amb_vec); \
__m256i t_amb_mask_vec = _mm256_cmpeq_epi16(ts_vec, amb_vec); \
__m256i amb_mask_vec = _mm256_or_si256(q_amb_mask_vec, t_amb_mask_vec); \
score_vec = _mm256_andnot_si256(amb_mask_vec, score_vec); \
__m256i amb_score_vec = _mm256_and_si256(amb_mask_vec, amb_sc_vec); \
score_vec = _mm256_or_si256(score_vec, amb_score_vec);
// 向量化计算h, e, f, m
#define SIMD_COMPUTE \
__m256i en_vec0 = _mm256_add_epi16(m1, oe_del_vec); \
__m256i en_vec1 = _mm256_add_epi16(e1, e_del_vec); \
__m256i en_vec = _mm256_max_epi16(en_vec0, en_vec1); \
__m256i fn_vec0 = _mm256_add_epi16(m1j1, oe_ins_vec); \
__m256i fn_vec1 = _mm256_add_epi16(f1j1, e_ins_vec); \
__m256i fn_vec = _mm256_max_epi16(fn_vec0, fn_vec1); \
__m256i mn_vec0 = _mm256_add_epi16(h0j1, score_vec); \
__m256i mn_mask = _mm256_cmpgt_epi16(h0j1, zero_vec); \
__m256i mn_vec = _mm256_and_si256(mn_vec0, mn_mask); \
__m256i hn_vec0 = _mm256_max_epi16(en_vec, fn_vec); \
__m256i hn_vec = _mm256_max_epi16(hn_vec0, mn_vec); \
en_vec = _mm256_max_epi16(en_vec, zero_vec); \
fn_vec = _mm256_max_epi16(fn_vec, zero_vec); \
mn_vec = _mm256_max_epi16(mn_vec, zero_vec); \
hn_vec = _mm256_max_epi16(hn_vec, zero_vec);
#define SIMD_STORE \
max_vec = _mm256_max_epu8(max_vec, hn_vec); \
_mm256_storeu_si256((__m256i *)&next_del_arr[j], en_vec); \
_mm256_storeu_si256((__m256i *)&next_ins_arr[j], fn_vec); \
_mm256_storeu_si256((__m256i *)&next_match_arr[j], mn_vec); \
_mm256_storeu_si256((__m256i *)&next_max_arr[j], hn_vec);
// 去除多余的部分
#define SIMD_REMOVE_EXTRA \
en_vec = _mm256_and_si256(en_vec, h_vec_mask[read_end_pos - j]); \
fn_vec = _mm256_and_si256(fn_vec, h_vec_mask[read_end_pos - j]); \
mn_vec = _mm256_and_si256(mn_vec, h_vec_mask[read_end_pos - j]); \
hn_vec = _mm256_and_si256(hn_vec, h_vec_mask[read_end_pos - j]);
// 找最大值和位置
#define SIMD_FIND_MAX \
__m256i cmp_max = _mm256_cmpgt_epi16(max_vec, last_max_vec); \
uint32_t cmp_result = _mm256_movemask_epi8(cmp_max); \
if (cmp_result > 0) \
{ \
max_vec = _mm256_max_epu16(max_vec, _mm256_alignr_epi8(max_vec, max_vec, 2)); \
max_vec = _mm256_max_epu16(max_vec, _mm256_alignr_epi8(max_vec, max_vec, 4)); \
max_vec = _mm256_max_epu16(max_vec, _mm256_alignr_epi8(max_vec, max_vec, 6)); \
max_vec = _mm256_max_epu16(max_vec, _mm256_alignr_epi8(max_vec, max_vec, 8)); \
max_vec = _mm256_max_epu16(max_vec, _mm256_permute2x128_si256(max_vec, max_vec, 0x01)); \
int16_t *maxVal = (int16_t *)&max_vec; \
m = maxVal[0]; \
for (j = aligned_read_start_pos, i = aligned_ref_end_pos; j <= read_end_pos; j += SIMD_WIDTH, i -= SIMD_WIDTH) \
{ \
__m256i h2_vec = _mm256_loadu_si256((__m256i *)(&next_max_arr[j])); \
__m256i vcmp = _mm256_cmpeq_epi16(h2_vec, max_vec); \
uint32_t mask = _mm256_movemask_epi8(vcmp); \
if (mask > 0) \
{ \
int pos = SIMD_WIDTH - 1 - ((__builtin_clz(mask)) >> 1); \
mj = j - 1 + pos; \
mi = i - 1 - pos; \
for (; mj + 1 < qlen && mi + 1 < tlen; mj++, mi++) \
{ \
if (read_seq[mj + 2] == ref_seq[mi + 1 + SIMD_WIDTH]) \
{ \
m += base_match_score; \
} \
else \
{ \
break; \
} \
} \
} \
} \
last_max_vec = _mm256_set1_epi16(m); \
}
// 每轮迭代后,交换数组
#define SWAP_DATA_POINTER \
int16_t *tmp = last_max_arr; \
last_max_arr = cur_max_arr; \
cur_max_arr = next_max_arr; \
next_max_arr = tmp; \
tmp = cur_del_arr; \
cur_del_arr = next_del_arr; \
next_del_arr = tmp; \
tmp = cur_ins_arr; \
cur_ins_arr = next_ins_arr; \
next_ins_arr = tmp; \
tmp = cur_match_arr; \
cur_match_arr = next_match_arr; \
next_match_arr = tmp;
// uint8_t mem_addr[102400];
int ksw_avx2_aligned(thread_mem_t *tmem,
int qlen, // query length 待匹配段碱基的query长度
const uint8_t *query, // read碱基序列
int tlen, // target length reference的长度
const uint8_t *target, // reference序列
int extend_left, // 是不是向左扩展
int o_del, // deletion 错配开始的惩罚系数
int e_del, // deletion extension的惩罚系数
int o_ins, // insertion 错配开始的惩罚系数
int e_ins, // insertion extension的惩罚系数SIMD_BTYES
int base_match_score, // 碱基match时的分数
int base_mis_score, // 碱基mismatch时的惩罚分数正数
int window_size, // 提前剪枝系数w =100 匹配位置和beg的最大距离
int end_bonus, // 如果query比对到了最后一个字符额外奖励分值
int init_score, // 该seed的初始得分完全匹配query的碱基数
int *_qle, // 匹配得到全局最大得分的碱基在query的位置
int *_tle, // 匹配得到全局最大得分的碱基在reference的位置
int *_gtle, // query全部匹配上的target的长度
int *_gscore, // query的端到端匹配得分
int *_max_off) // 取得最大得分时在query和reference上位置差的 最大值
{
int16_t *cur_match_arr, *next_match_arr,
*last_max_arr, *cur_max_arr, *next_max_arr,
*cur_del_arr, *next_del_arr,
*cur_ins_arr, *next_ins_arr; // hA0保存上上个col的H其他的保存上个H E F M
int16_t *read_seq, *ref_seq;
uint8_t *mem_addr;
int read_size = align_number(qlen * BASE_BYTES + MEM_ALIGN_BYTES);
int ref_size = align_number((tlen + SIMD_WIDTH) * BASE_BYTES);
int back_diagnal_num = tlen + qlen; // 循环跳出条件 D从1开始遍历
int score_array_size = align_number((qlen + BOUNDARY_SCORE_NUM) * SCORE_BYTES);
int score_element_num = score_array_size / SCORE_BYTES;
int score_mem_size = score_array_size * TMP_SCORE_ARRAY_NUM;
int request_mem_size = read_size + ref_size + score_mem_size + MEM_ALIGN_BYTES * 3; // 左侧内存地址对齐 + 数据向左偏移一个元素 + 末尾SIMD补齐
int i, ref_start_pos, di, j, read_start_pos, read_end_pos, max, max_i, max_j, max_ins, max_del, max_ie, gscore, max_off;
int span, beg1, end1; // 边界条件计算
int aligned_read_start_pos, aligned_ref_end_pos;
int ref_end_pos;
SIMD_INIT; // 初始化simd用的数据
assert(init_score > 0);
// allocate memory
mem_addr = thread_mem_request(tmem, request_mem_size);
mem_addr = (void *)align_mem((uint64_t)mem_addr);
ref_seq = (int16_t *)&mem_addr[0];
read_seq = (int16_t *)(mem_addr + ref_size + SIMD_BYTES - BASE_BYTES);
if (extend_left)
{
for (i = 0; i < qlen; ++i)
read_seq[i + 1] = query[qlen - 1 - i];
for (i = 0; i < tlen; ++i)
ref_seq[i + SIMD_WIDTH] = target[tlen - 1 - i];
}
else
{
for (i = 0; i < qlen; ++i)
read_seq[i + 1] = query[i];
for (i = 0; i < tlen; ++i)
ref_seq[i + SIMD_WIDTH] = target[i];
}
mem_addr += read_size + ref_size + (SIMD_BYTES - SCORE_BYTES);
for (i = 0; i < score_mem_size; i += SIMD_BYTES)
{
_mm256_storeu_si256((__m256i *)&mem_addr[i], zero_vec);
}
last_max_arr = (int16_t *)&mem_addr[0];
cur_max_arr = &last_max_arr[score_element_num];
next_max_arr = &cur_max_arr[score_element_num];
cur_match_arr = &next_max_arr[score_element_num];
next_match_arr = &cur_match_arr[score_element_num];
cur_del_arr = &next_match_arr[score_element_num];
next_del_arr = &cur_del_arr[score_element_num];
cur_ins_arr = &next_del_arr[score_element_num];
next_ins_arr = &cur_ins_arr[score_element_num];
// adjust $window_size if it is too large
// get the max score
max = base_match_score;
max_ins = (int)((double)(qlen * max + end_bonus - o_ins) / e_ins + 1.);
max_ins = max_ins > 1 ? max_ins : 1;
window_size = window_size < max_ins ? window_size : max_ins;
max_del = (int)((double)(qlen * max + end_bonus - o_del) / e_del + 1.);
max_del = max_del > 1 ? max_del : 1;
window_size = window_size < max_del ? window_size : max_del; // TODO: is this necessary?
if (tlen < qlen)
window_size = MIN(tlen - 1, window_size);
// DP loop
max = init_score, max_i = max_j = -1;
max_ie = -1, gscore = -1;
;
max_off = 0;
read_start_pos = 1;
read_end_pos = qlen;
// init init_score
last_max_arr[0] = init_score; // 左上角
if (qlen == 0 || tlen == 0)
back_diagnal_num = 0; // 防止意外情况
if (window_size >= qlen)
{
max_ie = 0;
gscore = 0;
}
for (di = 1; LIKELY(di < back_diagnal_num); ++di)
{
// 边界条件一定要注意! tlen 大于,等于,小于 qlen时的情况
if (di > tlen)
{
span = MIN(back_diagnal_num - di, window_size); // 计算的窗口,或者说范围
beg1 = MAX(di - tlen + 1, ((di - window_size) / 2) + 1);
}
else
{
span = MIN(di - 1, window_size);
beg1 = MAX(1, ((di - window_size) / 2) + 1);
}
end1 = MIN(qlen, beg1 + span);
if (read_start_pos < beg1)
read_start_pos = beg1;
if (read_end_pos > end1)
read_end_pos = end1;
if (read_start_pos > read_end_pos)
break; // 不用计算了直接跳出否则hA2没有被赋值里边是上一轮hA0的值会出bug
// read_start_pos = 1;
// read_end_pos = qlen;
ref_end_pos = di - (read_start_pos - 1); // ref开始计算的位置倒序
span = read_end_pos - read_start_pos;
ref_start_pos = ref_end_pos - span - 1; // 0开始的ref索引位置
// 每一轮需要记录的数据
int m = 0, mj = -1, mi = -1;
max_vec = zero_vec;
// 要处理边界
// 左边界 处理f (insert)
if (ref_start_pos == 0)
{
cur_max_arr[read_end_pos] = MAX(0, init_score - (o_ins + e_ins * read_end_pos));
}
// 上边界 delete
if (read_start_pos == 1)
{
cur_max_arr[0] = MAX(0, init_score - (o_del + e_del * ref_end_pos));
}
else
{
cur_max_arr[read_start_pos - 1] = 0;
cur_del_arr[read_start_pos - 1] = 0;
}
// aligned_read_start_pos = (read_start_pos >> ALIGN_SHIFT_BITS << ALIGN_SHIFT_BITS) + 1;
// aligned_ref_end_pos = ref_end_pos + (read_start_pos - aligned_read_start_pos);
aligned_read_start_pos = read_start_pos;
aligned_ref_end_pos = ref_end_pos;
// fprintf(stderr, "%d\t%d\n", read_start_pos, aligned_read_start_pos);
for (j = aligned_read_start_pos, i = aligned_ref_end_pos; j <= read_end_pos + 1 - SIMD_WIDTH; j += SIMD_WIDTH, i -= SIMD_WIDTH)
{
// 取数据
SIMD_LOAD;
// 比对seq计算罚分
SIMD_CMP_SEQ;
// 计算
SIMD_COMPUTE;
// 存储结果
SIMD_STORE;
}
// 剩下的计算单元
if (j <= read_end_pos)
{
// 取数据
SIMD_LOAD;
// 比对seq计算罚分
SIMD_CMP_SEQ;
// 计算
SIMD_COMPUTE;
// 去除多余计算的部分
SIMD_REMOVE_EXTRA;
// 存储结果
SIMD_STORE;
}
SIMD_FIND_MAX;
// 注意最后跳出循环j的值
j = read_end_pos + 1;
if (j == qlen + 1) // 遍历到了query最后一个碱基此时next_max_arr[qlen]为全局匹配的最大分值
{
max_ie = gscore > next_max_arr[qlen] ? max_ie : ref_start_pos;
gscore = gscore > next_max_arr[qlen] ? gscore : next_max_arr[qlen];
}
if (m > max)
{
max = m, max_i = mi, max_j = mj;
max_off = max_off > abs(mj - mi) ? max_off : abs(mj - mi);
}
// 调整计算的边界
for (j = read_start_pos; LIKELY(j <= read_end_pos); ++j)
{
int has_val = cur_max_arr[j - 1] | next_max_arr[j];
if (has_val)
{
break;
}
}
read_start_pos = j;
next_max_arr[read_end_pos + 1] = 0;
for (j = read_end_pos + 1; LIKELY(j >= read_start_pos); --j)
{
int has_val = cur_max_arr[j - 1] | next_max_arr[j];
if (has_val)
{
break;
}
}
read_end_pos = j + 1 <= qlen ? j + 1 : qlen;
// swap m, h, e, f
SWAP_DATA_POINTER;
}
thread_mem_release(tmem, request_mem_size);
if (_qle)
*_qle = max_j + 1;
if (_tle)
*_tle = max_i + 1;
if (_gtle)
*_gtle = max_ie + 1;
if (_gscore)
*_gscore = gscore;
if (_max_off)
*_max_off = max_off;
return max;
}

View File

@ -366,6 +366,8 @@ int ksw_avx2_u8(int qlen, // query length 待匹配段碱基的que
if (beg > end) if (beg > end)
break; // 不用计算了直接跳出否则hA2没有被赋值里边是上一轮hA0的值会出bug break; // 不用计算了直接跳出否则hA2没有被赋值里边是上一轮hA0的值会出bug
// beg = 1;
// end = qlen;
iend = D - (beg - 1); // ref开始计算的位置倒序 iend = D - (beg - 1); // ref开始计算的位置倒序
span = end - beg; span = end - beg;
iStart = iend - span - 1; // 0开始的ref索引位置 iStart = iend - span - 1; // 0开始的ref索引位置

View File

@ -0,0 +1,477 @@
#include <stdlib.h>
#include <stdint.h>
#include <assert.h>
#include <stdio.h>
#include <immintrin.h>
#include <emmintrin.h>
#include "thread_mem.h"
#ifdef __GNUC__
#define LIKELY(x) __builtin_expect((x), 1)
#define UNLIKELY(x) __builtin_expect((x), 0)
#else
#define LIKELY(x) (x)
#define UNLIKELY(x) (x)
#endif
#undef MAX
#undef MIN
#define MAX(x, y) ((x) > (y) ? (x) : (y))
#define MIN(x, y) ((x) < (y) ? (x) : (y))
#define SIMD_WIDTH 32
#define BASE_BYTES 1
#define SCORE_BYTES 1
#define BOUNDARY_SCORE_NUM 2
#define TMP_SCORE_ARRAY_NUM 9
#define MEM_ALIGN_BYTES 32
#define ALIGN_SHIFT_BITS 5
#define SIMD_BYTES 32
#define AMBIGUOUS_BASE_CODE 4
#define AMBIGUOUS_BASE_SCORE -1
// 32字节对齐256位
#define align_mem(x) (((x) + 31) >> 5 << 5)
#define align_number(x) align_mem(x)
static const uint8_t h_vec_int_mask[SIMD_WIDTH][SIMD_WIDTH] = {
{0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0},
{0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff}};
// static const uint8_t reverse_mask[SIMD_WIDTH] = {1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14, 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14};
static const uint8_t reverse_mask[SIMD_WIDTH] = {7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8, 7, 6, 5, 4, 3, 2, 1, 0, 15, 14, 13, 12, 11, 10, 9, 8};
// const int permute_mask = _MM_SHUFFLE(0, 1, 2, 3);
// #define permute_mask _MM_SHUFFLE(0, 1, 2, 3)
#define permute_mask 27
// 初始化变量
#define SIMD_INIT \
int oe_del = o_del + e_del, oe_ins = o_ins + e_ins; \
__m256i zero_vec; \
__m256i max_vec, last_max_vec = _mm256_set1_epi8(h0); \
__m256i oe_del_vec; \
__m256i oe_ins_vec; \
__m256i e_del_vec; \
__m256i e_ins_vec; \
__m256i h_vec_mask[SIMD_WIDTH]; \
__m256i reverse_mask_vec; \
zero_vec = _mm256_setzero_si256(); \
oe_del_vec = _mm256_set1_epi8(oe_del); \
oe_ins_vec = _mm256_set1_epi8(oe_ins); \
e_del_vec = _mm256_set1_epi8(e_del); \
e_ins_vec = _mm256_set1_epi8(e_ins); \
__m256i match_sc_vec = _mm256_set1_epi8(a); \
__m256i mis_sc_vec = _mm256_set1_epi8(b); \
__m256i amb_sc_vec = _mm256_set1_epi8(1); \
__m256i amb_vec = _mm256_set1_epi8(4); \
reverse_mask_vec = _mm256_loadu_si256((__m256i *)(reverse_mask)); \
for (i = 0; i < SIMD_WIDTH; ++i) \
h_vec_mask[i] = _mm256_loadu_si256((__m256i *)(&h_vec_int_mask[i]));
/*
* e ref
* f seq
* m
* h
*/
// load向量化数据
#define SIMD_LOAD \
__m256i m1 = _mm256_loadu_si256((__m256i *)(&mA1[j])); \
__m256i e1 = _mm256_loadu_si256((__m256i *)(&eA1[j])); \
__m256i m1j1 = _mm256_loadu_si256((__m256i *)(&mA1[j - 1])); \
__m256i f1j1 = _mm256_loadu_si256((__m256i *)(&fA1[j - 1])); \
__m256i h0j1 = _mm256_loadu_si256((__m256i *)(&hA0[j - 1])); \
__m256i qs_vec = _mm256_loadu_si256((__m256i *)(&read_seq[j])); \
__m256i ts_vec = _mm256_loadu_si256((__m256i *)(&ref_seq[i]));
// 比对ref和seq的序列计算罚分
#define SIMD_CMP_SEQ \
ts_vec = _mm256_permute4x64_epi64(ts_vec, permute_mask); \
ts_vec = _mm256_shuffle_epi8(ts_vec, reverse_mask_vec); \
__m256i match_mask_vec = _mm256_cmpeq_epi8(qs_vec, ts_vec); \
__m256i mis_score_vec = _mm256_andnot_si256(match_mask_vec, mis_sc_vec); \
__m256i match_score_vec = _mm256_and_si256(match_sc_vec, match_mask_vec); \
__m256i q_amb_mask_vec = _mm256_cmpeq_epi8(qs_vec, amb_vec); \
__m256i t_amb_mask_vec = _mm256_cmpeq_epi8(ts_vec, amb_vec); \
__m256i amb_mask_vec = _mm256_or_si256(q_amb_mask_vec, t_amb_mask_vec); \
__m256i amb_score_vec = _mm256_and_si256(amb_mask_vec, amb_sc_vec); \
mis_score_vec = _mm256_andnot_si256(amb_mask_vec, mis_score_vec); \
mis_score_vec = _mm256_or_si256(amb_score_vec, mis_score_vec); \
match_score_vec = _mm256_andnot_si256(amb_mask_vec, match_score_vec);
// 向量化计算h, e, f, m
#define SIMD_COMPUTE \
__m256i en_vec0 = _mm256_max_epu8(m1, oe_del_vec); \
en_vec0 = _mm256_subs_epu8(en_vec0, oe_del_vec); \
__m256i en_vec1 = _mm256_max_epu8(e1, e_del_vec); \
en_vec1 = _mm256_subs_epu8(en_vec1, e_del_vec); \
__m256i en_vec = _mm256_max_epu8(en_vec0, en_vec1); \
__m256i fn_vec0 = _mm256_max_epu8(m1j1, oe_ins_vec); \
fn_vec0 = _mm256_subs_epu8(fn_vec0, oe_ins_vec); \
__m256i fn_vec1 = _mm256_max_epu8(f1j1, e_ins_vec); \
fn_vec1 = _mm256_subs_epu8(fn_vec1, e_ins_vec); \
__m256i fn_vec = _mm256_max_epu8(fn_vec0, fn_vec1); \
__m256i mn_vec0 = _mm256_adds_epu8(h0j1, match_score_vec); \
mn_vec0 = _mm256_max_epu8(mn_vec0, mis_score_vec); \
mn_vec0 = _mm256_subs_epu8(mn_vec0, mis_score_vec); \
__m256i mn_mask = _mm256_cmpeq_epi8(h0j1, zero_vec); \
__m256i mn_vec = _mm256_andnot_si256(mn_mask, mn_vec0); \
__m256i hn_vec0 = _mm256_max_epu8(en_vec, fn_vec); \
__m256i hn_vec = _mm256_max_epu8(hn_vec0, mn_vec);
// 存储向量化结果
#define SIMD_STORE \
max_vec = _mm256_max_epu8(max_vec, hn_vec); \
_mm256_storeu_si256((__m256i *)&eA2[j], en_vec); \
_mm256_storeu_si256((__m256i *)&fA2[j], fn_vec); \
_mm256_storeu_si256((__m256i *)&mA2[j], mn_vec); \
_mm256_storeu_si256((__m256i *)&hA2[j], hn_vec);
// 去除多余的部分
#define SIMD_REMOVE_EXTRA \
en_vec = _mm256_and_si256(en_vec, h_vec_mask[end - j]); \
fn_vec = _mm256_and_si256(fn_vec, h_vec_mask[end - j]); \
mn_vec = _mm256_and_si256(mn_vec, h_vec_mask[end - j]); \
hn_vec = _mm256_and_si256(hn_vec, h_vec_mask[end - j]);
// cmp_max = _mm256_xor_si256(last_max_vec, cmp_max);
// last_max_vec = _mm256_set1_epi8(m);
// 找最大值和位置
#define SIMD_FIND_MAX \
__m256i cmp_max = _mm256_max_epu8(max_vec, last_max_vec); \
cmp_max = _mm256_xor_si256(cmp_max, last_max_vec); \
cmp_max = _mm256_cmpeq_epi8(cmp_max, zero_vec); \
uint32_t cmp_result = _mm256_movemask_epi8(cmp_max); \
if (cmp_result != 4294967295) \
{ \
uint8_t *maxVal = (uint8_t *)&max_vec; \
max_vec = _mm256_max_epu8(max_vec, _mm256_alignr_epi8(max_vec, max_vec, 1)); \
max_vec = _mm256_max_epu8(max_vec, _mm256_alignr_epi8(max_vec, max_vec, 2)); \
max_vec = _mm256_max_epu8(max_vec, _mm256_alignr_epi8(max_vec, max_vec, 3)); \
max_vec = _mm256_max_epu8(max_vec, _mm256_alignr_epi8(max_vec, max_vec, 4)); \
max_vec = _mm256_max_epu8(max_vec, _mm256_alignr_epi8(max_vec, max_vec, 5)); \
max_vec = _mm256_max_epu8(max_vec, _mm256_alignr_epi8(max_vec, max_vec, 6)); \
max_vec = _mm256_max_epu8(max_vec, _mm256_alignr_epi8(max_vec, max_vec, 7)); \
max_vec = _mm256_max_epu8(max_vec, _mm256_alignr_epi8(max_vec, max_vec, 8)); \
max_vec = _mm256_max_epu8(max_vec, _mm256_permute2x128_si256(max_vec, max_vec, 0x01)); \
m = maxVal[0]; \
for (j = beg, i = iend; j <= end; j += SIMD_WIDTH, i -= SIMD_WIDTH) \
{ \
__m256i h2_vec = _mm256_loadu_si256((__m256i *)(&hA2[j])); \
__m256i vcmp = _mm256_cmpeq_epi8(h2_vec, max_vec); \
uint32_t mask = _mm256_movemask_epi8(vcmp); \
if (mask > 0) \
{ \
int pos = SIMD_WIDTH - 1 - __builtin_clz(mask); \
mj = j - 1 + pos; \
mi = i - 1 - pos; \
for (; mj + 1 < qlen && mi + 1 < tlen; mj++, mi++) \
{ \
if (read_seq[mj + 2] == ref_seq[mi + 1 + SIMD_WIDTH]) \
{ \
m += a; \
} \
else \
{ \
break; \
} \
} \
} \
} \
last_max_vec = _mm256_set1_epi8(m); \
}
// 每轮迭代后,交换数组
#define SWAP_DATA_POINTER \
uint8_t *tmp = hA0; \
hA0 = hA1; \
hA1 = hA2; \
hA2 = tmp; \
tmp = eA1; \
eA1 = eA2; \
eA2 = tmp; \
tmp = fA1; \
fA1 = fA2; \
fA2 = tmp; \
tmp = mA1; \
mA1 = mA2; \
mA2 = tmp;
int ksw_avx2_u8_aligned(thread_mem_t *tmem,
int qlen, // query length 待匹配段碱基的query长度
const uint8_t *query, // read碱基序列
int tlen, // target length reference的长度
const uint8_t *target, // reference序列
int is_left, // 是不是向左扩展
int m, // 碱基种类 (5)
const int8_t *mat, // 每个位置的query和target的匹配得分 m*m
int o_del, // deletion 错配开始的惩罚系数
int e_del, // deletion extension的惩罚系数
int o_ins, // insertion 错配开始的惩罚系数
int e_ins, // insertion extension的惩罚系数
int a, // 碱基match时的分数
int b, // 碱基mismatch时的惩罚分数正数
int w, // 提前剪枝系数w =100 匹配位置和beg的最大距离
int end_bonus,
int zdrop,
int h0, // 该seed的初始得分完全匹配query的碱基数
int *_qle, // 匹配得到全局最大得分的碱基在query的位置
int *_tle, // 匹配得到全局最大得分的碱基在reference的位置
int *_gtle, // query全部匹配上的target的长度
int *_gscore, // query的端到端匹配得分
int *_max_off) // 取得最大得分时在query和reference上位置差的 最大值
{
uint8_t *mA, *hA, *eA, *fA, *mA1, *mA2, *hA0, *hA1, *eA1, *fA1, *hA2, *eA2, *fA2; // hA0保存上上个col的H其他的保存上个H E F M
uint8_t *read_seq, *ref_seq;
int i, iStart, D, j, k, beg, end, max, max_i, max_j, max_ins, max_del, max_ie, gscore, max_off;
int span, beg1, end1; // 边界条件计算
int col_size = qlen + 2 + SIMD_WIDTH;
int val_mem_size = (col_size * 9 + 31) >> 5 << 5; // 32字节的整数倍
// int mem_size = seq_size + ref_size + val_mem_size;
uint8_t *mem_addr;
int read_size = align_number(qlen * BASE_BYTES + MEM_ALIGN_BYTES);
int ref_size = align_number((tlen + SIMD_WIDTH) * BASE_BYTES);
int back_diagnal_num = tlen + qlen; // 循环跳出条件 D从1开始遍历
int score_array_size = align_number((qlen + BOUNDARY_SCORE_NUM) * SCORE_BYTES);
int score_element_num = score_array_size / SCORE_BYTES;
int score_mem_size = score_array_size * TMP_SCORE_ARRAY_NUM;
int request_mem_size = read_size + ref_size + score_mem_size + MEM_ALIGN_BYTES * 3;
SIMD_INIT; // 初始化simd用的数据
assert(h0 > 0);
mem_addr = thread_mem_request(tmem, request_mem_size);
mem_addr = (void *)align_mem((uint64_t)mem_addr);
ref_seq = (uint8_t *)&mem_addr[0];
read_seq = (uint8_t *)(mem_addr + ref_size + SIMD_BYTES - BASE_BYTES);
if (is_left)
{
for (i = 0; i < qlen; ++i)
read_seq[i + 1] = query[qlen - 1 - i];
for (i = 0; i < tlen; ++i)
ref_seq[i + SIMD_WIDTH] = target[tlen - 1 - i];
}
else
{
for (i = 0; i < qlen; ++i)
read_seq[i + 1] = query[i];
for (i = 0; i < tlen; ++i)
ref_seq[i + SIMD_WIDTH] = target[i];
}
mem_addr += read_size + ref_size;
for (i = 0; i <= score_mem_size; i += SIMD_BYTES)
{
_mm256_storeu_si256((__m256i *)&mem_addr[i], zero_vec);
}
mem_addr += SIMD_BYTES - SCORE_BYTES;
hA0 = (uint8_t *)&mem_addr[0];
hA1 = &hA0[score_element_num];
hA2 = &hA1[score_element_num];
mA1 = &hA2[score_element_num];
mA2 = &mA1[score_element_num];
eA1 = &mA2[score_element_num];
eA2 = &eA1[score_element_num];
fA1 = &eA2[score_element_num];
fA2 = &fA1[score_element_num];
// adjust $w if it is too large
k = m * m;
// get the max score
for (i = 0, max = 0; i < k; ++i)
max = max > mat[i] ? max : mat[i];
max_ins = (int)((double)(qlen * max + end_bonus - o_ins) / e_ins + 1.);
max_ins = max_ins > 1 ? max_ins : 1;
w = w < max_ins ? w : max_ins;
max_del = (int)((double)(qlen * max + end_bonus - o_del) / e_del + 1.);
max_del = max_del > 1 ? max_del : 1;
w = w < max_del ? w : max_del; // TODO: is this necessary?
if (tlen < qlen)
w = MIN(tlen - 1, w);
// DP loop
max = h0, max_i = max_j = -1;
max_ie = -1, gscore = -1;
;
max_off = 0;
beg = 1;
end = qlen;
// init h0
hA0[0] = h0; // 左上角
if (qlen == 0 || tlen == 0)
back_diagnal_num = 0; // 防止意外情况
if (w >= qlen)
{
max_ie = 0;
gscore = 0;
}
int m_last = 0;
int iend;
for (D = 1; LIKELY(D < back_diagnal_num); ++D)
{
// 边界条件一定要注意! tlen 大于,等于,小于 qlen时的情况
if (D > tlen)
{
span = MIN(back_diagnal_num - D, w);
beg1 = MAX(D - tlen + 1, ((D - w) / 2) + 1);
}
else
{
span = MIN(D - 1, w);
beg1 = MAX(1, ((D - w) / 2) + 1);
}
end1 = MIN(qlen, beg1 + span);
if (beg < beg1)
beg = beg1;
if (end > end1)
end = end1;
if (beg > end)
break; // 不用计算了直接跳出否则hA2没有被赋值里边是上一轮hA0的值会出bug
// beg = 1;
// end = qlen;
iend = D - (beg - 1); // ref开始计算的位置倒序
span = end - beg;
iStart = iend - span - 1; // 0开始的ref索引位置
// 每一轮需要记录的数据
int m = 0, mj = -1, mi = -1;
max_vec = zero_vec;
// 要处理边界
// 左边界 处理f (insert)
if (iStart == 0)
{
hA1[end] = MAX(0, h0 - (o_ins + e_ins * end));
}
// 上边界
if (beg == 1)
{
hA1[0] = MAX(0, h0 - (o_del + e_del * iend));
}
else
{
hA1[beg - 1] = 0;
eA1[beg - 1] = 0;
}
for (j = beg, i = iend; j <= end + 1 - SIMD_WIDTH; j += SIMD_WIDTH, i -= SIMD_WIDTH)
{
// 取数据
SIMD_LOAD;
// 比对seq计算罚分
SIMD_CMP_SEQ;
// 计算
SIMD_COMPUTE;
// 存储结果
SIMD_STORE;
}
// 剩下的计算单元
if (j <= end)
{
// 取数据
SIMD_LOAD;
// 比对seq计算罚分
SIMD_CMP_SEQ;
// 计算
SIMD_COMPUTE;
// 去除多余计算的部分
SIMD_REMOVE_EXTRA;
// 存储结果
SIMD_STORE;
}
SIMD_FIND_MAX;
// 注意最后跳出循环j的值
j = end + 1;
if (j == qlen + 1)
{
max_ie = gscore > hA2[qlen] ? max_ie : iStart;
gscore = gscore > hA2[qlen] ? gscore : hA2[qlen];
}
if (m > max)
{
max = m, max_i = mi, max_j = mj;
max_off = max_off > abs(mj - mi) ? max_off : abs(mj - mi);
}
// 调整计算的边界
for (j = beg; LIKELY(j <= end); ++j)
{
int has_val = hA1[j - 1] | hA2[j];
if (has_val)
break;
}
beg = j;
hA2[end + 1] = 0;
for (j = end + 1; LIKELY(j >= beg); --j)
{
int has_val = hA1[j - 1] | hA2[j];
if (has_val)
break;
}
end = j + 1 <= qlen ? j + 1 : qlen;
// beg = 0;
// end = qlen;
m_last = m;
// swap m, h, e, f
SWAP_DATA_POINTER;
}
thread_mem_release(tmem, request_mem_size);
if (_qle)
*_qle = max_j + 1;
if (_tle)
*_tle = max_i + 1;
if (_gtle)
*_gtle = max_ie + 1;
if (_gscore)
*_gscore = gscore;
if (_max_off)
*_max_off = max_off;
return max;
}

View File

@ -18,7 +18,7 @@ typedef struct
int ksw_normal(int qlen, const uint8_t *query, int tlen, const uint8_t *target, int m, const int8_t *mat, int o_del, int e_del, int o_ins, int e_ins, int w, int end_bonus, int zdrop, int h0, int *_qle, int *_tle, int *_gtle, int *_gscore, int *_max_off) int ksw_normal(int qlen, const uint8_t *query, int tlen, const uint8_t *target, int m, const int8_t *mat, int o_del, int e_del, int o_ins, int e_ins, int w, int end_bonus, int zdrop, int h0, int *_qle, int *_tle, int *_gtle, int *_gscore, int *_max_off)
{ {
// return h0; return h0;
eh_t *eh; // score array eh_t *eh; // score array
int8_t *qp; // query profile int8_t *qp; // query profile
int i, j, k, oe_del = o_del + e_del, oe_ins = o_ins + e_ins, beg, end, max, max_i, max_j, max_ins, max_del, max_ie, gscore, max_off; int i, j, k, oe_del = o_del + e_del, oe_ins = o_ins + e_ins, beg, end, max, max_i, max_j, max_ins, max_del, max_ie, gscore, max_off;
@ -58,11 +58,13 @@ int ksw_normal(int qlen, const uint8_t *query, int tlen, const uint8_t *target,
int t, f = 0, h1, m = 0, mj = -1; int t, f = 0, h1, m = 0, mj = -1;
int8_t *q = &qp[target[i] * qlen]; // 对于target第i个字符query中每个字符的分值只有匹配和不匹配 int8_t *q = &qp[target[i] * qlen]; // 对于target第i个字符query中每个字符的分值只有匹配和不匹配
// apply the band and the constraint (if provided) // apply the band and the constraint (if provided)
if (beg < i - w) // 检查开始点是否可以缩小一些 // if (beg < i - w) // 检查开始点是否可以缩小一些
beg = i - w; // beg = i - w;
if (end > i + w + 1) // 检查终点是否可以缩小,使得整体的遍历范围缩小 // if (end > i + w + 1) // 检查终点是否可以缩小,使得整体的遍历范围缩小
end = i + w + 1; // end = i + w + 1;
if (end > qlen) // 终点不超过query长度 // if (end > qlen) // 终点不超过query长度
// end = qlen;
beg = 0;
end = qlen; end = qlen;
// compute the first column // compute the first column
if (beg == 0) if (beg == 0)
@ -107,14 +109,14 @@ int ksw_normal(int qlen, const uint8_t *query, int tlen, const uint8_t *target,
max_ie = gscore > h1 ? max_ie : i; // max_ie表示取得全局最大分值时target字符串的位置 max_ie = gscore > h1 ? max_ie : i; // max_ie表示取得全局最大分值时target字符串的位置
gscore = gscore > h1 ? gscore : h1; gscore = gscore > h1 ? gscore : h1;
} }
if (m == 0) // 遍历完query之后当前轮次的最大分值为0则跳出循环 // if (m == 0) // 遍历完query之后当前轮次的最大分值为0则跳出循环
break; // break;
if (m > max) // 当前轮最大分值大于之前的最大分值 if (m > max) // 当前轮最大分值大于之前的最大分值
{ {
max = m, max_i = i, max_j = mj; // 更新取得最大值的target和query的位置 max = m, max_i = i, max_j = mj; // 更新取得最大值的target和query的位置
max_off = max_off > abs(mj - i) ? max_off : abs(mj - i); // 取得最大分值时query和target对应字符串坐标的差值 max_off = max_off > abs(mj - i) ? max_off : abs(mj - i); // 取得最大分值时query和target对应字符串坐标的差值
} }
else if (zdrop > 0) // 当前轮匹配之后取得的最大分值没有大于之前的最大值而且zdrop值大于0 else if (0) //(zdrop > 0) // 当前轮匹配之后取得的最大分值没有大于之前的最大值而且zdrop值大于0
{ {
if (i - max_i > mj - max_j) if (i - max_i > mj - max_j)
{ {
@ -128,12 +130,12 @@ int ksw_normal(int qlen, const uint8_t *query, int tlen, const uint8_t *target,
} }
} }
// update beg and end for the next round // update beg and end for the next round
for (j = beg; LIKELY(j < end) && eh[j].h == 0 && eh[j].e == 0; ++j) // for (j = beg; LIKELY(j < end) && eh[j].h == 0 && eh[j].e == 0; ++j)
; // ;
beg = j; // beg = j;
for (j = end; LIKELY(j >= beg) && eh[j].h == 0 && eh[j].e == 0; --j) // for (j = end; LIKELY(j >= beg) && eh[j].h == 0 && eh[j].e == 0; --j)
; // ;
end = j + 2 < qlen ? j + 2 : qlen; // 剪枝没考虑f即insert // end = j + 2 < qlen ? j + 2 : qlen; // 剪枝没考虑f即insert
// beg = 0, end = qlen; // uncomment this line for debugging // beg = 0, end = qlen; // uncomment this line for debugging
// fprintf(stderr, "\n"); // fprintf(stderr, "\n");
// fprintf(stderr, "%d\n", end); // fprintf(stderr, "%d\n", end);

136
main.c
View File

@ -5,6 +5,7 @@
#include <assert.h> #include <assert.h>
#include <time.h> #include <time.h>
#include "sys/time.h" #include "sys/time.h"
#include "thread_mem.h"
#define SW_NORMAL 0 #define SW_NORMAL 0
#define SW_AVX2 1 #define SW_AVX2 1
@ -31,6 +32,7 @@ int64_t get_mseconds()
int64_t time_sw_normal = 0, int64_t time_sw_normal = 0,
time_sw_avx2 = 0, time_sw_avx2 = 0,
time_sw_avx2_u8 = 0, time_sw_avx2_u8 = 0,
time_sw_avx2_u8_aligned = 0,
time_bsw_avx2 = 0, time_bsw_avx2 = 0,
time_bsw_init = 0, time_bsw_init = 0,
time_bsw_main_loop = 0, time_bsw_main_loop = 0,
@ -72,37 +74,15 @@ void convert_char_to_2bit(char *str)
str[i] = nst_nt4_table[str[i]]; str[i] = nst_nt4_table[str[i]];
} }
/* // 读取测试数据
* sw int read_data()
* normal/avx2/cuda {
*/ return 0;
}
// 程序执行入口 // 程序执行入口
int main(int argc, char *argv[]) int main(int argc, char *argv[])
{ {
/*
int sw_algo = SW_NORMAL;
// 判断执行的sw的实现类型
if (argc > 1)
{
if (strcmp(argv[1], "normal") == 0)
{
sw_algo = SW_NORMAL;
}
else if (strcmp(argv[1], "avx2") == 0)
{
sw_algo = SW_AVX2;
}
else if (strcmp(argv[1], "cuda") == 0)
{
sw_algo = SW_CUDA;
}
else
{
sw_algo = SW_ALL;
}
} */
// 初始化一些全局参数 // 初始化一些全局参数
int8_t mat[25] = {1, -4, -4, -4, -1, int8_t mat[25] = {1, -4, -4, -4, -1,
-4, 1, -4, -4, -1, -4, 1, -4, -4, -1,
@ -111,6 +91,10 @@ int main(int argc, char *argv[])
-1, -1, -1, -1, -1}; -1, -1, -1, -1, -1};
int max_off[2]; int max_off[2];
int qle, tle, gtle, gscore; int qle, tle, gtle, gscore;
thread_mem_t tmem, tmem_u8;
init_thread_mem(&tmem);
init_thread_mem(&tmem_u8);
// thread_mem_init_alloc(&tmem_u8, 10960);
// 读取测试数据 // 读取测试数据
char *query_arr = (char *)malloc(SEQ_BUF_SIZE); char *query_arr = (char *)malloc(SEQ_BUF_SIZE);
@ -130,12 +114,12 @@ int main(int argc, char *argv[])
// const char *qf_path = "/home/zzh/data/sw/q_m.fa"; // const char *qf_path = "/home/zzh/data/sw/q_m.fa";
// const char *tf_path = "/home/zzh/data/sw/t_m.fa"; // const char *tf_path = "/home/zzh/data/sw/t_m.fa";
// const char *if_path = "/home/zzh/data/sw/i_m.txt"; // const char *if_path = "/home/zzh/data/sw/i_m.txt";
// const char *qf_path = "/home/zzh/data/sw/q_l.fa"; const char *qf_path = "/home/zzh/data/sw/q_l.fa";
// const char *tf_path = "/home/zzh/data/sw/t_l.fa"; const char *tf_path = "/home/zzh/data/sw/t_l.fa";
// const char *if_path = "/home/zzh/data/sw/i_l.txt"; const char *if_path = "/home/zzh/data/sw/i_l.txt";
const char *qf_path = "/home/zzh/data/sw/query.fa"; // const char *qf_path = "/home/zzh/data/sw/query.fa";
const char *tf_path = "/home/zzh/data/sw/target.fa"; // const char *tf_path = "/home/zzh/data/sw/target.fa";
const char *if_path = "/home/zzh/data/sw/info.txt"; // const char *if_path = "/home/zzh/data/sw/info.txt";
query_f = fopen(qf_path, "r"); query_f = fopen(qf_path, "r");
target_f = fopen(tf_path, "r"); target_f = fopen(tf_path, "r");
info_f = fopen(if_path, "r"); info_f = fopen(if_path, "r");
@ -168,6 +152,7 @@ int main(int argc, char *argv[])
int score_normal = 0, score_avx2 = 0, score_avx2_u8 = 0, score_bsw_avx2 = 0; int score_normal = 0, score_avx2 = 0, score_avx2_u8 = 0, score_bsw_avx2 = 0;
int score_normal_total = 0, score_avx2_total = 0, score_avx2_u8_total = 0, score_bsw_avx2_total = 0; int score_normal_total = 0, score_avx2_total = 0, score_avx2_u8_total = 0, score_bsw_avx2_total = 0;
int score_avx2_u8_aligned = 0, score_avx2_u8_aligned_total = 0;
while (!feof(target_f)) while (!feof(target_f))
{ {
@ -253,46 +238,57 @@ int main(int argc, char *argv[])
// fprintf(normal_out_f, "%d %d\n", info_arr[i][2], score_normal); // fprintf(normal_out_f, "%d %d\n", info_arr[i][2], score_normal);
// fprintf(stderr, "%d %d %d %d %d %d %d\n", info_arr[i][2], score_normal, qle, tle, gtle, gscore, max_off[0]); // fprintf(stderr, "%d %d %d %d %d %d %d\n", info_arr[i][2], score_normal, qle, tle, gtle, gscore, max_off[0]);
#ifdef SHOW_PERF // #ifdef SHOW_PERF
start_time = get_mseconds(); // start_time = get_mseconds();
#endif // #endif
score_bsw_avx2 = bsw_avx2( // score_bsw_avx2 = ksw_avx2_aligned(
info_arr[i][0], // &tmem,
(uint8_t *)query_arr + cur_query_pos, // info_arr[i][0],
info_arr[i][1], // (uint8_t *)query_arr + cur_query_pos,
(uint8_t *)target_arr + cur_target_pos, // info_arr[i][1],
0, 5, mat, 6, 1, 6, 1, // (uint8_t *)target_arr + cur_target_pos,
1, 4, // 0, 6, 1, 6, 1,
100, 5, 100, // 1, 4,
info_arr[i][2], // 100, 5,
&qle, &tle, &gtle, &gscore, &max_off[0]); // info_arr[i][2],
#ifdef SHOW_PERF // &qle, &tle, &gtle, &gscore, &max_off[0]);
time_bsw_avx2 += get_mseconds() - start_time; // #ifdef SHOW_PERF
#endif // time_bsw_avx2 += get_mseconds() - start_time;
score_bsw_avx2_total += score_bsw_avx2; // #endif
// score_bsw_avx2_total += score_bsw_avx2;
// fprintf(avx2_out_f, "%d %d\n", info_arr[i][2], score_avx2); // fprintf(avx2_out_f, "%d %d\n", info_arr[i][2], score_avx2);
// fprintf(stderr, "%d %d %d %d %d %d %d\n", info_arr[i][2], score_bsw_avx2_total, qle, tle, gtle, gscore, max_off[0]); // fprintf(stderr, "%d %d %d %d %d %d %d\n", info_arr[i][2], score_bsw_avx2_total, qle, tle, gtle, gscore, max_off[0]);
/* /**/
#ifdef SHOW_PERF #ifdef SHOW_PERF
start_time = get_mseconds(); start_time = get_mseconds();
#endif #endif
score_avx2 = ksw_avx2( score_avx2 = bsw_avx2(
info_arr[i][0], info_arr[i][0],
(uint8_t *)query_arr + cur_query_pos, (uint8_t *)query_arr + cur_query_pos,
info_arr[i][1], info_arr[i][1],
(uint8_t *)target_arr + cur_target_pos, (uint8_t *)target_arr + cur_target_pos,
0, 5, mat, 6, 1, 6, 1, 0, 6, 1, 6, 1,
1, 4, 1, 4,
100, 5, 100, 100, 5,
info_arr[i][2], info_arr[i][2],
&qle, &tle, &gtle, &gscore, &max_off[0]); &qle, &tle, &gtle, &gscore, &max_off[0]);
// score_avx2 = ksw_avx2(
// info_arr[i][0],
// (uint8_t *)query_arr + cur_query_pos,
// info_arr[i][1],
// (uint8_t *)target_arr + cur_target_pos,
// 0, 5, mat, 6, 1, 6, 1,
// 1, 4,
// 100, 5, 100,
// info_arr[i][2],
// &qle, &tle, &gtle, &gscore, &max_off[0]);
#ifdef SHOW_PERF #ifdef SHOW_PERF
time_sw_avx2 += get_mseconds() - start_time; time_sw_avx2 += get_mseconds() - start_time;
#endif #endif
score_avx2_total += score_avx2; score_avx2_total += score_avx2;
// fprintf(avx2_out_f, "%d %d\n", info_arr[i][2], score_avx2); // fprintf(avx2_out_f, "%d %d\n", info_arr[i][2], score_avx2);
fprintf(stderr, "%d %d %d %d %d %d %d\n", info_arr[i][2], score_avx2, qle, tle, gtle, gscore, max_off[0]); // fprintf(stderr, "%d %d %d %d %d %d %d\n", info_arr[i][2], score_avx2, qle, tle, gtle, gscore, max_off[0]);
*/
#ifdef SHOW_PERF #ifdef SHOW_PERF
start_time = get_mseconds(); start_time = get_mseconds();
#endif #endif
@ -311,6 +307,27 @@ int main(int argc, char *argv[])
#endif #endif
score_avx2_u8_total += score_avx2_u8; score_avx2_u8_total += score_avx2_u8;
// fprintf(avx2_u8_out_f, "%d %d %d %d %d %d\n", score_avx2_u8, qle, tle, gtle, gscore, max_off[0]); // fprintf(avx2_u8_out_f, "%d %d %d %d %d %d\n", score_avx2_u8, qle, tle, gtle, gscore, max_off[0]);
#ifdef SHOW_PERF
start_time = get_mseconds();
#endif
score_avx2_u8_aligned = ksw_avx2_u8_aligned(
&tmem_u8,
info_arr[i][0],
(uint8_t *)query_arr + cur_query_pos,
info_arr[i][1],
(uint8_t *)target_arr + cur_target_pos,
0, 5, mat, 6, 1, 6, 1,
1, 4,
100, 5, 100,
info_arr[i][2],
&qle, &tle, &gtle, &gscore, &max_off[0]);
#ifdef SHOW_PERF
time_sw_avx2_u8_aligned += get_mseconds() - start_time;
#endif
score_avx2_u8_aligned_total += score_avx2_u8_aligned;
// fprintf(avx2_u8_out_f, "%d %d %d %d %d %d\n", score_avx2_u8, qle, tle, gtle, gscore, max_off[0]);
// 更新query和target位置信息 // 更新query和target位置信息
cur_query_pos += info_arr[i][0]; cur_query_pos += info_arr[i][0];
cur_target_pos += info_arr[i][1]; cur_target_pos += info_arr[i][1];
@ -379,9 +396,10 @@ int main(int argc, char *argv[])
#ifdef SHOW_PERF #ifdef SHOW_PERF
fprintf(stderr, "time_sw_normal: %f s; score: %d\n", time_sw_normal / DIVIDE_BY, score_normal_total); fprintf(stderr, "time_sw_normal: %f s; score: %d\n", time_sw_normal / DIVIDE_BY, score_normal_total);
fprintf(stderr, "time_bsw_avx2: %f s; score: %d\n", time_bsw_avx2 / DIVIDE_BY, score_bsw_avx2_total); fprintf(stderr, "time_bsw_avx2: %f s; score: %d\n", time_bsw_avx2 / DIVIDE_BY, score_bsw_avx2_total);
// fprintf(stderr, "time_sw_avx2: %f s; score: %d\n", time_sw_avx2 / DIVIDE_BY, score_avx2_total); fprintf(stderr, "time_sw_avx2: %f s; score: %d\n", time_sw_avx2 / DIVIDE_BY, score_avx2_total);
fprintf(stderr, "time_sw_avx2_u8: %f s; score: %d\n", time_sw_avx2_u8 / DIVIDE_BY, score_avx2_u8_total); fprintf(stderr, "time_sw_avx2_u8: %f s; score: %d\n", time_sw_avx2_u8 / DIVIDE_BY, score_avx2_u8_total);
fprintf(stderr, "time_sw_avx2_u8_aligned: %f s; score: %d\n", time_sw_avx2_u8_aligned / DIVIDE_BY, score_avx2_u8_aligned_total);
fprintf(stderr, "thread mem capacity: %d\t%d\n", tmem.capacity, tmem_u8.capacity);
fprintf(stderr, "time_bsw_init: %f s\n", time_bsw_init / DIVIDE_BY); fprintf(stderr, "time_bsw_init: %f s\n", time_bsw_init / DIVIDE_BY);
// fprintf(stderr, "time_bsw_main_loop: %f s\n", (time_bsw_main_loop) / DIVIDE_BY); // fprintf(stderr, "time_bsw_main_loop: %f s\n", (time_bsw_main_loop) / DIVIDE_BY);
// fprintf(stderr, "time_bsw_find_max: %f s\n", (time_bsw_find_max) / DIVIDE_BY); // fprintf(stderr, "time_bsw_find_max: %f s\n", (time_bsw_find_max) / DIVIDE_BY);

89
thread_mem.c 100644
View File

@ -0,0 +1,89 @@
/*********************************************************************************************
Description: In-thread memory allocation with boundary aligned
Copyright : All right reserved by NCIC.ICT
Author : Zhang Zhonghai
Date : 2023/08/23
***********************************************************************************************/
#include "thread_mem.h"
#include <stdio.h>
// 创建
thread_mem_t *create_thread_mem()
{
thread_mem_t *tmem = (thread_mem_t *)malloc(sizeof(thread_mem_t));
tmem->occupied = tmem->capacity = 0;
tmem->mem = 0;
return tmem;
}
// 初始化
void init_thread_mem(thread_mem_t *tmem)
{
tmem->occupied = tmem->capacity = 0;
tmem->mem = 0;
}
// 初始化并开辟一定量的内存
void thread_mem_init_alloc(thread_mem_t *tmem, size_t byte_cnt)
{
tmem->capacity = byte_cnt;
tmem->mem = malloc(tmem->capacity);
tmem->occupied = 0;
}
// 请求内存
void *thread_mem_request(thread_mem_t *tmem, size_t byte_cnt)
{
void *ret_mem = 0;
if (tmem == 0)
{
ret_mem = 0;
}
else if (tmem->capacity == 0)
{
tmem->capacity = byte_cnt;
tmem->mem = malloc(tmem->capacity);
tmem->occupied = byte_cnt;
ret_mem = tmem->mem;
}
else if (tmem->capacity - tmem->occupied >= byte_cnt)
{
ret_mem = tmem->mem + tmem->occupied;
tmem->occupied += byte_cnt;
}
else
{
tmem->capacity = tmem->occupied + byte_cnt;
tmem->mem = realloc(tmem->mem, tmem->capacity);
ret_mem = tmem->mem + tmem->occupied;
tmem->occupied += byte_cnt;
}
return ret_mem;
}
// 将不用的内存归还给thread mem
void thread_mem_release(thread_mem_t *tmem, size_t byte_cnt)
{
tmem->occupied -= byte_cnt;
}
// 彻底释放内存
void thread_mem_free(thread_mem_t *tmem)
{
tmem->capacity = tmem->occupied = 0;
free(tmem->mem);
tmem->mem = 0;
}
// 销毁thread_mem_t
void destroy_thread_mem(thread_mem_t *tmem)
{
if (tmem != 0)
{
thread_mem_free(tmem);
free(tmem);
}
}

44
thread_mem.h 100644
View File

@ -0,0 +1,44 @@
/*********************************************************************************************
Description: In-thread memory allocation with boundary aligned
Copyright : All right reserved by NCIC.ICT
Author : Zhang Zhonghai
Date : 2023/08/23
***********************************************************************************************/
#ifndef __MEMORY_H
#define __MEMORY_H
#include <stdlib.h>
#include <stdint.h>
#include <string.h>
#define MEM_ALIGN_BYTE 8
#define MEM_MOVE_BIT 3
typedef struct
{
size_t occupied; // 已经占用的容量(字节数) 对齐的
size_t capacity; // 总容量(字节数)
void *mem; // 申请的内存首地址
} thread_mem_t;
// 创建thread_mem_t
thread_mem_t *create_thread_mem();
// 初始化
void init_thread_mem(thread_mem_t *tmem);
// 初始化并开辟一定量的内存
void thread_mem_init_alloc(thread_mem_t *tmem, size_t byte_cnt);
// 请求内存
void *thread_mem_request(thread_mem_t *tmem, size_t byte_cnt);
// 将不用的内存归还给thread mem
void thread_mem_release(thread_mem_t *tmem, size_t byte_cnt);
// 彻底释放内存
void thread_mem_free(thread_mem_t *tmem);
// 销毁thread_mem_t
void destroy_thread_mem(thread_mem_t *tmem);
#endif