bwa_perf/bwt.cpp

362 lines
12 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <stdint.h>
#include <stdlib.h>
#include <stdio.h>
#include <string.h>
#include <string>
#include <iostream>
#include "util.h"
#include "bwt.h"
using namespace std;
const static char BASE[4] = {'A', 'C', 'G', 'T'};
// 打印32位整型数据中包含的pre-bwtbwt
static void print_base_uint32(uint32_t p)
{
for (int i = 30; i > 0; i -= 2)
{
int b1 = p >> i & 3;
cout << BASE[b1] << endl;
}
}
// 保存bwt数据
void dump_bwt(const char *fn, const bwt_t *bwt)
{
FILE *fp;
fp = xopen(fn, "wb");
err_fwrite(&bwt->primary, sizeof(bwtint_t), 1, fp);
err_fwrite(bwt->L2 + 1, sizeof(bwtint_t), 4, fp);
err_fwrite(bwt->bwt, 4, bwt->bwt_size, fp);
err_fflush(fp);
err_fclose(fp);
}
// 计算一个字节构成的T,G,C,A序列对应的每个碱基的个数(按T,G,C,A顺序存储在32位整数中每个占8位)
void bwt_gen_cnt_table(bwt_t *bwt)
{
int i, j;
for (i = 0; i != 256; ++i)
{
uint32_t x = 0;
for (j = 0; j != 4; ++j)
x |= (((i & 3) == j) + ((i >> 2 & 3) == j) + ((i >> 4 & 3) == j) + (i >> 6 == j)) << (j << 3);
bwt->cnt_table[i] = x;
}
}
// 解析两bit的bwt碱基序列这个只有bwt str可以包含也可不包含occ check point
bwt_t *restore_bwt(const char *fn)
{
bwt_t *bwt;
bwt = (bwt_t *)calloc(1, sizeof(bwt_t));
FILE *fp = fopen(fn, "rb");
fseek(fp, 0, SEEK_END);
bwt->bwt_size = (ftell(fp) - sizeof(bwtint_t) * 5) >> 2; // 以32位word为单位计算的size
bwt->bwt = (uint32_t *)calloc(bwt->bwt_size, 4);
fseek(fp, 0, SEEK_SET);
fread(&bwt->primary, sizeof(bwtint_t), 1, fp);
fread(bwt->L2 + 1, sizeof(bwtint_t), 4, fp);
fread_fix(fp, bwt->bwt_size << 2, bwt->bwt);
bwt->seq_len = bwt->L2[4];
// char *buf = (char *)calloc(bwt->seq_len + 1, 1);
// for (bwtint_t i = 0; i < bwt->seq_len; ++i)
// {
// buf[i] = BASE[bwt->bwt[i >> 4] >> ((15 - (i & 15)) << 1) & 3];
// cout << buf[i];
// }
// cout << endl;
fclose(fp);
bwt_gen_cnt_table(bwt); // 字节所能表示的各种碱基组合中,各个碱基的累积数量
return bwt;
}
// 根据原始的字符串bwt创建interval-bwt用uint32_t来表示occ
void create_interval_occ_bwt(bwt_t *bwt)
{
bwtint_t i, k, n_occ;
uint32_t *buf;
uint32_t c[4];
n_occ = (bwt->seq_len + OCC_INTERVAL - 1) / OCC_INTERVAL + 1;
bwt->bwt_size += n_occ * 4; // the new size
buf = (uint32_t *)calloc(bwt->bwt_size, 4); // will be the new bwt
c[0] = c[1] = c[2] = c[3] = 0;
// 计算occ生成naive bwt
for (i = k = 0; i < bwt->seq_len; ++i)
{
if (i % OCC_INTERVAL == 0)
{
memcpy(buf + k, c, sizeof(uint32_t) * 4);
k += 4; // in fact: sizeof(bwtint_t)=4*(sizeof(bwtint_t)/4) 每个c包含多少个32位
}
if (i % 16 == 0)
buf[k++] = bwt->bwt[i / 16]; // 16 == sizeof(uint32_t)/2, 2个bit表示一个碱基
++c[bwt_B00(bwt, i)];
}
// the last element
memcpy(buf + k, c, sizeof(uint32_t) * 4);
xassert(k + 4 == bwt->bwt_size, "inconsistent bwt_size");
// update bwt
free(bwt->bwt);
bwt->bwt = buf;
}
// 对64位整型数据y计算碱基c的累积个数
static inline int __occ_aux(uint64_t y, int c)
{
// reduce nucleotide counting to bits counting
y = ((c & 2) ? y : ~y) >> 1 & ((c & 1) ? y : ~y) & 0x5555555555555555ull;
// count the number of 1s in y
y = (y & 0x3333333333333333ull) + (y >> 2 & 0x3333333333333333ull);
return ((y + (y >> 4)) & 0xf0f0f0f0f0f0f0full) * 0x101010101010101ull >> 56;
}
// k行(包含)之前碱基c的累计总数, interval大于等于32才能正确计算
bwtint_t bwt_occ(const bwt_t *bwt, bwtint_t k, uint8_t c)
{
bwtint_t n;
uint32_t *p, *end;
if (k == bwt->seq_len)
return bwt->L2[c + 1] - bwt->L2[c];
if (k == (bwtint_t)(-1))
return 0;
k -= (k >= bwt->primary); // because $ is not in bwt
// retrieve Occ at k/OCC_INTERVAL
n = ((bwtint_t *)(p = bwt_occ_intv(bwt, k)))[c];
p += sizeof(bwtint_t); // jump to the start of the first BWT cell
// calculate Occ up to the last k/32
end = p + (((k >> 5) - ((k & ~OCC_INTV_MASK) >> 5)) << 1);
for (; p < end; p += 2)
n += __occ_aux((uint64_t)p[0] << 32 | p[1], c);
// calculate Occ
n += __occ_aux(((uint64_t)p[0] << 32 | p[1]) & ~((1ull << ((~k & 31) << 1)) - 1), c);
if (c == 0)
n -= ~k & 31; // corrected for the masked bits
return n;
}
// 统计k行bwt mtx行之前4种碱基累积数量这里的k是bwt矩阵里的行比bwt字符串多1
void bwt_occ4(const bwt_t *bwt, bwtint_t k, bwtint_t cnt[4])
{
bwtint_t x = 0;
uint32_t *p, tmp, *end;
// bwtint_t bwt_k_base_line = k >> OCC_INTV_SHIFT << OCC_INTV_SHIFT;
if (k == (bwtint_t)(-1))
{
memset(cnt, 0, 4 * sizeof(bwtint_t));
return;
}
k -= (k >= bwt->primary); // because $ is not in bwt
p = bwt_occ_intv(bwt, k);
cnt[0] = p[0];
cnt[1] = p[1];
cnt[2] = p[2];
cnt[3] = p[3];
p += 4; // check point之后的bwt字符串首地址
end = p + ((k >> 4) - ((k & ~OCC_INTV_MASK) >> 4)); // this is the end point of the following loop
for (; p < end; ++p)
{
x += __occ_aux4(bwt, *p);
}
tmp = *p & ~((1U << ((~k & 15) << 1)) - 1);
x += __occ_aux4(bwt, tmp) - (~k & 15); // 这里多算了A要减去
cnt[0] += x & 0xff;
cnt[1] += x >> 8 & 0xff;
cnt[2] += x >> 16 & 0xff;
cnt[3] += x >> 24;
}
// an analogy to bwt_occ4() but more efficient, requiring k <= l有提升但不大
void bwt_2occ4(const bwt_t *bwt, bwtint_t k, bwtint_t l, bwtint_t cntk[4], bwtint_t cntl[4])
{
// bwt_occ4(bwt, k, cntk);
// bwt_occ4(bwt, l, cntl);
// return;
bwtint_t _k, _l;
_k = k - (k >= bwt->primary);
_l = l - (l >= bwt->primary);
if (_l >> OCC_INTV_SHIFT != _k >> OCC_INTV_SHIFT || k == (bwtint_t)(-1) || l == (bwtint_t)(-1))
{
bwt_occ4(bwt, k, cntk);
bwt_occ4(bwt, l, cntl);
}
else
{
bwtint_t x, y;
uint32_t *p, tmp, *endk, *endl;
k -= (k >= bwt->primary); // because $ is not in bwt
l -= (l >= bwt->primary);
p = bwt_occ_intv(bwt, k);
cntk[0] = p[0];
cntk[1] = p[1];
cntk[2] = p[2];
cntk[3] = p[3];
p += 4;
// prepare cntk[]
endk = p + ((k >> 4) - ((k & ~OCC_INTV_MASK) >> 4));
endl = p + ((l >> 4) - ((l & ~OCC_INTV_MASK) >> 4));
for (x = 0; p < endk; ++p)
x += __occ_aux4(bwt, *p);
y = x;
tmp = *p & ~((1U << ((~k & 15) << 1)) - 1);
x += __occ_aux4(bwt, tmp) - (~k & 15);
// calculate cntl[] and finalize cntk[]
for (; p < endl; ++p)
y += __occ_aux4(bwt, *p);
tmp = *p & ~((1U << ((~l & 15) << 1)) - 1);
y += __occ_aux4(bwt, tmp) - (~l & 15);
memcpy(cntl, cntk, 4 * sizeof(bwtint_t));
cntk[0] += x & 0xff;
cntk[1] += x >> 8 & 0xff;
cntk[2] += x >> 16 & 0xff;
cntk[3] += x >> 24;
cntl[0] += y & 0xff;
cntl[1] += y >> 8 & 0xff;
cntl[2] += y >> 16 & 0xff;
cntl[3] += y >> 24;
}
}
// 设置某一行的排序值sa的索引有效值从1开始0设置为-1, 小端模式)
void bwt_set_sa(uint8_t *sa_arr, bwtint_t k, bwtint_t val)
{
const bwtint_t block_idx = (k >> 3) * 33; // 8个数为一组共享33个字节
const int val_idx_in_block = k & 7;
const bwtint_t start_byte_idx = block_idx + (val_idx_in_block << 2);
bwtint_t *sa_addr = (bwtint_t *)(sa_arr + start_byte_idx);
// *sa_addr &= (1 << val_idx_in_block) - 1; // 如果开辟内存的时候清零了,这一步可以省略,会清除后面的数据,只适合按递增顺序赋值
*sa_addr |= (val & ((1L << 33) - 1)) << val_idx_in_block;
}
// 获取某一行的排序值(小端模式)
bwtint_t bwt_get_sa(uint8_t *sa_arr, bwtint_t k)
{
const bwtint_t block_idx = (k >> 3) * 33; // 8个数为一组共享33个字节
const int val_idx_in_block = k & 7;
const bwtint_t start_byte_idx = block_idx + (val_idx_in_block << 2);
bwtint_t val = *(bwtint_t *)(sa_arr + start_byte_idx);
val = (val >> val_idx_in_block) & 8589934591;
return val;
}
void bwt_extend(const bwt_t *bwt, const bwtintv_t *ik, bwtintv_t ok[4], int is_back)
{
bwtint_t tk[4], tl[4];
int i;
bwt_2occ4(bwt, ik->x[!is_back] - 1, ik->x[!is_back] - 1 + ik->x[2], tk, tl); // tk表示在k行之前所有各个碱基累积出现次数tl表示在l行之前的累积
// 这里是反向扩展
for (i = 0; i != 4; ++i)
{
ok[i].x[!is_back] = bwt->L2[i] + 1 + tk[i]; // 起始行位置,互补链
ok[i].x[2] = tl[i] - tk[i]; // 间隔
}
// 因为计算的是互补碱基所以3对应着0,2对应1下边是正向扩展
ok[3].x[is_back] = ik->x[is_back] + (ik->x[!is_back] <= bwt->primary && ik->x[!is_back] + ik->x[2] - 1 >= bwt->primary);
ok[2].x[is_back] = ok[3].x[is_back] + ok[3].x[2];
ok[1].x[is_back] = ok[2].x[is_back] + ok[2].x[2];
ok[0].x[is_back] = ok[1].x[is_back] + ok[1].x[2];
}
// 利用bwt搜索seed完整搜索只需要单向搜索
bwtintv_t bwt_search(bwt_t *bwt, const string &q)
{
bwtintv_t ik, ok[4];
int i, c, x = 0;
bwt_set_intv(bwt, bval(q[x]), ik);
ik.info = x + 1;
for (i = x + 1; i < (int)q.size(); ++i)
{
if (bval(q[i]) < 4)
{
c = cbval(q[i]);
bwt_extend(bwt, &ik, ok, 0);
ik = ok[c];
ik.info = i + 1;
}
}
return ik;
}
// 扩展两次
void bwt_extend2(const bwt_t *bwt, bwtintv_t *ik, bwtintv_t ok[4], int is_back, int c1, int c2)
{
bwtint_t tk[4], tl[4];
int i;
bwt_2occ4(bwt, ik->x[!is_back] - 1, ik->x[!is_back] - 1 + ik->x[2], tk, tl); // tk表示在k行之前所有各个碱基累积出现次数tl表示在l行之前的累积
// 这里是反向扩展
for (i = 0; i != 4; ++i)
{
ok[i].x[!is_back] = bwt->L2[i] + 1 + tk[i]; // 起始行位置,互补链
ok[i].x[2] = tl[i] - tk[i]; // 间隔
}
// 因为计算的是互补碱基所以3对应着0,2对应1下边是正向扩展
ok[3].x[is_back] = ik->x[is_back] + (ik->x[!is_back] <= bwt->primary && ik->x[!is_back] + ik->x[2] - 1 >= bwt->primary);
ok[2].x[is_back] = ok[3].x[is_back] + ok[3].x[2];
ok[1].x[is_back] = ok[2].x[is_back] + ok[2].x[2];
ok[0].x[is_back] = ok[1].x[is_back] + ok[1].x[2];
*ik = ok[c1];
bwt_2occ4(bwt, ik->x[!is_back] - 1, ik->x[!is_back] - 1 + ik->x[2], tk, tl);
for (i = 0; i != 4; ++i)
{
ok[i].x[!is_back] = bwt->L2[i] + 1 + tk[i]; // 起始行位置,互补链
ok[i].x[2] = tl[i] - tk[i]; // 间隔
}
// 因为计算的是互补碱基所以3对应着0,2对应1下边是正向扩展
ok[3].x[is_back] = ik->x[is_back] + (ik->x[!is_back] <= bwt->primary && ik->x[!is_back] + ik->x[2] - 1 >= bwt->primary);
ok[2].x[is_back] = ok[3].x[is_back] + ok[3].x[2];
ok[1].x[is_back] = ok[2].x[is_back] + ok[2].x[2];
ok[0].x[is_back] = ok[1].x[is_back] + ok[1].x[2];
}
// 每次扩展两步
bwtintv_t bwt_search2(bwt_t *bwt, const string &q)
{
bwtintv_t ik, ok[4];
int i, c1, c2, x = 0;
bwt_set_intv(bwt, bval(q[x]), ik);
ik.info = x + 1;
for (i = x + 1; i + 1 < (int)q.size(); i += 2)
{
if (bval(q[i]) < 4 && bval(q[i + 1]) < 4)
{
c1 = cbval(q[i]);
c2 = cbval(q[i + 1]);
bwt_extend2(bwt, &ik, ok, 0, c1, c2);
ik = ok[c2];
ik.info = i + 1;
}
else
{
break;
}
}
if (i < (int)q.size() && bval(q[i]) < 4)
{ // 最后一次扩展
c1 = cbval(q[i]);
bwt_extend(bwt, &ik, ok, 0);
ik = ok[c1];
ik.info = i + 1;
}
return ik;
}