224 lines
6.8 KiB
C++
224 lines
6.8 KiB
C++
#include <mex.h>
|
|
#include <mat.h>
|
|
#include <iostream>
|
|
#include <algorithm>
|
|
#include <vector>
|
|
#include <string>
|
|
#include <unordered_set>
|
|
#include <ctime>
|
|
#include <immintrin.h>
|
|
#include <zmmintrin.h>
|
|
#include "CommonLib/kthread.h"
|
|
#include "CommonLib/thread_pool.h"
|
|
using std::cout;
|
|
using std::endl;
|
|
using namespace std;
|
|
|
|
// 线程参数
|
|
struct TPCorDist {
|
|
vector<vector<float>>* pvvX;
|
|
vector<double>* pvDist;
|
|
vector<double>* pvSq;
|
|
int rowIdxStart;
|
|
int rowIdxEnd;
|
|
int rowNum;
|
|
int colNum;
|
|
};
|
|
|
|
// 线程函数
|
|
void ThreadCalcDist(TPCorDist& param) {
|
|
vector<vector<float>>& vvX = *param.pvvX;
|
|
vector<double>& vDist = *param.pvDist;
|
|
vector<double>& vSq = *param.pvSq;
|
|
int rowIdxStart = param.rowIdxStart;
|
|
int rowIdxEnd = param.rowIdxEnd;
|
|
int rowNum = param.rowNum;
|
|
int colNum = param.colNum;
|
|
double uv = 0;
|
|
clock_t begin = clock(), finish;
|
|
__m256 vec_zero = _mm256_set1_ps(0);
|
|
for (int i = rowIdxStart; i < rowIdxEnd; ++i) {
|
|
const int baseIdx = i * (rowNum * 2 - i - 1) / 2;
|
|
for (int cur = i + 1, idx = 0; cur < rowNum; ++cur, ++idx) {
|
|
double uv = 0;
|
|
//for (int j = 0; j < colNum; ++j) { uv += vvX[i][j] * vvX[cur][j]; }
|
|
__m256 vec_uv = vec_zero;
|
|
int colNum8 = colNum / 8 * 8;
|
|
int remain = colNum - colNum8;
|
|
for (int j = 0; j < colNum8; j += 8) {
|
|
__m256 vec_u = _mm256_loadu_ps(vvX[i].data() + j);
|
|
__m256 vec_v = _mm256_loadu_ps(vvX[cur].data() + j);
|
|
vec_uv = _mm256_add_ps(_mm256_mul_ps(vec_u, vec_v), vec_uv);
|
|
}
|
|
for (int j = colNum8; j < colNum; ++j) {
|
|
uv += vvX[i][j] * vvX[cur][j];
|
|
}
|
|
float* pVec = (float*)&vec_uv;
|
|
for (int j = 0; j < 8; ++j) uv += pVec[j];
|
|
|
|
uv /= colNum;
|
|
const double dist = abs(1.0 - uv / sqrt(vSq[i] * vSq[cur]));
|
|
vDist[baseIdx + idx] = dist;
|
|
}
|
|
}
|
|
finish = clock();
|
|
//cout << rowIdxEnd << " Thread time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
|
|
}
|
|
|
|
/* 入口函数 */
|
|
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
|
|
//cout << "WordSplit" << endl;
|
|
//cout << nlhs << '\t' << nrhs << endl;
|
|
//if (nrhs < 1) {
|
|
// cout << "At least 1 arguments should be given for this function!" << endl;
|
|
// return;
|
|
//}
|
|
clock_t begin = clock(), mid, finish;
|
|
|
|
int rowNum = (int)mxGetM(prhs[0]);
|
|
int colNum = (int)mxGetN(prhs[0]);
|
|
double* pData = (double*)mxGetData(prhs[0]);
|
|
//
|
|
//int numThread = 1;
|
|
//if (nrhs > 1) {
|
|
// double* pNumThread = (double*)mxGetData(prhs[1]);
|
|
// numThread = (int)*pNumThread;
|
|
// if (numThread < 1) numThread = 1;
|
|
//}
|
|
|
|
|
|
//for (int i = 0; i < rowNum; ++i) {
|
|
// for (int j = 0; j < colNum; ++j) {
|
|
// cout << pData[j * rowNum + i] << '\t';
|
|
// }
|
|
// cout << endl;
|
|
//}
|
|
|
|
// int rowNum = 2;
|
|
// int colNum = 5;
|
|
// vector<double> pData = { 1,1,2,2,3,9,4,4,5,4 };
|
|
|
|
cout << rowNum << '\t' << colNum << endl;
|
|
|
|
/* 计算每一行的平均数 */
|
|
mid = clock();
|
|
vector<double> vMean(rowNum);
|
|
for (int j = 0; j < colNum; ++j) {
|
|
for (int i = 0; i < rowNum; ++i) {
|
|
vMean[i] += pData[j * rowNum + i];
|
|
}
|
|
}
|
|
for (int i = 0; i < rowNum; ++i) {
|
|
vMean[i] /= colNum;
|
|
}
|
|
finish = clock();
|
|
cout << "计算平均数: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
|
|
/* 减去平均值, 计算平方 */
|
|
mid = clock();
|
|
vector<vector<float>> vvX(rowNum, vector<float>(colNum));
|
|
vector<double> vSq(rowNum);
|
|
for (int i = 0; i < rowNum; ++i) {
|
|
for (int j = 0; j < colNum; ++j) {
|
|
vvX[i][j] = pData[j * rowNum + i] - vMean[i];
|
|
vSq[i] += vvX[i][j] * vvX[i][j];
|
|
}
|
|
}
|
|
for (auto& val : vSq) { val /= colNum; }
|
|
finish = clock();
|
|
cout << "计算平方: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
|
|
|
|
/* 计算相关距离 */
|
|
// clock_t mid0 = clock();
|
|
// const int distSize = rowNum * (rowNum - 1) / 2;
|
|
// const int row2 = 2 * rowNum;
|
|
// vector<double> vDist(distSize, 0.0);
|
|
// //__m256d vec_zero = _mm256_set1_pd(0);
|
|
// __m256 vec_zero = _mm256_set1_ps(0);
|
|
// for (int i = 0; i < rowNum; ++i) {
|
|
// if (i % 100 == 0) {
|
|
// finish = clock();
|
|
// cout << "time " << i << ": " << (double)(finish - mid) / CLOCKS_PER_SEC << " s; " <<
|
|
// (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
|
|
// mid = clock();
|
|
// }
|
|
// const int baseIdx = i * (row2 - i - 1) / 2;
|
|
// for (int cur = i + 1, idx = 0; cur < rowNum; ++cur, ++idx) {
|
|
// double uv = 0;
|
|
// //for (int j = 0; j < colNum; ++j) { uv += vvX[i][j] * vvX[cur][j]; }
|
|
// //__m256d vec_uv = vec_zero;
|
|
// //int colNum4 = colNum / 4 * 4;
|
|
// //int remain = colNum - colNum4;
|
|
// //for (int j = 0; j < colNum4; j += 4) {
|
|
// // __m256d vec_u = _mm256_loadu_pd(vvX[i].data() + j);
|
|
// // __m256d vec_v = _mm256_loadu_pd(vvX[cur].data() + j);
|
|
// // vec_uv = _mm256_add_pd(_mm256_mul_pd(vec_u, vec_v), vec_uv);
|
|
// //}
|
|
// //for (int j = colNum4; j < colNum; ++j) {
|
|
// // uv += vvX[i][j] * vvX[cur][j];
|
|
// //}
|
|
// //double* pVec = (double*)&vec_uv;
|
|
// //for (int j = 0; j < 4; ++j) uv += pVec[j];
|
|
//
|
|
// __m256 vec_uv = vec_zero;
|
|
// int colNum8 = colNum / 8 * 8;
|
|
// int remain = colNum - colNum8;
|
|
// for (int j = 0; j < colNum8; j += 8) {
|
|
// __m256 vec_u = _mm256_loadu_ps(vvX[i].data() + j);
|
|
// __m256 vec_v = _mm256_loadu_ps(vvX[cur].data() + j);
|
|
// vec_uv = _mm256_add_ps(_mm256_mul_ps(vec_u, vec_v), vec_uv);
|
|
// }
|
|
// for (int j = colNum8; j < colNum; ++j) {
|
|
// uv += vvX[i][j] * vvX[cur][j];
|
|
// }
|
|
// float* pVec = (float*)&vec_uv;
|
|
// for (int j = 0; j < 8; ++j) uv += pVec[j];
|
|
//
|
|
// uv /= colNum;
|
|
// const double dist = abs(1.0 - uv / sqrt(vSq[i] * vSq[cur]));
|
|
// vDist[baseIdx + idx] = dist;
|
|
// }
|
|
// }
|
|
// finish = clock();
|
|
// cout << "计算相关距离: " << (double)(finish - mid0) / CLOCKS_PER_SEC << " s" << endl;
|
|
// for (auto& val : vDist) {cout << val << endl;}
|
|
|
|
/* 多线程计算相关距离 */
|
|
const int distSize = rowNum * (rowNum - 1) / 2;
|
|
const int row2 = 2 * rowNum;
|
|
vector<double> vDist(distSize, 0.0);
|
|
mid = clock();
|
|
vector<TPCorDist> vTP;
|
|
ThreadPool thPool(6);
|
|
int span = 32;
|
|
int rowNumSpan = rowNum / span * span;
|
|
int i = 0;
|
|
for (i = 0; i < rowNumSpan; i += span) {
|
|
// vTP.push_back({&vvX, &vDist, &vSq, i, rowNum, colNum});
|
|
TPCorDist tp = { &vvX, &vDist, &vSq, i, i + span, rowNum, colNum };
|
|
thPool.enqueue(ThreadCalcDist, tp);
|
|
//vTP.push_back(tp);
|
|
}
|
|
if (i < rowNum) {
|
|
TPCorDist tp = { &vvX, &vDist, &vSq, i, rowNum, rowNum, colNum };
|
|
thPool.enqueue(ThreadCalcDist, tp);
|
|
//vTP.push_back(tp);
|
|
}
|
|
thPool.~ThreadPool();
|
|
//kt_for(6, ThreadCalcDist, vTP);
|
|
|
|
finish = clock();
|
|
cout << "多线程计算相关距离: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
|
|
// for (auto& val : vDist) {cout << val << endl;}
|
|
|
|
/* 写入结果 */
|
|
if (nlhs > 0) { // b
|
|
mxArray* pWriteArray = NULL;
|
|
//创建一个rowNum*colNum的矩阵
|
|
pWriteArray = mxCreateDoubleMatrix(1, distSize, mxREAL);
|
|
memcpy((void*)(mxGetPr(pWriteArray)), (void*)vDist.data(), sizeof(double) * 6);
|
|
plhs[0] = pWriteArray; // 赋值给返回值
|
|
}
|
|
|
|
finish = clock();
|
|
cout << "Correlation Dist Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
|
|
} |