#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using std::cout; using std::endl; using namespace std; class ThreadPool { public: ThreadPool(size_t); template auto enqueue(F&& f, Args&&... args) ->std::future::type>; ~ThreadPool(); private: // need to keep track of threads so we can join them std::vector< std::thread > workers; // the task queue std::queue< std::function > tasks; // synchronization std::mutex queue_mutex; std::condition_variable condition; bool stop; }; // the constructor just launches some amount of workers inline ThreadPool::ThreadPool(size_t threads) : stop(false) { for (size_t i = 0;i < threads;++i) workers.emplace_back( [this] { for (;;) { std::function task; { std::unique_lock lock(this->queue_mutex); this->condition.wait(lock, [this] { return this->stop || !this->tasks.empty(); }); if (this->stop && this->tasks.empty()) return; task = std::move(this->tasks.front()); this->tasks.pop(); } task(); } } ); } // add new work item to the pool template auto ThreadPool::enqueue(F&& f, Args&&... args) -> std::future::type> { using return_type = typename std::result_of::type; auto task = std::make_shared< std::packaged_task >( std::bind(std::forward(f), std::forward(args)...) ); std::future res = task->get_future(); { std::unique_lock lock(queue_mutex); // don't allow enqueueing after stopping the pool if (stop) throw std::runtime_error("enqueue on stopped ThreadPool"); tasks.emplace([task]() { (*task)(); }); } condition.notify_one(); return res; } // the destructor joins all threads inline ThreadPool::~ThreadPool() { { std::unique_lock lock(queue_mutex); stop = true; } condition.notify_all(); for (std::thread& worker : workers) worker.join(); } // 线程参数 struct TPCorDist { vector>* pvvX; vector* pvDist; vector* pvSq; int rowIdxStart; int rowIdxEnd; int rowNum; int colNum; }; // 线程函数 void ThreadCalcDist(TPCorDist& param) { vector>& vvX = *param.pvvX; vector& vDist = *param.pvDist; vector& vSq = *param.pvSq; int rowIdxStart = param.rowIdxStart; int rowIdxEnd = param.rowIdxEnd; int rowNum = param.rowNum; int colNum = param.colNum; double uv = 0; clock_t begin = clock(), finish; __m256 vec_zero = _mm256_set1_ps(0); for (int i = rowIdxStart; i < rowIdxEnd; ++i) { const int baseIdx = i * (rowNum * 2 - i - 1) / 2; for (int cur = i + 1, idx = 0; cur < rowNum; ++cur, ++idx) { double uv = 0; //for (int j = 0; j < colNum; ++j) { uv += vvX[i][j] * vvX[cur][j]; } __m256 vec_uv = vec_zero; int colNum8 = colNum / 8 * 8; int remain = colNum - colNum8; for (int j = 0; j < colNum8; j += 8) { __m256 vec_u = _mm256_loadu_ps(vvX[i].data() + j); __m256 vec_v = _mm256_loadu_ps(vvX[cur].data() + j); vec_uv = _mm256_add_ps(_mm256_mul_ps(vec_u, vec_v), vec_uv); } for (int j = colNum8; j < colNum; ++j) { uv += vvX[i][j] * vvX[cur][j]; } float* pVec = (float*)&vec_uv; for (int j = 0; j < 8; ++j) uv += pVec[j]; uv /= colNum; const double dist = abs(1.0 - uv / sqrt(vSq[i] * vSq[cur])); vDist[baseIdx + idx] = dist; } } finish = clock(); //cout << rowIdxEnd << " Thread time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; } /* 入口函数 */ /* 输入: 1. x: 二维。 [2]. numThread: 线程数。 [3]. numGroup: 每次线程函数处理的数据量。 输出: 1. d: 相关距离 */ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { if (nrhs < 1) { cout << "At least 1 arguments should be given for this function!" << endl; return; } clock_t begin = clock(), mid, finish; int rowNum = (int)mxGetM(prhs[0]); int colNum = (int)mxGetN(prhs[0]); double* pData = (double*)mxGetData(prhs[0]); int numThread = 6; if (nrhs > 1) { double* pNumThread = (double*)mxGetData(prhs[1]); numThread = (int)pNumThread[0]; if (numThread < 1) numThread = 1; } int numGroup = 1; if (nrhs > 2) { double* pNumGroup = (double*)mxGetData(prhs[2]); numGroup = (int)pNumGroup[0]; if (numGroup < 1) numGroup = 1; } //cout << numThread << '\t' << numGroup << endl; //for (int i = 0; i < rowNum; ++i) { // for (int j = 0; j < colNum; ++j) { // cout << pData[j * rowNum + i] << '\t'; // } // cout << endl; //} // int rowNum = 2; // int colNum = 5; // vector pData = { 1,1,2,2,3,9,4,4,5,4 }; //cout << rowNum << '\t' << colNum << endl; /* 计算每一行的平均数 */ mid = clock(); vector vMean(rowNum); for (int j = 0; j < colNum; ++j) { for (int i = 0; i < rowNum; ++i) { vMean[i] += pData[j * rowNum + i]; } } for (int i = 0; i < rowNum; ++i) { vMean[i] /= colNum; } finish = clock(); //cout << "计算平均数: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; /* 减去平均值, 计算平方 */ mid = clock(); vector> vvX(rowNum, vector(colNum)); vector vSq(rowNum); for (int i = 0; i < rowNum; ++i) { for (int j = 0; j < colNum; ++j) { vvX[i][j] = pData[j * rowNum + i] - vMean[i]; vSq[i] += vvX[i][j] * vvX[i][j]; } } for (auto& val : vSq) { val /= colNum; } finish = clock(); //cout << "计算平方: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; /* 计算相关距离 */ // clock_t mid0 = clock(); // const int distSize = rowNum * (rowNum - 1) / 2; // const int row2 = 2 * rowNum; // vector vDist(distSize, 0.0); // //__m256d vec_zero = _mm256_set1_pd(0); // __m256 vec_zero = _mm256_set1_ps(0); // for (int i = 0; i < rowNum; ++i) { // if (i % 100 == 0) { // finish = clock(); // cout << "time " << i << ": " << (double)(finish - mid) / CLOCKS_PER_SEC << " s; " << // (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; // mid = clock(); // } // const int baseIdx = i * (row2 - i - 1) / 2; // for (int cur = i + 1, idx = 0; cur < rowNum; ++cur, ++idx) { // double uv = 0; // //for (int j = 0; j < colNum; ++j) { uv += vvX[i][j] * vvX[cur][j]; } // //__m256d vec_uv = vec_zero; // //int colNum4 = colNum / 4 * 4; // //int remain = colNum - colNum4; // //for (int j = 0; j < colNum4; j += 4) { // // __m256d vec_u = _mm256_loadu_pd(vvX[i].data() + j); // // __m256d vec_v = _mm256_loadu_pd(vvX[cur].data() + j); // // vec_uv = _mm256_add_pd(_mm256_mul_pd(vec_u, vec_v), vec_uv); // //} // //for (int j = colNum4; j < colNum; ++j) { // // uv += vvX[i][j] * vvX[cur][j]; // //} // //double* pVec = (double*)&vec_uv; // //for (int j = 0; j < 4; ++j) uv += pVec[j]; // // __m256 vec_uv = vec_zero; // int colNum8 = colNum / 8 * 8; // int remain = colNum - colNum8; // for (int j = 0; j < colNum8; j += 8) { // __m256 vec_u = _mm256_loadu_ps(vvX[i].data() + j); // __m256 vec_v = _mm256_loadu_ps(vvX[cur].data() + j); // vec_uv = _mm256_add_ps(_mm256_mul_ps(vec_u, vec_v), vec_uv); // } // for (int j = colNum8; j < colNum; ++j) { // uv += vvX[i][j] * vvX[cur][j]; // } // float* pVec = (float*)&vec_uv; // for (int j = 0; j < 8; ++j) uv += pVec[j]; // // uv /= colNum; // const double dist = abs(1.0 - uv / sqrt(vSq[i] * vSq[cur])); // vDist[baseIdx + idx] = dist; // } // } // finish = clock(); // cout << "计算相关距离: " << (double)(finish - mid0) / CLOCKS_PER_SEC << " s" << endl; // for (auto& val : vDist) {cout << val << endl;} /* 多线程计算相关距离 */ const int distSize = rowNum * (rowNum - 1) / 2; const int row2 = 2 * rowNum; vector vDist(distSize, 0.0); mid = clock(); // vector vTP; ThreadPool thPool(numThread); int span = numGroup; int rowNumSpan = rowNum / span * span; int i = 0; for (i = 0; i < rowNumSpan; i += span) { // vTP.push_back({&vvX, &vDist, &vSq, i, rowNum, colNum}); TPCorDist tp = { &vvX, &vDist, &vSq, i, i + span, rowNum, colNum }; thPool.enqueue(ThreadCalcDist, tp); //vTP.push_back(tp); } if (i < rowNum) { TPCorDist tp = { &vvX, &vDist, &vSq, i, rowNum, rowNum, colNum }; thPool.enqueue(ThreadCalcDist, tp); //vTP.push_back(tp); } thPool.~ThreadPool(); //kt_for(6, ThreadCalcDist, vTP); finish = clock(); //cout << "多线程计算相关距离: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; // for (auto& val : vDist) {cout << val << endl;} /* 写入结果 */ if (nlhs > 0) { // b mxArray* pWriteArray = NULL; //创建一个rowNum*colNum的矩阵 pWriteArray = mxCreateDoubleMatrix(1, distSize, mxREAL); memcpy((void*)(mxGetPr(pWriteArray)), (void*)vDist.data(), sizeof(double) * distSize); plhs[0] = pWriteArray; // 赋值给返回值 } finish = clock(); cout << "Correlation Dist Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; } /* 供main调试调用 */ void mexFunctionWrap(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { mexFunction(nlhs, plhs, nrhs, prhs); }