twirls/MexFunc/AllEntropyMean.cpp

312 lines
8.9 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <mex.h>
#include <mat.h>
#include <iostream>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <ctime>
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
#include <unordered_map>
#include <set>
#include <fstream>
#include <random>
#include <cmath>
#include <stdlib.h>
#include <limits.h>
#include <atomic>
using std::cout;
using std::endl;
using namespace std;
#define STRING_BUF_SIZE 204800
// 将matlab存储方式转换成c存储方式
#define TRANS_ROW_COL(dst, src, rowNum, colNum) \
for (int rowI = 0; rowI < rowNum; ++rowI) { \
for (int colJ = 0; colJ < colNum; ++colJ) { \
dst[rowI * colNum + colJ] = src[colJ * rowNum + rowI]; \
} \
}
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
->std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for (size_t i = 0;i < threads;++i)
workers.emplace_back(
[this]
{
for (;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F && f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if (stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread& worker : workers)
worker.join();
}
// 读取一维cell字符串并转换成大写
inline bool ReadWord1DCell(const mxArray* pMxArray, vector<string>& vStr) {
mxArray* pCell = nullptr;
int rowNum, colNum;
char* strBuf = new char[STRING_BUF_SIZE];
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
vStr.resize(rowNum * colNum);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
pCell = mxGetCell(pMxArray, j * rowNum + i);
if (mxGetString(pCell, strBuf, STRING_BUF_SIZE) != 0) {
cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl;
return false;
}
vStr[i * colNum + j] = strBuf;
auto& lastStr = vStr[i * colNum + j];
transform(lastStr.cbegin(), lastStr.cend(), lastStr.begin(), ::toupper); // 转成大写
}
}
delete[]strBuf;
return true;
}
// 读取二维cell字符串并转换成大写
inline bool ReadWord2DCell(const mxArray* pMxArray, vector<vector<string>>& vvStr) {
mxArray* pCell = nullptr;
int rowNum, colNum;
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
pCell = mxGetCell(pMxArray, j * rowNum + i);
vvStr.push_back(vector<string>());
ReadWord1DCell(pCell, vvStr.back());
}
}
return true;
}
// 读取由一维cell包裹的double数据每个cell是一个一维的double数组
inline void ReadDoulbe1DCell(const mxArray* pMxArray, vector<vector<double> >& vvData) {
// 读取fr数值
int rowNum = (int)mxGetM(pMxArray);
int colNum = (int)mxGetN(pMxArray);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
mxArray* pCell = mxGetCell(pMxArray, j * rowNum + i);
int childRowNum = (int)mxGetM(pCell);
int childColNum = (int)mxGetN(pCell);
vvData.push_back(vector<double>());
vvData.back().resize(childRowNum * childColNum);
double* pVal = (double*)mxGetData(pCell); //获取数据指针
TRANS_ROW_COL(vvData.back(), pVal, childRowNum, childColNum); // 行列存储方式转换
}
}
}
// 将结果写入mxArray, 作为后续的返回值
mxArray* writeToMatDouble(const double* data, int rowNum, int colNum) {
mxArray* pWriteArray = NULL;//matlab格式矩阵
int len = rowNum * colNum;
//创建一个rowNum*colNum的矩阵
pWriteArray = mxCreateDoubleMatrix(rowNum, colNum, mxREAL);
//把data的值赋给pWriteArray指针
memcpy((void*)(mxGetPr(pWriteArray)), (void*)data, sizeof(double) * len);
return pWriteArray; // 赋值给返回值
}
/* 多线程计算信息熵 */
struct TPEntropyMean {
vector<string>* pvDs;
vector<double>* pvFr;
vector<unordered_set<string>>* pvusAbsWord;
vector<double>* pvHs;
vector<double>* pvHd;
};
void ThreadCalcEntropyMean(TPEntropyMean& param) {
vector<string>& vDs = *param.pvDs; // 这一组ds
vector<double>& vFr = *param.pvFr; // frequency
vector<unordered_set<string>>& vusAbsWord = *param.pvusAbsWord;
vector<double>& vHs = *param.pvHs;
vector<double>& vHd = *param.pvHd;
const int numAbs = vusAbsWord.size();
const int numDsWord = vDs.size(); // 这一组数据中包含的单词数量
// 检查知识颗粒中的词语是否出现在pubmed摘要的词语中
for (int i = 0; i < numAbs; ++i) {
for (int j = 0; j < numDsWord; ++j) {
if (vusAbsWord[i].find(vDs[j]) != vusAbsWord[i].end()) { // 这一组单词中的j索引位置的单词在第i个文献中出现过
vHs[i] -= vFr[j] * log2(vFr[j]);
}
}
vHd[i] = vHs[i] / vusAbsWord[i].size();
}
}
/*
输入:
1. ds: 知识颗粒中的信息,应该也是摘要,分割成了字符串,大写的。
2. frr: ds中每个单词对应的频率。
3. ws: 文献的摘要,被切割成特定的长度了,字符串已经分割好了。
[4]. numThread
输出:
1. hs: 信息熵,二维[len(ds)][len(ws)]
2. hd: 每个单词的平均信息熵维度同hs
*/
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
if (nrhs < 3) {
cout << "At least 3 arguments should be given for this function!" << endl;
return;
}
clock_t begin = clock(), mid, finish;
vector<vector<string> > vvDs; // 每个知识颗粒的ds矩阵词汇矩阵
vector<vector<double> > vvFr; // 词汇对应的频率
ReadWord2DCell(prhs[0], vvDs);
ReadDoulbe1DCell(prhs[1], vvFr);
vector<vector<string>> vvWs;
ReadWord2DCell(prhs[2], vvWs); // 文献摘要的字符串数组
int numThread = 1;
if (nrhs > 3) {
double* pNumThread = (double*)mxGetData(prhs[3]);
numThread = (int)pNumThread[0];
if (numThread < 1) numThread = 1;
}
vector<unordered_set<string>> vusAbsWord(vvWs.size()); // 将每篇文章摘要的单词放入hash表
// 将处理摘要之后的每个词语放入hash表
for (int i=0; i<vvWs.size(); ++i) {
for (auto& word : vvWs[i]) {
vusAbsWord[i].insert(word);
}
}
finish = clock();
cout << "Load Data time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl << flush;
int numGroup = vvDs.size();
int numAbs = vvWs.size(); // 摘要个数
// 计算结果
vector<vector<double>> vvHs(numGroup, vector<double>(numAbs));
vector<vector<double>> vvHd(numGroup, vector<double>(numAbs));
vector<string>* pvDs;
vector<double>* pvFr;
vector<unordered_set<string>>* pvusAbsWord;
vector<double>* pvHs;
vector<double>* pvHd;
/* 多线程计算信息熵 */
ThreadPool thPool(numThread);
for (int groupIdx = 0; groupIdx < numGroup; ++groupIdx) { // 遍历知识颗粒中的每一组
TPEntropyMean tp = { &vvDs[groupIdx], &vvFr[groupIdx], &vusAbsWord, &vvHs[groupIdx], &vvHd[groupIdx] };
thPool.enqueue(ThreadCalcEntropyMean, tp);
}
thPool.~ThreadPool();
// ofstream ofs("d:\\result_hs.txt");
// for (int i = 0; i < numGroup; ++i) {
// for (int j = 0; j < numAbs; ++j) {
// ofs << vvHs[i][j] << ' ';
// }
// ofs << endl;
// }
// ofs.close();
/* 将结果写入返回值 */
if (nlhs > 0) {
vector<double> vData(numGroup * numAbs);
for (int i = 0; i < numGroup; ++i) for (int j = 0; j < numAbs; ++j) vData[j * numGroup + i] = vvHs[i][j];
plhs[0] = writeToMatDouble(vData.data(), numGroup, numAbs);
}
if (nlhs > 1) {
vector<double> vData(numGroup * numAbs);
for (int i = 0; i < numGroup; ++i) for (int j = 0; j < numAbs; ++j) vData[j * numGroup + i] = vvHd[i][j];
plhs[1] = writeToMatDouble(vData.data(), numGroup, numAbs);
}
finish = clock();
cout << "Calc Entropy and Mean Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}
/* 供main调试调用 */
void mexFunctionWrap(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
mexFunction(nlhs, plhs, nrhs, prhs);
}