From 5ff04ec853f9900af8bc4ffcdd4ee9c00f6c2c97 Mon Sep 17 00:00:00 2001 From: zzh Date: Fri, 2 Feb 2024 12:53:34 +0800 Subject: [PATCH] =?UTF-8?q?=E5=8E=BB=E6=8E=89=E4=BA=86=E4=B8=80=E4=BA=9B?= =?UTF-8?q?=E8=B0=83=E8=AF=95=E4=BB=A3=E7=A0=81=EF=BC=8Cfmt=E5=A2=9E?= =?UTF-8?q?=E5=8A=A0=E4=BA=86=E5=B0=8F=E9=97=B4=E9=9A=94=E5=90=8C=E6=97=B6?= =?UTF-8?q?=E8=AE=A1=E7=AE=97=E7=9A=84=E4=BB=A3=E7=A0=81=EF=BC=8C=E4=BD=86?= =?UTF-8?q?=E6=98=AF=E6=95=88=E6=9E=9C=E4=B8=8D=E6=98=8E=E6=98=BE=EF=BC=8C?= =?UTF-8?q?=E5=90=8E=E7=BB=AD=E5=8F=AF=E8=83=BD=E4=BC=9A=E5=88=A0=E6=8E=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- bwt.cpp | 6 +- fmt_index.cpp | 201 +++++++++++++++++++++++++++++++++----------------- 2 files changed, 135 insertions(+), 72 deletions(-) diff --git a/bwt.cpp b/bwt.cpp index b680eda..5454d1e 100644 --- a/bwt.cpp +++ b/bwt.cpp @@ -198,9 +198,9 @@ void bwt_occ4(const bwt_t *bwt, bwtint_t k, bwtint_t cnt[4]) void bwt_2occ4(const bwt_t *bwt, bwtint_t k, bwtint_t l, bwtint_t cntk[4], bwtint_t cntl[4]) { // double tm_t = realtime(); - bwt_occ4(bwt, k, cntk); - bwt_occ4(bwt, l, cntl); - return; + // bwt_occ4(bwt, k, cntk); + // bwt_occ4(bwt, l, cntl); + // return; bwtint_t _k, _l; _k = k - (k >= bwt->primary); _l = l - (l >= bwt->primary); diff --git a/fmt_index.cpp b/fmt_index.cpp index 57031a2..3cb1def 100644 --- a/fmt_index.cpp +++ b/fmt_index.cpp @@ -365,8 +365,8 @@ void fmt_e2_occ(const FMTIndex *fmt, bwtint_t k, int b1, int b2, bwtint_t cnt[4] { uint32_t x = 0; uint32_t *p, *q, tmp; - bwtint_t bwt_k_line = k, bwt_k_base_line = k >> FMT_OCC_INTV_SHIFT << FMT_OCC_INTV_SHIFT; - int i, ti; + bwtint_t str_line = k, cp_line = k & (~FMT_OCC_INTV_MASK); + int i, ti = b1 << 2 | b2; cnt[0] = 0; cnt[1] = 0; cnt[2] = 0; @@ -377,10 +377,8 @@ void fmt_e2_occ(const FMTIndex *fmt, bwtint_t k, int b1, int b2, bwtint_t cnt[4] cnt[3] = p[b2]; return; } - ti = b1 << 2 | b2; k -= (k >= fmt->primary); // k由bwt矩阵对应的行转换成bwt字符串对应的行(去掉了$,所以大于$的行,都减掉1) p = fmt_occ_intv(fmt, k); - //cout << "k-base: " << k << "; occ: " << p[0] << ' ' << p[1] << ' ' << p[2] << ' ' << p[3] << endl; for (i = b1 + 1; i < 4; ++i) cnt[0] += p[i]; // 大于b1的碱基的occ之和 cnt[1] = p[b1]; // b1的occ q = p + 4 + b1 * 4; @@ -388,31 +386,10 @@ void fmt_e2_occ(const FMTIndex *fmt, bwtint_t k, int b1, int b2, bwtint_t cnt[4] cnt[3] = q[b2]; // b2的occ p += 20; - //if (k == 3965453116 || k == 3965453672 || k == 2668688087 || k == 2668688550) { - // cout << "sec base-occ: " << q[0] << ' ' << q[1] << ' ' << q[2] << ' ' << q[3] << endl; - //} #ifdef FMT_MID_INTERVAL // 使用mid interval信息 int mk = k % FMT_OCC_INTERVAL; int n_mintv = mk >> FMT_MID_INTV_SHIFT; - //if (k == 3137454504) - //{ - // for (i = 0; i < n_mintv; ++i) - // { - // q = p + i * 6; - // print_base_uint32(*q); - // print_base_uint32(*(q + 1)); - // x = *(q + 2); - // cout << ((x) >> 24 & 0xff) << ' ' << ((x) >> 16 & 0xff) << ' ' << ((x) >> 8 & 0xff) << ' ' << ((x) & 0xff) << ' ' << __fmt_mid_sum(x) << endl; - // x = *(q + 3); - // cout << ((x) >> 24 & 0xff) << ' ' << ((x) >> 16 & 0xff) << ' ' << ((x) >> 8 & 0xff) << ' ' << ((x) & 0xff) << ' ' << __fmt_mid_sum(x) << endl; - // x = *(q + 4); - // cout << ((x) >> 24 & 0xff) << ' ' << ((x) >> 16 & 0xff) << ' ' << ((x) >> 8 & 0xff) << ' ' << ((x) & 0xff) << ' ' << __fmt_mid_sum(x) << endl; - // x = *(q + 5); - // cout << ((x) >> 24 & 0xff) << ' ' << ((x) >> 16 & 0xff) << ' ' << ((x) >> 8 & 0xff) << ' ' << ((x) & 0xff) << ' ' << __fmt_mid_sum(x) << endl; - // } - // x = 0; - //} if (n_mintv > 0) // 至少超过了第一个mid interval { p += n_mintv * (4 + (FMT_MID_INTERVAL >> 3)) - 4; // 对应的mid interval check point的首地址,即A C G T的局部累积量 @@ -421,80 +398,163 @@ void fmt_e2_occ(const FMTIndex *fmt, bwtint_t k, int b1, int b2, bwtint_t cnt[4] x += p[i]; // 大于b1的碱基的occ之和 cnt[0] += __fmt_mid_sum(x); x = *q; - // if (k == 3137454504) - // { - // cout << ((x) >> 24 & 0xff) << '\t' << ((x) >> 16 & 0xff) - // << '\t' << ((x) >> 8 & 0xff) << '\t' << ((x) & 0xff) << endl; - // cout << __fmt_mid_sum(x) << endl; - // } cnt[1] += __fmt_mid_sum(x); // b1的occ for (i = 3; i > b2; --i) cnt[2] += x >> (i << 3) & 0xff; // 大于b2的occ之和 - cnt[3] += x >> (b2 << 3) & 0xff; // b2的occ + cnt[3] += x >> (b2 << 3) & 0xff; // b2的occ x = 0; p += 4; } - // cout << "mid-occ: " << cnt[0] << ' ' << cnt[1] << ' ' << cnt[2] << ' ' << cnt[3] << endl; - // cout << "k: " << k << ' ' << cnt[1] << endl; #if FMT_MID_INTERVAL == 16 if ((mk & FMT_MID_INTV_MASK) >> 3) { x += __fmt_occ_e2_aux2(fmt, ti, *p); - // print_base_uint32(*p); ++p; } -#endif - tmp = *p & ~((1U << ((~k & 7) << 2)) - 1); - x += __fmt_occ_e2_aux2(fmt, ti, tmp); - //if (k == 3965453116 || k == 3965453672 || k == 2668688087 || k == 2668688550) - //{ - // print_base_uint32(tmp); - // cout << "多的: " << (~k & 7) << endl; - // cout << (x >> 24 & 0xff) << ' ' << (x >> 16 & 0xff) << ' ' << (x >> 8 & 0xff) << ' ' << (x & 0xff) << endl; - //} -#else // 该地址是bwt和pre_bwt字符串数据的首地址 - uint32_t *end = p + ((k >> 3) - ((k & ~FMT_OCC_INTV_MASK) >> 3)); // this is the end point of the following loop - // p = end - (end - p) / 8; - // cout << "k - kbase: " << k - bwt_k_base_line << endl; +#elif FMT_MID_INTERVAL > 16 // 该地址是bwt和pre_bwt字符串数据的首地址 + uint32_t *end = p + ((k >> 3) - ((k & ~FMT_OCC_INTV_MASK) >> 3)); for (; p < end; ++p) { x += __fmt_occ_e2_aux2(fmt, ti, *p); - // print_base_uint32(*p); } tmp = *p & ~((1U << ((~k & 7) << 2)) - 1); - // print_base_uint32(tmp); x += __fmt_occ_e2_aux2(fmt, ti, tmp); #endif - +#else + uint32_t *end = p + ((k >> 3) - ((k & ~FMT_OCC_INTV_MASK) >> 3)); + for (; p < end; ++p) + { + x += __fmt_occ_e2_aux2(fmt, ti, *p); + } + tmp = *p & ~((1U << ((~k & 7) << 2)) - 1); + x += __fmt_occ_e2_aux2(fmt, ti, tmp); +#endif + tmp = *p & ~((1U << ((~k & 7) << 2)) - 1); + x += __fmt_occ_e2_aux2(fmt, ti, tmp); + if (b1 == 0) { x -= (~k & 7) << 8; if (b2 == 0) x -= (~k & 7) << 24; - // cout << "here-1" << endl; } // 如果跨过了second_primary,那么可能需要减掉一次累积值 - if (b1 == fmt->first_base && bwt_k_base_line < fmt->sec_primary && bwt_k_line >= fmt->sec_primary) + if (b1 == fmt->first_base && cp_line < fmt->sec_primary && str_line >= fmt->sec_primary) { if (b2 < fmt->last_base) - cnt[2] -= 1 ; + cnt[2] -= 1; else if (b2 == fmt->last_base) cnt[3] -= 1; - // cout << "here-2" << endl; - } -// if (k == 3965453116 || k == 3965453672 || k == 2668688087 || k == 2668688550) -// { -// cout << BASE[b1] << BASE[b2] << endl; -// cout << "final-x: " << (x >> 24 & 0xff) << ' ' << (x >> 16 & 0xff) << ' ' << (x >> 8 & 0xff) << ' ' << (x & 0xff) << endl; -// } - cnt[0] += x & 0xff; cnt[1] += x >> 8 & 0xff; cnt[2] += x >> 16 & 0xff; cnt[3] += x >> 24 & 0xff; - // if (k == 3965453116 || k == 3965453672 || k == 2668688087 || k == 2668688550) - // cout << "final-occ: " << cnt[0] << ' ' << cnt[1] << ' ' << cnt[2] << ' ' << cnt[3] << endl; +} + +void calc_intv_occ(const FMTIndex *fmt, bwtint_t k, bwtint_t str_line, bwtint_t cp_line, + uint32_t *p, int b1, int b2, bwtint_t cnt[4]) +{ + + uint32_t x = 0; + uint32_t *q, tmp; + int i, ti = b1 << 2 | b2; +#ifdef FMT_MID_INTERVAL + // 使用mid interval信息 + int mk = k % FMT_OCC_INTERVAL; + int n_mintv = mk >> FMT_MID_INTV_SHIFT; + if (n_mintv > 0) // 至少超过了第一个mid interval + { + p += n_mintv * (4 + (FMT_MID_INTERVAL >> 3)) - 4; // 对应的mid interval check point的首地址,即A C G T的局部累积量 + q = p + b1; + for (i = b1 + 1; i < 4; ++i) + x += p[i]; // 大于b1的碱基的occ之和 + cnt[0] += __fmt_mid_sum(x); + x = *q; + cnt[1] += __fmt_mid_sum(x); // b1的occ + for (i = 3; i > b2; --i) + cnt[2] += x >> (i << 3) & 0xff; // 大于b2的occ之和 + cnt[3] += x >> (b2 << 3) & 0xff; // b2的occ + x = 0; + p += 4; + } +#if FMT_MID_INTERVAL == 16 + if ((mk & FMT_MID_INTV_MASK) >> 3) + { + x += __fmt_occ_e2_aux2(fmt, ti, *p); + ++p; + } +#elif FMT_MID_INTERVAL > 16 // 该地址是bwt和pre_bwt字符串数据的首地址 + uint32_t *end = p + ((k >> 3) - ((k & ~FMT_OCC_INTV_MASK) >> 3)); + for (; p < end; ++p) + { + x += __fmt_occ_e2_aux2(fmt, ti, *p); + } + tmp = *p & ~((1U << ((~k & 7) << 2)) - 1); + x += __fmt_occ_e2_aux2(fmt, ti, tmp); +#endif +#else + uint32_t *end = p + ((k >> 3) - ((k & ~FMT_OCC_INTV_MASK) >> 3)); + for (; p < end; ++p) + { + x += __fmt_occ_e2_aux2(fmt, ti, *p); + } + tmp = *p & ~((1U << ((~k & 7) << 2)) - 1); + x += __fmt_occ_e2_aux2(fmt, ti, tmp); +#endif + tmp = *p & ~((1U << ((~k & 7) << 2)) - 1); + x += __fmt_occ_e2_aux2(fmt, ti, tmp); + + if (b1 == 0) + { + x -= (~k & 7) << 8; + if (b2 == 0) + x -= (~k & 7) << 24; + } + // 如果跨过了second_primary,那么可能需要减掉一次累积值 + if (b1 == fmt->first_base && cp_line < fmt->sec_primary && str_line >= fmt->sec_primary) + { + if (b2 < fmt->last_base) + cnt[2] -= 1; + else if (b2 == fmt->last_base) + cnt[3] -= 1; + } + cnt[0] += x & 0xff; + cnt[1] += x >> 8 & 0xff; + cnt[2] += x >> 16 & 0xff; + cnt[3] += x >> 24 & 0xff; +} + +void fmt_e2_2occ(const FMTIndex *fmt, bwtint_t k, bwtint_t l, int b1, int b2, bwtint_t tk[4], bwtint_t tl[4]) +{ + bwtint_t _k, _l; + _k = k - (k >= fmt->primary); + _l = l - (l >= fmt->primary); + if (_l >> FMT_OCC_INTV_SHIFT != _k >> FMT_OCC_INTV_SHIFT || k == (bwtint_t)(-1) || l == (bwtint_t)(-1)) + { + fmt_e2_occ(fmt, k, b1, b2, tk); + fmt_e2_occ(fmt, l, b1, b2, tl); + } + else + { + uint32_t *p, *q; + bwtint_t cp_line = _k >> FMT_OCC_INTV_SHIFT << FMT_OCC_INTV_SHIFT; + int i; + tk[0] = 0; + tk[2] = 0; + p = fmt_occ_intv(fmt, _k); + for (i = b1 + 1; i < 4; ++i) + tk[0] += p[i]; // 大于b1的碱基的occ之和 + tk[1] = p[b1]; // b1的occ + q = p + 4 + b1 * 4; + for (i = b2 + 1; i < 4; ++i) + tk[2] += q[i]; // 大于b2的occ之和 + tk[3] = q[b2]; // b2的occ + p += 20; + memcpy(tl, tk, 32); + calc_intv_occ(fmt, _k, k, cp_line, p, b1, b2, tk); + calc_intv_occ(fmt, _l, l, cp_line, p, b1, b2, tl); + } } // 扩展两个碱基 @@ -502,8 +562,10 @@ void fmt_extend2(const FMTIndex *fmt, bwtintv_t *ik, bwtintv_t *ok, int is_back, { bwtint_t tk[4], tl[4], first_pos; // tk表示在k行之前所有各个碱基累积出现次数,tl表示在l行之前的累积 - fmt_e2_occ(fmt, ik->x[!is_back] - 1, b1, b2, tk); - fmt_e2_occ(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl); + //fmt_e2_occ(fmt, ik->x[!is_back] - 1, b1, b2, tk); + //fmt_e2_occ(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl); + fmt_e2_2occ(fmt, ik->x[!is_back] - 1, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tk, tl); + //fmt_e2_occ_2way(fmt, ik->x[!is_back] - 1, b1, b2, tk); //fmt_e2_occ_2way(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl); // 这里是反向扩展 @@ -525,8 +587,9 @@ void fmt_extend1(const FMTIndex *fmt, bwtintv_t *ik, bwtintv_t *ok, int is_back, bwtint_t tk[4], tl[4]; int b2 = 3; // tk表示在k行之前所有各个碱基累积出现次数,tl表示在l行之前的累积 - fmt_e2_occ(fmt, ik->x[!is_back] - 1, b1, b2, tk); - fmt_e2_occ(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl); + //fmt_e2_occ(fmt, ik->x[!is_back] - 1, b1, b2, tk); + //fmt_e2_occ(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl); + fmt_e2_2occ(fmt, ik->x[!is_back] - 1, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tk, tl); //fmt_e2_occ_2way(fmt, ik->x[!is_back] - 1, b1, b2, tk); //fmt_e2_occ_2way(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, b2, tl); // 这里是反向扩展 @@ -654,7 +717,7 @@ int main_fmtidx(int argc, char **argv) t1 = realtime(); for (int i = 0; i < (int)seed_arr.size(); ++i) - seed_arr[i] = generate_rand_seq(7); + seed_arr[i] = generate_rand_seq(29); t1 = realtime() - t1; cout << "[time gen seed:] " << t1 << "s" << endl;