twirls/MexFunc/SortDedup.cpp

147 lines
3.9 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <mex.h>
#include <mat.h>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <unordered_set>
#include <ctime>
#include <immintrin.h>
#include <zmmintrin.h>
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
#include <unordered_set>
#include <set>
#include <fstream>
using std::cout;
using std::endl;
using namespace std;
#define STRING_BUF_SIZE 204800
// 读取字符串并转换成大写, 插入set
bool ReadInsertWord(const mxArray* pMxArray, unordered_set<string> &sWord) {
mxArray* pCell = nullptr;
int rowNum, colNum;
char* strBuf = new char[STRING_BUF_SIZE];
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
pCell = mxGetCell(pMxArray, j * rowNum + i);
int childRowNum = (int)mxGetM(pCell);
int childColNum = (int)mxGetN(pCell);
for (int ii = 0; ii < childRowNum; ii++) {
for (int jj = 0; jj < childColNum; jj++) {
mxArray* pChildCell = mxGetCell(pCell, jj * childRowNum + ii);
if (mxGetString(pChildCell, strBuf, STRING_BUF_SIZE) != 0) {
cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl;
return false;
}
string str(strBuf);
transform(str.cbegin(), str.cend(), str.begin(), ::toupper); // 转成大写
sWord.insert(str);
}
}
}
}
delete[]strBuf;
return true;
}
/* 入口函数 */
/*
输入:
1. wd: 文献摘要由二维cell组成的字符串数组
[2]. 将字符串保存到文件路径
[3]. flagPrint 是否输出信息
输出:
1. dic: 单词组成的一维cell包含去重之后的文献摘要所有单词大写按字母序排序(只包含字母的单词,去掉数字等)
*/
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
if (nrhs < 1) {
cout << "At least 1 arguments should be given for this function!" << endl;
return;
}
clock_t begin = clock(), mid, finish;
unordered_set<string> usStr;
ReadInsertWord(prhs[0], usStr);
// usStr.insert("A");
// usStr.insert("Z");
string outputPath;
if (nrhs > 1) {
char* strBuf = new char[STRING_BUF_SIZE];
mxGetString(prhs[1], strBuf, STRING_BUF_SIZE);
outputPath = strBuf;
delete[]strBuf;
}
int flagPrint = 0; // 是否打印信息, 1打印简单信息2打印详细信息
if (nrhs > 2) {
double* pData = (double*)mxGetData(prhs[2]);
flagPrint = (int)pData[0];
}
finish = clock();
if (flagPrint == 2) cout << "Load data time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
/* 排序 */
mid = clock();
set<string> sOrderedWord;
for (auto& word : usStr) {
sOrderedWord.insert(word);
}
finish = clock();
if (flagPrint == 2) cout << "Sort and deduplicate time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* 将字符串保存到文件 */
if (! outputPath.empty()) {
cout << outputPath << endl;
ofstream ofs(outputPath);
for (auto& word : sOrderedWord) ofs << word << endl;
ofs.close();
}
sOrderedWord.insert("A");
sOrderedWord.insert("Z");
/* 写入结果 */
mid = clock();
if (nlhs > 0) {
int wordSize = 0;
for (auto& word : sOrderedWord) {
if (word[0] >= 'A' && word[0] <= 'Z') {
wordSize++;
}
}
mxArray* pCell = mxCreateCellMatrix(1, wordSize);
int i = 0;
for (auto& word : sOrderedWord) {
if (word[0] >= 'A' && word[0] <= 'Z') {
mxArray* mxStr = mxCreateString(word.c_str());
mxSetCell(pCell, i++, mxStr);
}
}
plhs[0] = pCell; // 赋值给返回值
}
finish = clock();
if (flagPrint == 2) cout << "Write back data time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
finish = clock();
if (flagPrint)cout << "Deduplicate and Sort word Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}
// 供c++调试用
void mexFunctionWrap(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
return mexFunction(nlhs, plhs, nrhs, prhs);
}