加入线程池

This commit is contained in:
zzh 2023-09-18 02:38:24 +08:00
parent f9e3d038c1
commit 7ced32b592
4 changed files with 110 additions and 133 deletions

View File

@ -109,6 +109,7 @@
<ItemGroup>
<ClInclude Include="gmm.h" />
<ClInclude Include="kmeans.h" />
<ClInclude Include="thread_pool.h" />
</ItemGroup>
<ItemGroup>
<ClCompile Include="gmm.cpp" />

View File

@ -21,6 +21,9 @@
<ClInclude Include="gmm.h">
<Filter>Header Files</Filter>
</ClInclude>
<ClInclude Include="thread_pool.h">
<Filter>Header Files</Filter>
</ClInclude>
</ItemGroup>
<ItemGroup>
<ClCompile Include="main.cpp">

View File

@ -21,6 +21,7 @@
#endif
#include <mat.h>
#include "gmm.h"
#include "thread_pool.h"
using namespace std;
using std::cout;
using std::vector;
@ -113,7 +114,6 @@ void PutXtoBin(double* x, int xSize, double binWidth, vector<double>& vXBin, vec
int binIdx = (int)((x[i] + binWidth / 2) / binWidth);
if (binIdx >= binSize) binIdx = binSize - 1;
vYBin[binIdx] += 1;
// vXBin[i] = binIdx * binWidth;
}
// 按大小顺序将修改后的x数值存储在vXBin中点的顺序不同训练出的高斯混合模型参数会有一些不同。
int xIdx = 0;
@ -148,10 +148,8 @@ void GMMToFactorEY(GMM& gmm, double binWidth, vector<double> &vYBin, vector<doub
for (int i = 0; i < topM; ++i) {
pair<double, double> topEle = pqTopM.top();
pqTopM.pop();
// cout << topEle.first << '\t' << topEle.second << endl;
zoomFactorSum += topEle.first / topEle.second;
}
// cout << endl;
double zoomFactor = zoomFactorSum / topM;
@ -198,122 +196,9 @@ double CorrelationDistance(vector<double>& vX, vector<double>& vY) {
double vv = SquareAverage(vY);
double dist = 1.0 - uv / sqrt(uu * vv);
return abs(dist);
}
/* 处理matlab的mat文件中包含的待拟合的数据 */
void processMatData(const string& filePath) {
double* hs = nullptr;
int rowNum = 0;
int colNum = 0;
clock_t begin, finish;
double total_cov = 0;
double total_cov2 = 0;
begin = clock();
hs = ReadMatlabMat<double>(filePath, "hs", &rowNum, &colNum);
ofstream gmmOfs("mat_gmm.debug");
ofstream gmmOfs2("mat_gmm2.debug");
ofstream xyOfs("xy_cpp.debug");
ofstream brOfs("br.debug");
vector<double>vXBin;
vector<double>vYBin;
vector<double>vEY;
vector<double>vFactor;
/* 用来保存数据存入mat文件 */
vector<double>vDist(rowNum);
vector<double>vFactorAll;
for (int i = 0; i < rowNum; ++i) {
PutXtoBin(hs + i * colNum, colNum, 0.2, vXBin, vYBin);
// for (int m = 0; m < vYBin.size(); ++m) xyOfs<< fixed << setprecision(1) << 0.2 * m << ' ';
// xyOfs << endl;
// for (int m = 0; m < vYBin.size(); ++m) xyOfs << (int)vYBin[m] << ' ';
// xyOfs << endl;
GMM gmm(1, 2); // 1维 2个高斯模型
gmm.Train(vXBin.data(), vXBin.size());
total_cov += *gmm.Variance(0);
gmmOfs << gmm << endl;
GMMToFactorEY(gmm, 0.2, vYBin, vFactor, vEY);
vDist[i] = CorrelationDistance(vYBin, vEY);
vFactorAll.insert(vFactorAll.end(), vFactor.begin(), vFactor.end());
brOfs << CorrelationDistance(vYBin, vEY) << endl;
for (int j = 0; j < vFactor.size(); ++j) brOfs << vFactor[j] << ", ";
GMM gmm2(1, 2);
gmm2.Train(hs + i * colNum, colNum);
total_cov2 += *gmm2.Variance(0);
gmmOfs2 << gmm2 << endl;
}
/* 写入matlab文件 */
MATFile* pMatFile = matOpen("D:\\save_br.mat", "w");
SaveMatrix<double>(vFactorAll.data(), pMatFile, "factor", rowNum, 6);
SaveMatrix<double>(vDist.data(), pMatFile, "correlation", rowNum, 1);
matClose(pMatFile);
gmmOfs.close();
gmmOfs2.close();
xyOfs.close();
brOfs.close();
finish = clock();
cout << "Total cov: " << total_cov << endl;
cout << "Total cov2: " << total_cov2 << endl;
cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
//MATFile* pMatFile = matOpen("D:\\save_hs.mat", "w");
//SaveMatrix<double>(hs, pMatFile, "hs_saved", rowNum, colNum);
//matClose(pMatFile);
delete[] hs;
}
/* 处理已经转换成txt的文本数据 */
void processTxtData(const string& filePath) {
clock_t begin, finish;
double total_cov = 0;
ifstream ifs(filePath, ios::in);
begin = clock();
ofstream gmmOfs("txt_gmm.debug");
while (!ifs.eof()) {
vector<double> vec_point;
string x_str, y_str;
if (!getline(ifs, x_str)) break;
if (!getline(ifs, y_str)) break;
// cout << x_str << endl << y_str << endl;
stringstream ss_x(x_str);
stringstream ss_y(y_str);
float x, y;
while (ss_x >> x && ss_y >> y) {
vec_point.resize(vec_point.size() + y);
for (int i = vec_point.size() - y; i < vec_point.size(); ++i)
vec_point[i] = x;
}
if (vec_point.size() == 0) continue;
GMM gmm(1, 2); // 1维 2个高斯模型
gmm.Train(vec_point.data(), vec_point.size());
// cout << *gmm.Mean(0) << endl;
total_cov += *gmm.Variance(0);
gmmOfs << gmm << endl;
}
gmmOfs.close();
finish = clock();
cout << "Total cov: " << total_cov << endl;
cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
if (ifs.is_open())
ifs.close();
}
/* 处理一个知识颗粒 */
struct ThreadParam {
fs::path matFilePath;
@ -322,8 +207,6 @@ struct ThreadParam {
void ThreadProcessData(const ThreadParam& param) {
const fs::path& matFilePath = param.matFilePath;
const fs::path& outFilePath = param.outFilePath;
// cout << parrentPath.string() << '\t' << matFilePath.filename().string() << endl;
cout << outFilePath.string() << endl;
double* hs = nullptr;
int rowNum = 0;
int colNum = 0;
@ -356,42 +239,34 @@ void ThreadProcessData(const ThreadParam& param) {
delete[] hs;
}
/* 程序入口 */
int main(int argc, char** argv) {
if (argc != 4) {
cerr << "This program should take 3 arguments(1.parrent Dir; 2. mat file suffix; 3. out mat filename)!" << endl;
if (argc != 5) {
cerr << "This program should take 4 arguments(1.parrent Dir; 2. mat file suffix; 3. out mat filename; 4. thread number)!" << endl;
return 1;
}
string parrentDir(argv[1]); // 知识颗粒的父目录名称
string hsMatSuffix(argv[2]); // hs矩阵对应的mat文件的后缀名可以是全文件名可以是文件名后缀必须保证唯一
fs::path outFileName(argv[3]);
vector<thread> vThread;
ThreadPool thPool(8);
clock_t begin, finish;
begin = clock();
/* 遍历所有的知识颗粒目录,注意进行处理 */
/* 遍历所有的知识颗粒目录,逐一进行处理 */
for (auto& childDir : fs::directory_iterator(parrentDir)) {
// cout << childDir.path().string() << endl;
fs::path outFilePath = childDir / outFileName;
for (auto& file : fs::directory_iterator(childDir)) {
// cout << file.path().filename().string() << endl;
const string& fileName = file.path().filename().string();
auto rPos = fileName.rfind(hsMatSuffix);
if (rPos != string::npos && fileName.size() - rPos == hsMatSuffix.size()) {
ThreadParam tParam = { file, outFilePath };
vThread.push_back(thread(ThreadProcessData, tParam));
// ThreadProcessData(tParam);
thPool.enqueue(ThreadProcessData, tParam);
}
}
}
for (auto& thread : vThread) {
thread.join();
}
thPool.~ThreadPool();
finish = clock();
cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
// processMatData(argv[1]);
// processMatData("D:\\Twirls\\runtime\\ALS_test\\1775\\twirls_id_abs2class_hs.mat");
// processTxtData("D:\\Twirls\\backup\\xy.txt");
return 0;
}

98
GMM/thread_pool.h 100644
View File

@ -0,0 +1,98 @@
#ifndef THREAD_POOL_H
#define THREAD_POOL_H
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
->std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for (size_t i = 0;i < threads;++i)
workers.emplace_back(
[this]
{
for (;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if (stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread& worker : workers)
worker.join();
}
#endif