Skip to content

Commit 1b34b27

Browse files
committed
std.crypto.tls.Client: fix verify_data for batched handshakes
1 parent 0908ddb commit 1b34b27

File tree

2 files changed

+56
-43
lines changed

2 files changed

+56
-43
lines changed

lib/std/crypto/tls.zig

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -221,6 +221,12 @@ pub const CipherSuite = enum(u16) {
221221
_,
222222
};
223223

224+
pub const CertificateType = enum(u8) {
225+
X509 = 0,
226+
RawPublicKey = 2,
227+
_,
228+
};
229+
224230
pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type {
225231
return struct {
226232
pub const AEAD = AeadType;
@@ -237,7 +243,6 @@ pub fn CipherParamsT(comptime AeadType: type, comptime HashType: type) type {
237243
client_handshake_iv: [AEAD.nonce_length]u8,
238244
server_handshake_iv: [AEAD.nonce_length]u8,
239245
transcript_hash: Hash,
240-
finished_digest: [Hash.digest_length]u8,
241246
};
242247
}
243248

lib/std/crypto/tls/Client.zig

Lines changed: 50 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -62,12 +62,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
6262
.rsa_pss_rsae_sha384,
6363
.rsa_pss_rsae_sha512,
6464
.ed25519,
65-
.ed448,
66-
.rsa_pss_pss_sha256,
67-
.rsa_pss_pss_sha384,
68-
.rsa_pss_pss_sha512,
69-
.rsa_pkcs1_sha1,
70-
.ecdsa_sha1,
7165
})) ++ tls.extension(.supported_groups, enum_array(tls.NamedGroup, &.{
7266
.secp256r1,
7367
.x25519,
@@ -98,24 +92,21 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
9892
int2(legacy_compression_methods) ++
9993
extensions_header;
10094

101-
const handshake =
95+
const out_handshake =
10296
[_]u8{@enumToInt(HandshakeType.client_hello)} ++
10397
int3(@intCast(u24, client_hello.len + host_len)) ++
10498
client_hello;
10599

106-
const hello_header = [_]u8{
107-
// Plaintext header
100+
const plaintext_header = [_]u8{
108101
@enumToInt(ContentType.handshake),
109102
0x03, 0x01, // legacy_record_version
110-
} ++
111-
int2(@intCast(u16, handshake.len + host_len)) ++
112-
handshake;
103+
} ++ int2(@intCast(u16, out_handshake.len + host_len)) ++ out_handshake;
113104

114105
{
115106
var iovecs = [_]std.os.iovec_const{
116107
.{
117-
.iov_base = &hello_header,
118-
.iov_len = hello_header.len,
108+
.iov_base = &plaintext_header,
109+
.iov_len = plaintext_header.len,
119110
},
120111
.{
121112
.iov_base = host.ptr,
@@ -125,7 +116,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
125116
try stream.writevAll(&iovecs);
126117
}
127118

128-
const client_hello_bytes1 = hello_header[5..];
119+
const client_hello_bytes1 = plaintext_header[5..];
129120

130121
var cipher_params: CipherParams = undefined;
131122

@@ -176,7 +167,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
176167
const cipher_suite_int = mem.readIntBig(u16, frag[i..][0..2]);
177168
i += 2;
178169
const cipher_suite_tag = @intToEnum(CipherSuite, cipher_suite_int);
179-
std.debug.print("server wants cipher suite {any}\n", .{cipher_suite_tag});
180170
const legacy_compression_method = frag[i];
181171
i += 1;
182172
_ = legacy_compression_method;
@@ -243,12 +233,7 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
243233
if (!have_shared_key) return error.TlsIllegalParameter;
244234
const tls_version = if (supported_version == 0) legacy_version else supported_version;
245235
switch (tls_version) {
246-
@enumToInt(tls.ProtocolVersion.tls_1_2) => {
247-
std.debug.print("server wants TLS v1.2\n", .{});
248-
},
249-
@enumToInt(tls.ProtocolVersion.tls_1_3) => {
250-
std.debug.print("server wants TLS v1.3\n", .{});
251-
},
236+
@enumToInt(tls.ProtocolVersion.tls_1_3) => {},
252237
else => return error.TlsIllegalParameter,
253238
}
254239

@@ -270,7 +255,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
270255
.client_handshake_iv = undefined,
271256
.server_handshake_iv = undefined,
272257
.transcript_hash = P.Hash.init(.{}),
273-
.finished_digest = undefined,
274258
});
275259
const p = &@field(cipher_params, @tagName(tag));
276260
p.transcript_hash.update(client_hello_bytes1); // Client Hello part 1
@@ -361,7 +345,6 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
361345
const ad = handshake_buf[end_hdr - 5 ..][0..5];
362346
P.AEAD.decrypt(cleartext, ciphertext, auth_tag, ad, nonce, p.server_handshake_key) catch
363347
return error.TlsBadRecordMac;
364-
p.transcript_hash.update(cleartext[0 .. cleartext.len - 1]);
365348
break :c cleartext;
366349
},
367350
};
@@ -378,17 +361,22 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
378361
const next_handshake_i = ct_i + handshake_len;
379362
if (next_handshake_i > cleartext.len - 1)
380363
return error.TlsBadLength;
364+
const wrapped_handshake = cleartext[ct_i - 4 .. next_handshake_i];
365+
const handshake = cleartext[ct_i..next_handshake_i];
381366
switch (handshake_type) {
382367
@enumToInt(HandshakeType.encrypted_extensions) => {
383-
const total_ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]);
384-
ct_i += 2;
385-
const end_ext_i = ct_i + total_ext_size;
386-
while (ct_i < end_ext_i) {
387-
const et = mem.readIntBig(u16, cleartext[ct_i..][0..2]);
388-
ct_i += 2;
389-
const ext_size = mem.readIntBig(u16, cleartext[ct_i..][0..2]);
390-
ct_i += 2;
391-
const next_ext_i = ct_i + ext_size;
368+
switch (cipher_params) {
369+
inline else => |*p| p.transcript_hash.update(wrapped_handshake),
370+
}
371+
const total_ext_size = mem.readIntBig(u16, handshake[0..2]);
372+
var hs_i: usize = 2;
373+
const end_ext_i = 2 + total_ext_size;
374+
while (hs_i < end_ext_i) {
375+
const et = mem.readIntBig(u16, handshake[hs_i..][0..2]);
376+
hs_i += 2;
377+
const ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]);
378+
hs_i += 2;
379+
const next_ext_i = hs_i + ext_size;
392380
switch (et) {
393381
@enumToInt(tls.ExtensionType.server_name) => {},
394382
else => {
@@ -397,19 +385,38 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
397385
});
398386
},
399387
}
400-
ct_i = next_ext_i;
388+
hs_i = next_ext_i;
401389
}
402390
},
403391
@enumToInt(HandshakeType.certificate) => {
404-
std.debug.print("cool certificate bro\n", .{});
392+
switch (cipher_params) {
393+
inline else => |*p| p.transcript_hash.update(wrapped_handshake),
394+
}
395+
var hs_i: usize = 0;
396+
const cert_req_ctx_len = handshake[hs_i];
397+
hs_i += 1;
398+
if (cert_req_ctx_len != 0) return error.TlsIllegalParameter;
399+
const certs_size = mem.readIntBig(u24, handshake[hs_i..][0..3]);
400+
hs_i += 3;
401+
const end_certs = hs_i + certs_size;
402+
while (hs_i < end_certs) {
403+
const cert_size = mem.readIntBig(u24, handshake[hs_i..][0..3]);
404+
hs_i += 3;
405+
hs_i += cert_size;
406+
const total_ext_size = mem.readIntBig(u16, handshake[hs_i..][0..2]);
407+
hs_i += 2;
408+
hs_i += total_ext_size;
409+
410+
std.debug.print("received certificate of size {d} bytes with {d} bytes of extensions\n", .{
411+
cert_size, total_ext_size,
412+
});
413+
}
405414
},
406415
@enumToInt(HandshakeType.certificate_verify) => {
407-
std.debug.print("the certificate came with a fancy signature\n", .{});
408416
switch (cipher_params) {
409-
inline else => |*p| {
410-
p.finished_digest = p.transcript_hash.peek();
411-
},
417+
inline else => |*p| p.transcript_hash.update(wrapped_handshake),
412418
}
419+
std.debug.print("ignoring certificate_verify\n", .{});
413420
},
414421
@enumToInt(HandshakeType.finished) => {
415422
// This message is to trick buggy proxies into behaving correctly.
@@ -422,9 +429,10 @@ pub fn init(stream: net.Stream, host: []const u8) !Client {
422429
const app_cipher = switch (cipher_params) {
423430
inline else => |*p, tag| c: {
424431
const P = @TypeOf(p.*);
425-
const expected_server_verify_data = tls.hmac(P.Hmac, &p.finished_digest, p.server_finished_key);
426-
const actual_server_verify_data = cleartext[ct_i..][0..handshake_len];
427-
if (!mem.eql(u8, &expected_server_verify_data, actual_server_verify_data))
432+
const finished_digest = p.transcript_hash.peek();
433+
p.transcript_hash.update(wrapped_handshake);
434+
const expected_server_verify_data = tls.hmac(P.Hmac, &finished_digest, p.server_finished_key);
435+
if (!mem.eql(u8, &expected_server_verify_data, handshake))
428436
return error.TlsDecryptError;
429437
const handshake_hash = p.transcript_hash.finalResult();
430438
const verify_data = tls.hmac(P.Hmac, &handshake_hash, p.client_finished_key);

0 commit comments

Comments
 (0)