bwa_perf/sa.cpp

141 lines
4.8 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <iostream>
#include <stdint.h>
#include <stdlib.h>
#include <vector>
#include "util.h"
#include "sa.h"
using namespace std;
// 由33bit表示的bwt行对应的所在序列位置信息
inline void bwt_set_sa_33(uint8_t *sa_arr, bwtint_t k, bwtint_t val)
{
const bwtint_t block_idx = (k >> 3) * 33; // 8个数为一组共享33个字节
const int val_idx_in_block = k & 7;
const bwtint_t start_byte_idx = block_idx + (val_idx_in_block << 2);
bwtint_t *sa_addr = (bwtint_t *)(sa_arr + start_byte_idx);
// *sa_addr &= (1 << val_idx_in_block) - 1; // 如果开辟内存的时候清零了,这一步可以省略,会清除后面的数据,只适合按递增顺序赋值
*sa_addr |= val << val_idx_in_block;
}
// 获取某一行的排序值(小端模式)
inline bwtint_t bwt_get_sa_33(uint8_t *sa_arr, bwtint_t k)
{
const bwtint_t block_idx = (k >> 3) * 33; // 8个数为一组共享33个字节
const int val_idx_in_block = k & 7;
const bwtint_t start_byte_idx = block_idx + (val_idx_in_block << 2);
bwtint_t val = *(bwtint_t *)(sa_arr + start_byte_idx);
val = (val >> val_idx_in_block) & 8589934591;
return val;
}
inline void bwt_set_sa_40(uint8_t *sa_arr, bwtint_t k, bwtint_t val)
{
// const bwtint_t start_byte_idx = k * 5;
// bwtint_t *sa_addr = (bwtint_t *)(sa_arr + start_byte_idx);
// *sa_addr |= val;
uint64_t byte_idx = k * 5;
sa_arr[byte_idx] = (uint8_t)val;
*((uint32_t*)(sa_arr+byte_idx+1)) = (uint32_t) (val >> 8);
}
// 获取某一行的排序值(小端模式)
inline bwtint_t bwt_get_sa_40(uint8_t *sa_arr, bwtint_t k)
{
const bwtint_t start_byte_idx = k * 5;
// const bwtint_t start_byte_idx = (k << 2) + k;
bwtint_t val = *(bwtint_t *)(sa_arr + start_byte_idx);
val = val & 8589934591;
return val;
}
#define GET_SA_40(sa_arr, k) ((*(bwtint_t *)(sa_arr + k*5)) & 8589934591)
static inline void set_sa_val_40(uint8_t *sa_arr, uint64_t idx, uint64_t val) {
uint64_t byte_idx = idx * 5;
sa_arr[byte_idx] = (uint8_t)(val >> 32);
*((uint32_t*)(sa_arr+byte_idx+1)) = (uint32_t) val;
}
static inline uint64_t get_sa_val_40(uint8_t *sa_arr, uint64_t idx) {
//uint64_t byte_idx = (idx << 2) + idx;
uint64_t byte_idx = idx * 5;
sa_arr += byte_idx;
uint64_t sa_val = *sa_arr;
sa_val = (sa_val << 32) | *((uint32_t*)(sa_arr+1));
return sa_val;
}
#define TEST_RAND_READ 1
int main_sa(int argc, char **argv)
{
double timeRead40, timeWrite40,
timeRead33, timeWrite33,
timeRead401, timeWrite401,
timeRead64, timeWrite64;
double timeStart;
int saLen = 1 << 25; // 对应于bwt字符串长度
bwtint_t diffPos = 0; // 用来显示结果不一致时候的出错位置
vector<bwtint_t> valArr(saLen);
vector<int> ri(saLen);
uint8_t *sa33 = (uint8_t*)calloc(SA_BYTES_33(saLen), 1);
uint8_t *sa40 = (uint8_t*)calloc(SA_BYTES_40(saLen), 1);
uint8_t *sa401 = (uint8_t*)calloc(SA_BYTES_40(saLen), 1);
bwtint_t *sa64 = (bwtint_t*)calloc(saLen, sizeof(bwtint_t));
for (int i=0; i< saLen; ++i) {
valArr[i] = rand();
valArr[i] <<= 1;
#if TEST_RAND_READ
ri[i] = rand() % saLen;
#else
ri[i] = i;
#endif
}
// 33 test
timeStart = realtime();
for (int i=0; i<saLen; ++i) bwt_set_sa_33(sa33, i, valArr[i]);
timeWrite33 = realtime() - timeStart;
timeStart = realtime();
for (int i=0; i<saLen; ++i) if (bwt_get_sa_33(sa33, ri[i]) != valArr[i]) diffPos = i;
timeRead33 = realtime() - timeStart;
// 40 test
timeStart = realtime();
for (int i=0; i<saLen; ++i) set_sa_val_40(sa40, i, valArr[i]);
timeWrite40 = realtime() - timeStart;
timeStart = realtime();
for (int i=0; i<saLen; ++i) if (get_sa_val_40(sa40, ri[i]) != valArr[i]) diffPos = i;
timeRead40 = realtime() - timeStart;
// 401 test
timeStart = realtime();
for (int i=0; i<saLen; ++i) bwt_set_sa_40(sa401, i, valArr[i]);
timeWrite401 = realtime() - timeStart;
timeStart = realtime();
for (int i=0; i<saLen; ++i) if (bwt_get_sa_40(sa401, ri[i]) != valArr[i]) diffPos = i;
timeRead401 = realtime() - timeStart;
// 64 test
timeStart = realtime();
for (int i=0; i<saLen; ++i) sa64[i] = valArr[i];
timeWrite64 = realtime() - timeStart;
timeStart = realtime();
for (int i=0; i<saLen; ++i) if (sa64[ri[i]] != valArr[i]) diffPos = i;
timeRead64 = realtime() - timeStart;
cout << "33 write time: " << timeWrite33 << " s. read time: " << timeRead33 << " s." << endl;
cout << "40 write time: " << timeWrite40 << " s. read time: " << timeRead40 << " s." << endl;
cout << "401 write time: " << timeWrite401 << " s. read time: " << timeRead401 << " s." << endl;
cout << "64 write time: " << timeWrite64 << " s. read time: " << timeRead64 << " s." << endl;
cout << "diff pos: " << diffPos << endl;
return 0;
}