twirls/GMM/main.cpp

212 lines
6.2 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/*********************************************************************************************
Description: 用高斯混合模型进行曲线的拟合
Copyright : All right reserved by ZheYuan.BJ
Author : Zhang Zhonghai
Date : 2023/09/12
***********************************************************************************************/
#include <iostream>
#include <iomanip>
#include <fstream>
#include <sstream>
#include <algorithm>
#include <random>
#include <unordered_map>
#include <omp.h>
#include <time.h>
#include <string>
#include <vector>
#include <queue>
#include <filesystem>
#include <thread>
#ifdef _WIN32
#include <io.h>
#include <process.h>
#define F_OK 0
#else
#include <unistd.h>
#endif
#include <mat.h>
#include "gmm.h"
#include "thread_pool.h"
#include "CommonLib/matlab_io.h"
using namespace std;
using std::cout;
using std::vector;
namespace fs = std::filesystem;
/* 将x向量放到宽度为binWidth大小的桶中功能类似matlab的hist*/
void PutXtoBin(double* x, int xSize, double binWidth, vector<double>& vXBin, vector<double>& vYBin) {
double maxX = 0.0;
for (int i = 0; i < xSize; ++i) {
if (maxX < x[i]) maxX = x[i];
}
int binSize = (int)((maxX + binWidth / 2) / binWidth + 1);
double binMaxVal = (binSize - 1) * binWidth;
if (binMaxVal > maxX) { // 确保最后一个bin不大于maxX而且不小于maxX-binWidth
binSize -= 1;
}
vXBin.resize(xSize);
vYBin.resize(binSize);
for (int i = 0; i < binSize; ++i) vYBin[i] = 0;
for (int i = 0; i < xSize; ++i) {
int binIdx = (int)((x[i] + binWidth / 2) / binWidth);
if (binIdx >= binSize) binIdx = binSize - 1;
vYBin[binIdx] += 1;
}
// 按大小顺序将修改后的x数值存储在vXBin中点的顺序不同训练出的高斯混合模型参数会有一些不同。
int xIdx = 0;
for (int i = 0; i < binSize; ++i) {
for (int j = 0; j < vYBin[i]; ++j) {
vXBin[xIdx++] = i * binWidth;
}
}
}
/* 将标准高斯模型训练出的参数转换成自定义的系数, 并返回拟合后的Y值向量 */
struct cmpFunc {
bool operator()(const pair<double, double>& a, const pair<double, double>& b) { return a.first < b.first; }
};
void GMMToFactorEY(GMM& gmm, double binWidth, vector<double> &vYBin, vector<double>& vFactor, vector<double>& vEY) {
/* 需要调整曲线的权重,来拟合高斯曲线,而不是用概率密度 */
double zoomFactorSum = 0.0;
vEY.resize(vYBin.size());
int topM = (int)(vYBin.size() / 4);
if (topM < 1) topM = 1;
/* 用堆排序的方式取前topM个最大值, 用来计算缩放参数*/
priority_queue<pair<double, double>, vector<pair<double, double> >, cmpFunc> pqTopM;
for (int i = 0; i < vYBin.size(); ++i) {
double xVal = i * binWidth;
double probVal = gmm.GetProbability(&xVal);
vEY[i] = probVal;
pqTopM.push(make_pair(vYBin[i], probVal));
}
for (int i = 0; i < topM; ++i) {
pair<double, double> topEle = pqTopM.top();
pqTopM.pop();
zoomFactorSum += topEle.first / topEle.second;
}
double zoomFactor = zoomFactorSum / topM;
for (int i = 0; i < vEY.size(); ++i) {
vEY[i] *= zoomFactor;
}
vFactor.clear();
vFactor.push_back(zoomFactor * gmm.Prior(0) / sqrt(2 * M_PI * *gmm.Variance(0)));
vFactor.push_back(*gmm.Mean(0));
vFactor.push_back(sqrt(2 * *gmm.Variance(0)));
vFactor.push_back(zoomFactor * gmm.Prior(1) / sqrt(2 * M_PI * *gmm.Variance(1)));
vFactor.push_back(*gmm.Mean(1));
vFactor.push_back(sqrt(2 * *gmm.Variance(1)));
}
/* 计算平均数 */
template <typename T>
T Average(vector<T>& vVal) {
T sumVal = T(0);
for (int i = 0; i < vVal.size(); ++i) {
sumVal += vVal[i];
}
return sumVal / vVal.size();
}
/* 计算平方的均值 */
template <typename T>
T SquareAverage(vector<T>& vVal) {
vector<T> vSquare(vVal.size());
for (int i = 0; i < vVal.size(); ++i) {
vSquare[i] = vVal[i] * vVal[i];
}
return Average(vSquare);
}
/* 计算向量x和y的相关距离, 向量维度必须相等*/
double CorrelationDistance(vector<double>& vX, vector<double>& vY) {
vector<double> vXY(vX.size());
for (int i = 0; i < vXY.size(); ++i) {
vXY[i] = vX[i] * vY[i];
}
double uv = Average(vXY);
double uu = SquareAverage(vX);
double vv = SquareAverage(vY);
double dist = 1.0 - uv / sqrt(uu * vv);
return abs(dist);
}
/* 处理一个知识颗粒 */
struct ThreadParam {
fs::path matFilePath;
fs::path outFilePath;
};
void ThreadProcessData(const ThreadParam& param) {
const fs::path& matFilePath = param.matFilePath;
const fs::path& outFilePath = param.outFilePath;
double* hs = nullptr;
int rowNum = 0;
int colNum = 0;
hs = ReadMtxDouble(matFilePath.string(), "hs", &rowNum, &colNum);
vector<double>vXBin;
vector<double>vYBin;
vector<double>vEY;
vector<double>vFactor;
/* 用来保存数据存入mat文件 */
vector<double>vDist(rowNum);
vector<double>vFactorAll;
for (int i = 0; i < rowNum; ++i) {
PutXtoBin(hs + i * colNum, colNum, 0.2, vXBin, vYBin);
GMM gmm(1, 2); // 1维 2个高斯模型
gmm.Train(vXBin.data(), (int)vXBin.size());
GMMToFactorEY(gmm, 0.2, vYBin, vFactor, vEY);
vDist[i] = CorrelationDistance(vYBin, vEY);
vFactorAll.insert(vFactorAll.end(), vFactor.begin(), vFactor.end());
}
/* 写入matlab文件 */
MATFile* pMatFile = matOpen(outFilePath.string().c_str(), "w");
SaveMtxDouble(vFactorAll.data(), pMatFile, "factor", rowNum, 6);
SaveMtxDouble(vDist.data(), pMatFile, "correlation", rowNum, 1);
matClose(pMatFile);
delete[] hs;
}
/* 程序入口 */
int main(int argc, const char** argv) {
if (argc != 5) {
cerr << "This program should take 4 arguments(1.parrent Dir; 2. mat file suffix; 3. out mat filename; 4. thread number)!" << endl;
return 1;
}
string parrentDir(argv[1]); // 知识颗粒的父目录名称
string hsMatSuffix(argv[2]); // hs矩阵对应的mat文件的后缀名可以是全文件名可以是文件名后缀必须保证唯一
fs::path outFileName(argv[3]);
ThreadPool thPool(8);
clock_t begin, finish;
begin = clock();
/* 遍历所有的知识颗粒目录,逐一进行处理 */
for (auto& childDir : fs::directory_iterator(parrentDir)) {
fs::path outFilePath = childDir / outFileName;
for (auto& file : fs::directory_iterator(childDir)) {
const string& fileName = file.path().filename().string();
auto rPos = fileName.rfind(hsMatSuffix);
if (rPos != string::npos && fileName.size() - rPos == hsMatSuffix.size()) {
ThreadParam tParam = { file, outFilePath };
thPool.enqueue(ThreadProcessData, tParam);
}
}
}
thPool.~ThreadPool();
finish = clock();
cout << "GMM Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
return 0;
}