238 lines
8.5 KiB
C++
238 lines
8.5 KiB
C++
#include <mex.h>
|
||
#include <mat.h>
|
||
#include <iostream>
|
||
#include <algorithm>
|
||
#include <vector>
|
||
#include <string>
|
||
#include <unordered_set>
|
||
#include <ctime>
|
||
using std::cout;
|
||
using std::endl;
|
||
using namespace std;
|
||
|
||
#define STRING_BUF_SIZE 204800
|
||
|
||
/* 读取二层cell包裹的字符串,和数值,ds,fr */
|
||
#define OUTER_FOR_BEGIN \
|
||
rowNum = (int)mxGetM(pMxArray); \
|
||
colNum = (int)mxGetN(pMxArray); \
|
||
for (int i = 0; i < rowNum; ++i) { \
|
||
for (int j = 0; j < colNum; ++j) { \
|
||
mxArray* pCell = mxGetCell(pMxArray, j * rowNum + i); \
|
||
int childRowNum = (int)mxGetM(pCell); \
|
||
int childColNum = (int)mxGetN(pCell);
|
||
|
||
#define OUTER_FOR_END \
|
||
} \
|
||
}
|
||
|
||
#define INNTER_FOR_BEGIN \
|
||
for (int ii = 0; ii < childRowNum; ii++) { \
|
||
for (int jj = 0; jj < childColNum; jj++) { \
|
||
mxArray *pChildCell = mxGetCell(pCell, jj * childRowNum + ii);
|
||
#define INNTER_FOR_END \
|
||
} \
|
||
}
|
||
// 将matlab存储方式转换成c存储方式
|
||
#define TRANS_ROW_COL(dst, src, rowNum, colNum) \
|
||
for (int rowI = 0; rowI < rowNum; ++rowI) { \
|
||
for (int colJ = 0; colJ < colNum; ++colJ) { \
|
||
dst[rowI * colNum + colJ] = src[colJ * rowNum + rowI]; \
|
||
} \
|
||
}
|
||
// 将二维索引转成一维的索引
|
||
inline int Get1DIndex(int colNum, int row, int col) {
|
||
return row * colNum + col;
|
||
}
|
||
|
||
// 读取G结构体中的ds和fr
|
||
void GetFrDs(const mxArray* pMxParent, vector<vector<string> >& vvDs, vector<vector<double> >& vvFr) {
|
||
// 读取ds字符串
|
||
int rowNum, colNum;
|
||
char *strBuf = new char[STRING_BUF_SIZE];
|
||
mxArray* pMxArray = mxGetField(pMxParent, 0, "ds"); // ds
|
||
OUTER_FOR_BEGIN
|
||
vvDs.push_back(vector<string>());
|
||
vvDs.back().resize(childRowNum * childColNum);
|
||
INNTER_FOR_BEGIN
|
||
if (mxGetString(pChildCell, strBuf, STRING_BUF_SIZE) != 0) {
|
||
cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl;
|
||
delete[]strBuf;
|
||
return;
|
||
}
|
||
vvDs.back()[ii * childColNum + jj] = strBuf;
|
||
auto& lastStr = vvDs.back()[ii * childColNum + jj];
|
||
transform(lastStr.begin(), lastStr.end(), lastStr.begin(), ::toupper); // 转成大写
|
||
INNTER_FOR_END
|
||
OUTER_FOR_END
|
||
|
||
// 读取fr数值
|
||
pMxArray = mxGetField(pMxParent, 0, "fr"); // fr
|
||
OUTER_FOR_BEGIN
|
||
vvFr.push_back(vector<double>());
|
||
vvFr.back().resize(childRowNum * childColNum);
|
||
double* pVal = (double*)mxGetData(pCell); //获取指针
|
||
TRANS_ROW_COL(vvFr.back(), pVal, childRowNum, childColNum); // 行列存储方式转换
|
||
OUTER_FOR_END
|
||
delete[]strBuf;
|
||
}
|
||
|
||
/* 读取abs */
|
||
void GetAbstract(const mxArray* pMxAbs, vector<string>& vAbs) {
|
||
int rowNum = (int)mxGetM(pMxAbs);
|
||
int colNum = (int)mxGetN(pMxAbs);
|
||
char *strBuf = new char[STRING_BUF_SIZE];
|
||
|
||
vAbs.resize(rowNum * colNum);
|
||
for (int i = 0; i < rowNum; ++i) {
|
||
for (int j = 0; j < colNum; ++j) {
|
||
mxArray* pCell = mxGetCell(pMxAbs, j * rowNum + i);
|
||
if (mxGetString(pCell, strBuf, STRING_BUF_SIZE) != 0) {
|
||
cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl;
|
||
delete[]strBuf;
|
||
return;
|
||
}
|
||
vAbs[i * colNum + j] = strBuf;
|
||
}
|
||
}
|
||
delete[]strBuf;
|
||
}
|
||
|
||
/*
|
||
输入:
|
||
1. abs: 待感知的文献的摘要信息。
|
||
2. G: 知识颗粒,包含该程序需要的热词ds以及对应的频率fr。
|
||
输出:
|
||
1. hs: 信息熵,二维[len(知识颗粒)][len(文献)]
|
||
*/
|
||
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
|
||
//cout << "MexCalcEntropy" << endl;
|
||
//cout << nlhs << '\t' << nrhs << endl;
|
||
if (nrhs != 2) {
|
||
cout << "2 arguments should be given for this function!" << endl;
|
||
return;
|
||
}
|
||
clock_t begin, finish;
|
||
begin = clock();
|
||
vector<vector<string> > vvDs; // 每个知识颗粒的ds矩阵(词汇矩阵)
|
||
vector<vector<double> > vvFr; // 词汇对应的频率
|
||
GetFrDs(prhs[1], vvDs, vvFr);
|
||
|
||
vector<string> vAbstract; // 读取abs1, 然后分割成一个一个的单词
|
||
GetAbstract(prhs[0], vAbstract);
|
||
|
||
/* 将摘要信息分割成一个一个的词汇 */
|
||
// begin = clock();
|
||
unordered_set<char> usWordChars; // 能组成单词的字符,要不要考虑数字?原版matlab是提取了数字的
|
||
for (int i = 65; i <= 90; i++) usWordChars.insert(char(i)); // A - Z
|
||
for (int i = 97; i <= 122; i++) usWordChars.insert(char(i)); // a - z
|
||
for (int i = 48; i <= 57; i++) usWordChars.insert(char(i)); // 0 - 9
|
||
usWordChars.insert('/'); usWordChars.insert('+'); usWordChars.insert('-');
|
||
vector<vector<string> > vvWordMtx(vAbstract.size()); // 初始大小为文章的个数
|
||
vector<unordered_set<string> > vusAbsWord(vAbstract.size()); // 将每篇文章摘要的单词放入hash表
|
||
for (int i = 0; i < vAbstract.size(); i++) {
|
||
auto& strAbs = vAbstract[i];
|
||
// 遍历摘要字符串的每一个字符,取出每一个单词
|
||
vector<string>& vWord = vvWordMtx[i];
|
||
if (strAbs.size() == 0) continue; // 摘要信息为空,跳过(一般不会出现这个情况)
|
||
int wordStartPos = 0;
|
||
while (wordStartPos < strAbs.size() && usWordChars.find(strAbs[wordStartPos]) == usWordChars.end())
|
||
wordStartPos++;
|
||
for (int curPos = wordStartPos + 1; curPos < strAbs.size(); ++curPos) {
|
||
if (usWordChars.find(strAbs[curPos]) == usWordChars.end()) { // 找到了分割符
|
||
vWord.push_back(strAbs.substr(wordStartPos, curPos - wordStartPos));
|
||
wordStartPos = curPos + 1; // 找下一个词语起始位置
|
||
while (wordStartPos < strAbs.size() && usWordChars.find(strAbs[wordStartPos]) == usWordChars.end())
|
||
wordStartPos++;
|
||
curPos = wordStartPos; // 循环会自动加1
|
||
}
|
||
}
|
||
// 将处理摘要之后的每个词语放入hash表
|
||
for (auto& word : vWord) {
|
||
string upWord(word);
|
||
transform(upWord.begin(), upWord.end(), upWord.begin(), ::toupper);
|
||
vusAbsWord[i].insert(upWord);
|
||
}
|
||
}
|
||
// finish = clock();
|
||
// cout << "Split abstract time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
|
||
|
||
// 存放结果,用一维数组存放二维数据
|
||
vector<double> hs;
|
||
vector<double> hr;
|
||
const int numLiterature = vusAbsWord.size(); // pubmed 文件中包含的文献数量
|
||
const int numGroup = vvDs.size(); // ds包含的组数
|
||
hs.resize(numGroup * numLiterature);
|
||
hr.resize(numLiterature * numGroup);
|
||
|
||
for (int groupIdx = 0; groupIdx < numGroup; ++groupIdx) { // 遍历知识颗粒中的每一组
|
||
vector<string>& vDs = vvDs[groupIdx]; // 这一组ds
|
||
vector<double>& vFr = vvFr[groupIdx]; // frequency
|
||
const int numWord = vDs.size(); // 这一组数据中包含的单词数量
|
||
vector<vector<int> > vX(numLiterature, vector<int>(numWord, 0));
|
||
// 检查知识颗粒中的词语是否出现在pubmed摘要的词语中
|
||
for (int i = 0; i < numLiterature; ++i) {
|
||
for (int j = 0; j < numWord; ++j) {
|
||
if (vusAbsWord[i].find(vDs[j]) != vusAbsWord[i].end()) { // 这一组单词中的j索引位置的单词在第i个文献中出现过
|
||
vX[i][j] = 1;
|
||
}
|
||
}
|
||
}
|
||
|
||
// 找词汇的最高频率
|
||
double maxFr = *max_element(vFr.begin(), vFr.end());
|
||
// 将fr的数值规范化到(0,0.368)之间
|
||
const double normalMax = 0.368;
|
||
for (auto& frVal : vFr) frVal = frVal * normalMax / maxFr;
|
||
maxFr = normalMax;
|
||
// 对每个知识颗粒每一组数据,计算信息熵
|
||
for (int i = 0; i < numLiterature; ++i) {
|
||
for (int j = 0; j < numWord; ++j) {
|
||
if (vX[i][j] == 1) {
|
||
hs[Get1DIndex(numLiterature, groupIdx, i)] -= vFr[j] * log2(vFr[j]);
|
||
}
|
||
}
|
||
}
|
||
|
||
// 找最高频词汇所在的索引位置
|
||
vector<int> vMaxPos;
|
||
int idx = 0;
|
||
for_each(vFr.begin(), vFr.end(), [&idx, maxFr, &vMaxPos](double val) {
|
||
if (val == maxFr) vMaxPos.push_back(idx);
|
||
idx++;
|
||
});
|
||
|
||
for (int i = 0; i < numLiterature; ++i) {
|
||
int cumulateX = 0; // 计算在最高频词汇处,x值的累加结果
|
||
for (int j = 0; j < vMaxPos.size(); ++j) cumulateX += vX[i][vMaxPos[j]];
|
||
if (cumulateX == vMaxPos.size()) { // 如果频率最高的词汇都出现在了文献中
|
||
hr[Get1DIndex(numGroup, i, groupIdx)] = 1; // 应该是表示知识颗粒的这一组数据跟这篇文献相关性比较高
|
||
}
|
||
}
|
||
}
|
||
|
||
/* 将结果写入返回值 */
|
||
if (nlhs > 0) {
|
||
int datasize = numGroup * numLiterature;
|
||
double* mtxData = new double[datasize];//待存储数据转为double格式
|
||
mxArray* pWriteArray = NULL;//matlab格式矩阵
|
||
//创建一个rowNum*colNum的矩阵
|
||
pWriteArray = mxCreateDoubleMatrix(numGroup, numLiterature, mxREAL);
|
||
for (int i = 0; i < numGroup; i++) {
|
||
for (int j = 0; j < numLiterature; j++) {
|
||
mtxData[j * numGroup + i] = hs[i * numLiterature + j];
|
||
}
|
||
}
|
||
//把data的值赋给pWriteArray指针
|
||
memcpy((void*)(mxGetPr(pWriteArray)), (void*)mtxData, sizeof(double) * datasize);
|
||
plhs[0] = pWriteArray; // 赋值给返回值
|
||
delete[]mtxData;
|
||
}
|
||
finish = clock();
|
||
// cout << "CalcEntropy Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
|
||
}
|
||
|
||
/* 供main调试调用 */
|
||
void mexFunctionWrap(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
|
||
mexFunction(nlhs, plhs, nrhs, prhs);
|
||
} |