bwa_perf/fmt_index.cpp

979 lines
32 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <iostream>
#include <stdint.h>
#include <stdlib.h>
#include <vector>
#include <sys/time.h>
#include <string>
#include <stdio.h>
#include <algorithm>
#include <string.h>
#include "common.h"
using namespace std;
///* For general OCC_INTERVAL, the following is correct:
#define bwt_bwt(b, k) ((b)->bwt[(k) / OCC_INTERVAL * (OCC_INTERVAL / (sizeof(uint32_t) * 8 / 2) + sizeof(bwtint_t) / 4 * 4) + sizeof(bwtint_t) / 4 * 4 + (k) % OCC_INTERVAL / 16])
#define bwt_occ_intv(b, k) ((b)->bwt + (k) / OCC_INTERVAL * (OCC_INTERVAL / (sizeof(uint32_t) * 8 / 2) + sizeof(bwtint_t) / 4 * 4))
//*/
// The following two lines are ONLY correct when OCC_INTERVAL==0x80
// #define bwt_bwt(b, k) ((b)->bwt[((k) >> 7 << 4) + sizeof(bwtint_t) + (((k) & 0x7f) >> 4)])
// #define bwt_occ_intv(b, k) ((b)->bwt + ((k) >> 7 << 4))
/* retrieve a character from the $-removed BWT string. Note that
* bwt_t::bwt is not exactly the BWT string and therefore this macro is
* called bwt_B0 instead of bwt_B */
#define bwt_B0(b, k) (bwt_bwt(b, k) >> ((~(k) & 0xf) << 1) & 3)
#define bwt_set_intv(bwt, c, ik) ((ik).x[0] = (bwt)->L2[(int)(c)] + 1, (ik).x[2] = (bwt)->L2[(int)(c) + 1] - (bwt)->L2[(int)(c)], (ik).x[1] = (bwt)->L2[3 - (c)] + 1, (ik).info = 0)
#define __occ_aux4(bwt, b) \
((bwt)->cnt_table[(b) & 0xff] + (bwt)->cnt_table[(b) >> 8 & 0xff] + (bwt)->cnt_table[(b) >> 16 & 0xff] + (bwt)->cnt_table[(b) >> 24])
const char BASE[4] = {'A', 'C', 'G', 'T'};
// base转成2bit值
inline int bval(char b)
{
if (b == 'A')
return 0;
if (b == 'C')
return 1;
if (b == 'G')
return 2;
if (b == 'T')
return 3;
return 4;
}
// 互补碱基值
inline int cbval(char b)
{
return 3 - bval(b);
}
struct bwtintv_t
{
bwtint_t x[3], info; // x[0]表示正链位置x[1]表示互补链位置x[2]表示间隔长度info 表示read的起始结束位置
};
// 原始fm-index结构
struct bwt_t
{
bwtint_t primary; // S^{-1}(0), or the primary index of BWT
bwtint_t L2[5]; // C(), cumulative count
bwtint_t seq_len; // sequence length
bwtint_t bwt_size; // size of bwt, about seq_len/4
uint32_t *bwt; // BWT
// occurance array, separated to two parts
uint32_t cnt_table[256];
// suffix array
int sa_intv;
bwtint_t n_sa;
uint8_t *sa;
};
// fm-index, twice calc in one memory access
struct FMTIndex {
bwtint_t primary; // S^{-1}(0), or the primary index of BWT
bwtint_t sec_primary; // second primary line
bwtint_t L2[5]; // C(), cumulative count
bwtint_t seq_len; // sequence length
bwtint_t bwt_size; // size of bwt, about seq_len/4
uint32_t *bwt; // BWT
// occurance array, separated to two parts
uint32_t cnt_table[5][256]; // 4对应原来的cnt_table0,1,2,3,分别对应该碱基的扩展
int sec_bcp; // base couple for sec primary line, AA=>0, AC=>1 ... TT=>15
int first_base; // 序列的第一个碱基2bit的int类型0,1,2,3
int last_base; // dollar转换成的base
// suffix array
int sa_intv;
bwtint_t n_sa;
uint8_t *sa;
};
void _err_fatal_simple_core(const char *func, const char *msg)
{
fprintf(stderr, "[%s] %s Abort!\n", func, msg);
abort();
}
// 读取最原始的bwt
static bwtint_t fread_fix(FILE *fp, bwtint_t size, void *a)
{ // Mac/Darwin has a bug when reading data longer than 2GB. This function fixes this issue by reading data in small chunks
const int bufsize = 0x1000000; // 16M block
bwtint_t offset = 0;
while (size)
{
int x = bufsize < size ? bufsize : size;
if ((x = fread((uint8_t *)a + offset, 1, x, fp)) == 0)
break;
size -= x;
offset += x;
}
return offset;
}
// 求反向互补序列
string calc_reverse_seq(string &seq)
{
string rseq(seq.size(), '0');
for (int i = 0; i < seq.size(); ++i)
{
if (seq[i] == 'A')
rseq[i] = 'T';
else if (seq[i] == 'C')
rseq[i] = 'G';
else if (seq[i] == 'G')
rseq[i] = 'C';
else if (seq[i] == 'T')
rseq[i] = 'A';
}
std::reverse(rseq.begin(), rseq.end());
return rseq;
}
static inline int __occ_aux(uint64_t y, int c)
{
// reduce nucleotide counting to bits counting
y = ((c & 2) ? y : ~y) >> 1 & ((c & 1) ? y : ~y) & 0x5555555555555555ull;
// count the number of 1s in y
y = (y & 0x3333333333333333ull) + (y >> 2 & 0x3333333333333333ull);
return ((y + (y >> 4)) & 0xf0f0f0f0f0f0f0full) * 0x101010101010101ull >> 56;
}
// k行(包含)之前碱基c的累计总数, interval大于等于32才能正确计算
bwtint_t bwt_occ(const bwt_t *bwt, bwtint_t k, uint8_t c)
{
bwtint_t n;
uint32_t *p, *end;
if (k == bwt->seq_len)
return bwt->L2[c + 1] - bwt->L2[c];
if (k == (bwtint_t)(-1))
return 0;
k -= (k >= bwt->primary); // because $ is not in bwt
// retrieve Occ at k/OCC_INTERVAL
n = ((bwtint_t *)(p = bwt_occ_intv(bwt, k)))[c];
// cout << "bwt_occ - 1: " << (int)c << '\t' << k << '\t' << n << endl;
p += sizeof(bwtint_t); // jump to the start of the first BWT cell
// calculate Occ up to the last k/32
end = p + (((k >> 5) - ((k & ~OCC_INTV_MASK) >> 5)) << 1);
for (; p < end; p += 2)
n += __occ_aux((uint64_t)p[0] << 32 | p[1], c);
//cout << "bwt_occ - 2: " << (int)c << '\t' << k << '\t' << n << endl;
// calculate Occ
n += __occ_aux(((uint64_t)p[0] << 32 | p[1]) & ~((1ull << ((~k & 31) << 1)) - 1), c);
if (c == 0)
n -= ~k & 31; // corrected for the masked bits
//cout << "bwt_occ - 3: " << (int)c << '\t' << k << '\t' << n << endl;
return n;
}
// 这里的k是bwt矩阵里的行比bwt字符串多1
void bwt_occ4(const bwt_t *bwt, bwtint_t k, bwtint_t cnt[4])
{
bwtint_t x;
uint32_t *p, tmp, *end;
if (k == (bwtint_t)(-1))
{
memset(cnt, 0, 4 * sizeof(bwtint_t));
return;
}
k -= (k >= bwt->primary); // because $ is not in bwt
p = bwt_occ_intv(bwt, k);
memcpy(cnt, p, 4 * sizeof(bwtint_t));
p += sizeof(bwtint_t); // sizeof(bwtint_t) = 4*(sizeof(bwtint_t)/sizeof(uint32_t))
end = p + ((k >> 4) - ((k & ~OCC_INTV_MASK) >> 4)); // this is the end point of the following loop
for (x = 0; p < end; ++p)
x += __occ_aux4(bwt, *p);
tmp = *p & ~((1U << ((~k & 15) << 1)) - 1);
x += __occ_aux4(bwt, tmp) - (~k & 15); // 这里多算了A要减去
cnt[0] += x & 0xff;
cnt[1] += x >> 8 & 0xff;
cnt[2] += x >> 16 & 0xff;
cnt[3] += x >> 24;
}
// 创建bwt矩阵
void create_bwt_mtx(string &seq)
{
cout << "seq size: " << seq.size() + 1 << endl;
string sarr[seq.size() + 1];
sarr[0] = seq + '$';
for (int i = 1; i < sarr[0].size(); ++i)
{
sarr[i] = sarr[0].substr(i) + sarr[0].substr(0, i);
}
std::sort(sarr, sarr + seq.size() + 1);
// bwt matrix
for (int i = 0; i < sarr[0].size(); ++i)
{
// cout << i << ' ' << sarr[i] << endl;
cout << sarr[i] << endl;
}
// cout << "bwt string" << endl;
// for (int i = 0; i < sarr[0].size(); ++i)
// {
// cout << sarr[i].back();
// }
// cout << endl;
//
// cout << "pre bwt string" << endl;
// for (int i = 0; i < sarr[0].size(); ++i)
// {
// cout << sarr[i][sarr[0].size() - 2];
// }
// cout << endl;
}
// 计算一个字节构成的A,T,C,G序列对应的每个碱基的个数因为最多有4个相同的碱基所以每次左移3位就行
void bwt_gen_cnt_table(bwt_t *bwt)
{
int i, j;
for (i = 0; i != 256; ++i)
{
uint32_t x = 0;
for (j = 0; j != 4; ++j)
x |= (((i & 3) == j) + ((i >> 2 & 3) == j) + ((i >> 4 & 3) == j) + (i >> 6 == j)) << (j << 3);
bwt->cnt_table[i] = x;
}
}
// fmt-index的count table
void fmt_gen_cnt_table(FMTIndex *fmt)
{
int i, j, k;
for (i = 0; i != 256; ++i) // 遍历单个字节的各种情况
{
uint32_t x = 0;
for (j = 0; j != 4; ++j) // 一个字节有8位每个碱基是2位所以一个字节包含4个碱基从右向左数第一个和第三个碱基数据bwt第二个和第四个是对应的pre-bwt
x |= (((i & 3) == j) + ((i >> 4 & 3) == j)) << (j << 3); // 高位存pre-bwt挨着存bwt一一对应
fmt->cnt_table[4][i] = x; // 保存单个字节中bwt碱基个数每8位对应一个碱基的个数从左到右依次是TGCA
for (k = 0; k < 4;++k) // bwt碱基
{
x = 0; // for [A,C,G,T][A,C,G,T]
for (j = 0; j != 4; ++j) // pre-bwt碱基
x |= (((i >> 6 & 3) == j && (i >> 4 & 3) == k) + ((i >> 2 & 3) == j && (i & 3) == k)) << (j << 3);
fmt->cnt_table[k][i] = x;
}
}
}
void print_base_uint32(uint32_t p)
{
for (int i = 30; i > 0; i -= 4)
{
int b1 = p >> i & 3;
int b2 = p >> (i - 2) & 3;
cout << BASE[b1] << BASE[b2] << endl;
}
}
// 解析两bit碱基序列
bwt_t *restore_bwt_str(const char *fn)
{
bwt_t *bwt;
bwt = (bwt_t *)calloc(1, sizeof(bwt_t));
FILE *fp = fopen(fn, "rb");
char *buf;
fseek(fp, 0, SEEK_END);
bwt->bwt_size = (ftell(fp) - sizeof(bwtint_t) * 5) >> 2; // 以32位word为单位计算的size
bwt->bwt = (uint32_t *)calloc(bwt->bwt_size, 4);
fseek(fp, 0, SEEK_SET);
fread(&bwt->primary, sizeof(bwtint_t), 1, fp);
fread(bwt->L2 + 1, sizeof(bwtint_t), 4, fp);
fread_fix(fp, bwt->bwt_size << 2, bwt->bwt);
bwt->seq_len = bwt->L2[4];
// buf = (char *)calloc(bwt->seq_len + 1, 1);
// for (bwtint_t i = 0; i < bwt->seq_len; ++i)
// {
// buf[i] = BASE[bwt->bwt[i >> 4] >> ((15 - (i & 15)) << 1) & 3];
// cout << buf[i];
// }
// cout << endl;
fclose(fp);
bwt_gen_cnt_table(bwt); // 字节所能表示的各种碱基组合中,各个碱基的累积数量
return bwt;
}
// 根据原始的字符串bwt创建interval-bwt
void create_interval_occ_bwt(bwt_t *bwt)
{
bwtint_t i, k, c[4], n_occ;
uint32_t *buf;
n_occ = (bwt->seq_len + OCC_INTERVAL - 1) / OCC_INTERVAL + 1;
bwt->bwt_size += n_occ * sizeof(bwtint_t); // the new size
buf = (uint32_t *)calloc(bwt->bwt_size, 4); // will be the new bwt
c[0] = c[1] = c[2] = c[3] = 0;
// 计算occ生成naive bwt
for (i = k = 0; i < bwt->seq_len; ++i)
{
// cout << i << '\t';
// cout << c[0] << ' ' << c[1] << ' ' << c[2] << ' ' << c[3] << endl;
if (i % OCC_INTERVAL == 0)
{
memcpy(buf + k, c, sizeof(bwtint_t) * 4);
k += sizeof(bwtint_t); // in fact: sizeof(bwtint_t)=4*(sizeof(bwtint_t)/4) 每个c包含多少个32位
// cout << "i: " << i << "\tc: " << c[0] << '\t' << c[1] << '\t' << c[2] << '\t' << c[3] << endl;
}
if (i % 16 == 0)
buf[k++] = bwt->bwt[i / 16]; // 16 == sizeof(uint32_t)/2, 2个bit表示一个碱基
++c[bwt_B00(bwt, i)];
}
// the last element
// cout << c[0] << '\t' << c[1] << '\t' << c[2] << '\t' << c[3] << endl;
memcpy(buf + k, c, sizeof(bwtint_t) * 4);
xassert(k + sizeof(bwtint_t) == bwt->bwt_size, "inconsistent bwt_size");
// update bwt
free(bwt->bwt);
bwt->bwt = buf;
}
// 根据interval-bwt创建fmt-index
FMTIndex *create_fmt_from_bwt(bwt_t *bwt)
{
FILE *fmt_out = fopen("fmt.txt", "w");
FMTIndex *fmt = (FMTIndex *)calloc(1, sizeof(FMTIndex));
fmt_gen_cnt_table(fmt);
bwtint_t i, j, k, m, n, n_occ, cnt[4], cnt2[4];
uint32_t c[4], c2[16] /*保存AA..TT*/;
uint32_t *buf;
fmt->seq_len = bwt->seq_len;
for (i = 0; i < 5; ++i)
fmt->L2[i] = bwt->L2[i];
fmt->primary = bwt->primary;
n_occ = (bwt->seq_len + OCC_INTERVAL - 1) / OCC_INTERVAL + 1; // check point 个数
fmt->bwt_size = (fmt->seq_len * 2 + 15) >> 4; // 要保存最后两列碱基
fmt->bwt_size += n_occ * 20; // A,C,G,T和AA,AC.....TG,TT共20个
buf = (uint32_t *)calloc(fmt->bwt_size, 4); // 开辟计算fmt用到的缓存
c[0] = c[1] = c[2] = c[3] = 0;
// 首行的c2应该是对应的ACGT对应的行减去1的occ
for (i = 0; i < 4; ++i)
{
bwtint_t before_first_line = fmt->L2[i];
bwt_occ4(bwt, before_first_line, cnt);
for (j = i * 4, k = 0; k < 4; ++j, ++k)
c2[j] = cnt[k];
// cout << "start: " << BASE[i] << " line: " << before_first_line << " occ: " << cnt[0] << '\t' << cnt[1] << '\t' << cnt[2] << '\t' << cnt[3] << endl;
}
// cout << "c2: ";
// for (m = 0; m < 16; ++m)
// cout << c2[m] << ' ';
// cout << endl;
// k表示buf存储的偏移量
for (i = k = 0; i < bwt->seq_len; ++i)
{
// 记录occ
if (i % OCC_INTERVAL == 0)
{
memcpy(buf + k, c, sizeof(uint32_t) * 4); // 保存occ
k += 4;
memcpy(buf + k, c2, sizeof(uint32_t) * 16); // 二次计算的occ
k += 16;
}
// 每个32位整数保存8个倒数第二列碱基和8个倒数第一列(bwt)碱基
if (i % 16 == 0) // 每个32位整数可以包含16个碱基每次需要处理16个碱基也就是间隔最小可以设置为16
{
uint32_t bwt_16_seq = bwt->bwt[i / 16];
uint32_t pre_bwt_16_seq = 0;
uint32_t *bwt_addr = bwt_occ_intv(bwt, i) + 8; // bwt字符串i对应的基准行
int offset = (i % OCC_INTERVAL) / 16;
bwt_16_seq = *(bwt_addr + offset);
for (j = 0; j < 16; ++j)
{
bwtint_t cur_line = i + j;
if (cur_line < bwt->seq_len) // 因为bwt序列里除去了$符号所以bwt序列个数比原版bwt少1
{
uint8_t bwt_base = bwt_B0(bwt, cur_line); // 对应行的bwt的碱基
// 先求出该碱基对应在第一列的行
if (cur_line >= bwt->primary) // 因为bwt序列里除去了$符号,所以,超过$所在行之后对应的seq位置应该加一才是真正对应的行
cur_line += 1;
bwtint_t origin_base_line = bwt->L2[bwt_base] + 1 + bwt_occ(bwt, cur_line - 1, bwt_base); // bwt矩阵行
bwtint_t base_line = origin_base_line;
if (base_line >= bwt->primary) // base_line表示在bwt字符中的位置所以超出$为最尾所在行之后要减掉1
base_line -= 1; // bwt碱基序列行不包含$
uint32_t pre_bwt_base = bwt_B0(bwt, base_line); // bwt列碱基对应的前一个碱基
if (origin_base_line == bwt->primary)
{
// 计算sec_bcp
fmt->sec_bcp = pre_bwt_base << 2 | bwt_base; // 因为把$当成A处理了
fmt->sec_primary = cur_line;
fmt->first_base = bwt_base;
fmt->last_base = pre_bwt_base;
}
// 暂存
pre_bwt_16_seq = pre_bwt_16_seq | (pre_bwt_base << (15-j)*2);
if (base_line >= bwt->primary)
base_line += 1; // bwt矩阵行
bwtint_t pre_base_line = bwt->L2[pre_bwt_base] + 1 + bwt_occ(bwt, base_line - 1, pre_bwt_base);
// 获取c
bwt_occ4(bwt, cur_line, cnt);
for (m = 0; m < 4; ++m)
{
c[m] = (uint32_t)cnt[m]; // 碱基m在cur_line(包含)之前的累积值
}
// 求出c2
cnt[bwt_base] -= 1; // 得到cur_line(不包含)之前的累积量
// bwtint_t m_first_line = bwt->L2[bwt_base] + cnt[bwt_base]; // 该bwt_base对应的在bwt矩阵中行的前一行
// bwt_occ4(bwt, m_first_line, cnt2);
// for (n = 0; n < 4; ++n) // 只计算bwt_base对应的二级occ其他用之前的值
// {
// int c2_idx = bwt_base << 2 | n;
// c2[c2_idx] = (uint32_t)cnt2[n];
// }
for (m = 0; m < 4; ++m)
{
bwtint_t m_first_line = -1;
// if (m == bwt_base || cnt[m] > 0)
if (m == bwt_base)
{
m_first_line = bwt->L2[m] + 1 + cnt[m]; // m是否与bwt_base相同这里需要想清楚情况不一样的
if (m_first_line >= bwt->seq_len)
m_first_line = bwt->seq_len;
// cout << cur_line << '\t' << BASE[m] << '\t' << m_first_line << endl;
bwt_occ4(bwt, m_first_line, cnt2);
for (n = 0; n < 4; ++n)
{
int c2_idx = m << 2 | n;
c2[c2_idx] = (uint32_t)cnt2[n];
}
}
}
cnt[bwt_base] += 1; // cur_line(包含)之前
// cout << cur_line << '\t'
// << base_line << '\t'
// << pre_base_line << '\t'
// << BASE[pre_bwt_base] << '\t'
// << BASE[bwt_base] << '\t'
// << cnt[0] << ' ' << cnt[1] << ' ' << cnt[2] << ' ' << cnt[3] << "\t\t";
// for (m = 0; m < 16; ++m)
// cout << c2[m] << ' ';
// cout << endl;
// for (m = 0; m < 16; ++m)
// fprintf(fmt_out, "%-4d", c2[m]);
// fprintf(fmt_out, "\n");
}
else
break;
}
//print_base_uint32(pre_bwt_16_seq);
//cout << endl;
//print_base_uint32(bwt_16_seq);
// 保存bwt和pre_bwt
uint32_t tmp_seq = 0;
tmp_seq = (((pre_bwt_16_seq & (3 << 30)) >> 0) | ((bwt_16_seq & (3 << 30)) >> 2))
| (((pre_bwt_16_seq & (3 << 28)) >> 2) | ((bwt_16_seq & (3 << 28)) >> 4))
| (((pre_bwt_16_seq & (3 << 26)) >> 4) | ((bwt_16_seq & (3 << 26)) >> 6))
| (((pre_bwt_16_seq & (3 << 24)) >> 6) | ((bwt_16_seq & (3 << 24)) >> 8))
| (((pre_bwt_16_seq & (3 << 22)) >> 8) | ((bwt_16_seq & (3 << 22)) >> 10))
| (((pre_bwt_16_seq & (3 << 20)) >> 10) | ((bwt_16_seq & (3 << 20)) >> 12))
| (((pre_bwt_16_seq & (3 << 18)) >> 12) | ((bwt_16_seq & (3 << 18)) >> 14))
| (((pre_bwt_16_seq & (3 << 16)) >> 14) | ((bwt_16_seq & (3 << 16)) >> 16));
buf[k++] = tmp_seq;
//cout << i << endl;
//print_base_uint32(tmp_seq);
if (j > 8)
{
// cout << "j: " << j << endl;
tmp_seq = (((pre_bwt_16_seq & (3 << 14)) << 16) | ((bwt_16_seq & (3 << 14)) << 14))
| (((pre_bwt_16_seq & (3 << 12)) << 14) | ((bwt_16_seq & (3 << 12)) << 12))
| (((pre_bwt_16_seq & (3 << 10)) << 12) | ((bwt_16_seq & (3 << 10)) << 10))
| (((pre_bwt_16_seq & (3 << 8)) << 10) | ((bwt_16_seq & (3 << 8)) << 8))
| (((pre_bwt_16_seq & (3 << 6)) << 8) | ((bwt_16_seq & (3 << 6)) << 6))
| (((pre_bwt_16_seq & (3 << 4)) << 6) | ((bwt_16_seq & (3 << 4)) << 4))
| (((pre_bwt_16_seq & (3 << 2)) << 4) | ((bwt_16_seq & (3 << 2)) << 2))
| (((pre_bwt_16_seq & (3 << 0)) << 2) | ((bwt_16_seq & (3 << 0)) << 0));
buf[k++] = tmp_seq;
//print_base_uint32(tmp_seq);
}
}
}
// the last element
// cout << c[0] << '\t' << c[1] << '\t' << c[2] << '\t' << c[3] << endl;
memcpy(buf + k, c, sizeof(uint32_t) * 4);
k += 4;
memcpy(buf + k, c2, sizeof(uint32_t) * 16);
k += 16;
// cout << "n occ: " << n_occ << endl;
// cout << "size: " << k << '\t' << fmt->bwt_size << endl;
xassert(k == fmt->bwt_size, "inconsistent bwt_size");
// update fmt
fmt->bwt = buf;
return fmt;
}
// an analogy to bwt_occ4() but more efficient, requiring k <= l
void bwt_2occ4(const bwt_t *bwt, bwtint_t k, bwtint_t l, bwtint_t cntk[4], bwtint_t cntl[4])
{
bwtint_t _k, _l;
_k = k - (k >= bwt->primary);
_l = l - (l >= bwt->primary);
if (_l >> OCC_INTV_SHIFT != _k >> OCC_INTV_SHIFT || k == (bwtint_t)(-1) || l == (bwtint_t)(-1))
{
bwt_occ4(bwt, k, cntk);
bwt_occ4(bwt, l, cntl);
}
else
{
bwtint_t x, y;
uint32_t *p, tmp, *endk, *endl;
k -= (k >= bwt->primary); // because $ is not in bwt
l -= (l >= bwt->primary);
p = bwt_occ_intv(bwt, k);
memcpy(cntk, p, 4 * sizeof(bwtint_t));
p += sizeof(bwtint_t); // sizeof(bwtint_t) = 4*(sizeof(bwtint_t)/sizeof(uint32_t))
// prepare cntk[]
endk = p + ((k >> 4) - ((k & ~OCC_INTV_MASK) >> 4));
endl = p + ((l >> 4) - ((l & ~OCC_INTV_MASK) >> 4));
for (x = 0; p < endk; ++p)
x += __occ_aux4(bwt, *p);
y = x;
tmp = *p & ~((1U << ((~k & 15) << 1)) - 1);
x += __occ_aux4(bwt, tmp) - (~k & 15);
// calculate cntl[] and finalize cntk[]
for (; p < endl; ++p)
y += __occ_aux4(bwt, *p);
tmp = *p & ~((1U << ((~l & 15) << 1)) - 1);
y += __occ_aux4(bwt, tmp) - (~l & 15);
memcpy(cntl, cntk, 4 * sizeof(bwtint_t));
cntk[0] += x & 0xff;
cntk[1] += x >> 8 & 0xff;
cntk[2] += x >> 16 & 0xff;
cntk[3] += x >> 24;
cntl[0] += y & 0xff;
cntl[1] += y >> 8 & 0xff;
cntl[2] += y >> 16 & 0xff;
cntl[3] += y >> 24;
}
}
void bwt_extend(const bwt_t *bwt, const bwtintv_t *ik, bwtintv_t ok[4], int is_back)
{
bwtint_t tk[4], tl[4];
int i;
bwt_2occ4(bwt, ik->x[!is_back] - 1, ik->x[!is_back] - 1 + ik->x[2], tk, tl); // tk表示在k行之前所有各个碱基累积出现次数tl表示在l行之前的累积
// 这里是反向扩展
for (i = 0; i != 4; ++i)
{
ok[i].x[!is_back] = bwt->L2[i] + 1 + tk[i]; // 起始行位置,互补链
ok[i].x[2] = tl[i] - tk[i]; // 间隔
}
// 因为计算的是互补碱基所以3对应着0,2对应1下边是正向扩展
ok[3].x[is_back] = ik->x[is_back] + (ik->x[!is_back] <= bwt->primary && ik->x[!is_back] + ik->x[2] - 1 >= bwt->primary);
ok[2].x[is_back] = ok[3].x[is_back] + ok[3].x[2];
ok[1].x[is_back] = ok[2].x[is_back] + ok[2].x[2];
ok[0].x[is_back] = ok[1].x[is_back] + ok[1].x[2];
}
// 利用bwt搜索seed完整搜索只需要单向搜索
void bwt_search(bwt_t *bwt, const string &q)
{
bwtintv_t ik, ok[4];
int i, j, c, x = 0;
bwt_set_intv(bwt, bval(q[x]), ik);
ik.info = x + 1;
cout << ik.x[0] << '\t' << ik.x[1] << '\t' << ik.x[2] << endl;
for (i = x + 1; i < q.size(); ++i)
{
if (bval(q[i]) < 4)
{
c = cbval(q[i]);
bwt_extend(bwt, &ik, ok, 0);
ik = ok[c];
ik.info = i + 1;
cout << "bwt-1: " << ik.x[0] << '\t' << ik.x[1] << '\t' << ik.x[2] << endl;
}
}
}
// 扩展两次
void bwt_extend2(const bwt_t *bwt, bwtintv_t *ik, bwtintv_t ok[4], int is_back, int c1)
{
bwtint_t tk[4], tl[4];
int i;
bwt_2occ4(bwt, ik->x[!is_back] - 1, ik->x[!is_back] - 1 + ik->x[2], tk, tl); // tk表示在k行之前所有各个碱基累积出现次数tl表示在l行之前的累积
// 这里是反向扩展
for (i = 0; i != 4; ++i)
{
ok[i].x[!is_back] = bwt->L2[i] + 1 + tk[i]; // 起始行位置,互补链
ok[i].x[2] = tl[i] - tk[i]; // 间隔
}
// 因为计算的是互补碱基所以3对应着0,2对应1下边是正向扩展
ok[3].x[is_back] = ik->x[is_back] + (ik->x[!is_back] <= bwt->primary && ik->x[!is_back] + ik->x[2] - 1 >= bwt->primary);
ok[2].x[is_back] = ok[3].x[is_back] + ok[3].x[2];
ok[1].x[is_back] = ok[2].x[is_back] + ok[2].x[2];
ok[0].x[is_back] = ok[1].x[is_back] + ok[1].x[2];
*ik = ok[c1];
bwt_2occ4(bwt, ik->x[!is_back] - 1, ik->x[!is_back] - 1 + ik->x[2], tk, tl);
for (i = 0; i != 4; ++i)
{
ok[i].x[!is_back] = bwt->L2[i] + 1 + tk[i]; // 起始行位置,互补链
ok[i].x[2] = tl[i] - tk[i]; // 间隔
}
// 因为计算的是互补碱基所以3对应着0,2对应1下边是正向扩展
ok[3].x[is_back] = ik->x[is_back] + (ik->x[!is_back] <= bwt->primary && ik->x[!is_back] + ik->x[2] - 1 >= bwt->primary);
ok[2].x[is_back] = ok[3].x[is_back] + ok[3].x[2];
ok[1].x[is_back] = ok[2].x[is_back] + ok[2].x[2];
ok[0].x[is_back] = ok[1].x[is_back] + ok[1].x[2];
}
void bwt_search2(bwt_t *bwt, const string &q)
{
bwtintv_t ik, ok[4];
int i, j, c1, c2, x = 0;
bwt_set_intv(bwt, bval(q[x]), ik);
ik.info = x + 1;
cout << ik.x[0] << '\t' << ik.x[1] << '\t' << ik.x[2] << endl;
for (i = x + 1; i + 1 < q.size(); i += 2)
{
if (bval(q[i]) < 4 && bval(q[i+1]) < 4)
{
c1 = cbval(q[i]);
c2 = cbval(q[i + 1]);
bwt_extend2(bwt, &ik, ok, 0, c1);
ik = ok[c2];
ik.info = i + 1;
cout << "bwt-2: " << ik.x[0] << '\t' << ik.x[1] << '\t' << ik.x[2] << endl;
}
else
{
break;
}
}
if (i < q.size() && bval(q[i]) < 4) { // 最后一次扩展
c1 = cbval(q[i]);
bwt_extend(bwt, &ik, ok, 0);
ik = ok[c1];
ik.info = i + 1;
cout << "bwt-2: " << ik.x[0] << '\t' << ik.x[1] << '\t' << ik.x[2] << endl;
}
}
#define fmt_set_intv(fmt, c, ik) ((ik).x[0] = (fmt)->L2[(int)(c)] + 1, (ik).x[2] = (fmt)->L2[(int)(c) + 1] - (fmt)->L2[(int)(c)], (ik).x[1] = (fmt)->L2[3 - (c)] + 1, (ik).info = 0)
#define fmt_occ_intv(b, k) ((b)->bwt + (k) / OCC_INTERVAL * (OCC_INTERVAL / 8 + 20))
void fmt_occ4(const FMTIndex *fmt, bwtint_t k, int b, uint32_t cnt1[4], uint32_t cnt2[4])
{
bwtint_t x;
uint32_t *p, tmp, *end;
if (k == (bwtint_t)(-1))
{
memset(cnt1, 0, 4 * sizeof(uint32_t));
memset(cnt2, 0, 4 * sizeof(uint32_t));
return;
}
k -= (k >= fmt->primary); // k由bwt矩阵对应的行转换成bwt字符串对应的行去掉了$,所以大于$的行都减掉1
p = fmt_occ_intv(fmt, k);
cout << "base: " << BASE[b] << endl;
cout << "k: " << k << "; p: " << (uint64_t)p << endl;
// cout << "0: " << (uint64_t)fmt_occ_intv(fmt, 0)
// << " ;31: " << (uint64_t)fmt_occ_intv(fmt, 31)
// << " ;32: " << (uint64_t)fmt_occ_intv(fmt, 32)
// << " ;64: " << (uint64_t)fmt_occ_intv(fmt, 64)
// << " ;96: " << (uint64_t)fmt_occ_intv(fmt, 96) << endl;
memcpy(cnt1, p, 4 * sizeof(uint32_t));
memcpy(cnt2, p + 4 + b * 4, 4 * sizeof(uint32_t));
cout << "cnt1: " << cnt1[0] << '\t' << cnt1[1] << '\t' << cnt1[2] << '\t' << cnt1[3] << endl;
cout << "cnt2: " << cnt2[0] << '\t' << cnt2[1] << '\t' << cnt2[2] << '\t' << cnt2[3] << endl;
p += 20; // 该地址是bwt和pre_bwt字符串数据的首地址
end = p + ((k >> 4) - ((k & ~OCC_INTV_MASK) >> 4)); // this is the end point of the following loop
// for (x = 0; p < end; ++p)
// x += __occ_aux4(bwt, *p);
// tmp = *p & ~((1U << ((~k & 15) << 1)) - 1);
// x += __occ_aux4(bwt, tmp) - (~k & 15);
// cnt[0] += x & 0xff;
// cnt[1] += x >> 8 & 0xff;
// cnt[2] += x >> 16 & 0xff;
// cnt[3] += x >> 24;
}
void fmt_2occ4(const FMTIndex *fmt, bwtint_t k, bwtint_t l, int b,
uint32_t cntk1[4], uint32_t cntl1[4], uint32_t cntk2[4], uint32_t cntl2[4])
{
bwtint_t _k, _l;
_k = k - (k >= fmt->primary); // 换算成了seq的行
_l = l - (l >= fmt->primary);
// if (_l >> OCC_INTV_SHIFT != _k >> OCC_INTV_SHIFT || k == (bwtint_t)(-1) || l == (bwtint_t)(-1))
// {
fmt_occ4(fmt, k, b, cntk1, cntk2);
fmt_occ4(fmt, l, b, cntl1, cntk1);
// }
// else
// {
// bwtint_t x, y;
// uint32_t *p, tmp, *endk, *endl;
// k -= (k >= bwt->primary); // because $ is not in bwt
// l -= (l >= bwt->primary);
// p = bwt_occ_intv(bwt, k);
// memcpy(cntk, p, 4 * sizeof(bwtint_t));
// p += sizeof(bwtint_t); // sizeof(bwtint_t) = 4*(sizeof(bwtint_t)/sizeof(uint32_t))
// // prepare cntk[]
// endk = p + ((k >> 4) - ((k & ~OCC_INTV_MASK) >> 4));
// endl = p + ((l >> 4) - ((l & ~OCC_INTV_MASK) >> 4));
// for (x = 0; p < endk; ++p)
// x += __occ_aux4(bwt, *p);
// y = x;
// tmp = *p & ~((1U << ((~k & 15) << 1)) - 1);
// x += __occ_aux4(bwt, tmp) - (~k & 15);
// // calculate cntl[] and finalize cntk[]
// for (; p < endl; ++p)
// y += __occ_aux4(bwt, *p);
// tmp = *p & ~((1U << ((~l & 15) << 1)) - 1);
// y += __occ_aux4(bwt, tmp) - (~l & 15);
// memcpy(cntl, cntk, 4 * sizeof(bwtint_t));
// cntk[0] += x & 0xff;
// cntk[1] += x >> 8 & 0xff;
// cntk[2] += x >> 16 & 0xff;
// cntk[3] += x >> 24;
// cntl[0] += y & 0xff;
// cntl[1] += y >> 8 & 0xff;
// cntl[2] += y >> 16 & 0xff;
// cntl[3] += y >> 24;
// }
}
#define __fmt_occ_e2_aux4(fmt, b, val) \
((fmt)->cnt_table[(b)][(val) & 0xff] + (fmt)->cnt_table[b][(val) >> 8 & 0xff] + (fmt)->cnt_table[b][(val) >> 16 & 0xff] + (fmt)->cnt_table[b][(val) >> 24])
void fmt_e1_occ4(const FMTIndex *fmt, bwtint_t k, uint32_t cnt[4])
{
uint32_t x;
uint32_t *p, tmp, *end;
if (k == (bwtint_t)(-1))
{
memset(cnt, 0, 4 * sizeof(uint32_t));
return;
}
k -= (k >= fmt->primary); // k由bwt矩阵对应的行转换成bwt字符串对应的行去掉了$,所以大于$的行都减掉1
p = fmt_occ_intv(fmt, k);
memcpy(cnt, p, 4 * sizeof(uint32_t));
p += 20; // 该地址是bwt和pre_bwt字符串数据的首地址
end = p + ((k >> 3) - ((k & ~OCC_INTV_MASK) >> 3)); // this is the end point of the following loop
for (x = 0; p < end; ++p)
{
x += __fmt_occ_e2_aux4(fmt, 4, *p);
}
tmp = *p & ~((1U << ((~k & 7) << 2)) - 1);
x += __fmt_occ_e2_aux4(fmt, 4, tmp) - (~k & 7);
cnt[0] += x & 0xff;
cnt[1] += x >> 8 & 0xff;
cnt[2] += x >> 16 & 0xff;
cnt[3] += x >> 24;
}
void fmt_e2_occ4(const FMTIndex *fmt, bwtint_t k, int b, uint32_t cnt1[4], uint32_t cnt2[4])
{
uint32_t x1, x2;
uint32_t *p, tmp, *end;
bwtint_t bwt_k_line = k, bwt_k_base_line = k >> OCC_INTV_SHIFT << OCC_INTV_SHIFT;
if (k == (bwtint_t)(-1))
{
p = fmt->bwt + 4 + b * 4;
memset(cnt1, 0, 4 * sizeof(uint32_t));
memcpy(cnt2, p, 4 * sizeof(uint32_t));
return;
}
k -= (k >= fmt->primary); // k由bwt矩阵对应的行转换成bwt字符串对应的行去掉了$,所以大于$的行都减掉1
p = fmt_occ_intv(fmt, k);
// cout << "base: " << BASE[b] << endl;
// cout << "k: " << k << "; c 0 cnt: " << p[0] << '\t' << p[1] << '\t' << p[2] << '\t' << p[3] << endl;
memcpy(cnt1, p, 4 * sizeof(uint32_t));
memcpy(cnt2, p + 4 + b * 4, 4 * sizeof(uint32_t));
// cout << "[start: ] k: " << k << "; k line cnt: " << cnt[0] << '\t' << cnt[1] << '\t' << cnt[2] << '\t' << cnt[3] << endl;
p += 20; // 该地址是bwt和pre_bwt字符串数据的首地址
end = p + ((k >> 3) - ((k & ~OCC_INTV_MASK) >> 3)); // this is the end point of the following loop
for (x1 = 0, x2 = 0; p < end; ++p)
{
x1 += __fmt_occ_e2_aux4(fmt, 4, *p);
x2 += __fmt_occ_e2_aux4(fmt, b, *p);
}
//{
// x += fmt->cnt_table[b][*p & 0xff]
// + fmt->cnt_table[b][*p >> 8 & 0xff]
// + fmt->cnt_table[b][*p >> 16 & 0xff]
// + fmt->cnt_table[b][*p >> 24 & 0xff];
// // cout << "p: " << *p << endl;
// // print_base_uint32(*p);
// // cout << (fmt->cnt_table[b][*p & 0xff] >> 24) << ' '
// // << fmt->cnt_table[b][*p >> 24 & 0xff]
// // << endl;
//}
tmp = *p & ~((1U << ((~k & 7) << 2)) - 1);
x1 += __fmt_occ_e2_aux4(fmt, 4, tmp) - (~k & 7);
x2 += __fmt_occ_e2_aux4(fmt, b, tmp);
if (b == 0)
x2 -= ~k & 7;
// 如果跨过了second_primary,那么可能需要减掉一次累积值
if (b == fmt->first_base && bwt_k_base_line < fmt->sec_primary && bwt_k_line >= fmt->sec_primary)
{
x2 -= 1 << (fmt->last_base << 3);
}
// x += __occ_aux4(bwt, tmp) - (~k & 15);
// cout << "x: " << x << " b:" << b << endl;
cnt1[0] += x1 & 0xff;
cnt1[1] += x1 >> 8 & 0xff;
cnt1[2] += x1 >> 16 & 0xff;
cnt1[3] += x1 >> 24;
cnt2[0] += x2 & 0xff;
cnt2[1] += x2 >> 8 & 0xff;
cnt2[2] += x2 >> 16 & 0xff;
cnt2[3] += x2 >> 24;
// cout << "[end : ]k: " << k << "; k line cnt: " << cnt[0] << '\t' << cnt[1] << '\t' << cnt[2] << '\t' << cnt[3] << endl;
}
void fmt_extend1(const FMTIndex *fmt, bwtintv_t *ik, bwtintv_t ok[4], int is_back, int b1)
{
uint32_t tk[4], tl[4];
int i;
fmt_e1_occ4(fmt, ik->x[!is_back] - 1, tk);
fmt_e1_occ4(fmt, ik->x[!is_back] - 1 + ik->x[2], tl);
for (i = 0; i != 4; ++i)
{
ok[i].x[!is_back] = fmt->L2[i] + 1 + tk[i]; // 起始行位置,互补链
ok[i].x[2] = tl[i] - tk[i]; // 间隔
}
ok[3].x[is_back] = ik->x[is_back] + (ik->x[!is_back] <= fmt->primary && ik->x[!is_back] + ik->x[2] - 1 >= fmt->primary);
for (i = 2; i >= b1; --i)
ok[i].x[is_back] = ok[i + 1].x[is_back] + ok[i + 1].x[2];
*ik = ok[b1];
}
void fmt_extend2(const FMTIndex *fmt, bwtintv_t *ik, bwtintv_t ok[4], int is_back, int b1, int b2)
{
uint32_t tk1[4], tl1[4], tk2[4], tl2[4];
int i;
// fmt_2occ4(fmt, ik->x[!is_back] - 1, ik->x[!is_back] - 1 + ik->x[2], b1, tk1, tl1, tk2, tl2); // tk表示在k行之前所有各个碱基累积出现次数tl表示在l行之前的累积
fmt_e2_occ4(fmt, ik->x[!is_back] - 1, b1, tk1, tk2);
fmt_e2_occ4(fmt, ik->x[!is_back] - 1 + ik->x[2], b1, tl1, tl2);
// fmt_e2_occ(fmt, -1, 0, tk);
// 这里是反向扩展
for (i = 0; i != 4; ++i)
{
ok[i].x[!is_back] = fmt->L2[i] + 1 + tk2[i]; // 起始行位置,互补链
ok[i].x[2] = tl2[i] - tk2[i]; // 间隔
}
// 因为计算的是互补碱基所以3对应着0,2对应1下边是正向扩展
ok[3].x[is_back] = ik->x[is_back] + (ik->x[!is_back] <= fmt->primary && ik->x[!is_back] + ik->x[2] - 1 >= fmt->primary);
for (i = 2; i >= b1; --i)
ok[i].x[is_back] = ok[i + 1].x[is_back] + tl1[i + 1] - tk1[i + 1];
ok[3].x[is_back] = ok[b1].x[is_back] + (ok[b1].x[!is_back] <= fmt->primary && ok[b1].x[!is_back] + ok[b1].x[2] - 1 >= fmt->primary);
for (i = 2; i >= b2; --i)
ok[i].x[is_back] = ok[i + 1].x[is_back] + ok[i + 1].x[2];
*ik = ok[b2];
}
// 利用fmt搜索seed完整搜索只需要单向搜索
void fmt_search(FMTIndex *fmt, const string &q)
{
bwtintv_t ik, ok[4];
int i, j, c1, c2, x = 0;
fmt_set_intv(fmt, bval(q[x]), ik);
ik.info = x + 1;
cout << ik.x[0] << '\t' << ik.x[1] << '\t' << ik.x[2] << endl;
for (i = x + 1; i + 1 < q.size(); i += 2)
{
if (bval(q[i]) < 4 && bval(q[i + 1]) < 4)
{
c1 = cbval(q[i]);
c2 = cbval(q[i + 1]);
fmt_extend2(fmt, &ik, ok, 0, c1, c2);
ik.info = i + 1;
cout << "fmt : " << ik.x[0] << '\t' << ik.x[1] << '\t' << ik.x[2] << endl;
}
else
{
break;
}
}
if (i < q.size() && bval(q[i]) < 4)
{ // 最后一次扩展
c1 = cbval(q[i]);
fmt_extend1(fmt, &ik, ok, 0, c1);
ik.info = i + 1;
cout << "fmt : " << ik.x[0] << '\t' << ik.x[1] << '\t' << ik.x[2] << endl;
}
}
int main_fmtidx(int argc, char **argv)
{
// string seq("ACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTAACCCTA");
string seq("ACCCT");
string rseq = calc_reverse_seq(seq);
seq = seq + rseq;
//create_bwt_mtx(seq);
//cout << seq << endl;
bwt_t *bwt = restore_bwt_str(argv[1]); // 读取bwt原始字符串带ACGT总的累积量
create_interval_occ_bwt(bwt); // 根据bwt字符串创建包含interval occ的bwt128碱基+ACGT累积量
cout << "L2: " << bwt->L2[0] << '\t' << bwt->L2[1] << '\t' << bwt->L2[2] << '\t'
<< bwt->L2[3] << '\t' << bwt->L2[4] << endl;
string s = "AACCCTAA";
bwt_search(bwt, s);
bwt_search2(bwt, s);
// for (int i = 0; i < 120; ++i)
// {
// cout << i << '\t' << bwt_B0(bwt, i) << endl;
// }
// TGGGAT
FMTIndex *fmt = create_fmt_from_bwt(bwt);
fmt_search(fmt, s);
// cout << bwt->bwt_size << endl;
// cout << bwt->seq_len << endl;
cout << "sec_: " << fmt->sec_bcp << '\t' << fmt->sec_primary << endl;
uint8_t b8 = 2 << 4 | 2;
cout << "AGAG: " << fmt->cnt_table[2][b8] << endl;
cout << (((b8 >> 6) == 0 && (b8 >> 4 & 3) == 2) + ((b8 >> 2 & 3) == 0 && (b8 & 3) == 2)) << endl;
return 0;
}