#include "threadpool.h" #include "util.h" #include #include struct thread_pool_queue { Task task; void *arg; struct thread_pool_queue *next; }; struct _ThreadPool { bool running; size_t nthreads; sigset_t thread_sig_mask; pthread_t *threads; pthread_cond_t queue_cnd; pthread_mutex_t queue_mtx; struct thread_pool_queue *queue; }; // return false if we need to stop static bool get_task(ThreadPool *pool, Task *task, void **task_arg) { pthread_mutex_lock(&pool->queue_mtx); if (!pool->running) { pthread_mutex_unlock(&pool->queue_mtx); return false; } while (true) { pthread_cond_wait(&pool->queue_cnd, &pool->queue_mtx); if (!pool->running) { pthread_mutex_unlock(&pool->queue_mtx); return false; } struct thread_pool_queue *ent = pool->queue; if (ent) { pool->queue = pool->queue->next; pthread_mutex_unlock(&pool->queue_mtx); *task = ent->task; *task_arg = ent->arg; free(ent); return true; } } abort(); } static void *pool_thread_function(void *arg) { ThreadPool *pool = arg; pthread_sigmask(SIG_SETMASK, &pool->thread_sig_mask, NULL); Task task; void *task_arg; while (get_task(pool, &task, &task_arg)) { task(task_arg); } return NULL; } ThreadPool *make_thread_pool(size_t parallelism, sigset_t sig_mask) { ThreadPool *pool = malloc_safe(sizeof(ThreadPool)); pthread_mutex_init(&pool->queue_mtx, NULL); pthread_cond_init(&pool->queue_cnd, NULL); pool->running = true; pool->queue = NULL; pool->nthreads = parallelism; pool->thread_sig_mask = sig_mask; pool->threads = malloc_safe(sizeof(pthread_t) * parallelism); // create don't race with any received signals sigset_t sset_full; sigfillset(&sset_full); sigset_t sset_save; pthread_sigmask(SIG_SETMASK, &sset_full, &sset_save); for (size_t i = 0; i < parallelism; ++i) { pthread_create(&pool->threads[i], NULL, &pool_thread_function, pool); } pthread_sigmask(SIG_SETMASK, &sset_save, NULL); return pool; } void destroy_thread_pool(ThreadPool *pool) { pthread_mutex_lock(&pool->queue_mtx); pool->running = false; pthread_cond_broadcast(&pool->queue_cnd); pthread_mutex_unlock(&pool->queue_mtx); for (size_t i = 0; i < pool->nthreads; ++i) { pthread_join(pool->threads[i], NULL); } free(pool->threads); pthread_mutex_destroy(&pool->queue_mtx); pthread_cond_destroy(&pool->queue_cnd); free(pool); } void thread_pool_enqueue(ThreadPool *pool, Task task, void *arg) { pthread_mutex_lock(&pool->queue_mtx); struct thread_pool_queue *new = malloc_safe(sizeof(struct thread_pool_queue)); new->task = task; new->arg = arg; new->next = pool->queue; pool->queue = new; pthread_cond_signal(&pool->queue_cnd); pthread_mutex_unlock(&pool->queue_mtx); }