完成了计算相关距离的mex函数

This commit is contained in:
zzh 2023-10-05 23:12:02 +08:00
parent a149240783
commit 1b63656908
4 changed files with 151 additions and 54 deletions

View File

@ -303,7 +303,7 @@ void CalcEntropy(int argc, const char** argv) {
/* 遍历所有的知识颗粒目录,逐一进行处理 */
begin = clock();
//ThreadPool thPool(numThread);
ThreadPool thPool(numThread);
// 查看知识颗粒数量
int numKnowledgeParticle = 0;
FOREACH_PARTICLE_START
@ -311,20 +311,18 @@ void CalcEntropy(int argc, const char** argv) {
FOREACH_PARTICLE_END
// 遍历每个知识颗粒,逐一进行处理
vector<ThreadParamEntropy> vTP;
for (int round = 0; round < 1; ++round) { // ²âÊÔÓÃ
int i = 0;
FOREACH_PARTICLE_START
//ThreadParam tParam = { file, childDir / outFileName, &vusAbsWord };
//thPool.enqueue(ThreadProcessData, tParam);
vTP.push_back({ file, childDir / outFileName, &vusAbsWord });
i++;
FOREACH_PARTICLE_END
}
kt_for(numThread, ThreadProcessEntropy, vTP);
// vector<ThreadParamEntropy> vTP;
FOREACH_PARTICLE_START
ThreadParamEntropy tParam = { file, childDir / outFileName, &vusAbsWord };
thPool.enqueue(ThreadProcessEntropy, tParam);
//vTP.push_back({ file, childDir / outFileName, &vusAbsWord });
FOREACH_PARTICLE_END
// synchronize
//thPool.~ThreadPool();
thPool.~ThreadPool();
// kt_for(numThread, ThreadProcessEntropy, vTP);
finish = clock();
cout << "thread pool time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;

View File

@ -195,19 +195,18 @@ void ProcessPubmedTxt(int argc, const char** argv) {
int numThread = 1;
if (argc >= 5) numThread = atoi(argv[4]);
if (numThread < 1) numThread = 1;
// ThreadPool thPool(numThread);
ThreadPool thPool(numThread);
vumPaperTagVal.resize(vPaperStartIdx.size()-1);
vector<ThreadParamPubmed> vTP(vumPaperTagVal.size());
// vector<ThreadParamPubmed> vTP(vumPaperTagVal.size());
begin = clock();
for (int i = 0; i < vumPaperTagVal.size(); ++i) {
vTP[i] = { &vumPaperTagVal[i], &vLineTag, &vTgName, vPaperStartIdx[i], vPaperStartIdx[i + 1], &umFullTagToTag, &vStrPubmedTxt };
// ThreadParamPubmed tp = { &vumPaperTagVal[i], &vLineTag, &vTgName, vPaperStartIdx[i], vPaperStartIdx[i + 1], &umFullTagToTag, &vStrPubmedTxt };
//vTP[i] = { &vumPaperTagVal[i], &vLineTag, &vTgName, vPaperStartIdx[i], vPaperStartIdx[i + 1], &umFullTagToTag, &vStrPubmedTxt };
ThreadParamPubmed tp = { &vumPaperTagVal[i], &vLineTag, &vTgName, vPaperStartIdx[i], vPaperStartIdx[i + 1], &umFullTagToTag, &vStrPubmedTxt };
// ThreadProcessArticle(tp);
// thPool.enqueue(ThreadProcessArticle, tp);
thPool.enqueue(ThreadProcessArticle, tp);
}
// thPool.~ThreadPool();
kt_for(numThread, ThreadProcessArticle, vTP);
thPool.~ThreadPool();
//kt_for(numThread, ThreadProcessArticle, vTP);
finish = clock();
cout << "kt for Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;

View File

@ -30,7 +30,7 @@
#endif
#include <mat.h>
#include "gmm.h"
// #include "CommonLib/thread_pool.h"
#include "CommonLib/thread_pool.h"
#include "CommonLib/matlab_io.h"
#include "CommonLib/kthread.h"
using namespace std;
@ -200,10 +200,10 @@ int main(int argc, const char** argv) {
int numThread = 1;
if (argc >= 4) numThread = atoi(argv[4]);
if (numThread < 1) numThread = 1;
//ThreadPool thPool(numThread);
ThreadPool thPool(numThread);
clock_t begin, finish;
begin = clock();
vector<ThreadParamKP> vTP;
//vector<ThreadParamKP> vTP;
/* 遍历所有的知识颗粒目录,逐一进行处理 */
for (auto& childDir : fs::directory_iterator(parrentDir)) {
fs::path outFilePath = childDir / outFileName;
@ -211,14 +211,14 @@ int main(int argc, const char** argv) {
const string& fileName = file.path().filename().string();
auto rPos = fileName.rfind(hsMatSuffix);
if (rPos != string::npos && fileName.size() - rPos == hsMatSuffix.size()) {
//ThreadParam tParam = { file, outFilePath };
//thPool.enqueue(ThreadProcessData, tParam);
vTP.push_back({ file, outFilePath });
ThreadParamKP tParam = { file, outFilePath };
thPool.enqueue(ThreadProcessKP, tParam);
//vTP.push_back({ file, outFilePath });
}
}
}
kt_for(numThread, ThreadProcessKP, vTP);
//thPool.~ThreadPool();
//kt_for(numThread, ThreadProcessKP, vTP);
thPool.~ThreadPool();
finish = clock();
cout << "GMM Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
return 0;

View File

@ -8,12 +8,106 @@
#include <ctime>
#include <immintrin.h>
#include <zmmintrin.h>
#include "CommonLib/kthread.h"
#include "CommonLib/thread_pool.h"
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
// #include "CommonLib/kthread.h"
// #include "CommonLib/thread_pool.h"
using std::cout;
using std::endl;
using namespace std;
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
->std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for (size_t i = 0;i < threads;++i)
workers.emplace_back(
[this]
{
for (;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if (stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread& worker : workers)
worker.join();
}
// 线程参数
struct TPCorDist {
vector<vector<float>>* pvvX;
@ -66,26 +160,32 @@ void ThreadCalcDist(TPCorDist& param) {
}
/* 入口函数 */
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
//cout << "WordSplit" << endl;
//cout << nlhs << '\t' << nrhs << endl;
//if (nrhs < 1) {
// cout << "At least 1 arguments should be given for this function!" << endl;
// return;
//}
// void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
if (nrhs < 1) {
cout << "At least 1 arguments should be given for this function!" << endl;
return;
}
clock_t begin = clock(), mid, finish;
int rowNum = (int)mxGetM(prhs[0]);
int colNum = (int)mxGetN(prhs[0]);
double* pData = (double*)mxGetData(prhs[0]);
//
//int numThread = 1;
//if (nrhs > 1) {
// double* pNumThread = (double*)mxGetData(prhs[1]);
// numThread = (int)*pNumThread;
// if (numThread < 1) numThread = 1;
//}
int numThread = 6;
if (nrhs > 1) {
double* pNumThread = (double*)mxGetData(prhs[1]);
numThread = (int)pNumThread[0];
if (numThread < 1) numThread = 1;
}
int numGroup = 1;
if (nrhs > 2) {
double* pNumGroup = (double*)mxGetData(prhs[2]);
numGroup = (int)pNumGroup[0];
if (numGroup < 1) numGroup = 1;
}
//cout << numThread << '\t' << numGroup << endl;
//for (int i = 0; i < rowNum; ++i) {
// for (int j = 0; j < colNum; ++j) {
@ -98,7 +198,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
// int colNum = 5;
// vector<double> pData = { 1,1,2,2,3,9,4,4,5,4 };
cout << rowNum << '\t' << colNum << endl;
//cout << rowNum << '\t' << colNum << endl;
/* 计算每一行的平均数 */
mid = clock();
@ -112,7 +212,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
vMean[i] /= colNum;
}
finish = clock();
cout << "计算平均数: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
//cout << "计算平均数: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* 减去平均值, 计算平方 */
mid = clock();
vector<vector<float>> vvX(rowNum, vector<float>(colNum));
@ -125,7 +225,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
}
for (auto& val : vSq) { val /= colNum; }
finish = clock();
cout << "计算平方: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
//cout << "计算平方: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* 计算相关距离 */
// clock_t mid0 = clock();
@ -187,9 +287,9 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
const int row2 = 2 * rowNum;
vector<double> vDist(distSize, 0.0);
mid = clock();
vector<TPCorDist> vTP;
ThreadPool thPool(6);
int span = 32;
// vector<TPCorDist> vTP;
ThreadPool thPool(numThread);
int span = numGroup;
int rowNumSpan = rowNum / span * span;
int i = 0;
for (i = 0; i < rowNumSpan; i += span) {
@ -207,7 +307,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
//kt_for(6, ThreadCalcDist, vTP);
finish = clock();
cout << "多线程计算相关距离: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
//cout << "多线程计算相关距离: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
// for (auto& val : vDist) {cout << val << endl;}
/* 写入结果 */
@ -215,7 +315,7 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
mxArray* pWriteArray = NULL;
//创建一个rowNum*colNum的矩阵
pWriteArray = mxCreateDoubleMatrix(1, distSize, mxREAL);
memcpy((void*)(mxGetPr(pWriteArray)), (void*)vDist.data(), sizeof(double) * 6);
memcpy((void*)(mxGetPr(pWriteArray)), (void*)vDist.data(), sizeof(double) * distSize);
plhs[0] = pWriteArray; // 赋值给返回值
}