twirls/CommonLib/kthread.h

146 lines
3.2 KiB
C
Raw Permalink Normal View History

#ifndef KTHREAD_H
#define KTHREAD_H
#include <stdlib.h>
#include <limits.h>
#include <thread>
#include <vector>
#include <atomic>
using std::atomic;
using std::thread;
using std::vector;
/************
* kt_for() *
************/
template <typename T>
using FuncType3Arg = void (*)(vector<T>&, long, int);
template <typename T>
using FuncType1Arg = void (*)(T&);
template <class T>
struct kt_for_t;
template <typename T>
struct ktf_worker_t
{
kt_for_t<T>* t;
atomic<long> i;
};
template <typename T>
struct kt_for_t
{
int n_threads;
long n;
ktf_worker_t<T>* w;
FuncType1Arg<T> func1Arg;
FuncType3Arg<T> func3Arg;
vector<T>* data;
};
template <class T>
static inline long steal_work(kt_for_t<T>* t)
{
int i, min_i = -1;
long k, min = LONG_MAX;
for (i = 0; i < t->n_threads; ++i)
if (min > t->w[i].i)
min = t->w[i].i, min_i = i;
k = t->w[min_i].i.fetch_add(t->n_threads);
return k >= t->n ? -1 : k;
}
template <class T>
static void ktf_worker_1_arg(void* data)
{
ktf_worker_t<T>* w = (ktf_worker_t<T> *)data;
long i;
for (;;)
{
i = w->i.fetch_add(w->t->n_threads);
if (i >= w->t->n)
break;
w->t->func1Arg(( * w->t->data)[i]);
}
while ((i = steal_work<T>(w->t)) >= 0)
w->t->func1Arg((*w->t->data)[i]);
}
template <class T>
static void ktf_worker_3_arg(void* data)
{
ktf_worker_t<T>* w = (ktf_worker_t<T> *)data;
long i;
for (;;)
{
i = w->i.fetch_add(w->t->n_threads);
if (i >= w->t->n)
break;
w->t->func3Arg(*w->t->data, i, w - w->t->w);
}
while ((i = steal_work<T>(w->t)) >= 0)
w->t->func3Arg(*w->t->data, i, w - w->t->w);
}
template <typename T>
void kt_for(int n_threads, FuncType3Arg<T> func, vector<T>& vData)
{
const long n = (long)vData.size();
if (n_threads > 1)
{
int i;
kt_for_t<T> t;
t.func3Arg = func, t.data = &vData, t.n_threads = n_threads, t.n = n;
t.w = (ktf_worker_t<T> *)alloca(n_threads * sizeof(ktf_worker_t<T>));
vector<thread> vThread;
for (i = 0; i < n_threads; ++i)
t.w[i].t = &t, t.w[i].i.store(i);
for (i = 0; i < n_threads; ++i)
vThread.push_back(thread(ktf_worker_3_arg<T>, &t.w[i]));
for (i = 0; i < n_threads; ++i)
vThread[i].join();
}
else
{
long j;
for (j = 0; j < n; ++j)
func(vData, j, 0);
}
}
template <typename T>
void kt_for(int n_threads, FuncType1Arg<T> func, vector<T>& vData)
{
const long n = (long)vData.size();
if (n_threads > 1)
{
int i;
kt_for_t<T> t;
t.func1Arg = func, t.data = &vData, t.n_threads = n_threads, t.n = n;
t.w = (ktf_worker_t<T> *)alloca(n_threads * sizeof(ktf_worker_t<T>));
vector<thread> vThread;
for (i = 0; i < n_threads; ++i)
t.w[i].t = &t, t.w[i].i.store(i);
for (i = 0; i < n_threads; ++i)
vThread.push_back(thread(ktf_worker_1_arg<T>, &t.w[i]));
for (i = 0; i < n_threads; ++i)
vThread[i].join();
}
else
{
long j;
for (j = 0; j < n; ++j)
func(vData[j]);
}
}
#endif