/* * Design Doc: * The server is broken into the following files: * - main.c: Argument parsing and high-level control * - http.c: HTTP request parsing and response formatting * - server.c: Abstraction over sockets * - threadpool.c: Thread pool and wokr queue implementation * - util.c: Misc. utility functions */ #include "http.h" #include "server.h" #include "threadpool.h" #include "util.h" #include #include #include #include #include #include #include #include #include #include #include #include typedef struct { bool opt_error; bool help_flag; uint32_t port; size_t parallelism; const char* address; } GlobalFlags; static const GlobalFlags DEFAULT_FLAGS = { .opt_error = false, .help_flag = false, .port = 0, .parallelism = 1, .address = "127.0.0.1", }; // Return false on failure, true on success static bool parse_uint(const char* str, uintmax_t min, uintmax_t max, uintmax_t* output) { if (isspace(*str) || *str == '+' || *str == '-') { #ifdef BAD_ERROR_REPORTING_FOR_AUTOGRADER fprintf(stderr, "Invalid Port\n"); #else log_error("malformed number: \"%s\"", str); #endif return false; } errno = 0; char* endptr; uintmax_t conv = strtoumax(str, &endptr, 10); if (!*str || *endptr) { #ifdef BAD_ERROR_REPORTING_FOR_AUTOGRADER fprintf(stderr, "Invalid Port\n"); #else log_error("malformed number: \"%s\"", str); #endif return false; } else if ((conv == UINTMAX_MAX && errno == ERANGE) || conv < min || conv > max) { #ifdef BAD_ERROR_REPORTING_FOR_AUTOGRADER fprintf(stderr, "Invalid Port\n"); #else log_error("out of range: %s", str); #endif return false; } *output = conv; return true; } static void parse_cli_options(int argc, const char** argv, GlobalFlags* flags) { #define MAX_PARALLELISM SIZE_MAX #define MAX_PORT 65535 opterr = false; int c; char pretty_flag[8]; while ((c = getopt(argc, (char* const*)argv, ":ha:p:")) >= 0) { if (isprint(c)) { snprintf(pretty_flag, sizeof(pretty_flag), "%c", optopt); } else { #pragma GCC diagnostic push #pragma GCC diagnostic ignored "-Wformat-truncation" // `c` is actually in the range [0,255], and so will be at most 2 // hex digits snprintf(pretty_flag, sizeof(pretty_flag), "0x%02x", optopt); #pragma GCC diagnostic pop } switch (c) { case 'h': flags->help_flag = true; return; case 'p': { uintmax_t conv; if (!parse_uint(optarg, 1, MAX_PARALLELISM, &conv)) { flags->opt_error = true; } flags->parallelism = conv; } break; case 'a': flags->address = optarg; break; case ':': flags->opt_error = true; log_error("flag requires argument: '%s'", pretty_flag); break; case '?': flags->opt_error = true; log_error("unknown flag: '%s'", pretty_flag); break; } } uintmax_t port_conv; if (flags->opt_error) { return; } else if (optind == argc) { flags->opt_error = true; log_error("no port provided"); } else if (optind != argc - 1) { flags->opt_error = true; log_error("extra positional arguments after port"); } else if (!parse_uint(argv[optind], 1, MAX_PORT, &port_conv)) { flags->opt_error = true; } else { flags->port = port_conv; } } static void print_help(FILE* file) { fprintf(file, "usage: httpserver [-h] [-a ADDRESS] [-p PARALLELISM] \n"); fprintf(file, " -h print this message, then exit\n"); fprintf(file, " -p use PARALLELISM threads for processing requests (default: 1)\n"); fprintf(file, " -a bind to ADDRESS (default: 127.0.0.1)\n"); } static const int WORKER_BLOCKED_SIGNALS[] = { SIGTERM, SIGINT, SIGHUP }; static const size_t N_WORKER_BLOCKED_SIGNALS = sizeof(WORKER_BLOCKED_SIGNALS) / sizeof(int); static bool shutdown_flag = false; static void signal_handler(int signal) { if (shutdown_flag) { _exit(signal == SIGTERM ? 0 : EXIT_FAILURE); } shutdown_flag = true; } static void setup_signals() { struct sigaction pipe_act = { .sa_flags = 0, .sa_handler = SIG_IGN, }; sigemptyset(&pipe_act.sa_mask); if (sigaction(SIGPIPE, &pipe_act, NULL) < 0) { log_errno("sigaction"); exit(EXIT_FAILURE); } struct sigaction act = { .sa_handler = signal_handler, .sa_flags = 0, }; sigemptyset(&act.sa_mask); for (size_t i = 0; i < N_WORKER_BLOCKED_SIGNALS; ++i) { if (sigaction(WORKER_BLOCKED_SIGNALS[i], &act, NULL) < 0) { log_errno("sigaction"); exit(EXIT_FAILURE); } } } static int parse_result_to_status(HTTPRequestParseResult res) { switch (res) { case HRPR_OK: return 200; case HRPR_NO_MEM: case HRPR_READ_FAILED: return 500; case HRPR_BAD_VERSION: return 505; case HRPR_BAD_FORMAT: return 400; default: abort(); } } static void send_simple_response(FILE* stream, int status) { size_t status_msg_len; const char* status_msg = status_code_to_message(status, &status_msg_len); HTTPResponse resp = { .status = status, .headers = NULL, .body_length = status_msg_len + 1, }; format_http_response(stream, &resp); write(fileno(stream), status_msg, status_msg_len); write(fileno(stream), "\n", 1); } static void write_audit_log_entry(HTTPRequest* restrict req, int status) { fprintf(stderr, "%s,%s,%d,%s\n", req->method, req->uri, status, http_header_list_search(req->headers, "Request-ID", "0")); } static pthread_mutex_t file_access_mutex = PTHREAD_MUTEX_INITIALIZER; static void handle_get_request(FILE* conn, HTTPRequest* restrict req) { int status = 200; struct stat statbuf; pthread_mutex_lock(&file_access_mutex); FILE* file_handle = fopen(req->path, "r"); if (!file_handle) { if (errno == ENOENT) { status = 404; } else if (errno == EACCES) { status = 403; } else { status = 500; } } if (file_handle && fstat(fileno(file_handle), &statbuf) < 0) { status = errno == EACCES ? 403 : 500; } write_audit_log_entry(req, status); pthread_mutex_unlock(&file_access_mutex); if (status != 200) { if (file_handle) { fclose(file_handle); } send_simple_response(conn, status); return; } HTTPResponse resp = { .status = 200, .headers = NULL, .body_length = statbuf.st_size, }; format_http_response(conn, &resp); char read_buf[4096]; while (!feof(file_handle)) { if (ferror(file_handle)) { // can't really do anything about it here, just break break; } size_t count = fread(read_buf, 1, sizeof(read_buf), file_handle); if (count) { write(fileno(conn), read_buf, count); } } fclose(file_handle); } static ssize_t get_content_length(HTTPRequest* restrict req) { const char* text = http_header_list_search(req->headers, "Content-Length", NULL); if (!text) { return 0; } else if (isspace(*text) || *text == '-' || *text == '+') { return -1; } char* endptr; uintmax_t conv = strtoumax(text, &endptr, 10); if (*endptr || (conv == UINTMAX_MAX && errno == ERANGE) || conv > SIZE_MAX) { return -1; } return conv; } static void handle_put_request(FILE* conn, HTTPRequest* restrict req) { ssize_t content_length = get_content_length(req); if (content_length < 0) { write_audit_log_entry(req, 400); send_simple_response(conn, 400); fclose(conn); return; } char temp_file[] = "temp_fileXXXXXX"; int temp_fd = mkstemp(temp_file); if (temp_fd < 0) { write_audit_log_entry(req, 500); send_simple_response(conn, 500); return; } char read_buff[4096]; while (content_length) { ssize_t read_size = fread(read_buff, 1, (size_t)content_length < sizeof(read_buff) ? (size_t)content_length : sizeof(read_buff), conn); if (ferror(conn) || write(temp_fd, read_buff, read_size) < 0) { write_audit_log_entry(req, 500); send_simple_response(conn, 500); close(temp_fd); unlink(temp_file); return; } content_length -= read_size; } close(temp_fd); int status; bool need_unlink = true; pthread_mutex_lock(&file_access_mutex); int creat_fd = open(req->path, O_RDONLY | O_CREAT | O_EXCL, 0644); if (creat_fd < 0 && errno == EEXIST) { // file existed status = 200; } else if (creat_fd <= 0) { // actual error status = errno == EACCES ? 403 : 500; goto write_status_and_unlock; } else { // created the file status = 201; close(creat_fd); } if (rename(temp_file, req->path) < 0) { status = errno == EACCES ? 403 : 500; goto write_status_and_unlock; } need_unlink = false; write_status_and_unlock: write_audit_log_entry(req, status); pthread_mutex_unlock(&file_access_mutex); send_simple_response(conn, status); if (need_unlink) { unlink(temp_file); } } static void handle_connection(void* arg) { FILE* conn = arg; HTTPRequest req; HTTPRequestParseResult res = parse_http_request(conn, &req); if (res != HRPR_OK) { int status = parse_result_to_status(res); write_audit_log_entry(&req, status); send_simple_response(conn, status); } else if (strchr(req.path, '/')) { write_audit_log_entry(&req, 400); send_simple_response(conn, 400); } else if (strcmp(req.method, "GET") == 0) { handle_get_request(conn, &req); } else if (strcmp(req.method, "PUT") == 0) { handle_put_request(conn, &req); } else { write_audit_log_entry(&req, 501); send_simple_response(conn, 501); } free_http_request(&req); shutdown(fileno(conn), SHUT_RDWR); fclose(conn); } static void fclose_free_func(void* file) { fclose(file); } int main(int argc, const char** argv) { setenv("POSIXLY_CORRECT", "1", true); GlobalFlags flags = DEFAULT_FLAGS; parse_cli_options(argc, argv, &flags); if (flags.help_flag) { print_help(stdout); return flags.opt_error ? EXIT_FAILURE : 0; } else if (flags.opt_error) { return EXIT_FAILURE; } sigset_t worker_block_set; sigemptyset(&worker_block_set); for (size_t i = 0; i < N_WORKER_BLOCKED_SIGNALS; ++i) { sigaddset(&worker_block_set, WORKER_BLOCKED_SIGNALS[i]); } Server* server = make_server(flags.address, flags.port); if (!server) { return EXIT_FAILURE; } ThreadPool* pool = make_thread_pool(flags.parallelism, worker_block_set); if (!pool) { destroy_server(server); return EXIT_FAILURE; } setup_signals(); while (true) { int conn_fd = server_accept(server); if (shutdown_flag) { if (conn_fd >= 0) { close(conn_fd); } break; } else if (conn_fd < 0) { continue; } FILE* conn = fdopen(conn_fd, "r+"); if (!conn) { close(conn_fd); continue; } thread_pool_enqueue(pool, handle_connection, conn, fclose_free_func); } destroy_thread_pool(pool); destroy_server(server); return 0; }