twirls/MexFunc/IsWordInDic.cpp

204 lines
5.5 KiB
C++
Raw Normal View History

#include <mex.h>
#include <mat.h>
#include <iostream>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <ctime>
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
#include <unordered_map>
#include <set>
#include <fstream>
#include <algorithm>
#include <random>
#include <cmath>
using std::cout;
using std::endl;
using namespace std;
#define STRING_BUF_SIZE 204800
// <20><>ȡһάcell<6C>ַ<EFBFBD><D6B7><EFBFBD><EFBFBD><EFBFBD>ת<EFBFBD><D7AA><EFBFBD>ɴ<EFBFBD>д
inline bool Read1DWord(const mxArray* pMxArray, vector<string>& vStr) {
mxArray* pCell = nullptr;
int rowNum, colNum;
char* strBuf = new char[STRING_BUF_SIZE];
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
vStr.resize(rowNum * colNum);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
pCell = mxGetCell(pMxArray, j * rowNum + i);
if (mxGetString(pCell, strBuf, STRING_BUF_SIZE) != 0) {
cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl;
return false;
}
vStr[i * colNum + j] = strBuf;
auto& lastStr = vStr[i * colNum + j];
transform(lastStr.cbegin(), lastStr.cend(), lastStr.begin(), ::toupper); // ת<>ɴ<EFBFBD>д
}
}
delete[]strBuf;
return true;
}
// <20><>ȡ<EFBFBD><C8A1>άcell<6C>ַ<EFBFBD><D6B7><EFBFBD><EFBFBD><EFBFBD>ת<EFBFBD><D7AA><EFBFBD>ɴ<EFBFBD>д
inline bool Read2DWord(const mxArray* pMxArray, vector<vector<string>>& vvStr) {
mxArray* pCell = nullptr;
int rowNum, colNum;
char* strBuf = new char[STRING_BUF_SIZE];
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
pCell = mxGetCell(pMxArray, j * rowNum + i);
int childRowNum = (int)mxGetM(pCell);
int childColNum = (int)mxGetN(pCell);
vvStr.push_back(vector<string>());
Read1DWord(pCell, vvStr.back());
}
}
delete[]strBuf;
return true;
}
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>д<EFBFBD><D0B4>mxArray, <20><>Ϊ<EFBFBD><CEAA><EFBFBD><EFBFBD><EFBFBD>ķ<EFBFBD><C4B7><EFBFBD>ֵ
mxArray* writeToMat(const double *data, int rowNum, int colNum) {
mxArray* pWriteArray = NULL;//matlab<61><62>ʽ<EFBFBD><CABD><EFBFBD><EFBFBD>
int len = rowNum * colNum;
//<2F><><EFBFBD><EFBFBD>һ<EFBFBD><D2BB>rowNum*colNum<75>ľ<EFBFBD><C4BE><EFBFBD>
pWriteArray = mxCreateDoubleMatrix(rowNum, colNum, mxREAL);
//<2F><>data<74><61>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD>pWriteArrayָ<79><D6B8>
memcpy((void*)(mxGetPr(pWriteArray)), (void*)data, sizeof(double) * len);
return pWriteArray; // <20><>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ
}
/* <20><><EFBFBD>ں<EFBFBD><DABA><EFBFBD> */
/*
<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>һ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ
<EFBFBD><EFBFBD><EFBFBD>
1. wd <EFBFBD><EFBFBD><EFBFBD><EFBFBD>ժҪ<EFBFBD>еĵ<EFBFBD><EFBFBD>ʣ<EFBFBD><EFBFBD><EFBFBD>άcell
2. dic <EFBFBD>ֵ<EFBFBD><EFBFBD><EFBFBD><EFBFBD>ĸ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ģ<EFBFBD>һάcell
3. threshold <EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ<EFBFBD><EFBFBD><EFBFBD><EFBFBD>
<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
x һάint(double)<EFBFBD><EFBFBD><EFBFBD>ͣ<EFBFBD><EFBFBD><EFBFBD>ʾ<EFBFBD><EFBFBD>wd<EFBFBD><EFBFBD>ÿһ<EFBFBD>е<EFBFBD><EFBFBD><EFBFBD><EFBFBD>У<EFBFBD>dic<EFBFBD><EFBFBD><EFBFBD>Ƿ<EFBFBD><EFBFBD>е<EFBFBD><EFBFBD><EFBFBD>ƥ<EFBFBD><EFBFBD><EFBFBD>ϣ<EFBFBD>ƥ<EFBFBD><EFBFBD><EFBFBD>󣬶<EFBFBD>Ӧ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϊ<EFBFBD><EFBFBD>Ϊ1<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϊ0<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ƥ<EFBFBD><EFBFBD><EFBFBD>ĸ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ͳ<EFBFBD><EFBFBD>ÿһ<EFBFBD>е<EFBFBD>ƥ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
dicÿһ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>wd<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>г<EFBFBD><EFBFBD>ֵĴ<EFBFBD><EFBFBD><EFBFBD>
*/
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
if (nrhs < 2) {
cout << "At least 2 arguments should be given for this function!" << endl;
return;
}
clock_t begin = clock(), mid, finish;
vector<string> vDic;
vector<vector<string>> vvWd;
Read2DWord(prhs[0], vvWd);
Read1DWord(prhs[1], vDic);
int rowNum = vvWd.size();
int threshold = 5;
if (nrhs > 2) {
double* pThreshold = (double*)mxGetData(prhs[2]);
threshold = (int)pThreshold[0];
if (threshold < 5) threshold = 5;
}
finish = clock();
cout << "Load data time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
vector<double> vXSum(vDic.size());
/* ͳ<><CDB3>dicr<63>ֵ<EFBFBD><D6B5>У<EFBFBD>ÿ<EFBFBD><C3BF><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>wd<77>г<EFBFBD><D0B3>ֵĴ<D6B5><C4B4><EFBFBD> */
mid = clock();
unordered_map<string, int> umWordPos;
for (int i = 0; i < vDic.size(); ++i) umWordPos[vDic[i]] = i; // <20><>¼<EFBFBD><C2BC><EFBFBD><EFBFBD>λ<EFBFBD><CEBB>
unordered_set<int> usPos; // <20><><EFBFBD>γ<EFBFBD><CEB3><EFBFBD><EFBFBD><EFBFBD>wd<77>еĵ<D0B5><C4B5>ʣ<EFBFBD>ֻͳ<D6BB><CDB3>һ<EFBFBD>Σ<EFBFBD><CEA3><EFBFBD><EFBFBD><EFBFBD>ԭmatlab<61><62><EFBFBD><EFBFBD><EFBFBD>Ĺ<EFBFBD><C4B9>ܣ<EFBFBD><DCA3>Ƿ<EFBFBD><C7B7><EFBFBD>Ҫ<EFBFBD>޸ģ<DEB8>
vector<unordered_set<int>> vusX(rowNum); // <20><><EFBFBD><EFBFBD>ÿһ<C3BF><D2BB><EFBFBD>з<EFBFBD><D0B7><EFBFBD>Ԫ<EFBFBD><D4AA><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
int row = 0;
for (auto& vWd : vvWd) {
auto& usPos = vusX[row++];
for (auto& word : vWd) {
auto itr = umWordPos.find(word);
if (itr != umWordPos.end()) {
usPos.insert(itr->second);
}
}
for (auto idx : usPos) {
vXSum[idx] += 1;
}
}
finish = clock();
cout << "Calc word occurrence time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* <20><><EFBFBD><EFBFBD>xs */
mid = clock();
int colNum = 0;
vector<int> vColIdx;
for (int i = 0; i < vXSum.size(); ++i) {
if (vXSum[i] >= threshold) {
vColIdx.push_back(i);
}
}
colNum = vColIdx.size();
vector<double> vXsData(rowNum * colNum);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
if (vusX[i].find(vColIdx[j]) != vusX[i].end()) {
vXsData[j * rowNum + i] = 1;
}
}
}
finish = clock();
cout << "Calc xs time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
// <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
// cout << rowNum << '\t' << colNum << endl;
// ofstream ofs1("d:\\result_xsum.txt");
// for (auto& val : vXSum) {
// ofs1 << val << endl;
// }
// ofs1.close();
//
// ofstream ofs2("d:\\result_xs.txt");
// for (int i = 0; i < rowNum; ++i) {
// for (int j = 0; j < colNum; ++j) {
// if (vXsData[j * rowNum + i] > 0) {
// ofs2 << j + 1 << '\t';
// }
// }
// ofs2 << endl;
// }
// ofs2.close();
/* д<><D0B4><EFBFBD><EFBFBD><EFBFBD><EFBFBD> */
mid = clock();
if (nlhs > 0) {
plhs[0] = writeToMat(vXSum.data(), 1, vXSum.size());
}
if (nlhs > 1) { // xs
plhs[1] = writeToMat(vXsData.data(), rowNum, colNum);
}
finish = clock();
cout << "Write result time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
cout << "Calc word occurrence in Dic Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}
// <20><>c++<2B><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
void mexFunctionWrap(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
return mexFunction(nlhs, plhs, nrhs, prhs);
}