147 lines
3.9 KiB
C++
147 lines
3.9 KiB
C++
#include <mex.h>
|
||
#include <mat.h>
|
||
#include <iostream>
|
||
#include <algorithm>
|
||
#include <vector>
|
||
#include <string>
|
||
#include <unordered_set>
|
||
#include <ctime>
|
||
#include <immintrin.h>
|
||
#include <zmmintrin.h>
|
||
#include <vector>
|
||
#include <queue>
|
||
#include <memory>
|
||
#include <thread>
|
||
#include <mutex>
|
||
#include <condition_variable>
|
||
#include <future>
|
||
#include <functional>
|
||
#include <stdexcept>
|
||
#include <unordered_set>
|
||
#include <set>
|
||
#include <fstream>
|
||
|
||
using std::cout;
|
||
using std::endl;
|
||
using namespace std;
|
||
|
||
#define STRING_BUF_SIZE 204800
|
||
|
||
// 读取字符串并转换成大写, 插入set
|
||
bool ReadInsertWord(const mxArray* pMxArray, unordered_set<string> &sWord) {
|
||
mxArray* pCell = nullptr;
|
||
int rowNum, colNum;
|
||
char* strBuf = new char[STRING_BUF_SIZE];
|
||
|
||
rowNum = (int)mxGetM(pMxArray);
|
||
colNum = (int)mxGetN(pMxArray);
|
||
for (int i = 0; i < rowNum; ++i) {
|
||
for (int j = 0; j < colNum; ++j) {
|
||
pCell = mxGetCell(pMxArray, j * rowNum + i);
|
||
int childRowNum = (int)mxGetM(pCell);
|
||
int childColNum = (int)mxGetN(pCell);
|
||
for (int ii = 0; ii < childRowNum; ii++) {
|
||
for (int jj = 0; jj < childColNum; jj++) {
|
||
mxArray* pChildCell = mxGetCell(pCell, jj * childRowNum + ii);
|
||
if (mxGetString(pChildCell, strBuf, STRING_BUF_SIZE) != 0) {
|
||
cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl;
|
||
return false;
|
||
}
|
||
string str(strBuf);
|
||
transform(str.cbegin(), str.cend(), str.begin(), ::toupper); // 转成大写
|
||
sWord.insert(str);
|
||
}
|
||
}
|
||
}
|
||
}
|
||
delete[]strBuf;
|
||
return true;
|
||
}
|
||
|
||
/* 入口函数 */
|
||
/*
|
||
输入:
|
||
1. wd: 文献摘要,由二维cell组成的字符串数组
|
||
[2]. 将字符串保存到文件路径
|
||
[3]. flagPrint 是否输出信息
|
||
输出:
|
||
1. dic: 单词组成的一维cell,包含去重之后的文献摘要所有单词,大写,按字母序排序(只包含字母的单词,去掉数字等)
|
||
*/
|
||
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
|
||
if (nrhs < 1) {
|
||
cout << "At least 1 arguments should be given for this function!" << endl;
|
||
return;
|
||
}
|
||
clock_t begin = clock(), mid, finish;
|
||
|
||
unordered_set<string> usStr;
|
||
ReadInsertWord(prhs[0], usStr);
|
||
// usStr.insert("A");
|
||
// usStr.insert("Z");
|
||
string outputPath;
|
||
if (nrhs > 1) {
|
||
char* strBuf = new char[STRING_BUF_SIZE];
|
||
mxGetString(prhs[1], strBuf, STRING_BUF_SIZE);
|
||
outputPath = strBuf;
|
||
delete[]strBuf;
|
||
}
|
||
|
||
int flagPrint = 0; // 是否打印信息, 1打印简单信息,2打印详细信息
|
||
if (nrhs > 2) {
|
||
double* pData = (double*)mxGetData(prhs[2]);
|
||
flagPrint = (int)pData[0];
|
||
}
|
||
|
||
finish = clock();
|
||
if (flagPrint == 2) cout << "Load data time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
|
||
|
||
/* 排序 */
|
||
mid = clock();
|
||
set<string> sOrderedWord;
|
||
for (auto& word : usStr) {
|
||
sOrderedWord.insert(word);
|
||
}
|
||
finish = clock();
|
||
if (flagPrint == 2) cout << "Sort and deduplicate time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
|
||
|
||
/* 将字符串保存到文件 */
|
||
if (! outputPath.empty()) {
|
||
cout << outputPath << endl;
|
||
ofstream ofs(outputPath);
|
||
for (auto& word : sOrderedWord) ofs << word << endl;
|
||
ofs.close();
|
||
}
|
||
|
||
sOrderedWord.insert("A");
|
||
sOrderedWord.insert("Z");
|
||
|
||
/* 写入结果 */
|
||
mid = clock();
|
||
if (nlhs > 0) {
|
||
int wordSize = 0;
|
||
for (auto& word : sOrderedWord) {
|
||
if (word[0] >= 'A' && word[0] <= 'Z') {
|
||
wordSize++;
|
||
}
|
||
}
|
||
mxArray* pCell = mxCreateCellMatrix(1, wordSize);
|
||
int i = 0;
|
||
for (auto& word : sOrderedWord) {
|
||
if (word[0] >= 'A' && word[0] <= 'Z') {
|
||
mxArray* mxStr = mxCreateString(word.c_str());
|
||
mxSetCell(pCell, i++, mxStr);
|
||
}
|
||
}
|
||
plhs[0] = pCell; // 赋值给返回值
|
||
}
|
||
finish = clock();
|
||
if (flagPrint == 2) cout << "Write back data time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
|
||
|
||
finish = clock();
|
||
if (flagPrint)cout << "Deduplicate and Sort word Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
|
||
}
|
||
|
||
// 供c++调试用
|
||
void mexFunctionWrap(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
|
||
return mexFunction(nlhs, plhs, nrhs, prhs);
|
||
} |