#ifndef KTHREAD_H #define KTHREAD_H #include #include #include #include #include using std::atomic; using std::thread; using std::vector; /************ * kt_for() * ************/ template using FuncType3Arg = void (*)(vector&, long, int); template using FuncType1Arg = void (*)(T&); template struct kt_for_t; template struct ktf_worker_t { kt_for_t* t; atomic i; }; template struct kt_for_t { int n_threads; long n; ktf_worker_t* w; FuncType1Arg func1Arg; FuncType3Arg func3Arg; vector* data; }; template static inline long steal_work(kt_for_t* t) { int i, min_i = -1; long k, min = LONG_MAX; for (i = 0; i < t->n_threads; ++i) if (min > t->w[i].i) min = t->w[i].i, min_i = i; k = t->w[min_i].i.fetch_add(t->n_threads); return k >= t->n ? -1 : k; } template static void ktf_worker_1_arg(void* data) { ktf_worker_t* w = (ktf_worker_t *)data; long i; for (;;) { i = w->i.fetch_add(w->t->n_threads); if (i >= w->t->n) break; w->t->func1Arg(( * w->t->data)[i]); } while ((i = steal_work(w->t)) >= 0) w->t->func1Arg((*w->t->data)[i]); } template static void ktf_worker_3_arg(void* data) { ktf_worker_t* w = (ktf_worker_t *)data; long i; for (;;) { i = w->i.fetch_add(w->t->n_threads); if (i >= w->t->n) break; w->t->func3Arg(*w->t->data, i, w - w->t->w); } while ((i = steal_work(w->t)) >= 0) w->t->func3Arg(*w->t->data, i, w - w->t->w); } template void kt_for(int n_threads, FuncType3Arg func, vector& vData) { const long n = (long)vData.size(); if (n_threads > 1) { int i; kt_for_t t; t.func3Arg = func, t.data = &vData, t.n_threads = n_threads, t.n = n; t.w = (ktf_worker_t *)alloca(n_threads * sizeof(ktf_worker_t)); vector vThread; for (i = 0; i < n_threads; ++i) t.w[i].t = &t, t.w[i].i.store(i); for (i = 0; i < n_threads; ++i) vThread.push_back(thread(ktf_worker_3_arg, &t.w[i])); for (i = 0; i < n_threads; ++i) vThread[i].join(); } else { long j; for (j = 0; j < n; ++j) func(vData, j, 0); } } template void kt_for(int n_threads, FuncType1Arg func, vector& vData) { const long n = (long)vData.size(); if (n_threads > 1) { int i; kt_for_t t; t.func1Arg = func, t.data = &vData, t.n_threads = n_threads, t.n = n; t.w = (ktf_worker_t *)alloca(n_threads * sizeof(ktf_worker_t)); vector vThread; for (i = 0; i < n_threads; ++i) t.w[i].t = &t, t.w[i].i.store(i); for (i = 0; i < n_threads; ++i) vThread.push_back(thread(ktf_worker_1_arg, &t.w[i])); for (i = 0; i < n_threads; ++i) vThread[i].join(); } else { long j; for (j = 0; j < n; ++j) func(vData[j]); } } #endif