diff --git a/lib/std/compress/deflate.zig b/lib/std/compress/deflate.zig index 2fef0b32bca5..aaa2ac4adbc4 100644 --- a/lib/std/compress/deflate.zig +++ b/lib/std/compress/deflate.zig @@ -45,7 +45,9 @@ const Huffman = struct { min_code_len: u16, - fn construct(self: *Huffman, code_length: []const u16) !void { + const ConstructError = error{ Oversubscribed, IncompleteSet }; + + fn construct(self: *Huffman, code_length: []const u16) ConstructError!void { for (self.count) |*val| { val.* = 0; } @@ -70,7 +72,7 @@ const Huffman = struct { // Make sure the number of codes with this length isn't too high. left -= @as(isize, @bitCast(i16, val)); if (left < 0) - return error.InvalidTree; + return error.Oversubscribed; } // Compute the offset of the first symbol represented by a code of a @@ -125,6 +127,9 @@ const Huffman = struct { self.last_code = codes[PREFIX_LUT_BITS + 1]; self.last_index = offset[PREFIX_LUT_BITS + 1] - self.count[PREFIX_LUT_BITS + 1]; + + if (left > 0) + return error.IncompleteSet; } }; @@ -322,7 +327,13 @@ pub fn InflateStream(comptime ReaderType: type) type { try lencode.construct(len_lengths[0..]); const dist_lengths = [_]u16{5} ** MAXDCODES; - try distcode.construct(dist_lengths[0..]); + distcode.construct(dist_lengths[0..]) catch |err| switch (err) { + // This error is expected because we only compute distance codes + // 0-29, which is fine since "distance codes 30-31 will never actually + // occur in the compressed data" (from section 3.2.6 of RFC1951). + error.IncompleteSet => {}, + else => return err, + }; } self.hlen = &lencode; @@ -357,7 +368,7 @@ pub fn InflateStream(comptime ReaderType: type) type { lengths[val] = @intCast(u16, try self.readBits(3)); } - try lencode.construct(lengths[0..]); + lencode.construct(lengths[0..]) catch return error.InvalidTree; } // Read the length/literal and distance code length tables. @@ -406,8 +417,24 @@ pub fn InflateStream(comptime ReaderType: type) type { if (lengths[256] == 0) return error.MissingEOBCode; - try self.huffman_tables[0].construct(lengths[0..nlen]); - try self.huffman_tables[1].construct(lengths[nlen .. nlen + ndist]); + self.huffman_tables[0].construct(lengths[0..nlen]) catch |err| switch (err) { + error.Oversubscribed => return error.InvalidTree, + error.IncompleteSet => { + // incomplete code ok only for single length 1 code + if (nlen != self.huffman_tables[0].count[0] + self.huffman_tables[0].count[1]) { + return error.InvalidTree; + } + }, + }; + self.huffman_tables[1].construct(lengths[nlen .. nlen + ndist]) catch |err| switch (err) { + error.Oversubscribed => return error.InvalidTree, + error.IncompleteSet => { + // incomplete code ok only for single length 1 code + if (ndist != self.huffman_tables[1].count[0] + self.huffman_tables[1].count[1]) { + return error.InvalidTree; + } + }, + }; self.hlen = &self.huffman_tables[0]; self.hdist = &self.huffman_tables[1]; @@ -675,8 +702,22 @@ test "empty distance alphabet" { test "inflateStream fuzzing" { // see https://github.com/ziglang/zig/issues/9842 - try std.testing.expectError(error.EndOfStream, testInflate("\x950000")); + try std.testing.expectError(error.EndOfStream, testInflate("\x95\x90=o\xc20\x10\x86\xf30")); try std.testing.expectError(error.OutOfCodes, testInflate("\x950\x00\x0000000")); + + // Huffman.construct errors + // lencode + try std.testing.expectError(error.InvalidTree, testInflate("\x950000")); + try std.testing.expectError(error.InvalidTree, testInflate("\x05000")); + // hlen + try std.testing.expectError(error.InvalidTree, testInflate("\x05\xea\x01\t\x00\x00\x00\x01\x00\\\xbf.\t\x00")); + // hdist + try std.testing.expectError(error.InvalidTree, testInflate("\x05\xe0\x01A\x00\x00\x00\x00\x10\\\xbf.")); + + // Huffman.construct -> error.IncompleteSet returns that shouldn't give error.InvalidTree + // (like the "empty distance alphabet" test but for ndist instead of nlen) + try std.testing.expectError(error.EndOfStream, testInflate("\x05\xe0\x01\t\x00\x00\x00\x00\x10\\\xbf\xce")); + try testInflate("\x15\xe0\x01\t\x00\x00\x00\x00\x10\\\xbf.0"); } fn testInflate(data: []const u8) !void {