twirls/GMM/main.cpp

397 lines
12 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <iostream>
#include <iomanip>
#include <fstream>
#include <sstream>
#include <algorithm>
#include <random>
#include <unordered_map>
#include <omp.h>
#include <time.h>
#include <string>
#include <vector>
#include <queue>
#include <filesystem>
#include <thread>
#ifdef _WIN32
#include <io.h>
#include <process.h>
#define F_OK 0
#else
#include <unistd.h>
#endif
#include <mat.h>
#include "gmm.h"
using namespace std;
using std::cout;
using std::vector;
namespace fs = std::filesystem;
/* 从mat文件中读取给定名称的矩阵数据并获取矩阵的行列数值 */
template<typename T>
T* ReadMatlabMat(const string &filePath, const string &mtxName, int *pRowNum, int *pColNum) {
T* dst = nullptr;
MATFile* pMatFile = nullptr;
mxArray* pMxArray = nullptr;
int rowNum, colNum;
double* matData;
pMatFile = matOpen(filePath.c_str(), "r"); //打开.mat文件
if (pMatFile == nullptr) {
cerr << "filePath is error!" << endl;
return nullptr;
}
pMxArray = matGetVariable(pMatFile, mtxName.c_str()); //获取.mat文件里面名为matrixName的矩阵
rowNum = mxGetM(pMxArray);
colNum = mxGetN(pMxArray);
// cout << rowNum << " " << colNum << endl;
matData = (double*)mxGetData(pMxArray); //获取指针
dst = new T[rowNum * colNum];
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
dst[i * colNum + j] = T(matData[j * rowNum + i]);
}
}
mxDestroyArray(pMxArray); //释放内存
matClose(pMatFile); // 关闭文件
*pRowNum = rowNum;
*pColNum = colNum;
return dst;
}
/* 将数据写入mat文件中用给定的名称命名 */
template<typename T>
bool SaveMatrix(T* src, MATFile* pMatFile, string matrixName, int rowNum, int colNum)
{
//转置存储
int datasize = colNum * rowNum;
double* mtxData = new double[datasize];//待存储数据转为double格式
// memset(mtxData, 0, datasize * sizeof(double));
for (int i = 0; i < rowNum; i++)
{
for (int j = 0; j < colNum; j++)
{
mtxData[j * rowNum + i] = double(src[i * colNum + j]);
// *(mtxData + j * rowNum + i) = (double)src[i * colNum + j]; 可消除警告
}
}
mxArray* pWriteArray = NULL;//matlab格式矩阵
if (pMatFile == nullptr)
{
cerr << "mat file pointer is error!" << endl;
return false;
}
//创建一个rowNum*colNum的矩阵
pWriteArray = mxCreateDoubleMatrix(rowNum, colNum, mxREAL);
//把data的值赋给pWriteArray指针
memcpy((void*)(mxGetPr(pWriteArray)), (void*)mtxData, sizeof(double) * datasize);
//给矩阵命名为matrixName
matPutVariable(pMatFile, matrixName.c_str(), pWriteArray);
mxDestroyArray(pWriteArray);//release resource
delete[]mtxData;//release resource
return true;
}
/* 将x向量放到宽度为binWidth大小的桶中功能类似matlab的hist*/
void PutXtoBin(double* x, int xSize, double binWidth, vector<double>& vXBin, vector<double>& vYBin) {
double maxX = 0.0;
for (int i = 0; i < xSize; ++i) {
if (maxX < x[i]) maxX = x[i];
}
int binSize = (int)((maxX + binWidth / 2) / binWidth + 1);
double binMaxVal = (binSize - 1) * binWidth;
if (binMaxVal > maxX) { // 确保最后一个bin不大于maxX而且不小于maxX-binWidth
binSize -= 1;
}
vXBin.resize(xSize);
vYBin.resize(binSize);
for (int i = 0; i < binSize; ++i) vYBin[i] = 0;
for (int i = 0; i < xSize; ++i) {
int binIdx = (int)((x[i] + binWidth / 2) / binWidth);
if (binIdx >= binSize) binIdx = binSize - 1;
vYBin[binIdx] += 1;
// vXBin[i] = binIdx * binWidth;
}
// 按大小顺序将修改后的x数值存储在vXBin中点的顺序不同训练出的高斯混合模型参数会有一些不同。
int xIdx = 0;
for (int i = 0; i < binSize; ++i) {
for (int j = 0; j < vYBin[i]; ++j) {
vXBin[xIdx++] = i * binWidth;
}
}
}
/* 将标准高斯模型训练出的参数转换成自定义的系数, 并返回拟合后的Y值向量 */
struct cmpFunc {
bool operator()(const pair<double, double>& a, const pair<double, double>& b) { return a.first < b.first; }
};
void GMMToFactorEY(GMM& gmm, double binWidth, vector<double> &vYBin, vector<double>& vFactor, vector<double>& vEY) {
/* 需要调整曲线的权重,来拟合高斯曲线,而不是用概率密度 */
double zoomFactorSum = 0.0;
int valNum = 0;
vEY.resize(vYBin.size());
int topM = vYBin.size() / 4;
if (topM < 1) topM = 1;
/* 用堆排序的方式取前topM个最大值, 用来计算缩放参数*/
priority_queue<pair<double, double>, vector<pair<double, double> >, cmpFunc> pqTopM;
for (int i = 0; i < vYBin.size(); ++i) {
double xVal = i * binWidth;
double probVal = gmm.GetProbability(&xVal);
vEY[i] = probVal;
pqTopM.push(make_pair(vYBin[i], probVal));
}
for (int i = 0; i < topM; ++i) {
pair<double, double> topEle = pqTopM.top();
pqTopM.pop();
// cout << topEle.first << '\t' << topEle.second << endl;
zoomFactorSum += topEle.first / topEle.second;
}
// cout << endl;
double zoomFactor = zoomFactorSum / topM;
for (int i = 0; i < vEY.size(); ++i) {
vEY[i] *= zoomFactor;
}
vFactor.clear();
vFactor.push_back(zoomFactor * gmm.Prior(0) / sqrt(2 * M_PI * *gmm.Variance(0)));
vFactor.push_back(*gmm.Mean(0));
vFactor.push_back(sqrt(2 * *gmm.Variance(0)));
vFactor.push_back(zoomFactor * gmm.Prior(1) / sqrt(2 * M_PI * *gmm.Variance(1)));
vFactor.push_back(*gmm.Mean(1));
vFactor.push_back(sqrt(2 * *gmm.Variance(1)));
}
/* 计算平均数 */
template <typename T>
T Average(vector<T>& vVal) {
T sumVal = T(0);
for (int i = 0; i < vVal.size(); ++i) {
sumVal += vVal[i];
}
return sumVal / vVal.size();
}
/* 计算平方的均值 */
template <typename T>
T SquareAverage(vector<T>& vVal) {
vector<T> vSquare(vVal.size());
for (int i = 0; i < vVal.size(); ++i) {
vSquare[i] = vVal[i] * vVal[i];
}
return Average(vSquare);
}
/* 计算向量x和y的相关距离, 向量维度必须相等*/
double CorrelationDistance(vector<double>& vX, vector<double>& vY) {
vector<double> vXY(vX.size());
for (int i = 0; i < vXY.size(); ++i) {
vXY[i] = vX[i] * vY[i];
}
double uv = Average(vXY);
double uu = SquareAverage(vX);
double vv = SquareAverage(vY);
double dist = 1.0 - uv / sqrt(uu * vv);
return abs(dist);
}
/* 处理matlab的mat文件中包含的待拟合的数据 */
void processMatData(const string& filePath) {
double* hs = nullptr;
int rowNum = 0;
int colNum = 0;
clock_t begin, finish;
double total_cov = 0;
double total_cov2 = 0;
begin = clock();
hs = ReadMatlabMat<double>(filePath, "hs", &rowNum, &colNum);
ofstream gmmOfs("mat_gmm.debug");
ofstream gmmOfs2("mat_gmm2.debug");
ofstream xyOfs("xy_cpp.debug");
ofstream brOfs("br.debug");
vector<double>vXBin;
vector<double>vYBin;
vector<double>vEY;
vector<double>vFactor;
/* 用来保存数据存入mat文件 */
vector<double>vDist(rowNum);
vector<double>vFactorAll;
for (int i = 0; i < rowNum; ++i) {
PutXtoBin(hs + i * colNum, colNum, 0.2, vXBin, vYBin);
// for (int m = 0; m < vYBin.size(); ++m) xyOfs<< fixed << setprecision(1) << 0.2 * m << ' ';
// xyOfs << endl;
// for (int m = 0; m < vYBin.size(); ++m) xyOfs << (int)vYBin[m] << ' ';
// xyOfs << endl;
GMM gmm(1, 2); // 1维 2个高斯模型
gmm.Train(vXBin.data(), vXBin.size());
total_cov += *gmm.Variance(0);
gmmOfs << gmm << endl;
GMMToFactorEY(gmm, 0.2, vYBin, vFactor, vEY);
vDist[i] = CorrelationDistance(vYBin, vEY);
vFactorAll.insert(vFactorAll.end(), vFactor.begin(), vFactor.end());
brOfs << CorrelationDistance(vYBin, vEY) << endl;
for (int j = 0; j < vFactor.size(); ++j) brOfs << vFactor[j] << ", ";
GMM gmm2(1, 2);
gmm2.Train(hs + i * colNum, colNum);
total_cov2 += *gmm2.Variance(0);
gmmOfs2 << gmm2 << endl;
}
/* 写入matlab文件 */
MATFile* pMatFile = matOpen("D:\\save_br.mat", "w");
SaveMatrix<double>(vFactorAll.data(), pMatFile, "factor", rowNum, 6);
SaveMatrix<double>(vDist.data(), pMatFile, "correlation", rowNum, 1);
matClose(pMatFile);
gmmOfs.close();
gmmOfs2.close();
xyOfs.close();
brOfs.close();
finish = clock();
cout << "Total cov: " << total_cov << endl;
cout << "Total cov2: " << total_cov2 << endl;
cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
//MATFile* pMatFile = matOpen("D:\\save_hs.mat", "w");
//SaveMatrix<double>(hs, pMatFile, "hs_saved", rowNum, colNum);
//matClose(pMatFile);
delete[] hs;
}
/* 处理已经转换成txt的文本数据 */
void processTxtData(const string& filePath) {
clock_t begin, finish;
double total_cov = 0;
ifstream ifs(filePath, ios::in);
begin = clock();
ofstream gmmOfs("txt_gmm.debug");
while (!ifs.eof()) {
vector<double> vec_point;
string x_str, y_str;
if (!getline(ifs, x_str)) break;
if (!getline(ifs, y_str)) break;
// cout << x_str << endl << y_str << endl;
stringstream ss_x(x_str);
stringstream ss_y(y_str);
float x, y;
while (ss_x >> x && ss_y >> y) {
vec_point.resize(vec_point.size() + y);
for (int i = vec_point.size() - y; i < vec_point.size(); ++i)
vec_point[i] = x;
}
if (vec_point.size() == 0) continue;
GMM gmm(1, 2); // 1维 2个高斯模型
gmm.Train(vec_point.data(), vec_point.size());
// cout << *gmm.Mean(0) << endl;
total_cov += *gmm.Variance(0);
gmmOfs << gmm << endl;
}
gmmOfs.close();
finish = clock();
cout << "Total cov: " << total_cov << endl;
cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
if (ifs.is_open())
ifs.close();
}
/* 处理一个知识颗粒 */
struct ThreadParam {
fs::path matFilePath;
fs::path outFilePath;
};
void ThreadProcessData(const ThreadParam& param) {
const fs::path& matFilePath = param.matFilePath;
const fs::path& outFilePath = param.outFilePath;
// cout << parrentPath.string() << '\t' << matFilePath.filename().string() << endl;
cout << outFilePath.string() << endl;
double* hs = nullptr;
int rowNum = 0;
int colNum = 0;
hs = ReadMatlabMat<double>(matFilePath.string(), "hs", &rowNum, &colNum);
vector<double>vXBin;
vector<double>vYBin;
vector<double>vEY;
vector<double>vFactor;
/* 用来保存数据存入mat文件 */
vector<double>vDist(rowNum);
vector<double>vFactorAll;
for (int i = 0; i < rowNum; ++i) {
PutXtoBin(hs + i * colNum, colNum, 0.2, vXBin, vYBin);
GMM gmm(1, 2); // 1维 2个高斯模型
gmm.Train(vXBin.data(), vXBin.size());
GMMToFactorEY(gmm, 0.2, vYBin, vFactor, vEY);
vDist[i] = CorrelationDistance(vYBin, vEY);
vFactorAll.insert(vFactorAll.end(), vFactor.begin(), vFactor.end());
}
/* 写入matlab文件 */
MATFile* pMatFile = matOpen(outFilePath.string().c_str(), "w");
SaveMatrix<double>(vFactorAll.data(), pMatFile, "factor", rowNum, 6);
SaveMatrix<double>(vDist.data(), pMatFile, "correlation", rowNum, 1);
matClose(pMatFile);
delete[] hs;
}
int main(int argc, char** argv) {
if (argc != 4) {
cerr << "This program should take 3 arguments(1.parrent Dir; 2. mat file suffix; 3. out mat filename)!" << endl;
return 1;
}
string parrentDir(argv[1]); // 知识颗粒的父目录名称
string hsMatSuffix(argv[2]); // hs矩阵对应的mat文件的后缀名可以是全文件名可以是文件名后缀必须保证唯一
fs::path outFileName(argv[3]);
vector<thread> vThread;
clock_t begin, finish;
begin = clock();
/* 遍历所有的知识颗粒目录,注意进行处理 */
for (auto& childDir : fs::directory_iterator(parrentDir)) {
// cout << childDir.path().string() << endl;
fs::path outFilePath = childDir / outFileName;
for (auto& file : fs::directory_iterator(childDir)) {
// cout << file.path().filename().string() << endl;
const string& fileName = file.path().filename().string();
auto rPos = fileName.rfind(hsMatSuffix);
if (rPos != string::npos && fileName.size() - rPos == hsMatSuffix.size()) {
ThreadParam tParam = { file, outFilePath };
vThread.push_back(thread(ThreadProcessData, tParam));
// ThreadProcessData(tParam);
}
}
}
for (auto& thread : vThread) {
thread.join();
}
finish = clock();
cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
// processMatData(argv[1]);
// processMatData("D:\\Twirls\\runtime\\ALS_test\\1775\\twirls_id_abs2class_hs.mat");
// processTxtData("D:\\Twirls\\backup\\xy.txt");
return 0;
}