From 1005650ea61d0c50bf360a4b657042e53840cf2f Mon Sep 17 00:00:00 2001 From: Alexander Rosenberg Date: Mon, 25 May 2026 17:25:50 -0700 Subject: [PATCH] Finish initial implementation --- Makefile | 9 +- src/http.c | 236 ++++++++++++++++++++++++++++++++++++ src/http.h | 57 +++++++++ src/main.c | 309 +++++++++++++++++++++++++++++++++++++++++++---- src/server.c | 61 ++++++++++ src/server.h | 18 +++ src/threadpool.c | 95 +++++++++++++-- src/threadpool.h | 7 +- src/util.c | 57 ++++++--- src/util.h | 18 +-- 10 files changed, 804 insertions(+), 63 deletions(-) create mode 100644 src/http.c create mode 100644 src/http.h create mode 100644 src/server.c create mode 100644 src/server.h diff --git a/Makefile b/Makefile index 5593fee..e6153ef 100644 --- a/Makefile +++ b/Makefile @@ -1,7 +1,7 @@ CC=gcc -CFLAGS=-Wall -Wextra -Wpedantic -std=c23 -D_POSIX_C_SOURCE=200112L -O2 +CFLAGS=-Wall -Wextra -Wpedantic -std=c23 -D_POSIX_C_SOURCE=200809L -pthread -O2 CFLAGS+=-Og -g -fsanitize=address,undefined -SRCS=main.c threadpool.c util.c +SRCS=main.c threadpool.c util.c server.c http.c OBJS=$(SRCS:%.c=bin/%.o) @@ -10,9 +10,10 @@ all: httpserver httpserver: $(OBJS) $(CC) $(CFLAGS) -o $@ $^ -bin/%.o: src/%.c +# Auto-rebuild if Makefile changes +bin/%.o: src/%.c Makefile @mkdir -p bin/deps/ - $(CC) $(CFLAGS) -MD -MF $(patsubst src/%.c,bin/deps/%.d,$^) -c -o $@ $^ + $(CC) $(CFLAGS) -MD -MF $(patsubst src/%.c,bin/deps/%.d,$<) -c -o $@ $< include $(SRCS:%.c/bin/deps/%.d) diff --git a/src/http.c b/src/http.c new file mode 100644 index 0000000..8daea48 --- /dev/null +++ b/src/http.c @@ -0,0 +1,236 @@ +#include "http.h" + +#include +#include +#include +#include + +static const char *EMPTY_STRING = ""; + +HTTPHeaderList *http_header_list_push(HTTPHeaderList *list, const char *key, + const char *value) { + HTTPHeaderList *new = malloc(sizeof(HTTPHeaderList)); + if (!new) { + return NULL; + } + new->key_length = strlen(key); + new->key = malloc(new->key_length + 1); + if (!new->key) { + free(new); + return NULL; + } + memcpy(new->key, key, new->key_length + 1); + new->value_length = strlen(value); + new->value = malloc(new->value_length + 1); + if (!new->value) { + free(new->key); + free(new); + return NULL; + } + memcpy(new->value, value, new->value_length + 1); + new->next = list; + return new; +} + +void free_http_header_list(HTTPHeaderList *list) { + while (list) { + HTTPHeaderList *next = list->next; + free(list->key); + free(list->value); + free(list); + list = next; + } +} + +const char *http_header_list_search(HTTPHeaderList *list, const char *key, + const char *def) { + while (list) { + if (strcmp(list->key, key) == 0) { + return list->value; + } + list = list->next; + } + return def; +} + +#define MAX_REQUEST_LENGTH 16384 +#define MAX_METHOD_LENGTH 16 +#define MAX_URI_LENGTH 256 +#define MAX_VERSION_LENGTH 3 +#define MAX_HEADER_KEY_LENGTH 2048 +#define MAX_HEADER_VALUE_LENGTH 2048 + +#define RETURN_IF_READ_ERROR(s) \ + if (ferror((s))) { \ + return HRPR_READ_FAILED; \ + } + +// if an error occured, req->uri is not allocated +static HTTPRequestParseResult parse_method_uri_line(FILE *stream, + size_t *restrict bytes_read, + HTTPRequest *restrict req) { + // allow for some leeway in passing incorrect methods + char method[MAX_METHOD_LENGTH + 1]; + char uri[MAX_URI_LENGTH + 1]; + char version_str[MAX_VERSION_LENGTH + 1]; + char whitespace[4]; + ssize_t signed_bytes_read; +#define S1(s) #s +#define S(s) S1(s) + int nconv = fscanf( + stream, + // clang-format off + "%" S(MAX_METHOD_LENGTH) "[A-Z]" + "%c" + "%" S(MAX_URI_LENGTH) "[^ \n\r]" + "%c" + "HTTP/%" S(MAX_VERSION_LENGTH) "[0-9.]" + "%c%c" + "%zn", + // clang-format on + method, &whitespace[0], uri, &whitespace[1], version_str, + &whitespace[2], &whitespace[3], &signed_bytes_read); +#undef S + *bytes_read = signed_bytes_read; + RETURN_IF_READ_ERROR(stream); + if (nconv != 7 || memcmp(whitespace, " \r\n", 4) != 0) { + return HRPR_BAD_FORMAT; + } + req->method = strdup(method); + if (!req->method) { + return HRPR_NO_MEM; + } + req->uri = strdup(uri); + if (!req->uri) { + return HRPR_NO_MEM; + } + req->path = req->uri; + while (*req->path == '/') { + ++req->path; + } + return strcmp(version_str, "1.1") == 0 ? HRPR_OK : HRPR_BAD_VERSION; +} + +// return true if there are more headers and no error occurred, false otherwise +// this will *not* free LIST if an error occurs +static bool next_header(FILE *stream, size_t *restrict bytes_read, + HTTPRequestParseResult *restrict res, + HTTPHeaderList *restrict *restrict list) { + char c = fgetc(stream); + RETURN_IF_READ_ERROR(stream); + if (c == '\r') { + c = fgetc(stream); + RETURN_IF_READ_ERROR(stream); + if (c != '\n') { + *res = HRPR_BAD_FORMAT; + return false; + } + *bytes_read = 2; + *res = HRPR_OK; + return false; + } + ungetc(c, stream); + char key[MAX_HEADER_KEY_LENGTH + 1]; + char value[MAX_HEADER_VALUE_LENGTH + 1]; + char whitespace[3]; + ssize_t signed_bytes_read; +#define S1(s) #s +#define S(s) S1(s) + int nconv = + fscanf(stream, + // clang-format off + "%" S(MAX_HEADER_KEY_LENGTH) "[a-zA-Z0-9.-]" + ":%c" + "%" S(MAX_HEADER_VALUE_LENGTH) "[ -~]" + "%c%c" + "%zn", + // clang-format on + key, &whitespace[0], value, &whitespace[1], &whitespace[2], + &signed_bytes_read); +#undef S + *bytes_read = signed_bytes_read; + RETURN_IF_READ_ERROR(stream); + if (nconv != 5 || memcmp(whitespace, " \r\n", 3) != 0) { + return HRPR_BAD_FORMAT; + } + *list = http_header_list_push(*list, key, value); + if (!*list) { + *res = HRPR_NO_MEM; + return false; + } + return true; +} + +HTTPRequestParseResult parse_http_request(FILE *stream, + HTTPRequest *restrict out) { + out->uri = EMPTY_STRING; + out->path = EMPTY_STRING; + out->method = EMPTY_STRING; + size_t total_bytes; + HTTPRequestParseResult res; + if ((res = parse_method_uri_line(stream, &total_bytes, out)) != HRPR_OK) { + return res; + } + out->headers = NULL; + size_t bytes_read; + while (next_header(stream, &bytes_read, &res, &out->headers)) { + if ((total_bytes += bytes_read) > MAX_REQUEST_LENGTH) { + res = HRPR_BAD_FORMAT; + break; + } + } + return res; +} + +void free_http_request(HTTPRequest *restrict req) { + if (req->method != EMPTY_STRING) { + free((char *) req->method); + } + if (req->uri != EMPTY_STRING) { + free((char *) req->uri); + } + free_http_header_list(req->headers); +} + +const char *status_code_to_message(int status, size_t *restrict length) { + static const struct { + int code; + const char *msg; + size_t size; + } CODES[] = { +#define P(s, m) {s, m, sizeof(m) - 1} + P(200, "OK"), + P(201, "Created"), + P(400, "Bad Request"), + P(403, "Forbidden"), + P(404, "Not Found"), + P(500, "Internal Server Error"), + P(501, "Not Implemented"), + P(505, "Version Not Supported"), +#undef P + }; + static const size_t NCODES = sizeof(CODES) / sizeof(CODES[0]); + for (size_t i = 0; i < NCODES; ++i) { + if (status == CODES[i].code) { + if (length) { + *length = CODES[i].size; + } + return CODES[i].msg; + } + } + return NULL; +} + +void format_http_response(FILE *stream, HTTPResponse *restrict resp) { + assert(status_code_to_message(resp->status, NULL)); + fprintf(stream, "HTTP/1.1 %d %s\r\n", resp->status, + status_code_to_message(resp->status, NULL)); + fprintf(stream, "Content-Length: %zu\r\n", resp->body_length); + for (HTTPHeaderList *h = resp->headers; h; h = h->next) { + fwrite(h->key, 1, h->key_length, stream); + fwrite(": ", 1, 2, stream); + fwrite(h->value, 1, h->value_length, stream); + fwrite("\r\n", 1, 2, stream); + } + fwrite("\r\n", 1, 2, stream); +} diff --git a/src/http.h b/src/http.h new file mode 100644 index 0000000..4aedabd --- /dev/null +++ b/src/http.h @@ -0,0 +1,57 @@ +#ifndef INCLUDED_HTTP_H +#define INCLUDED_HTTP_H + +#include +#include + +typedef struct _HTTPHeaderList HTTPHeaderList; + +struct _HTTPHeaderList { + char *key; + size_t key_length; + char *value; + size_t value_length; + HTTPHeaderList *next; +}; + +HTTPHeaderList *http_header_list_push(HTTPHeaderList *list, const char *key, + const char *value); + +void free_http_header_list(HTTPHeaderList *list); + +const char *http_header_list_search(HTTPHeaderList *list, const char *key, + const char *def); + +typedef enum { + HRPR_OK, + HRPR_NO_MEM, + HRPR_BAD_VERSION, + HRPR_BAD_FORMAT, + HRPR_READ_FAILED, +} HTTPRequestParseResult; + +typedef struct { + const char *method; + const char *uri; + // relative (shares memory with uri) + const char *path; + + HTTPHeaderList *headers; +} HTTPRequest; + +HTTPRequestParseResult parse_http_request(FILE *stream, + HTTPRequest *restrict out); + +void free_http_request(HTTPRequest *restrict req); + +typedef struct { + int status; + HTTPHeaderList *headers; + size_t body_length; +} HTTPResponse; + +const char *status_code_to_message(int status, size_t *restrict length); + +void format_http_response(FILE *stream, HTTPResponse *restrict resp); + +#endif diff --git a/src/main.c b/src/main.c index bab0de9..4088372 100644 --- a/src/main.c +++ b/src/main.c @@ -1,17 +1,25 @@ +#include "http.h" +#include "server.h" +#include "threadpool.h" #include "util.h" #include #include -#include +#include +#include +#include #include #include +#include +#include #include typedef struct { bool opt_error; bool help_flag; - int port; - int parallelism; + uint32_t port; + size_t parallelism; + const char *address; } GlobalFlags; static const GlobalFlags DEFAULT_FLAGS = { @@ -19,18 +27,24 @@ static const GlobalFlags DEFAULT_FLAGS = { .help_flag = false, .port = 0, .parallelism = 1, + .address = "127.0.0.1", }; // Return false on failure, true on success -static bool parse_int(const char *str, int min, int max, int *output) { +static bool parse_uint(const char *str, uintmax_t min, uintmax_t max, + uintmax_t *output) { + if (isspace(*str) || *str == '+' || *str == '-') { + log_error("malformed number: \"%s\"", str); + return false; + } errno = 0; char *endptr; - long conv = strtol(str, &endptr, 10); + uintmax_t conv = strtoumax(str, &endptr, 10); if (!*str || *endptr) { log_error("malformed number: \"%s\"", str); return false; - } else if (((conv == LONG_MIN || conv == LONG_MAX) && errno == ERANGE) - || conv < min || conv > max) { + } else if ((conv == UINTMAX_MAX && errno == ERANGE) || conv < min + || conv > max) { log_error("out of range: %s", str); return false; } @@ -38,31 +52,36 @@ static bool parse_int(const char *str, int min, int max, int *output) { return true; } -void parse_cli_options(int argc, const char **argv, GlobalFlags *flags) { - constexpr unsigned long MAX_PARALLELISM = INT_MAX; - constexpr unsigned long MAX_PORT = 65535; +static void parse_cli_options(int argc, const char **argv, GlobalFlags *flags) { + constexpr size_t MAX_PARALLELISM = SIZE_MAX; + constexpr uint32_t MAX_PORT = 65535; opterr = false; int c; char pretty_flag[8]; - while ((c = getopt(argc, (char *const *) argv, "hp:")) >= 0) { + while ((c = getopt(argc, (char *const *) argv, ":ha:p:")) >= 0) { if (isprint(c)) { - snprintf(pretty_flag, sizeof(pretty_flag), "%c", c); + snprintf(pretty_flag, sizeof(pretty_flag), "%c", optopt); } else { #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wformat-truncation" // `c` is actually in the range [0,255], and so will be at most 2 // hex digits - snprintf(pretty_flag, sizeof(pretty_flag), "0x%02x", c); + snprintf(pretty_flag, sizeof(pretty_flag), "0x%02x", optopt); #pragma GCC diagnostic pop } switch (c) { case 'h': flags->help_flag = true; return; - case 'p': - if (!parse_int(optarg, 1, MAX_PARALLELISM, &flags->parallelism)) { + case 'p': { + uintmax_t conv; + if (!parse_uint(optarg, 1, MAX_PARALLELISM, &conv)) { flags->opt_error = true; } + flags->parallelism = conv; + } break; + case 'a': + flags->address = optarg; break; case ':': flags->opt_error = true; @@ -74,28 +93,242 @@ void parse_cli_options(int argc, const char **argv, GlobalFlags *flags) { break; } } - if (optind == argc) { + uintmax_t port_conv; + if (flags->opt_error) { + return; + } else if (optind == argc) { flags->opt_error = true; log_error("no port provided"); } else if (optind != argc - 1) { flags->opt_error = true; log_error("extra positional arguments after port"); - } else if (!parse_int(argv[optind], 1, MAX_PORT, &flags->port)) { + } else if (!parse_uint(argv[optind], 1, MAX_PORT, &port_conv)) { flags->opt_error = true; + } else { + flags->port = port_conv; } } -void print_help(FILE *file) { - fprintf(file, "usage: httpserver [-h] [-p PARALLELISM] \n"); +static void print_help(FILE *file) { + fprintf(file, + "usage: httpserver [-h] [-a ADDRESS] [-p PARALLELISM] \n"); fprintf(file, " -h print this message, then exit\n"); fprintf( file, " -p use PARALLELISM threads for processing requests (default: 1)\n"); + fprintf(file, " -a bind to ADDRESS (default: 127.0.0.1)\n"); +} + +constexpr int WORKER_BLOCKED_SIGNALS[] = {SIGTERM, SIGINT, SIGHUP}; +constexpr size_t N_WORKER_BLOCKED_SIGNALS = + sizeof(WORKER_BLOCKED_SIGNALS) / sizeof(int); +static bool shutdown_flag = false; + +static void signal_handler(int signal) { + if (shutdown_flag) { + _exit(signal == SIGTERM ? 0 : EXIT_FAILURE); + } + shutdown_flag = true; +} + +static void setup_signals() { + struct sigaction act = { + .sa_handler = signal_handler, + .sa_flags = 0, + }; + sigemptyset(&act.sa_mask); + for (size_t i = 0; i < N_WORKER_BLOCKED_SIGNALS; ++i) { + if (sigaction(WORKER_BLOCKED_SIGNALS[i], &act, NULL) < 0) { + log_errno("sigaction"); + exit(EXIT_FAILURE); + } + } +} + +static int parse_result_to_status(HTTPRequestParseResult res) { + switch (res) { + case HRPR_OK: + return 200; + case HRPR_NO_MEM: + case HRPR_READ_FAILED: + return 500; + case HRPR_BAD_VERSION: + return 505; + case HRPR_BAD_FORMAT: + return 400; + default: + abort(); + } +} + +static void send_simple_response(FILE *stream, int status) { + size_t status_msg_len; + const char *status_msg = status_code_to_message(status, &status_msg_len); + HTTPResponse resp = { + .status = status, + .headers = NULL, + .body_length = status_msg_len + 1, + }; + format_http_response(stream, &resp); + fwrite(status_msg, 1, status_msg_len, stream); + fputc('\n', stream); +} + +static void write_audit_log_entry(HTTPRequest *restrict req, int status) { + fprintf(stderr, "%s,%s,%d,%s\n", req->method, req->uri, status, + http_header_list_search(req->headers, "Request-ID", "0")); +} + +static void handle_get_request(FILE *conn, HTTPRequest *restrict req) { + int status = 200; + struct stat statbuf; + flockfile(stderr); + FILE *file_handle = fopen(req->path, "r"); + if (!file_handle) { + if (errno == ENOENT) { + status = 404; + } else if (errno == EACCES) { + status = 403; + } else { + status = 500; + } + } + if (file_handle && fstat(fileno(file_handle), &statbuf) < 0) { + status = errno == EACCES ? 403 : 500; + } + write_audit_log_entry(req, status); + funlockfile(stderr); + if (status != 200) { + if (file_handle) { + fclose(file_handle); + } + send_simple_response(conn, status); + return; + } + HTTPResponse resp = { + .status = 200, + .headers = NULL, + .body_length = statbuf.st_size, + }; + format_http_response(conn, &resp); + char read_buf[4096]; + while (!feof(file_handle)) { + if (ferror(file_handle)) { + // can't really do anything about it here, just break + break; + } + size_t count = fread(read_buf, 1, sizeof(read_buf), file_handle); + if (count) { + fwrite(read_buf, 1, count, conn); + } + } + fclose(file_handle); +} + +static ssize_t get_content_length(HTTPRequest *restrict req) { + const char *text = + http_header_list_search(req->headers, "Content-Length", NULL); + if (!text) { + return 0; + } else if (isspace(*text) || *text == '-' || *text == '+') { + return -1; + } + char *endptr; + uintmax_t conv = strtoumax(text, &endptr, 10); + if (*endptr || (conv == UINTMAX_MAX && errno == ERANGE) + || conv > SIZE_MAX) { + return -1; + } + return conv; +} + +static void handle_put_request(FILE *conn, HTTPRequest *restrict req) { + ssize_t content_length = get_content_length(req); + if (content_length < 0) { + write_audit_log_entry(req, 400); + send_simple_response(conn, 400); + fclose(conn); + return; + } + char temp_file[] = "temp_fileXXXXXX"; + int temp_fd = mkstemp(temp_file); + if (temp_fd < 0) { + write_audit_log_entry(req, 500); + send_simple_response(conn, 500); + return; + } + char read_buff[4096]; + while (content_length) { + ssize_t read_size = fread(read_buff, 1, + (size_t) content_length < sizeof(read_buff) + ? (size_t) content_length + : sizeof(read_buff), + conn); + if (ferror(conn) || write(temp_fd, read_buff, read_size) < 0) { + write_audit_log_entry(req, 500); + send_simple_response(conn, 500); + close(temp_fd); + unlink(temp_file); + return; + } + content_length -= read_size; + } + close(temp_fd); + int status; + bool need_unlink = true; + flockfile(stderr); + int creat_fd = open(req->path, O_RDONLY | O_CREAT | O_EXCL, 0644); + if (creat_fd < 0 && errno == EEXIST) { + // file existed + status = 200; + } else if (creat_fd <= 0) { + // actual error + status = errno == EACCES ? 403 : 500; + goto write_status_and_unlock; + } else { + // created the file + status = 201; + close(creat_fd); + } + if (rename(temp_file, req->path) < 0) { + status = errno == EACCES ? 403 : 500; + goto write_status_and_unlock; + } + need_unlink = false; +write_status_and_unlock: + write_audit_log_entry(req, status); + funlockfile(stderr); + send_simple_response(conn, status); + if (need_unlink) { + unlink(temp_file); + } +} + +static void handle_connection(void *arg) { + FILE *conn = arg; + HTTPRequest req; + HTTPRequestParseResult res = parse_http_request(conn, &req); + if (res != HRPR_OK) { + int status = parse_result_to_status(res); + write_audit_log_entry(&req, status); + send_simple_response(conn, status); + } else if (strchr(req.path, '/')) { + write_audit_log_entry(&req, 400); + send_simple_response(conn, 400); + } else if (strcmp(req.method, "GET") == 0) { + handle_get_request(conn, &req); + } else if (strcmp(req.method, "PUT") == 0) { + handle_put_request(conn, &req); + } else { + write_audit_log_entry(&req, 501); + send_simple_response(conn, 501); + } + free_http_request(&req); + fclose(conn); } int main(int argc, const char **argv) { setenv("POSIXLY_CORRECT", "1", true); - set_exec_name(argv[0]); GlobalFlags flags = DEFAULT_FLAGS; parse_cli_options(argc, argv, &flags); if (flags.help_flag) { @@ -104,7 +337,39 @@ int main(int argc, const char **argv) { } else if (flags.opt_error) { return EXIT_FAILURE; } - printf("Port: %d\n", flags.port); - printf("Parallelism: %d\n", flags.parallelism); + sigset_t worker_block_set; + sigemptyset(&worker_block_set); + for (size_t i = 0; i < N_WORKER_BLOCKED_SIGNALS; ++i) { + sigaddset(&worker_block_set, WORKER_BLOCKED_SIGNALS[i]); + } + Server *server = make_server(flags.address, flags.port); + if (!server) { + return EXIT_FAILURE; + } + ThreadPool *pool = make_thread_pool(flags.parallelism, worker_block_set); + if (!pool) { + destroy_server(server); + return EXIT_FAILURE; + } + setup_signals(); + while (true) { + int conn_fd = server_accept(server); + if (shutdown_flag) { + if (conn_fd >= 0) { + close(conn_fd); + } + break; + } else if (conn_fd < 0) { + continue; + } + FILE *conn = fdopen(conn_fd, "w+"); + if (!conn) { + close(conn_fd); + continue; + } + thread_pool_enqueue(pool, handle_connection, conn); + } + destroy_thread_pool(pool); + destroy_server(server); return 0; } diff --git a/src/server.c b/src/server.c new file mode 100644 index 0000000..4772de2 --- /dev/null +++ b/src/server.c @@ -0,0 +1,61 @@ +#include "server.h" + +#include "util.h" + +#include +#include +#include +#include +#include + +struct _Server { + int socket; +}; + +Server *make_server(const char *text_addr, uint32_t port) { + struct sockaddr_in addr = { + .sin_family = AF_INET, + .sin_port = htons(port), + }; + if (inet_pton(AF_INET, text_addr, &addr.sin_addr) != 1) { + log_error("bad IPv4 address: \"%s\"", text_addr); + return NULL; + } + Server *server = malloc_safe(sizeof(Server)); + server->socket = socket(AF_INET, SOCK_STREAM, 0); + if (server->socket < 0) { + log_errno("socket"); + free(server); + return NULL; + } + int on = 1; + if (setsockopt(server->socket, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on)) + < 0) { + log_errno("setsockopt SO_REUSEADDR"); + close(server->socket); + free(server); + return NULL; + } + if (bind(server->socket, (struct sockaddr *) &addr, sizeof(addr)) < 0) { + log_errno("bind"); + close(server->socket); + free(server); + return NULL; + } + if (listen(server->socket, SERVER_BACKLOG) < 0) { + log_errno("listen"); + close(server->socket); + free(server); + return NULL; + } + return server; +} + +void destroy_server(Server *server) { + close(server->socket); + free(server); +} + +int server_accept(Server *server) { + return accept(server->socket, NULL, NULL); +} diff --git a/src/server.h b/src/server.h new file mode 100644 index 0000000..310aa11 --- /dev/null +++ b/src/server.h @@ -0,0 +1,18 @@ +#ifndef INCLUDED_SERVER_H +#define INCLUDED_SERVER_H + +#include +#include + +constexpr int SERVER_BACKLOG = 32; + +typedef struct _Server Server; + +Server *make_server(const char *text_addr, uint32_t port); + +void destroy_server(Server *server); + +// same return value as accept(2) +int server_accept(Server *server); + +#endif diff --git a/src/threadpool.c b/src/threadpool.c index 012e632..1333001 100644 --- a/src/threadpool.c +++ b/src/threadpool.c @@ -1,7 +1,9 @@ #include "threadpool.h" +#include "util.h" + +#include #include -#include struct thread_pool_queue { Task task; @@ -10,17 +12,96 @@ struct thread_pool_queue { }; struct _ThreadPool { + bool running; size_t nthreads; - thrd_t *threads; + sigset_t thread_sig_mask; + pthread_t *threads; + pthread_cond_t queue_cnd; + pthread_mutex_t queue_mtx; struct thread_pool_queue *queue; }; -ThreadPool *make_thread_pool(int parallelism) { - ThreadPool *pool = malloc(sizeof(ThreadPool)); - pool->nthreads = parallelism; +// return false if we need to stop +static bool get_task(ThreadPool *pool, Task *task, void **task_arg) { + pthread_mutex_lock(&pool->queue_mtx); + if (!pool->running) { + pthread_mutex_unlock(&pool->queue_mtx); + return false; + } + while (true) { + pthread_cond_wait(&pool->queue_cnd, &pool->queue_mtx); + if (!pool->running) { + pthread_mutex_unlock(&pool->queue_mtx); + return false; + } + struct thread_pool_queue *ent = pool->queue; + if (ent) { + pool->queue = pool->queue->next; + pthread_mutex_unlock(&pool->queue_mtx); + *task = ent->task; + *task_arg = ent->arg; + free(ent); + return true; + } + } + abort(); } -void destroy_thread_pool(ThreadPool *pool) {} +static void *pool_thread_function(void *arg) { + ThreadPool *pool = arg; + pthread_sigmask(SIG_SETMASK, &pool->thread_sig_mask, NULL); + Task task; + void *task_arg; + while (get_task(pool, &task, &task_arg)) { + task(task_arg); + } + return NULL; +} -void thread_pool_enqueue(Task task, void *arg) {} +ThreadPool *make_thread_pool(size_t parallelism, sigset_t sig_mask) { + ThreadPool *pool = malloc_safe(sizeof(ThreadPool)); + pthread_mutex_init(&pool->queue_mtx, NULL); + pthread_cond_init(&pool->queue_cnd, NULL); + pool->running = true; + pool->queue = NULL; + pool->nthreads = parallelism; + pool->thread_sig_mask = sig_mask; + pool->threads = malloc_safe(sizeof(pthread_t) * parallelism); + // create don't race with any received signals + sigset_t sset_full; + sigfillset(&sset_full); + sigset_t sset_save; + pthread_sigmask(SIG_SETMASK, &sset_full, &sset_save); + for (size_t i = 0; i < parallelism; ++i) { + pthread_create(&pool->threads[i], NULL, &pool_thread_function, pool); + } + pthread_sigmask(SIG_SETMASK, &sset_save, NULL); + return pool; +} + +void destroy_thread_pool(ThreadPool *pool) { + pthread_mutex_lock(&pool->queue_mtx); + pool->running = false; + pthread_cond_broadcast(&pool->queue_cnd); + pthread_mutex_unlock(&pool->queue_mtx); + for (size_t i = 0; i < pool->nthreads; ++i) { + pthread_join(pool->threads[i], NULL); + } + free(pool->threads); + pthread_mutex_destroy(&pool->queue_mtx); + pthread_cond_destroy(&pool->queue_cnd); + free(pool); +} + +void thread_pool_enqueue(ThreadPool *pool, Task task, void *arg) { + pthread_mutex_lock(&pool->queue_mtx); + struct thread_pool_queue *new = + malloc_safe(sizeof(struct thread_pool_queue)); + new->task = task; + new->arg = arg; + new->next = pool->queue; + pool->queue = new; + pthread_cond_signal(&pool->queue_cnd); + pthread_mutex_unlock(&pool->queue_mtx); +} diff --git a/src/threadpool.h b/src/threadpool.h index 14efa9a..4007d56 100644 --- a/src/threadpool.h +++ b/src/threadpool.h @@ -1,13 +1,16 @@ #ifndef INCLUDED_THREAD_POOL_H #define INCLUDED_THREAD_POOL_H +#include // IWYU pragma: keep +#include + typedef struct _ThreadPool ThreadPool; typedef void (*Task)(void *); -ThreadPool *make_thread_pool(int parallelism); +ThreadPool *make_thread_pool(size_t parallelism, sigset_t sig_mask); void destroy_thread_pool(ThreadPool *pool); -void thread_pool_enqueue(Task task, void *arg); +void thread_pool_enqueue(ThreadPool *pool, Task task, void *arg); #endif diff --git a/src/util.c b/src/util.c index 23bd715..e2c5039 100644 --- a/src/util.c +++ b/src/util.c @@ -1,35 +1,54 @@ #include "util.h" +#include #include #include #include #include +#include -#define DEFAULT_EXEC_NAME "httpserver" - -const char *EXEC_NAME = DEFAULT_EXEC_NAME; -static bool need_free_out_of_memory_msg = false; -NO_SANITIZE -static const char *out_of_memory_msg = - DEFAULT_EXEC_NAME ": error: out of memory\n"; - -void set_exec_name(const char *argv0) { - const char *slash = strrchr(argv0, '/'); - if (!slash) { - EXEC_NAME = argv0; - } else { - EXEC_NAME = slash + 1; - } - if (need_free_out_of_memory_msg) { - free((char *) out_of_memory_msg); +void *realloc_safe(void *oldptr, size_t size) { + static const char OOM_MSG[] = "fatal: out of memory\n"; + void *ptr = realloc(oldptr, size); + if (size && !ptr) { + fwrite(OOM_MSG, 1, sizeof(OOM_MSG) - 1, stderr); + abort(); } + return ptr; } -void log_error(const char *fmt, ...) { - fprintf(stderr, "%s: error: ", EXEC_NAME); +void *malloc_safe(size_t size) { + return realloc_safe(NULL, size); +} + +// asprintf is not POSIX +int alloc_sprintf(char *restrict *restrict out, const char *restrict fmt, ...) { + va_list args; + va_start(args, fmt); + va_list args2; + va_copy(args2, args); + int need = vsnprintf(NULL, 0, fmt, args2); + va_end(args2); + *out = realloc_safe(*out, need + 1); + int written = vsnprintf(*out, need + 1, fmt, args); + va_end(args); + return written; +} + +void log_error(const char *restrict fmt, ...) { + time_t cur_time = time(NULL); + struct tm tm; + localtime_r(&cur_time, &tm); + char time_str[32]; + strftime(time_str, sizeof(time_str), "%c", &tm); + fprintf(stderr, "[%s] error: ", time_str); va_list args; va_start(args, fmt); vfprintf(stderr, fmt, args); va_end(args); fputc('\n', stderr); } + +void log_errno(const char *detail) { + log_error("%s: %s", detail, strerror(errno)); +} diff --git a/src/util.h b/src/util.h index 18abbfd..65570c0 100644 --- a/src/util.h +++ b/src/util.h @@ -1,23 +1,23 @@ #ifndef INCLUDED_UTIL_H #define INCLUDED_UTIL_H +#include + #if __has_attribute(format) # define PRINTF_LIKE(i, j) __attribute__((format(printf, i, j))) #else # define PRINTF_LIKE(i, j) #endif -#if __has_attribute(no_sanitize_address) -# define NO_SANITIZE __attribute__((no_sanitize_address)) -#else -# define NO_SANITIZE -#endif +void *realloc_safe(void *oldptr, size_t size); +void *malloc_safe(size_t size); -extern const char *EXEC_NAME; - -void set_exec_name(const char *argv0); +// asprintf is not POSIX +PRINTF_LIKE(2, 3) +int alloc_sprintf(char *restrict *restrict out, const char *restrict fmt, ...); PRINTF_LIKE(1, 2) -void log_error(const char *fmt, ...); +void log_error(const char *restrict fmt, ...); +void log_errno(const char *detail); #endif