zig-pay/src/load_balancer.zig

230 lines
6.2 KiB
Zig

const std = @import("std");
const log = std.log;
const Server = std.Server;
const http = std.http;
const net = std.net;
const testing = std.testing;
const expect = testing.expect;
const Address = net.Address;
const Stream = net.Stream;
const BUFFER_SIZE = 1024;
const UpstreamConnectionState = enum { inactive, available, occupied };
const UpstreamConnection = struct {
state: UpstreamConnectionState,
stream: Stream,
pub fn init(address: net.Address) UpstreamConnection {
const stream = net.tcpConnectToAddress(address) catch |err| {
log.err("Error when connecting to upstream {}", .{err});
unreachable;
};
return UpstreamConnection{
.stream = stream,
.state = .available,
};
}
inline fn occupy(self: *UpstreamConnection) void {
self.state = .occupied;
}
inline fn free(self: *UpstreamConnection) void {
self.state = .available;
}
};
pub const UpstreamServer = struct {
pool: [300]UpstreamConnection,
address: net.Address,
pub fn init(host: []const u8, port: u16) UpstreamServer {
var buf: [4096]u8 = undefined;
var pba = std.heap.FixedBufferAllocator.init(&buf);
const addrList = net.getAddressList(pba.allocator(), host, port) catch unreachable;
defer addrList.deinit();
var final_addr: ?Address = null;
for (addrList.addrs) |addr| {
const ping = net.tcpConnectToAddress(addr) catch {
continue;
};
ping.close();
final_addr = addr;
}
std.debug.assert(final_addr != null);
var connections: [300]UpstreamConnection = undefined;
for (&connections) |*conn| {
conn.* = UpstreamConnection.init(final_addr.?);
}
return UpstreamServer{
.address = final_addr.?,
.pool = connections,
};
}
pub fn getAvailableConnection(self: *UpstreamServer) ?*UpstreamConnection {
for (&self.pool) |*conn| {
if (conn.state == .available) {
return conn;
}
}
return null;
}
};
pub const LoadBalancer = struct {
address: net.Address,
address_options: net.Address.ListenOptions,
servers: []UpstreamServer,
pub fn init(ip_map: []const u8, port: u16, servers: []UpstreamServer) !LoadBalancer {
const address = try net.Address.parseIp4(ip_map, port);
return LoadBalancer{
.address = address,
.address_options = net.Address.ListenOptions{},
.servers = servers,
};
}
pub fn start(self: *LoadBalancer) !void {
var server = try self.address.listen(self.address_options);
std.debug.print("Listening load balancer http://{}\n", .{self.address});
defer server.deinit();
var lb: usize = 0;
while (true) {
const conn = server.accept() catch |err| {
log.err("Error socket {}\n", .{err});
continue;
};
var upstream: ?*UpstreamConnection = null;
while (upstream == null) {
upstream = self.servers[lb % self.servers.len].getAvailableConnection();
lb += 1;
}
upstream.?.occupy();
var thread = std.Thread.spawn(.{ .stack_size = 1024 * 16 }, handleConnection, .{ self, conn, upstream.? }) catch |err| {
log.err("Creating thread error: {}\n", .{err});
conn.stream.close();
continue;
};
thread.detach();
}
}
fn handleConnection(
_: *LoadBalancer,
conn: net.Server.Connection,
upstream: *UpstreamConnection,
) void {
defer upstream.free();
defer conn.stream.close();
var buffer_request: [BUFFER_SIZE]u8 = undefined;
var buffer_response: [BUFFER_SIZE]u8 = undefined;
buffer_request[0] = 0;
while (true) {
var req_len: usize = 1;
while (true) {
const aux_len = conn.stream.read(buffer_request[req_len..]) catch |err| {
log.err("Error when read from connection {}\n", .{err});
return;
};
if (aux_len == 0) {
return;
}
req_len += aux_len;
if (buffer_request[1] == 'G' or buffer_request[8] == 'u') break;
if (req_len >= 25 + 22 + 30 + 20 + 32 + 2 + 70) {
break;
}
}
upstream.stream.writeAll(buffer_request[0..req_len]) catch |err| {
log.err("Error when writing to upstream {}\n", .{err});
return;
};
var res_len: usize = 0;
while (res_len == 0 or buffer_response[res_len - 1] != 0) {
res_len += upstream.stream.read(buffer_response[res_len..]) catch |err| {
log.err("Error when reading from upstream {}\n", .{err});
return;
};
}
_ = conn.stream.write(buffer_response[0 .. res_len - 1]) catch |err| {
log.err("Error when write from connection {}\n", .{err});
return;
};
}
}
};
const skip_tests = true;
test "expect connect to upstream server" {
if (skip_tests) {
return;
}
const upstream_server = UpstreamServer.init("localhost", 5001);
for (upstream_server.pool) |conn| {
try expect(conn.state == .available);
}
}
test "expect get the first available connection upstream server" {
if (skip_tests) {
return;
}
var upstream_server = UpstreamServer.init("localhost", 5001);
upstream_server.pool[0].state = .occupied;
upstream_server.pool[1].state = .occupied;
upstream_server.pool[3].state = .occupied;
try expect(upstream_server.getAvailableConnection().? == &upstream_server.pool[2]);
}
test "expect initiate load balancer" {
if (skip_tests) {
return;
}
const upstream_server = UpstreamServer.init("localhost", 5001);
var servers = [_]UpstreamServer{upstream_server};
var lb = try LoadBalancer.init("127.0.0.1", 9999, &servers);
try lb.start();
}