Fix last commit

This commit is contained in:
2026-05-25 21:59:56 -07:00
parent ba8c94d971
commit 5d631b594f
11 changed files with 167 additions and 113 deletions
+25
View File
@@ -0,0 +1,25 @@
---
BasedOnStyle: WebKit
IndentWidth: 4
---
Language: Cpp
AlignConsecutiveMacros: true
AlignEscapedNewlines: Right
AlignOperands: true
AllowShortBlocksOnASingleLine: Never
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: None
AllowShortIfStatementsOnASingleLine: Never
ColumnLimit: 100
IncludeBlocks: Regroup
IndentCaseLabels: false
IndentWrappedFunctionNames: true
PointerAlignment: Right
ReflowComments: false
SortIncludes: false
SpaceAfterCStyleCast: true
SpaceAfterLogicalNot: false
SpaceBeforeAssignmentOperators: true
SpaceAfterControlStatementKeyword: true
BreakBeforeBraces: Attach
---
+2 -2
View File
@@ -11,9 +11,9 @@ httpserver: $(OBJS)
$(CC) $(CFLAGS) -o $@ $^ $(CC) $(CFLAGS) -o $@ $^
# Auto-rebuild if Makefile changes # Auto-rebuild if Makefile changes
bin/%.o: src/%.c Makefile bin/%.o: %.c Makefile
@mkdir -p bin/deps/ @mkdir -p bin/deps/
$(CC) $(CFLAGS) -MD -MF $(patsubst src/%.c,bin/deps/%.d,$<) -c -o $@ $< $(CC) $(CFLAGS) -MD -MF $(patsubst %.c,bin/deps/%.d,$<) -c -o $@ $<
include $(SRCS:%.c/bin/deps/%.d) include $(SRCS:%.c/bin/deps/%.d)
+46 -38
View File
@@ -5,11 +5,12 @@
#include <string.h> #include <string.h>
#include <unistd.h> #include <unistd.h>
static const char *EMPTY_STRING = ""; static const char* EMPTY_STRING = "";
HTTPHeaderList *http_header_list_push(HTTPHeaderList *list, const char *key, HTTPHeaderList* http_header_list_push(HTTPHeaderList* list, const char* key,
const char *value) { const char* value)
HTTPHeaderList *new = malloc(sizeof(HTTPHeaderList)); {
HTTPHeaderList* new = malloc(sizeof(HTTPHeaderList));
if (!new) { if (!new) {
return NULL; return NULL;
} }
@@ -32,9 +33,10 @@ HTTPHeaderList *http_header_list_push(HTTPHeaderList *list, const char *key,
return new; return new;
} }
void free_http_header_list(HTTPHeaderList *list) { void free_http_header_list(HTTPHeaderList* list)
{
while (list) { while (list) {
HTTPHeaderList *next = list->next; HTTPHeaderList* next = list->next;
free(list->key); free(list->key);
free(list->value); free(list->value);
free(list); free(list);
@@ -42,8 +44,9 @@ void free_http_header_list(HTTPHeaderList *list) {
} }
} }
const char *http_header_list_search(HTTPHeaderList *list, const char *key, const char* http_header_list_search(HTTPHeaderList* list, const char* key,
const char *def) { const char* def)
{
while (list) { while (list) {
if (strcmp(list->key, key) == 0) { if (strcmp(list->key, key) == 0) {
return list->value; return list->value;
@@ -53,11 +56,11 @@ const char *http_header_list_search(HTTPHeaderList *list, const char *key,
return def; return def;
} }
#define MAX_REQUEST_LENGTH 16384 #define MAX_REQUEST_LENGTH 16384
#define MAX_METHOD_LENGTH 16 #define MAX_METHOD_LENGTH 16
#define MAX_URI_LENGTH 256 #define MAX_URI_LENGTH 256
#define MAX_VERSION_LENGTH 3 #define MAX_VERSION_LENGTH 3
#define MAX_HEADER_KEY_LENGTH 2048 #define MAX_HEADER_KEY_LENGTH 2048
#define MAX_HEADER_VALUE_LENGTH 2048 #define MAX_HEADER_VALUE_LENGTH 2048
#define RETURN_IF_READ_ERROR(s) \ #define RETURN_IF_READ_ERROR(s) \
@@ -66,9 +69,10 @@ const char *http_header_list_search(HTTPHeaderList *list, const char *key,
} }
// if an error occured, req->uri is not allocated // if an error occured, req->uri is not allocated
static HTTPRequestParseResult parse_method_uri_line(FILE *stream, static HTTPRequestParseResult parse_method_uri_line(FILE* stream,
size_t *restrict bytes_read, size_t* restrict bytes_read,
HTTPRequest *restrict req) { HTTPRequest* restrict req)
{
// allow for some leeway in passing incorrect methods // allow for some leeway in passing incorrect methods
char method[MAX_METHOD_LENGTH + 1]; char method[MAX_METHOD_LENGTH + 1];
char uri[MAX_URI_LENGTH + 1]; char uri[MAX_URI_LENGTH + 1];
@@ -76,7 +80,7 @@ static HTTPRequestParseResult parse_method_uri_line(FILE *stream,
char whitespace[4]; char whitespace[4];
ssize_t signed_bytes_read; ssize_t signed_bytes_read;
#define S1(s) #s #define S1(s) #s
#define S(s) S1(s) #define S(s) S1(s)
int nconv = fscanf( int nconv = fscanf(
stream, stream,
// clang-format off // clang-format off
@@ -113,9 +117,10 @@ static HTTPRequestParseResult parse_method_uri_line(FILE *stream,
// return true if there are more headers and no error occurred, false otherwise // return true if there are more headers and no error occurred, false otherwise
// this will *not* free LIST if an error occurs // this will *not* free LIST if an error occurs
static bool next_header(FILE *stream, size_t *restrict bytes_read, static bool next_header(FILE* stream, size_t* restrict bytes_read,
HTTPRequestParseResult *restrict res, HTTPRequestParseResult* restrict res,
HTTPHeaderList *restrict *restrict list) { HTTPHeaderList* restrict* restrict list)
{
char c = fgetc(stream); char c = fgetc(stream);
RETURN_IF_READ_ERROR(stream); RETURN_IF_READ_ERROR(stream);
if (c == '\r') { if (c == '\r') {
@@ -135,18 +140,17 @@ static bool next_header(FILE *stream, size_t *restrict bytes_read,
char whitespace[3]; char whitespace[3];
ssize_t signed_bytes_read; ssize_t signed_bytes_read;
#define S1(s) #s #define S1(s) #s
#define S(s) S1(s) #define S(s) S1(s)
int nconv = int nconv = fscanf(stream,
fscanf(stream, // clang-format off
// clang-format off
"%" S(MAX_HEADER_KEY_LENGTH) "[a-zA-Z0-9.-]" "%" S(MAX_HEADER_KEY_LENGTH) "[a-zA-Z0-9.-]"
":%c" ":%c"
"%" S(MAX_HEADER_VALUE_LENGTH) "[ -~]" "%" S(MAX_HEADER_VALUE_LENGTH) "[ -~]"
"%c%c" "%c%c"
"%zn", "%zn",
// clang-format on // clang-format on
key, &whitespace[0], value, &whitespace[1], &whitespace[2], key, &whitespace[0], value, &whitespace[1], &whitespace[2],
&signed_bytes_read); &signed_bytes_read);
#undef S #undef S
*bytes_read = signed_bytes_read; *bytes_read = signed_bytes_read;
RETURN_IF_READ_ERROR(stream); RETURN_IF_READ_ERROR(stream);
@@ -161,8 +165,9 @@ static bool next_header(FILE *stream, size_t *restrict bytes_read,
return true; return true;
} }
HTTPRequestParseResult parse_http_request(FILE *stream, HTTPRequestParseResult parse_http_request(FILE* stream,
HTTPRequest *restrict out) { HTTPRequest* restrict out)
{
out->uri = EMPTY_STRING; out->uri = EMPTY_STRING;
out->path = EMPTY_STRING; out->path = EMPTY_STRING;
out->method = EMPTY_STRING; out->method = EMPTY_STRING;
@@ -182,23 +187,25 @@ HTTPRequestParseResult parse_http_request(FILE *stream,
return res; return res;
} }
void free_http_request(HTTPRequest *restrict req) { void free_http_request(HTTPRequest* restrict req)
{
if (req->method != EMPTY_STRING) { if (req->method != EMPTY_STRING) {
free((char *) req->method); free((char*)req->method);
} }
if (req->uri != EMPTY_STRING) { if (req->uri != EMPTY_STRING) {
free((char *) req->uri); free((char*)req->uri);
} }
free_http_header_list(req->headers); free_http_header_list(req->headers);
} }
const char *status_code_to_message(int status, size_t *restrict length) { const char* status_code_to_message(int status, size_t* restrict length)
{
static const struct { static const struct {
int code; int code;
const char *msg; const char* msg;
size_t size; size_t size;
} CODES[] = { } CODES[] = {
#define P(s, m) {s, m, sizeof(m) - 1} #define P(s, m) { s, m, sizeof(m) - 1 }
P(200, "OK"), P(200, "OK"),
P(201, "Created"), P(201, "Created"),
P(400, "Bad Request"), P(400, "Bad Request"),
@@ -221,12 +228,13 @@ const char *status_code_to_message(int status, size_t *restrict length) {
return NULL; return NULL;
} }
void format_http_response(FILE *stream, HTTPResponse *restrict resp) { void format_http_response(FILE* stream, HTTPResponse* restrict resp)
{
assert(status_code_to_message(resp->status, NULL)); assert(status_code_to_message(resp->status, NULL));
fprintf(stream, "HTTP/1.1 %d %s\r\n", resp->status, fprintf(stream, "HTTP/1.1 %d %s\r\n", resp->status,
status_code_to_message(resp->status, NULL)); status_code_to_message(resp->status, NULL));
fprintf(stream, "Content-Length: %zu\r\n", resp->body_length); fprintf(stream, "Content-Length: %zu\r\n", resp->body_length);
for (HTTPHeaderList *h = resp->headers; h; h = h->next) { for (HTTPHeaderList* h = resp->headers; h; h = h->next) {
fwrite(h->key, 1, h->key_length, stream); fwrite(h->key, 1, h->key_length, stream);
fwrite(": ", 1, 2, stream); fwrite(": ", 1, 2, stream);
fwrite(h->value, 1, h->value_length, stream); fwrite(h->value, 1, h->value_length, stream);
+3 -6
View File
@@ -14,13 +14,11 @@ struct _HTTPHeaderList {
HTTPHeaderList *next; HTTPHeaderList *next;
}; };
HTTPHeaderList *http_header_list_push(HTTPHeaderList *list, const char *key, HTTPHeaderList *http_header_list_push(HTTPHeaderList *list, const char *key, const char *value);
const char *value);
void free_http_header_list(HTTPHeaderList *list); void free_http_header_list(HTTPHeaderList *list);
const char *http_header_list_search(HTTPHeaderList *list, const char *key, const char *http_header_list_search(HTTPHeaderList *list, const char *key, const char *def);
const char *def);
typedef enum { typedef enum {
HRPR_OK, HRPR_OK,
@@ -39,8 +37,7 @@ typedef struct {
HTTPHeaderList *headers; HTTPHeaderList *headers;
} HTTPRequest; } HTTPRequest;
HTTPRequestParseResult parse_http_request(FILE *stream, HTTPRequestParseResult parse_http_request(FILE *stream, HTTPRequest *restrict out);
HTTPRequest *restrict out);
void free_http_request(HTTPRequest *restrict req); void free_http_request(HTTPRequest *restrict req);
+49 -37
View File
@@ -28,7 +28,7 @@ typedef struct {
bool help_flag; bool help_flag;
uint32_t port; uint32_t port;
size_t parallelism; size_t parallelism;
const char *address; const char* address;
} GlobalFlags; } GlobalFlags;
static const GlobalFlags DEFAULT_FLAGS = { static const GlobalFlags DEFAULT_FLAGS = {
@@ -40,20 +40,21 @@ static const GlobalFlags DEFAULT_FLAGS = {
}; };
// Return false on failure, true on success // Return false on failure, true on success
static bool parse_uint(const char *str, uintmax_t min, uintmax_t max, static bool parse_uint(const char* str, uintmax_t min, uintmax_t max,
uintmax_t *output) { uintmax_t* output)
{
if (isspace(*str) || *str == '+' || *str == '-') { if (isspace(*str) || *str == '+' || *str == '-') {
log_error("malformed number: \"%s\"", str); log_error("malformed number: \"%s\"", str);
return false; return false;
} }
errno = 0; errno = 0;
char *endptr; char* endptr;
uintmax_t conv = strtoumax(str, &endptr, 10); uintmax_t conv = strtoumax(str, &endptr, 10);
if (!*str || *endptr) { if (!*str || *endptr) {
log_error("malformed number: \"%s\"", str); log_error("malformed number: \"%s\"", str);
return false; return false;
} else if ((conv == UINTMAX_MAX && errno == ERANGE) || conv < min } else if ((conv == UINTMAX_MAX && errno == ERANGE) || conv < min
|| conv > max) { || conv > max) {
log_error("out of range: %s", str); log_error("out of range: %s", str);
return false; return false;
} }
@@ -61,13 +62,14 @@ static bool parse_uint(const char *str, uintmax_t min, uintmax_t max,
return true; return true;
} }
static void parse_cli_options(int argc, const char **argv, GlobalFlags *flags) { static void parse_cli_options(int argc, const char** argv, GlobalFlags* flags)
{
constexpr size_t MAX_PARALLELISM = SIZE_MAX; constexpr size_t MAX_PARALLELISM = SIZE_MAX;
constexpr uint32_t MAX_PORT = 65535; constexpr uint32_t MAX_PORT = 65535;
opterr = false; opterr = false;
int c; int c;
char pretty_flag[8]; char pretty_flag[8];
while ((c = getopt(argc, (char *const *) argv, ":ha:p:")) >= 0) { while ((c = getopt(argc, (char* const*)argv, ":ha:p:")) >= 0) {
if (isprint(c)) { if (isprint(c)) {
snprintf(pretty_flag, sizeof(pretty_flag), "%c", optopt); snprintf(pretty_flag, sizeof(pretty_flag), "%c", optopt);
} else { } else {
@@ -118,9 +120,10 @@ static void parse_cli_options(int argc, const char **argv, GlobalFlags *flags) {
} }
} }
static void print_help(FILE *file) { static void print_help(FILE* file)
{
fprintf(file, fprintf(file,
"usage: httpserver [-h] [-a ADDRESS] [-p PARALLELISM] <PORT>\n"); "usage: httpserver [-h] [-a ADDRESS] [-p PARALLELISM] <PORT>\n");
fprintf(file, " -h print this message, then exit\n"); fprintf(file, " -h print this message, then exit\n");
fprintf( fprintf(
file, file,
@@ -128,19 +131,20 @@ static void print_help(FILE *file) {
fprintf(file, " -a bind to ADDRESS (default: 127.0.0.1)\n"); fprintf(file, " -a bind to ADDRESS (default: 127.0.0.1)\n");
} }
constexpr int WORKER_BLOCKED_SIGNALS[] = {SIGTERM, SIGINT, SIGHUP}; constexpr int WORKER_BLOCKED_SIGNALS[] = { SIGTERM, SIGINT, SIGHUP };
constexpr size_t N_WORKER_BLOCKED_SIGNALS = constexpr size_t N_WORKER_BLOCKED_SIGNALS = sizeof(WORKER_BLOCKED_SIGNALS) / sizeof(int);
sizeof(WORKER_BLOCKED_SIGNALS) / sizeof(int);
static bool shutdown_flag = false; static bool shutdown_flag = false;
static void signal_handler(int signal) { static void signal_handler(int signal)
{
if (shutdown_flag) { if (shutdown_flag) {
_exit(signal == SIGTERM ? 0 : EXIT_FAILURE); _exit(signal == SIGTERM ? 0 : EXIT_FAILURE);
} }
shutdown_flag = true; shutdown_flag = true;
} }
static void setup_signals() { static void setup_signals()
{
struct sigaction act = { struct sigaction act = {
.sa_handler = signal_handler, .sa_handler = signal_handler,
.sa_flags = 0, .sa_flags = 0,
@@ -154,7 +158,8 @@ static void setup_signals() {
} }
} }
static int parse_result_to_status(HTTPRequestParseResult res) { static int parse_result_to_status(HTTPRequestParseResult res)
{
switch (res) { switch (res) {
case HRPR_OK: case HRPR_OK:
return 200; return 200;
@@ -170,9 +175,10 @@ static int parse_result_to_status(HTTPRequestParseResult res) {
} }
} }
static void send_simple_response(FILE *stream, int status) { static void send_simple_response(FILE* stream, int status)
{
size_t status_msg_len; size_t status_msg_len;
const char *status_msg = status_code_to_message(status, &status_msg_len); const char* status_msg = status_code_to_message(status, &status_msg_len);
HTTPResponse resp = { HTTPResponse resp = {
.status = status, .status = status,
.headers = NULL, .headers = NULL,
@@ -183,16 +189,18 @@ static void send_simple_response(FILE *stream, int status) {
fputc('\n', stream); fputc('\n', stream);
} }
static void write_audit_log_entry(HTTPRequest *restrict req, int status) { static void write_audit_log_entry(HTTPRequest* restrict req, int status)
{
fprintf(stderr, "%s,%s,%d,%s\n", req->method, req->uri, status, fprintf(stderr, "%s,%s,%d,%s\n", req->method, req->uri, status,
http_header_list_search(req->headers, "Request-ID", "0")); http_header_list_search(req->headers, "Request-ID", "0"));
} }
static void handle_get_request(FILE *conn, HTTPRequest *restrict req) { static void handle_get_request(FILE* conn, HTTPRequest* restrict req)
{
int status = 200; int status = 200;
struct stat statbuf; struct stat statbuf;
flockfile(stderr); flockfile(stderr);
FILE *file_handle = fopen(req->path, "r"); FILE* file_handle = fopen(req->path, "r");
if (!file_handle) { if (!file_handle) {
if (errno == ENOENT) { if (errno == ENOENT) {
status = 404; status = 404;
@@ -234,15 +242,15 @@ static void handle_get_request(FILE *conn, HTTPRequest *restrict req) {
fclose(file_handle); fclose(file_handle);
} }
static ssize_t get_content_length(HTTPRequest *restrict req) { static ssize_t get_content_length(HTTPRequest* restrict req)
const char *text = {
http_header_list_search(req->headers, "Content-Length", NULL); const char* text = http_header_list_search(req->headers, "Content-Length", NULL);
if (!text) { if (!text) {
return 0; return 0;
} else if (isspace(*text) || *text == '-' || *text == '+') { } else if (isspace(*text) || *text == '-' || *text == '+') {
return -1; return -1;
} }
char *endptr; char* endptr;
uintmax_t conv = strtoumax(text, &endptr, 10); uintmax_t conv = strtoumax(text, &endptr, 10);
if (*endptr || (conv == UINTMAX_MAX && errno == ERANGE) if (*endptr || (conv == UINTMAX_MAX && errno == ERANGE)
|| conv > SIZE_MAX) { || conv > SIZE_MAX) {
@@ -251,7 +259,8 @@ static ssize_t get_content_length(HTTPRequest *restrict req) {
return conv; return conv;
} }
static void handle_put_request(FILE *conn, HTTPRequest *restrict req) { static void handle_put_request(FILE* conn, HTTPRequest* restrict req)
{
ssize_t content_length = get_content_length(req); ssize_t content_length = get_content_length(req);
if (content_length < 0) { if (content_length < 0) {
write_audit_log_entry(req, 400); write_audit_log_entry(req, 400);
@@ -269,10 +278,10 @@ static void handle_put_request(FILE *conn, HTTPRequest *restrict req) {
char read_buff[4096]; char read_buff[4096];
while (content_length) { while (content_length) {
ssize_t read_size = fread(read_buff, 1, ssize_t read_size = fread(read_buff, 1,
(size_t) content_length < sizeof(read_buff) (size_t)content_length < sizeof(read_buff)
? (size_t) content_length ? (size_t)content_length
: sizeof(read_buff), : sizeof(read_buff),
conn); conn);
if (ferror(conn) || write(temp_fd, read_buff, read_size) < 0) { if (ferror(conn) || write(temp_fd, read_buff, read_size) < 0) {
write_audit_log_entry(req, 500); write_audit_log_entry(req, 500);
send_simple_response(conn, 500); send_simple_response(conn, 500);
@@ -313,8 +322,9 @@ write_status_and_unlock:
} }
} }
static void handle_connection(void *arg) { static void handle_connection(void* arg)
FILE *conn = arg; {
FILE* conn = arg;
HTTPRequest req; HTTPRequest req;
HTTPRequestParseResult res = parse_http_request(conn, &req); HTTPRequestParseResult res = parse_http_request(conn, &req);
if (res != HRPR_OK) { if (res != HRPR_OK) {
@@ -336,11 +346,13 @@ static void handle_connection(void *arg) {
fclose(conn); fclose(conn);
} }
static void fclose_free_func(void *file) { static void fclose_free_func(void* file)
{
fclose(file); fclose(file);
} }
int main(int argc, const char **argv) { int main(int argc, const char** argv)
{
setenv("POSIXLY_CORRECT", "1", true); setenv("POSIXLY_CORRECT", "1", true);
GlobalFlags flags = DEFAULT_FLAGS; GlobalFlags flags = DEFAULT_FLAGS;
parse_cli_options(argc, argv, &flags); parse_cli_options(argc, argv, &flags);
@@ -355,11 +367,11 @@ int main(int argc, const char **argv) {
for (size_t i = 0; i < N_WORKER_BLOCKED_SIGNALS; ++i) { for (size_t i = 0; i < N_WORKER_BLOCKED_SIGNALS; ++i) {
sigaddset(&worker_block_set, WORKER_BLOCKED_SIGNALS[i]); sigaddset(&worker_block_set, WORKER_BLOCKED_SIGNALS[i]);
} }
Server *server = make_server(flags.address, flags.port); Server* server = make_server(flags.address, flags.port);
if (!server) { if (!server) {
return EXIT_FAILURE; return EXIT_FAILURE;
} }
ThreadPool *pool = make_thread_pool(flags.parallelism, worker_block_set); ThreadPool* pool = make_thread_pool(flags.parallelism, worker_block_set);
if (!pool) { if (!pool) {
destroy_server(server); destroy_server(server);
return EXIT_FAILURE; return EXIT_FAILURE;
@@ -375,7 +387,7 @@ int main(int argc, const char **argv) {
} else if (conn_fd < 0) { } else if (conn_fd < 0) {
continue; continue;
} }
FILE *conn = fdopen(conn_fd, "w+"); FILE* conn = fdopen(conn_fd, "w+");
if (!conn) { if (!conn) {
close(conn_fd); close(conn_fd);
continue; continue;
+8 -5
View File
@@ -12,7 +12,8 @@ struct _Server {
int socket; int socket;
}; };
Server *make_server(const char *text_addr, uint32_t port) { Server* make_server(const char* text_addr, uint32_t port)
{
struct sockaddr_in addr = { struct sockaddr_in addr = {
.sin_family = AF_INET, .sin_family = AF_INET,
.sin_port = htons(port), .sin_port = htons(port),
@@ -21,7 +22,7 @@ Server *make_server(const char *text_addr, uint32_t port) {
log_error("bad IPv4 address: \"%s\"", text_addr); log_error("bad IPv4 address: \"%s\"", text_addr);
return NULL; return NULL;
} }
Server *server = malloc_safe(sizeof(Server)); Server* server = malloc_safe(sizeof(Server));
server->socket = socket(AF_INET, SOCK_STREAM, 0); server->socket = socket(AF_INET, SOCK_STREAM, 0);
if (server->socket < 0) { if (server->socket < 0) {
log_errno("socket"); log_errno("socket");
@@ -36,7 +37,7 @@ Server *make_server(const char *text_addr, uint32_t port) {
free(server); free(server);
return NULL; return NULL;
} }
if (bind(server->socket, (struct sockaddr *) &addr, sizeof(addr)) < 0) { if (bind(server->socket, (struct sockaddr*)&addr, sizeof(addr)) < 0) {
log_errno("bind"); log_errno("bind");
close(server->socket); close(server->socket);
free(server); free(server);
@@ -51,11 +52,13 @@ Server *make_server(const char *text_addr, uint32_t port) {
return server; return server;
} }
void destroy_server(Server *server) { void destroy_server(Server* server)
{
close(server->socket); close(server->socket);
free(server); free(server);
} }
int server_accept(Server *server) { int server_accept(Server* server)
{
return accept(server->socket, NULL, NULL); return accept(server->socket, NULL, NULL);
} }
View File
+21 -17
View File
@@ -7,24 +7,25 @@
struct thread_pool_queue { struct thread_pool_queue {
Task task; Task task;
void *arg; void* arg;
FreeFunc ff; FreeFunc ff;
struct thread_pool_queue *next; struct thread_pool_queue* next;
}; };
struct _ThreadPool { struct _ThreadPool {
bool running; bool running;
size_t nthreads; size_t nthreads;
sigset_t thread_sig_mask; sigset_t thread_sig_mask;
pthread_t *threads; pthread_t* threads;
pthread_cond_t queue_cnd; pthread_cond_t queue_cnd;
pthread_mutex_t queue_mtx; pthread_mutex_t queue_mtx;
struct thread_pool_queue *queue; struct thread_pool_queue* queue;
}; };
// return false if we need to stop // return false if we need to stop
static bool get_task(ThreadPool *pool, Task *task, void **task_arg) { static bool get_task(ThreadPool* pool, Task* task, void** task_arg)
{
pthread_mutex_lock(&pool->queue_mtx); pthread_mutex_lock(&pool->queue_mtx);
if (!pool->running) { if (!pool->running) {
pthread_mutex_unlock(&pool->queue_mtx); pthread_mutex_unlock(&pool->queue_mtx);
@@ -36,7 +37,7 @@ static bool get_task(ThreadPool *pool, Task *task, void **task_arg) {
pthread_mutex_unlock(&pool->queue_mtx); pthread_mutex_unlock(&pool->queue_mtx);
return false; return false;
} }
struct thread_pool_queue *ent = pool->queue; struct thread_pool_queue* ent = pool->queue;
if (ent) { if (ent) {
pool->queue = pool->queue->next; pool->queue = pool->queue->next;
pthread_mutex_unlock(&pool->queue_mtx); pthread_mutex_unlock(&pool->queue_mtx);
@@ -49,19 +50,21 @@ static bool get_task(ThreadPool *pool, Task *task, void **task_arg) {
abort(); abort();
} }
static void *pool_thread_function(void *arg) { static void* pool_thread_function(void* arg)
ThreadPool *pool = arg; {
ThreadPool* pool = arg;
pthread_sigmask(SIG_SETMASK, &pool->thread_sig_mask, NULL); pthread_sigmask(SIG_SETMASK, &pool->thread_sig_mask, NULL);
Task task; Task task;
void *task_arg; void* task_arg;
while (get_task(pool, &task, &task_arg)) { while (get_task(pool, &task, &task_arg)) {
task(task_arg); task(task_arg);
} }
return NULL; return NULL;
} }
ThreadPool *make_thread_pool(size_t parallelism, sigset_t sig_mask) { ThreadPool* make_thread_pool(size_t parallelism, sigset_t sig_mask)
ThreadPool *pool = malloc_safe(sizeof(ThreadPool)); {
ThreadPool* pool = malloc_safe(sizeof(ThreadPool));
pthread_mutex_init(&pool->queue_mtx, NULL); pthread_mutex_init(&pool->queue_mtx, NULL);
pthread_cond_init(&pool->queue_cnd, NULL); pthread_cond_init(&pool->queue_cnd, NULL);
pool->running = true; pool->running = true;
@@ -81,7 +84,8 @@ ThreadPool *make_thread_pool(size_t parallelism, sigset_t sig_mask) {
return pool; return pool;
} }
void destroy_thread_pool(ThreadPool *pool) { void destroy_thread_pool(ThreadPool* pool)
{
pthread_mutex_lock(&pool->queue_mtx); pthread_mutex_lock(&pool->queue_mtx);
pool->running = false; pool->running = false;
pthread_cond_broadcast(&pool->queue_cnd); pthread_cond_broadcast(&pool->queue_cnd);
@@ -92,9 +96,9 @@ void destroy_thread_pool(ThreadPool *pool) {
free(pool->threads); free(pool->threads);
pthread_mutex_destroy(&pool->queue_mtx); pthread_mutex_destroy(&pool->queue_mtx);
pthread_cond_destroy(&pool->queue_cnd); pthread_cond_destroy(&pool->queue_cnd);
struct thread_pool_queue *queue = pool->queue; struct thread_pool_queue* queue = pool->queue;
while (queue) { while (queue) {
struct thread_pool_queue *next = queue->next; struct thread_pool_queue* next = queue->next;
if (queue->ff) { if (queue->ff) {
queue->ff(queue->arg); queue->ff(queue->arg);
} }
@@ -104,10 +108,10 @@ void destroy_thread_pool(ThreadPool *pool) {
free(pool); free(pool);
} }
void thread_pool_enqueue(ThreadPool *pool, Task task, void *arg, FreeFunc ff) { void thread_pool_enqueue(ThreadPool* pool, Task task, void* arg, FreeFunc ff)
{
pthread_mutex_lock(&pool->queue_mtx); pthread_mutex_lock(&pool->queue_mtx);
struct thread_pool_queue *new = struct thread_pool_queue* new = malloc_safe(sizeof(struct thread_pool_queue));
malloc_safe(sizeof(struct thread_pool_queue));
new->task = task; new->task = task;
new->arg = arg; new->arg = arg;
new->ff = ff; new->ff = ff;
View File
+11 -6
View File
@@ -7,9 +7,10 @@
#include <string.h> #include <string.h>
#include <time.h> #include <time.h>
void *realloc_safe(void *oldptr, size_t size) { void* realloc_safe(void* oldptr, size_t size)
{
static const char OOM_MSG[] = "fatal: out of memory\n"; static const char OOM_MSG[] = "fatal: out of memory\n";
void *ptr = realloc(oldptr, size); void* ptr = realloc(oldptr, size);
if (size && !ptr) { if (size && !ptr) {
fwrite(OOM_MSG, 1, sizeof(OOM_MSG) - 1, stderr); fwrite(OOM_MSG, 1, sizeof(OOM_MSG) - 1, stderr);
abort(); abort();
@@ -17,12 +18,14 @@ void *realloc_safe(void *oldptr, size_t size) {
return ptr; return ptr;
} }
void *malloc_safe(size_t size) { void* malloc_safe(size_t size)
{
return realloc_safe(NULL, size); return realloc_safe(NULL, size);
} }
// asprintf is not POSIX // asprintf is not POSIX
int alloc_sprintf(char *restrict *restrict out, const char *restrict fmt, ...) { int alloc_sprintf(char* restrict* restrict out, const char* restrict fmt, ...)
{
va_list args; va_list args;
va_start(args, fmt); va_start(args, fmt);
va_list args2; va_list args2;
@@ -35,7 +38,8 @@ int alloc_sprintf(char *restrict *restrict out, const char *restrict fmt, ...) {
return written; return written;
} }
void log_error(const char *restrict fmt, ...) { void log_error(const char* restrict fmt, ...)
{
time_t cur_time = time(NULL); time_t cur_time = time(NULL);
struct tm tm; struct tm tm;
localtime_r(&cur_time, &tm); localtime_r(&cur_time, &tm);
@@ -49,6 +53,7 @@ void log_error(const char *restrict fmt, ...) {
fputc('\n', stderr); fputc('\n', stderr);
} }
void log_errno(const char *detail) { void log_errno(const char* detail)
{
log_error("%s: %s", detail, strerror(errno)); log_error("%s: %s", detail, strerror(errno));
} }
+2 -2
View File
@@ -4,9 +4,9 @@
#include <stddef.h> #include <stddef.h>
#if __has_attribute(format) #if __has_attribute(format)
# define PRINTF_LIKE(i, j) __attribute__((format(printf, i, j))) #define PRINTF_LIKE(i, j) __attribute__((format(printf, i, j)))
#else #else
# define PRINTF_LIKE(i, j) #define PRINTF_LIKE(i, j)
#endif #endif
void *realloc_safe(void *oldptr, size_t size); void *realloc_safe(void *oldptr, size_t size);