Skip to content

Commit 524e0cd

Browse files
committed
std.http: rework connection pool into its own type
1 parent 634e715 commit 524e0cd

File tree

4 files changed

+134
-87
lines changed

4 files changed

+134
-87
lines changed

lib/std/http/Client.zig

Lines changed: 117 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ const testing = std.testing;
1616
pub const Request = @import("Client/Request.zig");
1717
pub const Response = @import("Client/Response.zig");
1818

19+
pub const default_connection_pool_size = 32;
20+
const connection_pool_size = std.options.http_connection_pool_size;
21+
1922
/// Used for tcpConnectToHost and storing HTTP headers when an externally
2023
/// managed buffer is not provided.
2124
allocator: Allocator,
@@ -24,39 +27,115 @@ ca_bundle: std.crypto.Certificate.Bundle = .{},
2427
/// it will first rescan the system for root certificates.
2528
next_https_rescan_certs: bool = true,
2629

27-
connection_mutex: std.Thread.Mutex = .{},
2830
connection_pool: ConnectionPool = .{},
29-
connection_used: ConnectionPool = .{},
3031

31-
pub const ConnectionPool = std.TailQueue(Connection);
32-
pub const ConnectionNode = ConnectionPool.Node;
32+
pub const ConnectionPool = struct {
33+
pub const Criteria = struct {
34+
host: []const u8,
35+
port: u16,
36+
is_tls: bool,
37+
};
3338

34-
/// Acquires an existing connection from the connection pool. This function is threadsafe.
35-
/// If the caller already holds the connection mutex, it should pass `true` for `held`.
36-
pub fn acquire(client: *Client, node: *ConnectionNode, held: bool) void {
37-
if (!held) client.connection_mutex.lock();
38-
defer if (!held) client.connection_mutex.unlock();
39+
const Queue = std.TailQueue(Connection);
40+
pub const Node = Queue.Node;
41+
42+
mutex: std.Thread.Mutex = .{},
43+
used: Queue = .{},
44+
free: Queue = .{},
45+
free_len: usize = 0,
46+
free_size: usize = default_connection_pool_size,
47+
48+
/// Finds and acquires a connection from the connection pool matching the criteria. This function is threadsafe.
49+
/// If no connection is found, null is returned.
50+
pub fn findConnection(pool: *ConnectionPool, criteria: Criteria) ?*Node {
51+
pool.mutex.lock();
52+
defer pool.mutex.unlock();
53+
54+
var next = pool.free.last;
55+
while (next) |node| : (next = node.prev) {
56+
if ((node.data.protocol == .tls) != criteria.is_tls) continue;
57+
if (node.data.port != criteria.port) continue;
58+
if (std.mem.eql(u8, node.data.host, criteria.host)) continue;
59+
60+
pool.acquireUnsafe(node);
61+
return node;
62+
}
3963

40-
client.connection_pool.remove(node);
41-
client.connection_used.append(node);
42-
}
64+
return null;
65+
}
4366

44-
/// Tries to release a connection back to the connection pool. This function is threadsafe.
45-
/// If the connection is marked as closing, it will be closed instead.
46-
pub fn release(client: *Client, node: *ConnectionNode) void {
47-
client.connection_mutex.lock();
48-
defer client.connection_mutex.unlock();
67+
/// Acquires an existing connection from the connection pool. This function is not threadsafe.
68+
pub fn acquireUnsafe(pool: *ConnectionPool, node: *Node) void {
69+
pool.free.remove(node);
70+
pool.free_len -= 1;
4971

50-
client.connection_used.remove(node);
72+
pool.used.append(node);
73+
}
5174

52-
if (node.data.closing) {
53-
node.data.close(client);
75+
/// Acquires an existing connection from the connection pool. This function is threadsafe.
76+
pub fn acquire(pool: *ConnectionPool, node: *Node) void {
77+
pool.mutex.lock();
78+
defer pool.mutex.unlock();
5479

55-
return client.allocator.destroy(node);
80+
return pool.acquireUnsafe(node);
5681
}
5782

58-
client.connection_pool.append(node);
59-
}
83+
/// Tries to release a connection back to the connection pool. This function is threadsafe.
84+
/// If the connection is marked as closing, it will be closed instead.
85+
pub fn release(pool: *ConnectionPool, client: *Client, node: *Node) void {
86+
pool.mutex.lock();
87+
defer pool.mutex.unlock();
88+
89+
pool.used.remove(node);
90+
91+
if (node.data.closing) {
92+
node.data.close(client);
93+
94+
return client.allocator.destroy(node);
95+
}
96+
97+
if (pool.free_len + 1 >= pool.free_size) {
98+
const popped = pool.free.popFirst() orelse unreachable;
99+
100+
popped.data.close(client);
101+
102+
return client.allocator.destroy(popped);
103+
}
104+
105+
pool.free.append(node);
106+
pool.free_len += 1;
107+
}
108+
109+
/// Adds a newly created node to the pool of used connections. This function is threadsafe.
110+
pub fn addUsed(pool: *ConnectionPool, node: *Node) void {
111+
pool.mutex.lock();
112+
defer pool.mutex.unlock();
113+
114+
pool.used.append(node);
115+
}
116+
117+
pub fn deinit(pool: *ConnectionPool, client: *Client) void {
118+
pool.mutex.lock();
119+
120+
var next = pool.free.first;
121+
while (next) |node| {
122+
defer client.allocator.destroy(node);
123+
next = node.next;
124+
125+
node.data.close(client);
126+
}
127+
128+
next = pool.used.first;
129+
while (next) |node| {
130+
defer client.allocator.destroy(node);
131+
next = node.next;
132+
133+
node.data.close(client);
134+
}
135+
136+
pool.* = undefined;
137+
}
138+
};
60139

61140
pub const DeflateDecompressor = std.compress.zlib.ZlibStream(Request.ReaderRaw);
62141
pub const GzipDecompressor = std.compress.gzip.Decompress(Request.ReaderRaw);
@@ -142,62 +221,33 @@ pub const Connection = struct {
142221
};
143222

144223
pub fn deinit(client: *Client) void {
145-
client.connection_mutex.lock();
146-
147-
var next = client.connection_pool.first;
148-
while (next) |node| {
149-
next = node.next;
150-
151-
node.data.close(client);
152-
153-
client.allocator.destroy(node);
154-
}
155-
156-
next = client.connection_used.first;
157-
while (next) |node| {
158-
next = node.next;
159-
160-
node.data.close(client);
161-
162-
client.allocator.destroy(node);
163-
}
224+
client.connection_pool.deinit(client);
164225

165226
client.ca_bundle.deinit(client.allocator);
166227
client.* = undefined;
167228
}
168229

169230
pub const ConnectError = std.mem.Allocator.Error || net.TcpConnectToHostError || std.crypto.tls.Client.InitError(net.Stream);
170231

171-
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionNode {
172-
{ // Search through the connection pool for a potential connection.
173-
client.connection_mutex.lock();
174-
defer client.connection_mutex.unlock();
175-
176-
var potential = client.connection_pool.last;
177-
while (potential) |node| {
178-
const same_host = mem.eql(u8, node.data.host, host);
179-
const same_port = node.data.port == port;
180-
const same_protocol = node.data.protocol == protocol;
181-
182-
if (same_host and same_port and same_protocol) {
183-
client.acquire(node, true);
184-
return node;
185-
}
186-
187-
potential = node.prev;
188-
}
189-
}
232+
pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connection.Protocol) ConnectError!*ConnectionPool.Node {
233+
if (client.connection_pool.findConnection(.{
234+
.host = host,
235+
.port = port,
236+
.is_tls = protocol == .tls,
237+
})) |node|
238+
return node;
190239

191-
const conn = try client.allocator.create(ConnectionNode);
240+
const conn = try client.allocator.create(ConnectionPool.Node);
192241
errdefer client.allocator.destroy(conn);
242+
conn.* = .{ .data = undefined };
193243

194-
conn.* = .{ .data = .{
244+
conn.data = .{
195245
.stream = try net.tcpConnectToHost(client.allocator, host, port),
196246
.tls_client = undefined,
197247
.protocol = protocol,
198248
.host = try client.allocator.dupe(u8, host),
199249
.port = port,
200-
} };
250+
};
201251

202252
switch (protocol) {
203253
.plain => {},
@@ -210,12 +260,7 @@ pub fn connect(client: *Client, host: []const u8, port: u16, protocol: Connectio
210260
},
211261
}
212262

213-
{
214-
client.connection_mutex.lock();
215-
defer client.connection_mutex.unlock();
216-
217-
client.connection_used.append(conn);
218-
}
263+
client.connection_pool.addUsed(conn);
219264

220265
return conn;
221266
}
@@ -247,8 +292,8 @@ pub fn request(client: *Client, uri: Uri, headers: Request.Headers, options: Req
247292
const host = uri.host orelse return error.UriMissingHost;
248293

249294
if (client.next_https_rescan_certs and protocol == .tls) {
250-
client.connection_mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
251-
defer client.connection_mutex.unlock();
295+
client.connection_pool.mutex.lock(); // TODO: this could be so much better than reusing the connection pool mutex.
296+
defer client.connection_pool.mutex.unlock();
252297

253298
if (client.next_https_rescan_certs) {
254299
try client.ca_bundle.rescan(client.allocator);

lib/std/http/Client/Request.zig

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ const assert = std.debug.assert;
66

77
const Client = @import("../Client.zig");
88
const Connection = Client.Connection;
9-
const ConnectionNode = Client.ConnectionNode;
9+
const ConnectionNode = Client.ConnectionPool.Node;
1010
const Response = @import("Response.zig");
1111

1212
const Request = @This();
@@ -85,7 +85,7 @@ pub fn deinit(req: *Request) void {
8585
if (!req.response.done) {
8686
// If the response wasn't fully read, then we need to close the connection.
8787
req.connection.data.closing = true;
88-
req.client.release(req.connection);
88+
req.client.connection_pool.release(req.client, req.connection);
8989
}
9090

9191
req.arena.deinit();
@@ -135,7 +135,7 @@ fn checkForCompleteHead(req: *Request, buffer: []u8) !usize {
135135
if (req.response.state == .finished) {
136136
req.response.headers = try Response.Headers.parse(req.response.header_bytes.items);
137137

138-
if (req.response.upgrade) |_| {
138+
if (req.response.headers.upgrade) |_| {
139139
req.connection.data.closing = false;
140140
req.response.done = true;
141141
return i;
@@ -226,7 +226,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
226226
req.response.next_chunk_length -= can_read;
227227

228228
if (req.response.next_chunk_length == 0) {
229-
req.client.release(req.connection);
229+
req.client.connection_pool.release(req.client, req.connection);
230230
req.connection = undefined;
231231
req.response.done = true;
232232
}
@@ -241,7 +241,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
241241
req.read_buffer_start += @intCast(ReadBufferIndex, can_read);
242242

243243
if (req.response.next_chunk_length == 0) {
244-
req.client.release(req.connection);
244+
req.client.connection_pool.release(req.client, req.connection);
245245
req.connection = undefined;
246246
req.response.done = true;
247247
}
@@ -293,7 +293,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
293293
.chunk_data => {
294294
if (req.response.next_chunk_length == 0) {
295295
req.response.done = true;
296-
req.client.release(req.connection);
296+
req.client.connection_pool.release(req.client, req.connection);
297297
req.connection = undefined;
298298

299299
return out_index;
@@ -317,7 +317,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
317317
req.response.next_chunk_length -= can_read;
318318

319319
if (req.response.next_chunk_length == 0) {
320-
req.client.release(req.connection);
320+
req.client.connection_pool.release(req.client, req.connection);
321321
req.connection = undefined;
322322
req.response.done = true;
323323
continue;
@@ -345,13 +345,7 @@ fn readRawAdvanced(req: *Request, buffer: []u8) !usize {
345345
}
346346
}
347347

348-
pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{
349-
BadHeader,
350-
InvalidCompression,
351-
StreamTooLong,
352-
InvalidWindowSize,
353-
CompressionNotSupported
354-
};
348+
pub const ReadError = Client.DeflateDecompressor.Error || Client.GzipDecompressor.Error || Client.ZstdDecompressor.Error || WaitForCompleteHeadError || error{ BadHeader, InvalidCompression, StreamTooLong, InvalidWindowSize, CompressionNotSupported };
355349

356350
pub const Reader = std.io.Reader(*Request, ReadError, read);
357351

lib/std/http/Client/Response.zig

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@ pub const Headers = struct {
3232
transfer_encoding: ?http.TransferEncoding = null,
3333
transfer_compression: ?http.ContentEncoding = null,
3434
connection: http.Connection = .close,
35+
upgrade: ?[]const u8 = null,
3536

3637
number_of_headers: usize = 0,
3738

@@ -93,7 +94,7 @@ pub const Headers = struct {
9394

9495
if (iter.next()) |second| {
9596
if (headers.transfer_compression != null) return error.HttpTransferEncodingUnsupported;
96-
97+
9798
const trimmed = std.mem.trim(u8, second, " ");
9899

99100
if (std.meta.stringToEnum(http.ContentEncoding, trimmed)) |ce| {
@@ -122,6 +123,8 @@ pub const Headers = struct {
122123
} else {
123124
return error.HttpConnectionHeaderUnsupported;
124125
}
126+
} else if (std.ascii.eqlIgnoreCase(header_name, "upgrade")) {
127+
headers.upgrade = header_value;
125128
}
126129
}
127130

lib/std/std.zig

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -185,6 +185,11 @@ pub const options = struct {
185185
options_override.keep_sigpipe
186186
else
187187
false;
188+
189+
pub const http_connection_pool_size = if (@hasDecl(options_override, "http_connection_pool_size"))
190+
options_override.http_connection_pool_size
191+
else
192+
http.Client.default_connection_pool_size;
188193
};
189194

190195
// This forces the start.zig file to be imported, and the comptime logic inside that

0 commit comments

Comments
 (0)