230 lines
6.2 KiB
Zig
230 lines
6.2 KiB
Zig
const std = @import("std");
|
|
|
|
const log = std.log;
|
|
const Server = std.Server;
|
|
const http = std.http;
|
|
const net = std.net;
|
|
|
|
const testing = std.testing;
|
|
const expect = testing.expect;
|
|
|
|
const Address = net.Address;
|
|
const Stream = net.Stream;
|
|
|
|
const BUFFER_SIZE = 1024;
|
|
|
|
const UpstreamConnectionState = enum { inactive, available, occupied };
|
|
|
|
const UpstreamConnection = struct {
|
|
state: UpstreamConnectionState,
|
|
stream: Stream,
|
|
|
|
pub fn init(address: net.Address) UpstreamConnection {
|
|
const stream = net.tcpConnectToAddress(address) catch |err| {
|
|
log.err("Error when connecting to upstream {}", .{err});
|
|
unreachable;
|
|
};
|
|
|
|
return UpstreamConnection{
|
|
.stream = stream,
|
|
.state = .available,
|
|
};
|
|
}
|
|
|
|
inline fn occupy(self: *UpstreamConnection) void {
|
|
self.state = .occupied;
|
|
}
|
|
|
|
inline fn free(self: *UpstreamConnection) void {
|
|
self.state = .available;
|
|
}
|
|
};
|
|
|
|
pub const UpstreamServer = struct {
|
|
pool: [300]UpstreamConnection,
|
|
address: net.Address,
|
|
|
|
pub fn init(host: []const u8, port: u16) UpstreamServer {
|
|
var buf: [4096]u8 = undefined;
|
|
var pba = std.heap.FixedBufferAllocator.init(&buf);
|
|
|
|
const addrList = net.getAddressList(pba.allocator(), host, port) catch unreachable;
|
|
defer addrList.deinit();
|
|
|
|
var final_addr: ?Address = null;
|
|
for (addrList.addrs) |addr| {
|
|
const ping = net.tcpConnectToAddress(addr) catch {
|
|
continue;
|
|
};
|
|
|
|
ping.close();
|
|
final_addr = addr;
|
|
}
|
|
|
|
std.debug.assert(final_addr != null);
|
|
|
|
var connections: [300]UpstreamConnection = undefined;
|
|
|
|
for (&connections) |*conn| {
|
|
conn.* = UpstreamConnection.init(final_addr.?);
|
|
}
|
|
|
|
return UpstreamServer{
|
|
.address = final_addr.?,
|
|
.pool = connections,
|
|
};
|
|
}
|
|
|
|
pub fn getAvailableConnection(self: *UpstreamServer) ?*UpstreamConnection {
|
|
for (&self.pool) |*conn| {
|
|
if (conn.state == .available) {
|
|
return conn;
|
|
}
|
|
}
|
|
|
|
return null;
|
|
}
|
|
};
|
|
|
|
pub const LoadBalancer = struct {
|
|
address: net.Address,
|
|
address_options: net.Address.ListenOptions,
|
|
servers: []UpstreamServer,
|
|
|
|
pub fn init(ip_map: []const u8, port: u16, servers: []UpstreamServer) !LoadBalancer {
|
|
const address = try net.Address.parseIp4(ip_map, port);
|
|
|
|
return LoadBalancer{
|
|
.address = address,
|
|
.address_options = net.Address.ListenOptions{},
|
|
.servers = servers,
|
|
};
|
|
}
|
|
|
|
pub fn start(self: *LoadBalancer) !void {
|
|
var server = try self.address.listen(self.address_options);
|
|
|
|
std.debug.print("Listening load balancer http://{}\n", .{self.address});
|
|
|
|
defer server.deinit();
|
|
var lb: usize = 0;
|
|
|
|
while (true) {
|
|
const conn = server.accept() catch |err| {
|
|
log.err("Error socket {}\n", .{err});
|
|
continue;
|
|
};
|
|
|
|
var upstream: ?*UpstreamConnection = null;
|
|
|
|
while (upstream == null) {
|
|
upstream = self.servers[lb % self.servers.len].getAvailableConnection();
|
|
lb += 1;
|
|
}
|
|
|
|
upstream.?.occupy();
|
|
|
|
var thread = std.Thread.spawn(.{ .stack_size = 1024 * 16 }, handleConnection, .{ self, conn, upstream.? }) catch |err| {
|
|
log.err("Creating thread error: {}\n", .{err});
|
|
conn.stream.close();
|
|
continue;
|
|
};
|
|
|
|
thread.detach();
|
|
}
|
|
}
|
|
|
|
fn handleConnection(
|
|
_: *LoadBalancer,
|
|
conn: net.Server.Connection,
|
|
upstream: *UpstreamConnection,
|
|
) void {
|
|
defer upstream.free();
|
|
defer conn.stream.close();
|
|
|
|
var buffer_request: [BUFFER_SIZE]u8 = undefined;
|
|
var buffer_response: [BUFFER_SIZE]u8 = undefined;
|
|
|
|
buffer_request[0] = 0;
|
|
|
|
while (true) {
|
|
var req_len: usize = 1;
|
|
|
|
while (true) {
|
|
const aux_len = conn.stream.read(buffer_request[req_len..]) catch |err| {
|
|
log.err("Error when read from connection {}\n", .{err});
|
|
return;
|
|
};
|
|
|
|
if (aux_len == 0) {
|
|
return;
|
|
}
|
|
|
|
req_len += aux_len;
|
|
|
|
if (buffer_request[1] == 'G' or buffer_request[8] == 'u') break;
|
|
|
|
if (req_len >= 25 + 22 + 30 + 20 + 32 + 2 + 70) {
|
|
break;
|
|
}
|
|
}
|
|
|
|
upstream.stream.writeAll(buffer_request[0..req_len]) catch |err| {
|
|
log.err("Error when writing to upstream {}\n", .{err});
|
|
return;
|
|
};
|
|
|
|
var res_len: usize = 0;
|
|
|
|
while (res_len == 0 or buffer_response[res_len - 1] != 0) {
|
|
res_len += upstream.stream.read(buffer_response[res_len..]) catch |err| {
|
|
log.err("Error when reading from upstream {}\n", .{err});
|
|
return;
|
|
};
|
|
}
|
|
|
|
_ = conn.stream.write(buffer_response[0 .. res_len - 1]) catch |err| {
|
|
log.err("Error when write from connection {}\n", .{err});
|
|
return;
|
|
};
|
|
}
|
|
}
|
|
};
|
|
|
|
const skip_tests = true;
|
|
|
|
test "expect connect to upstream server" {
|
|
if (skip_tests) {
|
|
return;
|
|
}
|
|
const upstream_server = UpstreamServer.init("localhost", 5001);
|
|
|
|
for (upstream_server.pool) |conn| {
|
|
try expect(conn.state == .available);
|
|
}
|
|
}
|
|
|
|
test "expect get the first available connection upstream server" {
|
|
if (skip_tests) {
|
|
return;
|
|
}
|
|
var upstream_server = UpstreamServer.init("localhost", 5001);
|
|
|
|
upstream_server.pool[0].state = .occupied;
|
|
upstream_server.pool[1].state = .occupied;
|
|
upstream_server.pool[3].state = .occupied;
|
|
|
|
try expect(upstream_server.getAvailableConnection().? == &upstream_server.pool[2]);
|
|
}
|
|
|
|
test "expect initiate load balancer" {
|
|
if (skip_tests) {
|
|
return;
|
|
}
|
|
const upstream_server = UpstreamServer.init("localhost", 5001);
|
|
var servers = [_]UpstreamServer{upstream_server};
|
|
|
|
var lb = try LoadBalancer.init("127.0.0.1", 9999, &servers);
|
|
|
|
try lb.start();
|
|
}
|