twirls/MexFunc/CorrelationDist.cpp

324 lines
9.1 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <mex.h>
#include <mat.h>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <unordered_set>
#include <ctime>
#include <immintrin.h>
#include <zmmintrin.h>
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
// #include "CommonLib/kthread.h"
// #include "CommonLib/thread_pool.h"
using std::cout;
using std::endl;
using namespace std;
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
->std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for (size_t i = 0;i < threads;++i)
workers.emplace_back(
[this]
{
for (;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if (stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread& worker : workers)
worker.join();
}
// <20>̲߳<DFB3><CCB2><EFBFBD>
struct TPCorDist {
vector<vector<float>>* pvvX;
vector<double>* pvDist;
vector<double>* pvSq;
int rowIdxStart;
int rowIdxEnd;
int rowNum;
int colNum;
};
// <20>̺߳<DFB3><CCBA><EFBFBD>
void ThreadCalcDist(TPCorDist& param) {
vector<vector<float>>& vvX = *param.pvvX;
vector<double>& vDist = *param.pvDist;
vector<double>& vSq = *param.pvSq;
int rowIdxStart = param.rowIdxStart;
int rowIdxEnd = param.rowIdxEnd;
int rowNum = param.rowNum;
int colNum = param.colNum;
double uv = 0;
clock_t begin = clock(), finish;
__m256 vec_zero = _mm256_set1_ps(0);
for (int i = rowIdxStart; i < rowIdxEnd; ++i) {
const int baseIdx = i * (rowNum * 2 - i - 1) / 2;
for (int cur = i + 1, idx = 0; cur < rowNum; ++cur, ++idx) {
double uv = 0;
//for (int j = 0; j < colNum; ++j) { uv += vvX[i][j] * vvX[cur][j]; }
__m256 vec_uv = vec_zero;
int colNum8 = colNum / 8 * 8;
int remain = colNum - colNum8;
for (int j = 0; j < colNum8; j += 8) {
__m256 vec_u = _mm256_loadu_ps(vvX[i].data() + j);
__m256 vec_v = _mm256_loadu_ps(vvX[cur].data() + j);
vec_uv = _mm256_add_ps(_mm256_mul_ps(vec_u, vec_v), vec_uv);
}
for (int j = colNum8; j < colNum; ++j) {
uv += vvX[i][j] * vvX[cur][j];
}
float* pVec = (float*)&vec_uv;
for (int j = 0; j < 8; ++j) uv += pVec[j];
uv /= colNum;
const double dist = abs(1.0 - uv / sqrt(vSq[i] * vSq[cur]));
vDist[baseIdx + idx] = dist;
}
}
finish = clock();
//cout << rowIdxEnd << " Thread time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}
/* <20><><EFBFBD>ں<EFBFBD><DABA><EFBFBD> */
// void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
if (nrhs < 1) {
cout << "At least 1 arguments should be given for this function!" << endl;
return;
}
clock_t begin = clock(), mid, finish;
int rowNum = (int)mxGetM(prhs[0]);
int colNum = (int)mxGetN(prhs[0]);
double* pData = (double*)mxGetData(prhs[0]);
int numThread = 6;
if (nrhs > 1) {
double* pNumThread = (double*)mxGetData(prhs[1]);
numThread = (int)pNumThread[0];
if (numThread < 1) numThread = 1;
}
int numGroup = 1;
if (nrhs > 2) {
double* pNumGroup = (double*)mxGetData(prhs[2]);
numGroup = (int)pNumGroup[0];
if (numGroup < 1) numGroup = 1;
}
//cout << numThread << '\t' << numGroup << endl;
//for (int i = 0; i < rowNum; ++i) {
// for (int j = 0; j < colNum; ++j) {
// cout << pData[j * rowNum + i] << '\t';
// }
// cout << endl;
//}
// int rowNum = 2;
// int colNum = 5;
// vector<double> pData = { 1,1,2,2,3,9,4,4,5,4 };
//cout << rowNum << '\t' << colNum << endl;
/* <20><><EFBFBD><EFBFBD>ÿһ<C3BF>е<EFBFBD>ƽ<EFBFBD><C6BD><EFBFBD><EFBFBD> */
mid = clock();
vector<double> vMean(rowNum);
for (int j = 0; j < colNum; ++j) {
for (int i = 0; i < rowNum; ++i) {
vMean[i] += pData[j * rowNum + i];
}
}
for (int i = 0; i < rowNum; ++i) {
vMean[i] /= colNum;
}
finish = clock();
//cout << "<22><><EFBFBD><EFBFBD>ƽ<EFBFBD><C6BD><EFBFBD><EFBFBD>: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* <20><>ȥƽ<C8A5><C6BD>ֵ, <20><><EFBFBD><EFBFBD>ƽ<EFBFBD><C6BD> */
mid = clock();
vector<vector<float>> vvX(rowNum, vector<float>(colNum));
vector<double> vSq(rowNum);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
vvX[i][j] = pData[j * rowNum + i] - vMean[i];
vSq[i] += vvX[i][j] * vvX[i][j];
}
}
for (auto& val : vSq) { val /= colNum; }
finish = clock();
//cout << "<22><><EFBFBD><EFBFBD>ƽ<EFBFBD><C6BD>: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><D8BE><EFBFBD> */
// clock_t mid0 = clock();
// const int distSize = rowNum * (rowNum - 1) / 2;
// const int row2 = 2 * rowNum;
// vector<double> vDist(distSize, 0.0);
// //__m256d vec_zero = _mm256_set1_pd(0);
// __m256 vec_zero = _mm256_set1_ps(0);
// for (int i = 0; i < rowNum; ++i) {
// if (i % 100 == 0) {
// finish = clock();
// cout << "time " << i << ": " << (double)(finish - mid) / CLOCKS_PER_SEC << " s; " <<
// (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
// mid = clock();
// }
// const int baseIdx = i * (row2 - i - 1) / 2;
// for (int cur = i + 1, idx = 0; cur < rowNum; ++cur, ++idx) {
// double uv = 0;
// //for (int j = 0; j < colNum; ++j) { uv += vvX[i][j] * vvX[cur][j]; }
// //__m256d vec_uv = vec_zero;
// //int colNum4 = colNum / 4 * 4;
// //int remain = colNum - colNum4;
// //for (int j = 0; j < colNum4; j += 4) {
// // __m256d vec_u = _mm256_loadu_pd(vvX[i].data() + j);
// // __m256d vec_v = _mm256_loadu_pd(vvX[cur].data() + j);
// // vec_uv = _mm256_add_pd(_mm256_mul_pd(vec_u, vec_v), vec_uv);
// //}
// //for (int j = colNum4; j < colNum; ++j) {
// // uv += vvX[i][j] * vvX[cur][j];
// //}
// //double* pVec = (double*)&vec_uv;
// //for (int j = 0; j < 4; ++j) uv += pVec[j];
//
// __m256 vec_uv = vec_zero;
// int colNum8 = colNum / 8 * 8;
// int remain = colNum - colNum8;
// for (int j = 0; j < colNum8; j += 8) {
// __m256 vec_u = _mm256_loadu_ps(vvX[i].data() + j);
// __m256 vec_v = _mm256_loadu_ps(vvX[cur].data() + j);
// vec_uv = _mm256_add_ps(_mm256_mul_ps(vec_u, vec_v), vec_uv);
// }
// for (int j = colNum8; j < colNum; ++j) {
// uv += vvX[i][j] * vvX[cur][j];
// }
// float* pVec = (float*)&vec_uv;
// for (int j = 0; j < 8; ++j) uv += pVec[j];
//
// uv /= colNum;
// const double dist = abs(1.0 - uv / sqrt(vSq[i] * vSq[cur]));
// vDist[baseIdx + idx] = dist;
// }
// }
// finish = clock();
// cout << "<22><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><D8BE><EFBFBD>: " << (double)(finish - mid0) / CLOCKS_PER_SEC << " s" << endl;
// for (auto& val : vDist) {cout << val << endl;}
/* <20><><EFBFBD>̼߳<DFB3><CCBC><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><D8BE><EFBFBD> */
const int distSize = rowNum * (rowNum - 1) / 2;
const int row2 = 2 * rowNum;
vector<double> vDist(distSize, 0.0);
mid = clock();
// vector<TPCorDist> vTP;
ThreadPool thPool(numThread);
int span = numGroup;
int rowNumSpan = rowNum / span * span;
int i = 0;
for (i = 0; i < rowNumSpan; i += span) {
// vTP.push_back({&vvX, &vDist, &vSq, i, rowNum, colNum});
TPCorDist tp = { &vvX, &vDist, &vSq, i, i + span, rowNum, colNum };
thPool.enqueue(ThreadCalcDist, tp);
//vTP.push_back(tp);
}
if (i < rowNum) {
TPCorDist tp = { &vvX, &vDist, &vSq, i, rowNum, rowNum, colNum };
thPool.enqueue(ThreadCalcDist, tp);
//vTP.push_back(tp);
}
thPool.~ThreadPool();
//kt_for(6, ThreadCalcDist, vTP);
finish = clock();
//cout << "<22><><EFBFBD>̼߳<DFB3><CCBC><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><D8BE><EFBFBD>: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
// for (auto& val : vDist) {cout << val << endl;}
/* д<><D0B4><EFBFBD><EFBFBD><EFBFBD><EFBFBD> */
if (nlhs > 0) { // b
mxArray* pWriteArray = NULL;
//<2F><><EFBFBD><EFBFBD>һ<EFBFBD><D2BB>rowNum*colNum<75>ľ<EFBFBD><C4BE><EFBFBD>
pWriteArray = mxCreateDoubleMatrix(1, distSize, mxREAL);
memcpy((void*)(mxGetPr(pWriteArray)), (void*)vDist.data(), sizeof(double) * distSize);
plhs[0] = pWriteArray; // <20><>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ
}
finish = clock();
cout << "Correlation Dist Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}