Finish initial implementation

This commit is contained in:
2026-05-25 17:25:50 -07:00
parent 92acec0893
commit 1005650ea6
10 changed files with 804 additions and 63 deletions
+5 -4
View File
@@ -1,7 +1,7 @@
CC=gcc CC=gcc
CFLAGS=-Wall -Wextra -Wpedantic -std=c23 -D_POSIX_C_SOURCE=200112L -O2 CFLAGS=-Wall -Wextra -Wpedantic -std=c23 -D_POSIX_C_SOURCE=200809L -pthread -O2
CFLAGS+=-Og -g -fsanitize=address,undefined CFLAGS+=-Og -g -fsanitize=address,undefined
SRCS=main.c threadpool.c util.c SRCS=main.c threadpool.c util.c server.c http.c
OBJS=$(SRCS:%.c=bin/%.o) OBJS=$(SRCS:%.c=bin/%.o)
@@ -10,9 +10,10 @@ all: httpserver
httpserver: $(OBJS) httpserver: $(OBJS)
$(CC) $(CFLAGS) -o $@ $^ $(CC) $(CFLAGS) -o $@ $^
bin/%.o: src/%.c # Auto-rebuild if Makefile changes
bin/%.o: src/%.c Makefile
@mkdir -p bin/deps/ @mkdir -p bin/deps/
$(CC) $(CFLAGS) -MD -MF $(patsubst src/%.c,bin/deps/%.d,$^) -c -o $@ $^ $(CC) $(CFLAGS) -MD -MF $(patsubst src/%.c,bin/deps/%.d,$<) -c -o $@ $<
include $(SRCS:%.c/bin/deps/%.d) include $(SRCS:%.c/bin/deps/%.d)
+236
View File
@@ -0,0 +1,236 @@
#include "http.h"
#include <assert.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
static const char *EMPTY_STRING = "";
HTTPHeaderList *http_header_list_push(HTTPHeaderList *list, const char *key,
const char *value) {
HTTPHeaderList *new = malloc(sizeof(HTTPHeaderList));
if (!new) {
return NULL;
}
new->key_length = strlen(key);
new->key = malloc(new->key_length + 1);
if (!new->key) {
free(new);
return NULL;
}
memcpy(new->key, key, new->key_length + 1);
new->value_length = strlen(value);
new->value = malloc(new->value_length + 1);
if (!new->value) {
free(new->key);
free(new);
return NULL;
}
memcpy(new->value, value, new->value_length + 1);
new->next = list;
return new;
}
void free_http_header_list(HTTPHeaderList *list) {
while (list) {
HTTPHeaderList *next = list->next;
free(list->key);
free(list->value);
free(list);
list = next;
}
}
const char *http_header_list_search(HTTPHeaderList *list, const char *key,
const char *def) {
while (list) {
if (strcmp(list->key, key) == 0) {
return list->value;
}
list = list->next;
}
return def;
}
#define MAX_REQUEST_LENGTH 16384
#define MAX_METHOD_LENGTH 16
#define MAX_URI_LENGTH 256
#define MAX_VERSION_LENGTH 3
#define MAX_HEADER_KEY_LENGTH 2048
#define MAX_HEADER_VALUE_LENGTH 2048
#define RETURN_IF_READ_ERROR(s) \
if (ferror((s))) { \
return HRPR_READ_FAILED; \
}
// if an error occured, req->uri is not allocated
static HTTPRequestParseResult parse_method_uri_line(FILE *stream,
size_t *restrict bytes_read,
HTTPRequest *restrict req) {
// allow for some leeway in passing incorrect methods
char method[MAX_METHOD_LENGTH + 1];
char uri[MAX_URI_LENGTH + 1];
char version_str[MAX_VERSION_LENGTH + 1];
char whitespace[4];
ssize_t signed_bytes_read;
#define S1(s) #s
#define S(s) S1(s)
int nconv = fscanf(
stream,
// clang-format off
"%" S(MAX_METHOD_LENGTH) "[A-Z]"
"%c"
"%" S(MAX_URI_LENGTH) "[^ \n\r]"
"%c"
"HTTP/%" S(MAX_VERSION_LENGTH) "[0-9.]"
"%c%c"
"%zn",
// clang-format on
method, &whitespace[0], uri, &whitespace[1], version_str,
&whitespace[2], &whitespace[3], &signed_bytes_read);
#undef S
*bytes_read = signed_bytes_read;
RETURN_IF_READ_ERROR(stream);
if (nconv != 7 || memcmp(whitespace, " \r\n", 4) != 0) {
return HRPR_BAD_FORMAT;
}
req->method = strdup(method);
if (!req->method) {
return HRPR_NO_MEM;
}
req->uri = strdup(uri);
if (!req->uri) {
return HRPR_NO_MEM;
}
req->path = req->uri;
while (*req->path == '/') {
++req->path;
}
return strcmp(version_str, "1.1") == 0 ? HRPR_OK : HRPR_BAD_VERSION;
}
// return true if there are more headers and no error occurred, false otherwise
// this will *not* free LIST if an error occurs
static bool next_header(FILE *stream, size_t *restrict bytes_read,
HTTPRequestParseResult *restrict res,
HTTPHeaderList *restrict *restrict list) {
char c = fgetc(stream);
RETURN_IF_READ_ERROR(stream);
if (c == '\r') {
c = fgetc(stream);
RETURN_IF_READ_ERROR(stream);
if (c != '\n') {
*res = HRPR_BAD_FORMAT;
return false;
}
*bytes_read = 2;
*res = HRPR_OK;
return false;
}
ungetc(c, stream);
char key[MAX_HEADER_KEY_LENGTH + 1];
char value[MAX_HEADER_VALUE_LENGTH + 1];
char whitespace[3];
ssize_t signed_bytes_read;
#define S1(s) #s
#define S(s) S1(s)
int nconv =
fscanf(stream,
// clang-format off
"%" S(MAX_HEADER_KEY_LENGTH) "[a-zA-Z0-9.-]"
":%c"
"%" S(MAX_HEADER_VALUE_LENGTH) "[ -~]"
"%c%c"
"%zn",
// clang-format on
key, &whitespace[0], value, &whitespace[1], &whitespace[2],
&signed_bytes_read);
#undef S
*bytes_read = signed_bytes_read;
RETURN_IF_READ_ERROR(stream);
if (nconv != 5 || memcmp(whitespace, " \r\n", 3) != 0) {
return HRPR_BAD_FORMAT;
}
*list = http_header_list_push(*list, key, value);
if (!*list) {
*res = HRPR_NO_MEM;
return false;
}
return true;
}
HTTPRequestParseResult parse_http_request(FILE *stream,
HTTPRequest *restrict out) {
out->uri = EMPTY_STRING;
out->path = EMPTY_STRING;
out->method = EMPTY_STRING;
size_t total_bytes;
HTTPRequestParseResult res;
if ((res = parse_method_uri_line(stream, &total_bytes, out)) != HRPR_OK) {
return res;
}
out->headers = NULL;
size_t bytes_read;
while (next_header(stream, &bytes_read, &res, &out->headers)) {
if ((total_bytes += bytes_read) > MAX_REQUEST_LENGTH) {
res = HRPR_BAD_FORMAT;
break;
}
}
return res;
}
void free_http_request(HTTPRequest *restrict req) {
if (req->method != EMPTY_STRING) {
free((char *) req->method);
}
if (req->uri != EMPTY_STRING) {
free((char *) req->uri);
}
free_http_header_list(req->headers);
}
const char *status_code_to_message(int status, size_t *restrict length) {
static const struct {
int code;
const char *msg;
size_t size;
} CODES[] = {
#define P(s, m) {s, m, sizeof(m) - 1}
P(200, "OK"),
P(201, "Created"),
P(400, "Bad Request"),
P(403, "Forbidden"),
P(404, "Not Found"),
P(500, "Internal Server Error"),
P(501, "Not Implemented"),
P(505, "Version Not Supported"),
#undef P
};
static const size_t NCODES = sizeof(CODES) / sizeof(CODES[0]);
for (size_t i = 0; i < NCODES; ++i) {
if (status == CODES[i].code) {
if (length) {
*length = CODES[i].size;
}
return CODES[i].msg;
}
}
return NULL;
}
void format_http_response(FILE *stream, HTTPResponse *restrict resp) {
assert(status_code_to_message(resp->status, NULL));
fprintf(stream, "HTTP/1.1 %d %s\r\n", resp->status,
status_code_to_message(resp->status, NULL));
fprintf(stream, "Content-Length: %zu\r\n", resp->body_length);
for (HTTPHeaderList *h = resp->headers; h; h = h->next) {
fwrite(h->key, 1, h->key_length, stream);
fwrite(": ", 1, 2, stream);
fwrite(h->value, 1, h->value_length, stream);
fwrite("\r\n", 1, 2, stream);
}
fwrite("\r\n", 1, 2, stream);
}
+57
View File
@@ -0,0 +1,57 @@
#ifndef INCLUDED_HTTP_H
#define INCLUDED_HTTP_H
#include <stddef.h>
#include <stdio.h>
typedef struct _HTTPHeaderList HTTPHeaderList;
struct _HTTPHeaderList {
char *key;
size_t key_length;
char *value;
size_t value_length;
HTTPHeaderList *next;
};
HTTPHeaderList *http_header_list_push(HTTPHeaderList *list, const char *key,
const char *value);
void free_http_header_list(HTTPHeaderList *list);
const char *http_header_list_search(HTTPHeaderList *list, const char *key,
const char *def);
typedef enum {
HRPR_OK,
HRPR_NO_MEM,
HRPR_BAD_VERSION,
HRPR_BAD_FORMAT,
HRPR_READ_FAILED,
} HTTPRequestParseResult;
typedef struct {
const char *method;
const char *uri;
// relative (shares memory with uri)
const char *path;
HTTPHeaderList *headers;
} HTTPRequest;
HTTPRequestParseResult parse_http_request(FILE *stream,
HTTPRequest *restrict out);
void free_http_request(HTTPRequest *restrict req);
typedef struct {
int status;
HTTPHeaderList *headers;
size_t body_length;
} HTTPResponse;
const char *status_code_to_message(int status, size_t *restrict length);
void format_http_response(FILE *stream, HTTPResponse *restrict resp);
#endif
+287 -22
View File
@@ -1,17 +1,25 @@
#include "http.h"
#include "server.h"
#include "threadpool.h"
#include "util.h" #include "util.h"
#include <ctype.h> #include <ctype.h>
#include <errno.h> #include <errno.h>
#include <limits.h> #include <fcntl.h>
#include <inttypes.h>
#include <signal.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h>
#include <sys/stat.h>
#include <unistd.h> #include <unistd.h>
typedef struct { typedef struct {
bool opt_error; bool opt_error;
bool help_flag; bool help_flag;
int port; uint32_t port;
int parallelism; size_t parallelism;
const char *address;
} GlobalFlags; } GlobalFlags;
static const GlobalFlags DEFAULT_FLAGS = { static const GlobalFlags DEFAULT_FLAGS = {
@@ -19,18 +27,24 @@ static const GlobalFlags DEFAULT_FLAGS = {
.help_flag = false, .help_flag = false,
.port = 0, .port = 0,
.parallelism = 1, .parallelism = 1,
.address = "127.0.0.1",
}; };
// Return false on failure, true on success // Return false on failure, true on success
static bool parse_int(const char *str, int min, int max, int *output) { static bool parse_uint(const char *str, uintmax_t min, uintmax_t max,
uintmax_t *output) {
if (isspace(*str) || *str == '+' || *str == '-') {
log_error("malformed number: \"%s\"", str);
return false;
}
errno = 0; errno = 0;
char *endptr; char *endptr;
long conv = strtol(str, &endptr, 10); uintmax_t conv = strtoumax(str, &endptr, 10);
if (!*str || *endptr) { if (!*str || *endptr) {
log_error("malformed number: \"%s\"", str); log_error("malformed number: \"%s\"", str);
return false; return false;
} else if (((conv == LONG_MIN || conv == LONG_MAX) && errno == ERANGE) } else if ((conv == UINTMAX_MAX && errno == ERANGE) || conv < min
|| conv < min || conv > max) { || conv > max) {
log_error("out of range: %s", str); log_error("out of range: %s", str);
return false; return false;
} }
@@ -38,31 +52,36 @@ static bool parse_int(const char *str, int min, int max, int *output) {
return true; return true;
} }
void parse_cli_options(int argc, const char **argv, GlobalFlags *flags) { static void parse_cli_options(int argc, const char **argv, GlobalFlags *flags) {
constexpr unsigned long MAX_PARALLELISM = INT_MAX; constexpr size_t MAX_PARALLELISM = SIZE_MAX;
constexpr unsigned long MAX_PORT = 65535; constexpr uint32_t MAX_PORT = 65535;
opterr = false; opterr = false;
int c; int c;
char pretty_flag[8]; char pretty_flag[8];
while ((c = getopt(argc, (char *const *) argv, "hp:")) >= 0) { while ((c = getopt(argc, (char *const *) argv, ":ha:p:")) >= 0) {
if (isprint(c)) { if (isprint(c)) {
snprintf(pretty_flag, sizeof(pretty_flag), "%c", c); snprintf(pretty_flag, sizeof(pretty_flag), "%c", optopt);
} else { } else {
#pragma GCC diagnostic push #pragma GCC diagnostic push
#pragma GCC diagnostic ignored "-Wformat-truncation" #pragma GCC diagnostic ignored "-Wformat-truncation"
// `c` is actually in the range [0,255], and so will be at most 2 // `c` is actually in the range [0,255], and so will be at most 2
// hex digits // hex digits
snprintf(pretty_flag, sizeof(pretty_flag), "0x%02x", c); snprintf(pretty_flag, sizeof(pretty_flag), "0x%02x", optopt);
#pragma GCC diagnostic pop #pragma GCC diagnostic pop
} }
switch (c) { switch (c) {
case 'h': case 'h':
flags->help_flag = true; flags->help_flag = true;
return; return;
case 'p': case 'p': {
if (!parse_int(optarg, 1, MAX_PARALLELISM, &flags->parallelism)) { uintmax_t conv;
if (!parse_uint(optarg, 1, MAX_PARALLELISM, &conv)) {
flags->opt_error = true; flags->opt_error = true;
} }
flags->parallelism = conv;
} break;
case 'a':
flags->address = optarg;
break; break;
case ':': case ':':
flags->opt_error = true; flags->opt_error = true;
@@ -74,28 +93,242 @@ void parse_cli_options(int argc, const char **argv, GlobalFlags *flags) {
break; break;
} }
} }
if (optind == argc) { uintmax_t port_conv;
if (flags->opt_error) {
return;
} else if (optind == argc) {
flags->opt_error = true; flags->opt_error = true;
log_error("no port provided"); log_error("no port provided");
} else if (optind != argc - 1) { } else if (optind != argc - 1) {
flags->opt_error = true; flags->opt_error = true;
log_error("extra positional arguments after port"); log_error("extra positional arguments after port");
} else if (!parse_int(argv[optind], 1, MAX_PORT, &flags->port)) { } else if (!parse_uint(argv[optind], 1, MAX_PORT, &port_conv)) {
flags->opt_error = true; flags->opt_error = true;
} else {
flags->port = port_conv;
} }
} }
void print_help(FILE *file) { static void print_help(FILE *file) {
fprintf(file, "usage: httpserver [-h] [-p PARALLELISM] <PORT>\n"); fprintf(file,
"usage: httpserver [-h] [-a ADDRESS] [-p PARALLELISM] <PORT>\n");
fprintf(file, " -h print this message, then exit\n"); fprintf(file, " -h print this message, then exit\n");
fprintf( fprintf(
file, file,
" -p use PARALLELISM threads for processing requests (default: 1)\n"); " -p use PARALLELISM threads for processing requests (default: 1)\n");
fprintf(file, " -a bind to ADDRESS (default: 127.0.0.1)\n");
}
constexpr int WORKER_BLOCKED_SIGNALS[] = {SIGTERM, SIGINT, SIGHUP};
constexpr size_t N_WORKER_BLOCKED_SIGNALS =
sizeof(WORKER_BLOCKED_SIGNALS) / sizeof(int);
static bool shutdown_flag = false;
static void signal_handler(int signal) {
if (shutdown_flag) {
_exit(signal == SIGTERM ? 0 : EXIT_FAILURE);
}
shutdown_flag = true;
}
static void setup_signals() {
struct sigaction act = {
.sa_handler = signal_handler,
.sa_flags = 0,
};
sigemptyset(&act.sa_mask);
for (size_t i = 0; i < N_WORKER_BLOCKED_SIGNALS; ++i) {
if (sigaction(WORKER_BLOCKED_SIGNALS[i], &act, NULL) < 0) {
log_errno("sigaction");
exit(EXIT_FAILURE);
}
}
}
static int parse_result_to_status(HTTPRequestParseResult res) {
switch (res) {
case HRPR_OK:
return 200;
case HRPR_NO_MEM:
case HRPR_READ_FAILED:
return 500;
case HRPR_BAD_VERSION:
return 505;
case HRPR_BAD_FORMAT:
return 400;
default:
abort();
}
}
static void send_simple_response(FILE *stream, int status) {
size_t status_msg_len;
const char *status_msg = status_code_to_message(status, &status_msg_len);
HTTPResponse resp = {
.status = status,
.headers = NULL,
.body_length = status_msg_len + 1,
};
format_http_response(stream, &resp);
fwrite(status_msg, 1, status_msg_len, stream);
fputc('\n', stream);
}
static void write_audit_log_entry(HTTPRequest *restrict req, int status) {
fprintf(stderr, "%s,%s,%d,%s\n", req->method, req->uri, status,
http_header_list_search(req->headers, "Request-ID", "0"));
}
static void handle_get_request(FILE *conn, HTTPRequest *restrict req) {
int status = 200;
struct stat statbuf;
flockfile(stderr);
FILE *file_handle = fopen(req->path, "r");
if (!file_handle) {
if (errno == ENOENT) {
status = 404;
} else if (errno == EACCES) {
status = 403;
} else {
status = 500;
}
}
if (file_handle && fstat(fileno(file_handle), &statbuf) < 0) {
status = errno == EACCES ? 403 : 500;
}
write_audit_log_entry(req, status);
funlockfile(stderr);
if (status != 200) {
if (file_handle) {
fclose(file_handle);
}
send_simple_response(conn, status);
return;
}
HTTPResponse resp = {
.status = 200,
.headers = NULL,
.body_length = statbuf.st_size,
};
format_http_response(conn, &resp);
char read_buf[4096];
while (!feof(file_handle)) {
if (ferror(file_handle)) {
// can't really do anything about it here, just break
break;
}
size_t count = fread(read_buf, 1, sizeof(read_buf), file_handle);
if (count) {
fwrite(read_buf, 1, count, conn);
}
}
fclose(file_handle);
}
static ssize_t get_content_length(HTTPRequest *restrict req) {
const char *text =
http_header_list_search(req->headers, "Content-Length", NULL);
if (!text) {
return 0;
} else if (isspace(*text) || *text == '-' || *text == '+') {
return -1;
}
char *endptr;
uintmax_t conv = strtoumax(text, &endptr, 10);
if (*endptr || (conv == UINTMAX_MAX && errno == ERANGE)
|| conv > SIZE_MAX) {
return -1;
}
return conv;
}
static void handle_put_request(FILE *conn, HTTPRequest *restrict req) {
ssize_t content_length = get_content_length(req);
if (content_length < 0) {
write_audit_log_entry(req, 400);
send_simple_response(conn, 400);
fclose(conn);
return;
}
char temp_file[] = "temp_fileXXXXXX";
int temp_fd = mkstemp(temp_file);
if (temp_fd < 0) {
write_audit_log_entry(req, 500);
send_simple_response(conn, 500);
return;
}
char read_buff[4096];
while (content_length) {
ssize_t read_size = fread(read_buff, 1,
(size_t) content_length < sizeof(read_buff)
? (size_t) content_length
: sizeof(read_buff),
conn);
if (ferror(conn) || write(temp_fd, read_buff, read_size) < 0) {
write_audit_log_entry(req, 500);
send_simple_response(conn, 500);
close(temp_fd);
unlink(temp_file);
return;
}
content_length -= read_size;
}
close(temp_fd);
int status;
bool need_unlink = true;
flockfile(stderr);
int creat_fd = open(req->path, O_RDONLY | O_CREAT | O_EXCL, 0644);
if (creat_fd < 0 && errno == EEXIST) {
// file existed
status = 200;
} else if (creat_fd <= 0) {
// actual error
status = errno == EACCES ? 403 : 500;
goto write_status_and_unlock;
} else {
// created the file
status = 201;
close(creat_fd);
}
if (rename(temp_file, req->path) < 0) {
status = errno == EACCES ? 403 : 500;
goto write_status_and_unlock;
}
need_unlink = false;
write_status_and_unlock:
write_audit_log_entry(req, status);
funlockfile(stderr);
send_simple_response(conn, status);
if (need_unlink) {
unlink(temp_file);
}
}
static void handle_connection(void *arg) {
FILE *conn = arg;
HTTPRequest req;
HTTPRequestParseResult res = parse_http_request(conn, &req);
if (res != HRPR_OK) {
int status = parse_result_to_status(res);
write_audit_log_entry(&req, status);
send_simple_response(conn, status);
} else if (strchr(req.path, '/')) {
write_audit_log_entry(&req, 400);
send_simple_response(conn, 400);
} else if (strcmp(req.method, "GET") == 0) {
handle_get_request(conn, &req);
} else if (strcmp(req.method, "PUT") == 0) {
handle_put_request(conn, &req);
} else {
write_audit_log_entry(&req, 501);
send_simple_response(conn, 501);
}
free_http_request(&req);
fclose(conn);
} }
int main(int argc, const char **argv) { int main(int argc, const char **argv) {
setenv("POSIXLY_CORRECT", "1", true); setenv("POSIXLY_CORRECT", "1", true);
set_exec_name(argv[0]);
GlobalFlags flags = DEFAULT_FLAGS; GlobalFlags flags = DEFAULT_FLAGS;
parse_cli_options(argc, argv, &flags); parse_cli_options(argc, argv, &flags);
if (flags.help_flag) { if (flags.help_flag) {
@@ -104,7 +337,39 @@ int main(int argc, const char **argv) {
} else if (flags.opt_error) { } else if (flags.opt_error) {
return EXIT_FAILURE; return EXIT_FAILURE;
} }
printf("Port: %d\n", flags.port); sigset_t worker_block_set;
printf("Parallelism: %d\n", flags.parallelism); sigemptyset(&worker_block_set);
for (size_t i = 0; i < N_WORKER_BLOCKED_SIGNALS; ++i) {
sigaddset(&worker_block_set, WORKER_BLOCKED_SIGNALS[i]);
}
Server *server = make_server(flags.address, flags.port);
if (!server) {
return EXIT_FAILURE;
}
ThreadPool *pool = make_thread_pool(flags.parallelism, worker_block_set);
if (!pool) {
destroy_server(server);
return EXIT_FAILURE;
}
setup_signals();
while (true) {
int conn_fd = server_accept(server);
if (shutdown_flag) {
if (conn_fd >= 0) {
close(conn_fd);
}
break;
} else if (conn_fd < 0) {
continue;
}
FILE *conn = fdopen(conn_fd, "w+");
if (!conn) {
close(conn_fd);
continue;
}
thread_pool_enqueue(pool, handle_connection, conn);
}
destroy_thread_pool(pool);
destroy_server(server);
return 0; return 0;
} }
+61
View File
@@ -0,0 +1,61 @@
#include "server.h"
#include "util.h"
#include <arpa/inet.h>
#include <netinet/in.h>
#include <stdlib.h>
#include <sys/socket.h>
#include <unistd.h>
struct _Server {
int socket;
};
Server *make_server(const char *text_addr, uint32_t port) {
struct sockaddr_in addr = {
.sin_family = AF_INET,
.sin_port = htons(port),
};
if (inet_pton(AF_INET, text_addr, &addr.sin_addr) != 1) {
log_error("bad IPv4 address: \"%s\"", text_addr);
return NULL;
}
Server *server = malloc_safe(sizeof(Server));
server->socket = socket(AF_INET, SOCK_STREAM, 0);
if (server->socket < 0) {
log_errno("socket");
free(server);
return NULL;
}
int on = 1;
if (setsockopt(server->socket, SOL_SOCKET, SO_REUSEADDR, &on, sizeof(on))
< 0) {
log_errno("setsockopt SO_REUSEADDR");
close(server->socket);
free(server);
return NULL;
}
if (bind(server->socket, (struct sockaddr *) &addr, sizeof(addr)) < 0) {
log_errno("bind");
close(server->socket);
free(server);
return NULL;
}
if (listen(server->socket, SERVER_BACKLOG) < 0) {
log_errno("listen");
close(server->socket);
free(server);
return NULL;
}
return server;
}
void destroy_server(Server *server) {
close(server->socket);
free(server);
}
int server_accept(Server *server) {
return accept(server->socket, NULL, NULL);
}
+18
View File
@@ -0,0 +1,18 @@
#ifndef INCLUDED_SERVER_H
#define INCLUDED_SERVER_H
#include <inttypes.h>
#include <stddef.h>
constexpr int SERVER_BACKLOG = 32;
typedef struct _Server Server;
Server *make_server(const char *text_addr, uint32_t port);
void destroy_server(Server *server);
// same return value as accept(2)
int server_accept(Server *server);
#endif
+88 -7
View File
@@ -1,7 +1,9 @@
#include "threadpool.h" #include "threadpool.h"
#include "util.h"
#include <pthread.h>
#include <stdlib.h> #include <stdlib.h>
#include <threads.h>
struct thread_pool_queue { struct thread_pool_queue {
Task task; Task task;
@@ -10,17 +12,96 @@ struct thread_pool_queue {
}; };
struct _ThreadPool { struct _ThreadPool {
bool running;
size_t nthreads; size_t nthreads;
thrd_t *threads; sigset_t thread_sig_mask;
pthread_t *threads;
pthread_cond_t queue_cnd;
pthread_mutex_t queue_mtx;
struct thread_pool_queue *queue; struct thread_pool_queue *queue;
}; };
ThreadPool *make_thread_pool(int parallelism) { // return false if we need to stop
ThreadPool *pool = malloc(sizeof(ThreadPool)); static bool get_task(ThreadPool *pool, Task *task, void **task_arg) {
pool->nthreads = parallelism; pthread_mutex_lock(&pool->queue_mtx);
if (!pool->running) {
pthread_mutex_unlock(&pool->queue_mtx);
return false;
}
while (true) {
pthread_cond_wait(&pool->queue_cnd, &pool->queue_mtx);
if (!pool->running) {
pthread_mutex_unlock(&pool->queue_mtx);
return false;
}
struct thread_pool_queue *ent = pool->queue;
if (ent) {
pool->queue = pool->queue->next;
pthread_mutex_unlock(&pool->queue_mtx);
*task = ent->task;
*task_arg = ent->arg;
free(ent);
return true;
}
}
abort();
} }
void destroy_thread_pool(ThreadPool *pool) {} static void *pool_thread_function(void *arg) {
ThreadPool *pool = arg;
pthread_sigmask(SIG_SETMASK, &pool->thread_sig_mask, NULL);
Task task;
void *task_arg;
while (get_task(pool, &task, &task_arg)) {
task(task_arg);
}
return NULL;
}
void thread_pool_enqueue(Task task, void *arg) {} ThreadPool *make_thread_pool(size_t parallelism, sigset_t sig_mask) {
ThreadPool *pool = malloc_safe(sizeof(ThreadPool));
pthread_mutex_init(&pool->queue_mtx, NULL);
pthread_cond_init(&pool->queue_cnd, NULL);
pool->running = true;
pool->queue = NULL;
pool->nthreads = parallelism;
pool->thread_sig_mask = sig_mask;
pool->threads = malloc_safe(sizeof(pthread_t) * parallelism);
// create don't race with any received signals
sigset_t sset_full;
sigfillset(&sset_full);
sigset_t sset_save;
pthread_sigmask(SIG_SETMASK, &sset_full, &sset_save);
for (size_t i = 0; i < parallelism; ++i) {
pthread_create(&pool->threads[i], NULL, &pool_thread_function, pool);
}
pthread_sigmask(SIG_SETMASK, &sset_save, NULL);
return pool;
}
void destroy_thread_pool(ThreadPool *pool) {
pthread_mutex_lock(&pool->queue_mtx);
pool->running = false;
pthread_cond_broadcast(&pool->queue_cnd);
pthread_mutex_unlock(&pool->queue_mtx);
for (size_t i = 0; i < pool->nthreads; ++i) {
pthread_join(pool->threads[i], NULL);
}
free(pool->threads);
pthread_mutex_destroy(&pool->queue_mtx);
pthread_cond_destroy(&pool->queue_cnd);
free(pool);
}
void thread_pool_enqueue(ThreadPool *pool, Task task, void *arg) {
pthread_mutex_lock(&pool->queue_mtx);
struct thread_pool_queue *new =
malloc_safe(sizeof(struct thread_pool_queue));
new->task = task;
new->arg = arg;
new->next = pool->queue;
pool->queue = new;
pthread_cond_signal(&pool->queue_cnd);
pthread_mutex_unlock(&pool->queue_mtx);
}
+5 -2
View File
@@ -1,13 +1,16 @@
#ifndef INCLUDED_THREAD_POOL_H #ifndef INCLUDED_THREAD_POOL_H
#define INCLUDED_THREAD_POOL_H #define INCLUDED_THREAD_POOL_H
#include <signal.h> // IWYU pragma: keep
#include <stddef.h>
typedef struct _ThreadPool ThreadPool; typedef struct _ThreadPool ThreadPool;
typedef void (*Task)(void *); typedef void (*Task)(void *);
ThreadPool *make_thread_pool(int parallelism); ThreadPool *make_thread_pool(size_t parallelism, sigset_t sig_mask);
void destroy_thread_pool(ThreadPool *pool); void destroy_thread_pool(ThreadPool *pool);
void thread_pool_enqueue(Task task, void *arg); void thread_pool_enqueue(ThreadPool *pool, Task task, void *arg);
#endif #endif
+38 -19
View File
@@ -1,35 +1,54 @@
#include "util.h" #include "util.h"
#include <errno.h>
#include <stdarg.h> #include <stdarg.h>
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <time.h>
#define DEFAULT_EXEC_NAME "httpserver" void *realloc_safe(void *oldptr, size_t size) {
static const char OOM_MSG[] = "fatal: out of memory\n";
const char *EXEC_NAME = DEFAULT_EXEC_NAME; void *ptr = realloc(oldptr, size);
static bool need_free_out_of_memory_msg = false; if (size && !ptr) {
NO_SANITIZE fwrite(OOM_MSG, 1, sizeof(OOM_MSG) - 1, stderr);
static const char *out_of_memory_msg = abort();
DEFAULT_EXEC_NAME ": error: out of memory\n";
void set_exec_name(const char *argv0) {
const char *slash = strrchr(argv0, '/');
if (!slash) {
EXEC_NAME = argv0;
} else {
EXEC_NAME = slash + 1;
}
if (need_free_out_of_memory_msg) {
free((char *) out_of_memory_msg);
} }
return ptr;
} }
void log_error(const char *fmt, ...) { void *malloc_safe(size_t size) {
fprintf(stderr, "%s: error: ", EXEC_NAME); return realloc_safe(NULL, size);
}
// asprintf is not POSIX
int alloc_sprintf(char *restrict *restrict out, const char *restrict fmt, ...) {
va_list args;
va_start(args, fmt);
va_list args2;
va_copy(args2, args);
int need = vsnprintf(NULL, 0, fmt, args2);
va_end(args2);
*out = realloc_safe(*out, need + 1);
int written = vsnprintf(*out, need + 1, fmt, args);
va_end(args);
return written;
}
void log_error(const char *restrict fmt, ...) {
time_t cur_time = time(NULL);
struct tm tm;
localtime_r(&cur_time, &tm);
char time_str[32];
strftime(time_str, sizeof(time_str), "%c", &tm);
fprintf(stderr, "[%s] error: ", time_str);
va_list args; va_list args;
va_start(args, fmt); va_start(args, fmt);
vfprintf(stderr, fmt, args); vfprintf(stderr, fmt, args);
va_end(args); va_end(args);
fputc('\n', stderr); fputc('\n', stderr);
} }
void log_errno(const char *detail) {
log_error("%s: %s", detail, strerror(errno));
}
+9 -9
View File
@@ -1,23 +1,23 @@
#ifndef INCLUDED_UTIL_H #ifndef INCLUDED_UTIL_H
#define INCLUDED_UTIL_H #define INCLUDED_UTIL_H
#include <stddef.h>
#if __has_attribute(format) #if __has_attribute(format)
# define PRINTF_LIKE(i, j) __attribute__((format(printf, i, j))) # define PRINTF_LIKE(i, j) __attribute__((format(printf, i, j)))
#else #else
# define PRINTF_LIKE(i, j) # define PRINTF_LIKE(i, j)
#endif #endif
#if __has_attribute(no_sanitize_address) void *realloc_safe(void *oldptr, size_t size);
# define NO_SANITIZE __attribute__((no_sanitize_address)) void *malloc_safe(size_t size);
#else
# define NO_SANITIZE
#endif
extern const char *EXEC_NAME; // asprintf is not POSIX
PRINTF_LIKE(2, 3)
void set_exec_name(const char *argv0); int alloc_sprintf(char *restrict *restrict out, const char *restrict fmt, ...);
PRINTF_LIKE(1, 2) PRINTF_LIKE(1, 2)
void log_error(const char *fmt, ...); void log_error(const char *restrict fmt, ...);
void log_errno(const char *detail);
#endif #endif