加入线程池
This commit is contained in:
parent
f9e3d038c1
commit
7ced32b592
|
|
@ -109,6 +109,7 @@
|
|||
<ItemGroup>
|
||||
<ClInclude Include="gmm.h" />
|
||||
<ClInclude Include="kmeans.h" />
|
||||
<ClInclude Include="thread_pool.h" />
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="gmm.cpp" />
|
||||
|
|
|
|||
|
|
@ -21,6 +21,9 @@
|
|||
<ClInclude Include="gmm.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
<ClInclude Include="thread_pool.h">
|
||||
<Filter>Header Files</Filter>
|
||||
</ClInclude>
|
||||
</ItemGroup>
|
||||
<ItemGroup>
|
||||
<ClCompile Include="main.cpp">
|
||||
|
|
|
|||
141
GMM/main.cpp
141
GMM/main.cpp
|
|
@ -21,6 +21,7 @@
|
|||
#endif
|
||||
#include <mat.h>
|
||||
#include "gmm.h"
|
||||
#include "thread_pool.h"
|
||||
using namespace std;
|
||||
using std::cout;
|
||||
using std::vector;
|
||||
|
|
@ -113,7 +114,6 @@ void PutXtoBin(double* x, int xSize, double binWidth, vector<double>& vXBin, vec
|
|||
int binIdx = (int)((x[i] + binWidth / 2) / binWidth);
|
||||
if (binIdx >= binSize) binIdx = binSize - 1;
|
||||
vYBin[binIdx] += 1;
|
||||
// vXBin[i] = binIdx * binWidth;
|
||||
}
|
||||
// 按大小顺序将修改后的x数值存储在vXBin中,点的顺序不同,训练出的高斯混合模型参数会有一些不同。
|
||||
int xIdx = 0;
|
||||
|
|
@ -148,10 +148,8 @@ void GMMToFactorEY(GMM& gmm, double binWidth, vector<double> &vYBin, vector<doub
|
|||
for (int i = 0; i < topM; ++i) {
|
||||
pair<double, double> topEle = pqTopM.top();
|
||||
pqTopM.pop();
|
||||
// cout << topEle.first << '\t' << topEle.second << endl;
|
||||
zoomFactorSum += topEle.first / topEle.second;
|
||||
}
|
||||
// cout << endl;
|
||||
|
||||
double zoomFactor = zoomFactorSum / topM;
|
||||
|
||||
|
|
@ -198,122 +196,9 @@ double CorrelationDistance(vector<double>& vX, vector<double>& vY) {
|
|||
double vv = SquareAverage(vY);
|
||||
|
||||
double dist = 1.0 - uv / sqrt(uu * vv);
|
||||
|
||||
return abs(dist);
|
||||
}
|
||||
|
||||
/* 处理matlab的mat文件中包含的待拟合的数据 */
|
||||
void processMatData(const string& filePath) {
|
||||
|
||||
double* hs = nullptr;
|
||||
int rowNum = 0;
|
||||
int colNum = 0;
|
||||
|
||||
clock_t begin, finish;
|
||||
double total_cov = 0;
|
||||
double total_cov2 = 0;
|
||||
|
||||
begin = clock();
|
||||
hs = ReadMatlabMat<double>(filePath, "hs", &rowNum, &colNum);
|
||||
ofstream gmmOfs("mat_gmm.debug");
|
||||
ofstream gmmOfs2("mat_gmm2.debug");
|
||||
ofstream xyOfs("xy_cpp.debug");
|
||||
ofstream brOfs("br.debug");
|
||||
vector<double>vXBin;
|
||||
vector<double>vYBin;
|
||||
vector<double>vEY;
|
||||
vector<double>vFactor;
|
||||
/* 用来保存数据,存入mat文件 */
|
||||
vector<double>vDist(rowNum);
|
||||
vector<double>vFactorAll;
|
||||
for (int i = 0; i < rowNum; ++i) {
|
||||
PutXtoBin(hs + i * colNum, colNum, 0.2, vXBin, vYBin);
|
||||
// for (int m = 0; m < vYBin.size(); ++m) xyOfs<< fixed << setprecision(1) << 0.2 * m << ' ';
|
||||
// xyOfs << endl;
|
||||
// for (int m = 0; m < vYBin.size(); ++m) xyOfs << (int)vYBin[m] << ' ';
|
||||
// xyOfs << endl;
|
||||
|
||||
GMM gmm(1, 2); // 1维, 2个高斯模型
|
||||
gmm.Train(vXBin.data(), vXBin.size());
|
||||
total_cov += *gmm.Variance(0);
|
||||
gmmOfs << gmm << endl;
|
||||
|
||||
GMMToFactorEY(gmm, 0.2, vYBin, vFactor, vEY);
|
||||
vDist[i] = CorrelationDistance(vYBin, vEY);
|
||||
vFactorAll.insert(vFactorAll.end(), vFactor.begin(), vFactor.end());
|
||||
|
||||
brOfs << CorrelationDistance(vYBin, vEY) << endl;
|
||||
for (int j = 0; j < vFactor.size(); ++j) brOfs << vFactor[j] << ", ";
|
||||
|
||||
GMM gmm2(1, 2);
|
||||
gmm2.Train(hs + i * colNum, colNum);
|
||||
total_cov2 += *gmm2.Variance(0);
|
||||
gmmOfs2 << gmm2 << endl;
|
||||
}
|
||||
/* 写入matlab文件 */
|
||||
MATFile* pMatFile = matOpen("D:\\save_br.mat", "w");
|
||||
SaveMatrix<double>(vFactorAll.data(), pMatFile, "factor", rowNum, 6);
|
||||
SaveMatrix<double>(vDist.data(), pMatFile, "correlation", rowNum, 1);
|
||||
matClose(pMatFile);
|
||||
|
||||
gmmOfs.close();
|
||||
gmmOfs2.close();
|
||||
xyOfs.close();
|
||||
brOfs.close();
|
||||
finish = clock();
|
||||
cout << "Total cov: " << total_cov << endl;
|
||||
cout << "Total cov2: " << total_cov2 << endl;
|
||||
cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
|
||||
|
||||
//MATFile* pMatFile = matOpen("D:\\save_hs.mat", "w");
|
||||
//SaveMatrix<double>(hs, pMatFile, "hs_saved", rowNum, colNum);
|
||||
//matClose(pMatFile);
|
||||
|
||||
delete[] hs;
|
||||
}
|
||||
|
||||
/* 处理已经转换成txt的文本数据 */
|
||||
void processTxtData(const string& filePath) {
|
||||
clock_t begin, finish;
|
||||
double total_cov = 0;
|
||||
ifstream ifs(filePath, ios::in);
|
||||
|
||||
begin = clock();
|
||||
ofstream gmmOfs("txt_gmm.debug");
|
||||
while (!ifs.eof()) {
|
||||
vector<double> vec_point;
|
||||
string x_str, y_str;
|
||||
if (!getline(ifs, x_str)) break;
|
||||
if (!getline(ifs, y_str)) break;
|
||||
// cout << x_str << endl << y_str << endl;
|
||||
|
||||
stringstream ss_x(x_str);
|
||||
stringstream ss_y(y_str);
|
||||
|
||||
float x, y;
|
||||
|
||||
while (ss_x >> x && ss_y >> y) {
|
||||
vec_point.resize(vec_point.size() + y);
|
||||
for (int i = vec_point.size() - y; i < vec_point.size(); ++i)
|
||||
vec_point[i] = x;
|
||||
}
|
||||
if (vec_point.size() == 0) continue;
|
||||
|
||||
GMM gmm(1, 2); // 1维, 2个高斯模型
|
||||
gmm.Train(vec_point.data(), vec_point.size());
|
||||
// cout << *gmm.Mean(0) << endl;
|
||||
total_cov += *gmm.Variance(0);
|
||||
gmmOfs << gmm << endl;
|
||||
}
|
||||
gmmOfs.close();
|
||||
finish = clock();
|
||||
cout << "Total cov: " << total_cov << endl;
|
||||
cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
|
||||
|
||||
if (ifs.is_open())
|
||||
ifs.close();
|
||||
}
|
||||
|
||||
/* 处理一个知识颗粒 */
|
||||
struct ThreadParam {
|
||||
fs::path matFilePath;
|
||||
|
|
@ -322,8 +207,6 @@ struct ThreadParam {
|
|||
void ThreadProcessData(const ThreadParam& param) {
|
||||
const fs::path& matFilePath = param.matFilePath;
|
||||
const fs::path& outFilePath = param.outFilePath;
|
||||
// cout << parrentPath.string() << '\t' << matFilePath.filename().string() << endl;
|
||||
cout << outFilePath.string() << endl;
|
||||
double* hs = nullptr;
|
||||
int rowNum = 0;
|
||||
int colNum = 0;
|
||||
|
|
@ -356,42 +239,34 @@ void ThreadProcessData(const ThreadParam& param) {
|
|||
delete[] hs;
|
||||
}
|
||||
|
||||
/* 程序入口 */
|
||||
int main(int argc, char** argv) {
|
||||
|
||||
if (argc != 4) {
|
||||
cerr << "This program should take 3 arguments(1.parrent Dir; 2. mat file suffix; 3. out mat filename)!" << endl;
|
||||
if (argc != 5) {
|
||||
cerr << "This program should take 4 arguments(1.parrent Dir; 2. mat file suffix; 3. out mat filename; 4. thread number)!" << endl;
|
||||
return 1;
|
||||
}
|
||||
string parrentDir(argv[1]); // 知识颗粒的父目录名称
|
||||
string hsMatSuffix(argv[2]); // hs矩阵对应的mat文件的后缀名(可以是全文件名,可以是文件名后缀,必须保证唯一)
|
||||
fs::path outFileName(argv[3]);
|
||||
vector<thread> vThread;
|
||||
ThreadPool thPool(8);
|
||||
clock_t begin, finish;
|
||||
begin = clock();
|
||||
|
||||
/* 遍历所有的知识颗粒目录,注意进行处理 */
|
||||
/* 遍历所有的知识颗粒目录,逐一进行处理 */
|
||||
for (auto& childDir : fs::directory_iterator(parrentDir)) {
|
||||
// cout << childDir.path().string() << endl;
|
||||
fs::path outFilePath = childDir / outFileName;
|
||||
for (auto& file : fs::directory_iterator(childDir)) {
|
||||
// cout << file.path().filename().string() << endl;
|
||||
const string& fileName = file.path().filename().string();
|
||||
auto rPos = fileName.rfind(hsMatSuffix);
|
||||
if (rPos != string::npos && fileName.size() - rPos == hsMatSuffix.size()) {
|
||||
ThreadParam tParam = { file, outFilePath };
|
||||
vThread.push_back(thread(ThreadProcessData, tParam));
|
||||
// ThreadProcessData(tParam);
|
||||
thPool.enqueue(ThreadProcessData, tParam);
|
||||
}
|
||||
}
|
||||
}
|
||||
for (auto& thread : vThread) {
|
||||
thread.join();
|
||||
}
|
||||
thPool.~ThreadPool();
|
||||
finish = clock();
|
||||
cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
|
||||
// processMatData(argv[1]);
|
||||
// processMatData("D:\\Twirls\\runtime\\ALS_test\\1775\\twirls_id_abs2class_hs.mat");
|
||||
// processTxtData("D:\\Twirls\\backup\\xy.txt");
|
||||
|
||||
return 0;
|
||||
}
|
||||
|
|
@ -0,0 +1,98 @@
|
|||
#ifndef THREAD_POOL_H
|
||||
#define THREAD_POOL_H
|
||||
|
||||
#include <vector>
|
||||
#include <queue>
|
||||
#include <memory>
|
||||
#include <thread>
|
||||
#include <mutex>
|
||||
#include <condition_variable>
|
||||
#include <future>
|
||||
#include <functional>
|
||||
#include <stdexcept>
|
||||
|
||||
class ThreadPool {
|
||||
public:
|
||||
ThreadPool(size_t);
|
||||
template<class F, class... Args>
|
||||
auto enqueue(F&& f, Args&&... args)
|
||||
->std::future<typename std::result_of<F(Args...)>::type>;
|
||||
~ThreadPool();
|
||||
private:
|
||||
// need to keep track of threads so we can join them
|
||||
std::vector< std::thread > workers;
|
||||
// the task queue
|
||||
std::queue< std::function<void()> > tasks;
|
||||
|
||||
// synchronization
|
||||
std::mutex queue_mutex;
|
||||
std::condition_variable condition;
|
||||
bool stop;
|
||||
};
|
||||
|
||||
// the constructor just launches some amount of workers
|
||||
inline ThreadPool::ThreadPool(size_t threads)
|
||||
: stop(false)
|
||||
{
|
||||
for (size_t i = 0;i < threads;++i)
|
||||
workers.emplace_back(
|
||||
[this]
|
||||
{
|
||||
for (;;)
|
||||
{
|
||||
std::function<void()> task;
|
||||
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(this->queue_mutex);
|
||||
this->condition.wait(lock,
|
||||
[this] { return this->stop || !this->tasks.empty(); });
|
||||
if (this->stop && this->tasks.empty())
|
||||
return;
|
||||
task = std::move(this->tasks.front());
|
||||
this->tasks.pop();
|
||||
}
|
||||
|
||||
task();
|
||||
}
|
||||
}
|
||||
);
|
||||
}
|
||||
|
||||
// add new work item to the pool
|
||||
template<class F, class... Args>
|
||||
auto ThreadPool::enqueue(F&& f, Args&&... args)
|
||||
-> std::future<typename std::result_of<F(Args...)>::type>
|
||||
{
|
||||
using return_type = typename std::result_of<F(Args...)>::type;
|
||||
|
||||
auto task = std::make_shared< std::packaged_task<return_type()> >(
|
||||
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
|
||||
);
|
||||
|
||||
std::future<return_type> res = task->get_future();
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
|
||||
// don't allow enqueueing after stopping the pool
|
||||
if (stop)
|
||||
throw std::runtime_error("enqueue on stopped ThreadPool");
|
||||
|
||||
tasks.emplace([task]() { (*task)(); });
|
||||
}
|
||||
condition.notify_one();
|
||||
return res;
|
||||
}
|
||||
|
||||
// the destructor joins all threads
|
||||
inline ThreadPool::~ThreadPool()
|
||||
{
|
||||
{
|
||||
std::unique_lock<std::mutex> lock(queue_mutex);
|
||||
stop = true;
|
||||
}
|
||||
condition.notify_all();
|
||||
for (std::thread& worker : workers)
|
||||
worker.join();
|
||||
}
|
||||
|
||||
#endif
|
||||
Loading…
Reference in New Issue