From 6e1dd08fb646db2a0629c882fc16c375c3d6dabb Mon Sep 17 00:00:00 2001 From: zzh Date: Fri, 23 Feb 2024 01:09:08 +0800 Subject: [PATCH] =?UTF-8?q?=E5=B0=86seed=E5=92=8Cextend=E9=83=A8=E5=88=86?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=88=90=E4=BA=86batch=E6=A8=A1=E5=BC=8F?= =?UTF-8?q?=EF=BC=8C=E5=A5=BD=E5=83=8F=E6=B2=A1=E5=95=A5=E6=95=88=E6=9E=9C?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .vscode/launch.json | 4 +- .vscode/settings.json | 3 +- Makefile | 4 +- bwa.c | 4 +- bwamem.c | 540 +++++++++++++++++++++--------------------- bwamem.h | 1 + fastmap.c | 15 +- fmt_idx.c | 36 +-- ksw.h | 2 + ksw_extend2_avx2.c | 494 ++++++++++++++++++++++++++++++++++++++ ksw_extend2_avx2_u8.c | 370 +++++++++++++++++++++++++++++ run.sh | 14 +- utils.h | 8 +- 13 files changed, 1188 insertions(+), 307 deletions(-) create mode 100644 ksw_extend2_avx2.c create mode 100644 ksw_extend2_avx2_u8.c diff --git a/.vscode/launch.json b/.vscode/launch.json index 2fef6c8..abe1e3d 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -18,8 +18,8 @@ "-R", "'@RG\\tID:normal\\tSM:normal\\tPL:illumina\\tLB:normal\\tPG:bwa'", "~/reference/human_g1k_v37_decoy.fasta", - "~/fastq/diff_r1.fq", - "~/fastq/diff_r2.fq", + "~/fastq/ssn_r1.fq", + "~/fastq/ssn_r2.fq", "-o", "/dev/null" ], diff --git a/.vscode/settings.json b/.vscode/settings.json index 77e1001..5d14020 100644 --- a/.vscode/settings.json +++ b/.vscode/settings.json @@ -9,6 +9,7 @@ "istream": "c", "limits": "c", "bit": "c", - "numeric": "c" + "numeric": "c", + "typeinfo": "c" } } \ No newline at end of file diff --git a/Makefile b/Makefile index 263ee60..9679d71 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ CC= gcc #CC= clang --analyze # CFLAGS= -g -Wall -Wno-unused-function -O2 -CFLAGS= -g -Wall -Wno-unused-function -O2 +CFLAGS= -g -Wall -Wno-unused-function -mavx2 -O2 WRAP_MALLOC=-DUSE_MALLOC_WRAPPERS SHOW_PERF= -DSHOW_PERF AR= ar @@ -12,7 +12,7 @@ AOBJS= bwashm.o bwase.o bwaseqio.o bwtgap.o bwtaln.o bamlite.o \ bwape.o kopen.o pemerge.o maxk.o \ bwtsw2_core.o bwtsw2_main.o bwtsw2_aux.o bwt_lite.o \ bwtsw2_chain.o fastmap.o bwtsw2_pair.o \ - fmt_idx.o + fmt_idx.o ksw_extend2_avx2.o ksw_extend2_avx2_u8.o PROG= bwa INCLUDES= LIBS= -lm -lz -lpthread -ldl diff --git a/bwa.c b/bwa.c index ba14350..26c3e4c 100644 --- a/bwa.c +++ b/bwa.c @@ -339,9 +339,9 @@ bwaidx_t *bwa_idx_load_from_disk(const char *hint, int which) return 0; } idx = calloc(1, sizeof(bwaidx_t)); - if (which & BWA_IDX_BWT) idx->bwt = bwa_idx_load_bwt(hint); if (which & BWA_IDX_BWT) idx->fmt = bwa_idx_load_fmt(hint); - idx->bwt->kmer_hash = idx->fmt->kmer_hash; + //if (which & BWA_IDX_BWT) idx->bwt = bwa_idx_load_bwt(hint); + //idx->bwt->kmer_hash = idx->fmt->kmer_hash; if (which & BWA_IDX_BNS) { diff --git a/bwamem.c b/bwamem.c index 3a264ff..8e5d465 100644 --- a/bwamem.c +++ b/bwamem.c @@ -118,163 +118,147 @@ mem_opt_t *mem_opt_init() KSORT_INIT(mem_intv, bwtintv_t, intv_lt) typedef struct { - bwtintv_v mem, mem1, *tmpv[2]; + int *full_match; + bwtintv_v *mem, *mem1, *tmpv[2]; } smem_aux_t; -static smem_aux_t *smem_aux_init() +static smem_aux_t *smem_aux_init(int batch_size) { smem_aux_t *a; a = calloc(1, sizeof(smem_aux_t)); + a->mem = calloc(batch_size, sizeof(bwtintv_v)); + a->mem1 = calloc(batch_size, sizeof(bwtintv_v)); a->tmpv[0] = calloc(1, sizeof(bwtintv_v)); a->tmpv[1] = calloc(1, sizeof(bwtintv_v)); + a->full_match = calloc(batch_size, sizeof(int)); return a; } -static void smem_aux_destroy(smem_aux_t *a) +static void smem_aux_destroy(smem_aux_t *a, int batch_size) { + int i; free(a->tmpv[0]->a); free(a->tmpv[0]); free(a->tmpv[1]->a); free(a->tmpv[1]); - free(a->mem.a); free(a->mem1.a); + for (i = 0; i < batch_size; ++i) { + free(a->mem[i].a); free(a->mem1[i].a); + } + free(a->mem); free(a->mem1); + free(a->full_match); free(a); } #define USE_FMT 1 -static void mem_collect_intv(const mem_opt_t *opt, const bwt_t *bwt, const FMTIndex *fmt, int len, const uint8_t *seq, smem_aux_t *a) +static void mem_collect_intv(const mem_opt_t *opt, const bwt_t *bwt, const FMTIndex *fmt, bseq1_t *seq_arr, int nseq, smem_aux_t *a) { - int i, k, x = 0, old_n; + int si, i, k, x, old_n, len, slen, start, end; int start_width = 1; int split_len = (int)(opt->min_seed_len * opt->split_factor + .499); - int max_seed_len = 0; - int start_N_num = 0, start_flag = 1; - a->mem.n = 0; + int max_seed_len, start_N_num, start_flag; + uint8_t *seq; + bwtintv_v *mem1, *mem; + bwtintv_t *p; - // first pass: find all SMEMs - fprintf(fp1, "seq: %ld\n", dn++); - // fprintf(stderr, "seq: %ld\n", dn++); - // dn ++; - // goto third_seed; - - while (x < len) { - if (seq[x] < 4) { - start_flag = 0; + // 1. first pass: find all SMEMs #ifdef SHOW_PERF - int64_t tmp_time = realtime_msec(); + int64_t tmp_time = realtime_msec(); #endif -#if USE_FMT - x = fmt_smem(fmt, len, seq, x, start_width, opt->min_seed_len, &a->mem1, a->tmpv[0]); -#else - x = bwt_smem1(bwt, len, seq, x, start_width, &a->mem1, a->tmpv); -#endif -#ifdef SHOW_PERF - tmp_time = realtime_msec() - tmp_time; - __sync_fetch_and_add(&time_seed_1, tmp_time); -#endif - s1n += a->mem1.n; - for (i = 0; i < a->mem1.n; ++i) { - bwtintv_t *p = &a->mem1.a[i]; - //fprintf(fp1, "%ld %ld %d\n", p->x[2], p->info >> 32, (uint32_t)p->info); - //fprintf(fp1, "%ld %ld %ld %ld %d\n", p->x[0], p->x[1], p->x[2], p->info >> 32, (uint32_t)p->info); - //fprintf(stderr, "%ld %ld %ld %ld %d\n", p->x[0], p->x[1], p->x[2], p->info >> 32, (uint32_t)p->info); - int slen = (uint32_t)p->info - (p->info >> 32); // seed length - s1l += slen; - max_seed_len = fmax(max_seed_len, slen); - if (slen >= opt->min_seed_len) { - //fprintf(fp1, "%ld %ld %d\n", p->x[2], p->info >> 32, (uint32_t)p->info); - kv_push(bwtintv_t, a->mem, *p); - } - } - } else { - ++x; - if (start_flag) - ++start_N_num; - } - } - // second pass: find MEMs inside a long SMEM - //if (max_seed_len == len - start_N_num) - // goto collect_intv_end; - //goto third_seed; - - old_n = a->mem.n; - for (k = 0; k < old_n; ++k) { - bwtintv_t *p = &a->mem.a[k]; - int start = p->info>>32, end = (int32_t)p->info; - if (end - start < split_len || p->x[2] > opt->split_width) continue; -#ifdef SHOW_PERF - int64_t tmp_time = realtime_msec(); -#endif -#if USE_FMT - fmt_smem(fmt, len, seq, (start + end) >> 1, p->x[2] + 1, opt->min_seed_len, &a->mem1, a->tmpv[0]); -#else - bwt_smem1(bwt, len, seq, (start + end) >> 1, p->x[2] + 1, &a->mem1, a->tmpv); -#endif -#ifdef SHOW_PERF - tmp_time = realtime_msec() - tmp_time; - __sync_fetch_and_add(&time_seed_2, tmp_time); -#endif - s2n += a->mem1.n; - - for (i = 0; i < a->mem1.n; ++i) { - bwtintv_t *p = &a->mem1.a[i]; - //fprintf(fp1, "%ld %ld %d\n", p->x[2], p->info >> 32, (uint32_t)p->info); - // fprintf(fp1, "%ld %ld %ld %ld %d\n", p->x[0], p->x[1], p->x[2], p->info >> 32, (uint32_t)p->info); - // fprintf(stderr, "%ld %ld %ld %ld %d\n", p->x[0], p->x[1], p->x[2], p->info >> 32, (uint32_t)p->info); - int slen = (uint32_t)p->info - (p->info >> 32); - s2l += slen; - if (slen >= opt->min_seed_len) { - g_num_smem2 += 1; - fprintf(fp1, "%ld %ld %d\n", p->x[2], p->info >> 32, (uint32_t)p->info); - kv_push(bwtintv_t, a->mem, a->mem1.a[i]); - } - } - } - //if (max_seed_len == len - start_N_num) - // goto collect_intv_end; - //goto collect_intv_end; -third_seed: - // third pass: LAST-like - if (opt->max_mem_intv > 0) { - x = 0; + for (si = 0; si < nseq; ++si) { + x = 0; max_seed_len = 0; start_N_num = 0; start_flag = 1; + len = seq_arr[si].l_seq; seq = (uint8_t*) seq_arr[si].seq; + mem1 = &a->mem1[si]; mem = &a->mem[si]; mem->n = 0; while (x < len) { if (seq[x] < 4) { - if (1) { - bwtintv_t m; + start_flag = 0; +#if USE_FMT + x = fmt_smem(fmt, len, seq, x, start_width, opt->min_seed_len, mem1, a->tmpv[0]); +#else + x = bwt_smem1(bwt, len, seq, x, start_width, mem1, a->tmpv); +#endif + for (i = 0; i < mem1->n; ++i) { + p = &mem1->a[i]; + slen = (uint32_t)p->info - (p->info >> 32); // seed length + max_seed_len = MAX(max_seed_len, slen); + if (slen >= opt->min_seed_len) { + kv_push(bwtintv_t, *mem, *p); + } + } + } else { + ++x; + if (start_flag) ++start_N_num; + } + } + if (max_seed_len == len - start_N_num) a->full_match[si] = 1; + } #ifdef SHOW_PERF - int64_t tmp_time = realtime_msec(); + tmp_time = realtime_msec() - tmp_time; + __sync_fetch_and_add(&time_seed_1, tmp_time); #endif +#ifdef SHOW_PERF + tmp_time = realtime_msec(); +#endif + // 2. second pass: find MEMs inside a long SMEM + for (si = 0; si < nseq; ++si) { + len = seq_arr[si].l_seq; seq = (uint8_t*) seq_arr[si].seq; + mem1 = &a->mem1[si]; mem = &a->mem[si]; + old_n = mem->n; +// if (a->full_match[si]) continue; + for (k = 0; k < old_n; ++k) { + p = &mem->a[k]; + start = p->info >> 32; end = (int32_t)p->info; + if (end - start < split_len || p->x[2] > opt->split_width) continue; +#if USE_FMT + fmt_smem(fmt, len, seq, (start + end) >> 1, p->x[2] + 1, opt->min_seed_len, mem1, a->tmpv[0]); +#else + bwt_smem1(bwt, len, seq, (start + end) >> 1, p->x[2] + 1, mem1, a->tmpv); +#endif + for (i = 0; i < mem1->n; ++i) { + p = &mem1->a[i]; + slen = (uint32_t)p->info - (p->info >> 32); + if (slen >= opt->min_seed_len) { + kv_push(bwtintv_t, *mem, *p); + } + } + } + } +#ifdef SHOW_PERF + tmp_time = realtime_msec() - tmp_time; + __sync_fetch_and_add(&time_seed_2, tmp_time); +#endif + +#ifdef SHOW_PERF + tmp_time = realtime_msec(); +#endif + // 3. third pass: LAST-like + if (opt->max_mem_intv > 0) { + for (si = 0; si < nseq; ++si) { + // if (a->full_match[si]) continue; + len = seq_arr[si].l_seq; seq = (uint8_t*) seq_arr[si].seq; + x = 0; mem = &a->mem[si]; + while (x < len) { + if (seq[x] < 4) { + bwtintv_t m; #if USE_FMT x = fmt_seed_strategy1(fmt, len, seq, x, opt->min_seed_len, opt->max_mem_intv, &m); #else x = bwt_seed_strategy1(bwt, len, seq, x, opt->min_seed_len, opt->max_mem_intv, &m); #endif -#ifdef SHOW_PERF - tmp_time = realtime_msec() - tmp_time; - __sync_fetch_and_add(&time_seed_3, tmp_time); -#endif - s3n += 1; - s3l += (uint32_t)m.info - (m.info >> 32); - // bwtintv_t *p = &m; - // fprintf(fp1, "%ld %ld %ld %ld %d\n", p->x[0], p->x[1], p->x[2], p->info >> 32, (uint32_t)p->info); - if (m.x[2] > 0) { - kv_push(bwtintv_t, a->mem, m); - //bwtintv_t *p = &m; - //fprintf(fp1, "%ld %ld %d\n", p->x[2], p->info >> 32, (uint32_t)p->info); - //fprintf(fp1, "%ld %ld %ld %ld %d\n", p->x[0], p->x[1], p->x[2], p->info >> 32, (uint32_t)p->info); + kv_push(bwtintv_t, *mem, m); } - } else { // for now, we never come to this block which is slower - x = bwt_smem1a(bwt, len, seq, x, start_width, opt->max_mem_intv, &a->mem1, a->tmpv); - for (i = 0; i < a->mem1.n; ++i) - kv_push(bwtintv_t, a->mem, a->mem1.a[i]); - } - } else ++x; + } else ++x; + } } } - -//collect_intv_end: - // sort - ks_introsort(mem_intv, a->mem.n, a->mem.a); +#ifdef SHOW_PERF + tmp_time = realtime_msec() - tmp_time; + __sync_fetch_and_add(&time_seed_3, tmp_time); +#endif + // 4. sort + for (si = 0; si < nseq; ++si) { + ks_introsort(mem_intv, a->mem[si].n, a->mem[si].a); + } } /************ @@ -364,130 +348,107 @@ void mem_print_chain(const bntseq_t *bns, mem_chain_v *chn) } } -mem_chain_v mem_chain(const mem_opt_t *opt, const bwt_t *bwt, const FMTIndex *fmt, const bntseq_t *bns, int len, const uint8_t *seq, void *buf) +void mem_chain(const mem_opt_t *opt, const bwt_t *bwt, const FMTIndex *fmt, const bntseq_t *bns, bseq1_t *seq_arr, int nseq, mem_chain_v *chns, void *buf) { - int i, b, e, l_rep; - int64_t l_pac = bns->l_pac; - mem_chain_v chain; + int si, i, b, e, l_rep, len; + int64_t l_pac = bns->l_pac, k; + mem_chain_v *chain; kbtree_t(chn) *tree; smem_aux_t *aux; + bwtintv_v *mem; - kv_init(chain); - if (len < opt->min_seed_len) return chain; // if the query is shorter than the seed length, no match tree = kb_init(chn, KB_DEFAULT_SIZE); + aux = (smem_aux_t*)buf; + // 1. find smem + mem_collect_intv(opt, bwt, fmt, seq_arr, nseq, aux); - aux = buf? (smem_aux_t*)buf : smem_aux_init(); - mem_collect_intv(opt, bwt, fmt, len, seq, aux); - for (i = 0, b = e = l_rep = 0; i < aux->mem.n; ++i) { // compute frac_rep - bwtintv_t *p = &aux->mem.a[i]; - int sb = (p->info>>32), se = (uint32_t)p->info; - if (p->x[2] <= opt->max_occ) continue; - if (sb > e) l_rep += e - b, b = sb, e = se; - else e = e > se? e : se; + // 2. chain +#define CHECK_ADD_CHAIN(tmp, lower, upper) \ + int rid, to_add = 0; \ + mem_chain_t tmp, *lower, *upper; \ + tmp.pos = s.rbeg; \ + s.qbeg = p->info >> 32; \ + s.score = s.len = slen; \ + rid = bns_intv2rid(bns, s.rbeg, s.rbeg + s.len); \ + if (rid < 0) \ + continue; \ + if (kb_size(tree)) \ + { \ + kb_intervalp(chn, tree, &tmp, &lower, &upper); \ + if (!lower || !test_and_merge(opt, l_pac, lower, &s, rid)) \ + to_add = 1; \ + } \ + else \ + to_add = 1; \ + if (to_add) \ + { \ + tmp.n = 1; \ + tmp.m = 4; \ + tmp.seeds = calloc(tmp.m, sizeof(mem_seed_t)); \ + tmp.seeds[0] = s; \ + tmp.rid = rid; \ + tmp.is_alt = !!bns->anns[rid].is_alt; \ + kb_putp(chn, tree, &tmp); \ } - l_rep += e - b; - for (i = 0; i < aux->mem.n; ++i) { - bwtintv_t *p = &aux->mem.a[i]; - int step, count, slen = (uint32_t)p->info - (p->info>>32); // seed length - int64_t k; - // if (slen < opt->min_seed_len) continue; // ignore if too short or too repetitive - //if (0) { - if (p->num_match > 0) { - //continue; - for (k = 0; k < p->num_match; ++k) { - mem_chain_t tmp, *lower, *upper; - mem_seed_t s; - int rid, to_add = 0; - s.rbeg = p->rm[k].rs; - if (p->rm[k].reverse) { - s.rbeg = (fmt->l_pac << 1) - 1 - s.rbeg; - } - tmp.pos = s.rbeg; - s.qbeg = p->info >> 32; - s.score = s.len = slen; - rid = bns_intv2rid(bns, s.rbeg, s.rbeg + s.len); - if (rid < 0) continue; - if (kb_size(tree)) - { - kb_intervalp(chn, tree, &tmp, &lower, &upper); // find the closest chain - if (!lower || !test_and_merge(opt, l_pac, lower, &s, rid)) - to_add = 1; - } - else - to_add = 1; - if (to_add) - { // add the seed as a new chain - tmp.n = 1; - tmp.m = 4; - tmp.seeds = calloc(tmp.m, sizeof(mem_seed_t)); - tmp.seeds[0] = s; - tmp.rid = rid; - tmp.is_alt = !!bns->anns[rid].is_alt; - kb_putp(chn, tree, &tmp); - } - } - continue; + + for (si = 0; si < nseq; ++si) { + tree->n_keys = 0; + tree->n_nodes = 1; + tree->root->n = 0; + tree->root->is_internal = 0; + //tree = kb_init(chn, KB_DEFAULT_SIZE); + + len = seq_arr[si].l_seq; + mem = &aux->mem[si]; + for (i = 0, b = e = l_rep = 0; i < mem->n; ++i) { // compute frac_rep + bwtintv_t *p = &mem->a[i]; + int sb = (p->info>>32), se = (uint32_t)p->info; + if (p->x[2] <= opt->max_occ) continue; + if (sb > e) l_rep += e - b, b = sb, e = se; + else e = e > se? e : se; } - step = p->x[2] > opt->max_occ? p->x[2] / opt->max_occ : 1; - for (k = count = 0; k < p->x[2] && count < opt->max_occ; k += step, ++count) { - mem_chain_t tmp, *lower, *upper; - mem_seed_t s; - int rid, to_add = 0; + l_rep += e - b; + + for (i = 0; i < mem->n; ++i) { + bwtintv_t *p = &mem->a[i]; + int step, count, slen = (uint32_t)p->info - (p->info >> 32); // seed length + if (p->num_match > 0) { + mem_seed_t s; + s.rbeg = p->rm[0].rs; + if (p->rm[0].reverse) s.rbeg = (fmt->l_pac << 1) - 1 - s.rbeg; + CHECK_ADD_CHAIN(tmp, lower, upper); + } else { + step = p->x[2] > opt->max_occ ? p->x[2] / opt->max_occ : 1; + for (k = count = 0; k < p->x[2] && count < opt->max_occ; k += step, ++count) { + mem_seed_t s; #ifdef SHOW_PERF - int64_t tmp_time = realtime_msec(); + int64_t tmp_time = realtime_msec(); #endif #if USE_FMT - s.rbeg = tmp.pos = fmt_sa(fmt, p->x[0] + k); - // uint64_t tpos = bwt_sa(bwt, p->x[0] + k); - // if (s.rbeg != tpos) { - // fprintf(stderr, "diff: %ld, %ld %ld\n", p->x[0] + k, tmp.pos, tpos); - // } + s.rbeg = fmt_sa(fmt, p->x[0] + k); #else - s.rbeg = tmp.pos = bwt_sa(bwt, p->x[0] + k); // this is the base coordinate in the forward-reverse reference + s.rbeg = bwt_sa(bwt, p->x[0] + k); // this is the base coordinate in the forward-reverse reference #endif #ifdef SHOW_PERF - tmp_time = realtime_msec() - tmp_time; - __sync_fetch_and_add(&time_bwt_sa, tmp_time); - __sync_fetch_and_add(&num_sa, 1); + tmp_time = realtime_msec() - tmp_time; + __sync_fetch_and_add(&time_bwt_sa, tmp_time); #endif - s.qbeg = p->info >> 32; - s.score= s.len = slen; -#ifdef SHOW_PERF - tmp_time = realtime_msec(); -#endif - rid = bns_intv2rid(bns, s.rbeg, s.rbeg + s.len); -#ifdef SHOW_PERF - tmp_time = realtime_msec() - tmp_time; - __sync_fetch_and_add(&time_bns, tmp_time); -#endif - if (rid < 0) continue; // bridging multiple reference sequences or the forward-reverse boundary; TODO: split the seed; don't discard it!!! - if (kb_size(tree)) { - kb_intervalp(chn, tree, &tmp, &lower, &upper); // find the closest chain - if (!lower || !test_and_merge(opt, l_pac, lower, &s, rid)) to_add = 1; - } else to_add = 1; - if (to_add) { // add the seed as a new chain - tmp.n = 1; tmp.m = 4; - tmp.seeds = calloc(tmp.m, sizeof(mem_seed_t)); - tmp.seeds[0] = s; - tmp.rid = rid; - tmp.is_alt = !!bns->anns[rid].is_alt; - kb_putp(chn, tree, &tmp); + CHECK_ADD_CHAIN(tmp, lower, upper); + } } } + + chain = &chns[si]; + kv_resize(mem_chain_t, *chain, kb_size(tree)); +#define traverse_func(p_) (chain->a[chain->n++] = *(p_)) + __kb_traverse(mem_chain_t, tree, traverse_func); +#undef traverse_func + + for (i = 0; i < chain->n; ++i) chain->a[i].frac_rep = (float)l_rep / len; + if (bwa_verbose >= 4) printf("* fraction of repetitive seeds: %.3f\n", (float)l_rep / len); +// kb_destroy(chn, tree); } - if (buf == 0) smem_aux_destroy(aux); - - kv_resize(mem_chain_t, chain, kb_size(tree)); - - #define traverse_func(p_) (chain.a[chain.n++] = *(p_)) - __kb_traverse(mem_chain_t, tree, traverse_func); - #undef traverse_func - - for (i = 0; i < chain.n; ++i) chain.a[i].frac_rep = (float)l_rep / len; - if (bwa_verbose >= 4) printf("* fraction of repetitive seeds: %.3f\n", (float)l_rep / len); - kb_destroy(chn, tree); - return chain; } /******************** @@ -907,7 +868,8 @@ void mem_chain2aln(const mem_opt_t *opt, const bntseq_t *bns, const uint8_t *pac #ifdef SHOW_PERF int64_t tmp_time = realtime_msec(); #endif - a->score = ksw_extend2(s->qbeg, qs, tmp, rs, 5, opt->mat, opt->o_del, opt->e_del, opt->o_ins, opt->e_ins, aw[0], opt->pen_clip5, opt->zdrop, s->len * opt->a, &qle, &tle, >le, &gscore, &max_off[0]); +// a->score = ksw_extend2(s->qbeg, qs, tmp, rs, 5, opt->mat, opt->o_del, opt->e_del, opt->o_ins, opt->e_ins, aw[0], opt->pen_clip5, opt->zdrop, s->len * opt->a, &qle, &tle, >le, &gscore, &max_off[0]); + a->score = ksw_extend2_avx2(s->qbeg, query, tmp, rseq, 1, 5, opt->mat, opt->o_del, opt->e_del, opt->o_ins, opt->e_ins, opt->a, opt->b, aw[0], opt->pen_clip5, opt->zdrop, s->len * opt->a, &qle, &tle, >le, &gscore, &max_off[0]); #ifdef SHOW_PERF tmp_time = realtime_msec() - tmp_time; __sync_fetch_and_add(&time_ksw_extend2, tmp_time); @@ -942,7 +904,8 @@ void mem_chain2aln(const mem_opt_t *opt, const bntseq_t *bns, const uint8_t *pac #ifdef SHOW_PERF int64_t tmp_time = realtime_msec(); #endif - a->score = ksw_extend2(l_query - qe, query + qe, rmax[1] - rmax[0] - re, rseq + re, 5, opt->mat, opt->o_del, opt->e_del, opt->o_ins, opt->e_ins, aw[1], opt->pen_clip3, opt->zdrop, sc0, &qle, &tle, >le, &gscore, &max_off[1]); + //a->score = ksw_extend2(l_query - qe, query + qe, rmax[1] - rmax[0] - re, rseq + re, 5, opt->mat, opt->o_del, opt->e_del, opt->o_ins, opt->e_ins, aw[1], opt->pen_clip3, opt->zdrop, sc0, &qle, &tle, >le, &gscore, &max_off[1]); + a->score = ksw_extend2_avx2(l_query - qe, query + qe, rmax[1] - rmax[0] - re, rseq + re, 0, 5, opt->mat, opt->o_del, opt->e_del, opt->o_ins, opt->e_ins, opt->a, opt->b, aw[1], opt->pen_clip3, opt->zdrop, sc0, &qle, &tle, >le, &gscore, &max_off[1]); #ifdef SHOW_PERF tmp_time = realtime_msec() - tmp_time; __sync_fetch_and_add(&time_ksw_extend2, tmp_time); @@ -1242,42 +1205,51 @@ void mem_reg2sam(const mem_opt_t *opt, const bntseq_t *bns, const uint8_t *pac, } } -mem_alnreg_v mem_align1_core(const mem_opt_t *opt, const bwt_t *bwt, const FMTIndex *fmt, const bntseq_t *bns, const uint8_t *pac, int l_seq, char *seq, void *buf) +void mem_align1_core(const mem_opt_t *opt, const bwt_t *bwt, const FMTIndex *fmt, const bntseq_t *bns, const uint8_t *pac, bseq1_t *seq_arr, int nseq, mem_chain_v *chns, mem_alnreg_v *regs, void *buf) { - int i; - mem_chain_v chn; - mem_alnreg_v regs; - - for (i = 0; i < l_seq; ++i) // convert to 2-bit encoding if we have not done so - seq[i] = seq[i] < 4? seq[i] : nst_nt4_table[(int)seq[i]]; - - chn = mem_chain(opt, bwt, fmt, bns, l_seq, (uint8_t*)seq, buf); - chn.n = mem_chain_flt(opt, chn.n, chn.a); - mem_flt_chained_seeds(opt, bns, pac, l_seq, (uint8_t*)seq, chn.n, chn.a); - if (bwa_verbose >= 4) mem_print_chain(bns, &chn); - - kv_init(regs); - for (i = 0; i < chn.n; ++i) { - mem_chain_t *p = &chn.a[i]; - if (bwa_verbose >= 4) err_printf("* ---> Processing chain(%d) <---\n", i); - mem_chain2aln(opt, bns, pac, l_seq, (uint8_t*)seq, p, ®s); - free(chn.a[i].seeds); + int si, i, j; + // 1. 将seq都转成0,1,2,3 + for (si = 0; si < nseq; ++si) { // convert to 2-bit encoding if we have not done so + const int l_seq = seq_arr[si].l_seq; + char *seq = seq_arr[si].seq; + for (j = 0; j < l_seq; ++j) seq[j] = seq[j] < 4 ? seq[j] : nst_nt4_table[(int)seq[j]]; } - free(chn.a); - regs.n = mem_sort_dedup_patch(opt, bns, pac, (uint8_t*)seq, regs.n, regs.a); - if (bwa_verbose >= 4) { - err_printf("* %ld chains remain after removing duplicated chains\n", regs.n); - for (i = 0; i < regs.n; ++i) { - mem_alnreg_t *p = ®s.a[i]; - printf("** %d, [%d,%d) <=> [%ld,%ld)\n", p->score, p->qb, p->qe, (long)p->rb, (long)p->re); + + // 2. find smem and chain + mem_chain(opt, bwt, fmt, bns, seq_arr, nseq, chns, buf); + + // 3. filter chain + for (si = 0; si < nseq; ++si) { + chns[si].n = mem_chain_flt(opt, chns[si].n, chns[si].a); + mem_flt_chained_seeds(opt, bns, pac, seq_arr[si].l_seq, (uint8_t *)seq_arr[si].seq, chns[si].n, chns[si].a); + if (bwa_verbose >= 4) mem_print_chain(bns, &chns[si]); + } + + // 4. extend + for (si = 0; si < nseq; ++si) { + mem_chain_v *chn = &chns[si]; + for (i = 0; i < chn->n; ++i) { + mem_chain_t *p = &chn->a[i]; + if (bwa_verbose >= 4) err_printf("* ---> Processing chain(%d) <---\n", i); + mem_chain2aln(opt, bns, pac, seq_arr[si].l_seq, (uint8_t *)seq_arr[si].seq, p, ®s[si]); + free(chn->a[i].seeds); + } + free(chn->a); + regs[si].n = mem_sort_dedup_patch(opt, bns, pac, (uint8_t *)seq_arr[si].seq, regs[si].n, regs[si].a); + + if (bwa_verbose >= 4) { + err_printf("* %ld chains remain after removing duplicated chains\n", regs[si].n); + for (i = 0; i < regs[si].n; ++i) { + mem_alnreg_t *p = ®s[si].a[i]; + printf("** %d, [%d,%d) <=> [%ld,%ld)\n", p->score, p->qb, p->qe, (long)p->rb, (long)p->re); + } + } + for (i = 0; i < regs[si].n; ++i) { + mem_alnreg_t *p = ®s[si].a[i]; + if (p->rid >= 0 && bns->anns[p->rid].is_alt) + p->is_alt = 1; } } - for (i = 0; i < regs.n; ++i) { - mem_alnreg_t *p = ®s.a[i]; - if (p->rid >= 0 && bns->anns[p->rid].is_alt) - p->is_alt = 1; - } - return regs; } mem_aln_t mem_reg2aln(const mem_opt_t *opt, const bntseq_t *bns, const uint8_t *pac, int l_query, const char *query_, const mem_alnreg_t *ar) @@ -1353,6 +1325,7 @@ mem_aln_t mem_reg2aln(const mem_opt_t *opt, const bntseq_t *bns, const uint8_t * } typedef struct { + int num_reads; const mem_opt_t *opt; const bwt_t *bwt; const FMTIndex *fmt; @@ -1361,6 +1334,7 @@ typedef struct { const mem_pestat_t *pes; smem_aux_t **aux; bseq1_t *seqs; + mem_chain_v *chns; mem_alnreg_v *regs; int64_t n_processed; } worker_t; @@ -1368,15 +1342,19 @@ typedef struct { static void worker1(void *data, int i, int tid) { worker_t *w = (worker_t*)data; - if (!(w->opt->flag&MEM_F_PE)) { - if (bwa_verbose >= 4) printf("=====> Processing read '%s' <=====\n", w->seqs[i].name); - w->regs[i] = mem_align1_core(w->opt, w->bwt, w->fmt, w->bns, w->pac, w->seqs[i].l_seq, w->seqs[i].seq, w->aux[tid]); - } else { - if (bwa_verbose >= 4) printf("=====> Processing read '%s'/1 <=====\n", w->seqs[i<<1|0].name); - w->regs[i<<1|0] = mem_align1_core(w->opt, w->bwt, w->fmt, w->bns, w->pac, w->seqs[i<<1|0].l_seq, w->seqs[i<<1|0].seq, w->aux[tid]); - if (bwa_verbose >= 4) printf("=====> Processing read '%s'/2 <=====\n", w->seqs[i<<1|1].name); - w->regs[i<<1|1] = mem_align1_core(w->opt, w->bwt, w->fmt, w->bns, w->pac, w->seqs[i<<1|1].l_seq, w->seqs[i<<1|1].seq, w->aux[tid]); - } + int start = i * w->opt->batch_size; + int end = MIN(start + w->opt->batch_size, w->num_reads); + mem_align1_core(w->opt, w->bwt, w->fmt, w->bns, w->pac, w->seqs + start, end - start, w->chns + start, w->regs + start, w->aux[tid]); + //if (!(w->opt->flag&MEM_F_PE)) { + // if (bwa_verbose >= 4) printf("=====> Processing read '%s' <=====\n", w->seqs[i].name); + // w->regs[i] = mem_align1_core(w->opt, w->bwt, w->fmt, w->bns, w->pac, w->seqs[i].l_seq, w->seqs[i].seq, w->aux[tid]); + // + //} else { + // if (bwa_verbose >= 4) printf("=====> Processing read '%s'/1 <=====\n", w->seqs[i<<1|0].name); + // w->regs[i<<1|0] = mem_align1_core(w->opt, w->bwt, w->fmt, w->bns, w->pac, w->seqs[i<<1|0].l_seq, w->seqs[i<<1|0].seq, w->aux[tid]); + // if (bwa_verbose >= 4) printf("=====> Processing read '%s'/2 <=====\n", w->seqs[i<<1|1].name); + // w->regs[i<<1|1] = mem_align1_core(w->opt, w->bwt, w->fmt, w->bns, w->pac, w->seqs[i<<1|1].l_seq, w->seqs[i<<1|1].seq, w->aux[tid]); + //} } static void worker2(void *data, int i, int tid) @@ -1399,31 +1377,49 @@ static void worker2(void *data, int i, int tid) void mem_process_seqs(const mem_opt_t *opt, const bwt_t *bwt, const FMTIndex *fmt, const bntseq_t *bns, const uint8_t *pac, int64_t n_processed, int n, bseq1_t *seqs, const mem_pestat_t *pes0) { +#ifdef SHOW_PERF + int64_t tmp_time = realtime_msec(); +#endif extern void kt_for(int n_threads, void (*func)(void*,int,int), void *data, int n); worker_t w; mem_pestat_t pes[4]; double ctime, rtime; int i; - + int n_batch = (n + opt->batch_size - 1) / opt->batch_size; + w.num_reads = n; + ctime = cputime(); rtime = realtime(); global_bns = bns; - w.regs = malloc(n * sizeof(mem_alnreg_v)); + w.regs = calloc(n, sizeof(mem_alnreg_v)); + w.chns = calloc(n, sizeof(mem_chain_v)); w.opt = opt; w.bwt = bwt; w.bns = bns; w.pac = pac; w.seqs = seqs; w.n_processed = n_processed; w.pes = &pes[0]; w.fmt = fmt; w.aux = malloc(opt->n_threads * sizeof(smem_aux_t)); for (i = 0; i < opt->n_threads; ++i) - w.aux[i] = smem_aux_init(); - kt_for(opt->n_threads, worker1, &w, (opt->flag&MEM_F_PE)? n>>1 : n); // find mapping positions + w.aux[i] = smem_aux_init(opt->batch_size); + + //kt_for(opt->n_threads, worker1, &w, (opt->flag&MEM_F_PE)? n>>1 : n); // find mapping positions + + kt_for(opt->n_threads, worker1, &w, n_batch); // find mapping positions + for (i = 0; i < opt->n_threads; ++i) - smem_aux_destroy(w.aux[i]); + smem_aux_destroy(w.aux[i], opt->batch_size); free(w.aux); + free(w.chns); + if (opt->flag&MEM_F_PE) { // infer insert sizes if not provided if (pes0) memcpy(pes, pes0, 4 * sizeof(mem_pestat_t)); // if pes0 != NULL, set the insert-size distribution as pes0 else mem_pestat(opt, bns->l_pac, n, w.regs, pes); // otherwise, infer the insert size distribution from data } + kt_for(opt->n_threads, worker2, &w, (opt->flag&MEM_F_PE)? n>>1 : n); // generate alignment free(w.regs); + if (bwa_verbose >= 3) fprintf(stderr, "[M::%s] Processed %d reads in %.3f CPU sec, %.3f real sec\n", __func__, n, cputime() - ctime, realtime() - rtime); +#ifdef SHOW_PERF + tmp_time = realtime_msec() - tmp_time; + __sync_fetch_and_add(&time_core_process, tmp_time); +#endif } diff --git a/bwamem.h b/bwamem.h index 6f92c5e..d6e2730 100644 --- a/bwamem.h +++ b/bwamem.h @@ -71,6 +71,7 @@ typedef struct { int max_occ; // skip a seed if its occurence is larger than this value int max_chain_gap; // do not chain seed if it is max_chain_gap-bp away from the closest seed int n_threads; // number of threads + int batch_size; // batch size of seqs to process at one time int chunk_size; // process chunk_size-bp sequences in a batch float mask_level; // regard a hit as redundant if the overlap with another better hit is over mask_level times the min length of the two hits float drop_ratio; // drop a chain if its seed coverage is below drop_ratio times the seed coverage of a better chain overlapping with the small chain diff --git a/fastmap.c b/fastmap.c index 6a3aa35..7eb89a3 100644 --- a/fastmap.c +++ b/fastmap.c @@ -56,7 +56,8 @@ int64_t time_ksw_extend2 = 0, time_bwt_occ4 = 0, time_bwt_sa = 0, time_bwt_sa_read = 0, - time_bns = 0; + time_bns = 0, + time_core_process = 0; int64_t dn = 0, n16 = 0, n17 = 0, n18 = 0, n19 = 0, nall = 0, num_sa = 0; int64_t s1n = 0, s2n = 0, s3n = 0, s1l = 0, s2l = 0, s3l = 0; @@ -184,7 +185,7 @@ int main_mem(int argc, char *argv[]) aux.opt = opt = mem_opt_init(); memset(&opt0, 0, sizeof(mem_opt_t)); - while ((c = getopt(argc, argv, "51qpaMCSPVYjuk:c:v:s:r:t:R:A:B:O:E:U:w:L:d:T:Q:D:m:I:N:o:f:W:x:G:h:y:K:X:H:F:z:")) >= 0) { + while ((c = getopt(argc, argv, "51qpaMCSPVYjuk:c:v:s:r:t:b:R:A:B:O:E:U:w:L:d:T:Q:D:m:I:N:o:f:W:x:G:h:y:K:X:H:F:z:")) >= 0) { if (c == 'k') opt->min_seed_len = atoi(optarg), opt0.min_seed_len = 1; else if (c == '1') no_mt_io = 1; else if (c == 'x') mode = optarg; @@ -194,6 +195,7 @@ int main_mem(int argc, char *argv[]) else if (c == 'T') opt->T = atoi(optarg), opt0.T = 1; else if (c == 'U') opt->pen_unpaired = atoi(optarg), opt0.pen_unpaired = 1; else if (c == 't') opt->n_threads = atoi(optarg), opt->n_threads = opt->n_threads > 1? opt->n_threads : 1; + else if (c == 'b') opt->batch_size = atoi(optarg) >> 1 << 1, opt->batch_size = opt->batch_size > 1? opt->batch_size : 512; else if (c == 'P') opt->flag |= MEM_F_NOPAIRING; else if (c == 'a') opt->flag |= MEM_F_ALL; else if (c == 'p') opt->flag |= MEM_F_PE | MEM_F_SMARTPE; @@ -292,11 +294,13 @@ int main_mem(int argc, char *argv[]) } if (opt->n_threads < 1) opt->n_threads = 1; + if (opt->batch_size < 1) opt->batch_size = 512; if (optind + 1 >= argc || optind + 3 < argc) { fprintf(stderr, "\n"); fprintf(stderr, "Usage: bwa mem [options] [in2.fq]\n\n"); fprintf(stderr, "Algorithm options:\n\n"); fprintf(stderr, " -t INT number of threads [%d]\n", opt->n_threads); + fprintf(stderr, " -b INT batch size of reads to process at one time [%d]\n", opt->batch_size); fprintf(stderr, " -k INT minimum seed length [%d]\n", opt->min_seed_len); fprintf(stderr, " -w INT band width for banded alignment [%d]\n", opt->w); fprintf(stderr, " -d INT off-diagonal X-dropoff [%d]\n", opt->zdrop); @@ -439,13 +443,14 @@ int main_mem(int argc, char *argv[]) fprintf(stderr, "time_seed_3: %f s\n", time_seed_3 / 1000.0 / opt->n_threads); fprintf(stderr, "time_bwt_sa: %f s\n", time_bwt_sa / 1000.0 / opt->n_threads); fprintf(stderr, "time_ksw_extend2: %f s\n", time_ksw_extend2 / 1000.0 / opt->n_threads); + fprintf(stderr, "time_core_process: %f s\n", time_core_process / 1000.0 / opt->n_threads); fprintf(stderr, "time_bns: %f s\n", time_bns / 1000.0 / opt->n_threads); fprintf(stderr, "s1 num: %ld\n", s1n); fprintf(stderr, "s2 num: %ld\n", s2n); fprintf(stderr, "s3 num: %ld\n", s3n); - fprintf(stderr, "s1 len: %ld\n", s1l / s1n); - fprintf(stderr, "s2 len: %ld\n", s2l / s2n); - fprintf(stderr, "s3 len: %ld\n", s3l / s3n); +// fprintf(stderr, "s1 len: %ld\n", s1l / s1n); +// fprintf(stderr, "s2 len: %ld\n", s2l / s2n); +// fprintf(stderr, "s3 len: %ld\n", s3l / s3n); fprintf(stderr, "get sa num: %ld\n", num_sa); fprintf(stderr, "seed 2 num: %ld\n", g_num_smem2); fprintf(stderr, "\n"); diff --git a/fmt_idx.c b/fmt_idx.c index 95578b4..36f5245 100644 --- a/fmt_idx.c +++ b/fmt_idx.c @@ -565,19 +565,21 @@ int fmt_smem(const FMTIndex *fmt, int len, const uint8_t *q, int x, int min_intv if (q[i] < 4 && q[i + 1] < 4) { fmt_extend2(fmt, &ik, &ok1, &ok2, 0, 3 - q[i], 3 - q[i + 1]); -// __builtin_prefetch(fmt_occ_intv(fmt, ok2.x[1] - 1), 0, 2); -// __builtin_prefetch(fmt_occ_intv(fmt, ok2.x[1] - 1 + ok2.x[2]), 0, 2); + __builtin_prefetch(fmt_occ_intv(fmt, ok2.x[1] - 1), 0, 2); + __builtin_prefetch(fmt_occ_intv(fmt, ok2.x[1] - 1 + ok2.x[2]), 0, 2); CHECK_INTV_CHANGE(ik, ok1, i + 1); CHECK_INTV_CHANGE(ik, ok2, i + 2); - // 在这里进行判断是否只有一个候选了 - //if (min_intv == 1 && ok2.x[2] == min_intv) - //{ - // direct_extend(fmt, len, q, x, i + 2, ok2.x[0], &mt); - // kv_push(bwtintv_t, *mem, mt); - // ret = (uint32_t)mt.info; - // if (only_forward || mt.rm[0].qs == 0 || q[mt.rm[0].qs - 1] > 3) goto fmt_smem_end; - // goto backward_search; - //} + +#if 1 // 间隔为1的时候直接与reference比对 + if (min_intv == 1 && ok2.x[2] == min_intv) // 在这里进行判断是否只有一个候选了 + { + direct_extend(fmt, len, q, x, i + 2, ok2.x[0], &mt); + kv_push(bwtintv_t, *mem, mt); + ret = (uint32_t)mt.info; + if (only_forward || mt.rm[0].qs == 0 || q[mt.rm[0].qs - 1] > 3) goto fmt_smem_end; + goto backward_search; + } +#endif } else if (q[i] < 4) // q[i+1] >= 4 { fmt_extend1(fmt, &ik, &ok1, 0, 3 - q[i]); @@ -621,6 +623,7 @@ backward_search: bwtintv_t *p = &curr->a[j]; // 前向扩展的种子 // __builtin_prefetch(fmt_occ_intv(fmt, p->x[0] - 1), 0, 2); // __builtin_prefetch(fmt_occ_intv(fmt, p->x[0] - 1 + p->x[2]), 0, 2); +#if 1 if (!only_forward && p->info - x < HASH_KMER_LEN) { if (last_kmer_start && kmer_len == HASH_KMER_LEN && p->info == last_kmer_start && p->info - kmer_len > 0 && q[p->info - kmer_len] < 4) qbit = ((qbit << 2) | (3 - q[p->info - kmer_len])) & ((1L << (kmer_len << 1)) - 1); // 创建反向kmer @@ -634,13 +637,16 @@ backward_search: } else { i = x - 1; } +#else + i = x - 1; +#endif for (; i > 0; i -= 2) { if (q[i] < 4 && q[i - 1] < 4) // 两个都可以扩展 { fmt_extend2(fmt, p, &ok1, &ok2, 1, q[i], q[i - 1]); -// __builtin_prefetch(fmt_occ_intv(fmt, ok2.x[0] - 1), 0, 2); -// __builtin_prefetch(fmt_occ_intv(fmt, ok2.x[0] - 1 + ok2.x[2]), 0, 2); + __builtin_prefetch(fmt_occ_intv(fmt, ok2.x[0] - 1), 0, 2); + __builtin_prefetch(fmt_occ_intv(fmt, ok2.x[0] - 1 + ok2.x[2]), 0, 2); CHECK_INTV_ADD_MEM(ok1, i + 1, *p, mem); ok1.info = p->info; CHECK_INTV_ADD_MEM(ok2, i, ok1, mem); @@ -713,8 +719,8 @@ int fmt_seed_strategy1(const FMTIndex *fmt, int len, const uint8_t *q, int x, in if (q[i] < 4 && q[i + 1] < 4) { fmt_extend2(fmt, &ik, &ok1, &ok2, 0, 3 - q[i], 3 - q[i + 1]); -// __builtin_prefetch(fmt_occ_intv(fmt, ok2.x[1] - 1), 0, 2); -// __builtin_prefetch(fmt_occ_intv(fmt, ok2.x[1] - 1 + ok2.x[2]), 0, 2); + __builtin_prefetch(fmt_occ_intv(fmt, ok2.x[1] - 1), 0, 2); + __builtin_prefetch(fmt_occ_intv(fmt, ok2.x[1] - 1 + ok2.x[2]), 0, 2); COND_SET_RETURN(ok1, *mem, x, i, max_intv, min_len); COND_SET_RETURN(ok2, *mem, x, i + 1, max_intv, min_len); ik = ok2; diff --git a/ksw.h b/ksw.h index 5d45a67..f5b28cf 100644 --- a/ksw.h +++ b/ksw.h @@ -106,6 +106,8 @@ extern "C" { */ int ksw_extend(int qlen, const uint8_t *query, int tlen, const uint8_t *target, int m, const int8_t *mat, int gapo, int gape, int w, int end_bonus, int zdrop, int h0, int *qle, int *tle, int *gtle, int *gscore, int *max_off); int ksw_extend2(int qlen, const uint8_t *query, int tlen, const uint8_t *target, int m, const int8_t *mat, int o_del, int e_del, int o_ins, int e_ins, int w, int end_bonus, int zdrop, int h0, int *qle, int *tle, int *gtle, int *gscore, int *max_off); + int ksw_extend2_avx2(int qlen, const uint8_t *query, int tlen, const uint8_t *target, int is_left, int m, const int8_t *mat, int o_del, int e_del, + int o_ins, int e_ins, int a, int b, int w, int end_bonus, int zdrop, int h0, int *_qle, int *_tle, int *_gtle, int *_gscore, int *_max_off); #ifdef __cplusplus } diff --git a/ksw_extend2_avx2.c b/ksw_extend2_avx2.c new file mode 100644 index 0000000..7a0026b --- /dev/null +++ b/ksw_extend2_avx2.c @@ -0,0 +1,494 @@ +#include +#include +#include +#include +#include +#include +#include + +#ifdef __GNUC__ +#define LIKELY(x) __builtin_expect((x),1) +#define UNLIKELY(x) __builtin_expect((x),0) +#else +#define LIKELY(x) (x) +#define UNLIKELY(x) (x) +#endif + +#undef MAX +#undef MIN +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) +#define SIMD_WIDTH 16 + +extern int ksw_extend2_avx2_u8(int qlen, const uint8_t *query, int tlen, const uint8_t *target, int is_left, int m, const int8_t *mat, int o_del, int e_del, + int o_ins, int e_ins, int a, int b, int w, int end_bonus, int zdrop, int h0, int *_qle, int *_tle, int *_gtle, int *_gscore, int *_max_off); + +int ksw_extend2_origin(int qlen, const uint8_t *query, int tlen, const uint8_t *target, int is_left, int m, const int8_t *mat, int o_del, int e_del, + int o_ins, int e_ins, int w, int end_bonus, int zdrop, int h0, int *_qle, int *_tle, int *_gtle, int *_gscore, int *_max_off); + + +static const uint16_t h_vec_int_mask[SIMD_WIDTH][SIMD_WIDTH] = { + {0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0, 0}, + {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0, 0}, + {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0, 0}, + {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0, 0}, + {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0, 0}, + {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 0}, + {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0}, + {0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff} +}; + +#define permute_mask _MM_SHUFFLE(0, 1, 2, 3) + +// 初始化变量 +#define SIMD_INIT \ + int oe_del = o_del + e_del, oe_ins = o_ins + e_ins; \ + __m256i zero_vec; \ + __m256i max_vec; \ + __m256i oe_del_vec; \ + __m256i oe_ins_vec; \ + __m256i e_del_vec; \ + __m256i e_ins_vec; \ + __m256i h_vec_mask[SIMD_WIDTH]; \ + zero_vec = _mm256_setzero_si256(); \ + oe_del_vec = _mm256_set1_epi16(-oe_del); \ + oe_ins_vec = _mm256_set1_epi16(-oe_ins); \ + e_del_vec = _mm256_set1_epi16(-e_del); \ + e_ins_vec = _mm256_set1_epi16(-e_ins); \ + __m256i match_sc_vec = _mm256_set1_epi16(a); \ + __m256i mis_sc_vec = _mm256_set1_epi16(-b); \ + __m256i amb_sc_vec = _mm256_set1_epi16(-1); \ + __m256i amb_vec = _mm256_set1_epi16(4); \ + for (i=0; i 0) { \ + for(j=beg, i=iend; j<=end; j+=SIMD_WIDTH, i-=SIMD_WIDTH) { \ + __m256i h2_vec = _mm256_loadu_si256((__m256i*) (&hA2[j])); \ + __m256i vcmp = _mm256_cmpeq_epi16(h2_vec, max_vec); \ + uint32_t mask = _mm256_movemask_epi8(vcmp); \ + if (mask > 0) { \ + int pos = SIMD_WIDTH - 1 - (( __builtin_clz(mask)) >> 1); \ + mj = j - 1 + pos; \ + mi = i - 1 - pos; \ + } \ + } \ + } + +// 每轮迭代后,交换数组 +#define SWAP_DATA_POINTER \ + int16_t * tmp=hA0; \ + hA0 = hA1; hA1 = hA2; hA2 = tmp; \ + tmp = eA1; eA1 = eA2; eA2 = tmp; \ + tmp = fA1; fA1 = fA2; fA2 = tmp; \ + tmp = mA1; mA1 = mA2; mA2 = tmp; + + +int ksw_extend2_avx2(int qlen, // query length 待匹配段碱基的query长度 + const uint8_t *query, // read碱基序列 + int tlen, // target length reference的长度 + const uint8_t *target, // reference序列 + int is_left, // 是不是向左扩展 + int m, // 碱基种类 (5) + const int8_t *mat, // 每个位置的query和target的匹配得分 m*m + int o_del, // deletion 错配开始的惩罚系数 + int e_del, // deletion extension的惩罚系数 + int o_ins, // insertion 错配开始的惩罚系数 + int e_ins, // insertion extension的惩罚系数SIMD_BTYES + int a, // 碱基match时的分数 + int b, // 碱基mismatch时的惩罚分数(正数) + int w, // 提前剪枝系数,w =100 匹配位置和beg的最大距离 + int end_bonus, + int zdrop, + int h0, // 该seed的初始得分(完全匹配query的碱基数) + int *_qle, // 匹配得到全局最大得分的碱基在query的位置 + int *_tle, // 匹配得到全局最大得分的碱基在reference的位置 + int *_gtle, // query全部匹配上的target的长度 + int *_gscore, // query的端到端匹配得分 + int *_max_off) // 取得最大得分时在query和reference上位置差的 最大值 +{ + //return ksw_extend2_origin(qlen, query, tlen, target, is_left, m, mat, o_del, e_del, o_ins, e_ins, w, end_bonus, zdrop, h0, _qle, _tle, _gtle, _gscore, _max_off); +// fprintf(stderr, "qlen: %d, tlen: %d\n", qlen, tlen); + + if (qlen * a + h0 < 255) return ksw_extend2_avx2_u8(qlen, query, tlen, target, is_left, m, mat, o_del, e_del, o_ins, e_ins, a, b, w, end_bonus, zdrop, h0, _qle, _tle, _gtle, _gscore, _max_off); + + int16_t *mA,*hA, *eA, *fA, *mA1, *mA2, *hA0, *hA1, *eA1, *fA1, *hA2, *eA2, *fA2; // hA0保存上上个col的H,其他的保存上个H E F M + int16_t *seq, *ref; + uint8_t *mem; + int16_t *qtmem, *vmem; + int seq_size = qlen + SIMD_WIDTH, ref_size = tlen + SIMD_WIDTH; + int i, iStart, D, j, k, beg, end, max, max_i, max_j, max_ins, max_del, max_ie, gscore, max_off; + int Dloop = tlen + qlen; // 循环跳出条件 + int span, beg1, end1; // 边界条件计算 + int col_size = qlen + 2 + SIMD_WIDTH; + int val_mem_size = (col_size * 9 * 2 + 31) >> 5 << 5; // 32字节的整数倍 + int mem_size = (seq_size + ref_size) * 2 + val_mem_size; + + SIMD_INIT; // 初始化simd用的数据 + + assert(h0 > 0); + + // allocate memory + mem = malloc(mem_size); + qtmem = (int16_t*)&mem[0]; + seq=&qtmem[0]; ref=&qtmem[seq_size]; + if (is_left) { + for (i=0; i>1); i+=SIMD_WIDTH) { + _mm256_storeu_si256((__m256i*)&vmem[i], zero_vec); + } + hA = &vmem[0]; + mA = &vmem[col_size * 3]; + eA = &vmem[col_size * 5]; + fA = &vmem[col_size * 7]; + + hA0 = &hA[0]; hA1 = &hA[col_size]; hA2 = &hA1[col_size]; + mA1 = &mA[0]; mA2 = &mA[col_size]; + eA1 = &eA[0]; eA2 = &eA[col_size]; + fA1 = &fA[0]; fA2 = &fA[col_size]; + + // adjust $w if it is too large + k = m * m; + // get the max score + for (i = 0, max = 0; i < k; ++i) max = max > mat[i]? max : mat[i]; + max_ins = (int)((double)(qlen * max + end_bonus - o_ins) / e_ins + 1.); + max_ins = max_ins > 1? max_ins : 1; + w = w < max_ins? w : max_ins; + max_del = (int)((double)(qlen * max + end_bonus - o_del) / e_del + 1.); + max_del = max_del > 1? max_del : 1; + w = w < max_del? w : max_del; // TODO: is this necessary? + if (tlen < qlen) w = MIN(tlen - 1, w); + + // DP loop + max = h0, max_i = max_j = -1; max_ie = -1, gscore = -1;; + max_off = 0; + beg = 1; end = qlen; + // init h0 + hA0[0] = h0; // 左上角 + + if (qlen == 0 || tlen == 0) Dloop = 0; // 防止意外情况 + if (w >= qlen) { max_ie = 0; gscore = 0; } + + int m_last=0; + int iend; + + for (D = 1; LIKELY(D < Dloop); ++D) { + // 边界条件一定要注意! tlen 大于,等于,小于 qlen时的情况 + if (D > tlen) { + span = MIN(Dloop-D, w); + beg1 = MAX(D-tlen+1, ((D-w) / 2) + 1); + } else { + span = MIN(D-1, w); + beg1 = MAX(1, ((D-w) / 2) + 1); + } + end1 = MIN(qlen, beg1+span); + + if (beg < beg1) beg = beg1; + if (end > end1) end = end1; + if (beg > end) break; // 不用计算了,直接跳出,否则hA2没有被赋值,里边是上一轮hA0的值,会出bug + + iend = D - (beg - 1); // ref开始计算的位置,倒序 + span = end - beg; + iStart = iend - span - 1; // 0开始的ref索引位置 + + // 每一轮需要记录的数据 + int m = 0, mj = -1, mi = -1; + max_vec = zero_vec; + + // 要处理边界 + // 左边界 处理f (insert) + if (iStart == 0) { hA1[end] = MAX(0, h0 - (o_ins + e_ins * end)); } + // 上边界 + if (beg == 1) { hA1[0] = MAX(0, h0 - (o_del + e_del * iend)); } + else { hA1[beg-1] = 0; eA1[beg-1] = 0; } + + for (j=beg, i=iend; j<=end+1-SIMD_WIDTH; j+=SIMD_WIDTH, i-=SIMD_WIDTH) { + // 取数据 + SIMD_LOAD; + // 比对seq,计算罚分 + SIMD_CMP_SEQ; + // 计算 + SIMD_COMPUTE; + // 存储结果 + SIMD_STORE; + } + // 剩下的计算单元 + if (j <= end) { + // 取数据 + SIMD_LOAD; + // 比对seq,计算罚分 + SIMD_CMP_SEQ; + // 计算 + SIMD_COMPUTE; + // 去除多余计算的部分 + SIMD_REMOVE_EXTRA; + // 存储结果 + SIMD_STORE; + } + + SIMD_FIND_MAX; + + // 注意最后跳出循环j的值 + j = end + 1; + + if (j == qlen + 1) { + max_ie = gscore > hA2[qlen] ? max_ie : iStart; + gscore = gscore > hA2[qlen] ? gscore : hA2[qlen]; + } + if (m == 0 && m_last==0) break; // 一定要注意,斜对角遍历和按列遍历的不同点 + if (m > max) { + max = m, max_i = mi, max_j = mj; + max_off = max_off > abs(mj - mi)? max_off : abs(mj - mi); + } + else if (zdrop > 0) { + if (mi - max_i > mj - max_j) { + if (max - m - ((mi - max_i) - (mj - max_j)) * e_del > zdrop) break; + } else { + if (max - m - ((mj - max_j) - (mi - max_i)) * e_ins > zdrop) break; + } + } + + // 调整计算的边界 + for (j = beg; LIKELY(j <= end); ++j) { int has_val = hA1[j-1] | hA2[j]; if (has_val) break; } + beg = j; + for (j = end+1; LIKELY(j >= beg); --j) { int has_val = hA1[j-1] | hA2[j]; if (has_val) break; else hA0[j-1]=0; } + end = j + 1 <= qlen? j + 1 : qlen; + + m_last = m; + // swap m, h, e, f + SWAP_DATA_POINTER; + } + + free(mem); + if (_qle) *_qle = max_j + 1; + if (_tle) *_tle = max_i + 1; + if (_gtle) *_gtle = max_ie + 1; + if (_gscore) *_gscore = gscore; + if (_max_off) *_max_off = max_off; + return max; +} + +typedef struct { + int32_t h, e; +} eh_t; + +int ksw_extend2_origin(int qlen, // query length 待匹配段碱基的query长度 + const uint8_t *query, // read碱基序列 + int tlen, // target length reference的长度 + const uint8_t *target, // reference序列 + int is_left, // 是不是向左扩展 + int m, // 碱基种类 (5) + const int8_t *mat, // 每个位置的query和target的匹配得分 m*m + int o_del, // deletion 错配开始的惩罚系数 + int e_del, // deletion extension的惩罚系数 + int o_ins, // insertion 错配开始的惩罚系数 + int e_ins, // insertion extension的惩罚系数 + int w, // 提前剪枝系数,w =100 匹配位置和beg的最大距离 + int end_bonus, + int zdrop, + int h0, // 该seed的初始得分(完全匹配query的碱基数) + int *_qle, // 匹配得到全局最大得分的碱基在query的位置 + int *_tle, // 匹配得到全局最大得分的碱基在reference的位置 + int *_gtle, // query全部匹配上的target的长度 + int *_gscore, // query的端到端匹配得分 + int *_max_off) // 取得最大得分时在query和reference上位置差的 最大值 +{ + eh_t *eh; // score array + int8_t *qp; // query profile + int i, j, k, oe_del = o_del + e_del, oe_ins = o_ins + e_ins, beg, end, max, max_i, max_j, max_ins, max_del, max_ie, gscore, max_off; + uint8_t *qmem, *ref, *seq; + assert(h0 > 0); + // allocate memory + qp = malloc(qlen * m); + eh = calloc(qlen + 1, 8); + qmem = malloc(qlen + tlen); + seq=(uint8_t*)&qmem[0]; ref=(uint8_t*)&qmem[qlen]; + if (is_left) { + for (i=0; i oe_ins? h0 - oe_ins : 0; + for (j = 2; j <= qlen && eh[j-1].h > e_ins; ++j) + eh[j].h = eh[j-1].h - e_ins; + // adjust $w if it is too large + k = m * m; + for (i = 0, max = 0; i < k; ++i) // get the max score + max = max > mat[i]? max : mat[i]; + max_ins = (int)((double)(qlen * max + end_bonus - o_ins) / e_ins + 1.); + max_ins = max_ins > 1? max_ins : 1; + w = w < max_ins? w : max_ins; + max_del = (int)((double)(qlen * max + end_bonus - o_del) / e_del + 1.); + max_del = max_del > 1? max_del : 1; + w = w < max_del? w : max_del; // TODO: is this necessary? +// printf("%d\n", w); + // DP loop + max = h0, max_i = max_j = -1; max_ie = -1, gscore = -1; + max_off = 0; + beg = 0, end = qlen; + + for (i = 0; LIKELY(i < tlen); ++i) { + int t, f = 0, h1, m = 0, mj = -1; + int8_t *q = &qp[ref[i] * qlen]; + // apply the band and the constraint (if provided) + if (beg < i - w) beg = i - w; + if (end > i + w + 1) end = i + w + 1; + //if (end > qlen) end = qlen; 没用 + // compute the first column + if (beg == 0) { + h1 = h0 - (o_del + e_del * (i + 1)); + if (h1 < 0) h1 = 0; + } else h1 = 0; + for (j = beg; LIKELY(j < end); ++j) { + // At the beginning of the loop: eh[j] = { H(i-1,j-1), E(i,j) }, f = F(i,j) and h1 = H(i,j-1) + // Similar to SSE2-SW, cells are computed in the following order: + // H(i,j) = max{H(i-1,j-1)+S(i,j), E(i,j), F(i,j)} + // E(i+1,j) = max{H(i,j)-gapo, E(i,j)} - gape + // F(i,j+1) = max{H(i,j)-gapo, F(i,j)} - gape + eh_t *p = &eh[j]; + int h, M = p->h, e = p->e; // get H(i-1,j-1) and E(i-1,j) + p->h = h1; // set H(i,j-1) for the next row + M = M? M + q[j] : 0;// separating H and M to disallow a cigar like "100M3I3D20M" + h = M > e? M : e; // e and f are guaranteed to be non-negative, so h>=0 even if M<0 + h = h > f? h : f; + h1 = h; // save H(i,j) to h1 for the next column + mj = m > h? mj : j; // record the position where max score is achieved + m = m > h? m : h; // m is stored at eh[mj+1] + t = M - oe_del; + t = t > 0? t : 0; + e -= e_del; + e = e > t? e : t; // computed E(i+1,j) + p->e = e; // save E(i+1,j) for the next row + t = M - oe_ins; + t = t > 0? t : 0; + f -= e_ins; + f = f > t? f : t; // computed F(i,j+1) + } + eh[end].h = h1; eh[end].e = 0; + if (j == qlen) { + max_ie = gscore > h1? max_ie : i; + gscore = gscore > h1? gscore : h1; + } + if (m == 0) break; + if (m > max) { + max = m, max_i = i, max_j = mj; + max_off = max_off > abs(mj - i)? max_off : abs(mj - i); + } else if (zdrop > 0) { + if (i - max_i > mj - max_j) { + if (max - m - ((i - max_i) - (mj - max_j)) * e_del > zdrop) break; + } else { + if (max - m - ((mj - max_j) - (i - max_i)) * e_ins > zdrop) break; + } + } + // update beg and end for the next round + for (j = beg; LIKELY(j < end) && eh[j].h == 0 && eh[j].e == 0; ++j); + beg = j; + for (j = end; LIKELY(j >= beg) && eh[j].h == 0 && eh[j].e == 0; --j); + end = j + 2 < qlen? j + 2 : qlen; + //beg = 0; end = qlen; // uncomment this line for debugging + } + + free(eh); free(qp); free(qmem); + if (_qle) *_qle = max_j + 1; + if (_tle) *_tle = max_i + 1; + if (_gtle) *_gtle = max_ie + 1; + if (_gscore) *_gscore = gscore; + if (_max_off) *_max_off = max_off; + return max; +} diff --git a/ksw_extend2_avx2_u8.c b/ksw_extend2_avx2_u8.c new file mode 100644 index 0000000..9311c19 --- /dev/null +++ b/ksw_extend2_avx2_u8.c @@ -0,0 +1,370 @@ +#include +#include +#include +#include +#include +#include +#include + +#ifdef __GNUC__ +#define LIKELY(x) __builtin_expect((x),1) +#define UNLIKELY(x) __builtin_expect((x),0) +#else +#define LIKELY(x) (x) +#define UNLIKELY(x) (x) +#endif + +#undef MAX +#undef MIN +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) +#define SIMD_WIDTH 32 + +static const uint8_t h_vec_int_mask[SIMD_WIDTH][SIMD_WIDTH] = { + {0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0}, + {0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff} +}; + +//static const uint8_t reverse_mask[SIMD_WIDTH] = {1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14, 1, 0, 3, 2, 5, 4, 7, 6, 9, 8, 11, 10, 13, 12, 15, 14}; +static const uint8_t reverse_mask[SIMD_WIDTH] = {7,6,5,4,3,2,1,0,15,14,13,12,11,10,9,8,7,6,5,4,3,2,1,0,15,14,13,12,11,10,9,8}; +#define permute_mask _MM_SHUFFLE(0, 1, 2, 3) +//const int permute_mask = _MM_SHUFFLE(0, 1, 2, 3); +// 初始化变量 +#define SIMD_INIT \ + int oe_del = o_del + e_del, oe_ins = o_ins + e_ins; \ + __m256i zero_vec; \ + __m256i max_vec; \ + __m256i oe_del_vec; \ + __m256i oe_ins_vec; \ + __m256i e_del_vec; \ + __m256i e_ins_vec; \ + __m256i h_vec_mask[SIMD_WIDTH]; \ + __m256i reverse_mask_vec; \ + zero_vec = _mm256_setzero_si256(); \ + oe_del_vec = _mm256_set1_epi8(oe_del); \ + oe_ins_vec = _mm256_set1_epi8(oe_ins); \ + e_del_vec = _mm256_set1_epi8(e_del); \ + e_ins_vec = _mm256_set1_epi8(e_ins); \ + __m256i match_sc_vec = _mm256_set1_epi8(a); \ + __m256i mis_sc_vec = _mm256_set1_epi8(b); \ + __m256i amb_sc_vec = _mm256_set1_epi8(1); \ + __m256i amb_vec = _mm256_set1_epi8(4); \ + reverse_mask_vec = _mm256_loadu_si256((__m256i*) (reverse_mask)); \ + for (i=0; i 0) { \ + for(j=beg, i=iend; j<=end; j+=SIMD_WIDTH, i-=SIMD_WIDTH) { \ + __m256i h2_vec = _mm256_loadu_si256((__m256i*) (&hA2[j])); \ + __m256i vcmp = _mm256_cmpeq_epi8(h2_vec, max_vec); \ + uint32_t mask = _mm256_movemask_epi8(vcmp); \ + if (mask > 0) { \ + int pos = SIMD_WIDTH - 1 - __builtin_clz(mask); \ + mj = j - 1 + pos; \ + mi = i - 1 - pos; \ + } \ + } \ + } + +// 每轮迭代后,交换数组 +#define SWAP_DATA_POINTER \ + uint8_t * tmp=hA0; \ + hA0 = hA1; hA1 = hA2; hA2 = tmp; \ + tmp = eA1; eA1 = eA2; eA2 = tmp; \ + tmp = fA1; fA1 = fA2; fA2 = tmp; \ + tmp = mA1; mA1 = mA2; mA2 = tmp; + + +int ksw_extend2_avx2_u8(int qlen, // query length 待匹配段碱基的query长度 + const uint8_t *query, // read碱基序列 + int tlen, // target length reference的长度 + const uint8_t *target, // reference序列 + int is_left, // 是不是向左扩展 + int m, // 碱基种类 (5) + const int8_t *mat, // 每个位置的query和target的匹配得分 m*m + int o_del, // deletion 错配开始的惩罚系数 + int e_del, // deletion extension的惩罚系数 + int o_ins, // insertion 错配开始的惩罚系数 + int e_ins, // insertion extension的惩罚系数 + int a, // 碱基match时的分数 + int b, // 碱基mismatch时的惩罚分数(正数) + int w, // 提前剪枝系数,w =100 匹配位置和beg的最大距离 + int end_bonus, + int zdrop, + int h0, // 该seed的初始得分(完全匹配query的碱基数) + int *_qle, // 匹配得到全局最大得分的碱基在query的位置 + int *_tle, // 匹配得到全局最大得分的碱基在reference的位置 + int *_gtle, // query全部匹配上的target的长度 + int *_gscore, // query的端到端匹配得分 + int *_max_off) // 取得最大得分时在query和reference上位置差的 最大值 +{ + uint8_t *mA,*hA, *eA, *fA, *mA1, *mA2, *hA0, *hA1, *eA1, *fA1, *hA2, *eA2, *fA2; // hA0保存上上个col的H,其他的保存上个H E F M + uint8_t *seq, *ref; + uint8_t *mem, *qtmem, *vmem; + int seq_size = qlen + SIMD_WIDTH, ref_size = tlen + SIMD_WIDTH; + int i, iStart, D, j, k, beg, end, max, max_i, max_j, max_ins, max_del, max_ie, gscore, max_off; + int Dloop = tlen + qlen; // 循环跳出条件 + int span, beg1, end1; // 边界条件计算 + int col_size = qlen + 2 + SIMD_WIDTH; + int val_mem_size = (col_size * 9 + 31) >> 5 << 5; // 32字节的整数倍 + int mem_size = seq_size + ref_size + val_mem_size; + + SIMD_INIT; // 初始化simd用的数据 + + assert(h0 > 0); + + // allocate memory + mem = malloc(mem_size); + qtmem = &mem[0]; + seq=(uint8_t*)&qtmem[0]; ref=(uint8_t*)&qtmem[seq_size]; + if (is_left) { + for (i=0; i mat[i]? max : mat[i]; + max_ins = (int)((double)(qlen * max + end_bonus - o_ins) / e_ins + 1.); + max_ins = max_ins > 1? max_ins : 1; + w = w < max_ins? w : max_ins; + max_del = (int)((double)(qlen * max + end_bonus - o_del) / e_del + 1.); + max_del = max_del > 1? max_del : 1; + w = w < max_del? w : max_del; // TODO: is this necessary? + if (tlen < qlen) w = MIN(tlen - 1, w); + + // DP loop + max = h0, max_i = max_j = -1; max_ie = -1, gscore = -1;; + max_off = 0; + beg = 1; end = qlen; + // init h0 + hA0[0] = h0; // 左上角 + + if (qlen == 0 || tlen == 0) Dloop = 0; // 防止意外情况 + if (w >= qlen) { max_ie = 0; gscore = 0; } + + int m_last=0; + int iend; + + for (D = 1; LIKELY(D < Dloop); ++D) { + // 边界条件一定要注意! tlen 大于,等于,小于 qlen时的情况 + if (D > tlen) { + span = MIN(Dloop-D, w); + beg1 = MAX(D-tlen+1, ((D-w) / 2) + 1); + } else { + span = MIN(D-1, w); + beg1 = MAX(1, ((D-w) / 2) + 1); + } + end1 = MIN(qlen, beg1+span); + + if (beg < beg1) beg = beg1; + if (end > end1) end = end1; + if (beg > end) break; // 不用计算了,直接跳出,否则hA2没有被赋值,里边是上一轮hA0的值,会出bug + + iend = D - (beg - 1); // ref开始计算的位置,倒序 + span = end - beg; + iStart = iend - span - 1; // 0开始的ref索引位置 + + // 每一轮需要记录的数据 + int m = 0, mj = -1, mi = -1; + max_vec = zero_vec; + + // 要处理边界 + // 左边界 处理f (insert) + if (iStart == 0) { hA1[end] = MAX(0, h0 - (o_ins + e_ins * end)); } + // 上边界 + if (beg == 1) { hA1[0] = MAX(0, h0 - (o_del + e_del * iend)); } + else { hA1[beg-1] = 0; eA1[beg-1] = 0; } + + for (j=beg, i=iend; j<=end+1-SIMD_WIDTH; j+=SIMD_WIDTH, i-=SIMD_WIDTH) { + // 取数据 + SIMD_LOAD; + // 比对seq,计算罚分 + SIMD_CMP_SEQ; + // 计算 + SIMD_COMPUTE; + // 存储结果 + SIMD_STORE; + } + // 剩下的计算单元 + if (j <= end) { + // 取数据 + SIMD_LOAD; + // 比对seq,计算罚分 + SIMD_CMP_SEQ; + // 计算 + SIMD_COMPUTE; + // 去除多余计算的部分 + SIMD_REMOVE_EXTRA; + // 存储结果 + SIMD_STORE; + } + + SIMD_FIND_MAX; + + // 注意最后跳出循环j的值 + j = end + 1; + + if (j == qlen + 1) { + max_ie = gscore > hA2[qlen] ? max_ie : iStart; + gscore = gscore > hA2[qlen] ? gscore : hA2[qlen]; + } + if (m == 0 && m_last==0) break; // 一定要注意,斜对角遍历和按列遍历的不同点 + if (m > max) { + max = m, max_i = mi, max_j = mj; + max_off = max_off > abs(mj - mi)? max_off : abs(mj - mi); + } + else if (zdrop > 0) { + if (mi - max_i > mj - max_j) { + if (max - m - ((mi - max_i) - (mj - max_j)) * e_del > zdrop) break; + } else { + if (max - m - ((mj - max_j) - (mi - max_i)) * e_ins > zdrop) break; + } + } + + // 调整计算的边界 + for (j = beg; LIKELY(j <= end); ++j) { int has_val = hA1[j-1] | hA2[j]; if (has_val) break; } + beg = j; + for (j = end+1; LIKELY(j >= beg); --j) { int has_val = hA1[j-1] | hA2[j]; if (has_val) break; else hA0[j-1]=0; } + end = j + 1 <= qlen? j + 1 : qlen; + + m_last = m; + // swap m, h, e, f + SWAP_DATA_POINTER; + } + + free(mem); + if (_qle) *_qle = max_j + 1; + if (_tle) *_tle = max_i + 1; + if (_gtle) *_gtle = max_ie + 1; + if (_gscore) *_gscore = gscore; + if (_max_off) *_max_off = max_off; + return max; +} diff --git a/run.sh b/run.sh index 961dc81..c8da9d1 100755 --- a/run.sh +++ b/run.sh @@ -1,13 +1,13 @@ thread=1 -#n_r1=~/data/fastq/ZY2105177532213000/n_r1.fq -#n_r2=~/data/fastq/ZY2105177532213000/n_r2.fq +n_r1=~/data/fastq/ZY2105177532213000/n_r1.fq +n_r2=~/data/fastq/ZY2105177532213000/n_r2.fq #n_r1=~/data/fastq/ZY2105177532213000/sn_r1.fq #n_r2=~/data/fastq/ZY2105177532213000/sn_r2.fq #reference=~/data/reference/human_g1k_v37_decoy.fasta #n_r1=~/fastq/sn_r1.fq #n_r2=~/fastq/sn_r2.fq -n_r1=~/fastq/ssn_r1.fq -n_r2=~/fastq/ssn_r2.fq +#n_r1=~/fastq/ssn_r1.fq +#n_r2=~/fastq/ssn_r2.fq #n_r1=~/fastq/tiny_n_r1.fq #n_r2=~/fastq/tiny_n_r2.fq #n_r1=~/fastq/diff_r1.fq @@ -15,9 +15,9 @@ n_r2=~/fastq/ssn_r2.fq #n_r1=~/fastq/d_r1.fq #n_r2=~/fastq/d_r2.fq reference=~/reference/human_g1k_v37_decoy.fasta -out=./ssn.sam +#out=./ssn.sam #out=./out.sam -#out=/dev/null +out=/dev/null #time ./bwa mem -t 12 -M -R @RG\\tID:normal\\tSM:normal\\tPL:illumina\\tLB:normal\\tPG:bwa \ # /home/zzh/data/reference/human_g1k_v37_decoy.fasta \ # /home/zzh/data/fastq/nm1.fq \ @@ -29,7 +29,7 @@ out=./ssn.sam # /mnt/d/data/fastq/ZY2105177532213000/ZY2105177532213010_L4_2.fq.gz \ # -o /dev/null -time ./bwa mem -t $thread -M -R @RG\\tID:normal\\tSM:normal\\tPL:illumina\\tLB:normal\\tPG:bwa \ +time ./bwa mem -b 64 -t $thread -M -R @RG\\tID:normal\\tSM:normal\\tPL:illumina\\tLB:normal\\tPG:bwa \ $reference \ $n_r1 \ $n_r2 \ diff --git a/utils.h b/utils.h index 5722787..b54ffe0 100644 --- a/utils.h +++ b/utils.h @@ -45,7 +45,8 @@ extern int64_t time_ksw_extend2, time_bwt_occ4, time_bwt_sa, time_bwt_sa_read, - time_bns; + time_bns, + time_core_process; extern int64_t dn, n16, n17, n18, n19, nall, num_sa; extern int64_t s1n, s2n, s3n, s1l, s2l, s3l; @@ -54,6 +55,11 @@ extern FILE *fp1; #endif +#undef MAX +#undef MIN +#define MAX(x, y) ((x) > (y) ? (x) : (y)) +#define MIN(x, y) ((x) < (y) ? (x) : (y)) + #ifdef __GNUC__ // Tell GCC to validate printf format string and args #define ATTRIBUTE(list) __attribute__ (list)