diff --git a/GMM/GMM.vcxproj b/GMM/GMM.vcxproj
index 42bdea4..5a76361 100644
--- a/GMM/GMM.vcxproj
+++ b/GMM/GMM.vcxproj
@@ -109,6 +109,7 @@
+
diff --git a/GMM/GMM.vcxproj.filters b/GMM/GMM.vcxproj.filters
index c29d2e2..49c0664 100644
--- a/GMM/GMM.vcxproj.filters
+++ b/GMM/GMM.vcxproj.filters
@@ -21,6 +21,9 @@
Header Files
+
+ Header Files
+
diff --git a/GMM/main.cpp b/GMM/main.cpp
index ace5ff6..21163f3 100644
--- a/GMM/main.cpp
+++ b/GMM/main.cpp
@@ -21,6 +21,7 @@
#endif
#include
#include "gmm.h"
+#include "thread_pool.h"
using namespace std;
using std::cout;
using std::vector;
@@ -113,7 +114,6 @@ void PutXtoBin(double* x, int xSize, double binWidth, vector& vXBin, vec
int binIdx = (int)((x[i] + binWidth / 2) / binWidth);
if (binIdx >= binSize) binIdx = binSize - 1;
vYBin[binIdx] += 1;
- // vXBin[i] = binIdx * binWidth;
}
// 按大小顺序将修改后的x数值存储在vXBin中,点的顺序不同,训练出的高斯混合模型参数会有一些不同。
int xIdx = 0;
@@ -148,10 +148,8 @@ void GMMToFactorEY(GMM& gmm, double binWidth, vector &vYBin, vector topEle = pqTopM.top();
pqTopM.pop();
- // cout << topEle.first << '\t' << topEle.second << endl;
zoomFactorSum += topEle.first / topEle.second;
}
- // cout << endl;
double zoomFactor = zoomFactorSum / topM;
@@ -198,122 +196,9 @@ double CorrelationDistance(vector& vX, vector& vY) {
double vv = SquareAverage(vY);
double dist = 1.0 - uv / sqrt(uu * vv);
-
return abs(dist);
}
-/* 处理matlab的mat文件中包含的待拟合的数据 */
-void processMatData(const string& filePath) {
-
- double* hs = nullptr;
- int rowNum = 0;
- int colNum = 0;
-
- clock_t begin, finish;
- double total_cov = 0;
- double total_cov2 = 0;
-
- begin = clock();
- hs = ReadMatlabMat(filePath, "hs", &rowNum, &colNum);
- ofstream gmmOfs("mat_gmm.debug");
- ofstream gmmOfs2("mat_gmm2.debug");
- ofstream xyOfs("xy_cpp.debug");
- ofstream brOfs("br.debug");
- vectorvXBin;
- vectorvYBin;
- vectorvEY;
- vectorvFactor;
- /* 用来保存数据,存入mat文件 */
- vectorvDist(rowNum);
- vectorvFactorAll;
- for (int i = 0; i < rowNum; ++i) {
- PutXtoBin(hs + i * colNum, colNum, 0.2, vXBin, vYBin);
- // for (int m = 0; m < vYBin.size(); ++m) xyOfs<< fixed << setprecision(1) << 0.2 * m << ' ';
- // xyOfs << endl;
- // for (int m = 0; m < vYBin.size(); ++m) xyOfs << (int)vYBin[m] << ' ';
- // xyOfs << endl;
-
- GMM gmm(1, 2); // 1维, 2个高斯模型
- gmm.Train(vXBin.data(), vXBin.size());
- total_cov += *gmm.Variance(0);
- gmmOfs << gmm << endl;
-
- GMMToFactorEY(gmm, 0.2, vYBin, vFactor, vEY);
- vDist[i] = CorrelationDistance(vYBin, vEY);
- vFactorAll.insert(vFactorAll.end(), vFactor.begin(), vFactor.end());
-
- brOfs << CorrelationDistance(vYBin, vEY) << endl;
- for (int j = 0; j < vFactor.size(); ++j) brOfs << vFactor[j] << ", ";
-
- GMM gmm2(1, 2);
- gmm2.Train(hs + i * colNum, colNum);
- total_cov2 += *gmm2.Variance(0);
- gmmOfs2 << gmm2 << endl;
- }
- /* 写入matlab文件 */
- MATFile* pMatFile = matOpen("D:\\save_br.mat", "w");
- SaveMatrix(vFactorAll.data(), pMatFile, "factor", rowNum, 6);
- SaveMatrix(vDist.data(), pMatFile, "correlation", rowNum, 1);
- matClose(pMatFile);
-
- gmmOfs.close();
- gmmOfs2.close();
- xyOfs.close();
- brOfs.close();
- finish = clock();
- cout << "Total cov: " << total_cov << endl;
- cout << "Total cov2: " << total_cov2 << endl;
- cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
-
- //MATFile* pMatFile = matOpen("D:\\save_hs.mat", "w");
- //SaveMatrix(hs, pMatFile, "hs_saved", rowNum, colNum);
- //matClose(pMatFile);
-
- delete[] hs;
-}
-
-/* 处理已经转换成txt的文本数据 */
-void processTxtData(const string& filePath) {
- clock_t begin, finish;
- double total_cov = 0;
- ifstream ifs(filePath, ios::in);
-
- begin = clock();
- ofstream gmmOfs("txt_gmm.debug");
- while (!ifs.eof()) {
- vector vec_point;
- string x_str, y_str;
- if (!getline(ifs, x_str)) break;
- if (!getline(ifs, y_str)) break;
- // cout << x_str << endl << y_str << endl;
-
- stringstream ss_x(x_str);
- stringstream ss_y(y_str);
-
- float x, y;
-
- while (ss_x >> x && ss_y >> y) {
- vec_point.resize(vec_point.size() + y);
- for (int i = vec_point.size() - y; i < vec_point.size(); ++i)
- vec_point[i] = x;
- }
- if (vec_point.size() == 0) continue;
-
- GMM gmm(1, 2); // 1维, 2个高斯模型
- gmm.Train(vec_point.data(), vec_point.size());
- // cout << *gmm.Mean(0) << endl;
- total_cov += *gmm.Variance(0);
- gmmOfs << gmm << endl;
- }
- gmmOfs.close();
- finish = clock();
- cout << "Total cov: " << total_cov << endl;
- cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
-
- if (ifs.is_open())
- ifs.close();
-}
-
/* 处理一个知识颗粒 */
struct ThreadParam {
fs::path matFilePath;
@@ -322,8 +207,6 @@ struct ThreadParam {
void ThreadProcessData(const ThreadParam& param) {
const fs::path& matFilePath = param.matFilePath;
const fs::path& outFilePath = param.outFilePath;
- // cout << parrentPath.string() << '\t' << matFilePath.filename().string() << endl;
- cout << outFilePath.string() << endl;
double* hs = nullptr;
int rowNum = 0;
int colNum = 0;
@@ -356,42 +239,34 @@ void ThreadProcessData(const ThreadParam& param) {
delete[] hs;
}
+/* 程序入口 */
int main(int argc, char** argv) {
- if (argc != 4) {
- cerr << "This program should take 3 arguments(1.parrent Dir; 2. mat file suffix; 3. out mat filename)!" << endl;
+ if (argc != 5) {
+ cerr << "This program should take 4 arguments(1.parrent Dir; 2. mat file suffix; 3. out mat filename; 4. thread number)!" << endl;
return 1;
}
string parrentDir(argv[1]); // 知识颗粒的父目录名称
string hsMatSuffix(argv[2]); // hs矩阵对应的mat文件的后缀名(可以是全文件名,可以是文件名后缀,必须保证唯一)
fs::path outFileName(argv[3]);
- vector vThread;
+ ThreadPool thPool(8);
clock_t begin, finish;
begin = clock();
- /* 遍历所有的知识颗粒目录,注意进行处理 */
+ /* 遍历所有的知识颗粒目录,逐一进行处理 */
for (auto& childDir : fs::directory_iterator(parrentDir)) {
- // cout << childDir.path().string() << endl;
fs::path outFilePath = childDir / outFileName;
for (auto& file : fs::directory_iterator(childDir)) {
- // cout << file.path().filename().string() << endl;
const string& fileName = file.path().filename().string();
auto rPos = fileName.rfind(hsMatSuffix);
if (rPos != string::npos && fileName.size() - rPos == hsMatSuffix.size()) {
ThreadParam tParam = { file, outFilePath };
- vThread.push_back(thread(ThreadProcessData, tParam));
- // ThreadProcessData(tParam);
+ thPool.enqueue(ThreadProcessData, tParam);
}
}
}
- for (auto& thread : vThread) {
- thread.join();
- }
+ thPool.~ThreadPool();
finish = clock();
cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
- // processMatData(argv[1]);
- // processMatData("D:\\Twirls\\runtime\\ALS_test\\1775\\twirls_id_abs2class_hs.mat");
- // processTxtData("D:\\Twirls\\backup\\xy.txt");
-
return 0;
}
\ No newline at end of file
diff --git a/GMM/thread_pool.h b/GMM/thread_pool.h
new file mode 100644
index 0000000..d9dc583
--- /dev/null
+++ b/GMM/thread_pool.h
@@ -0,0 +1,98 @@
+#ifndef THREAD_POOL_H
+#define THREAD_POOL_H
+
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+#include
+
+class ThreadPool {
+public:
+ ThreadPool(size_t);
+ template
+ auto enqueue(F&& f, Args&&... args)
+ ->std::future::type>;
+ ~ThreadPool();
+private:
+ // need to keep track of threads so we can join them
+ std::vector< std::thread > workers;
+ // the task queue
+ std::queue< std::function > tasks;
+
+ // synchronization
+ std::mutex queue_mutex;
+ std::condition_variable condition;
+ bool stop;
+};
+
+// the constructor just launches some amount of workers
+inline ThreadPool::ThreadPool(size_t threads)
+ : stop(false)
+{
+ for (size_t i = 0;i < threads;++i)
+ workers.emplace_back(
+ [this]
+ {
+ for (;;)
+ {
+ std::function task;
+
+ {
+ std::unique_lock lock(this->queue_mutex);
+ this->condition.wait(lock,
+ [this] { return this->stop || !this->tasks.empty(); });
+ if (this->stop && this->tasks.empty())
+ return;
+ task = std::move(this->tasks.front());
+ this->tasks.pop();
+ }
+
+ task();
+ }
+ }
+ );
+}
+
+// add new work item to the pool
+template
+auto ThreadPool::enqueue(F&& f, Args&&... args)
+-> std::future::type>
+{
+ using return_type = typename std::result_of::type;
+
+ auto task = std::make_shared< std::packaged_task >(
+ std::bind(std::forward(f), std::forward(args)...)
+ );
+
+ std::future res = task->get_future();
+ {
+ std::unique_lock lock(queue_mutex);
+
+ // don't allow enqueueing after stopping the pool
+ if (stop)
+ throw std::runtime_error("enqueue on stopped ThreadPool");
+
+ tasks.emplace([task]() { (*task)(); });
+ }
+ condition.notify_one();
+ return res;
+}
+
+// the destructor joins all threads
+inline ThreadPool::~ThreadPool()
+{
+ {
+ std::unique_lock lock(queue_mutex);
+ stop = true;
+ }
+ condition.notify_all();
+ for (std::thread& worker : workers)
+ worker.join();
+}
+
+#endif
\ No newline at end of file