twirls/GMM/main.cpp

397 lines
12 KiB
C++
Raw Normal View History

#include <iostream>
#include <iomanip>
#include <fstream>
#include <sstream>
#include <algorithm>
#include <random>
#include <unordered_map>
#include <omp.h>
#include <time.h>
#include <string>
#include <vector>
#include <queue>
#include <filesystem>
#include <thread>
#ifdef _WIN32
#include <io.h>
#include <process.h>
#define F_OK 0
#else
#include <unistd.h>
#endif
#include <mat.h>
#include "gmm.h"
using namespace std;
using std::cout;
using std::vector;
namespace fs = std::filesystem;
/* <20><>mat<61>ļ<EFBFBD><C4BC>ж<EFBFBD>ȡ<EFBFBD><C8A1><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ƶľ<C6B5><C4BE><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ݣ<EFBFBD><DDA3><EFBFBD><EFBFBD><EFBFBD>ȡ<EFBFBD><C8A1><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ */
template<typename T>
T* ReadMatlabMat(const string &filePath, const string &mtxName, int *pRowNum, int *pColNum) {
T* dst = nullptr;
MATFile* pMatFile = nullptr;
mxArray* pMxArray = nullptr;
int rowNum, colNum;
double* matData;
pMatFile = matOpen(filePath.c_str(), "r"); //<2F><><EFBFBD><EFBFBD>.mat<61>ļ<EFBFBD>
if (pMatFile == nullptr) {
cerr << "filePath is error!" << endl;
return nullptr;
}
pMxArray = matGetVariable(pMatFile, mtxName.c_str()); //<2F><>ȡ.mat<61>ļ<EFBFBD><C4BC><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ΪmatrixName<6D>ľ<EFBFBD><C4BE><EFBFBD>
rowNum = mxGetM(pMxArray);
colNum = mxGetN(pMxArray);
// cout << rowNum << " " << colNum << endl;
matData = (double*)mxGetData(pMxArray); //<2F><>ȡָ<C8A1><D6B8>
dst = new T[rowNum * colNum];
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
dst[i * colNum + j] = T(matData[j * rowNum + i]);
}
}
mxDestroyArray(pMxArray); //<2F>ͷ<EFBFBD><CDB7>ڴ<EFBFBD>
matClose(pMatFile); // <20>ر<EFBFBD><D8B1>ļ<EFBFBD>
*pRowNum = rowNum;
*pColNum = colNum;
return dst;
}
/* <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>д<EFBFBD><D0B4>mat<61>ļ<EFBFBD><C4BC>У<EFBFBD><D0A3>ø<EFBFBD><C3B8><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> */
template<typename T>
bool SaveMatrix(T* src, MATFile* pMatFile, string matrixName, int rowNum, int colNum)
{
//ת<>ô洢
int datasize = colNum * rowNum;
double* mtxData = new double[datasize];//<2F><><EFBFBD><EFBFBD><E6B4A2><EFBFBD><EFBFBD>תΪdouble<6C><65>ʽ
// memset(mtxData, 0, datasize * sizeof(double));
for (int i = 0; i < rowNum; i++)
{
for (int j = 0; j < colNum; j++)
{
mtxData[j * rowNum + i] = double(src[i * colNum + j]);
// *(mtxData + j * rowNum + i) = (double)src[i * colNum + j]; <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
}
}
mxArray* pWriteArray = NULL;//matlab<61><62>ʽ<EFBFBD><CABD><EFBFBD><EFBFBD>
if (pMatFile == nullptr)
{
cerr << "mat file pointer is error!" << endl;
return false;
}
//<2F><><EFBFBD><EFBFBD>һ<EFBFBD><D2BB>rowNum*colNum<75>ľ<EFBFBD><C4BE><EFBFBD>
pWriteArray = mxCreateDoubleMatrix(rowNum, colNum, mxREAL);
//<2F><>data<74><61>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD>pWriteArrayָ<79><D6B8>
memcpy((void*)(mxGetPr(pWriteArray)), (void*)mtxData, sizeof(double) * datasize);
//<2F><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ΪmatrixName
matPutVariable(pMatFile, matrixName.c_str(), pWriteArray);
mxDestroyArray(pWriteArray);//release resource
delete[]mtxData;//release resource
return true;
}
/* <20><>x<EFBFBD><78><EFBFBD><EFBFBD><EFBFBD>ŵ<EFBFBD><C5B5><EFBFBD><EFBFBD><EFBFBD>ΪbinWidth<74><68>С<EFBFBD><D0A1>Ͱ<EFBFBD>У<EFBFBD><D0A3><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>matlab<61><62>hist*/
void PutXtoBin(double* x, int xSize, double binWidth, vector<double>& vXBin, vector<double>& vYBin) {
double maxX = 0.0;
for (int i = 0; i < xSize; ++i) {
if (maxX < x[i]) maxX = x[i];
}
int binSize = (int)((maxX + binWidth / 2) / binWidth + 1);
double binMaxVal = (binSize - 1) * binWidth;
if (binMaxVal > maxX) { // ȷ<><C8B7><EFBFBD><EFBFBD><EFBFBD><EFBFBD>һ<EFBFBD><D2BB>bin<69><6E><EFBFBD><EFBFBD><EFBFBD><EFBFBD>maxX<78><58><EFBFBD><EFBFBD><EFBFBD>Ҳ<EFBFBD>С<EFBFBD><D0A1>maxX-binWidth
binSize -= 1;
}
vXBin.resize(xSize);
vYBin.resize(binSize);
for (int i = 0; i < binSize; ++i) vYBin[i] = 0;
for (int i = 0; i < xSize; ++i) {
int binIdx = (int)((x[i] + binWidth / 2) / binWidth);
if (binIdx >= binSize) binIdx = binSize - 1;
vYBin[binIdx] += 1;
// vXBin[i] = binIdx * binWidth;
}
// <20><><EFBFBD><EFBFBD>С˳<D0A1><CBB3><EFBFBD><EFBFBD><EFBFBD>޸ĺ<DEB8><C4BA><EFBFBD>x<EFBFBD><78>ֵ<EFBFBD><EFBFBD><E6B4A2>vXBin<69>У<EFBFBD><D0A3><EFBFBD><EFBFBD><EFBFBD>˳<EFBFBD><CBB3><EFBFBD><EFBFBD>ͬ<EFBFBD><CDAC>ѵ<EFBFBD><D1B5><EFBFBD><EFBFBD><EFBFBD>ĸ<EFBFBD>˹<EFBFBD><CBB9><EFBFBD><EFBFBD>ģ<EFBFBD>Ͳ<EFBFBD><CDB2><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>һЩ<D2BB><D0A9>ͬ<EFBFBD><CDAC>
int xIdx = 0;
for (int i = 0; i < binSize; ++i) {
for (int j = 0; j < vYBin[i]; ++j) {
vXBin[xIdx++] = i * binWidth;
}
}
}
/* <20><><EFBFBD><EFBFBD>׼<EFBFBD><D7BC>˹ģ<CBB9><C4A3>ѵ<EFBFBD><D1B5><EFBFBD><EFBFBD><EFBFBD>IJ<EFBFBD><C4B2><EFBFBD>ת<EFBFBD><D7AA><EFBFBD><EFBFBD><EFBFBD>Զ<EFBFBD><D4B6><EFBFBD><EFBFBD><EFBFBD>ϵ<EFBFBD><CFB5>, <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϻ<EFBFBD><CFBA><EFBFBD><59><D6B5><EFBFBD><EFBFBD> */
struct cmpFunc {
bool operator()(const pair<double, double>& a, const pair<double, double>& b) { return a.first < b.first; }
};
void GMMToFactorEY(GMM& gmm, double binWidth, vector<double> &vYBin, vector<double>& vFactor, vector<double>& vEY) {
/* <20><>Ҫ<EFBFBD><D2AA><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ߵ<EFBFBD>Ȩ<EFBFBD>أ<EFBFBD><D8A3><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ϸ<EFBFBD>˹<EFBFBD><CBB9><EFBFBD>ߣ<EFBFBD><DFA3><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ø<EFBFBD><C3B8><EFBFBD><EFBFBD>ܶ<EFBFBD> */
double zoomFactorSum = 0.0;
int valNum = 0;
vEY.resize(vYBin.size());
int topM = vYBin.size() / 4;
if (topM < 1) topM = 1;
/* <20>ö<EFBFBD><C3B6><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ķ<EFBFBD>ʽȡǰtopM<70><4D><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ, <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ų<EFBFBD><C5B2><EFBFBD>*/
priority_queue<pair<double, double>, vector<pair<double, double> >, cmpFunc> pqTopM;
for (int i = 0; i < vYBin.size(); ++i) {
double xVal = i * binWidth;
double probVal = gmm.GetProbability(&xVal);
vEY[i] = probVal;
pqTopM.push(make_pair(vYBin[i], probVal));
}
for (int i = 0; i < topM; ++i) {
pair<double, double> topEle = pqTopM.top();
pqTopM.pop();
// cout << topEle.first << '\t' << topEle.second << endl;
zoomFactorSum += topEle.first / topEle.second;
}
// cout << endl;
double zoomFactor = zoomFactorSum / topM;
for (int i = 0; i < vEY.size(); ++i) {
vEY[i] *= zoomFactor;
}
vFactor.clear();
vFactor.push_back(zoomFactor * gmm.Prior(0) / sqrt(2 * M_PI * *gmm.Variance(0)));
vFactor.push_back(*gmm.Mean(0));
vFactor.push_back(sqrt(2 * *gmm.Variance(0)));
vFactor.push_back(zoomFactor * gmm.Prior(1) / sqrt(2 * M_PI * *gmm.Variance(1)));
vFactor.push_back(*gmm.Mean(1));
vFactor.push_back(sqrt(2 * *gmm.Variance(1)));
}
/* <20><><EFBFBD><EFBFBD>ƽ<EFBFBD><C6BD><EFBFBD><EFBFBD> */
template <typename T>
T Average(vector<T>& vVal) {
T sumVal = T(0);
for (int i = 0; i < vVal.size(); ++i) {
sumVal += vVal[i];
}
return sumVal / vVal.size();
}
/* <20><><EFBFBD><EFBFBD>ƽ<EFBFBD><C6BD><EFBFBD>ľ<EFBFBD>ֵ */
template <typename T>
T SquareAverage(vector<T>& vVal) {
vector<T> vSquare(vVal.size());
for (int i = 0; i < vVal.size(); ++i) {
vSquare[i] = vVal[i] * vVal[i];
}
return Average(vSquare);
}
/* <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>x<EFBFBD><78>y<EFBFBD><79><EFBFBD><EFBFBD><EFBFBD>ؾ<EFBFBD><D8BE><EFBFBD>, <20><><EFBFBD><EFBFBD>ά<EFBFBD>ȱ<EFBFBD><C8B1><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>*/
double CorrelationDistance(vector<double>& vX, vector<double>& vY) {
vector<double> vXY(vX.size());
for (int i = 0; i < vXY.size(); ++i) {
vXY[i] = vX[i] * vY[i];
}
double uv = Average(vXY);
double uu = SquareAverage(vX);
double vv = SquareAverage(vY);
double dist = 1.0 - uv / sqrt(uu * vv);
return abs(dist);
}
/* <20><><EFBFBD><EFBFBD>matlab<61><62>mat<61>ļ<EFBFBD><C4BC>а<EFBFBD><D0B0><EFBFBD><EFBFBD>Ĵ<EFBFBD><C4B4><EFBFBD><EFBFBD>ϵ<EFBFBD><CFB5><EFBFBD><EFBFBD><EFBFBD> */
void processMatData(const string& filePath) {
double* hs = nullptr;
int rowNum = 0;
int colNum = 0;
clock_t begin, finish;
double total_cov = 0;
double total_cov2 = 0;
begin = clock();
hs = ReadMatlabMat<double>(filePath, "hs", &rowNum, &colNum);
ofstream gmmOfs("mat_gmm.debug");
ofstream gmmOfs2("mat_gmm2.debug");
ofstream xyOfs("xy_cpp.debug");
ofstream brOfs("br.debug");
vector<double>vXBin;
vector<double>vYBin;
vector<double>vEY;
vector<double>vFactor;
/* <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ݣ<EFBFBD><DDA3><EFBFBD><EFBFBD><EFBFBD>mat<61>ļ<EFBFBD> */
vector<double>vDist(rowNum);
vector<double>vFactorAll;
for (int i = 0; i < rowNum; ++i) {
PutXtoBin(hs + i * colNum, colNum, 0.2, vXBin, vYBin);
// for (int m = 0; m < vYBin.size(); ++m) xyOfs<< fixed << setprecision(1) << 0.2 * m << ' ';
// xyOfs << endl;
// for (int m = 0; m < vYBin.size(); ++m) xyOfs << (int)vYBin[m] << ' ';
// xyOfs << endl;
GMM gmm(1, 2); // 1ά<31><CEAC> 2<><32><EFBFBD><EFBFBD>˹ģ<CBB9><C4A3>
gmm.Train(vXBin.data(), vXBin.size());
total_cov += *gmm.Variance(0);
gmmOfs << gmm << endl;
GMMToFactorEY(gmm, 0.2, vYBin, vFactor, vEY);
vDist[i] = CorrelationDistance(vYBin, vEY);
vFactorAll.insert(vFactorAll.end(), vFactor.begin(), vFactor.end());
brOfs << CorrelationDistance(vYBin, vEY) << endl;
for (int j = 0; j < vFactor.size(); ++j) brOfs << vFactor[j] << ", ";
GMM gmm2(1, 2);
gmm2.Train(hs + i * colNum, colNum);
total_cov2 += *gmm2.Variance(0);
gmmOfs2 << gmm2 << endl;
}
/* д<><D0B4>matlab<61>ļ<EFBFBD> */
MATFile* pMatFile = matOpen("D:\\save_br.mat", "w");
SaveMatrix<double>(vFactorAll.data(), pMatFile, "factor", rowNum, 6);
SaveMatrix<double>(vDist.data(), pMatFile, "correlation", rowNum, 1);
matClose(pMatFile);
gmmOfs.close();
gmmOfs2.close();
xyOfs.close();
brOfs.close();
finish = clock();
cout << "Total cov: " << total_cov << endl;
cout << "Total cov2: " << total_cov2 << endl;
cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
//MATFile* pMatFile = matOpen("D:\\save_hs.mat", "w");
//SaveMatrix<double>(hs, pMatFile, "hs_saved", rowNum, colNum);
//matClose(pMatFile);
delete[] hs;
}
/* <20><><EFBFBD><EFBFBD><EFBFBD>Ѿ<EFBFBD>ת<EFBFBD><D7AA><EFBFBD><EFBFBD>txt<78><74><EFBFBD>ı<EFBFBD><C4B1><EFBFBD><EFBFBD><EFBFBD> */
void processTxtData(const string& filePath) {
clock_t begin, finish;
double total_cov = 0;
ifstream ifs(filePath, ios::in);
begin = clock();
ofstream gmmOfs("txt_gmm.debug");
while (!ifs.eof()) {
vector<double> vec_point;
string x_str, y_str;
if (!getline(ifs, x_str)) break;
if (!getline(ifs, y_str)) break;
// cout << x_str << endl << y_str << endl;
stringstream ss_x(x_str);
stringstream ss_y(y_str);
float x, y;
while (ss_x >> x && ss_y >> y) {
vec_point.resize(vec_point.size() + y);
for (int i = vec_point.size() - y; i < vec_point.size(); ++i)
vec_point[i] = x;
}
if (vec_point.size() == 0) continue;
GMM gmm(1, 2); // 1ά<31><CEAC> 2<><32><EFBFBD><EFBFBD>˹ģ<CBB9><C4A3>
gmm.Train(vec_point.data(), vec_point.size());
// cout << *gmm.Mean(0) << endl;
total_cov += *gmm.Variance(0);
gmmOfs << gmm << endl;
}
gmmOfs.close();
finish = clock();
cout << "Total cov: " << total_cov << endl;
cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
if (ifs.is_open())
ifs.close();
}
/* <20><><EFBFBD><EFBFBD>һ<EFBFBD><D2BB>֪ʶ<D6AA><CAB6><EFBFBD><EFBFBD> */
struct ThreadParam {
fs::path matFilePath;
fs::path outFilePath;
};
void ThreadProcessData(const ThreadParam& param) {
const fs::path& matFilePath = param.matFilePath;
const fs::path& outFilePath = param.outFilePath;
// cout << parrentPath.string() << '\t' << matFilePath.filename().string() << endl;
cout << outFilePath.string() << endl;
double* hs = nullptr;
int rowNum = 0;
int colNum = 0;
hs = ReadMatlabMat<double>(matFilePath.string(), "hs", &rowNum, &colNum);
vector<double>vXBin;
vector<double>vYBin;
vector<double>vEY;
vector<double>vFactor;
/* <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ݣ<EFBFBD><DDA3><EFBFBD><EFBFBD><EFBFBD>mat<61>ļ<EFBFBD> */
vector<double>vDist(rowNum);
vector<double>vFactorAll;
for (int i = 0; i < rowNum; ++i) {
PutXtoBin(hs + i * colNum, colNum, 0.2, vXBin, vYBin);
GMM gmm(1, 2); // 1ά<31><CEAC> 2<><32><EFBFBD><EFBFBD>˹ģ<CBB9><C4A3>
gmm.Train(vXBin.data(), vXBin.size());
GMMToFactorEY(gmm, 0.2, vYBin, vFactor, vEY);
vDist[i] = CorrelationDistance(vYBin, vEY);
vFactorAll.insert(vFactorAll.end(), vFactor.begin(), vFactor.end());
}
/* д<><D0B4>matlab<61>ļ<EFBFBD> */
MATFile* pMatFile = matOpen(outFilePath.string().c_str(), "w");
SaveMatrix<double>(vFactorAll.data(), pMatFile, "factor", rowNum, 6);
SaveMatrix<double>(vDist.data(), pMatFile, "correlation", rowNum, 1);
matClose(pMatFile);
delete[] hs;
}
int main(int argc, char** argv) {
if (argc != 4) {
cerr << "This program should take 3 arguments(1.parrent Dir; 2. mat file suffix; 3. out mat filename)!" << endl;
return 1;
}
string parrentDir(argv[1]); // ֪ʶ<D6AA><CAB6><EFBFBD><EFBFBD><EFBFBD>ĸ<EFBFBD>Ŀ¼<C4BF><C2BC><EFBFBD><EFBFBD>
string hsMatSuffix(argv[2]); // hs<68><73><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ӧ<EFBFBD><D3A6>mat<61>ļ<EFBFBD><C4BC>ĺ<EFBFBD>׺<EFBFBD><D7BA><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ȫ<EFBFBD>ļ<EFBFBD><C4BC><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ļ<EFBFBD><C4BC><EFBFBD><EFBFBD><EFBFBD>׺<EFBFBD><D7BA><EFBFBD><EFBFBD><EFBFBD>뱣֤Ψһ<CEA8><D2BB>
fs::path outFileName(argv[3]);
vector<thread> vThread;
clock_t begin, finish;
begin = clock();
/* <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>е<EFBFBD>֪ʶ<D6AA><CAB6><EFBFBD><EFBFBD>Ŀ¼<C4BF><C2BC>ע<EFBFBD><D7A2><EFBFBD><EFBFBD><EFBFBD>д<EFBFBD><D0B4><EFBFBD> */
for (auto& childDir : fs::directory_iterator(parrentDir)) {
// cout << childDir.path().string() << endl;
fs::path outFilePath = childDir / outFileName;
for (auto& file : fs::directory_iterator(childDir)) {
// cout << file.path().filename().string() << endl;
const string& fileName = file.path().filename().string();
auto rPos = fileName.rfind(hsMatSuffix);
if (rPos != string::npos && fileName.size() - rPos == hsMatSuffix.size()) {
ThreadParam tParam = { file, outFilePath };
vThread.push_back(thread(ThreadProcessData, tParam));
// ThreadProcessData(tParam);
}
}
}
for (auto& thread : vThread) {
thread.join();
}
finish = clock();
cout << "Total time:" << (double)(finish - begin) / CLOCKS_PER_SEC << endl;
// processMatData(argv[1]);
// processMatData("D:\\Twirls\\runtime\\ALS_test\\1775\\twirls_id_abs2class_hs.mat");
// processTxtData("D:\\Twirls\\backup\\xy.txt");
return 0;
}