twirls/MexFunc/RandSim.cpp

381 lines
10 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <mex.h>
#include <mat.h>
#include <iostream>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <ctime>
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
#include <unordered_map>
#include <set>
#include <fstream>
#include <algorithm>
#include <random>
#include <cmath>
using std::cout;
using std::endl;
using namespace std;
#define STRING_BUF_SIZE 204800
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
->std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for (size_t i = 0;i < threads;++i)
workers.emplace_back(
[this]
{
for (;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if (stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread& worker : workers)
worker.join();
}
// <20><>ȡһάcell<6C>ַ<EFBFBD><D6B7><EFBFBD><EFBFBD><EFBFBD>ת<EFBFBD><D7AA><EFBFBD>ɴ<EFBFBD>д
inline bool Read1DWord(const mxArray* pMxArray, vector<string>& vStr) {
mxArray* pCell = nullptr;
int rowNum, colNum;
char* strBuf = new char[STRING_BUF_SIZE];
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
vStr.resize(rowNum * colNum);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
pCell = mxGetCell(pMxArray, j * rowNum + i);
if (mxGetString(pCell, strBuf, STRING_BUF_SIZE) != 0) {
cout << "String is too large to fit in the buffer! " << i + 1 << '\t' << j + 1 << endl;
return false;
}
vStr[i * colNum + j] = strBuf;
auto& lastStr = vStr[i * colNum + j];
transform(lastStr.cbegin(), lastStr.cend(), lastStr.begin(), ::toupper); // ת<>ɴ<EFBFBD>д
}
}
delete[]strBuf;
return true;
}
// <20><>ȡ<EFBFBD><C8A1>άcell<6C>ַ<EFBFBD><D6B7><EFBFBD><EFBFBD><EFBFBD>ת<EFBFBD><D7AA><EFBFBD>ɴ<EFBFBD>д
inline bool Read2DWord(const mxArray* pMxArray, vector<vector<string>>& vvStr) {
mxArray* pCell = nullptr;
int rowNum, colNum;
char* strBuf = new char[STRING_BUF_SIZE];
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
for (int i = 0; i < rowNum; ++i) {
for (int j = 0; j < colNum; ++j) {
pCell = mxGetCell(pMxArray, j * rowNum + i);
int childRowNum = (int)mxGetM(pCell);
int childColNum = (int)mxGetN(pCell);
vvStr.push_back(vector<string>());
Read1DWord(pCell, vvStr.back());
}
}
delete[]strBuf;
return true;
}
// <20><>txt<78>ļ<EFBFBD><C4BC><EFBFBD><EFBFBD><EFBFBD>ȡ<EFBFBD>ַ<EFBFBD><D6B7><EFBFBD>, <20><>ת<EFBFBD><D7AA><EFBFBD>ɴ<EFBFBD>д
inline void ReadWordFromFile(const string& filePath, vector<vector<string>>& vvStr) {
filebuf fb;
if (fb.open(filePath.c_str(), ios::in) == NULL) {
cout << "FilePath error: " << filePath << endl;
return;
}
istream ist(&fb);
string lineInfo;
while (getline(ist, lineInfo)) {
int i = 0;
vvStr.push_back(vector<string>());
vector<string> & vecStr = vvStr.back();
string tmp;
while (i < lineInfo.length()) {
while (i < lineInfo.length() && lineInfo[i] != ' ') {
tmp += lineInfo[i++];
}
if (!tmp.empty()) {
transform(tmp.begin(), tmp.end(), tmp.begin(), ::toupper);
vecStr.push_back(tmp);
}
tmp.clear();
++i;
}
}
fb.close();
}
// <20>̲߳<DFB3><CCB2><EFBFBD>
struct TPRandSim {
vector<int>* pvZr;
vector<int>* pvRandPos;
unordered_map<string, int>* pumDicWordPos;
vector<vector<string>>* pvvWd2;
int wdSize;
};
// <20><><EFBFBD>߳<EFBFBD><DFB3><EFBFBD><EFBFBD>ں<EFBFBD><DABA><EFBFBD>
void ThreadRandSim(TPRandSim& param) {
vector<int> &vZr = *param.pvZr;
vector<int> &vRandPos = *param.pvRandPos;
unordered_map<string, int> &umDicWordPos = *param.pumDicWordPos;
vector<vector<string>> &vvWd2 = *param.pvvWd2;
int wdSize = param.wdSize;
clock_t begin = clock(), finish;
/* <20><><EFBFBD><EFBFBD>ģ<EFBFBD><C4A3> */
std::random_device rd;
std::shuffle(vRandPos.begin(), vRandPos.end(), std::default_random_engine(rd()));
unordered_set<int> usPos;
for (int i = 0; i < wdSize; ++i) {
// cout << i << '\t' << vRandPos[i] << '\t' << vvWd2.size() << endl;
auto& vWd2 = vvWd2[vRandPos[i]];
usPos.clear();
for (auto& word : vWd2) {
auto itr = umDicWordPos.find(word);
if (itr != umDicWordPos.end()) {
usPos.insert(itr->second);
}
}
for (auto idx : usPos) {
vZr[idx] += 1;
}
}
finish = clock();
// cout << "Random simulation time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}
/* <20><><EFBFBD>ں<EFBFBD><DABA><EFBFBD> */
/*
<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>һ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ
<EFBFBD><EFBFBD><EFBFBD>
1. wd <20><><EFBFBD><EFBFBD>ժҪ<D5AA>еĵ<D0B5><C4B5>ʣ<EFBFBD><CAA3><EFBFBD>άcell
2. wd2 <20><><EFBFBD><EFBFBD>5w<35><77>ժҪ<D5AA><D2AA>ÿ<EFBFBD><C3BF>ժҪ<D5AA>а<EFBFBD><D0B0><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ʣ<EFBFBD><CAA3><EFBFBD>άcell
3. dicr <20>ɴ<EFBFBD>д<EFBFBD><D0B4>ĸ<EFBFBD><C4B8><EFBFBD>ɵ<EFBFBD><C9B5>ֵ䣬<D6B5><E4A3AC><EFBFBD><EFBFBD>ĸ<EFBFBD><C4B8><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ģ<EFBFBD>һάcell
4. numThread
5. numLoop
<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
vr Ӧ<><D3A6><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ
*/
//void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
if (nrhs < 1) {
cout << "At least 3 arguments should be given for this function!" << endl;
return;
}
clock_t begin = clock(), mid, finish;
vector<string> vDicr;
vector<vector<string>> vvWd;
vector<vector<string>> vvWd2; // 5w word
Read2DWord(prhs[0], vvWd);
// Read2DWord(prhs[1], vvWd2);
char* strBuf = new char[STRING_BUF_SIZE];
mxGetString(prhs[1], strBuf, STRING_BUF_SIZE);
string wd2FilePath(strBuf);
delete[]strBuf;
ReadWordFromFile(wd2FilePath, vvWd2);
Read1DWord(prhs[2], vDicr);
// char* strBuf = new char[STRING_BUF_SIZE];
// mxGetString(prhs[1], strBuf, STRING_BUF_SIZE);
// string wd2FilePath(strBuf);
// delete[]strBuf;
// cout << wd2FilePath << endl;
// vector<vector<string>> vvWd3;
// ReadWordFromFile("D:\\Twirls\\gat1\\literatures\\temp\\wd2s.txt", vvWd3);
//ofstream ofs("d:\\diff.txt");
//ofs << vvWd2.size() << '\t' << vvWd3.size() << endl;
//for (int i = 0; i < vvWd2.size(); ++i) {
// if (vvWd2[i].size() != vvWd3[i].size())
// ofs << vvWd2[i].size() << '\t' << vvWd3[i].size() << endl;
// //for (int j = 0; j < vvWd2[i].size(); ++j) {
// // if (vvWd2[i][j] != vvWd3[i][j]) {
// // ofs << i+1 << '\t' << j+1 << '\t' << vvWd2[i][j] << '\t' << vvWd3[i][j] << endl;
// // }
// //}
//}
//ofs.close();
int numThread = 1;
int loopNum = 1000;
if (nrhs > 3) {
double* pNumThread = (double*)mxGetData(prhs[3]);
numThread = (int)pNumThread[0];
if (numThread < 1) numThread = 1;
}
if (nrhs > 4) {
double* pLoopNum = (double*)mxGetData(prhs[4]);
loopNum = (int)pLoopNum[0];
if (loopNum < 1000) loopNum = 1000;
}
/* ͳ<><CDB3>dicr<63>ֵ<EFBFBD><D6B5>У<EFBFBD>ÿ<EFBFBD><C3BF><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>wd<77>г<EFBFBD><D0B3>ֵĴ<D6B5><C4B4><EFBFBD> */
mid = clock();
unordered_map<string, int> umWordPos;
for (int i = 0; i < vDicr.size(); ++i) umWordPos[vDicr[i]] = i; // <20><>¼<EFBFBD><C2BC><EFBFBD><EFBFBD>λ<EFBFBD><CEBB>
vector<int> vZs(vDicr.size());
unordered_set<int> usPos; // <20><><EFBFBD>γ<EFBFBD><CEB3><EFBFBD><EFBFBD><EFBFBD>wd<77>еĵ<D0B5><C4B5>ʣ<EFBFBD>ֻͳ<D6BB><CDB3>һ<EFBFBD>Σ<EFBFBD><CEA3><EFBFBD><EFBFBD><EFBFBD>ԭmatlab<61><62><EFBFBD><EFBFBD><EFBFBD>Ĺ<EFBFBD><C4B9>ܣ<EFBFBD><DCA3>Ƿ<EFBFBD><C7B7><EFBFBD>Ҫ<EFBFBD>޸ģ<DEB8>
for (auto & vWd : vvWd) {
usPos.clear();
for (auto & word : vWd) {
auto itr = umWordPos.find(word);
if (itr != umWordPos.end()) {
usPos.insert(itr->second);
}
}
for (auto idx : usPos) {
vZs[idx] += 1;
}
}
finish = clock();
cout << "Calc word occurrence time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ģ<EFBFBD><C4A3> */
mid = clock();
vector<vector<int>> vvZr(loopNum, vector<int>(vDicr.size(), 0)); // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
vector<vector<int>> vvRandPos(numThread, vector<int>(vvWd2.size()));
for (int i = 0; i < vvWd2.size(); ++i) {
for (auto& vRandPos : vvRandPos) {
vRandPos[i] = i;
}
}
ThreadPool thPool(numThread);
int tid = 0;
for (int i = 0; i < loopNum; ++i) {
TPRandSim tParam = { &vvZr[i], &vvRandPos[tid++ % numThread], &umWordPos, &vvWd2, vvWd.size()};
thPool.enqueue(ThreadRandSim, tParam);
//ThreadRandSim(tParam);
}
thPool.~ThreadPool();
finish = clock();
cout << "Random simulation time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* <20><><EFBFBD><EFBFBD>vr */
vector<double> vVr(vDicr.size());
// <20><><EFBFBD>м<EFBFBD><D0BC><EFBFBD>ƽ<EFBFBD><C6BD>ֵ
vector<double> vMean(vDicr.size());
vector<double> vStd(vDicr.size());
for (int i = 0; i < vvZr.size(); ++i) {
for (int j = 0; j < vvZr[i].size(); ++j) {
vMean[j] += vvZr[i][j];
}
}
for (auto& val : vMean) { val /= loopNum; } // <20><>ֵ
for (int i = 0; i < vvZr.size(); ++i) {
for (int j = 0; j < vvZr[i].size(); ++j) {
const double diff = vvZr[i][j] - vMean[j];
vStd[j] += diff * diff;
}
}
for (auto& val : vStd) { val = sqrt(val / (loopNum - 1)); } // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
// <20><><EFBFBD><EFBFBD>vr
for (int i = 0; i < vVr.size(); ++i) {
vVr[i] = (vZs[i] - vMean[i]) / vStd[i];
}
// ofstream ofs("d:\\result.txt");
// int i = 0;
// for (auto& vr : vVr) {
// ofs << vr << endl;
// }
// ofs.close();
/* д<><D0B4><EFBFBD><EFBFBD><EFBFBD><EFBFBD> */
if (nlhs > 0) {
mxArray* pWriteArray = NULL;//matlab<61><62>ʽ<EFBFBD><CABD><EFBFBD><EFBFBD>
//<2F><><EFBFBD><EFBFBD>һ<EFBFBD><D2BB>rowNum*colNum<75>ľ<EFBFBD><C4BE><EFBFBD>
pWriteArray = mxCreateDoubleMatrix(1, vVr.size(), mxREAL);
//<2F><>data<74><61>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD>pWriteArrayָ<79><D6B8>
memcpy((void*)(mxGetPr(pWriteArray)), (void*)vVr.data(), sizeof(double) * vVr.size());
plhs[0] = pWriteArray; // <20><>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ
}
finish = clock();
cout << "Random simulation Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}