2023-10-05 10:38:21 +08:00
|
|
|
|
#include <mex.h>
|
|
|
|
|
|
#include <mat.h>
|
|
|
|
|
|
#include <iostream>
|
|
|
|
|
|
#include <algorithm>
|
|
|
|
|
|
#include <vector>
|
|
|
|
|
|
#include <string>
|
|
|
|
|
|
#include <unordered_set>
|
|
|
|
|
|
#include <ctime>
|
|
|
|
|
|
#include <immintrin.h>
|
|
|
|
|
|
#include <zmmintrin.h>
|
2023-10-05 23:12:02 +08:00
|
|
|
|
#include <vector>
|
|
|
|
|
|
#include <queue>
|
|
|
|
|
|
#include <memory>
|
|
|
|
|
|
#include <thread>
|
|
|
|
|
|
#include <mutex>
|
|
|
|
|
|
#include <condition_variable>
|
|
|
|
|
|
#include <future>
|
|
|
|
|
|
#include <functional>
|
|
|
|
|
|
#include <stdexcept>
|
|
|
|
|
|
|
|
|
|
|
|
// #include "CommonLib/kthread.h"
|
|
|
|
|
|
// #include "CommonLib/thread_pool.h"
|
2023-10-05 10:38:21 +08:00
|
|
|
|
using std::cout;
|
|
|
|
|
|
using std::endl;
|
|
|
|
|
|
using namespace std;
|
|
|
|
|
|
|
2023-10-05 23:12:02 +08:00
|
|
|
|
class ThreadPool {
|
|
|
|
|
|
public:
|
|
|
|
|
|
ThreadPool(size_t);
|
|
|
|
|
|
template<class F, class... Args>
|
|
|
|
|
|
auto enqueue(F&& f, Args&&... args)
|
|
|
|
|
|
->std::future<typename std::result_of<F(Args...)>::type>;
|
|
|
|
|
|
~ThreadPool();
|
|
|
|
|
|
private:
|
|
|
|
|
|
// need to keep track of threads so we can join them
|
|
|
|
|
|
std::vector< std::thread > workers;
|
|
|
|
|
|
// the task queue
|
|
|
|
|
|
std::queue< std::function<void()> > tasks;
|
|
|
|
|
|
|
|
|
|
|
|
// synchronization
|
|
|
|
|
|
std::mutex queue_mutex;
|
|
|
|
|
|
std::condition_variable condition;
|
|
|
|
|
|
bool stop;
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
// the constructor just launches some amount of workers
|
|
|
|
|
|
inline ThreadPool::ThreadPool(size_t threads)
|
|
|
|
|
|
: stop(false)
|
|
|
|
|
|
{
|
|
|
|
|
|
for (size_t i = 0;i < threads;++i)
|
|
|
|
|
|
workers.emplace_back(
|
|
|
|
|
|
[this]
|
|
|
|
|
|
{
|
|
|
|
|
|
for (;;)
|
|
|
|
|
|
{
|
|
|
|
|
|
std::function<void()> task;
|
|
|
|
|
|
|
|
|
|
|
|
{
|
|
|
|
|
|
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
|
|
|
|
|
this->condition.wait(lock,
|
|
|
|
|
|
[this] { return this->stop || !this->tasks.empty(); });
|
|
|
|
|
|
if (this->stop && this->tasks.empty())
|
|
|
|
|
|
return;
|
|
|
|
|
|
task = std::move(this->tasks.front());
|
|
|
|
|
|
this->tasks.pop();
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
task();
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
);
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// add new work item to the pool
|
|
|
|
|
|
template<class F, class... Args>
|
|
|
|
|
|
auto ThreadPool::enqueue(F&& f, Args&&... args)
|
|
|
|
|
|
-> std::future<typename std::result_of<F(Args...)>::type>
|
|
|
|
|
|
{
|
|
|
|
|
|
using return_type = typename std::result_of<F(Args...)>::type;
|
|
|
|
|
|
|
|
|
|
|
|
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
|
|
|
|
|
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
|
|
|
|
|
);
|
|
|
|
|
|
|
|
|
|
|
|
std::future<return_type> res = task->get_future();
|
|
|
|
|
|
{
|
|
|
|
|
|
std::unique_lock<std::mutex> lock(queue_mutex);
|
|
|
|
|
|
|
|
|
|
|
|
// don't allow enqueueing after stopping the pool
|
|
|
|
|
|
if (stop)
|
|
|
|
|
|
throw std::runtime_error("enqueue on stopped ThreadPool");
|
|
|
|
|
|
|
|
|
|
|
|
tasks.emplace([task]() { (*task)(); });
|
|
|
|
|
|
}
|
|
|
|
|
|
condition.notify_one();
|
|
|
|
|
|
return res;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
// the destructor joins all threads
|
|
|
|
|
|
inline ThreadPool::~ThreadPool()
|
|
|
|
|
|
{
|
|
|
|
|
|
{
|
|
|
|
|
|
std::unique_lock<std::mutex> lock(queue_mutex);
|
|
|
|
|
|
stop = true;
|
|
|
|
|
|
}
|
|
|
|
|
|
condition.notify_all();
|
|
|
|
|
|
for (std::thread& worker : workers)
|
|
|
|
|
|
worker.join();
|
|
|
|
|
|
}
|
|
|
|
|
|
|
2023-10-05 10:38:21 +08:00
|
|
|
|
// <20>̲߳<DFB3><CCB2><EFBFBD>
|
|
|
|
|
|
struct TPCorDist {
|
|
|
|
|
|
vector<vector<float>>* pvvX;
|
|
|
|
|
|
vector<double>* pvDist;
|
|
|
|
|
|
vector<double>* pvSq;
|
|
|
|
|
|
int rowIdxStart;
|
|
|
|
|
|
int rowIdxEnd;
|
|
|
|
|
|
int rowNum;
|
|
|
|
|
|
int colNum;
|
|
|
|
|
|
};
|
|
|
|
|
|
|
|
|
|
|
|
// <20>̺߳<DFB3><CCBA><EFBFBD>
|
|
|
|
|
|
void ThreadCalcDist(TPCorDist& param) {
|
|
|
|
|
|
vector<vector<float>>& vvX = *param.pvvX;
|
|
|
|
|
|
vector<double>& vDist = *param.pvDist;
|
|
|
|
|
|
vector<double>& vSq = *param.pvSq;
|
|
|
|
|
|
int rowIdxStart = param.rowIdxStart;
|
|
|
|
|
|
int rowIdxEnd = param.rowIdxEnd;
|
|
|
|
|
|
int rowNum = param.rowNum;
|
|
|
|
|
|
int colNum = param.colNum;
|
|
|
|
|
|
double uv = 0;
|
|
|
|
|
|
clock_t begin = clock(), finish;
|
|
|
|
|
|
__m256 vec_zero = _mm256_set1_ps(0);
|
|
|
|
|
|
for (int i = rowIdxStart; i < rowIdxEnd; ++i) {
|
|
|
|
|
|
const int baseIdx = i * (rowNum * 2 - i - 1) / 2;
|
|
|
|
|
|
for (int cur = i + 1, idx = 0; cur < rowNum; ++cur, ++idx) {
|
|
|
|
|
|
double uv = 0;
|
|
|
|
|
|
//for (int j = 0; j < colNum; ++j) { uv += vvX[i][j] * vvX[cur][j]; }
|
|
|
|
|
|
__m256 vec_uv = vec_zero;
|
|
|
|
|
|
int colNum8 = colNum / 8 * 8;
|
|
|
|
|
|
int remain = colNum - colNum8;
|
|
|
|
|
|
for (int j = 0; j < colNum8; j += 8) {
|
|
|
|
|
|
__m256 vec_u = _mm256_loadu_ps(vvX[i].data() + j);
|
|
|
|
|
|
__m256 vec_v = _mm256_loadu_ps(vvX[cur].data() + j);
|
|
|
|
|
|
vec_uv = _mm256_add_ps(_mm256_mul_ps(vec_u, vec_v), vec_uv);
|
|
|
|
|
|
}
|
|
|
|
|
|
for (int j = colNum8; j < colNum; ++j) {
|
|
|
|
|
|
uv += vvX[i][j] * vvX[cur][j];
|
|
|
|
|
|
}
|
|
|
|
|
|
float* pVec = (float*)&vec_uv;
|
|
|
|
|
|
for (int j = 0; j < 8; ++j) uv += pVec[j];
|
|
|
|
|
|
|
|
|
|
|
|
uv /= colNum;
|
|
|
|
|
|
const double dist = abs(1.0 - uv / sqrt(vSq[i] * vSq[cur]));
|
|
|
|
|
|
vDist[baseIdx + idx] = dist;
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
finish = clock();
|
|
|
|
|
|
//cout << rowIdxEnd << " Thread time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
/* <20><><EFBFBD>ں<EFBFBD><DABA><EFBFBD> */
|
2023-10-05 23:12:02 +08:00
|
|
|
|
// void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
|
|
|
|
|
|
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
|
|
|
|
|
|
if (nrhs < 1) {
|
|
|
|
|
|
cout << "At least 1 arguments should be given for this function!" << endl;
|
|
|
|
|
|
return;
|
|
|
|
|
|
}
|
2023-10-05 10:38:21 +08:00
|
|
|
|
clock_t begin = clock(), mid, finish;
|
|
|
|
|
|
|
|
|
|
|
|
int rowNum = (int)mxGetM(prhs[0]);
|
|
|
|
|
|
int colNum = (int)mxGetN(prhs[0]);
|
|
|
|
|
|
double* pData = (double*)mxGetData(prhs[0]);
|
2023-10-05 23:12:02 +08:00
|
|
|
|
|
|
|
|
|
|
int numThread = 6;
|
|
|
|
|
|
if (nrhs > 1) {
|
|
|
|
|
|
double* pNumThread = (double*)mxGetData(prhs[1]);
|
|
|
|
|
|
numThread = (int)pNumThread[0];
|
|
|
|
|
|
if (numThread < 1) numThread = 1;
|
|
|
|
|
|
}
|
|
|
|
|
|
int numGroup = 1;
|
|
|
|
|
|
if (nrhs > 2) {
|
|
|
|
|
|
double* pNumGroup = (double*)mxGetData(prhs[2]);
|
|
|
|
|
|
numGroup = (int)pNumGroup[0];
|
|
|
|
|
|
if (numGroup < 1) numGroup = 1;
|
|
|
|
|
|
}
|
2023-10-05 10:38:21 +08:00
|
|
|
|
|
2023-10-05 23:12:02 +08:00
|
|
|
|
//cout << numThread << '\t' << numGroup << endl;
|
2023-10-05 10:38:21 +08:00
|
|
|
|
|
|
|
|
|
|
//for (int i = 0; i < rowNum; ++i) {
|
|
|
|
|
|
// for (int j = 0; j < colNum; ++j) {
|
|
|
|
|
|
// cout << pData[j * rowNum + i] << '\t';
|
|
|
|
|
|
// }
|
|
|
|
|
|
// cout << endl;
|
|
|
|
|
|
//}
|
|
|
|
|
|
|
|
|
|
|
|
// int rowNum = 2;
|
|
|
|
|
|
// int colNum = 5;
|
|
|
|
|
|
// vector<double> pData = { 1,1,2,2,3,9,4,4,5,4 };
|
|
|
|
|
|
|
2023-10-05 23:12:02 +08:00
|
|
|
|
//cout << rowNum << '\t' << colNum << endl;
|
2023-10-05 10:38:21 +08:00
|
|
|
|
|
|
|
|
|
|
/* <20><><EFBFBD><EFBFBD>ÿһ<C3BF>е<EFBFBD>ƽ<EFBFBD><C6BD><EFBFBD><EFBFBD> */
|
|
|
|
|
|
mid = clock();
|
|
|
|
|
|
vector<double> vMean(rowNum);
|
|
|
|
|
|
for (int j = 0; j < colNum; ++j) {
|
|
|
|
|
|
for (int i = 0; i < rowNum; ++i) {
|
|
|
|
|
|
vMean[i] += pData[j * rowNum + i];
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
for (int i = 0; i < rowNum; ++i) {
|
|
|
|
|
|
vMean[i] /= colNum;
|
|
|
|
|
|
}
|
|
|
|
|
|
finish = clock();
|
2023-10-05 23:12:02 +08:00
|
|
|
|
//cout << "<22><><EFBFBD><EFBFBD>ƽ<EFBFBD><C6BD><EFBFBD><EFBFBD>: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
|
2023-10-05 10:38:21 +08:00
|
|
|
|
/* <20><>ȥƽ<C8A5><C6BD>ֵ, <20><><EFBFBD><EFBFBD>ƽ<EFBFBD><C6BD> */
|
|
|
|
|
|
mid = clock();
|
|
|
|
|
|
vector<vector<float>> vvX(rowNum, vector<float>(colNum));
|
|
|
|
|
|
vector<double> vSq(rowNum);
|
|
|
|
|
|
for (int i = 0; i < rowNum; ++i) {
|
|
|
|
|
|
for (int j = 0; j < colNum; ++j) {
|
|
|
|
|
|
vvX[i][j] = pData[j * rowNum + i] - vMean[i];
|
|
|
|
|
|
vSq[i] += vvX[i][j] * vvX[i][j];
|
|
|
|
|
|
}
|
|
|
|
|
|
}
|
|
|
|
|
|
for (auto& val : vSq) { val /= colNum; }
|
|
|
|
|
|
finish = clock();
|
2023-10-05 23:12:02 +08:00
|
|
|
|
//cout << "<22><><EFBFBD><EFBFBD>ƽ<EFBFBD><C6BD>: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
|
2023-10-05 10:38:21 +08:00
|
|
|
|
|
|
|
|
|
|
/* <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><D8BE><EFBFBD> */
|
|
|
|
|
|
// clock_t mid0 = clock();
|
|
|
|
|
|
// const int distSize = rowNum * (rowNum - 1) / 2;
|
|
|
|
|
|
// const int row2 = 2 * rowNum;
|
|
|
|
|
|
// vector<double> vDist(distSize, 0.0);
|
|
|
|
|
|
// //__m256d vec_zero = _mm256_set1_pd(0);
|
|
|
|
|
|
// __m256 vec_zero = _mm256_set1_ps(0);
|
|
|
|
|
|
// for (int i = 0; i < rowNum; ++i) {
|
|
|
|
|
|
// if (i % 100 == 0) {
|
|
|
|
|
|
// finish = clock();
|
|
|
|
|
|
// cout << "time " << i << ": " << (double)(finish - mid) / CLOCKS_PER_SEC << " s; " <<
|
|
|
|
|
|
// (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
|
|
|
|
|
|
// mid = clock();
|
|
|
|
|
|
// }
|
|
|
|
|
|
// const int baseIdx = i * (row2 - i - 1) / 2;
|
|
|
|
|
|
// for (int cur = i + 1, idx = 0; cur < rowNum; ++cur, ++idx) {
|
|
|
|
|
|
// double uv = 0;
|
|
|
|
|
|
// //for (int j = 0; j < colNum; ++j) { uv += vvX[i][j] * vvX[cur][j]; }
|
|
|
|
|
|
// //__m256d vec_uv = vec_zero;
|
|
|
|
|
|
// //int colNum4 = colNum / 4 * 4;
|
|
|
|
|
|
// //int remain = colNum - colNum4;
|
|
|
|
|
|
// //for (int j = 0; j < colNum4; j += 4) {
|
|
|
|
|
|
// // __m256d vec_u = _mm256_loadu_pd(vvX[i].data() + j);
|
|
|
|
|
|
// // __m256d vec_v = _mm256_loadu_pd(vvX[cur].data() + j);
|
|
|
|
|
|
// // vec_uv = _mm256_add_pd(_mm256_mul_pd(vec_u, vec_v), vec_uv);
|
|
|
|
|
|
// //}
|
|
|
|
|
|
// //for (int j = colNum4; j < colNum; ++j) {
|
|
|
|
|
|
// // uv += vvX[i][j] * vvX[cur][j];
|
|
|
|
|
|
// //}
|
|
|
|
|
|
// //double* pVec = (double*)&vec_uv;
|
|
|
|
|
|
// //for (int j = 0; j < 4; ++j) uv += pVec[j];
|
|
|
|
|
|
//
|
|
|
|
|
|
// __m256 vec_uv = vec_zero;
|
|
|
|
|
|
// int colNum8 = colNum / 8 * 8;
|
|
|
|
|
|
// int remain = colNum - colNum8;
|
|
|
|
|
|
// for (int j = 0; j < colNum8; j += 8) {
|
|
|
|
|
|
// __m256 vec_u = _mm256_loadu_ps(vvX[i].data() + j);
|
|
|
|
|
|
// __m256 vec_v = _mm256_loadu_ps(vvX[cur].data() + j);
|
|
|
|
|
|
// vec_uv = _mm256_add_ps(_mm256_mul_ps(vec_u, vec_v), vec_uv);
|
|
|
|
|
|
// }
|
|
|
|
|
|
// for (int j = colNum8; j < colNum; ++j) {
|
|
|
|
|
|
// uv += vvX[i][j] * vvX[cur][j];
|
|
|
|
|
|
// }
|
|
|
|
|
|
// float* pVec = (float*)&vec_uv;
|
|
|
|
|
|
// for (int j = 0; j < 8; ++j) uv += pVec[j];
|
|
|
|
|
|
//
|
|
|
|
|
|
// uv /= colNum;
|
|
|
|
|
|
// const double dist = abs(1.0 - uv / sqrt(vSq[i] * vSq[cur]));
|
|
|
|
|
|
// vDist[baseIdx + idx] = dist;
|
|
|
|
|
|
// }
|
|
|
|
|
|
// }
|
|
|
|
|
|
// finish = clock();
|
|
|
|
|
|
// cout << "<22><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><D8BE><EFBFBD>: " << (double)(finish - mid0) / CLOCKS_PER_SEC << " s" << endl;
|
|
|
|
|
|
// for (auto& val : vDist) {cout << val << endl;}
|
|
|
|
|
|
|
|
|
|
|
|
/* <20><><EFBFBD>̼߳<DFB3><CCBC><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><D8BE><EFBFBD> */
|
|
|
|
|
|
const int distSize = rowNum * (rowNum - 1) / 2;
|
|
|
|
|
|
const int row2 = 2 * rowNum;
|
|
|
|
|
|
vector<double> vDist(distSize, 0.0);
|
|
|
|
|
|
mid = clock();
|
2023-10-05 23:12:02 +08:00
|
|
|
|
// vector<TPCorDist> vTP;
|
|
|
|
|
|
ThreadPool thPool(numThread);
|
|
|
|
|
|
int span = numGroup;
|
2023-10-05 10:38:21 +08:00
|
|
|
|
int rowNumSpan = rowNum / span * span;
|
|
|
|
|
|
int i = 0;
|
|
|
|
|
|
for (i = 0; i < rowNumSpan; i += span) {
|
|
|
|
|
|
// vTP.push_back({&vvX, &vDist, &vSq, i, rowNum, colNum});
|
|
|
|
|
|
TPCorDist tp = { &vvX, &vDist, &vSq, i, i + span, rowNum, colNum };
|
|
|
|
|
|
thPool.enqueue(ThreadCalcDist, tp);
|
|
|
|
|
|
//vTP.push_back(tp);
|
|
|
|
|
|
}
|
|
|
|
|
|
if (i < rowNum) {
|
|
|
|
|
|
TPCorDist tp = { &vvX, &vDist, &vSq, i, rowNum, rowNum, colNum };
|
|
|
|
|
|
thPool.enqueue(ThreadCalcDist, tp);
|
|
|
|
|
|
//vTP.push_back(tp);
|
|
|
|
|
|
}
|
|
|
|
|
|
thPool.~ThreadPool();
|
|
|
|
|
|
//kt_for(6, ThreadCalcDist, vTP);
|
|
|
|
|
|
|
|
|
|
|
|
finish = clock();
|
2023-10-05 23:12:02 +08:00
|
|
|
|
//cout << "<22><><EFBFBD>̼߳<DFB3><CCBC><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><D8BE><EFBFBD>: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
|
2023-10-05 10:38:21 +08:00
|
|
|
|
// for (auto& val : vDist) {cout << val << endl;}
|
|
|
|
|
|
|
|
|
|
|
|
/* д<><D0B4><EFBFBD><EFBFBD><EFBFBD><EFBFBD> */
|
|
|
|
|
|
if (nlhs > 0) { // b
|
|
|
|
|
|
mxArray* pWriteArray = NULL;
|
|
|
|
|
|
//<2F><><EFBFBD><EFBFBD>һ<EFBFBD><D2BB>rowNum*colNum<75>ľ<EFBFBD><C4BE><EFBFBD>
|
|
|
|
|
|
pWriteArray = mxCreateDoubleMatrix(1, distSize, mxREAL);
|
2023-10-05 23:12:02 +08:00
|
|
|
|
memcpy((void*)(mxGetPr(pWriteArray)), (void*)vDist.data(), sizeof(double) * distSize);
|
2023-10-05 10:38:21 +08:00
|
|
|
|
plhs[0] = pWriteArray; // <20><>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ
|
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
|
|
finish = clock();
|
|
|
|
|
|
cout << "Correlation Dist Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
|
|
|
|
|
|
}
|