From 3ba548cca064f3f823352bff0900c3af42226726 Mon Sep 17 00:00:00 2001 From: Veikka Tuominen Date: Mon, 2 Oct 2023 15:44:50 +0300 Subject: [PATCH] Sema: fix issues in `@errorCast` with error unions --- src/Sema.zig | 21 ++++++++++++------- test/behavior/error.zig | 14 +++++++++++-- ...ast error union casted to disjoint set.zig | 20 ++++++++++++++++++ 3 files changed, 46 insertions(+), 9 deletions(-) create mode 100644 test/cases/safety/@errorCast error union casted to disjoint set.zig diff --git a/src/Sema.zig b/src/Sema.zig index 096ebb0589cf..27dd7221fc7b 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -21771,10 +21771,10 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData // operand must be defined since it can be an invalid error value const maybe_operand_val = try sema.resolveDefinedValue(block, operand_src, operand); - if (disjoint: { + const disjoint = disjoint: { // Try avoiding resolving inferred error sets if we can - if (!dest_ty.isAnyError(mod) and dest_ty.errorSetNames(mod).len == 0) break :disjoint true; - if (!operand_ty.isAnyError(mod) and operand_ty.errorSetNames(mod).len == 0) break :disjoint true; + if (!dest_ty.isAnyError(mod) and dest_ty.errorSetIsEmpty(mod)) break :disjoint true; + if (!operand_ty.isAnyError(mod) and operand_ty.errorSetIsEmpty(mod)) break :disjoint true; if (dest_ty.isAnyError(mod)) break :disjoint false; if (operand_ty.isAnyError(mod)) break :disjoint false; for (dest_ty.errorSetNames(mod)) |dest_err_name| { @@ -21796,7 +21796,8 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData } break :disjoint true; - }) { + }; + if (disjoint and dest_tag != .ErrorUnion) { const msg = msg: { const msg = try sema.errMsg( block, @@ -21850,10 +21851,16 @@ fn zirErrorCast(sema: *Sema, block: *Block, extended: Zir.Inst.Extended.InstData .int = .{ .ty = .u16_type, .storage = .{ .u64 = 0 } }, })); - const has_value = try block.addTyOp(.error_set_has_value, dest_ty, err_code); const is_zero = try block.addBinOp(.cmp_eq, err_int, zero_u16); - const ok = try block.addBinOp(.bit_or, has_value, is_zero); - try sema.addSafetyCheck(block, src, ok, .invalid_error_code); + if (disjoint) { + // Error must be zero. + try sema.addSafetyCheck(block, src, is_zero, .invalid_error_code); + } else { + // Error must be in destination set or zero. + const has_value = try block.addTyOp(.error_set_has_value, dest_ty, err_code); + const ok = try block.addBinOp(.bit_or, has_value, is_zero); + try sema.addSafetyCheck(block, src, ok, .invalid_error_code); + } } else { const err_int_inst = try block.addBitCast(Type.err_int, operand); const ok = try block.addTyOp(.error_set_has_value, dest_ty, err_int_inst); diff --git a/test/behavior/error.zig b/test/behavior/error.zig index 2c3ba3b8c769..c0725f87791a 100644 --- a/test/behavior/error.zig +++ b/test/behavior/error.zig @@ -238,13 +238,23 @@ fn testExplicitErrorSetCast(set1: Set1) !void { test "@errorCast on error unions" { const S = struct { fn doTheTest() !void { - const casted: error{Bad}!i32 = @errorCast(retErrUnion()); - try expect((try casted) == 1234); + { + const casted: error{Bad}!i32 = @errorCast(retErrUnion()); + try expect((try casted) == 1234); + } + { + const casted: error{Bad}!i32 = @errorCast(retInferredErrUnion()); + try expect((try casted) == 5678); + } } fn retErrUnion() anyerror!i32 { return 1234; } + + fn retInferredErrUnion() !i32 { + return 5678; + } }; try S.doTheTest(); diff --git a/test/cases/safety/@errorCast error union casted to disjoint set.zig b/test/cases/safety/@errorCast error union casted to disjoint set.zig new file mode 100644 index 000000000000..c9da2fc636ae --- /dev/null +++ b/test/cases/safety/@errorCast error union casted to disjoint set.zig @@ -0,0 +1,20 @@ +const std = @import("std"); + +pub fn panic(message: []const u8, stack_trace: ?*std.builtin.StackTrace, _: ?usize) noreturn { + _ = stack_trace; + if (std.mem.eql(u8, message, "invalid error code")) { + std.process.exit(0); + } + std.process.exit(1); +} +pub fn main() !void { + const bar: error{Foo}!i32 = @errorCast(foo()); + _ = &bar; + return error.TestFailed; +} +fn foo() anyerror!i32 { + return error.Bar; +} +// run +// backend=llvm +// target=native