完成了计算相关距离的mex函数

This commit is contained in:
zzh 2023-10-05 23:12:02 +08:00
parent a149240783
commit 1b63656908
4 changed files with 151 additions and 54 deletions

View File

@ -303,7 +303,7 @@ void CalcEntropy(int argc, const char** argv) {
/* 遍历所有的知识颗粒目录,逐一进行处理 */ /* 遍历所有的知识颗粒目录,逐一进行处理 */
begin = clock(); begin = clock();
//ThreadPool thPool(numThread); ThreadPool thPool(numThread);
// 查看知识颗粒数量 // 查看知识颗粒数量
int numKnowledgeParticle = 0; int numKnowledgeParticle = 0;
FOREACH_PARTICLE_START FOREACH_PARTICLE_START
@ -311,20 +311,18 @@ void CalcEntropy(int argc, const char** argv) {
FOREACH_PARTICLE_END FOREACH_PARTICLE_END
// 遍历每个知识颗粒,逐一进行处理 // 遍历每个知识颗粒,逐一进行处理
vector<ThreadParamEntropy> vTP; // vector<ThreadParamEntropy> vTP;
for (int round = 0; round < 1; ++round) { // ²âÊÔÓÃ
int i = 0; FOREACH_PARTICLE_START
FOREACH_PARTICLE_START ThreadParamEntropy tParam = { file, childDir / outFileName, &vusAbsWord };
//ThreadParam tParam = { file, childDir / outFileName, &vusAbsWord }; thPool.enqueue(ThreadProcessEntropy, tParam);
//thPool.enqueue(ThreadProcessData, tParam); //vTP.push_back({ file, childDir / outFileName, &vusAbsWord });
vTP.push_back({ file, childDir / outFileName, &vusAbsWord }); FOREACH_PARTICLE_END
i++;
FOREACH_PARTICLE_END
}
kt_for(numThread, ThreadProcessEntropy, vTP);
// synchronize // synchronize
//thPool.~ThreadPool(); thPool.~ThreadPool();
// kt_for(numThread, ThreadProcessEntropy, vTP);
finish = clock(); finish = clock();
cout << "thread pool time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; cout << "thread pool time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;

View File

@ -195,19 +195,18 @@ void ProcessPubmedTxt(int argc, const char** argv) {
int numThread = 1; int numThread = 1;
if (argc >= 5) numThread = atoi(argv[4]); if (argc >= 5) numThread = atoi(argv[4]);
if (numThread < 1) numThread = 1; if (numThread < 1) numThread = 1;
// ThreadPool thPool(numThread); ThreadPool thPool(numThread);
vumPaperTagVal.resize(vPaperStartIdx.size()-1); vumPaperTagVal.resize(vPaperStartIdx.size()-1);
vector<ThreadParamPubmed> vTP(vumPaperTagVal.size()); // vector<ThreadParamPubmed> vTP(vumPaperTagVal.size());
begin = clock(); begin = clock();
for (int i = 0; i < vumPaperTagVal.size(); ++i) { for (int i = 0; i < vumPaperTagVal.size(); ++i) {
vTP[i] = { &vumPaperTagVal[i], &vLineTag, &vTgName, vPaperStartIdx[i], vPaperStartIdx[i + 1], &umFullTagToTag, &vStrPubmedTxt }; //vTP[i] = { &vumPaperTagVal[i], &vLineTag, &vTgName, vPaperStartIdx[i], vPaperStartIdx[i + 1], &umFullTagToTag, &vStrPubmedTxt };
// ThreadParamPubmed tp = { &vumPaperTagVal[i], &vLineTag, &vTgName, vPaperStartIdx[i], vPaperStartIdx[i + 1], &umFullTagToTag, &vStrPubmedTxt }; ThreadParamPubmed tp = { &vumPaperTagVal[i], &vLineTag, &vTgName, vPaperStartIdx[i], vPaperStartIdx[i + 1], &umFullTagToTag, &vStrPubmedTxt };
// ThreadProcessArticle(tp); // ThreadProcessArticle(tp);
// thPool.enqueue(ThreadProcessArticle, tp); thPool.enqueue(ThreadProcessArticle, tp);
} }
// thPool.~ThreadPool(); thPool.~ThreadPool();
//kt_for(numThread, ThreadProcessArticle, vTP);
kt_for(numThread, ThreadProcessArticle, vTP);
finish = clock(); finish = clock();
cout << "kt for Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; cout << "kt for Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;

View File

@ -30,7 +30,7 @@
#endif #endif
#include <mat.h> #include <mat.h>
#include "gmm.h" #include "gmm.h"
// #include "CommonLib/thread_pool.h" #include "CommonLib/thread_pool.h"
#include "CommonLib/matlab_io.h" #include "CommonLib/matlab_io.h"
#include "CommonLib/kthread.h" #include "CommonLib/kthread.h"
using namespace std; using namespace std;
@ -200,10 +200,10 @@ int main(int argc, const char** argv) {
int numThread = 1; int numThread = 1;
if (argc >= 4) numThread = atoi(argv[4]); if (argc >= 4) numThread = atoi(argv[4]);
if (numThread < 1) numThread = 1; if (numThread < 1) numThread = 1;
//ThreadPool thPool(numThread); ThreadPool thPool(numThread);
clock_t begin, finish; clock_t begin, finish;
begin = clock(); begin = clock();
vector<ThreadParamKP> vTP; //vector<ThreadParamKP> vTP;
/* 遍历所有的知识颗粒目录,逐一进行处理 */ /* 遍历所有的知识颗粒目录,逐一进行处理 */
for (auto& childDir : fs::directory_iterator(parrentDir)) { for (auto& childDir : fs::directory_iterator(parrentDir)) {
fs::path outFilePath = childDir / outFileName; fs::path outFilePath = childDir / outFileName;
@ -211,14 +211,14 @@ int main(int argc, const char** argv) {
const string& fileName = file.path().filename().string(); const string& fileName = file.path().filename().string();
auto rPos = fileName.rfind(hsMatSuffix); auto rPos = fileName.rfind(hsMatSuffix);
if (rPos != string::npos && fileName.size() - rPos == hsMatSuffix.size()) { if (rPos != string::npos && fileName.size() - rPos == hsMatSuffix.size()) {
//ThreadParam tParam = { file, outFilePath }; ThreadParamKP tParam = { file, outFilePath };
//thPool.enqueue(ThreadProcessData, tParam); thPool.enqueue(ThreadProcessKP, tParam);
vTP.push_back({ file, outFilePath }); //vTP.push_back({ file, outFilePath });
} }
} }
} }
kt_for(numThread, ThreadProcessKP, vTP); //kt_for(numThread, ThreadProcessKP, vTP);
//thPool.~ThreadPool(); thPool.~ThreadPool();
finish = clock(); finish = clock();
cout << "GMM Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; cout << "GMM Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
return 0; return 0;

View File

@ -8,12 +8,106 @@
#include <ctime> #include <ctime>
#include <immintrin.h> #include <immintrin.h>
#include <zmmintrin.h> #include <zmmintrin.h>
#include "CommonLib/kthread.h" #include <vector>
#include "CommonLib/thread_pool.h" #include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
// #include "CommonLib/kthread.h"
// #include "CommonLib/thread_pool.h"
using std::cout; using std::cout;
using std::endl; using std::endl;
using namespace std; using namespace std;
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
->std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for (size_t i = 0;i < threads;++i)
workers.emplace_back(
[this]
{
for (;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if (stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread& worker : workers)
worker.join();
}
// 线程参数 // 线程参数
struct TPCorDist { struct TPCorDist {
vector<vector<float>>* pvvX; vector<vector<float>>* pvvX;
@ -66,26 +160,32 @@ void ThreadCalcDist(TPCorDist& param) {
} }
/* 入口函数 */ /* 入口函数 */
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) { // void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
//cout << "WordSplit" << endl; void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
//cout << nlhs << '\t' << nrhs << endl; if (nrhs < 1) {
//if (nrhs < 1) { cout << "At least 1 arguments should be given for this function!" << endl;
// cout << "At least 1 arguments should be given for this function!" << endl; return;
// return; }
//}
clock_t begin = clock(), mid, finish; clock_t begin = clock(), mid, finish;
int rowNum = (int)mxGetM(prhs[0]); int rowNum = (int)mxGetM(prhs[0]);
int colNum = (int)mxGetN(prhs[0]); int colNum = (int)mxGetN(prhs[0]);
double* pData = (double*)mxGetData(prhs[0]); double* pData = (double*)mxGetData(prhs[0]);
//
//int numThread = 1; int numThread = 6;
//if (nrhs > 1) { if (nrhs > 1) {
// double* pNumThread = (double*)mxGetData(prhs[1]); double* pNumThread = (double*)mxGetData(prhs[1]);
// numThread = (int)*pNumThread; numThread = (int)pNumThread[0];
// if (numThread < 1) numThread = 1; if (numThread < 1) numThread = 1;
//} }
int numGroup = 1;
if (nrhs > 2) {
double* pNumGroup = (double*)mxGetData(prhs[2]);
numGroup = (int)pNumGroup[0];
if (numGroup < 1) numGroup = 1;
}
//cout << numThread << '\t' << numGroup << endl;
//for (int i = 0; i < rowNum; ++i) { //for (int i = 0; i < rowNum; ++i) {
// for (int j = 0; j < colNum; ++j) { // for (int j = 0; j < colNum; ++j) {
@ -98,7 +198,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
// int colNum = 5; // int colNum = 5;
// vector<double> pData = { 1,1,2,2,3,9,4,4,5,4 }; // vector<double> pData = { 1,1,2,2,3,9,4,4,5,4 };
cout << rowNum << '\t' << colNum << endl; //cout << rowNum << '\t' << colNum << endl;
/* 计算每一行的平均数 */ /* 计算每一行的平均数 */
mid = clock(); mid = clock();
@ -112,7 +212,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
vMean[i] /= colNum; vMean[i] /= colNum;
} }
finish = clock(); finish = clock();
cout << "计算平均数: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; //cout << "计算平均数: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* 减去平均值, 计算平方 */ /* 减去平均值, 计算平方 */
mid = clock(); mid = clock();
vector<vector<float>> vvX(rowNum, vector<float>(colNum)); vector<vector<float>> vvX(rowNum, vector<float>(colNum));
@ -125,7 +225,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
} }
for (auto& val : vSq) { val /= colNum; } for (auto& val : vSq) { val /= colNum; }
finish = clock(); finish = clock();
cout << "计算平方: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; //cout << "计算平方: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* 计算相关距离 */ /* 计算相关距离 */
// clock_t mid0 = clock(); // clock_t mid0 = clock();
@ -187,9 +287,9 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
const int row2 = 2 * rowNum; const int row2 = 2 * rowNum;
vector<double> vDist(distSize, 0.0); vector<double> vDist(distSize, 0.0);
mid = clock(); mid = clock();
vector<TPCorDist> vTP; // vector<TPCorDist> vTP;
ThreadPool thPool(6); ThreadPool thPool(numThread);
int span = 32; int span = numGroup;
int rowNumSpan = rowNum / span * span; int rowNumSpan = rowNum / span * span;
int i = 0; int i = 0;
for (i = 0; i < rowNumSpan; i += span) { for (i = 0; i < rowNumSpan; i += span) {
@ -207,7 +307,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
//kt_for(6, ThreadCalcDist, vTP); //kt_for(6, ThreadCalcDist, vTP);
finish = clock(); finish = clock();
cout << "多线程计算相关距离: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; //cout << "多线程计算相关距离: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
// for (auto& val : vDist) {cout << val << endl;} // for (auto& val : vDist) {cout << val << endl;}
/* 写入结果 */ /* 写入结果 */
@ -215,7 +315,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
mxArray* pWriteArray = NULL; mxArray* pWriteArray = NULL;
//创建一个rowNum*colNum的矩阵 //创建一个rowNum*colNum的矩阵
pWriteArray = mxCreateDoubleMatrix(1, distSize, mxREAL); pWriteArray = mxCreateDoubleMatrix(1, distSize, mxREAL);
memcpy((void*)(mxGetPr(pWriteArray)), (void*)vDist.data(), sizeof(double) * 6); memcpy((void*)(mxGetPr(pWriteArray)), (void*)vDist.data(), sizeof(double) * distSize);
plhs[0] = pWriteArray; // 赋值给返回值 plhs[0] = pWriteArray; // 赋值给返回值
} }