From 1b63656908bc8428485f751dff510b00c0eaede1 Mon Sep 17 00:00:00 2001 From: zzh Date: Thu, 5 Oct 2023 23:12:02 +0800 Subject: [PATCH] =?UTF-8?q?=E5=AE=8C=E6=88=90=E4=BA=86=E8=AE=A1=E7=AE=97?= =?UTF-8?q?=E7=9B=B8=E5=85=B3=E8=B7=9D=E7=A6=BB=E7=9A=84mex=E5=87=BD?= =?UTF-8?q?=E6=95=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- CppRun/calc_entropy.cpp | 26 +++--- CppRun/process_pubmed_txt.cpp | 15 ++-- GMM/main.cpp | 16 ++-- MexFunc/CorrelationDist.cpp | 148 ++++++++++++++++++++++++++++------ 4 files changed, 151 insertions(+), 54 deletions(-) diff --git a/CppRun/calc_entropy.cpp b/CppRun/calc_entropy.cpp index d00d3cc..08ebed7 100644 --- a/CppRun/calc_entropy.cpp +++ b/CppRun/calc_entropy.cpp @@ -303,7 +303,7 @@ void CalcEntropy(int argc, const char** argv) { /* 遍历所有的知识颗粒目录,逐一进行处理 */ begin = clock(); - //ThreadPool thPool(numThread); + ThreadPool thPool(numThread); // 查看知识颗粒数量 int numKnowledgeParticle = 0; FOREACH_PARTICLE_START @@ -311,20 +311,18 @@ void CalcEntropy(int argc, const char** argv) { FOREACH_PARTICLE_END // 遍历每个知识颗粒,逐一进行处理 - vector vTP; - for (int round = 0; round < 1; ++round) { // 测试用 - int i = 0; - FOREACH_PARTICLE_START - //ThreadParam tParam = { file, childDir / outFileName, &vusAbsWord }; - //thPool.enqueue(ThreadProcessData, tParam); - vTP.push_back({ file, childDir / outFileName, &vusAbsWord }); - i++; - FOREACH_PARTICLE_END - } - kt_for(numThread, ThreadProcessEntropy, vTP); - + // vector vTP; + + FOREACH_PARTICLE_START + ThreadParamEntropy tParam = { file, childDir / outFileName, &vusAbsWord }; + thPool.enqueue(ThreadProcessEntropy, tParam); + //vTP.push_back({ file, childDir / outFileName, &vusAbsWord }); + FOREACH_PARTICLE_END + // synchronize - //thPool.~ThreadPool(); + thPool.~ThreadPool(); + // kt_for(numThread, ThreadProcessEntropy, vTP); + finish = clock(); cout << "thread pool time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; diff --git a/CppRun/process_pubmed_txt.cpp b/CppRun/process_pubmed_txt.cpp index 3f780dd..4e21bc1 100644 --- a/CppRun/process_pubmed_txt.cpp +++ b/CppRun/process_pubmed_txt.cpp @@ -195,19 +195,18 @@ void ProcessPubmedTxt(int argc, const char** argv) { int numThread = 1; if (argc >= 5) numThread = atoi(argv[4]); if (numThread < 1) numThread = 1; - // ThreadPool thPool(numThread); + ThreadPool thPool(numThread); vumPaperTagVal.resize(vPaperStartIdx.size()-1); - vector vTP(vumPaperTagVal.size()); + // vector vTP(vumPaperTagVal.size()); begin = clock(); for (int i = 0; i < vumPaperTagVal.size(); ++i) { - vTP[i] = { &vumPaperTagVal[i], &vLineTag, &vTgName, vPaperStartIdx[i], vPaperStartIdx[i + 1], &umFullTagToTag, &vStrPubmedTxt }; - // ThreadParamPubmed tp = { &vumPaperTagVal[i], &vLineTag, &vTgName, vPaperStartIdx[i], vPaperStartIdx[i + 1], &umFullTagToTag, &vStrPubmedTxt }; + //vTP[i] = { &vumPaperTagVal[i], &vLineTag, &vTgName, vPaperStartIdx[i], vPaperStartIdx[i + 1], &umFullTagToTag, &vStrPubmedTxt }; + ThreadParamPubmed tp = { &vumPaperTagVal[i], &vLineTag, &vTgName, vPaperStartIdx[i], vPaperStartIdx[i + 1], &umFullTagToTag, &vStrPubmedTxt }; // ThreadProcessArticle(tp); - // thPool.enqueue(ThreadProcessArticle, tp); + thPool.enqueue(ThreadProcessArticle, tp); } - // thPool.~ThreadPool(); - - kt_for(numThread, ThreadProcessArticle, vTP); + thPool.~ThreadPool(); + //kt_for(numThread, ThreadProcessArticle, vTP); finish = clock(); cout << "kt for Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; diff --git a/GMM/main.cpp b/GMM/main.cpp index 8f134cf..ed008f5 100644 --- a/GMM/main.cpp +++ b/GMM/main.cpp @@ -30,7 +30,7 @@ #endif #include #include "gmm.h" -// #include "CommonLib/thread_pool.h" +#include "CommonLib/thread_pool.h" #include "CommonLib/matlab_io.h" #include "CommonLib/kthread.h" using namespace std; @@ -200,10 +200,10 @@ int main(int argc, const char** argv) { int numThread = 1; if (argc >= 4) numThread = atoi(argv[4]); if (numThread < 1) numThread = 1; - //ThreadPool thPool(numThread); + ThreadPool thPool(numThread); clock_t begin, finish; begin = clock(); - vector vTP; + //vector vTP; /* 遍历所有的知识颗粒目录,逐一进行处理 */ for (auto& childDir : fs::directory_iterator(parrentDir)) { fs::path outFilePath = childDir / outFileName; @@ -211,14 +211,14 @@ int main(int argc, const char** argv) { const string& fileName = file.path().filename().string(); auto rPos = fileName.rfind(hsMatSuffix); if (rPos != string::npos && fileName.size() - rPos == hsMatSuffix.size()) { - //ThreadParam tParam = { file, outFilePath }; - //thPool.enqueue(ThreadProcessData, tParam); - vTP.push_back({ file, outFilePath }); + ThreadParamKP tParam = { file, outFilePath }; + thPool.enqueue(ThreadProcessKP, tParam); + //vTP.push_back({ file, outFilePath }); } } } - kt_for(numThread, ThreadProcessKP, vTP); - //thPool.~ThreadPool(); + //kt_for(numThread, ThreadProcessKP, vTP); + thPool.~ThreadPool(); finish = clock(); cout << "GMM Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; return 0; diff --git a/MexFunc/CorrelationDist.cpp b/MexFunc/CorrelationDist.cpp index 7746785..efb919b 100644 --- a/MexFunc/CorrelationDist.cpp +++ b/MexFunc/CorrelationDist.cpp @@ -8,12 +8,106 @@ #include #include #include -#include "CommonLib/kthread.h" -#include "CommonLib/thread_pool.h" +#include +#include +#include +#include +#include +#include +#include +#include +#include + +// #include "CommonLib/kthread.h" +// #include "CommonLib/thread_pool.h" using std::cout; using std::endl; using namespace std; +class ThreadPool { +public: + ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + ->std::future::type>; + ~ThreadPool(); +private: + // need to keep track of threads so we can join them + std::vector< std::thread > workers; + // the task queue + std::queue< std::function > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) + : stop(false) +{ + for (size_t i = 0;i < threads;++i) + workers.emplace_back( + [this] + { + for (;;) + { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait(lock, + [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) + return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + } + ); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) +-> std::future::type> +{ + using return_type = typename std::result_of::type; + + auto task = std::make_shared< std::packaged_task >( + std::bind(std::forward(f), std::forward(args)...) + ); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if (stop) + throw std::runtime_error("enqueue on stopped ThreadPool"); + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() +{ + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) + worker.join(); +} + // 线程参数 struct TPCorDist { vector>* pvvX; @@ -66,26 +160,32 @@ void ThreadCalcDist(TPCorDist& param) { } /* 入口函数 */ -void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) { - //cout << "WordSplit" << endl; - //cout << nlhs << '\t' << nrhs << endl; - //if (nrhs < 1) { - // cout << "At least 1 arguments should be given for this function!" << endl; - // return; - //} +// void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) { +void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { + if (nrhs < 1) { + cout << "At least 1 arguments should be given for this function!" << endl; + return; + } clock_t begin = clock(), mid, finish; int rowNum = (int)mxGetM(prhs[0]); int colNum = (int)mxGetN(prhs[0]); double* pData = (double*)mxGetData(prhs[0]); - // - //int numThread = 1; - //if (nrhs > 1) { - // double* pNumThread = (double*)mxGetData(prhs[1]); - // numThread = (int)*pNumThread; - // if (numThread < 1) numThread = 1; - //} + + int numThread = 6; + if (nrhs > 1) { + double* pNumThread = (double*)mxGetData(prhs[1]); + numThread = (int)pNumThread[0]; + if (numThread < 1) numThread = 1; + } + int numGroup = 1; + if (nrhs > 2) { + double* pNumGroup = (double*)mxGetData(prhs[2]); + numGroup = (int)pNumGroup[0]; + if (numGroup < 1) numGroup = 1; + } + //cout << numThread << '\t' << numGroup << endl; //for (int i = 0; i < rowNum; ++i) { // for (int j = 0; j < colNum; ++j) { @@ -98,7 +198,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) { // int colNum = 5; // vector pData = { 1,1,2,2,3,9,4,4,5,4 }; - cout << rowNum << '\t' << colNum << endl; + //cout << rowNum << '\t' << colNum << endl; /* 计算每一行的平均数 */ mid = clock(); @@ -112,7 +212,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) { vMean[i] /= colNum; } finish = clock(); - cout << "计算平均数: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; + //cout << "计算平均数: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; /* 减去平均值, 计算平方 */ mid = clock(); vector> vvX(rowNum, vector(colNum)); @@ -125,7 +225,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) { } for (auto& val : vSq) { val /= colNum; } finish = clock(); - cout << "计算平方: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; + //cout << "计算平方: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; /* 计算相关距离 */ // clock_t mid0 = clock(); @@ -187,9 +287,9 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) { const int row2 = 2 * rowNum; vector vDist(distSize, 0.0); mid = clock(); - vector vTP; - ThreadPool thPool(6); - int span = 32; + // vector vTP; + ThreadPool thPool(numThread); + int span = numGroup; int rowNumSpan = rowNum / span * span; int i = 0; for (i = 0; i < rowNumSpan; i += span) { @@ -207,7 +307,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) { //kt_for(6, ThreadCalcDist, vTP); finish = clock(); - cout << "多线程计算相关距离: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; + //cout << "多线程计算相关距离: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; // for (auto& val : vDist) {cout << val << endl;} /* 写入结果 */ @@ -215,7 +315,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) { mxArray* pWriteArray = NULL; //创建一个rowNum*colNum的矩阵 pWriteArray = mxCreateDoubleMatrix(1, distSize, mxREAL); - memcpy((void*)(mxGetPr(pWriteArray)), (void*)vDist.data(), sizeof(double) * 6); + memcpy((void*)(mxGetPr(pWriteArray)), (void*)vDist.data(), sizeof(double) * distSize); plhs[0] = pWriteArray; // 赋值给返回值 }