twirls/MexFunc/ClusterRandSim.cpp

323 lines
8.2 KiB
C++
Raw Blame History

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <mex.h>
#include <mat.h>
#include <iostream>
#include <algorithm>
#include <string>
#include <unordered_set>
#include <ctime>
#include <vector>
#include <queue>
#include <memory>
#include <thread>
#include <mutex>
#include <condition_variable>
#include <future>
#include <functional>
#include <stdexcept>
#include <unordered_map>
#include <set>
#include <fstream>
#include <algorithm>
#include <random>
#include <cmath>
using std::cout;
using std::endl;
using namespace std;
#define STRING_BUF_SIZE 204800
class ThreadPool {
public:
ThreadPool(size_t);
template<class F, class... Args>
auto enqueue(F&& f, Args&&... args)
->std::future<typename std::result_of<F(Args...)>::type>;
~ThreadPool();
private:
// need to keep track of threads so we can join them
std::vector< std::thread > workers;
// the task queue
std::queue< std::function<void()> > tasks;
// synchronization
std::mutex queue_mutex;
std::condition_variable condition;
bool stop;
};
// the constructor just launches some amount of workers
inline ThreadPool::ThreadPool(size_t threads)
: stop(false)
{
for (size_t i = 0;i < threads;++i)
workers.emplace_back(
[this]
{
for (;;)
{
std::function<void()> task;
{
std::unique_lock<std::mutex> lock(this->queue_mutex);
this->condition.wait(lock,
[this] { return this->stop || !this->tasks.empty(); });
if (this->stop && this->tasks.empty())
return;
task = std::move(this->tasks.front());
this->tasks.pop();
}
task();
}
}
);
}
// add new work item to the pool
template<class F, class... Args>
auto ThreadPool::enqueue(F&& f, Args&&... args)
-> std::future<typename std::result_of<F(Args...)>::type>
{
using return_type = typename std::result_of<F(Args...)>::type;
auto task = std::make_shared< std::packaged_task<return_type()> >(
std::bind(std::forward<F>(f), std::forward<Args>(args)...)
);
std::future<return_type> res = task->get_future();
{
std::unique_lock<std::mutex> lock(queue_mutex);
// don't allow enqueueing after stopping the pool
if (stop)
throw std::runtime_error("enqueue on stopped ThreadPool");
tasks.emplace([task]() { (*task)(); });
}
condition.notify_one();
return res;
}
// the destructor joins all threads
inline ThreadPool::~ThreadPool()
{
{
std::unique_lock<std::mutex> lock(queue_mutex);
stop = true;
}
condition.notify_all();
for (std::thread& worker : workers)
worker.join();
}
/* <20><>ȡһάdouble<6C><65><EFBFBD><EFBFBD> */
void Read1DDouble(const mxArray* pMxArray, vector<double>& vDat) {
int rowNum, colNum;
double* matData;
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
// cout << rowNum << " " << colNum << endl;
matData = (double*)mxGetData(pMxArray); //<2F><>ȡָ<C8A1><D6B8>
vDat.resize(rowNum * colNum);
for (int i = 0; i < vDat.size(); ++i) vDat[i] = matData[i];
}
/* <20><>ȡ<EFBFBD><C8A1>άdouble<6C><65><EFBFBD><EFBFBD> */
void Read2DDouble(const mxArray* pMxArray, vector<vector<double>>& vvDat) {
int rowNum, colNum;
double* matData;
rowNum = (int)mxGetM(pMxArray);
colNum = (int)mxGetN(pMxArray);
vvDat.resize(rowNum);
matData = (double*)mxGetData(pMxArray); //<2F><>ȡָ<C8A1><D6B8>
for (int i = 0; i < rowNum; ++i) {
vvDat[i].resize(colNum);
for (int j = 0; j < colNum; ++j) {
vvDat[i][j] = matData[j * rowNum + i];
}
}
}
// <20>̲߳<DFB3><CCB2><EFBFBD>
struct TPRandSim {
vector<double>* pvTr;
vector<int>* pvRandPos;
vector<double>* pvH;
vector<vector<double>>* pvvX;
int numPositive;
};
// <20><><EFBFBD>߳<EFBFBD><DFB3><EFBFBD><EFBFBD>ں<EFBFBD><DABA><EFBFBD>
void ThreadRandSim(TPRandSim& param) {
vector<double>& vTr = *param.pvTr;
vector<int>& vRandPos = *param.pvRandPos;
vector<vector<double>>& vvX = *param.pvvX;
vector<double>& vH = *param.pvH;
int numPositive = param.numPositive;
int rowNum = vvX.size();
int colNum = vvX[0].size();
clock_t begin = clock(), finish;
/* <20><><EFBFBD><EFBFBD>ģ<EFBFBD><C4A3> */
std::random_device rd;
std::shuffle(vRandPos.begin(), vRandPos.end(), std::default_random_engine(rd()));
for (int i = 0; i < rowNum; ++i) {
int hRowIdx = vRandPos[i]; // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>֮<EFBFBD><D6AE><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
if (vH[hRowIdx] == 1) {
for (int j = 0; j < colNum; ++j) {
vTr[j] += vvX[i][j];
}
}
}
for (auto& val : vTr) val /= numPositive;
finish = clock();
// cout << "Random simulation time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}
/* <20><><EFBFBD>ں<EFBFBD><DABA><EFBFBD> */
/*
<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>һ<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ
<EFBFBD><EFBFBD><EFBFBD>
1. x <20><>ά<EFBFBD><CEAC><EFBFBD>ݣ<EFBFBD>double<6C><65><EFBFBD>ͣ<EFBFBD><CDA3><EFBFBD><EFBFBD><EFBFBD>Ϊ<EFBFBD><CEAA><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ϊ<EFBFBD>ֵ䳤<D6B5>ȣ<EFBFBD>ÿ<EFBFBD><C3BF><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>г<EFBFBD><D0B3>ֵĴ<D6B5><C4B4><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>5<EFBFBD><35>
2. h <20><><EFBFBD><EFBFBD>Ϊ<EFBFBD><CEAA><EFBFBD>׸<EFBFBD><D7B8><EFBFBD><EFBFBD><EFBFBD>ֵΪ1<CEAA><31><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ڸ<EFBFBD>֪ʶ<D6AA><CAB6><EFBFBD><EFBFBD><EFBFBD><EFBFBD>Ӧ<EFBFBD><D3A6><EFBFBD>ǣ<EFBFBD><C7A3><EFBFBD>Ϊ0<CEAA><30><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
3. numThread
<EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
vs z score,<2C><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ָ<EFBFBD><D6B8><EFBFBD><EFBFBD>һά
ps <20><>vs<76><73><EFBFBD><EFBFBD>һ<EFBFBD><D2BB>
*/
void mexFunction(int nlhs, mxArray* plhs[], int nrhs, mxArray** prhs) {
//void mexFunction(int nlhs, mxArray* plhs[], int nrhs, const mxArray* prhs[]) {
if (nrhs < 2) {
cout << "At least 2 arguments should be given for this function!" << endl;
return;
}
clock_t begin = clock(), mid, finish;
vector<double> vH;
vector<vector<double>> vvX;
Read2DDouble(prhs[0], vvX);
Read1DDouble(prhs[1], vH);
int rowNum = vvX.size();
int colNum = vvX[0].size();
cout << vH.size() << '\t' << vvX.size() << endl;
int numThread = 1;
int loopNum = 1000;
if (nrhs > 2) {
double* pNumThread = (double*)mxGetData(prhs[2]);
numThread = (int)pNumThread[0];
if (numThread < 1) numThread = 1;
}
if (nrhs > 3) {
double* pLoopNum = (double*)mxGetData(prhs[3]);
loopNum = (int)pLoopNum[0];
if (loopNum < 1000) loopNum = 1000;
}
/* <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ģ<EFBFBD><C4A3> */
mid = clock();
vector<double> vTs(colNum); // <20><>ʼ<EFBFBD><CABC><EFBFBD>ݣ<EFBFBD><DDA3><EFBFBD>¼vH<76><48>labelΪ1<CEAA><31><EFBFBD>е<EFBFBD><D0B5>о<EFBFBD>ֵ
int numPositive = 0;
for (int i = 0; i < rowNum; ++i) {
if (vH[i] == 1) {
++numPositive;
for (int j = 0; j < colNum; ++j) {
vTs[j] += vvX[i][j];
}
}
}
for (auto& val : vTs) val /= numPositive;
vector<vector<double>> vvTr(loopNum, vector<double>(colNum, 0)); // ģ<><C4A3><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
vector<vector<int>> vvRandPos(numThread, vector<int>(rowNum));
for (int i = 0; i < rowNum; ++i) {
for (auto& vRandPos : vvRandPos) {
vRandPos[i] = i;
}
}
ThreadPool thPool(numThread);
int tid = 0;
for (int i = 0; i < loopNum; ++i) {
TPRandSim tParam = { &vvTr[i], &vvRandPos[tid++ % numThread], &vH, &vvX, numPositive };
thPool.enqueue(ThreadRandSim, tParam);
//ThreadRandSim(tParam);
}
thPool.~ThreadPool();
finish = clock();
cout << "Random simulation time: " << (double)(finish - mid) / CLOCKS_PER_SEC << " s" << endl;
/* <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD><EFBFBD> */
vector<double> vVs(colNum);
vector<double> vPs(colNum);
// <20><><EFBFBD>м<EFBFBD><D0BC><EFBFBD>ƽ<EFBFBD><C6BD>ֵ
vector<double> vMean(colNum);
vector<double> vStd(colNum);
for (int i = 0; i < vvTr.size(); ++i) {
for (int j = 0; j < vvTr[i].size(); ++j) {
vMean[j] += vvTr[i][j];
}
}
for (auto& val : vMean) { val /= loopNum; } // <20><>ֵ
for (int i = 0; i < vvTr.size(); ++i) {
for (int j = 0; j < vvTr[i].size(); ++j) {
const double diff = vvTr[i][j] - vMean[j];
vStd[j] += diff * diff;
}
}
for (auto& val : vStd) { val = sqrt(val / (loopNum - 1)); } // <20><><EFBFBD><EFBFBD><EFBFBD><EFBFBD>
// <20><><EFBFBD><EFBFBD>vs
for (int i = 0; i < vVs.size(); ++i) {
vVs[i] = (vTs[i] - vMean[i]) / vStd[i];
}
// <20><><EFBFBD><EFBFBD>ps
vector<double> vSumGreater(colNum);
vector<double> vSumLess(colNum);
for (int i = 0; i < loopNum; ++i) {
for (int j = 0; j < colNum; ++j) {
if (vvTr[i][j] >= vTs[j]) vSumGreater[j] ++;
if (vvTr[i][j] <= vTs[j]) vSumLess[j] ++;
}
}
for (auto& val : vSumGreater) val /= loopNum;
for (auto& val : vSumLess) val /= loopNum;
for (int i = 0; i < colNum; ++i) {
vPs[i] = min(vSumGreater[i], vSumLess[i]);
}
ofstream ofs("d:\\result.txt");
for (int i = 0; i < colNum; ++i) {
ofs << vVs[i] << '\t' << vPs[i] << endl;
}
ofs.close();
/* д<><D0B4><EFBFBD><EFBFBD><EFBFBD><EFBFBD> */
if (nlhs > 0) { // vs
mxArray* pWriteArray = NULL;//matlab<61><62>ʽ<EFBFBD><CABD><EFBFBD><EFBFBD>
//<2F><><EFBFBD><EFBFBD>һ<EFBFBD><D2BB>rowNum*colNum<75>ľ<EFBFBD><C4BE><EFBFBD>
pWriteArray = mxCreateDoubleMatrix(1, vVs.size(), mxREAL);
//<2F><>data<74><61>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD>pWriteArrayָ<79><D6B8>
memcpy((void*)(mxGetPr(pWriteArray)), (void*)vVs.data(), sizeof(double) * vVs.size());
plhs[0] = pWriteArray; // <20><>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ
}
if (nlhs > 1) { // ps
mxArray* pWriteArray = NULL;//matlab<61><62>ʽ<EFBFBD><CABD><EFBFBD><EFBFBD>
//<2F><><EFBFBD><EFBFBD>һ<EFBFBD><D2BB>rowNum*colNum<75>ľ<EFBFBD><C4BE><EFBFBD>
pWriteArray = mxCreateDoubleMatrix(1, vPs.size(), mxREAL);
//<2F><>data<74><61>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD>pWriteArrayָ<79><D6B8>
memcpy((void*)(mxGetPr(pWriteArray)), (void*)vPs.data(), sizeof(double)* vPs.size());
plhs[1] = pWriteArray; // <20><>ֵ<EFBFBD><D6B5><EFBFBD><EFBFBD><EFBFBD><EFBFBD>ֵ
}
finish = clock();
cout << "Cluster Random simulation Total time: " << (double)(finish - begin) / CLOCKS_PER_SEC << " s" << endl;
}