twirls/MexFunc/IsWordInDic.cpp

204 lines
5.5 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <mex.h>
#include <mat.h>
#include <iostream>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <ctime>
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
#include <unordered_map>
#include <set>
#include <fstream>
#include <algorithm>
#include <random>
#include <cmath>
using std::cout;
using std::endl;
using namespace std;
#define STRING_BUF_SIZE 204800
// 读取一维cell字符串并转换成大写
inline bool Read1DWord(const mxArray* pMxArray, vector<string>& vStr) {
mxArray* pCell = nullptr;
int rowNum, colNum;
char* strBuf = new char[STRING_BUF_SIZE];
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
vStr.resize(rowNum * colNum);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
pCell = mxGetCell(pMxArray, j * rowNum + i);
if (mxGetString(pCell, strBuf, STRING_BUF_SIZE) != 0) {
cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl;
return false;
}
vStr[i * colNum + j] = strBuf;
auto& lastStr = vStr[i * colNum + j];
transform(lastStr.cbegin(), lastStr.cend(), lastStr.begin(), ::toupper); // 转成大写
}
}
delete[]strBuf;
return true;
}
// 读取二维cell字符串并转换成大写
inline bool Read2DWord(const mxArray* pMxArray, vector<vector<string>>& vvStr) {
mxArray* pCell = nullptr;
int rowNum, colNum;
char* strBuf = new char[STRING_BUF_SIZE];
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
pCell = mxGetCell(pMxArray, j * rowNum + i);
int childRowNum = (int)mxGetM(pCell);
int childColNum = (int)mxGetN(pCell);
vvStr.push_back(vector<string>());
Read1DWord(pCell, vvStr.back());
}
}
delete[]strBuf;
return true;
}
// 将结果写入mxArray, 作为后续的返回值
mxArray* writeToMat(const double *data, int rowNum, int colNum) {
mxArray* pWriteArray = NULL;//matlab格式矩阵
int len = rowNum * colNum;
//创建一个rowNum*colNum的矩阵
pWriteArray = mxCreateDoubleMatrix(rowNum, colNum, mxREAL);
//把data的值赋给pWriteArray指针
memcpy((void*)(mxGetPr(pWriteArray)), (void*)data, sizeof(double) * len);
return pWriteArray; // 赋值给返回值
}
/* 入口函数 */
/*
三个参数,一个返回值
输入:
1. wd 文献摘要中的单词二维cell
2. dic 字典按字母序排序的一维cell
3. threshold 保留超过阈值的列
输出:
x 一维int(double)类型表示在wd的每一行单词中dic中是否有单词匹配上匹配后对应的坐标为设为1否则为0所有匹配的个数统计每一行的匹配个数
dic每一个单词在wd所有行中出现的次数
*/
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
if (nrhs < 2) {
cout << "At least 2 arguments should be given for this function!" << endl;
return;
}
clock_t begin = clock(), mid, finish;
vector<string> vDic;
vector<vector<string>> vvWd;
Read2DWord(prhs[0], vvWd);
Read1DWord(prhs[1], vDic);
int rowNum = vvWd.size();
int threshold = 5;
if (nrhs > 2) {
double* pThreshold = (double*)mxGetData(prhs[2]);
threshold = (int)pThreshold[0];
if (threshold < 5) threshold = 5;
}
finish = clock();
cout << "Load data time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
vector<double> vXSum(vDic.size());
/* 统计dicr字典中每个单词在wd中出现的次数 */
mid = clock();
unordered_map<string, int> umWordPos;
for (int i = 0; i < vDic.size(); ++i) umWordPos[vDic[i]] = i; // 记录单词位置
unordered_set<int> usPos; // 多次出现在wd中的单词只统计一次这是原matlab代码的功能是否需要修改
vector<unordered_set<int>> vusX(rowNum); // 保存每一行中非零元的坐标
int row = 0;
for (auto& vWd : vvWd) {
auto& usPos = vusX[row++];
for (auto& word : vWd) {
auto itr = umWordPos.find(word);
if (itr != umWordPos.end()) {
usPos.insert(itr->second);
}
}
for (auto idx : usPos) {
vXSum[idx] += 1;
}
}
finish = clock();
cout << "Calc word occurrence time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* 计算xs */
mid = clock();
int colNum = 0;
vector<int> vColIdx;
for (int i = 0; i < vXSum.size(); ++i) {
if (vXSum[i] >= threshold) {
vColIdx.push_back(i);
}
}
colNum = vColIdx.size();
vector<double> vXsData(rowNum * colNum);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
if (vusX[i].find(vColIdx[j]) != vusX[i].end()) {
vXsData[j * rowNum + i] = 1;
}
}
}
finish = clock();
cout << "Calc xs time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
// 测试输出
// cout << rowNum << '\t' << colNum << endl;
// ofstream ofs1("d:\\result_xsum.txt");
// for (auto& val : vXSum) {
// ofs1 << val << endl;
// }
// ofs1.close();
//
// ofstream ofs2("d:\\result_xs.txt");
// for (int i = 0; i < rowNum; ++i) {
// for (int j = 0; j < colNum; ++j) {
// if (vXsData[j * rowNum + i] > 0) {
// ofs2 << j + 1 << '\t';
// }
// }
// ofs2 << endl;
// }
// ofs2.close();
/* 写入结果 */
mid = clock();
if (nlhs > 0) {
plhs[0] = writeToMat(vXSum.data(), 1, vXSum.size());
}
if (nlhs > 1) { // xs
plhs[1] = writeToMat(vXsData.data(), rowNum, colNum);
}
finish = clock();
cout << "Write result time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
cout << "Calc word occurrence in Dic Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}
// 供c++调试用
void mexFunctionWrap(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
return mexFunction(nlhs, plhs, nrhs, prhs);
}