twirls/MexFunc/CalcEntropy.cpp

237 lines
8.6 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <mex.h>
#include <mat.h>
#include <iostream>
#include <algorithm>
#include <vector>
#include <string>
#include <unordered_set>
#include <ctime>
using std::cout;
using std::endl;
using namespace std;
#define STRING_BUF_SIZE 204800
/* 读取二层cell包裹的字符串,和数值ds,fr */
#define OUTER_FOR_BEGIN \
rowNum = (int)mxGetM(pMxArray); \
colNum = (int)mxGetN(pMxArray); \
for (int i = 0; i < rowNum; ++i) { \
for (int j = 0; j < colNum; ++j) { \
mxArray* pCell = mxGetCell(pMxArray, j * rowNum + i); \
int childRowNum = (int)mxGetM(pCell); \
int childColNum = (int)mxGetN(pCell);
#define OUTER_FOR_END \
} \
}
#define INNTER_FOR_BEGIN \
for (int ii = 0; ii < childRowNum; ii++) { \
for (int jj = 0; jj < childColNum; jj++) { \
mxArray *pChildCell = mxGetCell(pCell, jj * childRowNum + ii);
#define INNTER_FOR_END \
} \
}
// 将matlab存储方式转换成c存储方式
#define TRANS_ROW_COL(dst, src, rowNum, colNum) \
for (int rowI = 0; rowI < rowNum; ++rowI) { \
for (int colJ = 0; colJ < colNum; ++colJ) { \
dst[rowI * colNum + colJ] = src[colJ * rowNum + rowI]; \
} \
}
// 将二维索引转成一维的索引
inline int Get1DIndex(int colNum, int row, int col) {
return row * colNum + col;
}
// 读取G结构体中的ds和fr
void GetFrDs(const mxArray* pMxParent, vector<vector<string> >& vvDs, vector<vector<double> >& vvFr) {
// 读取ds字符串
int rowNum, colNum;
char *strBuf = new char[STRING_BUF_SIZE];
mxArray* pMxArray = mxGetField(pMxParent, 0, "ds"); // ds
OUTER_FOR_BEGIN
vvDs.push_back(vector<string>());
vvDs.back().resize(childRowNum * childColNum);
INNTER_FOR_BEGIN
if (mxGetString(pChildCell, strBuf, STRING_BUF_SIZE) != 0) {
cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl;
delete[]strBuf;
return;
}
vvDs.back()[ii * childColNum + jj] = strBuf;
auto& lastStr = vvDs.back()[ii * childColNum + jj];
transform(lastStr.begin(), lastStr.end(), lastStr.begin(), ::toupper); // 转成大写
INNTER_FOR_END
OUTER_FOR_END
// 读取fr数值
pMxArray = mxGetField(pMxParent, 0, "fr"); // fr
OUTER_FOR_BEGIN
vvFr.push_back(vector<double>());
vvFr.back().resize(childRowNum * childColNum);
double* pVal = (double*)mxGetData(pCell); //获取指针
TRANS_ROW_COL(vvFr.back(), pVal, childRowNum, childColNum); // 行列存储方式转换
OUTER_FOR_END
delete[]strBuf;
}
/* 读取abs */
void GetAbstract(const mxArray* pMxAbs, vector<string>& vAbs) {
int rowNum = (int)mxGetM(pMxAbs);
int colNum = (int)mxGetN(pMxAbs);
char *strBuf = new char[STRING_BUF_SIZE];
vAbs.resize(rowNum * colNum);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
mxArray* pCell = mxGetCell(pMxAbs, j * rowNum + i);
if (mxGetString(pCell, strBuf, STRING_BUF_SIZE) != 0) {
cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl;
delete[]strBuf;
return;
}
vAbs[i * colNum + j] = strBuf;
}
}
delete[]strBuf;
}
/*
nlhs输出参数数目(Number Left - hand side),等号左边
plhs指向输出参数的指针(Point Left - hand side),等号左边
nrhs输入参数数目(Number Right - hand side),等号右边
prhs指向输入参数的指针(Point Right - hand side)等号右边。要注意prhs是const的指针数组即不能改变其指向内容。
*/
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
//cout << "MexCalcEntropy" << endl;
//cout << nlhs << '\t' << nrhs << endl;
if (nrhs != 2) {
cout << "2 arguments should be given for this function!" << endl;
return;
}
clock_t begin, finish;
begin = clock();
vector<vector<string> > vvDs; // 每个知识颗粒的ds矩阵词汇矩阵
vector<vector<double> > vvFr; // 词汇对应的频率
GetFrDs(prhs[1], vvDs, vvFr);
vector<string> vAbstract; // 读取abs1, 然后分割成一个一个的单词
GetAbstract(prhs[0], vAbstract);
/* 将摘要信息分割成一个一个的词汇 */
// begin = clock();
unordered_set<char> usWordChars; // 能组成单词的字符要不要考虑数字原版matlab是提取了数字的
for (int i = 65; i <= 90; i++) usWordChars.insert(char(i)); // A - Z
for (int i = 97; i <= 122; i++) usWordChars.insert(char(i)); // a - z
for (int i = 48; i <= 57; i++) usWordChars.insert(char(i)); // 0 - 9
usWordChars.insert('/'); usWordChars.insert('+'); usWordChars.insert('-');
vector<vector<string> > vvWordMtx(vAbstract.size()); // 初始大小为文章的个数
vector<unordered_set<string> > vusAbsWord(vAbstract.size()); // 将每篇文章摘要的单词放入hash表
for (int i = 0; i < vAbstract.size(); i++) {
auto& strAbs = vAbstract[i];
// 遍历摘要字符串的每一个字符,取出每一个单词
vector<string>& vWord = vvWordMtx[i];
if (strAbs.size() == 0) continue; // 摘要信息为空,跳过(一般不会出现这个情况)
int wordStartPos = 0;
while (wordStartPos < strAbs.size() && usWordChars.find(strAbs[wordStartPos]) == usWordChars.end())
wordStartPos++;
for (int curPos = wordStartPos + 1; curPos < strAbs.size(); ++curPos) {
if (usWordChars.find(strAbs[curPos]) == usWordChars.end()) { // 找到了分割符
vWord.push_back(strAbs.substr(wordStartPos, curPos - wordStartPos));
wordStartPos = curPos + 1; // 找下一个词语起始位置
while (wordStartPos < strAbs.size() && usWordChars.find(strAbs[wordStartPos]) == usWordChars.end())
wordStartPos++;
curPos = wordStartPos; // 循环会自动加1
}
}
// 将处理摘要之后的每个词语放入hash表
for (auto& word : vWord) {
string upWord(word);
transform(upWord.begin(), upWord.end(), upWord.begin(), ::toupper);
vusAbsWord[i].insert(upWord);
}
}
// finish = clock();
// cout << "Split abstract time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
// 存放结果,用一维数组存放二维数据
vector<double> hs;
vector<double> hr;
const int numLiterature = vusAbsWord.size(); // pubmed 文件中包含的文献数量
const int numGroup = vvDs.size(); // ds包含的组数
hs.resize(numGroup * numLiterature);
hr.resize(numLiterature * numGroup);
for (int groupIdx = 0; groupIdx < numGroup; ++groupIdx) { // 遍历知识颗粒中的每一组
vector<string>& vDs = vvDs[groupIdx]; // 这一组ds
vector<double>& vFr = vvFr[groupIdx]; // frequency
const int numWord = vDs.size(); // 这一组数据中包含的单词数量
vector<vector<int> > vX(numLiterature, vector<int>(numWord, 0));
// 检查知识颗粒中的词语是否出现在pubmed摘要的词语中
for (int i = 0; i < numLiterature; ++i) {
for (int j = 0; j < numWord; ++j) {
if (vusAbsWord[i].find(vDs[j]) != vusAbsWord[i].end()) { // 这一组单词中的j索引位置的单词在第i个文献中出现过
vX[i][j] = 1;
}
}
}
// 找词汇的最高频率
double maxFr = *max_element(vFr.begin(), vFr.end());
// 将fr的数值规范化到00.368)之间
const double normalMax = 0.368;
for (auto& frVal : vFr) frVal = frVal * normalMax / maxFr;
maxFr = normalMax;
// 对每个知识颗粒每一组数据,计算信息熵
for (int i = 0; i < numLiterature; ++i) {
for (int j = 0; j < numWord; ++j) {
if (vX[i][j] == 1) {
hs[Get1DIndex(numLiterature, groupIdx, i)] -= vFr[j] * log2(vFr[j]);
}
}
}
// 找最高频词汇所在的索引位置
vector<int> vMaxPos;
int idx = 0;
for_each(vFr.begin(), vFr.end(), [&idx, maxFr, &vMaxPos](double val) {
if (val == maxFr) vMaxPos.push_back(idx);
idx++;
});
for (int i = 0; i < numLiterature; ++i) {
int cumulateX = 0; // 计算在最高频词汇处x值的累加结果
for (int j = 0; j < vMaxPos.size(); ++j) cumulateX += vX[i][vMaxPos[j]];
if (cumulateX == vMaxPos.size()) { // 如果频率最高的词汇都出现在了文献中
hr[Get1DIndex(numGroup, i, groupIdx)] = 1; // 应该是表示知识颗粒的这一组数据跟这篇文献相关性比较高
}
}
}
/* 将结果写入返回值 */
if (nlhs > 0) {
int datasize = numGroup * numLiterature;
double* mtxData = new double[datasize];//待存储数据转为double格式
mxArray* pWriteArray = NULL;//matlab格式矩阵
//创建一个rowNum*colNum的矩阵
pWriteArray = mxCreateDoubleMatrix(numGroup, numLiterature, mxREAL);
for (int i = 0; i < numGroup; i++) {
for (int j = 0; j < numLiterature; j++) {
mtxData[j * numGroup + i] = hs[i * numLiterature + j];
}
}
//把data的值赋给pWriteArray指针
memcpy((void*)(mxGetPr(pWriteArray)), (void*)mtxData, sizeof(double) * datasize);
plhs[0] = pWriteArray; // 赋值给返回值
delete[]mtxData;
}
finish = clock();
// cout << "CalcEntropy Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}
/* 供main调试调用 */
void MexCalcEntropy(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
mexFunction(nlhs, plhs, nrhs, prhs);
}