去掉了一些调试代码,fmt增加了小间隔同时计算的代码,但是效果不明显,后续可能会删掉

This commit is contained in:
zzh 2024-02-02 12:53:34 +08:00
parent 5f18167703
commit 5ff04ec853
2 changed files with 135 additions and 72 deletions

View File

@ -198,9 +198,9 @@ void bwt_occ4(const bwt_t *bwt, bwtint_t k, bwtint_t cnt[4])
void bwt_2occ4(const bwt_t *bwt, bwtint_t k, bwtint_t l, bwtint_t cntk[4], bwtint_t cntl[4]) void bwt_2occ4(const bwt_t *bwt, bwtint_t k, bwtint_t l, bwtint_t cntk[4], bwtint_t cntl[4])
{ {
// double tm_t = realtime(); // double tm_t = realtime();
bwt_occ4(bwt, k, cntk); // bwt_occ4(bwt, k, cntk);
bwt_occ4(bwt, l, cntl); // bwt_occ4(bwt, l, cntl);
return; // return;
bwtint_t _k, _l; bwtint_t _k, _l;
_k = k - (k >= bwt->primary); _k = k - (k >= bwt->primary);
_l = l - (l >= bwt->primary); _l = l - (l >= bwt->primary);

View File

@ -365,8 +365,8 @@ void fmt_e2_occ(const FMTIndex *fmt, bwtint_t k, int b1, int b2, bwtint_t cnt[4]
{ {
uint32_t x = 0; uint32_t x = 0;
uint32_t *p, *q, tmp; uint32_t *p, *q, tmp;
bwtint_t bwt_k_line = k, bwt_k_base_line = k >> FMT_OCC_INTV_SHIFT << FMT_OCC_INTV_SHIFT; bwtint_t str_line = k, cp_line = k & (~FMT_OCC_INTV_MASK);
int i, ti; int i, ti = b1 << 2 | b2;
cnt[0] = 0; cnt[0] = 0;
cnt[1] = 0; cnt[1] = 0;
cnt[2] = 0; cnt[2] = 0;
@ -377,10 +377,8 @@ void fmt_e2_occ(const FMTIndex *fmt, bwtint_t k, int b1, int b2, bwtint_t cnt[4]
cnt[3] = p[b2]; cnt[3] = p[b2];
return; return;
} }
ti = b1 << 2 | b2;
k -= (k >= fmt->primary); // k由bwt矩阵对应的行转换成bwt字符串对应的行去掉了$,所以大于$的行都减掉1 k -= (k >= fmt->primary); // k由bwt矩阵对应的行转换成bwt字符串对应的行去掉了$,所以大于$的行都减掉1
p = fmt_occ_intv(fmt, k); p = fmt_occ_intv(fmt, k);
//cout << "k-base: " << k << "; occ: " << p[0] << ' ' << p[1] << ' ' << p[2] << ' ' << p[3] << endl;
for (i = b1 + 1; i < 4; ++i) cnt[0] += p[i]; // 大于b1的碱基的occ之和 for (i = b1 + 1; i < 4; ++i) cnt[0] += p[i]; // 大于b1的碱基的occ之和
cnt[1] = p[b1]; // b1的occ cnt[1] = p[b1]; // b1的occ
q = p + 4 + b1 * 4; q = p + 4 + b1 * 4;
@ -388,31 +386,10 @@ void fmt_e2_occ(const FMTIndex *fmt, bwtint_t k, int b1, int b2, bwtint_t cnt[4]
cnt[3] = q[b2]; // b2的occ cnt[3] = q[b2]; // b2的occ
p += 20; p += 20;
//if (k == 3965453116 || k == 3965453672 || k == 2668688087 || k == 2668688550) {
// cout << "sec base-occ: " << q[0] << ' ' << q[1] << ' ' << q[2] << ' ' << q[3] << endl;
//}
#ifdef FMT_MID_INTERVAL #ifdef FMT_MID_INTERVAL
// 使用mid interval信息 // 使用mid interval信息
int mk = k % FMT_OCC_INTERVAL; int mk = k % FMT_OCC_INTERVAL;
int n_mintv = mk >> FMT_MID_INTV_SHIFT; int n_mintv = mk >> FMT_MID_INTV_SHIFT;
//if (k == 3137454504)
//{
// for (i = 0; i < n_mintv; ++i)
// {
// q = p + i * 6;
// print_base_uint32(*q);
// print_base_uint32(*(q + 1));
// x = *(q + 2);
// cout << ((x) >> 24 & 0xff) << ' ' << ((x) >> 16 & 0xff) << ' ' << ((x) >> 8 & 0xff) << ' ' << ((x) & 0xff) << ' ' << __fmt_mid_sum(x) << endl;
// x = *(q + 3);
// cout << ((x) >> 24 & 0xff) << ' ' << ((x) >> 16 & 0xff) << ' ' << ((x) >> 8 & 0xff) << ' ' << ((x) & 0xff) << ' ' << __fmt_mid_sum(x) << endl;
// x = *(q + 4);
// cout << ((x) >> 24 & 0xff) << ' ' << ((x) >> 16 & 0xff) << ' ' << ((x) >> 8 & 0xff) << ' ' << ((x) & 0xff) << ' ' << __fmt_mid_sum(x) << endl;
// x = *(q + 5);
// cout << ((x) >> 24 & 0xff) << ' ' << ((x) >> 16 & 0xff) << ' ' << ((x) >> 8 & 0xff) << ' ' << ((x) & 0xff) << ' ' << __fmt_mid_sum(x) << endl;
// }
// x = 0;
//}
if (n_mintv > 0) // 至少超过了第一个mid interval if (n_mintv > 0) // 至少超过了第一个mid interval
{ {
p += n_mintv * (4 + (FMT_MID_INTERVAL >> 3)) - 4; // 对应的mid interval check point的首地址即A C G T的局部累积量 p += n_mintv * (4 + (FMT_MID_INTERVAL >> 3)) - 4; // 对应的mid interval check point的首地址即A C G T的局部累积量
@ -421,12 +398,6 @@ void fmt_e2_occ(const FMTIndex *fmt, bwtint_t k, int b1, int b2, bwtint_t cnt[4]
x += p[i]; // 大于b1的碱基的occ之和 x += p[i]; // 大于b1的碱基的occ之和
cnt[0] += __fmt_mid_sum(x); cnt[0] += __fmt_mid_sum(x);
x = *q; x = *q;
// if (k == 3137454504)
// {
// cout << ((x) >> 24 & 0xff) << '\t' << ((x) >> 16 & 0xff)
// << '\t' << ((x) >> 8 & 0xff) << '\t' << ((x) & 0xff) << endl;
// cout << __fmt_mid_sum(x) << endl;
// }
cnt[1] += __fmt_mid_sum(x); // b1的occ cnt[1] += __fmt_mid_sum(x); // b1的occ
for (i = 3; i > b2; --i) for (i = 3; i > b2; --i)
cnt[2] += x >> (i << 3) & 0xff; // 大于b2的occ之和 cnt[2] += x >> (i << 3) & 0xff; // 大于b2的occ之和
@ -434,67 +405,156 @@ void fmt_e2_occ(const FMTIndex *fmt, bwtint_t k, int b1, int b2, bwtint_t cnt[4]
x = 0; x = 0;
p += 4; p += 4;
} }
// cout << "mid-occ: " << cnt[0] << ' ' << cnt[1] << ' ' << cnt[2] << ' ' << cnt[3] << endl;
// cout << "k: " << k << ' ' << cnt[1] << endl;
#if FMT_MID_INTERVAL == 16 #if FMT_MID_INTERVAL == 16
if ((mk & FMT_MID_INTV_MASK) >> 3) if ((mk & FMT_MID_INTV_MASK) >> 3)
{ {
x += __fmt_occ_e2_aux2(fmt, ti, *p); x += __fmt_occ_e2_aux2(fmt, ti, *p);
// print_base_uint32(*p);
++p; ++p;
} }
#endif #elif FMT_MID_INTERVAL > 16 // 该地址是bwt和pre_bwt字符串数据的首地址
tmp = *p & ~((1U << ((~k & 7) << 2)) - 1); uint32_t *end = p + ((k >> 3) - ((k & ~FMT_OCC_INTV_MASK) >> 3));
x += __fmt_occ_e2_aux2(fmt, ti, tmp);
//if (k == 3965453116 || k == 3965453672 || k == 2668688087 || k == 2668688550)
//{
// print_base_uint32(tmp);
// cout << "多的: " << (~k & 7) << endl;
// cout << (x >> 24 & 0xff) << ' ' << (x >> 16 & 0xff) << ' ' << (x >> 8 & 0xff) << ' ' << (x & 0xff) << endl;
//}
#else // 该地址是bwt和pre_bwt字符串数据的首地址
uint32_t *end = p + ((k >> 3) - ((k & ~FMT_OCC_INTV_MASK) >> 3)); // this is the end point of the following loop
// p = end - (end - p) / 8;
// cout << "k - kbase: " << k - bwt_k_base_line << endl;
for (; p < end; ++p) for (; p < end; ++p)
{ {
x += __fmt_occ_e2_aux2(fmt, ti, *p); x += __fmt_occ_e2_aux2(fmt, ti, *p);
// print_base_uint32(*p);
} }
tmp = *p & ~((1U << ((~k & 7) << 2)) - 1); tmp = *p & ~((1U << ((~k & 7) << 2)) - 1);
// print_base_uint32(tmp);
x += __fmt_occ_e2_aux2(fmt, ti, tmp); x += __fmt_occ_e2_aux2(fmt, ti, tmp);
#endif #endif
#else
uint32_t *end = p + ((k >> 3) - ((k & ~FMT_OCC_INTV_MASK) >> 3));
for (; p < end; ++p)
{
x += __fmt_occ_e2_aux2(fmt, ti, *p);
}
tmp = *p & ~((1U << ((~k & 7) << 2)) - 1);
x += __fmt_occ_e2_aux2(fmt, ti, tmp);
#endif
tmp = *p & ~((1U << ((~k & 7) << 2)) - 1);
x += __fmt_occ_e2_aux2(fmt, ti, tmp);
if (b1 == 0) if (b1 == 0)
{ {
x -= (~k & 7) << 8; x -= (~k & 7) << 8;
if (b2 == 0) if (b2 == 0)
x -= (~k & 7) << 24; x -= (~k & 7) << 24;
// cout << "here-1" << endl;
} }
// 如果跨过了second_primary,那么可能需要减掉一次累积值 // 如果跨过了second_primary,那么可能需要减掉一次累积值
if (b1 == fmt->first_base && bwt_k_base_line < fmt->sec_primary && bwt_k_line >= fmt->sec_primary) if (b1 == fmt->first_base && cp_line < fmt->sec_primary && str_line >= fmt->sec_primary)
{ {
if (b2 < fmt->last_base) if (b2 < fmt->last_base)
cnt[2] -= 1; cnt[2] -= 1;
else if (b2 == fmt->last_base) else if (b2 == fmt->last_base)
cnt[3] -= 1; cnt[3] -= 1;
// cout << "here-2" << endl;
} }
// if (k == 3965453116 || k == 3965453672 || k == 2668688087 || k == 2668688550)
// {
// cout << BASE[b1] << BASE[b2] << endl;
// cout << "final-x: " << (x >> 24 & 0xff) << ' ' << (x >> 16 & 0xff) << ' ' << (x >> 8 & 0xff) << ' ' << (x & 0xff) << endl;
// }
cnt[0] += x & 0xff; cnt[0] += x & 0xff;
cnt[1] += x >> 8 & 0xff; cnt[1] += x >> 8 & 0xff;
cnt[2] += x >> 16 & 0xff; cnt[2] += x >> 16 & 0xff;
cnt[3] += x >> 24 & 0xff; cnt[3] += x >> 24 & 0xff;
// if (k == 3965453116 || k == 3965453672 || k == 2668688087 || k == 2668688550) }
// cout << "final-occ: " << cnt[0] << ' ' << cnt[1] << ' ' << cnt[2] << ' ' << cnt[3] << endl;
void calc_intv_occ(const FMTIndex *fmt, bwtint_t k, bwtint_t str_line, bwtint_t cp_line,
uint32_t *p, int b1, int b2, bwtint_t cnt[4])
{
uint32_t x = 0;
uint32_t *q, tmp;
int i, ti = b1 << 2 | b2;
#ifdef FMT_MID_INTERVAL
// 使用mid interval信息
int mk = k % FMT_OCC_INTERVAL;
int n_mintv = mk >> FMT_MID_INTV_SHIFT;
if (n_mintv > 0) // 至少超过了第一个mid interval
{
p += n_mintv * (4 + (FMT_MID_INTERVAL >> 3)) - 4; // 对应的mid interval check point的首地址即A C G T的局部累积量
q = p + b1;
for (i = b1 + 1; i < 4; ++i)
x += p[i]; // 大于b1的碱基的occ之和
cnt[0] += __fmt_mid_sum(x);
x = *q;
cnt[1] += __fmt_mid_sum(x); // b1的occ
for (i = 3; i > b2; --i)
cnt[2] += x >> (i << 3) & 0xff; // 大于b2的occ之和
cnt[3] += x >> (b2 << 3) & 0xff; // b2的occ
x = 0;
p += 4;
}
#if FMT_MID_INTERVAL == 16
if ((mk & FMT_MID_INTV_MASK) >> 3)
{
x += __fmt_occ_e2_aux2(fmt, ti, *p);
++p;
}
#elif FMT_MID_INTERVAL > 16 // 该地址是bwt和pre_bwt字符串数据的首地址
uint32_t *end = p + ((k >> 3) - ((k & ~FMT_OCC_INTV_MASK) >> 3));
for (; p < end; ++p)
{
x += __fmt_occ_e2_aux2(fmt, ti, *p);
}
tmp = *p & ~((1U << ((~k & 7) << 2)) - 1);
x += __fmt_occ_e2_aux2(fmt, ti, tmp);
#endif
#else
uint32_t *end = p + ((k >> 3) - ((k & ~FMT_OCC_INTV_MASK) >> 3));
for (; p < end; ++p)
{
x += __fmt_occ_e2_aux2(fmt, ti, *p);
}
tmp = *p & ~((1U << ((~k & 7) << 2)) - 1);
x += __fmt_occ_e2_aux2(fmt, ti, tmp);
#endif
tmp = *p & ~((1U << ((~k & 7) << 2)) - 1);
x += __fmt_occ_e2_aux2(fmt, ti, tmp);
if (b1 == 0)
{
x -= (~k & 7) << 8;
if (b2 == 0)
x -= (~k & 7) << 24;
}
// 如果跨过了second_primary,那么可能需要减掉一次累积值
if (b1 == fmt->first_base && cp_line < fmt->sec_primary && str_line >= fmt->sec_primary)
{
if (b2 < fmt->last_base)
cnt[2] -= 1;
else if (b2 == fmt->last_base)
cnt[3] -= 1;
}
cnt[0] += x & 0xff;
cnt[1] += x >> 8 & 0xff;
cnt[2] += x >> 16 & 0xff;
cnt[3] += x >> 24 & 0xff;
}
void fmt_e2_2occ(const FMTIndex *fmt, bwtint_t k, bwtint_t l, int b1, int b2, bwtint_t tk[4], bwtint_t tl[4])
{
bwtint_t _k, _l;
_k = k - (k >= fmt->primary);
_l = l - (l >= fmt->primary);
if (_l >> FMT_OCC_INTV_SHIFT != _k >> FMT_OCC_INTV_SHIFT || k == (bwtint_t)(-1) || l == (bwtint_t)(-1))
{
fmt_e2_occ(fmt, k, b1, b2, tk);
fmt_e2_occ(fmt, l, b1, b2, tl);
}
else
{
uint32_t *p, *q;
bwtint_t cp_line = _k >> FMT_OCC_INTV_SHIFT << FMT_OCC_INTV_SHIFT;
int i;
tk[0] = 0;
tk[2] = 0;
p = fmt_occ_intv(fmt, _k);
for (i = b1 + 1; i < 4; ++i)
tk[0] += p[i]; // 大于b1的碱基的occ之和
tk[1] = p[b1]; // b1的occ
q = p + 4 + b1 * 4;
for (i = b2 + 1; i < 4; ++i)
tk[2] += q[i]; // 大于b2的occ之和
tk[3] = q[b2]; // b2的occ
p += 20;
memcpy(tl, tk, 32);
calc_intv_occ(fmt, _k, k, cp_line, p, b1, b2, tk);
calc_intv_occ(fmt, _l, l, cp_line, p, b1, b2, tl);
}
} }
// 扩展两个碱基 // 扩展两个碱基
@ -502,8 +562,10 @@ void fmt_extend2(const FMTIndex *fmt, bwtintv_t *ik, bwtintv_t *ok, int is_back,
{ {
bwtint_t tk[4], tl[4], first_pos; bwtint_t tk[4], tl[4], first_pos;
// tk表示在k行之前所有各个碱基累积出现次数tl表示在l行之前的累积 // tk表示在k行之前所有各个碱基累积出现次数tl表示在l行之前的累积
fmt_e2_occ(fmt, ik->x[!is_back] - 1, b1, b2, tk); //fmt_e2_occ(fmt, ik->x[!is_back] - 1, b1, b2, tk);
fmt_e2_occ(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl); //fmt_e2_occ(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl);
fmt_e2_2occ(fmt, ik->x[!is_back] - 1, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tk, tl);
//fmt_e2_occ_2way(fmt, ik->x[!is_back] - 1, b1, b2, tk); //fmt_e2_occ_2way(fmt, ik->x[!is_back] - 1, b1, b2, tk);
//fmt_e2_occ_2way(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl); //fmt_e2_occ_2way(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl);
// 这里是反向扩展 // 这里是反向扩展
@ -525,8 +587,9 @@ void fmt_extend1(const FMTIndex *fmt, bwtintv_t *ik, bwtintv_t *ok, int is_back,
bwtint_t tk[4], tl[4]; bwtint_t tk[4], tl[4];
int b2 = 3; int b2 = 3;
// tk表示在k行之前所有各个碱基累积出现次数tl表示在l行之前的累积 // tk表示在k行之前所有各个碱基累积出现次数tl表示在l行之前的累积
fmt_e2_occ(fmt, ik->x[!is_back] - 1, b1, b2, tk); //fmt_e2_occ(fmt, ik->x[!is_back] - 1, b1, b2, tk);
fmt_e2_occ(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl); //fmt_e2_occ(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl);
fmt_e2_2occ(fmt, ik->x[!is_back] - 1, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tk, tl);
//fmt_e2_occ_2way(fmt, ik->x[!is_back] - 1, b1, b2, tk); //fmt_e2_occ_2way(fmt, ik->x[!is_back] - 1, b1, b2, tk);
//fmt_e2_occ_2way(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl); //fmt_e2_occ_2way(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl);
// 这里是反向扩展 // 这里是反向扩展
@ -654,7 +717,7 @@ int main_fmtidx(int argc, char **argv)
t1 = realtime(); t1 = realtime();
for (int i = 0; i < (int)seed_arr.size(); ++i) for (int i = 0; i < (int)seed_arr.size(); ++i)
seed_arr[i] = generate_rand_seq(7); seed_arr[i] = generate_rand_seq(29);
t1 = realtime() - t1; t1 = realtime() - t1;
cout << "[time gen seed:] " << t1 << "s" << endl; cout << "[time gen seed:] " << t1 << "s" << endl;