修改了IsWordInDic的bug

This commit is contained in:
zzh 2023-10-13 22:33:55 +08:00
parent f96d9cf4a2
commit c68aeefa37
2 changed files with 47 additions and 54 deletions

View File

@ -91,7 +91,7 @@ mxArray* writeToMatDouble(const double *data, int rowNum, int colNum) {
1. wd cell 1. wd cell
2. dic cell 2. dic cell
3. threshold [3]. flagPrint
x int(double)wddic10 x int(double)wddic10
dicwd dicwd
@ -108,63 +108,56 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
Read2DWord(prhs[0], vvWd); Read2DWord(prhs[0], vvWd);
Read1DWord(prhs[1], vDic); Read1DWord(prhs[1], vDic);
int rowNum = vvWd.size(); int rowNum = vvWd.size();
int colNum = vDic.size();
int threshold = 5; int flagPrint = 0;
if (nrhs > 2) { if (nrhs > 2) {
double* pThreshold = (double*)mxGetData(prhs[2]); double* pData = (double*)mxGetData(prhs[2]);
threshold = (int)pThreshold[0]; flagPrint = (int)pData[0];
if (threshold < 5) threshold = 5;
} }
finish = clock(); finish = clock();
cout << "Load data time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; if (flagPrint == 2) cout << "Load data time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
vector<double> vXSum(vDic.size()); vector<double> vX(rowNum * colNum); // 一维表示二维
/* 统计dicr字典中每个单词在wd中出现的次数 */ /* 统计dicr字典中每个单词在wd中出现的次数 */
mid = clock(); mid = clock();
unordered_map<string, int> umWordPos; unordered_map<string, vector<int>> umWordPos;
for (int i = 0; i < vDic.size(); ++i) umWordPos[vDic[i]] = i; // 记录单词位置 for (int i = 0; i < vDic.size(); ++i) umWordPos[vDic[i]].push_back(i); // dic中可能存在重复,记录单词位置
unordered_set<int> usPos; // 多次出现在wd中的单词只统计一次这是原matlab代码的功能是否需要修改 unordered_set<int> usPos; // 多次出现在wd中的单词只统计一次这是原matlab代码的功能是否需要修改
vector<unordered_set<int>> vusX(rowNum); // 保存每一行中非零元的坐标 vector<unordered_set<int>> vusX(rowNum); // 保存每一行中非零元的坐标
int row = 0; int row = 0;
// vector<double> vSum(colNum);
for (auto& vWd : vvWd) { for (auto& vWd : vvWd) {
auto& usPos = vusX[row++]; auto& usPos = vusX[row];
for (auto& word : vWd) { for (auto& word : vWd) {
auto itr = umWordPos.find(word); auto itr = umWordPos.find(word);
if (itr != umWordPos.end()) { if (itr != umWordPos.end()) {
usPos.insert(itr->second); for (auto pos : itr->second)
usPos.insert(pos);
} }
} }
for (auto idx : usPos) { for (auto idx : usPos) {
vXSum[idx] += 1; vX[idx * rowNum + row] = 1; // matlab 列优先存储模式
// vSum[idx] += 1;
} }
++row;
} }
finish = clock(); finish = clock();
cout << "Calc word occurrence time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl; if (flagPrint == 2) cout << "Calc word occurrence time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* 计算xs */ // for (auto& w : vDic) cout << umWordPos[w] << endl;
mid = clock();
int colNum = 0;
vector<int> vColIdx;
for (int i = 0; i < vXSum.size(); ++i) {
if (vXSum[i] >= threshold) {
vColIdx.push_back(i);
}
}
colNum = vColIdx.size();
vector<double> vXsData(rowNum * colNum); // for (auto& w : vvWd[260]) cout << w << '\t' << w.size() << endl;
//for (auto& w : vDic) cout << w << '\t' << w.size() << endl;
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
if (vusX[i].find(vColIdx[j]) != vusX[i].end()) {
vXsData[j * rowNum + i] = 1;
}
}
}
finish = clock();
cout << "Calc xs time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
// vector<double> vSum(colNum);
// for (int i = 0; i < rowNum; ++i) {
// for (int j = 0; j < colNum; ++j) {
// vSum[j] += vX[j * rowNum + i];
// }
// }
// for (auto val : vSum) cout << val << endl;
// 测试输出 // 测试输出
// cout << rowNum << '\t' << colNum << endl; // cout << rowNum << '\t' << colNum << endl;
// ofstream ofs1("d:\\result_xsum.txt"); // ofstream ofs1("d:\\result_xsum.txt");
@ -186,16 +179,14 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
/* 写入结果 */ /* 写入结果 */
mid = clock(); mid = clock();
if (nlhs > 0) { if (nlhs > 0) { // vvX
plhs[0] = writeToMatDouble(vXSum.data(), 1, vXSum.size()); plhs[0] = writeToMatDouble(vX.data(), rowNum, colNum);
} }
if (nlhs > 1) { // xs
plhs[1] = writeToMatDouble(vXsData.data(), rowNum, colNum);
}
finish = clock();
cout << "Write result time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
cout << "Calc word occurrence in Dic Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl; finish = clock();
if (flagPrint == 2) cout << "Write result time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
if (flagPrint) cout << "Calc word occurrence in Dic Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
} }
// 供c++调试用 // 供c++调试用

