diff --git a/README.md b/README.md index 75782fc..765a21e 100644 --- a/README.md +++ b/README.md @@ -4,4 +4,6 @@ This library is pretty much the [TigerBeetle](https://github.com/tigerbeetle/tig IO library with the TigerBeetle specifics removed, to make it a more general purpose async IO library for Zig programs. It works on Linux, macOS, and Windows. -See the [tests](src/test.zig) for sample usage. +The tests in [tests](src/test.zig) show some sample usage, and there's a full sample +[multithreaded web server](src/sample_web_server.zig) too. You can build and run the sample +web server with `zig build run server`. diff --git a/build.zig b/build.zig index 39a5184..01809b1 100644 --- a/build.zig +++ b/build.zig @@ -4,8 +4,10 @@ pub fn build(b: *std.Build) void { const target = b.standardTargetOptions(.{}); const optimize = b.standardOptimizeOption(.{}); - _ = b.addModule("io", .{ .root_source_file = .{ .path = "src/io.zig" } }); + // Main module for projects to use. + const io_module = b.addModule("io", .{ .root_source_file = .{ .path = "src/io.zig" } }); + // Library unit tests. const main_tests = b.addTest(.{ .root_source_file = .{ .path = "src/test.zig" }, .target = target, @@ -15,4 +17,17 @@ pub fn build(b: *std.Build) void { const run_main_tests = b.addRunArtifact(main_tests); const test_step = b.step("test", "Run library tests"); test_step.dependOn(&run_main_tests.step); + + // Sample web server. + const server_exe = b.addExecutable(.{ + .name = "server", + .root_source_file = .{ .path = "src/sample_web_server.zig" }, + .target = target, + .optimize = optimize, + }); + server_exe.root_module.addImport("io", io_module); + + const run_server = b.addRunArtifact(server_exe); + const server_step = b.step("server", "Run sample web server."); + server_step.dependOn(&run_server.step); } diff --git a/src/sample_web_server.zig b/src/sample_web_server.zig new file mode 100644 index 0000000..3e8376d --- /dev/null +++ b/src/sample_web_server.zig @@ -0,0 +1,269 @@ +const std = @import("std"); + +const Atomic = std.atomic.Value; +const debug = std.debug; +const fifo = std.fifo; +const heap = std.heap; +const log = std.log.scoped(.server); +const mem = std.mem; +const os = std.os; +const Thread = std.Thread; + +const IO = @import("io").IO; + +pub const Config = struct { + pub const kernel_backlog = 16; + pub const io_entries = 128; + pub const server_ip = "127.0.0.1"; + pub const server_port = 8080; + pub const worker_threads = 5; + + pub const accept_buf_len = 128; + pub const client_buf_len = 1024 * 512; + pub const recv_buf_len = 1024; +}; + +const Queue = fifo.LinearFifo(os.socket_t, .{ .Static = Config.accept_buf_len }); + +var running = Atomic(bool).init(true); + +const Acceptor = struct { + accepting: bool, + mutex: *Thread.Mutex, + queue: *Queue, +}; + +pub fn main() !void { + // Handle OS signals for graceful shutdown. + try addSignalHandlers(); + + // Cross-platform IO setup. + var io = try IO.init(Config.io_entries, 0); + defer io.deinit(); + + // Listener setup + const address = try std.net.Address.parseIp4(Config.server_ip, Config.server_port); + const listener = try io.open_socket(address.any.family, os.SOCK.STREAM, os.IPPROTO.TCP); + defer os.closeSocket(listener); + try os.setsockopt(listener, os.SOL.SOCKET, os.SO.REUSEADDR, &std.mem.toBytes(@as(c_int, 1))); + try os.bind(listener, &address.any, address.getOsSockLen()); + try os.listen(listener, Config.kernel_backlog); + + log.info("server listening on IP {s} port {}. CTRL+C to shutdown.", .{ Config.server_ip, Config.server_port }); + + // Client socket queue setup. + var queue_mutex = Thread.Mutex{}; + var queue = Queue.init(); + + // Setup thread pool. + var handles: [Config.worker_threads]Thread = undefined; + defer for (handles) |h| h.join(); + for (&handles) |*h| h.* = try Thread.spawn(.{}, logError, .{ &queue_mutex, &queue }); + + // Accept context setup. + var acceptor = Acceptor{ + .accepting = true, + .mutex = &queue_mutex, + .queue = &queue, + }; + + while (running.load(.Monotonic)) { + // Start accepting. + var acceptor_completion: IO.Completion = undefined; + io.accept(*Acceptor, &acceptor, acceptCallback, &acceptor_completion, listener); + + // Wait while accepting. + while (acceptor.accepting) try io.tick(); + + // Reset accepting flag. + acceptor.accepting = true; + } +} + +fn acceptCallback( + acceptor_ptr: *Acceptor, + completion: *IO.Completion, + result: IO.AcceptError!os.socket_t, +) void { + _ = completion; + const accepted_sock = result catch @panic("accept error"); + log.debug("Accepted socket fd: {}", .{accepted_sock}); + + // Dispatch + { + acceptor_ptr.mutex.lock(); + defer acceptor_ptr.mutex.unlock(); + acceptor_ptr.queue.writeItem(accepted_sock) catch @panic("queue.writeItem"); + } + + // Signal we're done. + acceptor_ptr.accepting = false; +} + +fn logError(queue_mutex: *Thread.Mutex, queue: *Queue) void { + handleClient(queue_mutex, queue) catch |err| { + std.log.err("error handling client request: {}", .{err}); + }; +} + +const Client = struct { + allocator: mem.Allocator, + completion: IO.Completion, + io: *IO, + recv_buf: [Config.recv_buf_len]u8, + socket: os.socket_t, + thread_id: Thread.Id, +}; + +fn handleClient(queue_mutex: *Thread.Mutex, queue: *Queue) !void { + const thread_id = Thread.getCurrentId(); + + // Cross-platform IO setup. + var io = try IO.init(Config.io_entries, 0); + defer io.deinit(); + + // Per thread allocator + var client_buf: [Config.client_buf_len]u8 = undefined; + var fba = heap.FixedBufferAllocator.init(&client_buf); + const allocator = fba.allocator(); + + while (running.load(.Monotonic)) { + // Get next accepted client socket. + const maybe_socket: ?os.socket_t = blk: { + queue_mutex.lock(); + defer queue_mutex.unlock(); + break :blk queue.readItem(); + }; + + if (maybe_socket) |socket| { + // Allocate and init new client. + var client_ptr = try allocator.create(Client); + client_ptr.allocator = allocator; + client_ptr.io = &io; + client_ptr.completion = undefined; + client_ptr.recv_buf = undefined; + client_ptr.socket = socket; + client_ptr.thread_id = thread_id; + + // Receive from client. + io.recv( + *Client, + client_ptr, + recvCallback, + &client_ptr.completion, + socket, + &client_ptr.recv_buf, + ); + } + + try io.tick(); + } +} + +fn recvCallback( + client_ptr: *Client, + completion: *IO.Completion, + result: IO.RecvError!usize, +) void { + const received = result catch |err| blk: { + log.err("recvCallback error: {}", .{err}); + break :blk 0; + }; + log.debug("{}: Received {} bytes from {}", .{ client_ptr.thread_id, received, client_ptr.socket }); + + if (received == 0) { + // Client connection closed. + client_ptr.io.close( + *Client, + client_ptr, + closeCallback, + completion, + client_ptr.socket, + ); + + return; + } + + const response = + \\HTTP/1.1 200 OK + \\Connection: Keep-Alive + \\Keep-Alive: timeout=1 + \\Content-Type: text/plain + \\Content-Length: 6 + \\Server: server/0.1.0 + \\ + \\Hello + \\ + ; + + client_ptr.io.send( + *Client, + client_ptr, + sendCallback, + completion, + client_ptr.socket, + response, + ); +} + +fn sendCallback( + client_ptr: *Client, + completion: *IO.Completion, + result: IO.SendError!usize, +) void { + const sent = result catch @panic("sendCallback"); + log.debug("{}: Sent {} to {}", .{ client_ptr.thread_id, sent, client_ptr.socket }); + + // Try to receive from client again (keep-alive). + client_ptr.io.recv( + *Client, + client_ptr, + recvCallback, + completion, + client_ptr.socket, + &client_ptr.recv_buf, + ); +} + +fn closeCallback( + client_ptr: *Client, + completion: *IO.Completion, + result: IO.CloseError!void, +) void { + _ = completion; + _ = result catch @panic("closeCallback"); + log.debug("{}: Closed {}", .{ client_ptr.thread_id, client_ptr.socket }); + client_ptr.allocator.destroy(client_ptr); +} + +fn addSignalHandlers() !void { + // Ignore broken pipes + { + var act = os.Sigaction{ + .handler = .{ + .handler = os.SIG.IGN, + }, + .mask = os.empty_sigset, + .flags = 0, + }; + try os.sigaction(os.SIG.PIPE, &act, null); + } + + // Catch SIGINT/SIGTERM for proper shutdown + { + var act = os.Sigaction{ + .handler = .{ + .handler = struct { + fn wrapper(sig: c_int) callconv(.C) void { + log.info("Caught signal {d}; Shutting down...", .{sig}); + running.store(false, .Release); + } + }.wrapper, + }, + .mask = os.empty_sigset, + .flags = 0, + }; + try os.sigaction(os.SIG.TERM, &act, null); + try os.sigaction(os.SIG.INT, &act, null); + } +}