修改了IsWordInDic的bug

This commit is contained in:
zzh 2023-10-13 22:33:55 +08:00
parent f96d9cf4a2
commit c68aeefa37
2 changed files with 47 additions and 54 deletions

View File

@ -91,7 +91,7 @@ mxArray* writeToMatDouble(const double *data, int rowNum, int colNum) {
1. wd cell
2. dic cell
3. threshold
[3]. flagPrint
x int(double)wddic10
dicwd
@ -108,63 +108,56 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
Read2DWord(prhs[0], vvWd);
Read1DWord(prhs[1], vDic);
int rowNum = vvWd.size();
int colNum = vDic.size();
int threshold = 5;
int flagPrint = 0;
if (nrhs > 2) {
double* pThreshold = (double*)mxGetData(prhs[2]);
threshold = (int)pThreshold[0];
if (threshold < 5) threshold = 5;
double* pData = (double*)mxGetData(prhs[2]);
flagPrint = (int)pData[0];
}
finish = clock();
cout << "Load data time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
if (flagPrint == 2) cout << "Load data time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
vector<double> vXSum(vDic.size());
vector<double> vX(rowNum * colNum); // 一维表示二维
/* 统计dicr字典中每个单词在wd中出现的次数 */
mid = clock();
unordered_map<string, int> umWordPos;
for (int i = 0; i < vDic.size(); ++i) umWordPos[vDic[i]] = i; // 记录单词位置
unordered_map<string, vector<int>> umWordPos;
for (int i = 0; i < vDic.size(); ++i) umWordPos[vDic[i]].push_back(i); // dic中可能存在重复,记录单词位置
unordered_set<int> usPos; // 多次出现在wd中的单词只统计一次这是原matlab代码的功能是否需要修改
vector<unordered_set<int>> vusX(rowNum); // 保存每一行中非零元的坐标
int row = 0;
// vector<double> vSum(colNum);
for (auto& vWd : vvWd) {
auto& usPos = vusX[row++];
auto& usPos = vusX[row];
for (auto& word : vWd) {
auto itr = umWordPos.find(word);
if (itr != umWordPos.end()) {
usPos.insert(itr->second);
for (auto pos : itr->second)
usPos.insert(pos);
}
}
for (auto idx : usPos) {
vXSum[idx] += 1;
vX[idx * rowNum + row] = 1; // matlab 列优先存储模式
// vSum[idx] += 1;
}
++row;
}
finish = clock();
cout << "Calc word occurrence time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
if (flagPrint == 2) cout << "Calc word occurrence time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* 计算xs */
mid = clock();
int colNum = 0;
vector<int> vColIdx;
for (int i = 0; i < vXSum.size(); ++i) {
if (vXSum[i] >= threshold) {
vColIdx.push_back(i);
}
}
colNum = vColIdx.size();
// for (auto& w : vDic) cout << umWordPos[w] << endl;
vector<double> vXsData(rowNum * colNum);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
if (vusX[i].find(vColIdx[j]) != vusX[i].end()) {
vXsData[j * rowNum + i] = 1;
}
}
}
finish = clock();
cout << "Calc xs time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
// for (auto& w : vvWd[260]) cout << w << '\t' << w.size() << endl;
//for (auto& w : vDic) cout << w << '\t' << w.size() << endl;
// vector<double> vSum(colNum);
// for (int i = 0; i < rowNum; ++i) {
// for (int j = 0; j < colNum; ++j) {
// vSum[j] += vX[j * rowNum + i];
// }
// }
// for (auto val : vSum) cout << val << endl;
// 测试输出
// cout << rowNum << '\t' << colNum << endl;
// ofstream ofs1("d:\\result_xsum.txt");
@ -186,16 +179,14 @@ void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
/* 写入结果 */
mid = clock();
if (nlhs > 0) {
plhs[0] = writeToMatDouble(vXSum.data(), 1, vXSum.size());
if (nlhs > 0) { // vvX
plhs[0] = writeToMatDouble(vX.data(), rowNum, colNum);
}
if (nlhs > 1) { // xs
plhs[1] = writeToMatDouble(vXsData.data(), rowNum, colNum);
}
finish = clock();
cout << "Write result time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
cout << "Calc word occurrence in Dic Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
finish = clock();
if (flagPrint == 2) cout << "Write result time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
if (flagPrint) cout << "Calc word occurrence in Dic Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}
// 供c++调试用

View File

@ -13,12 +13,12 @@ int main(int argc, const char** argv)
const mxArray* prhs[argReserveNum];
/* SortDedup */
int nlhs = 1, nrhs = 2;
MATFile* pwdMat = matOpen("D:\\tmp\\wd_small.mat", "r");
prhs[0] = matGetVariable(pwdMat, "wd");
prhs[1] = mxCreateString("D:\\Twirls\\runtime\\output_1.dat");
prhs[2] = mxCreateDoubleMatrix(1, 1, mxREAL);
*mxGetPr(prhs[2]) = 2;
// int nlhs = 1, nrhs = 2;
// MATFile* pwdMat = matOpen("D:\\tmp\\wd_small.mat", "r");
// prhs[0] = matGetVariable(pwdMat, "wd");
// prhs[1] = mxCreateString("D:\\Twirls\\runtime\\output_1.dat");
// prhs[2] = mxCreateDoubleMatrix(1, 1, mxREAL);
// *mxGetPr(prhs[2]) = 2;
/* CalcEntropy */
// int nlhs = 2, nrhs = 4;
@ -32,12 +32,14 @@ int main(int argc, const char** argv)
// *mxGetPr(prhs[3]) = 2;
/* IsWordInDic */
// MATFile* pwdMat, * pdicMat;
// int nlhs = 2, nrhs = 2;
// pwdMat = matOpen("D:\\tmp\\wd_large.mat", "r");
// pdicMat = matOpen("D:\\tmp\\G_dc_large.mat", "r");
// prhs[0] = matGetVariable(pwdMat, "wd"); //获取.mat文件里面名为matrixName的矩阵
// prhs[1] = matGetVariable(pdicMat, "dc");
MATFile* pwdMat, * pdicMat;
int nlhs = 1, nrhs = 3;
pwdMat = matOpen("D:\\tmp\\ws_small.mat", "r");
pdicMat = matOpen("D:\\tmp\\x_small.mat", "r");
prhs[0] = matGetVariable(pwdMat, "ws"); //获取.mat文件里面名为matrixName的矩阵
prhs[1] = matGetVariable(pdicMat, "x");
prhs[2] = mxCreateDoubleMatrix(1, 1, mxREAL);
*mxGetPr(prhs[2]) = 2;
/* ClusterRandSim */
// int nlhs = 2, nrhs = 4;