View File

@ -13,12 +13,12 @@ int main(int argc, const char** argv)
const mxArray* prhs[argReserveNum]; const mxArray* prhs[argReserveNum];
/* SortDedup */ /* SortDedup */
int nlhs = 1, nrhs = 2; // int nlhs = 1, nrhs = 2;
MATFile* pwdMat = matOpen("D:\\tmp\\wd_small.mat", "r"); // MATFile* pwdMat = matOpen("D:\\tmp\\wd_small.mat", "r");
prhs[0] = matGetVariable(pwdMat, "wd"); // prhs[0] = matGetVariable(pwdMat, "wd");
prhs[1] = mxCreateString("D:\\Twirls\\runtime\\output_1.dat"); // prhs[1] = mxCreateString("D:\\Twirls\\runtime\\output_1.dat");
prhs[2] = mxCreateDoubleMatrix(1, 1, mxREAL); // prhs[2] = mxCreateDoubleMatrix(1, 1, mxREAL);
*mxGetPr(prhs[2]) = 2; // *mxGetPr(prhs[2]) = 2;
/* CalcEntropy */ /* CalcEntropy */
// int nlhs = 2, nrhs = 4; // int nlhs = 2, nrhs = 4;
@ -32,12 +32,14 @@ int main(int argc, const char** argv)
// *mxGetPr(prhs[3]) = 2; // *mxGetPr(prhs[3]) = 2;
/* IsWordInDic */ /* IsWordInDic */
// MATFile* pwdMat, * pdicMat; MATFile* pwdMat, * pdicMat;
// int nlhs = 2, nrhs = 2; int nlhs = 1, nrhs = 3;
// pwdMat = matOpen("D:\\tmp\\wd_large.mat", "r"); pwdMat = matOpen("D:\\tmp\\ws_small.mat", "r");
// pdicMat = matOpen("D:\\tmp\\G_dc_large.mat", "r"); pdicMat = matOpen("D:\\tmp\\x_small.mat", "r");
// prhs[0] = matGetVariable(pwdMat, "wd"); //获取.mat文件里面名为matrixName的矩阵 prhs[0] = matGetVariable(pwdMat, "ws"); //获取.mat文件里面名为matrixName的矩阵
// prhs[1] = matGetVariable(pdicMat, "dc"); prhs[1] = matGetVariable(pdicMat, "x");
prhs[2] = mxCreateDoubleMatrix(1, 1, mxREAL);
*mxGetPr(prhs[2]) = 2;
/* ClusterRandSim */ /* ClusterRandSim */
// int nlhs = 2, nrhs = 4; // int nlhs = 2, nrhs = 4;