#include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include using std::cout; using std::endl; using namespace std; #define STRING_BUF_SIZE 204800 // 读取一维cell字符串并转换成大写 inline bool Read1DWord(const mxArray* pMxArray, vector& vStr) { mxArray* pCell = nullptr; int rowNum, colNum; char* strBuf = new char[STRING_BUF_SIZE]; rowNum = (int)mxGetM(pMxArray); colNum = (int)mxGetN(pMxArray); vStr.resize(rowNum * colNum); for (int i = 0; i < rowNum; ++i) { for (int j = 0; j < colNum; ++j) { pCell = mxGetCell(pMxArray, j * rowNum + i); if (mxGetString(pCell, strBuf, STRING_BUF_SIZE) != 0) { cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl; return false; } vStr[i * colNum + j] = strBuf; auto& lastStr = vStr[i * colNum + j]; transform(lastStr.cbegin(), lastStr.cend(), lastStr.begin(), ::toupper); // 转成大写 } } delete[]strBuf; return true; } // 读取二维cell字符串并转换成大写 inline bool Read2DWord(const mxArray* pMxArray, vector>& vvStr) { mxArray* pCell = nullptr; int rowNum, colNum; char* strBuf = new char[STRING_BUF_SIZE]; rowNum = (int)mxGetM(pMxArray); colNum = (int)mxGetN(pMxArray); for (int i = 0; i < rowNum; ++i) { for (int j = 0; j < colNum; ++j) { pCell = mxGetCell(pMxArray, j * rowNum + i); int childRowNum = (int)mxGetM(pCell); int childColNum = (int)mxGetN(pCell); vvStr.push_back(vector()); Read1DWord(pCell, vvStr.back()); } } delete[]strBuf; return true; } // 将结果写入mxArray, 作为后续的返回值 mxArray* writeToMatDouble(const double *data, int rowNum, int colNum) { mxArray* pWriteArray = NULL;//matlab格式矩阵 int len = rowNum * colNum; //创建一个rowNum*colNum的矩阵 pWriteArray = mxCreateDoubleMatrix(rowNum, colNum, mxREAL); //把data的值赋给pWriteArray指针 memcpy((void*)(mxGetPr(pWriteArray)), (void*)data, sizeof(double) * len); return pWriteArray; // 赋值给返回值 } /* 入口函数 */ /* 三个参数,一个返回值 输入: 1. wd 文献摘要中的单词,二维cell 2. dic 字典,按字母序排序的,一维cell 3. threshold 保留超过阈值的列 输出: x 一维int(double)类型,表示在wd的每一行单词中,dic中是否有单词匹配上(匹配后,对应的坐标为设为1,否则为0),所有匹配的个数(统计每一行的匹配个数) dic每一个单词在wd所有行中出现的次数 */ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { if (nrhs < 2) { cout << "At least 2 arguments should be given for this function!" << endl; return; } clock_t begin = clock(), mid, finish; vector vDic; vector> vvWd; Read2DWord(prhs[0], vvWd); Read1DWord(prhs[1], vDic); int rowNum = vvWd.size(); int threshold = 5; if (nrhs > 2) { double* pThreshold = (double*)mxGetData(prhs[2]); threshold = (int)pThreshold[0]; if (threshold < 5) threshold = 5; } finish = clock(); cout << "Load data time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; vector vXSum(vDic.size()); /* 统计dicr字典中,每个单词在wd中出现的次数 */ mid = clock(); unordered_map umWordPos; for (int i = 0; i < vDic.size(); ++i) umWordPos[vDic[i]] = i; // 记录单词位置 unordered_set usPos; // 多次出现在wd中的单词,只统计一次,这是原matlab代码的功能,是否需要修改? vector> vusX(rowNum); // 保存每一行中非零元的坐标 int row = 0; for (auto& vWd : vvWd) { auto& usPos = vusX[row++]; for (auto& word : vWd) { auto itr = umWordPos.find(word); if (itr != umWordPos.end()) { usPos.insert(itr->second); } } for (auto idx : usPos) { vXSum[idx] += 1; } } finish = clock(); cout << "Calc word occurrence time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; /* 计算xs */ mid = clock(); int colNum = 0; vector vColIdx; for (int i = 0; i < vXSum.size(); ++i) { if (vXSum[i] >= threshold) { vColIdx.push_back(i); } } colNum = vColIdx.size(); vector vXsData(rowNum * colNum); for (int i = 0; i < rowNum; ++i) { for (int j = 0; j < colNum; ++j) { if (vusX[i].find(vColIdx[j]) != vusX[i].end()) { vXsData[j * rowNum + i] = 1; } } } finish = clock(); cout << "Calc xs time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; // 测试输出 // cout << rowNum << '\t' << colNum << endl; // ofstream ofs1("d:\\result_xsum.txt"); // for (auto& val : vXSum) { // ofs1 << val << endl; // } // ofs1.close(); // // ofstream ofs2("d:\\result_xs.txt"); // for (int i = 0; i < rowNum; ++i) { // for (int j = 0; j < colNum; ++j) { // if (vXsData[j * rowNum + i] > 0) { // ofs2 << j + 1 << '\t'; // } // } // ofs2 << endl; // } // ofs2.close(); /* 写入结果 */ mid = clock(); if (nlhs > 0) { plhs[0] = writeToMatDouble(vXSum.data(), 1, vXSum.size()); } if (nlhs > 1) { // xs plhs[1] = writeToMatDouble(vXsData.data(), rowNum, colNum); } finish = clock(); cout << "Write result time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; cout << "Calc word occurrence in Dic Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; } // 供c++调试用 void mexFunctionWrap(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) { return mexFunction(nlhs, plhs, nrhs, prhs); }