#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using std::cout; using std::endl; using namespace std; #define STRING_BUF_SIZE 204800 // 读取一维cell字符串并转换成大写 inline bool Read1DWord(const mxArray* pMxArray, vector& vStr) { mxArray* pCell = nullptr; int rowNum, colNum; char* strBuf = new char[STRING_BUF_SIZE]; rowNum = (int)mxGetM(pMxArray); colNum = (int)mxGetN(pMxArray); vStr.resize(rowNum * colNum); for (int i = 0; i < rowNum; ++i) { for (int j = 0; j < colNum; ++j) { pCell = mxGetCell(pMxArray, j * rowNum + i); if (mxGetString(pCell, strBuf, STRING_BUF_SIZE) != 0) { cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl; return false; } vStr[i * colNum + j] = strBuf; auto& lastStr = vStr[i * colNum + j]; transform(lastStr.cbegin(), lastStr.cend(), lastStr.begin(), ::toupper); // 转成大写 } } delete[]strBuf; return true; } // 读取二维cell字符串并转换成大写 inline bool Read2DWord(const mxArray* pMxArray, vector>& vvStr) { mxArray* pCell = nullptr; int rowNum, colNum; char* strBuf = new char[STRING_BUF_SIZE]; rowNum = (int)mxGetM(pMxArray); colNum = (int)mxGetN(pMxArray); for (int i = 0; i < rowNum; ++i) { for (int j = 0; j < colNum; ++j) { pCell = mxGetCell(pMxArray, j * rowNum + i); int childRowNum = (int)mxGetM(pCell); int childColNum = (int)mxGetN(pCell); vvStr.push_back(vector()); Read1DWord(pCell, vvStr.back()); } } delete[]strBuf; return true; } // 将结果写入mxArray, 作为后续的返回值 mxArray* writeToMatDouble(const double *data, int rowNum, int colNum) { mxArray* pWriteArray = NULL;//matlab格式矩阵 int len = rowNum * colNum; //创建一个rowNum*colNum的矩阵 pWriteArray = mxCreateDoubleMatrix(rowNum, colNum, mxREAL); //把data的值赋给pWriteArray指针 memcpy((void*)(mxGetPr(pWriteArray)), (void*)data, sizeof(double) * len); return pWriteArray; // 赋值给返回值 } /* 入口函数 */ /* 三个参数,一个返回值 输入: 1. wd 文献摘要中的单词,二维cell 2. dic 字典,按字母序排序的,一维cell [3]. flagPrint 输出: x 一维int(double)类型,表示在wd的每一行单词中,dic中是否有单词匹配上(匹配后,对应的坐标为设为1,否则为0),所有匹配的个数(统计每一行的匹配个数) dic每一个单词在wd所有行中出现的次数 */ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { if (nrhs < 2) { cout << "At least 2 arguments should be given for this function!" << endl; return; } clock_t begin = clock(), mid, finish; vector vDic; vector> vvWd; Read2DWord(prhs[0], vvWd); Read1DWord(prhs[1], vDic); int rowNum = vvWd.size(); int colNum = vDic.size(); int flagPrint = 0; if (nrhs > 2) { double* pData = (double*)mxGetData(prhs[2]); flagPrint = (int)pData[0]; } finish = clock(); if (flagPrint == 2) cout << "Load data time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; vector vX(rowNum * colNum); // 一维表示二维 /* 统计dicr字典中,每个单词在wd中出现的次数 */ mid = clock(); unordered_map> umWordPos; for (int i = 0; i < vDic.size(); ++i) umWordPos[vDic[i]].push_back(i); // dic中可能存在重复,记录单词位置 unordered_set usPos; // 多次出现在wd中的单词,只统计一次,这是原matlab代码的功能,是否需要修改? vector> vusX(rowNum); // 保存每一行中非零元的坐标 int row = 0; // vector vSum(colNum); for (auto& vWd : vvWd) { auto& usPos = vusX[row]; for (auto& word : vWd) { auto itr = umWordPos.find(word); if (itr != umWordPos.end()) { for (auto pos : itr->second) usPos.insert(pos); } } for (auto idx : usPos) { vX[idx * rowNum + row] = 1; // matlab 列优先存储模式 // vSum[idx] += 1; } ++row; } finish = clock(); if (flagPrint == 2) cout << "Calc word occurrence time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; // for (auto& w : vDic) cout << umWordPos[w] << endl; // for (auto& w : vvWd[260]) cout << w << '\t' << w.size() << endl; //for (auto& w : vDic) cout << w << '\t' << w.size() << endl; // vector vSum(colNum); // for (int i = 0; i < rowNum; ++i) { // for (int j = 0; j < colNum; ++j) { // vSum[j] += vX[j * rowNum + i]; // } // } // for (auto val : vSum) cout << val << endl; // 测试输出 // cout << rowNum << '\t' << colNum << endl; // ofstream ofs1("d:\\result_xsum.txt"); // for (auto& val : vXSum) { // ofs1 << val << endl; // } // ofs1.close(); // // ofstream ofs2("d:\\result_xs.txt"); // for (int i = 0; i < rowNum; ++i) { // for (int j = 0; j < colNum; ++j) { // if (vXsData[j * rowNum + i] > 0) { // ofs2 << j + 1 << '\t'; // } // } // ofs2 << endl; // } // ofs2.close(); /* 写入结果 */ mid = clock(); if (nlhs > 0) { // vvX plhs[0] = writeToMatDouble(vX.data(), rowNum, colNum); } finish = clock(); if (flagPrint == 2) cout << "Write result time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; if (flagPrint) cout << "Calc word occurrence in Dic Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; } // 供c++调试用 void mexFunctionWrap(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { return mexFunction(nlhs, plhs, nrhs, prhs); }