twirls/CommonLib/matlab_io.cpp

205 lines
6.5 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

/*********************************************************************************************
Description: 提供基本matlab读写数据功能
Copyright : All right reserved by ZheYuan.BJ
Author : Zhang Zhonghai
Date : 2023/09/18
***********************************************************************************************/
#include <string>
#include <mat.h>
#include <iostream>
#include <vector>
#include "matlab_io.h"
using namespace std;
// 将matlab存储方式转换成c存储方式
#define TRANS_ROW_COL(dst, src, rowNum, colNum) \
for (int rowI = 0; rowI < rowNum; ++rowI) { \
for (int colJ = 0; colJ < colNum; ++colJ) { \
dst[rowI * colNum + colJ] = src[colJ * rowNum + rowI]; \
} \
}
/* 读取结构体中的二维字符串矩阵(一维的cell每个cell又有一层cell每个cell是一个字符串)*/
bool ReadChildString2D(const string& filePath, const string& parentName, const string& selfName, vector<vector<string> >& vvStr) {
MATFile* pMatFile = nullptr;
mxArray* pMxArray = nullptr;
mxArray* pCell = nullptr;
int rowNum, colNum;
char *strBuf = new char[STRING_BUF_SIZE];
pMatFile = matOpen(filePath.c_str(), "r"); //打开.mat文件
if (pMatFile == nullptr) {
cout << "filePath is error!" << endl;
return false;
}
mxArray* pMxParent = matGetVariable(pMatFile, parentName.c_str()); //获取G变量
// 读取字符串
pMxArray = mxGetField(pMxParent, 0, selfName.c_str()); // ds
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
pCell = mxGetCell(pMxArray, j * rowNum + i);
int childRowNum = (int)mxGetM(pCell);
int childColNum = (int)mxGetN(pCell);
vvStr.push_back(vector<string>());
vvStr.back().resize(childRowNum * childColNum);
for (int ii = 0; ii < childRowNum; ii++) {
for (int jj = 0; jj < childColNum; jj++) {
mxArray* pChildCell = mxGetCell(pCell, jj * childRowNum + ii);
if (mxGetString(pChildCell, strBuf, STRING_BUF_SIZE) != 0) {
cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl;
return false;
}
vvStr.back()[ii * childColNum + jj] = strBuf;
// auto& lastStr = vvStr.back()[ii * childColNum + jj];
// transform(lastStr.begin(), lastStr.end(), lastStr.begin(), ::toupper); // 转成大写
}
}
}
}
mxDestroyArray(pMxArray);
delete[]strBuf;
return true;
}
/* 读取结构体中的二维double矩阵(一维的cell每个cell又有一层cell每个cell是一个字符串)*/
bool ReadChildDouble2D(const string& filePath, const string& parentName, const string& selfName, vector<vector<double> >& vvDouble) {
MATFile* pMatFile = nullptr;
mxArray* pMxArray = nullptr;
mxArray* pCell = nullptr;
int rowNum, colNum;
pMatFile = matOpen(filePath.c_str(), "r"); //打开.mat文件
if (pMatFile == nullptr) {
cout << "filePath is error!" << endl;
return false;
}
mxArray* pMxParent = matGetVariable(pMatFile, parentName.c_str()); //获取G变量
// 读取double数据
pMxArray = mxGetField(pMxParent, 0, selfName.c_str()); // ds
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
pCell = mxGetCell(pMxArray, j * rowNum + i);
int childRowNum = (int)mxGetM(pCell);
int childColNum = (int)mxGetN(pCell);
vvDouble.push_back(vector<double>());
vvDouble.back().resize(childRowNum * childColNum);
double* pVal = (double*)mxGetData(pCell); //获取指针
TRANS_ROW_COL(vvDouble.back(), pVal, childRowNum, childColNum); // 行列存储方式转换
}
}
mxDestroyArray(pMxArray);
return true;
}
/* 读取字符串矩阵 */
bool ReadMtxString(const string& filePath, const string& mtxName,
vector<string>& vStr, int* pRowNum, int* pColNum) {
MATFile* pMatFile = nullptr;
mxArray* pMxArray = nullptr;
mxArray* pCell = nullptr;
int rowNum, colNum;
char strBuf[STRING_BUF_SIZE];
pMatFile = matOpen(filePath.c_str(), "r"); //打开.mat文件
if (pMatFile == nullptr) {
cout << "filePath is error!" << endl;
return false;
}
pMxArray = matGetVariable(pMatFile, mtxName.c_str()); //获取.mat文件里面名为matrixName的矩阵
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
vStr.resize(rowNum * colNum);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
pCell = mxGetCell(pMxArray, j * rowNum + i);
if (mxGetString(pCell, strBuf, STRING_BUF_SIZE) != 0) {
cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl;
return false;
}
vStr[i * colNum + j] = strBuf;
}
}
*pRowNum = rowNum;
*pColNum = colNum;
return true;
}
/* 从mat文件中读取给定名称的矩阵数据并获取矩阵的行列数值 */
T* ReadMtxDouble(const string& filePath, const string& mtxName, int* pRowNum, int* pColNum) {
T* dst = nullptr;
MATFile* pMatFile = nullptr;
mxArray* pMxArray = nullptr;
int rowNum, colNum;
double* matData;
pMatFile = matOpen(filePath.c_str(), "r"); //打开.mat文件
if (pMatFile == nullptr) {
cout << "filePath is error!" << endl;
return nullptr;
}
pMxArray = matGetVariable(pMatFile, mtxName.c_str()); //获取.mat文件里面名为matrixName的矩阵
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
// cout << rowNum << " " << colNum << endl;
matData = (double*)mxGetData(pMxArray); //获取指针
dst = new T[rowNum * colNum];
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
dst[i * colNum + j] = T(matData[j * rowNum + i]);
}
}
mxDestroyArray(pMxArray); //释放内存
matClose(pMatFile); // 关闭文件
*pRowNum = rowNum;
*pColNum = colNum;
return dst;
}
/* 将数据写入mat文件中用给定的名称命名 */
bool SaveMtxDouble(T* src, MATFile* pMatFile, string matrixName, int rowNum, int colNum)
{
//转置存储
int datasize = colNum * rowNum;
double* mtxData = new double[datasize];//待存储数据转为double格式
// memset(mtxData, 0, datasize * sizeof(double));
if (pMatFile == nullptr)
{
cout << "mat file pointer is error!" << endl;
return false;
}
for (int i = 0; i < rowNum; i++)
{
for (int j = 0; j < colNum; j++)
{
mtxData[j * rowNum + i] = double(src[i * colNum + j]);
// *(mtxData + j * rowNum + i) = (double)src[i * colNum + j]; 可消除警告
}
}
mxArray* pWriteArray = NULL;//matlab格式矩阵
//创建一个rowNum*colNum的矩阵
pWriteArray = mxCreateDoubleMatrix(rowNum, colNum, mxREAL);
//把data的值赋给pWriteArray指针
memcpy((void*)(mxGetPr(pWriteArray)), (void*)mtxData, sizeof(double) * datasize);
//给矩阵命名为matrixName
matPutVariable(pMatFile, matrixName.c_str(), pWriteArray);
mxDestroyArray(pWriteArray);//release resource
delete[]mtxData;//release resource
return true;
}