204 lines
5.5 KiB
C++
204 lines
5.5 KiB
C++
#include <mex.h>
|
||
#include <mat.h>
|
||
#include <iostream>
|
||
#include <algorithm>
|
||
#include <string>
|
||
#include <unordered_set>
|
||
#include <ctime>
|
||
#include <vector>
|
||
#include <queue>
|
||
#include <memory>
|
||
#include <thread>
|
||
#include <mutex>
|
||
#include <condition_variable>
|
||
#include <future>
|
||
#include <functional>
|
||
#include <stdexcept>
|
||
#include <unordered_map>
|
||
#include <set>
|
||
#include <fstream>
|
||
#include <algorithm>
|
||
#include <random>
|
||
#include <cmath>
|
||
|
||
using std::cout;
|
||
using std::endl;
|
||
using namespace std;
|
||
|
||
#define STRING_BUF_SIZE 204800
|
||
|
||
|
||
// 读取一维cell字符串并转换成大写
|
||
inline bool Read1DWord(const mxArray* pMxArray, vector<string>& vStr) {
|
||
mxArray* pCell = nullptr;
|
||
int rowNum, colNum;
|
||
char* strBuf = new char[STRING_BUF_SIZE];
|
||
|
||
rowNum = (int)mxGetM(pMxArray);
|
||
colNum = (int)mxGetN(pMxArray);
|
||
vStr.resize(rowNum * colNum);
|
||
for (int i = 0; i < rowNum; ++i) {
|
||
for (int j = 0; j < colNum; ++j) {
|
||
pCell = mxGetCell(pMxArray, j * rowNum + i);
|
||
if (mxGetString(pCell, strBuf, STRING_BUF_SIZE) != 0) {
|
||
cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl;
|
||
return false;
|
||
}
|
||
vStr[i * colNum + j] = strBuf;
|
||
auto& lastStr = vStr[i * colNum + j];
|
||
transform(lastStr.cbegin(), lastStr.cend(), lastStr.begin(), ::toupper); // 转成大写
|
||
}
|
||
}
|
||
delete[]strBuf;
|
||
return true;
|
||
}
|
||
|
||
// 读取二维cell字符串并转换成大写
|
||
inline bool Read2DWord(const mxArray* pMxArray, vector<vector<string>>& vvStr) {
|
||
mxArray* pCell = nullptr;
|
||
int rowNum, colNum;
|
||
char* strBuf = new char[STRING_BUF_SIZE];
|
||
|
||
rowNum = (int)mxGetM(pMxArray);
|
||
colNum = (int)mxGetN(pMxArray);
|
||
for (int i = 0; i < rowNum; ++i) {
|
||
for (int j = 0; j < colNum; ++j) {
|
||
pCell = mxGetCell(pMxArray, j * rowNum + i);
|
||
int childRowNum = (int)mxGetM(pCell);
|
||
int childColNum = (int)mxGetN(pCell);
|
||
vvStr.push_back(vector<string>());
|
||
Read1DWord(pCell, vvStr.back());
|
||
}
|
||
}
|
||
delete[]strBuf;
|
||
return true;
|
||
}
|
||
|
||
// 将结果写入mxArray, 作为后续的返回值
|
||
mxArray* writeToMatDouble(const double *data, int rowNum, int colNum) {
|
||
mxArray* pWriteArray = NULL;//matlab格式矩阵
|
||
int len = rowNum * colNum;
|
||
//创建一个rowNum*colNum的矩阵
|
||
pWriteArray = mxCreateDoubleMatrix(rowNum, colNum, mxREAL);
|
||
//把data的值赋给pWriteArray指针
|
||
memcpy((void*)(mxGetPr(pWriteArray)), (void*)data, sizeof(double) * len);
|
||
return pWriteArray; // 赋值给返回值
|
||
}
|
||
|
||
/* 入口函数 */
|
||
/*
|
||
三个参数,一个返回值
|
||
输入:
|
||
1. wd 文献摘要中的单词,二维cell
|
||
2. dic 字典,按字母序排序的,一维cell
|
||
3. threshold 保留超过阈值的列
|
||
输出:
|
||
x 一维int(double)类型,表示在wd的每一行单词中,dic中是否有单词匹配上(匹配后,对应的坐标为设为1,否则为0),所有匹配的个数(统计每一行的匹配个数)
|
||
dic每一个单词在wd所有行中出现的次数
|
||
*/
|
||
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
|
||
if (nrhs < 2) {
|
||
cout << "At least 2 arguments should be given for this function!" << endl;
|
||
return;
|
||
}
|
||
clock_t begin = clock(), mid, finish;
|
||
vector<string> vDic;
|
||
vector<vector<string>> vvWd;
|
||
|
||
Read2DWord(prhs[0], vvWd);
|
||
Read1DWord(prhs[1], vDic);
|
||
int rowNum = vvWd.size();
|
||
|
||
int threshold = 5;
|
||
if (nrhs > 2) {
|
||
double* pThreshold = (double*)mxGetData(prhs[2]);
|
||
threshold = (int)pThreshold[0];
|
||
if (threshold < 5) threshold = 5;
|
||
}
|
||
finish = clock();
|
||
cout << "Load data time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
|
||
|
||
vector<double> vXSum(vDic.size());
|
||
|
||
/* 统计dicr字典中,每个单词在wd中出现的次数 */
|
||
mid = clock();
|
||
unordered_map<string, int> umWordPos;
|
||
for (int i = 0; i < vDic.size(); ++i) umWordPos[vDic[i]] = i; // 记录单词位置
|
||
unordered_set<int> usPos; // 多次出现在wd中的单词,只统计一次,这是原matlab代码的功能,是否需要修改?
|
||
vector<unordered_set<int>> vusX(rowNum); // 保存每一行中非零元的坐标
|
||
int row = 0;
|
||
for (auto& vWd : vvWd) {
|
||
auto& usPos = vusX[row++];
|
||
for (auto& word : vWd) {
|
||
auto itr = umWordPos.find(word);
|
||
if (itr != umWordPos.end()) {
|
||
usPos.insert(itr->second);
|
||
}
|
||
}
|
||
for (auto idx : usPos) {
|
||
vXSum[idx] += 1;
|
||
}
|
||
}
|
||
finish = clock();
|
||
cout << "Calc word occurrence time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
|
||
|
||
/* 计算xs */
|
||
mid = clock();
|
||
int colNum = 0;
|
||
vector<int> vColIdx;
|
||
for (int i = 0; i < vXSum.size(); ++i) {
|
||
if (vXSum[i] >= threshold) {
|
||
vColIdx.push_back(i);
|
||
}
|
||
}
|
||
colNum = vColIdx.size();
|
||
|
||
vector<double> vXsData(rowNum * colNum);
|
||
|
||
for (int i = 0; i < rowNum; ++i) {
|
||
for (int j = 0; j < colNum; ++j) {
|
||
if (vusX[i].find(vColIdx[j]) != vusX[i].end()) {
|
||
vXsData[j * rowNum + i] = 1;
|
||
}
|
||
}
|
||
}
|
||
finish = clock();
|
||
cout << "Calc xs time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
|
||
|
||
// 测试输出
|
||
// cout << rowNum << '\t' << colNum << endl;
|
||
// ofstream ofs1("d:\\result_xsum.txt");
|
||
// for (auto& val : vXSum) {
|
||
// ofs1 << val << endl;
|
||
// }
|
||
// ofs1.close();
|
||
//
|
||
// ofstream ofs2("d:\\result_xs.txt");
|
||
// for (int i = 0; i < rowNum; ++i) {
|
||
// for (int j = 0; j < colNum; ++j) {
|
||
// if (vXsData[j * rowNum + i] > 0) {
|
||
// ofs2 << j + 1 << '\t';
|
||
// }
|
||
// }
|
||
// ofs2 << endl;
|
||
// }
|
||
// ofs2.close();
|
||
|
||
/* 写入结果 */
|
||
mid = clock();
|
||
if (nlhs > 0) {
|
||
plhs[0] = writeToMatDouble(vXSum.data(), 1, vXSum.size());
|
||
}
|
||
if (nlhs > 1) { // xs
|
||
plhs[1] = writeToMatDouble(vXsData.data(), rowNum, colNum);
|
||
}
|
||
finish = clock();
|
||
cout << "Write result time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
|
||
|
||
cout << "Calc word occurrence in Dic Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
|
||
}
|
||
|
||
// 供c++调试用
|
||
void mexFunctionWrap(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
|
||
return mexFunction(nlhs, plhs, nrhs, prhs);
|
||
} |