类视图设计
拟实现的功能
类设计

如上图所示,该实现包含四个类,分别是实现线程安全的队列抽线基类,继承自抽象基类的安全队列模板类,线程池类,以及线程任务类。在实际运行过程中,用户首先创建线程池类,并指定使用的最大线程数(默认情况下使用设备支持的最大线程数)。之后创建任务类,在实际工作中,希望能通过任务类来创建和操作消息队列,而消息队列又由线程池管理,因此需要在创建任务类的时候将线程池指针传入。最后,将任务类实例化,并通过线程池的addTask函数将任务加入到线程池中执行。此外,由于addTask使用了可变参数模板和类型萃取技术,可以添加基于函数形式、lambda表达式或者仿函数形式任务。事实上,传入类对象就是基于仿函数形式的任务,此时,需要将任务执行过程放在针对括号运算符的重载函数中。
这样的设计,可以使类图上方两个类(SafeQueueBase和ThreadPool)保持稳定,而下方两个类(SafeQueue和TheadTask)可以变化,隔绝了稳定部分和变换部分,遵循了OOP的开放封闭原则、依赖倒置原则、单一职责原则,提升了程序的扩展能力。
线程安全队列实现
在C++17标准之后,才正式引入支持并发的容器,由于目前工业上常用的C++版本多为C++11与C++14,因此需要手动实现针对队列的线程安全控制,这里提供了一个简易的实现方法,主要是通过使用共享锁实现的:
struct TypeErasedData {
virtual ~TypeErasedData() {}
};
template <typename T>
struct ErasedData : public TypeErasedData {
ErasedData(T& value) : data(value) {}
T& data;
};
class SafeQueueBase
{
public:
SafeQueueBase() = default;
SafeQueueBase(SafeQueueBase&&) = delete;
SafeQueueBase(const SafeQueueBase&) = delete;
SafeQueueBase& operator=(SafeQueueBase&&) = delete;
SafeQueueBase& operator=(const SafeQueueBase&) = delete;
bool empty()
{
std::shared_lock<std::shared_timed_mutex> lck(mux);
return primaryEmpty();
}
size_t size()
{
std::shared_lock<std::shared_timed_mutex> lck(mux);
return primarySize();
}
template <typename T>
void pop(T& _data_output)
{
std::unique_lock<std::shared_timed_mutex> lck(mux);
ErasedData<T> data_output(_data_output);
primaryPop(&data_output);
}
template <typename T>
void push(const T& _data)
{
std::unique_lock<std::shared_timed_mutex> lck(mux);
ErasedData<const T> data(_data);
primaryPush(&data);
}
template <typename T>
void see(std::vector<T>& _data_output)
{
std::shared_lock<std::shared_timed_mutex> lck(mux);
ErasedData<std::vector<T>> data_output(_data_output);
primarySee(&data_output);
}
virtual ~SafeQueueBase(){};
protected:
virtual void primaryPush(TypeErasedData* _data) = 0;
virtual void primaryPop(TypeErasedData* _data) = 0;
virtual void primarySee(TypeErasedData* _data) = 0;
virtual bool primaryEmpty() = 0;
virtual size_t primarySize() = 0;
private:
std::shared_timed_mutex mux; // WARN: need to support c++14
};
该类作为抽象基类,提供了5种操作容器的方法,在对容器进行访问的方法种,使用可移动共享互斥体所有权封装器保证多线程访问的效率;在对容器进行修改的方法种,使用可移动互斥体所有权包装器保证多线程环境下的安全性。此外,使用两个类(TypeErasedData 和ErasedData)实现类型擦除,保证在父类对子类开放统一的接口。在具体实现中,首先使用一个空的TypeErasedData 类作为基类,并将析构函数设为虚函数,方便后期实现多态,之后使用派生类ErasedData实现对多种数据类型的引用。
线程安全队列父类提供了线程安全的操作容器的方法,在这些方法中会调用primary开头的纯虚函数,以实现针对某一数据类型的具体操作,子类则需要逐个覆写:
template <typename T>
class SafeQueue : public SafeQueueBase
{
public:
SafeQueue() = default;
~SafeQueue() = default;
bool primaryEmpty() override
{
return m_queue.empty();
}
virtual size_t primarySize() override
{
return m_queue.size();
}
virtual void primaryPush(TypeErasedData* _data) override
{
auto data = dynamic_cast<ErasedData<const T>*>(_data);
std::cout<<data<<std::endl;
if (data == nullptr) {
throw "Error data type";
}
m_queue.emplace(data->data); // if c++11, use push instead of emplace
}
virtual void primaryPop(TypeErasedData* _data) override
{
auto data = dynamic_cast<ErasedData<T>*>(_data);
if (data == nullptr) {
throw "Error data type";
}
data->data = m_queue.front();
m_queue.pop();
}
virtual void primarySee(TypeErasedData* _data) override
{
auto data = dynamic_cast<ErasedData<std::vector<T>>*>(_data);
if (data == nullptr) {
throw "Error data type";
}
size_t size = m_queue.size();
for (size_t i = 0; i < size; ++i) {
data->data.emplace_back(m_queue.front());
m_queue.pop();
}
}
private:
std::queue<T> m_queue;
};
子类首先继承自父类,并实现了父类中的各种接口函数。注意此处使用运行时的动态类型转换(dynamic_cast),这是实现类型擦除的关键,利用虚表,通过基类找回派生类,从而获得实际的数据类型。在进行动态类型转换的时候如果基类不能转换为指定的派生类,则会返回nullptr,依此判断转换是否成功。如果转换不成功,则说明传入的数据类型与容器数据类型不一致,从而可以引入异常处理的方法。
线程池实现
创建一个线程池类ThreadPool. 该类需要实现以下功能:
- 设定线程池最大运行线程数量
- 创建并维护线程池
- 创建并维护消息队列
- 添加任务
class ThreadPool
{
public:
ThreadPool(std::size_t _thread_count);
ThreadPool(ThreadPool &&) = delete;
ThreadPool(const ThreadPool &) = delete;
ThreadPool &operator=(ThreadPool &&) = delete;
ThreadPool &operator=(const ThreadPool &) = delete;
~ThreadPool();
template <typename T>
void addQueue();
template <typename T>
std::shared_ptr<SafeQueue<T>> getQueue();
template <typename Func, typename... Args>
auto addTask(Func &&f, Args &&...args) -> std::future<typename std::result_of<Func(Args...)>::type>;
private:
using TaskType = typename std::function<void()>;
std::mutex mux;
void startBuildPool();
std::atomic<bool> m_running_state{false};
std::size_t m_thread_count = 0;
std::vector<std::shared_ptr<std::thread>> m_threads;
SafeQueue<TaskType> m_task_queue;
std::vector<std::shared_ptr<SafeQueueBase>> m_message_queue;
std::condition_variable m_cv;
};
以上是线程池类的内容和函数声明,值得注意的是所有拷贝、移动的构造函数和赋值运算符都声明为删除,这是为了防止在拷贝和移动过程中出现潜在的错误。其中,mux是控制任务分配的互斥量、m_running_state是控制线程池运行状态的原子变量、m_threads存放各个线程的智能指针、m_task_queue是存放任务的队列、m_message_queue存放着各种数据类型的消息队列、m_cv作为信号量调节各部分有序运行。下面详细介绍几个函数的具体实现,完整的代码附于文末。
构造函数
inline ThreadPool::ThreadPool(std::size_t _thread_count)
{
auto sys_max_threads = std::thread::hardware_concurrency(); // get system max threads number
if (_thread_count <= 0 || _thread_count >= sys_max_threads) {
m_thread_count = sys_max_threads;
std::cerr << "WARNING: use capable thread count:" << m_thread_count << std::endl;
} else {
m_thread_count = _thread_count;
}
m_threads.reserve(m_thread_count);
m_running_state = true;
startBuildPool();
}
在构造函数中,主要设定了线程池最大运行的线程数,需要注意的是,对于用户输入的线程数,应判断其是否超过硬件支持的最大线程数,这里使用了std::thread::hardware_concurrency();获取硬件最大线程数。在输入的线程数超过硬件支持的最大线程数或者输入线程数小于0时使用最大的线程数。
析构函数
inline ThreadPool::~ThreadPool()
{
m_running_state = false;
m_cv.notify_all(); // break all thread in block state
for (auto &thread : m_threads) {
if (thread->joinable()) {
thread->join();
}
}
}
在析构时,应通知所有等待线程取消等待状态,使用m_cv.notify_all();实现,之后将所有可合并的子线程全部合并到主线程中,也就是依次关闭各个子线程。
任务添加函数
template <typename Func, typename... Args>
inline auto ThreadPool::addTask(Func &&f, Args &&...args) -> std::future<typename std::result_of<Func(Args...)>::type>
{
using return_type = typename std::result_of<Func(Args...)>::type;
auto task = std::make_shared<std::packaged_task<return_type()>>(std::bind(std::forward<Func>(f), std::forward<Args>(args)...));
TaskType thread_task = [task]() { (*task)(); };
m_task_queue.push(thread_task);
m_cv.notify_one();
return task->get_future();
}
任务添加函数是线程池实现的核心,下面逐点解释如何设计任务添加函数使之能兼容有着各种返回值类型、各种数目输入参数的函数、lambda表达式、仿函数。
- 使用可变参数模板捕获函数、lambda表达式、仿函数以及它们的参数。
- 使用std::result_of::type类型萃取器,获得函数、lambda表达式、仿函数的返回值。
- 使用bind将函数与参数进行绑定,在绑定时为了减少不必要的拷贝开支,使用了完美转发将函数参数转发到绑定对象中。
- 打包成packaged_task对象,方便管理promise和future。
- 使用智能指针创建packaged_task对象并自动管理内存。
- 使用一个lambda表达式,将task封装成一个没有返回值的函数指针,注意这里的TaskType 原型是typename std::function; 因为后期通过task的get_future()方法获取线程执行结果,这里就不需要使用返回值了,而且没有返回值的任务对象(准确来说是具有统一的void返回值的任务对象)使得在容器中存储任务和调用执行任务更加方便。
- 将函数指针放入任务队列中,并通知一个空闲线程开始执行任务。
- 任务执行完后获取执行结果,该过程是异步的。
构建线程池并分配任务
inline void ThreadPool::startBuildPool()
{
for (size_t i = 0; i < m_thread_count; ++i) {
auto t = [this, i]() {
while (this->m_running_state) {
TaskType task;
std::unique_lock<std::mutex> lck(mux);
if (this->m_task_queue.empty()) {
m_cv.wait(lck);
}
// need to check again when thread pool destructed can also make wait() pass by notice_all
// and when the condition variable is waiting, it well unlock lck, this may cause many threads are waiting at the same time
if (this->m_running_state == false || m_task_queue.empty()) {
std::this_thread::sleep_for(std::chrono::milliseconds(2));
continue;
}
this->m_task_queue.pop(task);
lck.unlock();
try {
task(); // run task;
} catch (...) {
}
}
};
m_threads.emplace_back(std::make_shared<std::thread>(std::thread(t)));
}
}
构建线程池的函数内部有一个大循环,在执行过程中,程序将立即创建m_thread_count个线程并将线程移入到线程池中,该函数执行完就立即退出了,剩余的工作将有由lambda表达式t完成。
创建完的每个子线程都会运行t函数,而只有在m_running_state为false时,子线程才会退出,否则将无限循环,每一个循环都运行了一个任务。下面详细看每一次循环时候的执行过程。首先,为了防止子线程获取任务过程中与其他子线程形成竟态,需要为获取任务过程加锁,这里使用了基于互斥锁。拿到了互斥锁之后子线程会判断任务队列中是否为空,如果为空,则线程进入阻塞状态,不消耗资源。注意该过程,使用条件变量的wait()方法时,首先条件变量会获取锁,之后再释放锁,并进入阻塞状态,这样可以让线程池中所有线程同时进入阻塞状态。
之后,当任务队列中出现任务时,将唤起其中一个线程,之后重新加锁,进入后面的流程。唤起线程的信号除了有新任务添加到队列以外,还有一种可能,就是外部命令所有子线程程退出,这回导致所有线程的阻塞状态取消,而此时任务队列中不一定还有待执行的任务,因此需要添加一个额外的判断,如果线程池运行状态为false或者任务队列为空,则跳过此轮循环。
在确实是由新添加进任务队列的信号唤起线程时,在任务队列中弹出最先加入的任务,并开始执行任务。注意在执行任务之前一定要先释放掉互斥锁,如果不释放,其他子线程将无法在该任务运行时对任务队列进行操作。此外,为了防止其中某一个子线程在执行任务的过程中出错从而造成所有线程崩溃的情况出现,需要在每一个子线程执行任务的时候使用错误捕获语句,…表示捕获所有错误,可以根据具体情况细化。
运行实例
下面给出3种运行示例,展示了使用本文实现线程池的基本流程:
#include <chrono>
#include <future>
#include <memory>
#include <string>
#include <thread>
#include <vector>
#include "safe_queue_base.h"
#include "thread_pool.h"
using namespace std;
void test_multithreads()
{
ThreadPool tp(12);
auto f = [](int a, int b, int idx) {
cout<<"in sub thread: "<<idx<<endl;
return a+b+idx; };
vector<future<int>> res(16);
for (size_t i = 0; i < res.size(); ++i) {
res[i] = tp.addTask(f, 1, 1, i);
}
this_thread::sleep_for(2s);
for (size_t i = 0; i < res.size(); ++i) {
cout << "res: " << res[i].get() << endl;
}
}
void test_message_queue()
{
ThreadPool tp(8);
tp.addQueue<int>();
tp.addQueue<double>();
auto f = [](ThreadPool* ptr, int s_time) {
auto int_queue = ptr->getQueue<int>();
auto double_queue = ptr->getQueue<double>();
if (int_queue == nullptr || double_queue == nullptr) {
throw "without int/double queue";
}
if (s_time == 1) { // write something
for (int i = 0; i < 10; ++i) {
int_queue->push(i);
double d = (1 + i) / 2.0;
double_queue->push(d);
}
}
if (s_time == 0) {
this_thread::sleep_for(1s);
auto s1 = int_queue->size();
while (s1 != int_queue->size()) {
std::this_thread::sleep_for(10ms);
s1 = int_queue->size();
}
vector<double> all_d;
double_queue->see(all_d);
for (size_t i = 0; i < s1; ++i) {
int i_value;
int_queue->pop(i_value);
cout << "idx:" << i << " int value: " << i_value << " double value:" << all_d[i] << endl;
}
}
};
tp.addTask(f, &tp, 1);
tp.addTask(f, &tp, 0);
this_thread::sleep_for(3s);
/* this_thread::sleep_for(5s); */
}
class MyThreadTask
{
shared_ptr<ThreadPool> m_thread_pool;
public:
MyThreadTask()= default;
~MyThreadTask() = default;
auto operator()(int x, int y)
{
return x * x + y;
}
};
void test_class()
{
ThreadPool tp(-1);
MyThreadTask t;
auto res = tp.addTask(t,2,3);
cout<<res.get()<<endl;
}
int main(int argc, char* argv[])
{
/* test_multithreads(); */
/* test_message_queue(); */
test_class();
return 0;
}
附件:完整程序代码
safe_queue_base.h
#ifndef SAFE_QUEUE_BASE_H
#define SAFE_QUEUE_BASE_H
#include <functional>
#include <iostream>
#include <mutex>
#include <queue>
#include <shared_mutex>
#include <vector>
struct TypeErasedData {
virtual ~TypeErasedData() {}
};
template <typename T>
struct ErasedData : public TypeErasedData {
ErasedData(T& value) : data(value) {}
T& data;
};
class SafeQueueBase
{
public:
SafeQueueBase() = default;
SafeQueueBase(SafeQueueBase&&) = delete;
SafeQueueBase(const SafeQueueBase&) = delete;
SafeQueueBase& operator=(SafeQueueBase&&) = delete;
SafeQueueBase& operator=(const SafeQueueBase&) = delete;
bool empty()
{
std::shared_lock<std::shared_timed_mutex> lck(mux);
return primaryEmpty();
}
size_t size()
{
std::shared_lock<std::shared_timed_mutex> lck(mux);
return primarySize();
}
template <typename T>
void pop(T& _data_output)
{
std::unique_lock<std::shared_timed_mutex> lck(mux);
ErasedData<T> data_output(_data_output);
primaryPop(&data_output);
}
template <typename T>
void push(const T& _data)
{
std::unique_lock<std::shared_timed_mutex> lck(mux);
ErasedData<const T> data(_data);
primaryPush(&data);
}
template <typename T>
void see(std::vector<T>& _data_output)
{
std::shared_lock<std::shared_timed_mutex> lck(mux);
ErasedData<std::vector<T>> data_output(_data_output);
primarySee(&data_output);
}
virtual ~SafeQueueBase(){};
protected:
virtual void primaryPush(TypeErasedData* _data) = 0;
virtual void primaryPop(TypeErasedData* _data) = 0;
virtual void primarySee(TypeErasedData* _data) = 0;
virtual bool primaryEmpty() = 0;
virtual size_t primarySize() = 0;
private:
std::shared_timed_mutex mux; // WARN: need to support c++14
};
template <typename T>
class SafeQueue : public SafeQueueBase
{
public:
SafeQueue() = default;
~SafeQueue() = default;
bool primaryEmpty() override
{
return m_queue.empty();
}
virtual size_t primarySize() override
{
return m_queue.size();
}
virtual void primaryPush(TypeErasedData* _data) override
{
auto data = dynamic_cast<ErasedData<const T>*>(_data);
std::cout<<data<<std::endl;
if (data == nullptr) {
throw "Error data type";
}
m_queue.emplace(data->data); // if c++11, use push instead of emplace
}
virtual void primaryPop(TypeErasedData* _data) override
{
auto data = dynamic_cast<ErasedData<T>*>(_data);
if (data == nullptr) {
throw "Error data type";
}
data->data = m_queue.front();
m_queue.pop();
}
virtual void primarySee(TypeErasedData* _data) override
{
auto data = dynamic_cast<ErasedData<std::vector<T>>*>(_data);
if (data == nullptr) {
throw "Error data type";
}
size_t size = m_queue.size();
for (size_t i = 0; i < size; ++i) {
data->data.emplace_back(m_queue.front());
m_queue.pop();
}
}
private:
std::queue<T> m_queue;
};
#endif // SAFE_QUEUE_BASE_H
thread_pool.h
#ifndef THREAD_POOL_H
#define THREAD_POOL_H
#include <algorithm>
#include <atomic>
#include <chrono>
#include <future>
#include <iostream>
#include <memory>
#include <mutex>
#include <thread>
#include <vector>
#include <string>
/* #define DEBUG */
#include "safe_queue_base.h"
class ThreadPool
{
public:
ThreadPool(std::size_t _thread_count);
ThreadPool(ThreadPool &&) = delete;
ThreadPool(const ThreadPool &) = delete;
ThreadPool &operator=(ThreadPool &&) = delete;
ThreadPool &operator=(const ThreadPool &) = delete;
~ThreadPool();
template <typename T>
void addQueue();
template <typename T>
std::shared_ptr<SafeQueue<T>> getQueue();
template <typename Func, typename... Args>
auto addTask(Func &&f, Args &&...args) -> std::future<typename std::result_of<Func(Args...)>::type>;
private:
using TaskType = typename std::function<void()>;
std::mutex mux;
void startBuildPool();
std::atomic<bool> m_running_state{false};
std::size_t m_thread_count = 0;
std::vector<std::shared_ptr<std::thread>> m_threads;
SafeQueue<TaskType> m_task_queue;
std::vector<std::shared_ptr<SafeQueueBase>> m_message_queue;
std::condition_variable m_cv;
};
inline ThreadPool::ThreadPool(std::size_t _thread_count)
{
auto sys_max_threads = std::thread::hardware_concurrency(); // get system max threads number
if (_thread_count <= 0 || _thread_count >= sys_max_threads) {
m_thread_count = sys_max_threads;
std::cerr << "WARNING: use capable thread count:" << m_thread_count << std::endl;
} else {
m_thread_count = _thread_count;
}
m_threads.reserve(m_thread_count);
m_running_state = true;
startBuildPool();
}
inline ThreadPool::~ThreadPool()
{
m_running_state = false;
m_cv.notify_all(); // break all thread in block state
for (auto &thread : m_threads) {
if (thread->joinable()) {
thread->join();
}
}
}
template <typename T>
inline void ThreadPool::addQueue()
{
std::shared_ptr<SafeQueueBase> queue = std::make_shared<SafeQueue<T>>();
m_message_queue.push_back(std::move(queue));
}
template <typename T>
inline std::shared_ptr<SafeQueue<T>> ThreadPool::getQueue()
{
for (const auto &e : m_message_queue) {
// use dynamic_pointer_cast to transform shared_ptr from base class to subclass
if (auto ptr = std::dynamic_pointer_cast<SafeQueue<T>>(e)) {
return ptr;
}
}
return nullptr;
}
template <typename Func, typename... Args>
inline auto ThreadPool::addTask(Func &&f, Args &&...args) -> std::future<typename std::result_of<Func(Args...)>::type>
{
using return_type = typename std::result_of<Func(Args...)>::type;
auto task = std::make_shared<std::packaged_task<return_type()>>(std::bind(std::forward<Func>(f), std::forward<Args>(args)...));
TaskType thread_task = [task]() { (*task)(); }; // 不需要移动了,使用 shared_ptr
m_task_queue.push(thread_task);
m_cv.notify_one();
return task->get_future();
}
inline void ThreadPool::startBuildPool()
{
for (size_t i = 0; i < m_thread_count; ++i) {
auto t = [this, i]() {
while (this->m_running_state) {
TaskType task;
std::unique_lock<std::mutex> lck(mux);
if (this->m_task_queue.empty()) {
m_cv.wait(lck);
}
// need to check again when thread pool destructed can also make wait() pass by notice_all
// and when the condition variable is waiting, it well unlock lck, this may cause many threads are waiting at the same time
if (this->m_running_state == false || m_task_queue.empty()) {
std::this_thread::sleep_for(std::chrono::milliseconds(2));
continue;
}
this->m_task_queue.pop(task);
lck.unlock();
try {
task(); // run task;
} catch (...) {
}
}
};
m_threads.emplace_back(std::make_shared<std::thread>(std::thread(t)));
}
}
#endif // THREAD_POOL_