线程池 完全基于 C++ 11
目标 用C++ 11 的多线程 实现一个 带返回值且接收可变参数的 线程池
组件:信号量类(C++11线程不包含)、返回值类、线程类、线程池类、任务类
架构设计
main函数执行过程
基本组件
信号量
class Semaphore
{
public:
Semaphore(int limit = 0)
:resLimit_(limit)
{}
~Semaphore() = default;
void wait()
{
std::unique_lock<std::mutex> lock(mtx_);
cond_.wait(lock, [&]()->bool {return resLimit_ > 0; });
resLimit_--;
}
void post()
{
std::unique_lock<std::mutex> lock(mtx_);
resLimit_++;
cond_.notify_all();
}
private:
int resLimit_;
std::mutex mtx_;
std::condition_variable cond_;
};
返回值类型Result设计
class Result
{
public:
Result(std::shared_ptr<Task> task, bool isValid = true);
~Result() = default;
void setVal(Any any);
Any get();
private:
Any any_;
Semaphore sem_;
std::shared_ptr<Task> task_;
std::atomic_bool isValid_;
};
接收任意类型的any类 —— 用于Result类存储任务返回值
利用基类std::unique_ptr<Base> base_; 存储 任意类型的派生类,将T的类型(派生类)延后确定(任务执行完成确定以及获取返回值时确定)
class Any
{
public:
Any() = default;
~Any() = default;
Any(const Any&) = delete;
Any& operator=(const Any&) = delete;
Any(Any&&) = default;
Any& operator=(Any&&) = default;
template<typename T>
Any(T data) : base_(std::make_unique<Derive<T>>(data))
{}
template<typename T>
T cast_()
{
Derive<T>* pd = dynamic_cast<Derive<T>*>(base_.get());
if (pd == nullptr)
{
throw "type is unmatch!";
}
return pd->data_;
}
private:
class Base
{
public:
virtual ~Base() = default;
};
template<typename T>
class Derive : public Base
{
public:
Derive(T data) : data_(data)
{}
T data_;
};
private:
std::unique_ptr<Base> base_;
};
任务类
class Task
{
public:
Task();
~Task() = default;
void exec();
void setResult(Result* res);
virtual Any run() = 0;
private:
Result* result_;
};
线程类型
class Thread
{
public:
using ThreadFunc = std::function<void(int)>;
Thread(ThreadFunc func);
~Thread();
void start();
int getId()const;
private:
ThreadFunc func_;
static int generateId_;
int threadId_;
};
线程池类型
class ThreadPool
{
public:
ThreadPool();
~ThreadPool();
void setMode(PoolMode mode);
void setTaskQueMaxThreshHold(int threshhold);
void setThreadSizeThreshHold(int threshhold);
Result submitTask(std::shared_ptr<Task> sp);
void start(int initThreadSize = std::thread::hardware_concurrency());
ThreadPool(const ThreadPool&) = delete;
ThreadPool& operator=(const ThreadPool&) = delete;
private:
void threadFunc(int threadid);
bool checkRunningState() const;
private:
std::unordered_map<int, std::unique_ptr<Thread>> threads_;
int initThreadSize_;
int threadSizeThreshHold_;
std::atomic_int curThreadSize_;
std::atomic_int idleThreadSize_;
std::queue<std::shared_ptr<Task>> taskQue_;
std::atomic_int taskSize_;
int taskQueMaxThreshHold_;
std::mutex taskQueMtx_;
std::condition_variable notFull_;
std::condition_variable notEmpty_;
std::condition_variable exitCond_;
PoolMode poolMode_;
std::atomic_bool isPoolRunning_;
};
重点代码分析 submitTask 提交任务
Result ThreadPool::submitTask(std::shared_ptr<Task> sp)
{
std::unique_lock<std::mutex> lock(taskQueMtx_);
if (!notFull_.wait_for(lock, std::chrono::seconds(1),
[&]()->bool { return taskQue_.size() < (size_t)taskQueMaxThreshHold_; }))
{
std::cerr << "task queue is full, submit task fail." << std::endl;
return Result(sp, false);
}
taskQue_.emplace(sp);
taskSize_++;
notEmpty_.notify_all();
if (poolMode_ == PoolMode::MODE_CACHED
&& taskSize_ > idleThreadSize_
&& curThreadSize_ < threadSizeThreshHold_)
{
std::cout << ">>> create new thread..." << std::endl;
auto ptr = std::make_unique<Thread>(std::bind(&ThreadPool::threadFunc, this, std::placeholders::_1));
int threadId = ptr->getId();
threads_.emplace(threadId, std::move(ptr));
threads_[threadId]->start();
curThreadSize_++;
idleThreadSize_++;
}
return Result(sp);
}
重点代码分析 threadFunc 获取任务、执行任务
void ThreadPool::threadFunc(int threadid)
{
auto lastTime = std::chrono::high_resolution_clock().now();
for (;;)
{
std::shared_ptr<Task> task;
{
std::unique_lock<std::mutex> lock(taskQueMtx_);
std::cout << "tid:" << std::this_thread::get_id()
<< "尝试获取任务..." << std::endl;
while (taskQue_.size() == 0)
{
if (!isPoolRunning_)
{
threads_.erase(threadid);
std::cout << "threadid:" << std::this_thread::get_id() << " exit!"
<< std::endl;
exitCond_.notify_all();
return;
}
if (poolMode_ == PoolMode::MODE_CACHED)
{
if (std::cv_status::timeout ==
notEmpty_.wait_for(lock, std::chrono::seconds(1)))
{
auto now = std::chrono::high_resolution_clock().now();
auto dur = std::chrono::duration_cast<std::chrono::seconds>(now - lastTime);
if (dur.count() >= THREAD_MAX_IDLE_TIME
&& curThreadSize_ > initThreadSize_)
{
threads_.erase(threadid);
curThreadSize_--;
idleThreadSize_--;
std::cout << "threadid:" << std::this_thread::get_id() << " exit!"
<< std::endl;
return;
}
}
}
else
{
notEmpty_.wait(lock);
}
}
idleThreadSize_--;
std::cout << "tid:" << std::this_thread::get_id()
<< "获取任务成功..." << std::endl;
task = taskQue_.front();
taskQue_.pop();
taskSize_--;
if (taskQue_.size() > 0)
{
notEmpty_.notify_all();
}
notFull_.notify_all();
}
if (task != nullptr)
{
task->exec();
}
idleThreadSize_++;
lastTime = std::chrono::high_resolution_clock().now();
}
}
重点代码分析 线程池析构 —— 保证线程池的任务执行完毕,才回收资源
ThreadPool::~ThreadPool()
{
isPoolRunning_ = false;
std::unique_lock<std::mutex> lock(taskQueMtx_);
notEmpty_.notify_all();
exitCond_.wait(lock, [&]()->bool {return threads_.size() == 0; });
}
|