const std = @import("std"); const log = std.log; const Server = std.Server; const http = std.http; const net = std.net; const testing = std.testing; const expect = testing.expect; const Address = net.Address; const Stream = net.Stream; const BUFFER_SIZE = 1024; const MAX_CONNECTION_UPSTREAM = 1520; //var metrics_mutex = std.Thread.Mutex{}; //var metrics_num_requests: u64 = 0; //var metrics_sum_req_time: u64 = 0; const UpstreamConnectionState = enum { inactive, available, occupied }; const UpstreamConnection = struct { state: UpstreamConnectionState, stream: Stream, pub fn init(address: net.Address) UpstreamConnection { const stream = net.tcpConnectToAddress(address) catch |err| { log.err("Error when connecting to upstream {}", .{err}); unreachable; }; return UpstreamConnection{ .stream = stream, .state = .available, }; } inline fn occupy(self: *UpstreamConnection) void { self.state = .occupied; } inline fn free(self: *UpstreamConnection) void { self.state = .available; } }; pub const UpstreamServer = struct { pool: [MAX_CONNECTION_UPSTREAM]UpstreamConnection, address: net.Address, pub fn init(host: []const u8, port: u16) UpstreamServer { var buf: [4096]u8 = undefined; var pba = std.heap.FixedBufferAllocator.init(&buf); const addrList = net.getAddressList(pba.allocator(), host, port) catch unreachable; defer addrList.deinit(); var final_addr: ?Address = null; for (addrList.addrs) |addr| { const ping = net.tcpConnectToAddress(addr) catch { continue; }; ping.close(); final_addr = addr; } std.debug.assert(final_addr != null); return UpstreamServer{ .address = final_addr.?, .pool = undefined, }; } pub fn warmup(self: *UpstreamServer) void { _ = std.Thread.spawn(.{ .stack_size = 1024 * 24 }, _warmup, .{self}) catch unreachable; } fn _warmup(self: *UpstreamServer) void { for (0..self.pool.len) |i| { self.pool[i] = UpstreamConnection.init(self.address); } } pub fn getAvailableConnection(self: *UpstreamServer) ?*UpstreamConnection { for (&self.pool) |*conn| { if (conn.state == .available) { return conn; } } return null; } }; pub const LoadBalancer = struct { address: net.Address, address_options: net.Address.ListenOptions, servers: []UpstreamServer, pub fn init(ip_map: []const u8, port: u16, servers: []UpstreamServer) !LoadBalancer { const address = try net.Address.parseIp4(ip_map, port); return LoadBalancer{ .address = address, .address_options = net.Address.ListenOptions{}, .servers = servers, }; } pub fn start(self: *LoadBalancer) !void { var server = try self.address.listen(self.address_options); std.debug.print("Listening load balancer http://{}\n", .{self.address}); defer server.deinit(); var lb: usize = 0; while (true) { const conn = server.accept() catch |err| { log.err("Error socket {}\n", .{err}); continue; }; var upstream: ?*UpstreamConnection = null; while (upstream == null) { upstream = self.servers[lb % self.servers.len].getAvailableConnection(); lb += 1; } upstream.?.occupy(); var thread = std.Thread.spawn(.{ .stack_size = 1024 * 6 }, handleConnection, .{ self, conn, upstream.? }) catch |err| { log.err("Creating thread error: {}\n", .{err}); conn.stream.close(); continue; }; thread.detach(); } } fn handleConnection( _: *LoadBalancer, conn: net.Server.Connection, upstream: *UpstreamConnection, ) void { defer upstream.free(); defer conn.stream.close(); var buffer_request: [BUFFER_SIZE]u8 = undefined; var buffer_response: [BUFFER_SIZE]u8 = undefined; buffer_request[0] = 0; //var timer = std.time.Timer.start() catch return; while (true) { var req_len: usize = 1; while (true) { const aux_len = conn.stream.read(buffer_request[req_len..]) catch |err| { log.err("Error when read from connection {}\n", .{err}); return; }; if (aux_len == 0) { return; } req_len += aux_len; if (buffer_request[1] == 'G' or buffer_request[8] == 'u') break; if (req_len >= 25 + 22 + 30 + 20 + 32 + 2 + 70) { break; } } //timer.reset(); upstream.stream.writeAll(buffer_request[0..req_len]) catch |err| { log.err("Error when writing to upstream {}\n", .{err}); return; }; var res_len: usize = 0; while (res_len == 0 or buffer_response[res_len - 1] != 0) { res_len += upstream.stream.read(buffer_response[res_len..]) catch |err| { log.err("Error when reading from upstream {}\n", .{err}); return; }; } _ = conn.stream.write(buffer_response[0 .. res_len - 1]) catch |err| { log.err("Error when write from connection {}\n", .{err}); return; }; //const req_time_ns = timer.lap(); //metrics_mutex.lock(); //metrics_num_requests += 1; //metrics_sum_req_time += req_time_ns; //if (metrics_num_requests % 5000 == 0) { // std.debug.print("average requests time ns: {d}\n", .{metrics_sum_req_time / metrics_num_requests}); //} //metrics_mutex.unlock(); } } }; const skip_tests = true; test "expect connect to upstream server" { if (skip_tests) { return; } const upstream_server = UpstreamServer.init("localhost", 5001); for (upstream_server.pool) |conn| { try expect(conn.state == .available); } } test "expect get the first available connection upstream server" { if (skip_tests) { return; } var upstream_server = UpstreamServer.init("localhost", 5001); upstream_server.pool[0].state = .occupied; upstream_server.pool[1].state = .occupied; upstream_server.pool[3].state = .occupied; try expect(upstream_server.getAvailableConnection().? == &upstream_server.pool[2]); } test "expect initiate load balancer" { if (skip_tests) { return; } const upstream_server = UpstreamServer.init("localhost", 5001); var servers = [_]UpstreamServer{upstream_server}; var lb = try LoadBalancer.init("127.0.0.1", 9999, &servers); try lb.start(); }