diff --git a/GMM/GMM.vcxproj b/GMM/GMM.vcxproj index 42bdea4..5a76361 100644 --- a/GMM/GMM.vcxproj +++ b/GMM/GMM.vcxproj @@ -109,6 +109,7 @@ + diff --git a/GMM/GMM.vcxproj.filters b/GMM/GMM.vcxproj.filters index c29d2e2..49c0664 100644 --- a/GMM/GMM.vcxproj.filters +++ b/GMM/GMM.vcxproj.filters @@ -21,6 +21,9 @@ Header Files + + Header Files + diff --git a/GMM/main.cpp b/GMM/main.cpp index ace5ff6..21163f3 100644 --- a/GMM/main.cpp +++ b/GMM/main.cpp @@ -21,6 +21,7 @@ #endif #include #include "gmm.h" +#include "thread_pool.h" using namespace std; using std::cout; using std::vector; @@ -113,7 +114,6 @@ void PutXtoBin(double* x, int xSize, double binWidth, vector& vXBin, vec int binIdx = (int)((x[i] + binWidth / 2) / binWidth); if (binIdx >= binSize) binIdx = binSize - 1; vYBin[binIdx] += 1; - // vXBin[i] = binIdx * binWidth; } // 按大小顺序将修改后的x数值存储在vXBin中,点的顺序不同,训练出的高斯混合模型参数会有一些不同。 int xIdx = 0; @@ -148,10 +148,8 @@ void GMMToFactorEY(GMM& gmm, double binWidth, vector &vYBin, vector topEle = pqTopM.top(); pqTopM.pop(); - // cout << topEle.first << '\t' << topEle.second << endl; zoomFactorSum += topEle.first / topEle.second; } - // cout << endl; double zoomFactor = zoomFactorSum / topM; @@ -198,122 +196,9 @@ double CorrelationDistance(vector& vX, vector& vY) { double vv = SquareAverage(vY); double dist = 1.0 - uv / sqrt(uu * vv); - return abs(dist); } -/* 处理matlab的mat文件中包含的待拟合的数据 */ -void processMatData(const string& filePath) { - - double* hs = nullptr; - int rowNum = 0; - int colNum = 0; - - clock_t begin, finish; - double total_cov = 0; - double total_cov2 = 0; - - begin = clock(); - hs = ReadMatlabMat(filePath, "hs", &rowNum, &colNum); - ofstream gmmOfs("mat_gmm.debug"); - ofstream gmmOfs2("mat_gmm2.debug"); - ofstream xyOfs("xy_cpp.debug"); - ofstream brOfs("br.debug"); - vectorvXBin; - vectorvYBin; - vectorvEY; - vectorvFactor; - /* 用来保存数据,存入mat文件 */ - vectorvDist(rowNum); - vectorvFactorAll; - for (int i = 0; i < rowNum; ++i) { - PutXtoBin(hs + i * colNum, colNum, 0.2, vXBin, vYBin); - // for (int m = 0; m < vYBin.size(); ++m) xyOfs<< fixed << setprecision(1) << 0.2 * m << ' '; - // xyOfs << endl; - // for (int m = 0; m < vYBin.size(); ++m) xyOfs << (int)vYBin[m] << ' '; - // xyOfs << endl; - - GMM gmm(1, 2); // 1维, 2个高斯模型 - gmm.Train(vXBin.data(), vXBin.size()); - total_cov += *gmm.Variance(0); - gmmOfs << gmm << endl; - - GMMToFactorEY(gmm, 0.2, vYBin, vFactor, vEY); - vDist[i] = CorrelationDistance(vYBin, vEY); - vFactorAll.insert(vFactorAll.end(), vFactor.begin(), vFactor.end()); - - brOfs << CorrelationDistance(vYBin, vEY) << endl; - for (int j = 0; j < vFactor.size(); ++j) brOfs << vFactor[j] << ", "; - - GMM gmm2(1, 2); - gmm2.Train(hs + i * colNum, colNum); - total_cov2 += *gmm2.Variance(0); - gmmOfs2 << gmm2 << endl; - } - /* 写入matlab文件 */ - MATFile* pMatFile = matOpen("D:\\save_br.mat", "w"); - SaveMatrix(vFactorAll.data(), pMatFile, "factor", rowNum, 6); - SaveMatrix(vDist.data(), pMatFile, "correlation", rowNum, 1); - matClose(pMatFile); - - gmmOfs.close(); - gmmOfs2.close(); - xyOfs.close(); - brOfs.close(); - finish = clock(); - cout << "Total cov: " << total_cov << endl; - cout << "Total cov2: " << total_cov2 << endl; - cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl; - - //MATFile* pMatFile = matOpen("D:\\save_hs.mat", "w"); - //SaveMatrix(hs, pMatFile, "hs_saved", rowNum, colNum); - //matClose(pMatFile); - - delete[] hs; -} - -/* 处理已经转换成txt的文本数据 */ -void processTxtData(const string& filePath) { - clock_t begin, finish; - double total_cov = 0; - ifstream ifs(filePath, ios::in); - - begin = clock(); - ofstream gmmOfs("txt_gmm.debug"); - while (!ifs.eof()) { - vector vec_point; - string x_str, y_str; - if (!getline(ifs, x_str)) break; - if (!getline(ifs, y_str)) break; - // cout << x_str << endl << y_str << endl; - - stringstream ss_x(x_str); - stringstream ss_y(y_str); - - float x, y; - - while (ss_x >> x && ss_y >> y) { - vec_point.resize(vec_point.size() + y); - for (int i = vec_point.size() - y; i < vec_point.size(); ++i) - vec_point[i] = x; - } - if (vec_point.size() == 0) continue; - - GMM gmm(1, 2); // 1维, 2个高斯模型 - gmm.Train(vec_point.data(), vec_point.size()); - // cout << *gmm.Mean(0) << endl; - total_cov += *gmm.Variance(0); - gmmOfs << gmm << endl; - } - gmmOfs.close(); - finish = clock(); - cout << "Total cov: " << total_cov << endl; - cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl; - - if (ifs.is_open()) - ifs.close(); -} - /* 处理一个知识颗粒 */ struct ThreadParam { fs::path matFilePath; @@ -322,8 +207,6 @@ struct ThreadParam { void ThreadProcessData(const ThreadParam& param) { const fs::path& matFilePath = param.matFilePath; const fs::path& outFilePath = param.outFilePath; - // cout << parrentPath.string() << '\t' << matFilePath.filename().string() << endl; - cout << outFilePath.string() << endl; double* hs = nullptr; int rowNum = 0; int colNum = 0; @@ -356,42 +239,34 @@ void ThreadProcessData(const ThreadParam& param) { delete[] hs; } +/* 程序入口 */ int main(int argc, char** argv) { - if (argc != 4) { - cerr << "This program should take 3 arguments(1.parrent Dir; 2. mat file suffix; 3. out mat filename)!" << endl; + if (argc != 5) { + cerr << "This program should take 4 arguments(1.parrent Dir; 2. mat file suffix; 3. out mat filename; 4. thread number)!" << endl; return 1; } string parrentDir(argv[1]); // 知识颗粒的父目录名称 string hsMatSuffix(argv[2]); // hs矩阵对应的mat文件的后缀名(可以是全文件名,可以是文件名后缀,必须保证唯一) fs::path outFileName(argv[3]); - vector vThread; + ThreadPool thPool(8); clock_t begin, finish; begin = clock(); - /* 遍历所有的知识颗粒目录,注意进行处理 */ + /* 遍历所有的知识颗粒目录,逐一进行处理 */ for (auto& childDir : fs::directory_iterator(parrentDir)) { - // cout << childDir.path().string() << endl; fs::path outFilePath = childDir / outFileName; for (auto& file : fs::directory_iterator(childDir)) { - // cout << file.path().filename().string() << endl; const string& fileName = file.path().filename().string(); auto rPos = fileName.rfind(hsMatSuffix); if (rPos != string::npos && fileName.size() - rPos == hsMatSuffix.size()) { ThreadParam tParam = { file, outFilePath }; - vThread.push_back(thread(ThreadProcessData, tParam)); - // ThreadProcessData(tParam); + thPool.enqueue(ThreadProcessData, tParam); } } } - for (auto& thread : vThread) { - thread.join(); - } + thPool.~ThreadPool(); finish = clock(); cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl; - // processMatData(argv[1]); - // processMatData("D:\\Twirls\\runtime\\ALS_test\\1775\\twirls_id_abs2class_hs.mat"); - // processTxtData("D:\\Twirls\\backup\\xy.txt"); - return 0; } \ No newline at end of file diff --git a/GMM/thread_pool.h b/GMM/thread_pool.h new file mode 100644 index 0000000..d9dc583 --- /dev/null +++ b/GMM/thread_pool.h @@ -0,0 +1,98 @@ +#ifndef THREAD_POOL_H +#define THREAD_POOL_H + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +class ThreadPool { +public: + ThreadPool(size_t); + template + auto enqueue(F&& f, Args&&... args) + ->std::future::type>; + ~ThreadPool(); +private: + // need to keep track of threads so we can join them + std::vector< std::thread > workers; + // the task queue + std::queue< std::function > tasks; + + // synchronization + std::mutex queue_mutex; + std::condition_variable condition; + bool stop; +}; + +// the constructor just launches some amount of workers +inline ThreadPool::ThreadPool(size_t threads) + : stop(false) +{ + for (size_t i = 0;i < threads;++i) + workers.emplace_back( + [this] + { + for (;;) + { + std::function task; + + { + std::unique_lock lock(this->queue_mutex); + this->condition.wait(lock, + [this] { return this->stop || !this->tasks.empty(); }); + if (this->stop && this->tasks.empty()) + return; + task = std::move(this->tasks.front()); + this->tasks.pop(); + } + + task(); + } + } + ); +} + +// add new work item to the pool +template +auto ThreadPool::enqueue(F&& f, Args&&... args) +-> std::future::type> +{ + using return_type = typename std::result_of::type; + + auto task = std::make_shared< std::packaged_task >( + std::bind(std::forward(f), std::forward(args)...) + ); + + std::future res = task->get_future(); + { + std::unique_lock lock(queue_mutex); + + // don't allow enqueueing after stopping the pool + if (stop) + throw std::runtime_error("enqueue on stopped ThreadPool"); + + tasks.emplace([task]() { (*task)(); }); + } + condition.notify_one(); + return res; +} + +// the destructor joins all threads +inline ThreadPool::~ThreadPool() +{ + { + std::unique_lock lock(queue_mutex); + stop = true; + } + condition.notify_all(); + for (std::thread& worker : workers) + worker.join(); +} + +#endif \ No newline at end of file