twirls/MexFunc/CorrelationDist.cpp

334 lines
9.2 KiB
C++
Raw Normal View History

2023-10-05 10:38:21 +08:00
#include <mex.h>
#include <mat.h>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <unordered_set>
#include <ctime>
#include <immintrin.h>
#include <zmmintrin.h>
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
2023-10-05 10:38:21 +08:00
using std::cout;
using std::endl;
using namespace std;
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
->std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for (size_t i = 0;i < threads;++i)
workers.emplace_back(
[this]
{
for (;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if (stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread& worker : workers)
worker.join();
}
2023-10-05 10:38:21 +08:00
// <20>̲߳<DFB3><CCB2><EFBFBD>
struct TPCorDist {
vector<vector<float>>* pvvX;
vector<double>* pvDist;
vector<double>* pvSq;
int rowIdxStart;
int rowIdxEnd;
int rowNum;
int colNum;
};
// <20>̺߳<DFB3><CCBA><EFBFBD>
void ThreadCalcDist(TPCorDist& param) {
vector<vector<float>>& vvX = *param.pvvX;
vector<double>& vDist = *param.pvDist;
vector<double>& vSq = *param.pvSq;
int rowIdxStart = param.rowIdxStart;
int rowIdxEnd = param.rowIdxEnd;
int rowNum = param.rowNum;
int colNum = param.colNum;
double uv = 0;
clock_t begin = clock(), finish;
__m256 vec_zero = _mm256_set1_ps(0);
for (int i = rowIdxStart; i < rowIdxEnd; ++i) {
const int baseIdx = i * (rowNum * 2 - i - 1) / 2;
for (int cur = i + 1, idx = 0; cur < rowNum; ++cur, ++idx) {
double uv = 0;
//for (int j = 0; j < colNum; ++j) { uv += vvX[i][j] * vvX[cur][j]; }
__m256 vec_uv = vec_zero;
int colNum8 = colNum / 8 * 8;
int remain = colNum - colNum8;
for (int j = 0; j < colNum8; j += 8) {
__m256 vec_u = _mm256_loadu_ps(vvX[i].data() + j);
__m256 vec_v = _mm256_loadu_ps(vvX[cur].data() + j);
vec_uv = _mm256_add_ps(_mm256_mul_ps(vec_u, vec_v), vec_uv);
}
for (int j = colNum8; j < colNum; ++j) {
uv += vvX[i][j] * vvX[cur][j];
}
float* pVec = (float*)&vec_uv;
for (int j = 0; j < 8; ++j) uv += pVec[j];
uv /= colNum;
const double dist = abs(1.0 - uv / sqrt(vSq[i] * vSq[cur]));
vDist[baseIdx + idx] = dist;
}
}
finish = clock();
//cout << rowIdxEnd << " Thread time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}
/* <20><><EFBFBD>ں<EFBFBD><DABA><EFBFBD> */
/*
<EFBFBD><EFBFBD><EFBFBD>
1. x: <EFBFBD><EFBFBD>ά<EFBFBD><EFBFBD>
[2]. numThread: <EFBFBD>߳<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
[3]. numGroup: ÿ<EFBFBD><EFBFBD><EFBFBD>̺߳<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
1. d: <EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><EFBFBD><EFBFBD>
*/
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
if (nrhs < 1) {
cout << "At least 1 arguments should be given for this function!" << endl;
return;
}
2023-10-05 10:38:21 +08:00
clock_t begin = clock(), mid, finish;
int rowNum = (int)mxGetM(prhs[0]);
int colNum = (int)mxGetN(prhs[0]);
double* pData = (double*)mxGetData(prhs[0]);
int numThread = 6;
if (nrhs > 1) {
double* pNumThread = (double*)mxGetData(prhs[1]);
numThread = (int)pNumThread[0];
if (numThread < 1) numThread = 1;
}
int numGroup = 1;
if (nrhs > 2) {
double* pNumGroup = (double*)mxGetData(prhs[2]);
numGroup = (int)pNumGroup[0];
if (numGroup < 1) numGroup = 1;
}
2023-10-05 10:38:21 +08:00
//cout << numThread << '\t' << numGroup << endl;
2023-10-05 10:38:21 +08:00
//for (int i = 0; i < rowNum; ++i) {
// for (int j = 0; j < colNum; ++j) {
// cout << pData[j * rowNum + i] << '\t';
// }
// cout << endl;
//}
// int rowNum = 2;
// int colNum = 5;
// vector<double> pData = { 1,1,2,2,3,9,4,4,5,4 };
//cout << rowNum << '\t' << colNum << endl;
2023-10-05 10:38:21 +08:00
/* <20><><EFBFBD><EFBFBD>ÿһ<C3BF>е<EFBFBD>ƽ<EFBFBD><C6BD><EFBFBD><EFBFBD> */
mid = clock();
vector<double> vMean(rowNum);
for (int j = 0; j < colNum; ++j) {
for (int i = 0; i < rowNum; ++i) {
vMean[i] += pData[j * rowNum + i];
}
}
for (int i = 0; i < rowNum; ++i) {
vMean[i] /= colNum;
}
finish = clock();
//cout << "<22><><EFBFBD><EFBFBD>ƽ<EFBFBD><C6BD><EFBFBD><EFBFBD>: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
2023-10-05 10:38:21 +08:00
/* <20><>ȥƽ<C8A5><C6BD>ֵ, <20><><EFBFBD><EFBFBD>ƽ<EFBFBD><C6BD> */
mid = clock();
vector<vector<float>> vvX(rowNum, vector<float>(colNum));
vector<double> vSq(rowNum);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
vvX[i][j] = pData[j * rowNum + i] - vMean[i];
vSq[i] += vvX[i][j] * vvX[i][j];
}
}
for (auto& val : vSq) { val /= colNum; }
finish = clock();
//cout << "<22><><EFBFBD><EFBFBD>ƽ<EFBFBD><C6BD>: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
2023-10-05 10:38:21 +08:00
/* <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><D8BE><EFBFBD> */
// clock_t mid0 = clock();
// const int distSize = rowNum * (rowNum - 1) / 2;
// const int row2 = 2 * rowNum;
// vector<double> vDist(distSize, 0.0);
// //__m256d vec_zero = _mm256_set1_pd(0);
// __m256 vec_zero = _mm256_set1_ps(0);
// for (int i = 0; i < rowNum; ++i) {
// if (i % 100 == 0) {
// finish = clock();
// cout << "time " << i << ": " << (double)(finish - mid) / CLOCKS_PER_SEC << " s; " <<
// (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
// mid = clock();
// }
// const int baseIdx = i * (row2 - i - 1) / 2;
// for (int cur = i + 1, idx = 0; cur < rowNum; ++cur, ++idx) {
// double uv = 0;
// //for (int j = 0; j < colNum; ++j) { uv += vvX[i][j] * vvX[cur][j]; }
// //__m256d vec_uv = vec_zero;
// //int colNum4 = colNum / 4 * 4;
// //int remain = colNum - colNum4;
// //for (int j = 0; j < colNum4; j += 4) {
// // __m256d vec_u = _mm256_loadu_pd(vvX[i].data() + j);
// // __m256d vec_v = _mm256_loadu_pd(vvX[cur].data() + j);
// // vec_uv = _mm256_add_pd(_mm256_mul_pd(vec_u, vec_v), vec_uv);
// //}
// //for (int j = colNum4; j < colNum; ++j) {
// // uv += vvX[i][j] * vvX[cur][j];
// //}
// //double* pVec = (double*)&vec_uv;
// //for (int j = 0; j < 4; ++j) uv += pVec[j];
//
// __m256 vec_uv = vec_zero;
// int colNum8 = colNum / 8 * 8;
// int remain = colNum - colNum8;
// for (int j = 0; j < colNum8; j += 8) {
// __m256 vec_u = _mm256_loadu_ps(vvX[i].data() + j);
// __m256 vec_v = _mm256_loadu_ps(vvX[cur].data() + j);
// vec_uv = _mm256_add_ps(_mm256_mul_ps(vec_u, vec_v), vec_uv);
// }
// for (int j = colNum8; j < colNum; ++j) {
// uv += vvX[i][j] * vvX[cur][j];
// }
// float* pVec = (float*)&vec_uv;
// for (int j = 0; j < 8; ++j) uv += pVec[j];
//
// uv /= colNum;
// const double dist = abs(1.0 - uv / sqrt(vSq[i] * vSq[cur]));
// vDist[baseIdx + idx] = dist;
// }
// }
// finish = clock();
// cout << "<22><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><D8BE><EFBFBD>: " << (double)(finish - mid0) / CLOCKS_PER_SEC << " s" << endl;
// for (auto& val : vDist) {cout << val << endl;}
/* <20><><EFBFBD>̼߳<DFB3><CCBC><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><D8BE><EFBFBD> */
const int distSize = rowNum * (rowNum - 1) / 2;
const int row2 = 2 * rowNum;
vector<double> vDist(distSize, 0.0);
mid = clock();
// vector<TPCorDist> vTP;
ThreadPool thPool(numThread);
int span = numGroup;
2023-10-05 10:38:21 +08:00
int rowNumSpan = rowNum / span * span;
int i = 0;
for (i = 0; i < rowNumSpan; i += span) {
// vTP.push_back({&vvX, &vDist, &vSq, i, rowNum, colNum});
TPCorDist tp = { &vvX, &vDist, &vSq, i, i + span, rowNum, colNum };
thPool.enqueue(ThreadCalcDist, tp);
//vTP.push_back(tp);
}
if (i < rowNum) {
TPCorDist tp = { &vvX, &vDist, &vSq, i, rowNum, rowNum, colNum };
thPool.enqueue(ThreadCalcDist, tp);
//vTP.push_back(tp);
}
thPool.~ThreadPool();
//kt_for(6, ThreadCalcDist, vTP);
finish = clock();
//cout << "<22><><EFBFBD>̼߳<DFB3><CCBC><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><D8BE><EFBFBD>: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
2023-10-05 10:38:21 +08:00
// for (auto& val : vDist) {cout << val << endl;}
/* д<><D0B4><EFBFBD><EFBFBD><EFBFBD><EFBFBD> */
if (nlhs > 0) { // b
mxArray* pWriteArray = NULL;
//<2F><><EFBFBD><EFBFBD>һ<EFBFBD><D2BB>rowNum*colNum<75>ľ<EFBFBD><C4BE><EFBFBD>
pWriteArray = mxCreateDoubleMatrix(1, distSize, mxREAL);
memcpy((void*)(mxGetPr(pWriteArray)), (void*)vDist.data(), sizeof(double) * distSize);
2023-10-05 10:38:21 +08:00
plhs[0] = pWriteArray; // <20><>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ
}
finish = clock();
cout << "Correlation Dist Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}
/* <20><>main<69><6E><EFBFBD>Ե<EFBFBD><D4B5><EFBFBD> */
void mexFunctionWrap(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
mexFunction(nlhs, plhs, nrhs, prhs);
2023-10-05 10:38:21 +08:00
}