From d27e73c34a8eb87db07f961bc337c2bc05a789a1 Mon Sep 17 00:00:00 2001 From: Jacob Young Date: Sun, 28 Apr 2024 13:13:40 -0400 Subject: [PATCH 1/9] Builder: add `indirectbr` llvm instruction --- src/codegen/llvm/Builder.zig | 97 ++++++++++++++++++++++++++++++------ src/codegen/llvm/ir.zig | 14 ++++++ 2 files changed, 96 insertions(+), 15 deletions(-) diff --git a/src/codegen/llvm/Builder.zig b/src/codegen/llvm/Builder.zig index f0855eee2797..446c8f64a5a8 100644 --- a/src/codegen/llvm/Builder.zig +++ b/src/codegen/llvm/Builder.zig @@ -4080,6 +4080,7 @@ pub const Function = struct { @"icmp ugt", @"icmp ule", @"icmp ult", + indirectbr, insertelement, insertvalue, inttoptr, @@ -4290,6 +4291,7 @@ pub const Function = struct { return switch (wip.instructions.items(.tag)[@intFromEnum(self)]) { .br, .br_cond, + .indirectbr, .ret, .@"ret void", .@"switch", @@ -4304,6 +4306,7 @@ pub const Function = struct { .br, .br_cond, .fence, + .indirectbr, .ret, .@"ret void", .store, @@ -4394,6 +4397,7 @@ pub const Function = struct { .br, .br_cond, .fence, + .indirectbr, .ret, .@"ret void", .store, @@ -4580,6 +4584,7 @@ pub const Function = struct { .br, .br_cond, .fence, + .indirectbr, .ret, .@"ret void", .store, @@ -4750,6 +4755,12 @@ pub const Function = struct { //case_blocks: [cases_len]Block.Index, }; + pub const IndirectBr = struct { + addr: Value, + targets_len: u32, + //targets: [targets_len]Block.Index, + }; + pub const Binary = struct { lhs: Value, rhs: Value, @@ -5186,10 +5197,27 @@ pub const WipFunction = struct { return .{ .index = 0, .instruction = instruction }; } + pub fn indirectbr( + self: *WipFunction, + addr: Value, + targets: []const Block.Index, + ) Allocator.Error!Instruction.Index { + try self.ensureUnusedExtraCapacity(1, Instruction.IndirectBr, targets.len); + const instruction = try self.addInst(null, .{ + .tag = .indirectbr, + .data = self.addExtraAssumeCapacity(Instruction.IndirectBr{ + .addr = addr, + .targets_len = @intCast(targets.len), + }), + }); + _ = self.extra.appendSliceAssumeCapacity(@ptrCast(targets)); + for (targets) |target| target.ptr(self).branches += 1; + return instruction; + } + pub fn @"unreachable"(self: *WipFunction) Allocator.Error!Instruction.Index { try self.ensureUnusedExtraCapacity(1, NoExtra, 0); - const instruction = try self.addInst(null, .{ .tag = .@"unreachable", .data = undefined }); - return instruction; + return try self.addInst(null, .{ .tag = .@"unreachable", .data = undefined }); } pub fn un( @@ -6159,8 +6187,7 @@ pub const WipFunction = struct { }); names[@intFromEnum(new_block_index)] = try wip_name.map(current_block.name, ""); for (current_block.instructions.items) |old_instruction_index| { - const new_instruction_index: Instruction.Index = - @enumFromInt(function.instructions.len); + const new_instruction_index: Instruction.Index = @enumFromInt(function.instructions.len); var instruction = self.instructions.get(@intFromEnum(old_instruction_index)); switch (instruction.tag) { .add, @@ -6368,6 +6395,15 @@ pub const WipFunction = struct { }); wip_extra.appendMappedValues(indices, instructions); }, + .indirectbr => { + var extra = self.extraDataTrail(Instruction.IndirectBr, instruction.data); + const targets = extra.trail.next(extra.data.targets_len, Block.Index, self); + instruction.data = wip_extra.addExtra(Instruction.IndirectBr{ + .addr = instructions.map(extra.data.addr), + .targets_len = extra.data.targets_len, + }); + wip_extra.appendSlice(targets); + }, .insertelement => { const extra = self.extraData(Instruction.InsertElement, instruction.data); instruction.data = wip_extra.addExtra(Instruction.InsertElement{ @@ -7411,10 +7447,10 @@ pub const Constant = enum(u32) { .blockaddress => |tag| { const extra = data.builder.constantExtraData(BlockAddress, item.data); const function = extra.function.ptrConst(data.builder); - try writer.print("{s}({}, %{d})", .{ + try writer.print("{s}({}, {})", .{ @tagName(tag), function.global.fmt(data.builder), - @intFromEnum(extra.block), // TODO + extra.block.toInst(function).fmt(extra.function, data.builder), }); }, .dso_local_equivalent, @@ -9736,6 +9772,23 @@ pub fn printUnbuffered( index.fmt(function_index, self), }); }, + .indirectbr => |tag| { + var extra = + function.extraDataTrail(Function.Instruction.IndirectBr, instruction.data); + const targets = + extra.trail.next(extra.data.targets_len, Function.Block.Index, &function); + try writer.print(" {s} {%}, [", .{ + @tagName(tag), + extra.data.addr.fmt(function_index, self), + }); + for (0.., targets) |target_index, target| { + if (target_index > 0) try writer.writeAll(", "); + try writer.print("{%}", .{ + target.toInst(&function).fmt(function_index, self), + }); + } + try writer.writeByte(']'); + }, .insertelement => |tag| { const extra = function.extraData(Function.Instruction.InsertElement, instruction.data); @@ -14512,15 +14565,6 @@ pub fn toBitcode(self: *Builder, allocator: Allocator) bitcode_writer.Error![]co .indices = indices, }); }, - .insertvalue => { - var extra = func.extraDataTrail(Function.Instruction.InsertValue, datas[instr_index]); - const indices = extra.trail.next(extra.data.indices_len, u32, &func); - try function_block.writeAbbrev(FunctionBlock.InsertValue{ - .val = adapter.getOffsetValueIndex(extra.data.val), - .elem = adapter.getOffsetValueIndex(extra.data.elem), - .indices = indices, - }); - }, .extractelement => { const extra = func.extraData(Function.Instruction.ExtractElement, datas[instr_index]); try function_block.writeAbbrev(FunctionBlock.ExtractElement{ @@ -14528,6 +14572,20 @@ pub fn toBitcode(self: *Builder, allocator: Allocator) bitcode_writer.Error![]co .index = adapter.getOffsetValueIndex(extra.index), }); }, + .indirectbr => { + var extra = + func.extraDataTrail(Function.Instruction.IndirectBr, datas[instr_index]); + const targets = + extra.trail.next(extra.data.targets_len, Function.Block.Index, &func); + try function_block.writeAbbrevAdapted( + FunctionBlock.IndirectBr{ + .ty = extra.data.addr.typeOf(@enumFromInt(func_index), self), + .addr = extra.data.addr, + .targets = targets, + }, + adapter, + ); + }, .insertelement => { const extra = func.extraData(Function.Instruction.InsertElement, datas[instr_index]); try function_block.writeAbbrev(FunctionBlock.InsertElement{ @@ -14536,6 +14594,15 @@ pub fn toBitcode(self: *Builder, allocator: Allocator) bitcode_writer.Error![]co .index = adapter.getOffsetValueIndex(extra.index), }); }, + .insertvalue => { + var extra = func.extraDataTrail(Function.Instruction.InsertValue, datas[instr_index]); + const indices = extra.trail.next(extra.data.indices_len, u32, &func); + try function_block.writeAbbrev(FunctionBlock.InsertValue{ + .val = adapter.getOffsetValueIndex(extra.data.val), + .elem = adapter.getOffsetValueIndex(extra.data.elem), + .indices = indices, + }); + }, .select => { const extra = func.extraData(Function.Instruction.Select, datas[instr_index]); try function_block.writeAbbrev(FunctionBlock.Select{ diff --git a/src/codegen/llvm/ir.zig b/src/codegen/llvm/ir.zig index 6a7c6c6857db..adf9aa16b057 100644 --- a/src/codegen/llvm/ir.zig +++ b/src/codegen/llvm/ir.zig @@ -19,6 +19,7 @@ const LineAbbrev = AbbrevOp{ .vbr = 8 }; const ColumnAbbrev = AbbrevOp{ .vbr = 8 }; const BlockAbbrev = AbbrevOp{ .vbr = 6 }; +const BlockArrayAbbrev = AbbrevOp{ .array_vbr = 6 }; pub const MetadataKind = enum(u1) { dbg = 0, @@ -1132,6 +1133,7 @@ pub const FunctionBlock = struct { Fence, DebugLoc, DebugLocAgain, + IndirectBr, }; pub const DeclareBlocks = struct { @@ -1644,6 +1646,18 @@ pub const FunctionBlock = struct { .{ .literal = 33 }, }; }; + + pub const IndirectBr = struct { + pub const ops = [_]AbbrevOp{ + .{ .literal = 31 }, + .{ .fixed_runtime = Builder.Type }, + ValueAbbrev, + BlockArrayAbbrev, + }; + ty: Builder.Type, + addr: Builder.Value, + targets: []const Builder.Function.Block.Index, + }; }; pub const FunctionValueSymbolTable = struct { From 0d99a13e39cd5a011d35c1405a69c55eea6a9e4b Mon Sep 17 00:00:00 2001 From: mlugg Date: Wed, 24 Apr 2024 18:52:03 +0100 Subject: [PATCH 2/9] Air: direct representation of ranges in switch cases This commit modifies the representation of the AIR `switch_br` instruction to represent ranges in cases. Previously, Sema emitted different AIR in the case of a range, where the `else` branch of the `switch_br` contained a simple `cond_br` for each such case which did a simple range check (`x > a and x < b`). Not only does this add complexity to Sema, which -- as our secondary bottleneck -- we would like to keep as small as possible, but it also gets in the way of the implementation of #8220. This proposal turns certain `switch` statements into a looping construct, and for optimization purposes, we want to lower this to AIR fairly directly (i.e. without involving a `loop` instruction). That means we would ideally like a single instruction to represent the entire `switch` statement, so that we can dispatch back to it with a different operand as in #8220. This is not really possible to do correctly under the status quo system. For now, the actual lowering of `switch` is identical for the LLVM and C backends. This commit contains a TODO which temporarily regresseses all remaining self-hosted backends in the presence of switch case ranges. This functionality will be restored for at least the x86_64 backend before merge of this branch. --- src/Air.zig | 8 +- src/Liveness.zig | 10 +- src/Liveness/Verify.zig | 9 +- src/Sema.zig | 334 +++++++++++++++-------------------- src/arch/aarch64/CodeGen.zig | 3 +- src/arch/arm/CodeGen.zig | 1 + src/arch/wasm/CodeGen.zig | 1 + src/arch/x86_64/CodeGen.zig | 1 + src/codegen/c.zig | 73 ++++++-- src/codegen/llvm.zig | 72 +++++++- src/codegen/spirv.zig | 1 + src/print_air.zig | 17 +- 12 files changed, 297 insertions(+), 233 deletions(-) diff --git a/src/Air.zig b/src/Air.zig index 9554c55561a5..91ffed071120 100644 --- a/src/Air.zig +++ b/src/Air.zig @@ -1132,17 +1132,19 @@ pub const CondBr = struct { }; /// Trailing: -/// * 0. `Case` for each `cases_len` -/// * 1. the else body, according to `else_body_len`. +/// * 0. case: Case // for each `cases_len`. +/// * 1. else_body_inst: Inst.Index // for each `else_body_len`. pub const SwitchBr = struct { cases_len: u32, else_body_len: u32, /// Trailing: /// * item: Inst.Ref // for each `items_len`. - /// * instruction index for each `body_len`. + /// * { range_start: Inst.Ref, range_end: Inst.Ref } // for each `ranges_len`. + /// * body_inst: Inst.Index // for each `body_len`. pub const Case = struct { items_len: u32, + ranges_len: u32, body_len: u32, }; }; diff --git a/src/Liveness.zig b/src/Liveness.zig index 4ca28758e222..ee37ea76b6e2 100644 --- a/src/Liveness.zig +++ b/src/Liveness.zig @@ -1681,8 +1681,9 @@ fn analyzeInstSwitchBr( var air_extra_index: usize = switch_br.end; for (0..ncases) |_| { const case = a.air.extraData(Air.SwitchBr.Case, air_extra_index); - const case_body: []const Air.Inst.Index = @ptrCast(a.air.extra[case.end + case.data.items_len ..][0..case.data.body_len]); - air_extra_index = case.end + case.data.items_len + case_body.len; + air_extra_index = case.end + case.data.items_len + 2 * case.data.ranges_len; + const case_body: []const Air.Inst.Index = @ptrCast(a.air.extra[air_extra_index..][0..case.data.body_len]); + air_extra_index += case_body.len; try analyzeBody(a, pass, data, case_body); } { // else @@ -1707,8 +1708,9 @@ fn analyzeInstSwitchBr( var air_extra_index: usize = switch_br.end; for (case_live_sets[0..ncases]) |*live_set| { const case = a.air.extraData(Air.SwitchBr.Case, air_extra_index); - const case_body: []const Air.Inst.Index = @ptrCast(a.air.extra[case.end + case.data.items_len ..][0..case.data.body_len]); - air_extra_index = case.end + case.data.items_len + case_body.len; + air_extra_index = case.end + case.data.items_len + 2 * case.data.ranges_len; + const case_body: []const Air.Inst.Index = @ptrCast(a.air.extra[air_extra_index..][0..case.data.body_len]); + air_extra_index += case_body.len; try analyzeBody(a, pass, data, case_body); live_set.* = data.live_set.move(); } diff --git a/src/Liveness/Verify.zig b/src/Liveness/Verify.zig index 4392f25e101d..ef14033828b2 100644 --- a/src/Liveness/Verify.zig +++ b/src/Liveness/Verify.zig @@ -526,12 +526,9 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void { while (case_i < switch_br.data.cases_len) : (case_i += 1) { const case = self.air.extraData(Air.SwitchBr.Case, extra_index); - const items = @as( - []const Air.Inst.Ref, - @ptrCast(self.air.extra[case.end..][0..case.data.items_len]), - ); - const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]); - extra_index = case.end + items.len + case_body.len; + extra_index = case.end + case.data.items_len + case.data.ranges_len * 2; + const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[extra_index..][0..case.data.body_len]); + extra_index += case_body.len; self.live.deinit(self.gpa); self.live = try live.clone(self.gpa); diff --git a/src/Sema.zig b/src/Sema.zig index 74cea620e35c..5cd5c2fb860e 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -11415,9 +11415,14 @@ const SwitchProngAnalysis = struct { }; _ = try coerce_block.addBr(capture_block_inst, coerced); - try cases_extra.ensureUnusedCapacity(3 + coerce_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(coerce_block.instructions.items.len)); // body_len + try cases_extra.ensureUnusedCapacity(@typeInfo(Air.SwitchBr.Case).Struct.fields.len + + 1 + // `item`, no ranges + coerce_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(coerce_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(case_vals[idx])); // item cases_extra.appendSliceAssumeCapacity(@ptrCast(coerce_block.instructions.items)); // body } @@ -12587,20 +12592,18 @@ fn analyzeSwitchRuntimeBlock( _ = try case_block.addNoOp(.unreach); } - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).Struct.fields.len + + 1 + // item + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(item)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } - var is_first = true; - var prev_cond_br: Air.Inst.Index = undefined; - var first_else_body: []const Air.Inst.Index = &.{}; - defer gpa.free(first_else_body); - var prev_then_body: []const Air.Inst.Index = &.{}; - defer gpa.free(prev_then_body); - var cases_len = scalar_cases_len; var case_val_idx: usize = scalar_cases_len; var multi_i: u32 = 0; @@ -12674,9 +12677,14 @@ fn analyzeSwitchRuntimeBlock( info.has_tag_capture, ); - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).Struct.fields.len + + 1 + // item + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(item_ref)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); @@ -12724,9 +12732,14 @@ fn analyzeSwitchRuntimeBlock( _ = try case_block.addNoOp(.unreach); } - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).Struct.fields.len + + 1 + // item + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(item)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } @@ -12735,150 +12748,62 @@ fn analyzeSwitchRuntimeBlock( continue; } - var any_ok: Air.Inst.Ref = .none; - - // If there are any ranges, we have to put all the items into the - // else prong. Otherwise, we can take advantage of multiple items - // mapping to the same body. - if (ranges_len == 0) { - cases_len += 1; - - const analyze_body = if (union_originally) - for (items) |item| { - const item_val = sema.resolveConstDefinedValue(block, .unneeded, item, undefined) catch unreachable; - const field_ty = maybe_union_ty.unionFieldType(item_val, mod).?; - if (field_ty.zigTypeTag(mod) != .NoReturn) break true; - } else false - else - true; - - const body = sema.code.bodySlice(extra_index, info.body_len); - extra_index += info.body_len; - if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand, operand_src, allow_err_code_unwrap)) { - // nothing to do here - } else if (analyze_body) { - try spa.analyzeProngRuntime( - &case_block, - .normal, - body, - info.capture, - .{ .multi_capture = multi_i }, - items, - .none, - false, - ); - } else { - _ = try case_block.addNoOp(.unreach); - } - - try cases_extra.ensureUnusedCapacity(gpa, 2 + items.len + - case_block.instructions.items.len); - - cases_extra.appendAssumeCapacity(@intCast(items.len)); - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + cases_len += 1; + extra_index += 2 * ranges_len; + const analyze_body = if (union_originally) for (items) |item| { - cases_extra.appendAssumeCapacity(@intFromEnum(item)); - } + const item_val = sema.resolveConstDefinedValue(block, .unneeded, item, undefined) catch unreachable; + const field_ty = maybe_union_ty.unionFieldType(item_val, mod).?; + if (field_ty.zigTypeTag(mod) != .NoReturn) break true; + } else false + else + true; - cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); + const body = sema.code.bodySlice(extra_index, info.body_len); + extra_index += info.body_len; + if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand, operand_src, allow_err_code_unwrap)) { + // nothing to do here + } else if (analyze_body) { + try spa.analyzeProngRuntime( + &case_block, + .normal, + body, + info.capture, + .{ .multi_capture = multi_i }, + items, + .none, + false, + ); } else { - for (items) |item| { - const cmp_ok = try case_block.addBinOp(if (case_block.float_mode == .optimized) .cmp_eq_optimized else .cmp_eq, operand, item); - if (any_ok != .none) { - any_ok = try case_block.addBinOp(.bool_or, any_ok, cmp_ok); - } else { - any_ok = cmp_ok; - } - } - - var range_i: usize = 0; - while (range_i < ranges_len) : (range_i += 1) { - const range_items = case_vals.items[case_val_idx..][0..2]; - extra_index += 2; - case_val_idx += 2; - - const item_first = range_items[0]; - const item_last = range_items[1]; - - // operand >= first and operand <= last - const range_first_ok = try case_block.addBinOp( - if (case_block.float_mode == .optimized) .cmp_gte_optimized else .cmp_gte, - operand, - item_first, - ); - const range_last_ok = try case_block.addBinOp( - if (case_block.float_mode == .optimized) .cmp_lte_optimized else .cmp_lte, - operand, - item_last, - ); - const range_ok = try case_block.addBinOp( - .bool_and, - range_first_ok, - range_last_ok, - ); - if (any_ok != .none) { - any_ok = try case_block.addBinOp(.bool_or, any_ok, range_ok); - } else { - any_ok = range_ok; - } - } - - const new_cond_br = try case_block.addInstAsIndex(.{ .tag = .cond_br, .data = .{ - .pl_op = .{ - .operand = any_ok, - .payload = undefined, - }, - } }); - var cond_body = try case_block.instructions.toOwnedSlice(gpa); - defer gpa.free(cond_body); - - case_block.instructions.shrinkRetainingCapacity(0); - case_block.error_return_trace_index = child_block.error_return_trace_index; - - const body = sema.code.bodySlice(extra_index, info.body_len); - extra_index += info.body_len; - if (err_set and try sema.maybeErrorUnwrap(&case_block, body, operand, operand_src, allow_err_code_unwrap)) { - // nothing to do here - } else { - try spa.analyzeProngRuntime( - &case_block, - .normal, - body, - info.capture, - .{ .multi_capture = multi_i }, - items, - .none, - false, - ); - } - - if (is_first) { - is_first = false; - first_else_body = cond_body; - cond_body = &.{}; - } else { - try sema.air_extra.ensureUnusedCapacity( - gpa, - @typeInfo(Air.CondBr).Struct.fields.len + prev_then_body.len + cond_body.len, - ); + _ = try case_block.addNoOp(.unreach); + } - sema.air_instructions.items(.data)[@intFromEnum(prev_cond_br)].pl_op.payload = - sema.addExtraAssumeCapacity(Air.CondBr{ - .then_body_len = @intCast(prev_then_body.len), - .else_body_len = @intCast(cond_body.len), - }); - sema.air_extra.appendSliceAssumeCapacity(@ptrCast(prev_then_body)); - sema.air_extra.appendSliceAssumeCapacity(@ptrCast(cond_body)); - } - gpa.free(prev_then_body); - prev_then_body = try case_block.instructions.toOwnedSlice(gpa); - prev_cond_br = new_cond_br; + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).Struct.fields.len + + items.len + + 2 * ranges_len + + case_block.instructions.items.len); + + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = @intCast(items.len), + .ranges_len = @intCast(ranges_len), + .body_len = @intCast(case_block.instructions.items.len), + })); + for (items) |item| { + cases_extra.appendAssumeCapacity(@intFromEnum(item)); + } + for (0..ranges_len) |_| { + const range_first, const range_last = case_vals.items[case_val_idx..][0..2].*; + case_val_idx += 2; + cases_extra.appendSliceAssumeCapacity(&.{ + @intFromEnum(range_first), + @intFromEnum(range_last), + }); } + cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } - var final_else_body: []const Air.Inst.Index = &.{}; - if (special.body.len != 0 or !is_first or case_block.wantSafety()) { + const else_body: []const Air.Inst.Index = if (special.body.len != 0 or case_block.wantSafety()) else_body: { var emit_bb = false; if (special.is_inline) switch (operand_ty.zigTypeTag(mod)) { .Enum => { @@ -12920,9 +12845,14 @@ fn analyzeSwitchRuntimeBlock( _ = try case_block.addNoOp(.unreach); } - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).Struct.fields.len + + 1 + // item + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(item_ref)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } @@ -12962,9 +12892,14 @@ fn analyzeSwitchRuntimeBlock( special.has_tag_capture, ); - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).Struct.fields.len + + 1 + // item + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(item_ref)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } @@ -12993,9 +12928,14 @@ fn analyzeSwitchRuntimeBlock( special.has_tag_capture, ); - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).Struct.fields.len + + 1 + // item + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(item_ref)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } @@ -13021,9 +12961,14 @@ fn analyzeSwitchRuntimeBlock( special.has_tag_capture, ); - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).Struct.fields.len + + 1 + // item + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(Air.Inst.Ref.bool_true)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } @@ -13047,9 +12992,14 @@ fn analyzeSwitchRuntimeBlock( special.has_tag_capture, ); - try cases_extra.ensureUnusedCapacity(gpa, 3 + case_block.instructions.items.len); - cases_extra.appendAssumeCapacity(1); // items_len - cases_extra.appendAssumeCapacity(@intCast(case_block.instructions.items.len)); + try cases_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr.Case).Struct.fields.len + + 1 + // item + case_block.instructions.items.len); + cases_extra.appendSliceAssumeCapacity(&payloadToExtraItems(Air.SwitchBr.Case{ + .items_len = 1, + .ranges_len = 0, + .body_len = @intCast(case_block.instructions.items.len), + })); cases_extra.appendAssumeCapacity(@intFromEnum(Air.Inst.Ref.bool_false)); cases_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); } @@ -13107,33 +13057,19 @@ fn analyzeSwitchRuntimeBlock( } } - if (is_first) { - final_else_body = case_block.instructions.items; - } else { - try sema.air_extra.ensureUnusedCapacity(gpa, prev_then_body.len + - @typeInfo(Air.CondBr).Struct.fields.len + case_block.instructions.items.len); - - sema.air_instructions.items(.data)[@intFromEnum(prev_cond_br)].pl_op.payload = - sema.addExtraAssumeCapacity(Air.CondBr{ - .then_body_len = @intCast(prev_then_body.len), - .else_body_len = @intCast(case_block.instructions.items.len), - }); - sema.air_extra.appendSliceAssumeCapacity(@ptrCast(prev_then_body)); - sema.air_extra.appendSliceAssumeCapacity(@ptrCast(case_block.instructions.items)); - final_else_body = first_else_body; - } - } + break :else_body case_block.instructions.items; + } else &.{}; try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr).Struct.fields.len + - cases_extra.items.len + final_else_body.len); + cases_extra.items.len + else_body.len); const payload_index = sema.addExtraAssumeCapacity(Air.SwitchBr{ .cases_len = @intCast(cases_len), - .else_body_len = @intCast(final_else_body.len), + .else_body_len = @intCast(else_body.len), }); sema.air_extra.appendSliceAssumeCapacity(@ptrCast(cases_extra.items)); - sema.air_extra.appendSliceAssumeCapacity(@ptrCast(final_else_body)); + sema.air_extra.appendSliceAssumeCapacity(@ptrCast(else_body)); return try child_block.addInst(.{ .tag = .switch_br, @@ -37443,15 +37379,21 @@ pub fn addExtra(sema: *Sema, extra: anytype) Allocator.Error!u32 { } pub fn addExtraAssumeCapacity(sema: *Sema, extra: anytype) u32 { - const fields = std.meta.fields(@TypeOf(extra)); const result: u32 = @intCast(sema.air_extra.items.len); - inline for (fields) |field| { - sema.air_extra.appendAssumeCapacity(switch (field.type) { - u32 => @field(extra, field.name), - i32 => @bitCast(@field(extra, field.name)), - Air.Inst.Ref, InternPool.Index => @intFromEnum(@field(extra, field.name)), + sema.air_extra.appendSliceAssumeCapacity(&payloadToExtraItems(extra)); + return result; +} + +fn payloadToExtraItems(data: anytype) [@typeInfo(@TypeOf(data)).Struct.fields.len]u32 { + const fields = @typeInfo(@TypeOf(data)).Struct.fields; + var result: [fields.len]u32 = undefined; + inline for (&result, fields) |*val, field| { + val.* = switch (field.type) { + u32 => @field(data, field.name), + i32 => @bitCast(@field(data, field.name)), + Air.Inst.Ref, InternPool.Index => @intFromEnum(@field(data, field.name)), else => @compileError("bad field type: " ++ @typeName(field.type)), - }); + }; } return result; } diff --git a/src/arch/aarch64/CodeGen.zig b/src/arch/aarch64/CodeGen.zig index ddde72345efe..11efb49cef9b 100644 --- a/src/arch/aarch64/CodeGen.zig +++ b/src/arch/aarch64/CodeGen.zig @@ -5091,7 +5091,8 @@ fn airSwitch(self: *Self, inst: Air.Inst.Index) !void { var case_i: u32 = 0; while (case_i < switch_br.data.cases_len) : (case_i += 1) { const case = self.air.extraData(Air.SwitchBr.Case, extra_index); - const items = @as([]const Air.Inst.Ref, @ptrCast(self.air.extra[case.end..][0..case.data.items_len])); + if (case.data.ranges_len > 0) return self.fail("TODO: switch with ranges", .{}); + const items: []const Air.Inst.Ref = @ptrCast(self.air.extra[case.end..][0..case.data.items_len]); assert(items.len > 0); const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]); extra_index = case.end + items.len + case_body.len; diff --git a/src/arch/arm/CodeGen.zig b/src/arch/arm/CodeGen.zig index 86d4e8f7fdd6..6186b495e90a 100644 --- a/src/arch/arm/CodeGen.zig +++ b/src/arch/arm/CodeGen.zig @@ -5022,6 +5022,7 @@ fn airSwitch(self: *Self, inst: Air.Inst.Index) !void { var case_i: u32 = 0; while (case_i < switch_br.data.cases_len) : (case_i += 1) { const case = self.air.extraData(Air.SwitchBr.Case, extra_index); + if (case.data.ranges_len > 0) return self.fail("TODO: switch with ranges", .{}); const items: []const Air.Inst.Ref = @ptrCast(self.air.extra[case.end..][0..case.data.items_len]); assert(items.len > 0); const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]); diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index fe94c061365f..d639d780c2b2 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -4015,6 +4015,7 @@ fn airSwitchBr(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { var highest_maybe: ?i32 = null; while (case_i < switch_br.data.cases_len) : (case_i += 1) { const case = func.air.extraData(Air.SwitchBr.Case, extra_index); + if (case.data.ranges_len != 0) return func.fail("TODO: switch with ranges", .{}); const items: []const Air.Inst.Ref = @ptrCast(func.air.extra[case.end..][0..case.data.items_len]); const case_body: []const Air.Inst.Index = @ptrCast(func.air.extra[case.end + items.len ..][0..case.data.body_len]); extra_index = case.end + items.len + case_body.len; diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index 7faf3d6d50c2..1edab02f9d6d 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -13464,6 +13464,7 @@ fn airSwitchBr(self: *Self, inst: Air.Inst.Index) !void { while (case_i < switch_br.data.cases_len) : (case_i += 1) { const case = self.air.extraData(Air.SwitchBr.Case, extra_index); + if (case.data.ranges_len > 0) return self.fail("TODO: switch with ranges", .{}); const items: []const Air.Inst.Ref = @ptrCast(self.air.extra[case.end..][0..case.data.items_len]); const case_body: []const Air.Inst.Index = diff --git a/src/codegen/c.zig b/src/codegen/c.zig index 9514b826eaa8..a8c80d92b91a 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -5049,14 +5049,16 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !CValue { const liveness = try f.liveness.getSwitchBr(gpa, inst, switch_br.data.cases_len + 1); defer gpa.free(liveness.deaths); - // On the final iteration we do not need to fix any state. This is because, like in the `else` - // branch of a `cond_br`, our parent has to do it for this entire body anyway. - const last_case_i = switch_br.data.cases_len - @intFromBool(switch_br.data.else_body_len == 0); - + var any_range_cases = false; var extra_index: usize = switch_br.end; for (0..switch_br.data.cases_len) |case_i| { const case = f.air.extraData(Air.SwitchBr.Case, extra_index); - const items = @as([]const Air.Inst.Ref, @ptrCast(f.air.extra[case.end..][0..case.data.items_len])); + if (case.data.ranges_len != 0) { + any_range_cases = true; + extra_index = case.end + case.data.items_len + case.data.ranges_len * 2 + case.data.body_len; + continue; + } + const items: []const Air.Inst.Ref = @ptrCast(f.air.extra[case.end..][0..case.data.items_len]); const case_body: []const Air.Inst.Index = @ptrCast(f.air.extra[case.end + items.len ..][0..case.data.body_len]); extra_index = case.end + case.data.items_len + case_body.len; @@ -5079,30 +5081,69 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !CValue { } try writer.writeByte(' '); - if (case_i != last_case_i) { - try genBodyResolveState(f, inst, liveness.deaths[case_i], case_body, false); - } else { - for (liveness.deaths[case_i]) |death| { - try die(f, inst, death.toRef()); - } - try genBody(f, case_body); - } + try genBodyResolveState(f, inst, liveness.deaths[case_i], case_body, false); // The case body must be noreturn so we don't need to insert a break. } const else_body: []const Air.Inst.Index = @ptrCast(f.air.extra[extra_index..][0..switch_br.data.else_body_len]); try f.object.indent_writer.insertNewline(); + + try writer.writeAll("default: "); + if (any_range_cases) { + // We will iterate the cases again to handle those with ranges, and generate + // code using conditionals rather than switch cases for such cases. + extra_index = switch_br.end; + for (0..switch_br.data.cases_len) |case_i| { + const case = f.air.extraData(Air.SwitchBr.Case, extra_index); + if (case.data.ranges_len == 0) { + // No ranges, so handled above - skip this case. + extra_index = case.end + case.data.items_len + case.data.body_len; + continue; + } + extra_index = case.end; + const items: []const Air.Inst.Ref = @ptrCast(f.air.extra[extra_index..][0..case.data.items_len]); + extra_index += items.len; + // TODO: this can be written more cleanly once Sema allows @ptrCast on slices where the length changes. + const ranges: []const [2]Air.Inst.Ref = @as([*]const [2]Air.Inst.Ref, @ptrCast(f.air.extra[extra_index..].ptr))[0..case.data.ranges_len]; + extra_index += ranges.len * 2; + const case_body: []const Air.Inst.Index = @ptrCast(f.air.extra[extra_index..][0..case.data.body_len]); + extra_index += case_body.len; + try writer.writeAll("if ("); + for (items, 0..) |item, item_i| { + if (item_i != 0) try writer.writeAll(" || "); + try f.writeCValue(writer, condition, .Other); + try writer.writeAll(" == "); + try f.object.dg.renderValue(writer, (try f.air.value(item, zcu)).?, .Other); + } + for (ranges, 0..) |range, range_i| { + if (items.len != 0 or range_i != 0) try writer.writeAll(" || "); + // "(x >= lower && x <= upper)" + try writer.writeByte('('); + try f.writeCValue(writer, condition, .Other); + try writer.writeAll(" >= "); + try f.object.dg.renderValue(writer, (try f.air.value(range[0], zcu)).?, .Other); + try writer.writeAll(" && "); + try f.writeCValue(writer, condition, .Other); + try writer.writeAll(" <= "); + try f.object.dg.renderValue(writer, (try f.air.value(range[1], zcu)).?, .Other); + try writer.writeByte(')'); + } + try writer.writeAll(") "); + try genBodyResolveState(f, inst, liveness.deaths[case_i], case_body, false); + } + } if (else_body.len > 0) { - // Note that this must be the last case (i.e. the `last_case_i` case was not hit above) + // Note that this must be the last case, so we do not need to use `caseBodyResolveState` since + // the parent block will do it (because the case body is noreturn). for (liveness.deaths[liveness.deaths.len - 1]) |death| { try die(f, inst, death.toRef()); } - try writer.writeAll("default: "); try genBody(f, else_body); } else { - try writer.writeAll("default: zig_unreachable();"); + try writer.writeAll("zig_unreachable();"); } + try f.object.indent_writer.insertNewline(); f.object.indent_writer.popIndent(); diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 9be2316a558f..7956fcc4762e 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -6096,10 +6096,17 @@ pub const FuncGen = struct { cond; var extra_index: usize = switch_br.end; - var case_i: u32 = 0; + var any_range_cases = false; var llvm_cases_len: u32 = 0; - while (case_i < switch_br.data.cases_len) : (case_i += 1) { + for (0..switch_br.data.cases_len) |_| { const case = self.air.extraData(Air.SwitchBr.Case, extra_index); + if (case.data.ranges_len != 0) { + // TODO: for ranges, we could still define any scalar cases in the same prong within + // the switch, just directing it to the same bb as the range check. + any_range_cases = true; + extra_index = case.end + case.data.items_len + case.data.ranges_len * 2 + case.data.body_len; + continue; + } const items: []const Air.Inst.Ref = @ptrCast(self.air.extra[case.end..][0..case.data.items_len]); const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len]; @@ -6112,9 +6119,12 @@ pub const FuncGen = struct { defer wip_switch.finish(&self.wip); extra_index = switch_br.end; - case_i = 0; - while (case_i < switch_br.data.cases_len) : (case_i += 1) { + for (0..switch_br.data.cases_len) |_| { const case = self.air.extraData(Air.SwitchBr.Case, extra_index); + if (case.data.ranges_len != 0) { + extra_index = case.end + case.data.items_len + case.data.ranges_len * 2 + case.data.body_len; + continue; + } const items: []const Air.Inst.Ref = @ptrCast(self.air.extra[case.end..][0..case.data.items_len]); const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]); @@ -6137,6 +6147,60 @@ pub const FuncGen = struct { self.wip.cursor = .{ .block = else_block }; const else_body: []const Air.Inst.Index = @ptrCast(self.air.extra[extra_index..][0..switch_br.data.else_body_len]); + if (any_range_cases) { + // We will iterate the cases again to handle those with ranges, and generate + // code using conditionals rather than switch cases for such cases. + const cond_ty = self.typeOf(pl_op.operand); + extra_index = switch_br.end; + for (0..switch_br.data.cases_len) |_| { + const case = self.air.extraData(Air.SwitchBr.Case, extra_index); + if (case.data.ranges_len == 0) { + // No ranges, so handled above - skip this case. + extra_index = case.end + case.data.items_len + case.data.body_len; + continue; + } + extra_index = case.end; + const items: []const Air.Inst.Ref = @ptrCast(self.air.extra[extra_index..][0..case.data.items_len]); + extra_index += items.len; + // TODO: this can be written more cleanly once Sema allows @ptrCast on slices where the length changes. + const ranges: []const [2]Air.Inst.Ref = @as([*]const [2]Air.Inst.Ref, @ptrCast(self.air.extra[extra_index..].ptr))[0..case.data.ranges_len]; + extra_index += ranges.len * 2; + const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[extra_index..][0..case.data.body_len]); + extra_index += case_body.len; + + var range_cond: ?Builder.Value = null; + + for (items) |item| { + const llvm_item = try self.resolveInst(item); + const cond_part = try self.cmp(.normal, .eq, cond_ty, cond, llvm_item); + if (range_cond) |old| { + range_cond = try self.wip.bin(.@"or", old, cond_part, ""); + } else range_cond = cond_part; + } + for (ranges) |range| { + const llvm_min = try self.resolveInst(range[0]); + const llvm_max = try self.resolveInst(range[1]); + const cond_part = try self.wip.bin( + .@"and", + try self.cmp(.normal, .gte, cond_ty, cond, llvm_min), + try self.cmp(.normal, .lte, cond_ty, cond, llvm_max), + "", + ); + if (range_cond) |old| { + range_cond = try self.wip.bin(.@"or", old, cond_part, ""); + } else range_cond = cond_part; + } + + const range_case_block = try self.wip.block(1, "RangeCase"); + const range_else_block = try self.wip.block(1, "RangeDefault"); + + _ = try self.wip.brCond(range_cond.?, range_case_block, range_else_block); + + self.wip.cursor = .{ .block = range_case_block }; + try self.genBodyDebugScope(null, case_body); + self.wip.cursor = .{ .block = range_else_block }; + } + } if (else_body.len != 0) { try self.genBodyDebugScope(null, else_body); } else { diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index ed04ee475bf8..060b2a24f7ee 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -5456,6 +5456,7 @@ const DeclGen = struct { var num_conditions: u32 = 0; for (0..num_cases) |_| { const case = self.air.extraData(Air.SwitchBr.Case, extra_index); + if (case.data.ranges_len != 0) return self.fail("TODO: switch with ranges", .{}); const case_body = self.air.extra[case.end + case.data.items_len ..][0..case.data.body_len]; extra_index = case.end + case.data.items_len + case_body.len; num_conditions += case.data.items_len; diff --git a/src/print_air.zig b/src/print_air.zig index e61ae9fff004..030a03e5aa9e 100644 --- a/src/print_air.zig +++ b/src/print_air.zig @@ -843,15 +843,26 @@ const Writer = struct { while (case_i < switch_br.data.cases_len) : (case_i += 1) { const case = w.air.extraData(Air.SwitchBr.Case, extra_index); - const items = @as([]const Air.Inst.Ref, @ptrCast(w.air.extra[case.end..][0..case.data.items_len])); - const case_body: []const Air.Inst.Index = @ptrCast(w.air.extra[case.end + items.len ..][0..case.data.body_len]); - extra_index = case.end + case.data.items_len + case_body.len; + extra_index = case.end; + const items: []const Air.Inst.Ref = @ptrCast(w.air.extra[extra_index..][0..case.data.items_len]); + extra_index += items.len; + // TODO: this can be written more cleanly once Sema allows @ptrCast on slices where the length changes. + const ranges: []const [2]Air.Inst.Ref = @as([*]const [2]Air.Inst.Ref, @ptrCast(w.air.extra[extra_index..].ptr))[0..case.data.ranges_len]; + extra_index += case.data.ranges_len * 2; + const case_body: []const Air.Inst.Index = @ptrCast(w.air.extra[extra_index..][0..case.data.body_len]); + extra_index += case_body.len; try s.writeAll(", ["); for (items, 0..) |item, item_i| { if (item_i != 0) try s.writeAll(", "); try w.writeInstRef(s, item, false); } + for (ranges, 0..) |range, range_i| { + if (items.len != 0 or range_i != 0) try s.writeAll(", "); + try w.writeInstRef(s, range[0], false); + try s.writeAll(".."); + try w.writeInstRef(s, range[1], false); + } try s.writeAll("] => {\n"); w.indent += 2; From 3fd3bda31f172b8e3f4568801c6129615cbe5998 Mon Sep 17 00:00:00 2001 From: mlugg Date: Thu, 25 Apr 2024 03:46:10 +0100 Subject: [PATCH 3/9] Air: add explicit `repeat` instruction to repeat loops This commit introduces a new AIR instruction, `repeat`, which causes control flow to move back to the start of a given AIR loop. `loop` instructions will no longer automatically perform this operation after control flow reaches the end of the body. The motivation for making this change now was really just consistency with the upcoming implementation of #8220: it wouldn't make sense to have this feature work significantly differently. However, there were already some TODOs kicking around which wanted this feature. It's useful for two key reasons: * It allows loops over AIR instruction bodies to loop precisely until they reach a `noreturn` instruction. This allows for tail calling a few things, and avoiding a range check on each iteration of a hot path, plus gives a nice assertion that validates AIR structure a little. This is a very minor benefit, which this commit does apply to the LLVM and C backends. * It should allow for more compact ZIR and AIR to be emitted by having AstGen emit `repeat` instructions more often rather than having `continue` statements `break` to a `block` which is *followed* by a `repeat`. This is done in status quo because `repeat` instructions only ever cause the direct parent block to repeat. Now that AIR is more flexible, this flexibility can be pretty trivially extended to ZIR, and we can then emit better ZIR. This commit does not implement this. Support for this feature is currently regressed on all self-hosted native backends, including x86_64. This support will be added where necessary before this branch is merged. --- src/Air.zig | 15 ++-- src/Liveness.zig | 62 ++++++++++++++-- src/Liveness/Verify.zig | 38 +++++++--- src/Sema.zig | 19 ++++- src/arch/aarch64/CodeGen.zig | 1 + src/arch/arm/CodeGen.zig | 1 + src/arch/riscv64/CodeGen.zig | 1 + src/arch/sparc64/CodeGen.zig | 1 + src/arch/wasm/CodeGen.zig | 1 + src/arch/x86_64/CodeGen.zig | 1 + src/codegen/c.zig | 66 ++++++++++++----- src/codegen/llvm.zig | 133 ++++++++++++++++++++--------------- src/codegen/spirv.zig | 1 + src/print_air.zig | 6 ++ 14 files changed, 250 insertions(+), 96 deletions(-) diff --git a/src/Air.zig b/src/Air.zig index 91ffed071120..8f24f295b79a 100644 --- a/src/Air.zig +++ b/src/Air.zig @@ -272,13 +272,15 @@ pub const Inst = struct { /// is to encounter a `br` that targets this `block`. If the `block` type is `noreturn`, /// then there do not exist any `br` instructions targetting this `block`. block, - /// A labeled block of code that loops forever. At the end of the body it is implied - /// to repeat; no explicit "repeat" instruction terminates loop bodies. + /// A labeled block of code that loops forever. The body must be `noreturn`: loops + /// occur through an explicit `repeat` instruction pointing back to this one. /// Result type is always `noreturn`; no instructions in a block follow this one. - /// The body never ends with a `noreturn` instruction, so the "repeat" operation - /// is always statically reachable. + /// There is always at least one `repeat` instruction referencing the loop. /// Uses the `ty_pl` field. Payload is `Block`. loop, + /// Sends control flow back to the beginning of a parent `loop` body. + /// Uses the `repeat` field. + repeat, /// Return from a block with a result. /// Result type is always noreturn; no instructions in a block follow this one. /// Uses the `br` field. @@ -1052,6 +1054,9 @@ pub const Inst = struct { block_inst: Index, operand: Ref, }, + repeat: struct { + loop_inst: Index, + }, pl_op: struct { operand: Ref, payload: u32, @@ -1440,6 +1445,7 @@ pub fn typeOfIndex(air: *const Air, inst: Air.Inst.Index, ip: *const InternPool) => return datas[@intFromEnum(inst)].ty_op.ty.toType(), .loop, + .repeat, .br, .cond_br, .switch_br, @@ -1596,6 +1602,7 @@ pub fn mustLower(air: Air, inst: Air.Inst.Index, ip: *const InternPool) bool { .arg, .block, .loop, + .repeat, .br, .trap, .breakpoint, diff --git a/src/Liveness.zig b/src/Liveness.zig index ee37ea76b6e2..2b56fe728b3a 100644 --- a/src/Liveness.zig +++ b/src/Liveness.zig @@ -70,7 +70,8 @@ pub const Block = struct { const LivenessPass = enum { /// In this pass, we perform some basic analysis of loops to gain information the main pass /// needs. In particular, for every `loop`, we track the following information: - /// * Every block which the loop body contains a `br` to. + /// * Every outer block which the loop body contains a `br` to. + /// * Every outer loop which the loop body contains a `repeat` to. /// * Every operand referenced within the loop body but created outside the loop. /// This gives the main analysis pass enough information to determine the full set of /// instructions which need to be alive when a loop repeats. This data is TEMPORARILY stored in @@ -89,7 +90,8 @@ fn LivenessPassData(comptime pass: LivenessPass) type { return switch (pass) { .loop_analysis => struct { /// The set of blocks which are exited with a `br` instruction at some point within this - /// body and which we are currently within. + /// body and which we are currently within. Also includes `loop`s which are the target + /// of a `repeat` instruction. breaks: std.AutoHashMapUnmanaged(Air.Inst.Index, void) = .{}, /// The set of operands for which we have seen at least one usage but not their birth. @@ -102,7 +104,7 @@ fn LivenessPassData(comptime pass: LivenessPass) type { }, .main_analysis => struct { - /// Every `block` currently under analysis. + /// Every `block` and `loop` currently under analysis. block_scopes: std.AutoHashMapUnmanaged(Air.Inst.Index, BlockScope) = .{}, /// The set of instructions currently alive in the current control @@ -114,7 +116,8 @@ fn LivenessPassData(comptime pass: LivenessPass) type { old_extra: std.ArrayListUnmanaged(u32) = .{}, const BlockScope = struct { - /// The set of instructions which are alive upon a `br` to this block. + /// If this is a `block`, these instructions are alive upon a `br` to this block. + /// If this is a `loop`, these instructions are alive upon a `repeat` to this block. live_set: std.AutoHashMapUnmanaged(Air.Inst.Index, void), }; @@ -326,6 +329,7 @@ pub fn categorizeOperand( .ret_ptr, .trap, .breakpoint, + .repeat, .dbg_stmt, .unreach, .ret_addr, @@ -1199,6 +1203,7 @@ fn analyzeInst( }, .br => return analyzeInstBr(a, pass, data, inst), + .repeat => return analyzeInstRepeat(a, pass, data, inst), .assembly => { const extra = a.air.extraData(Air.Asm, inst_datas[@intFromEnum(inst)].ty_pl.payload); @@ -1378,6 +1383,33 @@ fn analyzeInstBr( return analyzeOperands(a, pass, data, inst, .{ br.operand, .none, .none }); } +fn analyzeInstRepeat( + a: *Analysis, + comptime pass: LivenessPass, + data: *LivenessPassData(pass), + inst: Air.Inst.Index, +) !void { + const inst_datas = a.air.instructions.items(.data); + const repeat = inst_datas[@intFromEnum(inst)].repeat; + const gpa = a.gpa; + + switch (pass) { + .loop_analysis => { + try data.breaks.put(gpa, repeat.loop_inst, {}); + }, + + .main_analysis => { + const block_scope = data.block_scopes.get(repeat.loop_inst).?; // we should always be repeating an enclosing loop + + const new_live_set = try block_scope.live_set.clone(gpa); + data.live_set.deinit(gpa); + data.live_set = new_live_set; + }, + } + + return analyzeOperands(a, pass, data, inst, .{ .none, .none, .none }); +} + fn analyzeInstBlock( a: *Analysis, comptime pass: LivenessPass, @@ -1400,8 +1432,10 @@ fn analyzeInstBlock( .main_analysis => { log.debug("[{}] %{}: block live set is {}", .{ pass, inst, fmtInstSet(&data.live_set) }); + // We can move the live set because the body should have a noreturn + // instruction which overrides the set. try data.block_scopes.put(gpa, inst, .{ - .live_set = try data.live_set.clone(gpa), + .live_set = data.live_set.move(), }); defer { log.debug("[{}] %{}: popped block scope", .{ pass, inst }); @@ -1469,10 +1503,15 @@ fn analyzeInstLoop( try analyzeBody(a, pass, data, body); + // `loop`s are guaranteed to have at least one matching `repeat`. + // However, we no longer care about repeats of this loop itself. + assert(data.breaks.remove(inst)); + + const extra_index: u32 = @intCast(a.extra.items.len); + const num_breaks = data.breaks.count(); try a.extra.ensureUnusedCapacity(gpa, 1 + num_breaks); - const extra_index = @as(u32, @intCast(a.extra.items.len)); a.extra.appendAssumeCapacity(num_breaks); var it = data.breaks.keyIterator(); @@ -1541,6 +1580,17 @@ fn analyzeInstLoop( } } + // Now, `data.live_set` is the operands which must be alive when the loop repeats. + // Move them into a block scope for corresponding `repeat` instructions to notice. + log.debug("[{}] %{}: loop live set is {}", .{ pass, inst, fmtInstSet(&data.live_set) }); + try data.block_scopes.putNoClobber(gpa, inst, .{ + .live_set = data.live_set.move(), + }); + defer { + log.debug("[{}] %{}: popped loop block scop", .{ pass, inst }); + var scope = data.block_scopes.fetchRemove(inst).?.value; + scope.live_set.deinit(gpa); + } try analyzeBody(a, pass, data, body); }, } diff --git a/src/Liveness/Verify.zig b/src/Liveness/Verify.zig index ef14033828b2..b4d150b64574 100644 --- a/src/Liveness/Verify.zig +++ b/src/Liveness/Verify.zig @@ -1,28 +1,38 @@ -//! Verifies that liveness information is valid. +//! Verifies that Liveness information is valid. gpa: std.mem.Allocator, air: Air, liveness: Liveness, live: LiveMap = .{}, blocks: std.AutoHashMapUnmanaged(Air.Inst.Index, LiveMap) = .{}, +loops: std.AutoHashMapUnmanaged(Air.Inst.Index, LiveMap) = .{}, intern_pool: *const InternPool, pub const Error = error{ LivenessInvalid, OutOfMemory }; pub fn deinit(self: *Verify) void { self.live.deinit(self.gpa); - var block_it = self.blocks.valueIterator(); - while (block_it.next()) |block| block.deinit(self.gpa); - self.blocks.deinit(self.gpa); + { + var it = self.blocks.valueIterator(); + while (it.next()) |block| block.deinit(self.gpa); + self.blocks.deinit(self.gpa); + } + { + var it = self.loops.valueIterator(); + while (it.next()) |block| block.deinit(self.gpa); + self.loops.deinit(self.gpa); + } self.* = undefined; } pub fn verify(self: *Verify) Error!void { self.live.clearRetainingCapacity(); self.blocks.clearRetainingCapacity(); + self.loops.clearRetainingCapacity(); try self.verifyBody(self.air.getMainBody()); // We don't care about `self.live` now, because the loop body was noreturn - everything being dead was checked on `ret` etc assert(self.blocks.count() == 0); + assert(self.loops.count() == 0); } const LiveMap = std.AutoHashMapUnmanaged(Air.Inst.Index, void); @@ -429,6 +439,13 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void { } try self.verifyInst(inst); }, + .repeat => { + const repeat = data[@intFromEnum(inst)].repeat; + const expected_live = self.loops.get(repeat.loop_inst) orelse + return invalid("%{}: loop %{} not in scope", .{ @intFromEnum(inst), @intFromEnum(repeat.loop_inst) }); + + try self.verifyMatchingLiveness(repeat.loop_inst, expected_live); + }, .block, .dbg_inline_block => |tag| { const ty_pl = data[@intFromEnum(inst)].ty_pl; const block_ty = ty_pl.ty.toType(); @@ -474,14 +491,17 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void { const extra = self.air.extraData(Air.Block, ty_pl.payload); const loop_body: []const Air.Inst.Index = @ptrCast(self.air.extra[extra.end..][0..extra.data.body_len]); - var live = try self.live.clone(self.gpa); - defer live.deinit(self.gpa); + // The same stuff should be alive after the loop as before it. + const gop = try self.loops.getOrPut(self.gpa, inst); + defer { + var live = self.loops.fetchRemove(inst).?; + live.value.deinit(self.gpa); + } + if (gop.found_existing) return invalid("%{}: loop already exists", .{@intFromEnum(inst)}); + gop.value_ptr.* = try self.live.clone(self.gpa); try self.verifyBody(loop_body); - // The same stuff should be alive after the loop as before it - try self.verifyMatchingLiveness(inst, live); - try self.verifyInstOperands(inst, .{ .none, .none, .none }); }, .cond_br => { diff --git a/src/Sema.zig b/src/Sema.zig index 5cd5c2fb860e..33e5e87e7c96 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -1520,6 +1520,8 @@ fn analyzeBodyInner( // We are definitely called by `zirLoop`, which will treat the // fact that this body does not terminate `noreturn` as an // implicit repeat. + // TODO: since AIR has `repeat` now, we could change ZIR to generate + // more optimal code utilizing `repeat` instructions across blocks! break; } }, @@ -5956,17 +5958,30 @@ fn zirLoop(sema: *Sema, parent_block: *Block, inst: Zir.Inst.Index) CompileError // Use `analyzeBodyInner` directly to push any comptime control flow up the stack. try sema.analyzeBodyInner(&loop_block, body); + // TODO: since AIR has `repeat` now, we could change ZIR to generate + // more optimal code utilizing `repeat` instructions across blocks! + // For now, if the generated loop body does not terminate `noreturn`, + // then `analyzeBodyInner` is signalling that it ended with `repeat`. + const loop_block_len = loop_block.instructions.items.len; if (loop_block_len > 0 and sema.typeOf(loop_block.instructions.items[loop_block_len - 1].toRef()).isNoReturn(mod)) { // If the loop ended with a noreturn terminator, then there is no way for it to loop, // so we can just use the block instead. try child_block.instructions.appendSlice(gpa, loop_block.instructions.items); } else { + _ = try loop_block.addInst(.{ + .tag = .repeat, + .data = .{ .repeat = .{ + .loop_inst = loop_inst, + } }, + }); + // Note that `loop_block_len` is now off by one. + try child_block.instructions.append(gpa, loop_inst); - try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.Block).Struct.fields.len + loop_block_len); + try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.Block).Struct.fields.len + loop_block_len + 1); sema.air_instructions.items(.data)[@intFromEnum(loop_inst)].ty_pl.payload = sema.addExtraAssumeCapacity( - Air.Block{ .body_len = @intCast(loop_block_len) }, + Air.Block{ .body_len = @intCast(loop_block_len + 1) }, ); sema.air_extra.appendSliceAssumeCapacity(@ptrCast(loop_block.instructions.items)); } diff --git a/src/arch/aarch64/CodeGen.zig b/src/arch/aarch64/CodeGen.zig index 11efb49cef9b..08e8500b6597 100644 --- a/src/arch/aarch64/CodeGen.zig +++ b/src/arch/aarch64/CodeGen.zig @@ -740,6 +740,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .bitcast => try self.airBitCast(inst), .block => try self.airBlock(inst), .br => try self.airBr(inst), + .repeat => return self.fail("TODO implement `repeat`", .{}), .trap => try self.airTrap(), .breakpoint => try self.airBreakpoint(), .ret_addr => try self.airRetAddr(inst), diff --git a/src/arch/arm/CodeGen.zig b/src/arch/arm/CodeGen.zig index 6186b495e90a..cb9a580832cd 100644 --- a/src/arch/arm/CodeGen.zig +++ b/src/arch/arm/CodeGen.zig @@ -726,6 +726,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .bitcast => try self.airBitCast(inst), .block => try self.airBlock(inst), .br => try self.airBr(inst), + .repeat => return self.fail("TODO implement `repeat`", .{}), .trap => try self.airTrap(), .breakpoint => try self.airBreakpoint(), .ret_addr => try self.airRetAddr(inst), diff --git a/src/arch/riscv64/CodeGen.zig b/src/arch/riscv64/CodeGen.zig index 5abe3afcfd2a..26083f9544c9 100644 --- a/src/arch/riscv64/CodeGen.zig +++ b/src/arch/riscv64/CodeGen.zig @@ -559,6 +559,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .bitcast => try self.airBitCast(inst), .block => try self.airBlock(inst), .br => try self.airBr(inst), + .repeat => return self.fail("TODO implement `repeat`", .{}), .trap => try self.airTrap(), .breakpoint => try self.airBreakpoint(), .ret_addr => try self.airRetAddr(inst), diff --git a/src/arch/sparc64/CodeGen.zig b/src/arch/sparc64/CodeGen.zig index 19c18ec4a6b0..bd3a5b0f5e92 100644 --- a/src/arch/sparc64/CodeGen.zig +++ b/src/arch/sparc64/CodeGen.zig @@ -573,6 +573,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .bitcast => try self.airBitCast(inst), .block => try self.airBlock(inst), .br => try self.airBr(inst), + .repeat => return self.fail("TODO implement `repeat`", .{}), .trap => try self.airTrap(), .breakpoint => try self.airBreakpoint(), .ret_addr => @panic("TODO try self.airRetAddr(inst)"), diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index d639d780c2b2..5a0478398b3c 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -1897,6 +1897,7 @@ fn genInst(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { .trap => func.airTrap(inst), .breakpoint => func.airBreakpoint(inst), .br => func.airBr(inst), + .repeat => return func.fail("TODO implement `repeat`", .{}), .int_from_bool => func.airIntFromBool(inst), .cond_br => func.airCondBr(inst), .intcast => func.airIntcast(inst), diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index 1edab02f9d6d..2ac50dfa3a07 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -2038,6 +2038,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .bitcast => try self.airBitCast(inst), .block => try self.airBlock(inst), .br => try self.airBr(inst), + .repeat => return self.fail("TODO implement `repeat`", .{}), .trap => try self.airTrap(), .breakpoint => try self.airBreakpoint(), .ret_addr => try self.airRetAddr(inst), diff --git a/src/codegen/c.zig b/src/codegen/c.zig index a8c80d92b91a..caf9e2097128 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -3188,11 +3188,9 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, .arg => try airArg(f, inst), - .trap => try airTrap(f, f.object.writer()), .breakpoint => try airBreakpoint(f.object.writer()), .ret_addr => try airRetAddr(f, inst), .frame_addr => try airFrameAddress(f, inst), - .unreach => try airUnreach(f), .fence => try airFence(f, inst), .ptr_add => try airPtrAddSub(f, inst, '+'), @@ -3299,21 +3297,13 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, .alloc => try airAlloc(f, inst), .ret_ptr => try airRetPtr(f, inst), .assembly => try airAsm(f, inst), - .block => try airBlock(f, inst), .bitcast => try airBitcast(f, inst), .intcast => try airIntCast(f, inst), .trunc => try airTrunc(f, inst), .int_from_bool => try airIntFromBool(f, inst), .load => try airLoad(f, inst), - .ret => try airRet(f, inst, false), - .ret_safe => try airRet(f, inst, false), // TODO - .ret_load => try airRet(f, inst, true), .store => try airStore(f, inst, false), .store_safe => try airStore(f, inst, true), - .loop => try airLoop(f, inst), - .cond_br => try airCondBr(f, inst), - .br => try airBr(f, inst), - .switch_br => try airSwitchBr(f, inst), .struct_field_ptr => try airStructFieldPtr(f, inst), .array_to_slice => try airArrayToSlice(f, inst), .cmpxchg_weak => try airCmpxchg(f, inst, "weak"), @@ -3345,14 +3335,8 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, .try_ptr => try airTryPtr(f, inst), .dbg_stmt => try airDbgStmt(f, inst), - .dbg_inline_block => try airDbgInlineBlock(f, inst), .dbg_var_ptr, .dbg_var_val => try airDbgVar(f, inst), - .call => try airCall(f, inst, .auto), - .call_always_tail => .none, - .call_never_tail => try airCall(f, inst, .never_tail), - .call_never_inline => try airCall(f, inst, .never_inline), - .float_from_int, .int_from_float, .fptrunc, @@ -3439,6 +3423,39 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, .work_group_size, .work_group_id, => unreachable, + + // Instructions that are known to always be `noreturn` based on their tag. + .br => return airBr(f, inst), + .repeat => return airRepeat(f, inst), + .cond_br => return airCondBr(f, inst), + .switch_br => return airSwitchBr(f, inst), + .loop => return airLoop(f, inst), + .ret => return airRet(f, inst, false), + .ret_safe => return airRet(f, inst, false), // TODO + .ret_load => return airRet(f, inst, true), + .trap => return airTrap(f, f.object.writer()), + .unreach => return airUnreach(f), + + // Instructions which may be `noreturn`. + .block => res: { + const res = try airBlock(f, inst); + if (f.typeOfIndex(inst).isNoReturn(zcu)) return; + break :res res; + }, + .dbg_inline_block => res: { + const res = try airDbgInlineBlock(f, inst); + if (f.typeOfIndex(inst).isNoReturn(zcu)) return; + break :res res; + }, + // TODO: calls should be in this category! The AIR we emit for them is a bit weird. + // The instruction has type `noreturn`, but there are instructions (and maybe a safety + // check) following nonetheless. The `unreachable` or safety check should be emitted by + // backends instead. + .call => try airCall(f, inst, .auto), + .call_always_tail => .none, + .call_never_tail => try airCall(f, inst, .never_tail), + .call_never_inline => try airCall(f, inst, .never_inline), + // zig fmt: on }; if (result_value == .new_local) { @@ -3450,6 +3467,7 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, else => result_value, }); } + unreachable; } fn airSliceField(f: *Function, inst: Air.Inst.Index, is_ptr: bool, field_name: []const u8) !CValue { @@ -4797,6 +4815,13 @@ fn airBr(f: *Function, inst: Air.Inst.Index) !CValue { return .none; } +fn airRepeat(f: *Function, inst: Air.Inst.Index) !CValue { + const repeat = f.air.instructions.items(.data)[@intFromEnum(inst)].repeat; + const writer = f.object.writer(); + try writer.print("goto zig_loop_{d};\n", .{@intFromEnum(repeat.loop_inst)}); + return .none; +} + fn airBitcast(f: *Function, inst: Air.Inst.Index) !CValue { const ty_op = f.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; const inst_ty = f.typeOfIndex(inst); @@ -4979,9 +5004,12 @@ fn airLoop(f: *Function, inst: Air.Inst.Index) !CValue { const body: []const Air.Inst.Index = @ptrCast(f.air.extra[loop.end..][0..loop.data.body_len]); const writer = f.object.writer(); - try writer.writeAll("for (;;) "); - try genBody(f, body); // no need to restore state, we're noreturn - try writer.writeByte('\n'); + // `repeat` instructions matching this loop will branch to + // this label. Since we need a label for arbitrary `repeat` + // anyway, there's actually no need to use a "real" looping + // construct at all! + try writer.print("zig_loop_{d}:\n", .{@intFromEnum(inst)}); + try genBodyInner(f, body); // no need to restore state, we're noreturn return .none; } diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 7956fcc4762e..0a6adb11db4c 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -1695,6 +1695,7 @@ pub const Object = struct { .arg_index = 0, .func_inst_table = .{}, .blocks = .{}, + .loops = .{}, .sync_scope = if (owner_mod.single_threaded) .singlethread else .system, .file = file, .scope = subprogram, @@ -4795,6 +4796,9 @@ pub const FuncGen = struct { breaks: *BreakList, }), + /// Maps `loop` instructions to the bb to branch to to repeat the loop. + loops: std.AutoHashMapUnmanaged(Air.Inst.Index, Builder.Function.Block.Index), + sync_scope: Builder.SyncScope, const BreakList = union { @@ -4809,6 +4813,7 @@ pub const FuncGen = struct { self.wip.deinit(); self.func_inst_table.deinit(self.gpa); self.blocks.deinit(self.gpa); + self.loops.deinit(self.gpa); } fn todo(self: *FuncGen, comptime format: []const u8, args: anytype) Error { @@ -4979,39 +4984,25 @@ pub const FuncGen = struct { .ret_ptr => try self.airRetPtr(inst), .arg => try self.airArg(inst), .bitcast => try self.airBitCast(inst), - .int_from_bool => try self.airIntFromBool(inst), - .block => try self.airBlock(inst), - .br => try self.airBr(inst), - .switch_br => try self.airSwitchBr(inst), - .trap => try self.airTrap(inst), + .int_from_bool => try self.airIntFromBool(inst), .breakpoint => try self.airBreakpoint(inst), .ret_addr => try self.airRetAddr(inst), .frame_addr => try self.airFrameAddress(inst), - .cond_br => try self.airCondBr(inst), .@"try" => try self.airTry(body[i..]), .try_ptr => try self.airTryPtr(inst), .intcast => try self.airIntCast(inst), .trunc => try self.airTrunc(inst), .fptrunc => try self.airFptrunc(inst), .fpext => try self.airFpext(inst), - .int_from_ptr => try self.airIntFromPtr(inst), + .int_from_ptr => try self.airIntFromPtr(inst), .load => try self.airLoad(body[i..]), - .loop => try self.airLoop(inst), .not => try self.airNot(inst), - .ret => try self.airRet(inst, false), - .ret_safe => try self.airRet(inst, true), - .ret_load => try self.airRetLoad(inst), .store => try self.airStore(inst, false), .store_safe => try self.airStore(inst, true), .assembly => try self.airAssembly(inst), .slice_ptr => try self.airSliceField(inst, 0), .slice_len => try self.airSliceField(inst, 1), - .call => try self.airCall(inst, .auto), - .call_always_tail => try self.airCall(inst, .always_tail), - .call_never_tail => try self.airCall(inst, .never_tail), - .call_never_inline => try self.airCall(inst, .never_inline), - .ptr_slice_ptr_ptr => try self.airPtrSliceFieldPtr(inst, 0), .ptr_slice_len_ptr => try self.airPtrSliceFieldPtr(inst, 1), @@ -5096,9 +5087,7 @@ pub const FuncGen = struct { .inferred_alloc, .inferred_alloc_comptime => unreachable, - .unreach => try self.airUnreach(inst), .dbg_stmt => try self.airDbgStmt(inst), - .dbg_inline_block => try self.airDbgInlineBlock(inst), .dbg_var_ptr => try self.airDbgVarPtr(inst), .dbg_var_val => try self.airDbgVarVal(inst), @@ -5110,10 +5099,50 @@ pub const FuncGen = struct { .work_item_id => try self.airWorkItemId(inst), .work_group_size => try self.airWorkGroupSize(inst), .work_group_id => try self.airWorkGroupId(inst), + + // Instructions that are known to always be `noreturn` based on their tag. + .br => return self.airBr(inst), + .repeat => return self.airRepeat(inst), + .cond_br => return self.airCondBr(inst), + .switch_br => return self.airSwitchBr(inst), + .loop => return self.airLoop(inst), + .ret => return self.airRet(inst, false), + .ret_safe => return self.airRet(inst, true), + .ret_load => return self.airRetLoad(inst), + .trap => return self.airTrap(inst), + .unreach => return self.airUnreach(inst), + + // Instructions which may be `noreturn`. + .block => res: { + const res = try self.airBlock(inst); + if (self.typeOfIndex(inst).isNoReturn(mod)) return; + break :res res; + }, + .dbg_inline_block => res: { + const res = try self.airDbgInlineBlock(inst); + if (self.typeOfIndex(inst).isNoReturn(mod)) return; + break :res res; + }, + .call, .call_always_tail, .call_never_tail, .call_never_inline => |tag| res: { + const res = try self.airCall(inst, switch (tag) { + .call => .auto, + .call_always_tail => .always_tail, + .call_never_tail => .never_tail, + .call_never_inline => .never_inline, + else => unreachable, + }); + // TODO: the AIR we emit for calls is a bit weird - the instruction has + // type `noreturn`, but there are instructions (and maybe a safety check) following + // nonetheless. The `unreachable` or safety check should be emitted by backends instead. + //if (self.typeOfIndex(inst).isNoReturn(mod)) return; + break :res res; + }, + // zig fmt: on }; if (val != .none) try self.func_inst_table.putNoClobber(self.gpa, inst.toRef(), val); } + unreachable; } fn genBodyDebugScope(self: *FuncGen, maybe_inline_func: ?InternPool.Index, body: []const Air.Inst.Index) Error!void { @@ -5553,7 +5582,7 @@ pub const FuncGen = struct { _ = try fg.wip.@"unreachable"(); } - fn airRet(self: *FuncGen, inst: Air.Inst.Index, safety: bool) !Builder.Value { + fn airRet(self: *FuncGen, inst: Air.Inst.Index, safety: bool) !void { const o = self.dg.object; const mod = o.module; const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op; @@ -5586,7 +5615,7 @@ pub const FuncGen = struct { try self.valgrindMarkUndef(self.ret_ptr, len); } _ = try self.wip.retVoid(); - return .none; + return; } const unwrapped_operand = operand.unwrap(); @@ -5595,12 +5624,12 @@ pub const FuncGen = struct { // Return value was stored previously if (unwrapped_operand == .instruction and unwrapped_ret == .instruction and unwrapped_operand.instruction == unwrapped_ret.instruction) { _ = try self.wip.retVoid(); - return .none; + return; } try self.store(self.ret_ptr, ptr_ty, operand, .none); _ = try self.wip.retVoid(); - return .none; + return; } const fn_info = mod.typeToFunc(self.dg.decl.typeOf(mod)).?; if (!ret_ty.hasRuntimeBitsIgnoreComptime(mod)) { @@ -5612,7 +5641,7 @@ pub const FuncGen = struct { } else { _ = try self.wip.retVoid(); } - return .none; + return; } const abi_ret_ty = try lowerFnRetTy(o, fn_info); @@ -5636,29 +5665,29 @@ pub const FuncGen = struct { try self.valgrindMarkUndef(rp, len); } _ = try self.wip.ret(try self.wip.load(.normal, abi_ret_ty, rp, alignment, "")); - return .none; + return; } if (isByRef(ret_ty, mod)) { // operand is a pointer however self.ret_ptr is null so that means // we need to return a value. _ = try self.wip.ret(try self.wip.load(.normal, abi_ret_ty, operand, alignment, "")); - return .none; + return; } const llvm_ret_ty = operand.typeOfWip(&self.wip); if (abi_ret_ty == llvm_ret_ty) { _ = try self.wip.ret(operand); - return .none; + return; } const rp = try self.buildAlloca(llvm_ret_ty, alignment); _ = try self.wip.store(.normal, operand, rp, alignment); _ = try self.wip.ret(try self.wip.load(.normal, abi_ret_ty, rp, alignment, "")); - return .none; + return; } - fn airRetLoad(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { + fn airRetLoad(self: *FuncGen, inst: Air.Inst.Index) !void { const o = self.dg.object; const mod = o.module; const un_op = self.air.instructions.items(.data)[@intFromEnum(inst)].un_op; @@ -5674,17 +5703,17 @@ pub const FuncGen = struct { } else { _ = try self.wip.retVoid(); } - return .none; + return; } if (self.ret_ptr != .none) { _ = try self.wip.retVoid(); - return .none; + return; } const ptr = try self.resolveInst(un_op); const abi_ret_ty = try lowerFnRetTy(o, fn_info); const alignment = ret_ty.abiAlignment(mod).toLlvm(); _ = try self.wip.ret(try self.wip.load(.normal, abi_ret_ty, ptr, alignment, "")); - return .none; + return; } fn airCVaArg(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { @@ -5944,7 +5973,7 @@ pub const FuncGen = struct { } } - fn airBr(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { + fn airBr(self: *FuncGen, inst: Air.Inst.Index) !void { const o = self.dg.object; const branch = self.air.instructions.items(.data)[@intFromEnum(inst)].br; const block = self.blocks.get(branch.block_inst).?; @@ -5960,10 +5989,16 @@ pub const FuncGen = struct { try block.breaks.list.append(self.gpa, .{ .bb = self.wip.cursor.block, .val = val }); } else block.breaks.len += 1; _ = try self.wip.br(block.parent_bb); - return .none; } - fn airCondBr(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { + fn airRepeat(self: *FuncGen, inst: Air.Inst.Index) !void { + const repeat = self.air.instructions.items(.data)[@intFromEnum(inst)].repeat; + const loop_bb = self.loops.get(repeat.loop_inst).?; + loop_bb.ptr(&self.wip).incoming += 1; + _ = try self.wip.br(loop_bb); + } + + fn airCondBr(self: *FuncGen, inst: Air.Inst.Index) !void { const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; const cond = try self.resolveInst(pl_op.operand); const extra = self.air.extraData(Air.CondBr, pl_op.payload); @@ -5981,7 +6016,6 @@ pub const FuncGen = struct { try self.genBodyDebugScope(null, else_body); // No need to reset the insert cursor since this instruction is noreturn. - return .none; } fn airTry(self: *FuncGen, body_tail: []const Air.Inst.Index) !Builder.Value { @@ -6083,7 +6117,7 @@ pub const FuncGen = struct { return fg.wip.extractValue(err_union, &.{offset}, ""); } - fn airSwitchBr(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { + fn airSwitchBr(self: *FuncGen, inst: Air.Inst.Index) !void { const o = self.dg.object; const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; const cond = try self.resolveInst(pl_op.operand); @@ -6208,31 +6242,20 @@ pub const FuncGen = struct { } // No need to reset the insert cursor since this instruction is noreturn. - return .none; } - fn airLoop(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { - const o = self.dg.object; - const mod = o.module; + fn airLoop(self: *FuncGen, inst: Air.Inst.Index) !void { const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; const loop = self.air.extraData(Air.Block, ty_pl.payload); const body: []const Air.Inst.Index = @ptrCast(self.air.extra[loop.end..][0..loop.data.body_len]); - const loop_block = try self.wip.block(2, "Loop"); + const loop_block = try self.wip.block(1, "Loop"); // `airRepeat` will increment incoming each time _ = try self.wip.br(loop_block); + try self.loops.putNoClobber(self.gpa, inst, loop_block); + defer assert(self.loops.remove(inst)); + self.wip.cursor = .{ .block = loop_block }; try self.genBodyDebugScope(null, body); - - // TODO instead of this logic, change AIR to have the property that - // every block is guaranteed to end with a noreturn instruction. - // Then we can simply rely on the fact that a repeat or break instruction - // would have been emitted already. Also the main loop in genBody can - // be while(true) instead of for(body), which will eliminate 1 branch on - // a hot path. - if (body.len == 0 or !self.typeOfIndex(body[body.len - 1]).isNoReturn(mod)) { - _ = try self.wip.br(loop_block); - } - return .none; } fn airArrayToSlice(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { @@ -6717,10 +6740,9 @@ pub const FuncGen = struct { return self.wip.not(operand, ""); } - fn airUnreach(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { + fn airUnreach(self: *FuncGen, inst: Air.Inst.Index) !void { _ = inst; _ = try self.wip.@"unreachable"(); - return .none; } fn airDbgStmt(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { @@ -9072,11 +9094,10 @@ pub const FuncGen = struct { return fg.load(ptr, ptr_ty); } - fn airTrap(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { + fn airTrap(self: *FuncGen, inst: Air.Inst.Index) !void { _ = inst; _ = try self.wip.callIntrinsic(.normal, .none, .trap, &.{}, &.{}, ""); _ = try self.wip.@"unreachable"(); - return .none; } fn airBreakpoint(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { diff --git a/src/codegen/spirv.zig b/src/codegen/spirv.zig index 060b2a24f7ee..5e8bc6e1fcfa 100644 --- a/src/codegen/spirv.zig +++ b/src/codegen/spirv.zig @@ -2439,6 +2439,7 @@ const DeclGen = struct { .store, .store_safe => return self.airStore(inst), .br => return self.airBr(inst), + .repeat => return self.fail("TODO implement `repeat`", .{}), .breakpoint => return, .cond_br => return self.airCondBr(inst), .loop => return self.airLoop(inst), diff --git a/src/print_air.zig b/src/print_air.zig index 030a03e5aa9e..d90adede978a 100644 --- a/src/print_air.zig +++ b/src/print_air.zig @@ -295,6 +295,7 @@ const Writer = struct { .aggregate_init => try w.writeAggregateInit(s, inst), .union_init => try w.writeUnionInit(s, inst), .br => try w.writeBr(s, inst), + .repeat => try w.writeRepeat(s, inst), .cond_br => try w.writeCondBr(s, inst), .@"try" => try w.writeTry(s, inst), .try_ptr => try w.writeTryPtr(s, inst), @@ -704,6 +705,11 @@ const Writer = struct { try w.writeOperand(s, inst, 0, br.operand); } + fn writeRepeat(w: *Writer, s: anytype, inst: Air.Inst.Index) @TypeOf(s).Error!void { + const repeat = w.air.instructions.items(.data)[@intFromEnum(inst)].repeat; + try w.writeInstIndex(s, repeat.loop_inst, false); + } + fn writeTry(w: *Writer, s: anytype, inst: Air.Inst.Index) @TypeOf(s).Error!void { const pl_op = w.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; const extra = w.air.extraData(Air.Try, pl_op.payload); From db6c2160eea3e45166c95127e6c77c9ee97ceed6 Mon Sep 17 00:00:00 2001 From: mlugg Date: Sun, 28 Apr 2024 21:44:57 +0100 Subject: [PATCH 4/9] compiler: implement labeled switch/continue --- lib/std/zig/Ast.zig | 54 ++- lib/std/zig/AstGen.zig | 113 ++++++- lib/std/zig/Parse.zig | 26 +- lib/std/zig/Zir.zig | 13 +- src/Air.zig | 75 +++++ src/Liveness.zig | 265 +++++++++------ src/Liveness/Verify.zig | 31 +- src/Sema.zig | 610 ++++++++++++++++++++++++++-------- src/Value.zig | 1 + src/arch/aarch64/CodeGen.zig | 2 + src/arch/arm/CodeGen.zig | 2 + src/arch/riscv64/CodeGen.zig | 2 + src/arch/sparc64/CodeGen.zig | 2 + src/arch/wasm/CodeGen.zig | 2 + src/arch/x86_64/CodeGen.zig | 2 + src/codegen/c.zig | 155 ++++++--- src/codegen/llvm.zig | 482 ++++++++++++++++++++------- src/print_air.zig | 3 +- src/print_zir.zig | 1 + test/behavior.zig | 1 + test/behavior/switch_loop.zig | 205 ++++++++++++ 21 files changed, 1621 insertions(+), 426 deletions(-) create mode 100644 test/behavior/switch_loop.zig diff --git a/lib/std/zig/Ast.zig b/lib/std/zig/Ast.zig index 20bdba8cf796..a12f49178232 100644 --- a/lib/std/zig/Ast.zig +++ b/lib/std/zig/Ast.zig @@ -1197,14 +1197,7 @@ pub fn lastToken(tree: Ast, node: Node.Index) TokenIndex { n = extra.sentinel; }, - .@"continue" => { - if (datas[n].lhs != 0) { - return datas[n].lhs + end_offset; - } else { - return main_tokens[n] + end_offset; - } - }, - .@"break" => { + .@"continue", .@"break" => { if (datas[n].rhs != 0) { n = datas[n].rhs; } else if (datas[n].lhs != 0) { @@ -1908,6 +1901,15 @@ pub fn taggedUnionEnumTag(tree: Ast, node: Node.Index) full.ContainerDecl { }); } +pub fn switchFull(tree: Ast, node: Node.Index) full.Switch { + const data = &tree.nodes.items(.data)[node]; + return tree.fullSwitchComponents(.{ + .switch_token = tree.nodes.items(.main_token)[node], + .condition = data.lhs, + .sub_range = data.rhs, + }); +} + pub fn switchCaseOne(tree: Ast, node: Node.Index) full.SwitchCase { const data = &tree.nodes.items(.data)[node]; const values: *[1]Node.Index = &data.lhs; @@ -2217,6 +2219,21 @@ fn fullContainerDeclComponents(tree: Ast, info: full.ContainerDecl.Components) f return result; } +fn fullSwitchComponents(tree: Ast, info: full.Switch.Components) full.Switch { + const token_tags = tree.tokens.items(.tag); + const tok_i = info.switch_token -| 1; + var result: full.Switch = .{ + .ast = info, + .label_token = null, + }; + if (token_tags[tok_i] == .colon and + token_tags[tok_i -| 1] == .identifier) + { + result.label_token = tok_i - 1; + } + return result; +} + fn fullSwitchCaseComponents(tree: Ast, info: full.SwitchCase.Components, node: Node.Index) full.SwitchCase { const token_tags = tree.tokens.items(.tag); const node_tags = tree.nodes.items(.tag); @@ -2488,6 +2505,13 @@ pub fn fullContainerDecl(tree: Ast, buffer: *[2]Ast.Node.Index, node: Node.Index }; } +pub fn fullSwitch(tree: Ast, node: Node.Index) ?full.Switch { + return switch (tree.nodes.items(.tag)[node]) { + .@"switch", .switch_comma => tree.switchFull(node), + else => null, + }; +} + pub fn fullSwitchCase(tree: Ast, node: Node.Index) ?full.SwitchCase { return switch (tree.nodes.items(.tag)[node]) { .switch_case_one, .switch_case_inline_one => tree.switchCaseOne(node), @@ -2840,6 +2864,17 @@ pub const full = struct { }; }; + pub const Switch = struct { + ast: Components, + label_token: ?TokenIndex, + + pub const Components = struct { + switch_token: TokenIndex, + condition: Node.Index, + sub_range: Node.Index, + }; + }; + pub const SwitchCase = struct { inline_token: ?TokenIndex, /// Points to the first token after the `|`. Will either be an identifier or @@ -3294,7 +3329,8 @@ pub const Node = struct { @"suspend", /// `resume lhs`. rhs is unused. @"resume", - /// `continue`. lhs is token index of label if any. rhs is unused. + /// `continue :lhs rhs` + /// both lhs and rhs may be omitted. @"continue", /// `break :lhs rhs` /// both lhs and rhs may be omitted. diff --git a/lib/std/zig/AstGen.zig b/lib/std/zig/AstGen.zig index 9be6a9f60597..e4987b05e360 100644 --- a/lib/std/zig/AstGen.zig +++ b/lib/std/zig/AstGen.zig @@ -1140,7 +1140,7 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE .error_set_decl => return errorSetDecl(gz, ri, node), .array_access => return arrayAccess(gz, scope, ri, node), .@"comptime" => return comptimeExprAst(gz, scope, ri, node), - .@"switch", .switch_comma => return switchExpr(gz, scope, ri.br(), node), + .@"switch", .switch_comma => return switchExpr(gz, scope, ri.br(), node, tree.fullSwitch(node).?), .@"nosuspend" => return nosuspendExpr(gz, scope, ri, node), .@"suspend" => return suspendExpr(gz, scope, node), @@ -2154,6 +2154,11 @@ fn breakExpr(parent_gz: *GenZir, parent_scope: *Scope, node: Ast.Node.Index) Inn if (break_label != 0) { if (block_gz.label) |*label| { if (try astgen.tokenIdentEql(label.token, break_label)) { + const maybe_switch_tag = astgen.instructions.items(.tag)[@intFromEnum(label.block_inst)]; + switch (maybe_switch_tag) { + .switch_block, .switch_block_ref => return astgen.failNode(node, "cannot break from switch", .{}), + else => {}, + } label.used = true; break :blk label.block_inst; } @@ -2228,6 +2233,11 @@ fn continueExpr(parent_gz: *GenZir, parent_scope: *Scope, node: Ast.Node.Index) const tree = astgen.tree; const node_datas = tree.nodes.items(.data); const break_label = node_datas[node].lhs; + const rhs = node_datas[node].rhs; + + if (break_label == 0 and rhs != 0) { + return astgen.failNode(node, "cannot continue with operand without label", .{}); + } // Look for the label in the scope. var scope = parent_scope; @@ -2252,6 +2262,15 @@ fn continueExpr(parent_gz: *GenZir, parent_scope: *Scope, node: Ast.Node.Index) if (break_label != 0) blk: { if (gen_zir.label) |*label| { if (try astgen.tokenIdentEql(label.token, break_label)) { + const maybe_switch_tag = astgen.instructions.items(.tag)[@intFromEnum(label.block_inst)]; + if (rhs != 0) switch (maybe_switch_tag) { + .switch_block, .switch_block_ref => {}, + else => return astgen.failNode(node, "cannot continue loop with operand", .{}), + } else switch (maybe_switch_tag) { + .switch_block, .switch_block_ref => return astgen.failNode(node, "cannot continue switch without operand", .{}), + else => {}, + } + label.used = true; break :blk; } @@ -2259,8 +2278,35 @@ fn continueExpr(parent_gz: *GenZir, parent_scope: *Scope, node: Ast.Node.Index) // found continue but either it has a different label, or no label scope = gen_zir.parent; continue; + } else if (gen_zir.label) |label| { + // This `continue` is unlabeled. If the gz we've found corresponds to a labeled + // `switch`, ignore it and continue to parent scopes. + switch (astgen.instructions.items(.tag)[@intFromEnum(label.block_inst)]) { + .switch_block, .switch_block_ref => { + scope = gen_zir.parent; + continue; + }, + else => {}, + } + } + + if (rhs != 0) { + // We need to figure out the result info to use. + // The type should match + const operand = try reachableExpr(parent_gz, parent_scope, gen_zir.continue_result_info, rhs, node); + + try genDefers(parent_gz, scope, parent_scope, .normal_only); + + // As our last action before the continue, "pop" the error trace if needed + if (!gen_zir.is_comptime) + _ = try parent_gz.addRestoreErrRetIndex(.{ .block = continue_block }, .always, node); + + _ = try parent_gz.addBreakWithSrcNode(.switch_continue, continue_block, operand, rhs); + return Zir.Inst.Ref.unreachable_value; } + try genDefers(parent_gz, scope, parent_scope, .normal_only); + const break_tag: Zir.Inst.Tag = if (gen_zir.is_inline) .break_inline else @@ -2278,12 +2324,7 @@ fn continueExpr(parent_gz: *GenZir, parent_scope: *Scope, node: Ast.Node.Index) }, .local_val => scope = scope.cast(Scope.LocalVal).?.parent, .local_ptr => scope = scope.cast(Scope.LocalPtr).?.parent, - .defer_normal => { - const defer_scope = scope.cast(Scope.Defer).?; - scope = defer_scope.parent; - try parent_gz.addDefer(defer_scope.index, defer_scope.len); - }, - .defer_error => scope = scope.cast(Scope.Defer).?.parent, + .defer_normal, .defer_error => scope = scope.cast(Scope.Defer).?.parent, .namespace => break, .top => unreachable, } @@ -2843,6 +2884,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As .panic, .trap, .check_comptime_control_flow, + .switch_continue, => { noreturn_src_node = statement; break :b true; @@ -7569,7 +7611,8 @@ fn switchExpr( parent_gz: *GenZir, scope: *Scope, ri: ResultInfo, - switch_node: Ast.Node.Index, + node: Ast.Node.Index, + switch_full: Ast.full.Switch, ) InnerError!Zir.Inst.Ref { const astgen = parent_gz.astgen; const gpa = astgen.gpa; @@ -7578,14 +7621,14 @@ fn switchExpr( const node_tags = tree.nodes.items(.tag); const main_tokens = tree.nodes.items(.main_token); const token_tags = tree.tokens.items(.tag); - const operand_node = node_datas[switch_node].lhs; - const extra = tree.extraData(node_datas[switch_node].rhs, Ast.Node.SubRange); + const operand_node = node_datas[node].lhs; + const extra = tree.extraData(node_datas[node].rhs, Ast.Node.SubRange); const case_nodes = tree.extra_data[extra.start..extra.end]; - const need_rl = astgen.nodes_need_rl.contains(switch_node); + const need_rl = astgen.nodes_need_rl.contains(node); const block_ri: ResultInfo = if (need_rl) ri else .{ .rl = switch (ri.rl) { - .ptr => .{ .ty = (try ri.rl.resultType(parent_gz, switch_node)).? }, + .ptr => .{ .ty = (try ri.rl.resultType(parent_gz, node)).? }, .inferred_ptr => .none, else => ri.rl, }, @@ -7596,11 +7639,16 @@ fn switchExpr( const LocTag = @typeInfo(ResultInfo.Loc).Union.tag_type.?; const need_result_rvalue = @as(LocTag, block_ri.rl) != @as(LocTag, ri.rl); + if (switch_full.label_token) |label_token| { + try astgen.checkLabelRedefinition(scope, label_token); + } + // We perform two passes over the AST. This first pass is to collect information // for the following variables, make note of the special prong AST node index, // and bail out with a compile error if there are multiple special prongs present. var any_payload_is_ref = false; var any_has_tag_capture = false; + var any_non_inline_capture = false; var scalar_cases_len: u32 = 0; var multi_cases_len: u32 = 0; var inline_cases_len: u32 = 0; @@ -7618,6 +7666,15 @@ fn switchExpr( if (token_tags[ident + 1] == .comma) { any_has_tag_capture = true; } + + // If the first capture is ignored, then there is no runtime-known + // capture, as the tag capture must be for an inline prong. + // This check isn't perfect, because for things like enums, the + // first prong *is* comptime-known for inline prongs! But such + // knowledge requires semantic analysis. + if (!mem.eql(u8, tree.tokenSlice(ident), "_")) { + any_non_inline_capture = true; + } } // Check for else/`_` prong. if (case.ast.values.len == 0) { @@ -7637,7 +7694,7 @@ fn switchExpr( ); } else if (underscore_src) |some_underscore| { return astgen.failNodeNotes( - switch_node, + node, "else and '_' prong in switch expression", .{}, &[_]u32{ @@ -7678,7 +7735,7 @@ fn switchExpr( ); } else if (else_src) |some_else| { return astgen.failNodeNotes( - switch_node, + node, "else and '_' prong in switch expression", .{}, &[_]u32{ @@ -7727,6 +7784,12 @@ fn switchExpr( const raw_operand = try expr(parent_gz, scope, operand_ri, operand_node); const item_ri: ResultInfo = .{ .rl = .none }; + // If this switch is labeled, it will have `continue`s targeting it, and thus we need the operand type + // to provide a result type. + const raw_operand_ty_ref = if (switch_full.label_token != null) t: { + break :t try parent_gz.addUnNode(.typeof, raw_operand, operand_node); + } else undefined; + // This contains the data that goes into the `extra` array for the SwitchBlock/SwitchBlockMulti, // except the first cases_nodes.len slots are a table that indexes payloads later in the array, with // the special case index coming first, then scalar_case_len indexes, then multi_cases_len indexes @@ -7748,7 +7811,22 @@ fn switchExpr( try emitDbgStmtForceCurrentIndex(parent_gz, operand_lc); // This gets added to the parent block later, after the item expressions. const switch_tag: Zir.Inst.Tag = if (any_payload_is_ref) .switch_block_ref else .switch_block; - const switch_block = try parent_gz.makeBlockInst(switch_tag, switch_node); + const switch_block = try parent_gz.makeBlockInst(switch_tag, node); + + if (switch_full.label_token) |label_token| { + block_scope.continue_block = switch_block.toOptional(); + block_scope.continue_result_info = .{ + .rl = if (any_payload_is_ref) + .{ .ref_coerced_ty = raw_operand_ty_ref } + else + .{ .coerced_ty = raw_operand_ty_ref }, + }; + + block_scope.label = .{ + .token = label_token, + .block_inst = switch_block, + }; + } // We re-use this same scope for all cases, including the special prong, if any. var case_scope = parent_gz.makeSubBlock(&block_scope.base); @@ -7969,6 +8047,8 @@ fn switchExpr( .has_else = special_prong == .@"else", .has_under = special_prong == .under, .any_has_tag_capture = any_has_tag_capture, + .any_non_inline_capture = any_non_inline_capture, + .has_continue = switch_full.label_token != null, .scalar_cases_len = @intCast(scalar_cases_len), }, }); @@ -8005,7 +8085,7 @@ fn switchExpr( } if (need_result_rvalue) { - return rvalue(parent_gz, ri, switch_block.toRef(), switch_node); + return rvalue(parent_gz, ri, switch_block.toRef(), node); } else { return switch_block.toRef(); } @@ -11894,6 +11974,7 @@ const GenZir = struct { continue_block: Zir.Inst.OptionalIndex = .none, /// Only valid when setBreakResultInfo is called. break_result_info: AstGen.ResultInfo = undefined, + continue_result_info: AstGen.ResultInfo = undefined, suspend_node: Ast.Node.Index = 0, nosuspend_node: Ast.Node.Index = 0, diff --git a/lib/std/zig/Parse.zig b/lib/std/zig/Parse.zig index 369e2ef125b4..55fbd95c4e1e 100644 --- a/lib/std/zig/Parse.zig +++ b/lib/std/zig/Parse.zig @@ -924,7 +924,6 @@ fn expectContainerField(p: *Parse) !Node.Index { /// / KEYWORD_errdefer Payload? BlockExprStatement /// / IfStatement /// / LabeledStatement -/// / SwitchExpr /// / VarDeclExprStatement fn expectStatement(p: *Parse, allow_defer_var: bool) Error!Node.Index { if (p.eatToken(.keyword_comptime)) |comptime_token| { @@ -995,7 +994,6 @@ fn expectStatement(p: *Parse, allow_defer_var: bool) Error!Node.Index { .rhs = try p.expectBlockExprStatement(), }, }), - .keyword_switch => return p.expectSwitchExpr(), .keyword_if => return p.expectIfStatement(), .keyword_enum, .keyword_struct, .keyword_union => { const identifier = p.tok_i + 1; @@ -1238,7 +1236,7 @@ fn expectIfStatement(p: *Parse) !Node.Index { }); } -/// LabeledStatement <- BlockLabel? (Block / LoopStatement) +/// LabeledStatement <- BlockLabel? (Block / LoopStatement / SwitchExpr) fn parseLabeledStatement(p: *Parse) !Node.Index { const label_token = p.parseBlockLabel(); const block = try p.parseBlock(); @@ -1247,6 +1245,9 @@ fn parseLabeledStatement(p: *Parse) !Node.Index { const loop_stmt = try p.parseLoopStatement(); if (loop_stmt != 0) return loop_stmt; + const switch_expr = try p.parseSwitchExpr(); + if (switch_expr != 0) return switch_expr; + if (label_token != 0) { const after_colon = p.tok_i; const node = try p.parseTypeExpr(); @@ -2072,7 +2073,7 @@ fn expectTypeExpr(p: *Parse) Error!Node.Index { /// / KEYWORD_break BreakLabel? Expr? /// / KEYWORD_comptime Expr /// / KEYWORD_nosuspend Expr -/// / KEYWORD_continue BreakLabel? +/// / KEYWORD_continue BreakLabel? Expr? /// / KEYWORD_resume Expr /// / KEYWORD_return Expr? /// / BlockLabel? LoopExpr @@ -2098,7 +2099,7 @@ fn parsePrimaryExpr(p: *Parse) !Node.Index { .main_token = p.nextToken(), .data = .{ .lhs = try p.parseBreakLabel(), - .rhs = undefined, + .rhs = try p.parseExpr(), }, }); }, @@ -2627,7 +2628,6 @@ fn parseSuffixExpr(p: *Parse) !Node.Index { /// / KEYWORD_anyframe /// / KEYWORD_unreachable /// / STRINGLITERAL -/// / SwitchExpr /// /// ContainerDecl <- (KEYWORD_extern / KEYWORD_packed)? ContainerDeclAuto /// @@ -2647,6 +2647,7 @@ fn parseSuffixExpr(p: *Parse) !Node.Index { /// LabeledTypeExpr /// <- BlockLabel Block /// / BlockLabel? LoopTypeExpr +/// / BlockLabel? SwitchExpr /// /// LoopTypeExpr <- KEYWORD_inline? (ForTypeExpr / WhileTypeExpr) fn parsePrimaryTypeExpr(p: *Parse) !Node.Index { @@ -2753,6 +2754,10 @@ fn parsePrimaryTypeExpr(p: *Parse) !Node.Index { p.tok_i += 2; return p.parseWhileTypeExpr(); }, + .keyword_switch => { + p.tok_i += 2; + return p.expectSwitchExpr(); + }, .l_brace => { p.tok_i += 2; return p.parseBlock(); @@ -3029,8 +3034,17 @@ fn parseWhileTypeExpr(p: *Parse) !Node.Index { } /// SwitchExpr <- KEYWORD_switch LPAREN Expr RPAREN LBRACE SwitchProngList RBRACE +fn parseSwitchExpr(p: *Parse) !Node.Index { + const switch_token = p.eatToken(.keyword_switch) orelse return null_node; + return p.expectSwitchSuffix(switch_token); +} + fn expectSwitchExpr(p: *Parse) !Node.Index { const switch_token = p.assertToken(.keyword_switch); + return p.expectSwitchSuffix(switch_token); +} + +fn expectSwitchSuffix(p: *Parse, switch_token: TokenIndex) !Node.Index { _ = try p.expectToken(.l_paren); const expr_node = try p.expectExpr(); _ = try p.expectToken(.r_paren); diff --git a/lib/std/zig/Zir.zig b/lib/std/zig/Zir.zig index 64e8a1c8050f..59942fdbd061 100644 --- a/lib/std/zig/Zir.zig +++ b/lib/std/zig/Zir.zig @@ -313,6 +313,9 @@ pub const Inst = struct { /// break instruction in a block, and the target block is the parent. /// Uses the `break` union field. break_inline, + /// Branch from within a switch case to the case specified by the operand. + /// Uses the `break` union field. `block_inst` refers to a `switch_block` or `switch_block_ref`. + switch_continue, /// Checks that comptime control flow does not happen inside a runtime block. /// Uses the `un_node` union field. check_comptime_control_flow, @@ -1282,6 +1285,7 @@ pub const Inst = struct { .panic, .trap, .check_comptime_control_flow, + .switch_continue, => true, }; } @@ -1524,6 +1528,7 @@ pub const Inst = struct { .break_inline, .condbr, .condbr_inline, + .switch_continue, .compile_error, .ret_node, .ret_load, @@ -1609,6 +1614,7 @@ pub const Inst = struct { .bool_br_or = .pl_node, .@"break" = .@"break", .break_inline = .@"break", + .switch_continue = .@"break", .check_comptime_control_flow = .un_node, .for_len = .pl_node, .call = .pl_node, @@ -2340,6 +2346,7 @@ pub const Inst = struct { }, @"break": struct { operand: Ref, + /// Index of a `Break` payload. payload_index: u32, }, dbg_stmt: LineColumn, @@ -2951,9 +2958,13 @@ pub const Inst = struct { has_under: bool, /// If true, at least one prong has an inline tag capture. any_has_tag_capture: bool, + /// If true, at least one prong has a capture which may not + /// be comptime-known via `inline`. + any_non_inline_capture: bool, + has_continue: bool, scalar_cases_len: ScalarCasesLen, - pub const ScalarCasesLen = u28; + pub const ScalarCasesLen = u26; pub fn specialProng(bits: Bits) SpecialProng { const has_else: u2 = @intFromBool(bits.has_else); diff --git a/src/Air.zig b/src/Air.zig index 8f24f295b79a..40b38bbb7a6c 100644 --- a/src/Air.zig +++ b/src/Air.zig @@ -427,6 +427,14 @@ pub const Inst = struct { /// Result type is always noreturn; no instructions in a block follow this one. /// Uses the `pl_op` field. Operand is the condition. Payload is `SwitchBr`. switch_br, + /// Switch branch which can dispatch back to itself with a different operand. + /// Result type is always noreturn; no instructions in a block follow this one. + /// Uses the `pl_op` field. Operand is the condition. Payload is `SwitchBr`. + loop_switch_br, + /// Dispatches back to a branch of a parent `loop_switch_br`. + /// Result type is always noreturn; no instructions in a block follow this one. + /// Uses the `br` field. `block_inst` is a `loop_switch_br` instruction. + switch_dispatch, /// Given an operand which is an error union, splits control flow. In /// case of error, control flow goes into the block that is part of this /// instruction, which is guaranteed to end with a return instruction @@ -1449,6 +1457,8 @@ pub fn typeOfIndex(air: *const Air, inst: Air.Inst.Index, ip: *const InternPool) .br, .cond_br, .switch_br, + .loop_switch_br, + .switch_dispatch, .ret, .ret_safe, .ret_load, @@ -1612,6 +1622,8 @@ pub fn mustLower(air: Air, inst: Air.Inst.Index, ip: *const InternPool) bool { .call_never_inline, .cond_br, .switch_br, + .loop_switch_br, + .switch_dispatch, .@"try", .try_ptr, .dbg_stmt, @@ -1814,3 +1826,66 @@ pub fn mustLower(air: Air, inst: Air.Inst.Index, ip: *const InternPool) bool { .atomic_load => air.typeOf(data.atomic_load.ptr, ip).isVolatilePtrIp(ip), }; } + +/// This is an iterator over `switch_br` or `loop_switch_br` cases. +/// Call `nextCase` until it returns `null`, then finally call `elseBody`. +pub const SwitchIterator = struct { + // Public, constant fields. + total_cases: u32, + else_body_len: u32, + operand: Air.Inst.Ref, + + // Iterator state. + air: *const Air, + next_case_idx: u32, + extra_index: u32, + + pub const Case = struct { + index: u32, + items: []const Air.Inst.Ref, + ranges: []const [2]Air.Inst.Ref, + body: []const Air.Inst.Index, + }; + + pub fn nextCase(it: *SwitchIterator) ?Case { + if (it.next_case_idx == it.total_cases) return null; + const case_idx = it.next_case_idx; + it.next_case_idx += 1; + + const case = it.air.extraData(SwitchBr.Case, it.extra_index); + const items = it.air.extra[case.end..][0..case.data.items_len]; + it.extra_index = @intCast(case.end + case.data.items_len); + const range_vals = it.air.extra[it.extra_index..][0 .. case.data.ranges_len * 2]; + it.extra_index += case.data.ranges_len * 2; + const body = it.air.extra[it.extra_index..][0..case.data.body_len]; + it.extra_index += case.data.body_len; + + return .{ + .index = case_idx, + .items = @ptrCast(items), + .ranges = @as([*]const [2]Air.Inst.Ref, @ptrCast(range_vals.ptr))[0..case.data.ranges_len], + .body = @ptrCast(body), + }; + } + pub fn elseBody(it: *SwitchIterator) []const Air.Inst.Index { + assert(it.next_case_idx == it.total_cases); + return @ptrCast(it.air.extra[it.extra_index..][0..it.else_body_len]); + } +}; + +pub fn switchIterator(air: *const Air, inst: Inst.Index) SwitchIterator { + const inst_info = air.instructions.get(@intFromEnum(inst)); + switch (inst_info.tag) { + .switch_br, .loop_switch_br => {}, + else => unreachable, // assertion failure + } + const switch_br = air.extraData(SwitchBr, inst_info.data.pl_op.payload); + return .{ + .total_cases = switch_br.data.cases_len, + .else_body_len = switch_br.data.else_body_len, + .operand = inst_info.data.pl_op.operand, + .air = air, + .next_case_idx = 0, + .extra_index = @intCast(switch_br.end), + }; +} diff --git a/src/Liveness.zig b/src/Liveness.zig index 2b56fe728b3a..3dabb5695464 100644 --- a/src/Liveness.zig +++ b/src/Liveness.zig @@ -31,6 +31,7 @@ tomb_bits: []usize, /// * `try`, `try_ptr` - points to a `CondBr` in `extra` at this index. The error path (the block /// in the instruction) is considered the "else" path, and the rest of the block the "then". /// * `switch_br` - points to a `SwitchBr` in `extra` at this index. +/// * `loop_switch_br` - points to a `SwitchBr` in `extra` at this index. /// * `block` - points to a `Block` in `extra` at this index. /// * `asm`, `call`, `aggregate_init` - the value is a set of bits which are the extra tomb /// bits of operands. @@ -68,8 +69,8 @@ pub const Block = struct { /// Liveness analysis runs in several passes. Each pass iterates backwards over instructions in /// bodies, and recurses into bodies. const LivenessPass = enum { - /// In this pass, we perform some basic analysis of loops to gain information the main pass - /// needs. In particular, for every `loop`, we track the following information: + /// In this pass, we perform some basic analysis of loops to gain information the main pass needs. + /// In particular, for every `loop` and `loop_switch_br`, we track the following information: /// * Every outer block which the loop body contains a `br` to. /// * Every outer loop which the loop body contains a `repeat` to. /// * Every operand referenced within the loop body but created outside the loop. @@ -91,7 +92,8 @@ fn LivenessPassData(comptime pass: LivenessPass) type { .loop_analysis => struct { /// The set of blocks which are exited with a `br` instruction at some point within this /// body and which we are currently within. Also includes `loop`s which are the target - /// of a `repeat` instruction. + /// of a `repeat` instruction, and `loop_switch_br`s which are the target of a + /// `switch_dispatch` instruction. breaks: std.AutoHashMapUnmanaged(Air.Inst.Index, void) = .{}, /// The set of operands for which we have seen at least one usage but not their birth. @@ -330,6 +332,7 @@ pub fn categorizeOperand( .trap, .breakpoint, .repeat, + .switch_dispatch, .dbg_stmt, .unreach, .ret_addr, @@ -661,21 +664,15 @@ pub fn categorizeOperand( return .complex; }, - .@"try" => { - return .complex; - }, - .try_ptr => { - return .complex; - }, - .loop => { - return .complex; - }, - .cond_br => { - return .complex; - }, - .switch_br => { - return .complex; - }, + + .@"try", + .try_ptr, + .loop, + .cond_br, + .switch_br, + .loop_switch_br, + => return .complex, + .wasm_memory_grow => { const pl_op = air_datas[@intFromEnum(inst)].pl_op; if (pl_op.operand == operand_ref) return matchOperandSmallIndex(l, inst, 0, .none); @@ -1204,6 +1201,7 @@ fn analyzeInst( .br => return analyzeInstBr(a, pass, data, inst), .repeat => return analyzeInstRepeat(a, pass, data, inst), + .switch_dispatch => return analyzeInstSwitchDispatch(a, pass, data, inst), .assembly => { const extra = a.air.extraData(Air.Asm, inst_datas[@intFromEnum(inst)].ty_pl.payload); @@ -1260,7 +1258,8 @@ fn analyzeInst( .@"try" => return analyzeInstCondBr(a, pass, data, inst, .@"try"), .try_ptr => return analyzeInstCondBr(a, pass, data, inst, .try_ptr), .cond_br => return analyzeInstCondBr(a, pass, data, inst, .cond_br), - .switch_br => return analyzeInstSwitchBr(a, pass, data, inst), + .switch_br => return analyzeInstSwitchBr(a, pass, data, inst, false), + .loop_switch_br => return analyzeInstSwitchBr(a, pass, data, inst, true), .wasm_memory_grow => { const pl_op = inst_datas[@intFromEnum(inst)].pl_op; @@ -1410,6 +1409,35 @@ fn analyzeInstRepeat( return analyzeOperands(a, pass, data, inst, .{ .none, .none, .none }); } +fn analyzeInstSwitchDispatch( + a: *Analysis, + comptime pass: LivenessPass, + data: *LivenessPassData(pass), + inst: Air.Inst.Index, +) !void { + // This happens to be identical to `analyzeInstBr`, but is separated anyway for clarity. + + const inst_datas = a.air.instructions.items(.data); + const br = inst_datas[@intFromEnum(inst)].br; + const gpa = a.gpa; + + switch (pass) { + .loop_analysis => { + try data.breaks.put(gpa, br.block_inst, {}); + }, + + .main_analysis => { + const block_scope = data.block_scopes.get(br.block_inst).?; // we should always be repeating an enclosing loop + + const new_live_set = try block_scope.live_set.clone(gpa); + data.live_set.deinit(gpa); + data.live_set = new_live_set; + }, + } + + return analyzeOperands(a, pass, data, inst, .{ br.operand, .none, .none }); +} + fn analyzeInstBlock( a: *Analysis, comptime pass: LivenessPass, @@ -1480,109 +1508,133 @@ fn analyzeInstBlock( } } -fn analyzeInstLoop( +fn writeLoopInfo( a: *Analysis, - comptime pass: LivenessPass, - data: *LivenessPassData(pass), + data: *LivenessPassData(.loop_analysis), inst: Air.Inst.Index, + old_breaks: std.AutoHashMapUnmanaged(Air.Inst.Index, void), + old_live: std.AutoHashMapUnmanaged(Air.Inst.Index, void), ) !void { - const inst_datas = a.air.instructions.items(.data); - const extra = a.air.extraData(Air.Block, inst_datas[@intFromEnum(inst)].ty_pl.payload); - const body: []const Air.Inst.Index = @ptrCast(a.air.extra[extra.end..][0..extra.data.body_len]); const gpa = a.gpa; - try analyzeOperands(a, pass, data, inst, .{ .none, .none, .none }); + // `loop`s are guaranteed to have at least one matching `repeat`. + // Similarly, `loop_switch_br`s have a matching `switch_dispatch`. + // However, we no longer care about repeats of this loop for resolving + // which operands must live within it. + assert(data.breaks.remove(inst)); - switch (pass) { - .loop_analysis => { - var old_breaks = data.breaks.move(); - defer old_breaks.deinit(gpa); + const extra_index: u32 = @intCast(a.extra.items.len); - var old_live = data.live_set.move(); - defer old_live.deinit(gpa); + const num_breaks = data.breaks.count(); + try a.extra.ensureUnusedCapacity(gpa, 1 + num_breaks); - try analyzeBody(a, pass, data, body); + a.extra.appendAssumeCapacity(num_breaks); - // `loop`s are guaranteed to have at least one matching `repeat`. - // However, we no longer care about repeats of this loop itself. - assert(data.breaks.remove(inst)); + var it = data.breaks.keyIterator(); + while (it.next()) |key| { + const block_inst = key.*; + a.extra.appendAssumeCapacity(@intFromEnum(block_inst)); + } + log.debug("[{}] %{}: includes breaks to {}", .{ LivenessPass.loop_analysis, inst, fmtInstSet(&data.breaks) }); - const extra_index: u32 = @intCast(a.extra.items.len); + // Now we put the live operands from the loop body in too + const num_live = data.live_set.count(); + try a.extra.ensureUnusedCapacity(gpa, 1 + num_live); - const num_breaks = data.breaks.count(); - try a.extra.ensureUnusedCapacity(gpa, 1 + num_breaks); + a.extra.appendAssumeCapacity(num_live); + it = data.live_set.keyIterator(); + while (it.next()) |key| { + const alive = key.*; + a.extra.appendAssumeCapacity(@intFromEnum(alive)); + } + log.debug("[{}] %{}: maintain liveness of {}", .{ LivenessPass.loop_analysis, inst, fmtInstSet(&data.live_set) }); - a.extra.appendAssumeCapacity(num_breaks); + try a.special.put(gpa, inst, extra_index); - var it = data.breaks.keyIterator(); - while (it.next()) |key| { - const block_inst = key.*; - a.extra.appendAssumeCapacity(@intFromEnum(block_inst)); - } - log.debug("[{}] %{}: includes breaks to {}", .{ pass, inst, fmtInstSet(&data.breaks) }); + // Add back operands which were previously alive + it = old_live.keyIterator(); + while (it.next()) |key| { + const alive = key.*; + try data.live_set.put(gpa, alive, {}); + } - // Now we put the live operands from the loop body in too - const num_live = data.live_set.count(); - try a.extra.ensureUnusedCapacity(gpa, 1 + num_live); + // And the same for breaks + it = old_breaks.keyIterator(); + while (it.next()) |key| { + const block_inst = key.*; + try data.breaks.put(gpa, block_inst, {}); + } +} - a.extra.appendAssumeCapacity(num_live); - it = data.live_set.keyIterator(); - while (it.next()) |key| { - const alive = key.*; - a.extra.appendAssumeCapacity(@intFromEnum(alive)); - } - log.debug("[{}] %{}: maintain liveness of {}", .{ pass, inst, fmtInstSet(&data.live_set) }); +/// When analyzing a loop in the main pass, sets up `data.live_set` to be the set +/// of operands known to be alive when the loop repeats. +fn resolveLoopLiveSet( + a: *Analysis, + data: *LivenessPassData(.main_analysis), + inst: Air.Inst.Index, +) !void { + const gpa = a.gpa; - try a.special.put(gpa, inst, extra_index); + const extra_idx = a.special.fetchRemove(inst).?.value; + const num_breaks = data.old_extra.items[extra_idx]; + const breaks: []const Air.Inst.Index = @ptrCast(data.old_extra.items[extra_idx + 1 ..][0..num_breaks]); - // Add back operands which were previously alive - it = old_live.keyIterator(); - while (it.next()) |key| { - const alive = key.*; - try data.live_set.put(gpa, alive, {}); - } + const num_loop_live = data.old_extra.items[extra_idx + num_breaks + 1]; + const loop_live: []const Air.Inst.Index = @ptrCast(data.old_extra.items[extra_idx + num_breaks + 2 ..][0..num_loop_live]); - // And the same for breaks - it = old_breaks.keyIterator(); - while (it.next()) |key| { - const block_inst = key.*; - try data.breaks.put(gpa, block_inst, {}); - } - }, + // This is necessarily not in the same control flow branch, because loops are noreturn + data.live_set.clearRetainingCapacity(); - .main_analysis => { - const extra_idx = a.special.fetchRemove(inst).?.value; // remove because this data does not exist after analysis + try data.live_set.ensureUnusedCapacity(gpa, @intCast(loop_live.len)); + for (loop_live) |alive| data.live_set.putAssumeCapacity(alive, {}); - const num_breaks = data.old_extra.items[extra_idx]; - const breaks: []const Air.Inst.Index = @ptrCast(data.old_extra.items[extra_idx + 1 ..][0..num_breaks]); + log.debug("[{}] %{}: block live set is {}", .{ LivenessPass.main_analysis, inst, fmtInstSet(&data.live_set) }); - const num_loop_live = data.old_extra.items[extra_idx + num_breaks + 1]; - const loop_live: []const Air.Inst.Index = @ptrCast(data.old_extra.items[extra_idx + num_breaks + 2 ..][0..num_loop_live]); + for (breaks) |block_inst| { + // We might break to this block, so include every operand that the block needs alive + const block_scope = data.block_scopes.get(block_inst).?; - // This is necessarily not in the same control flow branch, because loops are noreturn - data.live_set.clearRetainingCapacity(); + var it = block_scope.live_set.keyIterator(); + while (it.next()) |key| { + const alive = key.*; + try data.live_set.put(gpa, alive, {}); + } + } - try data.live_set.ensureUnusedCapacity(gpa, @intCast(loop_live.len)); - for (loop_live) |alive| { - data.live_set.putAssumeCapacity(alive, {}); - } + log.debug("[{}] %{}: loop live set is {}", .{ LivenessPass.main_analysis, inst, fmtInstSet(&data.live_set) }); +} - log.debug("[{}] %{}: block live set is {}", .{ pass, inst, fmtInstSet(&data.live_set) }); +fn analyzeInstLoop( + a: *Analysis, + comptime pass: LivenessPass, + data: *LivenessPassData(pass), + inst: Air.Inst.Index, +) !void { + const inst_datas = a.air.instructions.items(.data); + const extra = a.air.extraData(Air.Block, inst_datas[@intFromEnum(inst)].ty_pl.payload); + const body: []const Air.Inst.Index = @ptrCast(a.air.extra[extra.end..][0..extra.data.body_len]); + const gpa = a.gpa; - for (breaks) |block_inst| { - // We might break to this block, so include every operand that the block needs alive - const block_scope = data.block_scopes.get(block_inst).?; + try analyzeOperands(a, pass, data, inst, .{ .none, .none, .none }); - var it = block_scope.live_set.keyIterator(); - while (it.next()) |key| { - const alive = key.*; - try data.live_set.put(gpa, alive, {}); - } - } + switch (pass) { + .loop_analysis => { + var old_breaks = data.breaks.move(); + defer old_breaks.deinit(gpa); + + var old_live = data.live_set.move(); + defer old_live.deinit(gpa); + + try analyzeBody(a, pass, data, body); + + try writeLoopInfo(a, data, inst, old_breaks, old_live); + }, + + .main_analysis => { + try resolveLoopLiveSet(a, data, inst); // Now, `data.live_set` is the operands which must be alive when the loop repeats. // Move them into a block scope for corresponding `repeat` instructions to notice. - log.debug("[{}] %{}: loop live set is {}", .{ pass, inst, fmtInstSet(&data.live_set) }); try data.block_scopes.putNoClobber(gpa, inst, .{ .live_set = data.live_set.move(), }); @@ -1718,6 +1770,7 @@ fn analyzeInstSwitchBr( comptime pass: LivenessPass, data: *LivenessPassData(pass), inst: Air.Inst.Index, + is_dispatch_loop: bool, ) !void { const inst_datas = a.air.instructions.items(.data); const pl_op = inst_datas[@intFromEnum(inst)].pl_op; @@ -1728,6 +1781,17 @@ fn analyzeInstSwitchBr( switch (pass) { .loop_analysis => { + var old_breaks: std.AutoHashMapUnmanaged(Air.Inst.Index, void) = .{}; + defer old_breaks.deinit(gpa); + + var old_live: std.AutoHashMapUnmanaged(Air.Inst.Index, void) = .{}; + defer old_live.deinit(gpa); + + if (is_dispatch_loop) { + old_breaks = data.breaks.move(); + old_live = data.live_set.move(); + } + var air_extra_index: usize = switch_br.end; for (0..ncases) |_| { const case = a.air.extraData(Air.SwitchBr.Case, air_extra_index); @@ -1740,9 +1804,24 @@ fn analyzeInstSwitchBr( const else_body: []const Air.Inst.Index = @ptrCast(a.air.extra[air_extra_index..][0..switch_br.data.else_body_len]); try analyzeBody(a, pass, data, else_body); } + + if (is_dispatch_loop) { + try writeLoopInfo(a, data, inst, old_breaks, old_live); + } }, .main_analysis => { + if (is_dispatch_loop) { + try resolveLoopLiveSet(a, data, inst); + try data.block_scopes.putNoClobber(gpa, inst, .{ + .live_set = data.live_set.move(), + }); + } + defer if (is_dispatch_loop) { + log.debug("[{}] %{}: popped loop block scop", .{ pass, inst }); + var scope = data.block_scopes.fetchRemove(inst).?.value; + scope.live_set.deinit(gpa); + }; // This is, all in all, just a messier version of the `cond_br` logic. If you're trying // to understand it, I encourage looking at `analyzeInstCondBr` first. diff --git a/src/Liveness/Verify.zig b/src/Liveness/Verify.zig index b4d150b64574..bfb4b8c370bc 100644 --- a/src/Liveness/Verify.zig +++ b/src/Liveness/Verify.zig @@ -446,6 +446,16 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void { try self.verifyMatchingLiveness(repeat.loop_inst, expected_live); }, + .switch_dispatch => { + const br = data[@intFromEnum(inst)].br; + + try self.verifyOperand(inst, br.operand, self.liveness.operandDies(inst, 0)); + + const expected_live = self.loops.get(br.block_inst) orelse + return invalid("%{}: loop %{} not in scope", .{ @intFromEnum(inst), @intFromEnum(br.block_inst) }); + + try self.verifyMatchingLiveness(br.block_inst, expected_live); + }, .block, .dbg_inline_block => |tag| { const ty_pl = data[@intFromEnum(inst)].ty_pl; const block_ty = ty_pl.ty.toType(); @@ -493,11 +503,11 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void { // The same stuff should be alive after the loop as before it. const gop = try self.loops.getOrPut(self.gpa, inst); + if (gop.found_existing) return invalid("%{}: loop already exists", .{@intFromEnum(inst)}); defer { var live = self.loops.fetchRemove(inst).?; live.value.deinit(self.gpa); } - if (gop.found_existing) return invalid("%{}: loop already exists", .{@intFromEnum(inst)}); gop.value_ptr.* = try self.live.clone(self.gpa); try self.verifyBody(loop_body); @@ -527,7 +537,7 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void { try self.verifyInst(inst); }, - .switch_br => { + .switch_br, .loop_switch_br => { const pl_op = data[@intFromEnum(inst)].pl_op; const switch_br = self.air.extraData(Air.SwitchBr, pl_op.payload); var extra_index = switch_br.end; @@ -541,8 +551,17 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void { try self.verifyOperand(inst, pl_op.operand, self.liveness.operandDies(inst, 0)); - var live = self.live.move(); - defer live.deinit(self.gpa); + // Excluding the operand (which we just handled), the same stuff should be alive + // after the loop as before it. + { + const gop = try self.loops.getOrPut(self.gpa, inst); + if (gop.found_existing) return invalid("%{}: loop already exists", .{@intFromEnum(inst)}); + gop.value_ptr.* = self.live.move(); + } + defer { + var live = self.loops.fetchRemove(inst).?; + live.value.deinit(self.gpa); + } while (case_i < switch_br.data.cases_len) : (case_i += 1) { const case = self.air.extraData(Air.SwitchBr.Case, extra_index); @@ -551,7 +570,7 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void { extra_index += case_body.len; self.live.deinit(self.gpa); - self.live = try live.clone(self.gpa); + self.live = try self.loops.get(inst).?.clone(self.gpa); for (switch_br_liveness.deaths[case_i]) |death| try self.verifyDeath(inst, death); try self.verifyBody(case_body); @@ -560,7 +579,7 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void { const else_body: []const Air.Inst.Index = @ptrCast(self.air.extra[extra_index..][0..switch_br.data.else_body_len]); if (else_body.len > 0) { self.live.deinit(self.gpa); - self.live = try live.clone(self.gpa); + self.live = try self.loops.get(inst).?.clone(self.gpa); for (switch_br_liveness.deaths[case_i]) |death| try self.verifyDeath(inst, death); try self.verifyBody(else_body); diff --git a/src/Sema.zig b/src/Sema.zig index 33e5e87e7c96..ffb197364fe9 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -480,11 +480,21 @@ pub const Block = struct { /// to enable more precise compile errors. /// Same indexes, capacity, length as `results`. src_locs: std.ArrayListUnmanaged(?LazySrcLoc), - - pub fn deinit(merges: *@This(), allocator: mem.Allocator) void { + /// Most blocks do not utilize this field. When it is used, its use is + /// contextual. The possible uses are as follows: + /// * for a `switch_block[_ref]`, this refers to dummy `br` instructions + /// which correspond to `switch_continue` ZIR. The switch logic will + /// rewrite these to appropriate AIR switch dispatches. + extra_insts: std.ArrayListUnmanaged(Air.Inst.Index) = .{}, + /// Same indexes, capacity, length as `extra_insts`. + extra_src_locs: std.ArrayListUnmanaged(LazySrcLoc) = .{}, + + pub fn deinit(merges: *@This(), allocator: Allocator) void { merges.results.deinit(allocator); merges.br_list.deinit(allocator); merges.src_locs.deinit(allocator); + merges.extra_insts.deinit(allocator); + merges.extra_src_locs.deinit(allocator); } }; @@ -913,14 +923,21 @@ fn analyzeInlineBody( error.ComptimeBreak => {}, else => |e| return e, } - const break_inst = sema.comptime_break_inst; - const break_data = sema.code.instructions.items(.data)[@intFromEnum(break_inst)].@"break"; - const extra = sema.code.extraData(Zir.Inst.Break, break_data.payload_index).data; + const break_inst = sema.code.instructions.get(@intFromEnum(sema.comptime_break_inst)); + switch (break_inst.tag) { + .switch_continue => { + // This is handled by separate logic. + return error.ComptimeBreak; + }, + .break_inline, .@"break" => {}, + else => unreachable, + } + const extra = sema.code.extraData(Zir.Inst.Break, break_inst.data.@"break".payload_index).data; if (extra.block_inst != break_target) { // This control flow goes further up the stack. return error.ComptimeBreak; } - return try sema.resolveInst(break_data.operand); + return try sema.resolveInst(break_inst.data.@"break".operand); } /// Like `analyzeInlineBody`, but if the body does not break with a value, returns @@ -1532,6 +1549,13 @@ fn analyzeBodyInner( i = 0; continue; }, + .switch_continue => if (block.is_comptime) { + sema.comptime_break_inst = inst; + return error.ComptimeBreak; + } else { + try sema.zirSwitchContinue(block, inst); + break; + }, .loop => blk: { if (!block.is_comptime) break :blk try sema.zirLoop(block, inst); // Same as `block_inline`. TODO https://github.com/ziglang/zig/issues/8220 @@ -6707,6 +6731,56 @@ fn zirBreak(sema: *Sema, start_block: *Block, inst: Zir.Inst.Index) CompileError } } +fn zirSwitchContinue(sema: *Sema, start_block: *Block, inst: Zir.Inst.Index) CompileError!void { + const tracy = trace(@src()); + defer tracy.end(); + + const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].@"break"; + const extra = sema.code.extraData(Zir.Inst.Break, inst_data.payload_index).data; + assert(extra.operand_src_node != Zir.Inst.Break.no_src_node); + const operand_src = LazySrcLoc.nodeOffset(extra.operand_src_node); + const uncoerced_operand = try sema.resolveInst(inst_data.operand); + const switch_inst = extra.block_inst; + + switch (sema.code.instructions.items(.tag)[@intFromEnum(switch_inst)]) { + .switch_block, .switch_block_ref => {}, + else => unreachable, // assertion failure + } + + const switch_payload_index = sema.code.instructions.items(.data)[@intFromEnum(switch_inst)].pl_node.payload_index; + const switch_operand_ref = sema.code.extraData(Zir.Inst.SwitchBlock, switch_payload_index).data.operand; + const switch_operand_ty = sema.typeOf(try sema.resolveInst(switch_operand_ref)); + + const operand = try sema.coerce(start_block, switch_operand_ty, uncoerced_operand, operand_src); + + try sema.validateRuntimeValue(start_block, operand_src, operand); + + // We want to generate a `switch_dispatch` instruction with the switch condition, + // possibly preceded by a store to the stack alloc containing the raw operand. + // However, to avoid too much special-case state in Sema, this is handled by the + // `switch` lowering logic. As such, we will find the `Block` corresponding to the + // parent `switch_block[_ref]` instruction, create a dummy `br`, and add a merge + // to signal to the switch logic to rewrite this into an appropriate dispatch. + + var block = start_block; + while (true) { + if (block.label) |label| { + if (label.zir_block == switch_inst) { + const br_ref = try start_block.addBr(label.merges.block_inst, operand); + try label.merges.extra_insts.append(sema.gpa, br_ref.toIndex().?); + try label.merges.extra_src_locs.append(sema.gpa, operand_src); + block.runtime_index.increment(); + if (block.runtime_cond == null and block.runtime_loop == null) { + block.runtime_cond = start_block.runtime_cond orelse start_block.runtime_loop; + block.runtime_loop = start_block.runtime_loop; + } + return; + } + } + block = block.parent.?; + } +} + fn zirDbgStmt(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!void { if (block.is_comptime or block.ownerModule().strip) return; @@ -11003,12 +11077,7 @@ const SwitchProngAnalysis = struct { sema: *Sema, /// The block containing the `switch_block` itself. parent_block: *Block, - /// The raw switch operand value (*not* the condition). Always defined. - operand: Air.Inst.Ref, - /// May be `undefined` if no prong has a by-ref capture. - operand_ptr: Air.Inst.Ref, - /// The switch condition value. For unions, `operand` is the union and `cond` is its tag. - cond: Air.Inst.Ref, + operand: Operand, /// If this switch is on an error set, this is the type to assign to the /// `else` prong. If `null`, the prong should be unreachable. else_error_ty: ?Type, @@ -11018,6 +11087,34 @@ const SwitchProngAnalysis = struct { /// undefined if no prong has a tag capture. tag_capture_inst: Zir.Inst.Index, + const Operand = union(enum) { + /// This switch will be dispatched only once, with the given operand. + simple: struct { + /// The raw switch operand value. Always defined. + by_val: Air.Inst.Ref, + /// The switch operand *pointer*. Defined only if there is a prong + /// with a by-ref capture. + by_ref: Air.Inst.Ref, + /// The switch condition value. For unions, `operand` is the union + /// and `cond` is its enum tag value. + cond: Air.Inst.Ref, + }, + /// This switch may be dispatched multiple times with `continue` syntax. + /// As such, the operand is stored in an alloc if needed. + loop: struct { + /// The `alloc` containing the `switch` operand for the active dispatch. + /// Each prong must load from this `alloc` to get captures. + /// If there are no captures, this may be undefined. + operand_alloc: Air.Inst.Ref, + /// Whether `operand_alloc` contains a by-val operand or a by-ref + /// operand. + operand_is_ref: bool, + /// The switch condition value for the *initial* dispatch. For + /// unions, this is the enum tag value. + init_cond: Air.Inst.Ref, + }, + }; + /// Resolve a switch prong which is determined at comptime to have no peers. /// Uses `resolveBlockBody`. Sets up captures as needed. fn resolveProngComptime( @@ -11140,7 +11237,15 @@ const SwitchProngAnalysis = struct { ) CompileError!Air.Inst.Ref { const sema = spa.sema; const mod = sema.mod; - const operand_ty = sema.typeOf(spa.operand); + const operand_ty = switch (spa.operand) { + .simple => |s| sema.typeOf(s.by_val), + .loop => |l| ty: { + const alloc_ty = sema.typeOf(l.operand_alloc); + const alloc_child = alloc_ty.childType(mod); + if (l.operand_is_ref) break :ty alloc_child.childType(mod); + break :ty alloc_child; + }, + }; if (operand_ty.zigTypeTag(mod) != .Union) { const zir_datas = sema.code.instructions.items(.data); const switch_node_offset = zir_datas[@intFromEnum(spa.switch_block_inst)].pl_node.src_node; @@ -11175,10 +11280,24 @@ const SwitchProngAnalysis = struct { const zir_datas = sema.code.instructions.items(.data); const switch_node_offset = zir_datas[@intFromEnum(spa.switch_block_inst)].pl_node.src_node; - const operand_ty = sema.typeOf(spa.operand); - const operand_ptr_ty = if (capture_byref) sema.typeOf(spa.operand_ptr) else undefined; const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = switch_node_offset }; + const operand_val, const operand_ptr = switch (spa.operand) { + .simple => |s| .{ s.by_val, s.by_ref }, + .loop => |l| op: { + const loaded = try sema.analyzeLoad(block, operand_src, l.operand_alloc, operand_src); + if (l.operand_is_ref) { + const by_val = try sema.analyzeLoad(block, operand_src, loaded, operand_src); + break :op .{ by_val, loaded }; + } else { + break :op .{ loaded, undefined }; + } + }, + }; + + const operand_ty = sema.typeOf(operand_val); + const operand_ptr_ty = if (capture_byref) sema.typeOf(operand_ptr) else undefined; + if (inline_case_capture != .none) { const item_val = sema.resolveConstDefinedValue(block, .unneeded, inline_case_capture, undefined) catch unreachable; if (operand_ty.zigTypeTag(zcu) == .Union) { @@ -11194,16 +11313,16 @@ const SwitchProngAnalysis = struct { .address_space = operand_ptr_ty.ptrAddressSpace(zcu), }, }); - if (try sema.resolveDefinedValue(block, operand_src, spa.operand_ptr)) |union_ptr| { + if (try sema.resolveDefinedValue(block, operand_src, operand_ptr)) |union_ptr| { return Air.internedToRef((try union_ptr.ptrField(field_index, sema)).toIntern()); } - return block.addStructFieldPtr(spa.operand_ptr, field_index, ptr_field_ty); + return block.addStructFieldPtr(operand_ptr, field_index, ptr_field_ty); } else { - if (try sema.resolveDefinedValue(block, operand_src, spa.operand)) |union_val| { + if (try sema.resolveDefinedValue(block, operand_src, operand_val)) |union_val| { const tag_and_val = ip.indexToKey(union_val.toIntern()).un; return Air.internedToRef(tag_and_val.val); } - return block.addStructFieldVal(spa.operand, field_index, field_ty); + return block.addStructFieldVal(operand_val, field_index, field_ty); } } else if (capture_byref) { return anonDeclRef(sema, item_val.toIntern()); @@ -11214,17 +11333,17 @@ const SwitchProngAnalysis = struct { if (is_special_prong) { if (capture_byref) { - return spa.operand_ptr; + return operand_ptr; } switch (operand_ty.zigTypeTag(zcu)) { .ErrorSet => if (spa.else_error_ty) |ty| { - return sema.bitCast(block, ty, spa.operand, operand_src, null); + return sema.bitCast(block, ty, operand_val, operand_src, null); } else { try block.addUnreachable(operand_src, false); return .unreachable_value; }, - else => return spa.operand, + else => return operand_val, } } @@ -11333,19 +11452,19 @@ const SwitchProngAnalysis = struct { }; }; - if (try sema.resolveDefinedValue(block, operand_src, spa.operand_ptr)) |op_ptr_val| { + if (try sema.resolveDefinedValue(block, operand_src, operand_ptr)) |op_ptr_val| { if (op_ptr_val.isUndef(zcu)) return zcu.undefRef(capture_ptr_ty); const field_ptr_val = try op_ptr_val.ptrField(first_field_index, sema); return Air.internedToRef((try zcu.getCoerced(field_ptr_val, capture_ptr_ty)).toIntern()); } try sema.requireRuntimeBlock(block, operand_src, null); - return block.addStructFieldPtr(spa.operand_ptr, first_field_index, capture_ptr_ty); + return block.addStructFieldPtr(operand_ptr, first_field_index, capture_ptr_ty); } - if (try sema.resolveDefinedValue(block, operand_src, spa.operand)) |operand_val| { - if (operand_val.isUndef(zcu)) return zcu.undefRef(capture_ty); - const union_val = ip.indexToKey(operand_val.toIntern()).un; + if (try sema.resolveDefinedValue(block, operand_src, operand_val)) |operand_val_val| { + if (operand_val_val.isUndef(zcu)) return zcu.undefRef(capture_ty); + const union_val = ip.indexToKey(operand_val_val.toIntern()).un; if (Value.fromInterned(union_val.tag).isUndef(zcu)) return zcu.undefRef(capture_ty); const uncoerced = Air.internedToRef(union_val.val); return sema.coerce(block, capture_ty, uncoerced, operand_src); @@ -11354,7 +11473,7 @@ const SwitchProngAnalysis = struct { try sema.requireRuntimeBlock(block, operand_src, null); if (same_types) { - return block.addStructFieldVal(spa.operand, first_field_index, capture_ty); + return block.addStructFieldVal(operand_val, first_field_index, capture_ty); } // We may have to emit a switch block which coerces the operand to the capture type. @@ -11368,7 +11487,7 @@ const SwitchProngAnalysis = struct { } // All fields are in-memory coercible to the resolved type! // Just take the first field and bitcast the result. - const uncoerced = try block.addStructFieldVal(spa.operand, first_field_index, first_field_ty); + const uncoerced = try block.addStructFieldVal(operand_val, first_field_index, first_field_ty); return block.addBitCast(capture_ty, uncoerced); }; @@ -11416,7 +11535,7 @@ const SwitchProngAnalysis = struct { const field_idx = field_indices[idx]; const field_ty = Type.fromInterned(union_obj.field_types.get(ip)[field_idx]); - const uncoerced = try coerce_block.addStructFieldVal(spa.operand, field_idx, field_ty); + const uncoerced = try coerce_block.addStructFieldVal(operand_val, field_idx, field_ty); const coerced = sema.coerce(&coerce_block, capture_ty, uncoerced, .unneeded) catch |err| switch (err) { error.NeededSourceLocation => { const multi_idx = raw_capture_src.multi_capture; @@ -11450,7 +11569,7 @@ const SwitchProngAnalysis = struct { const first_imc_item_idx = in_mem_coercible.findFirstSet().?; const first_imc_field_idx = field_indices[first_imc_item_idx]; const first_imc_field_ty = Type.fromInterned(union_obj.field_types.get(ip)[first_imc_field_idx]); - const uncoerced = try coerce_block.addStructFieldVal(spa.operand, first_imc_field_idx, first_imc_field_ty); + const uncoerced = try coerce_block.addStructFieldVal(operand_val, first_imc_field_idx, first_imc_field_ty); const coerced = try coerce_block.addBitCast(capture_ty, uncoerced); _ = try coerce_block.addBr(capture_block_inst, coerced); @@ -11466,21 +11585,47 @@ const SwitchProngAnalysis = struct { const switch_br_inst: u32 = @intCast(sema.air_instructions.len); try sema.air_instructions.append(sema.gpa, .{ .tag = .switch_br, - .data = .{ .pl_op = .{ - .operand = spa.cond, - .payload = sema.addExtraAssumeCapacity(Air.SwitchBr{ - .cases_len = @intCast(prong_count), - .else_body_len = @intCast(else_body_len), - }), - } }, + .data = .{ + .pl_op = .{ + .operand = undefined, // set by switch below + .payload = sema.addExtraAssumeCapacity(Air.SwitchBr{ + .cases_len = @intCast(prong_count), + .else_body_len = @intCast(else_body_len), + }), + }, + }, }); sema.air_extra.appendSliceAssumeCapacity(cases_extra.items); // Set up block body - sema.air_instructions.items(.data)[@intFromEnum(capture_block_inst)].ty_pl.payload = sema.addExtraAssumeCapacity(Air.Block{ - .body_len = 1, - }); - sema.air_extra.appendAssumeCapacity(switch_br_inst); + switch (spa.operand) { + .simple => |s| { + const air_datas = sema.air_instructions.items(.data); + air_datas[switch_br_inst].pl_op.operand = s.cond; + air_datas[@intFromEnum(capture_block_inst)].ty_pl.payload = sema.addExtraAssumeCapacity(Air.Block{ + .body_len = 1, + }); + sema.air_extra.appendAssumeCapacity(switch_br_inst); + }, + .loop => { + // The block must first extract the tag from the loaded union. + const tag_inst: Air.Inst.Index = @enumFromInt(sema.air_instructions.len); + try sema.air_instructions.append(sema.gpa, .{ + .tag = .get_union_tag, + .data = .{ .ty_op = .{ + .ty = Air.internedToRef(union_obj.enum_tag_ty), + .operand = operand_val, + } }, + }); + const air_datas = sema.air_instructions.items(.data); + air_datas[switch_br_inst].pl_op.operand = tag_inst.toRef(); + air_datas[@intFromEnum(capture_block_inst)].ty_pl.payload = sema.addExtraAssumeCapacity(Air.Block{ + .body_len = 2, + }); + sema.air_extra.appendAssumeCapacity(@intFromEnum(tag_inst)); + sema.air_extra.appendAssumeCapacity(switch_br_inst); + }, + } return capture_block_inst.toRef(); }, @@ -11498,7 +11643,7 @@ const SwitchProngAnalysis = struct { if (case_vals.len == 1) { const item_val = sema.resolveConstDefinedValue(block, .unneeded, case_vals[0], undefined) catch unreachable; const item_ty = try zcu.singleErrorSetType(item_val.getErrorName(zcu).unwrap().?); - return sema.bitCast(block, item_ty, spa.operand, operand_src, null); + return sema.bitCast(block, item_ty, operand_val, operand_src, null); } var names: InferredErrorSet.NameMap = .{}; @@ -11508,15 +11653,15 @@ const SwitchProngAnalysis = struct { names.putAssumeCapacityNoClobber(err_val.getErrorName(zcu).unwrap().?, {}); } const error_ty = try zcu.errorSetFromUnsortedNames(names.keys()); - return sema.bitCast(block, error_ty, spa.operand, operand_src, null); + return sema.bitCast(block, error_ty, operand_val, operand_src, null); }, else => { // In this case the capture value is just the passed-through value // of the switch condition. if (capture_byref) { - return spa.operand_ptr; + return operand_ptr; } else { - return spa.operand; + return operand_val; } }, } @@ -11746,9 +11891,13 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp var spa: SwitchProngAnalysis = .{ .sema = sema, .parent_block = block, - .operand = undefined, // must be set to the unwrapped error code before use - .operand_ptr = .none, - .cond = raw_operand_val, + .operand = .{ + .simple = .{ + .by_val = undefined, // must be set to the unwrapped error code before use + .by_ref = undefined, + .cond = raw_operand_val, + }, + }, .else_error_ty = else_error_ty, .switch_block_inst = inst, .tag_capture_inst = undefined, @@ -11769,13 +11918,13 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp .name = operand_val.getErrorName(mod).unwrap().?, }, })); - spa.operand = if (extra.data.bits.payload_is_ref) + spa.operand.simple.by_val = if (extra.data.bits.payload_is_ref) try sema.analyzeErrUnionCodePtr(block, switch_operand_src, raw_operand_val) else try sema.analyzeErrUnionCode(block, switch_operand_src, raw_operand_val); if (extra.data.bits.any_uses_err_capture) { - sema.inst_map.putAssumeCapacity(err_capture_inst, spa.operand); + sema.inst_map.putAssumeCapacity(err_capture_inst, spa.operand.simple.by_val); } defer if (extra.data.bits.any_uses_err_capture) assert(sema.inst_map.remove(err_capture_inst)); @@ -11783,7 +11932,7 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp sema, spa, &child_block, - try sema.switchCond(block, switch_operand_src, spa.operand), + try sema.switchCond(block, switch_operand_src, spa.operand.simple.by_val), err_val, operand_err_set_ty, .{ @@ -11836,20 +11985,20 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp const true_instructions = try sub_block.instructions.toOwnedSlice(gpa); defer gpa.free(true_instructions); - spa.operand = if (extra.data.bits.payload_is_ref) + spa.operand.simple.by_val = if (extra.data.bits.payload_is_ref) try sema.analyzeErrUnionCodePtr(&sub_block, switch_operand_src, raw_operand_val) else try sema.analyzeErrUnionCode(&sub_block, switch_operand_src, raw_operand_val); if (extra.data.bits.any_uses_err_capture) { - sema.inst_map.putAssumeCapacity(err_capture_inst, spa.operand); + sema.inst_map.putAssumeCapacity(err_capture_inst, spa.operand.simple.by_val); } defer if (extra.data.bits.any_uses_err_capture) assert(sema.inst_map.remove(err_capture_inst)); _ = try sema.analyzeSwitchRuntimeBlock( spa, &sub_block, switch_src, - try sema.switchCond(block, switch_operand_src, spa.operand), + try sema.switchCond(block, switch_operand_src, spa.operand.simple.by_val), operand_err_set_ty, switch_operand_src, case_vals, @@ -11908,17 +12057,63 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r const special_prong_src: LazySrcLoc = .{ .node_offset_switch_special_prong = src_node_offset }; const extra = sema.code.extraData(Zir.Inst.SwitchBlock, inst_data.payload_index); - const raw_operand_val: Air.Inst.Ref, const raw_operand_ptr: Air.Inst.Ref = blk: { + const operand: SwitchProngAnalysis.Operand, const raw_operand_ty: Type = op: { const maybe_ptr = try sema.resolveInst(extra.data.operand); - if (operand_is_ref) { - const val = try sema.analyzeLoad(block, src, maybe_ptr, operand_src); - break :blk .{ val, maybe_ptr }; - } else { - break :blk .{ maybe_ptr, undefined }; + const val, const ref = if (operand_is_ref) + .{ try sema.analyzeLoad(block, src, maybe_ptr, operand_src), maybe_ptr } + else + .{ maybe_ptr, undefined }; + + const init_cond = try sema.switchCond(block, operand_src, val); + + const operand_ty = sema.typeOf(val); + + if (extra.data.bits.has_continue and !block.is_comptime) { + // Even if the operand is comptime-known, this `switch` is runtime. + if (try sema.typeRequiresComptime(operand_ty)) { + return sema.failWithOwnedErrorMsg(block, msg: { + const msg = try sema.errMsg(block, operand_src, "operand of switch loop has comptime-only type '{}'", .{operand_ty.fmt(mod)}); + errdefer msg.destroy(gpa); + try sema.errNote(block, operand_src, msg, "switch loops are evalauted at runtime outside of comptime scopes", .{}); + break :msg msg; + }); + } + try sema.validateRuntimeValue(block, operand_src, maybe_ptr); + const operand_alloc = if (extra.data.bits.any_non_inline_capture) a: { + const operand_ptr_ty = try mod.singleMutPtrType(sema.typeOf(maybe_ptr)); + const operand_alloc = try block.addTy(.alloc, operand_ptr_ty); + _ = try block.addBinOp(.store, operand_alloc, maybe_ptr); + break :a operand_alloc; + } else undefined; + break :op .{ + .{ .loop = .{ + .operand_alloc = operand_alloc, + .operand_is_ref = operand_is_ref, + .init_cond = init_cond, + } }, + operand_ty, + }; } + + // We always use `simple` in the comptime case, because as far as the dispatching logic + // is concerned, it really is dispatching a single prong. `resolveSwitchComptime` will + // be resposible for recursively resolving different prongs as needed. + break :op .{ + .{ .simple = .{ + .by_val = val, + .by_ref = ref, + .cond = init_cond, + } }, + operand_ty, + }; }; - const operand = try sema.switchCond(block, operand_src, raw_operand_val); + const union_originally = raw_operand_ty.zigTypeTag(mod) == .Union; + const err_set = raw_operand_ty.zigTypeTag(mod) == .ErrorSet; + const cond_ty = switch (raw_operand_ty.zigTypeTag(mod)) { + .Union => raw_operand_ty.unionTagType(mod).?, // validated by `switchCond` above + else => raw_operand_ty, + }; // AstGen guarantees that the instruction immediately preceding // switch_block(_ref) is a dbg_stmt @@ -11968,9 +12163,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r }, }; - const maybe_union_ty = sema.typeOf(raw_operand_val); - const union_originally = maybe_union_ty.zigTypeTag(mod) == .Union; - // Duplicate checking variables later also used for `inline else`. var seen_enum_fields: []?Module.SwitchProngSrc = &.{}; var seen_errors = SwitchErrorSet.init(gpa); @@ -11986,13 +12178,10 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r var empty_enum = false; - const operand_ty = sema.typeOf(operand); - const err_set = operand_ty.zigTypeTag(mod) == .ErrorSet; - var else_error_ty: ?Type = null; // Validate usage of '_' prongs. - if (special_prong == .under and (!operand_ty.isNonexhaustiveEnum(mod) or union_originally)) { + if (special_prong == .under and !raw_operand_ty.isNonexhaustiveEnum(mod)) { const msg = msg: { const msg = try sema.errMsg( block, @@ -12021,11 +12210,11 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r } // Validate for duplicate items, missing else prong, and invalid range. - switch (operand_ty.zigTypeTag(mod)) { + switch (cond_ty.zigTypeTag(mod)) { .Union => unreachable, // handled in `switchCond` .Enum => { - seen_enum_fields = try gpa.alloc(?Module.SwitchProngSrc, operand_ty.enumFieldCount(mod)); - empty_enum = seen_enum_fields.len == 0 and !operand_ty.isNonexhaustiveEnum(mod); + seen_enum_fields = try gpa.alloc(?Module.SwitchProngSrc, cond_ty.enumFieldCount(mod)); + empty_enum = seen_enum_fields.len == 0 and !cond_ty.isNonexhaustiveEnum(mod); @memset(seen_enum_fields, null); // `range_set` is used for non-exhaustive enum values that do not correspond to any tags. @@ -12043,7 +12232,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r seen_enum_fields, &range_set, item_ref, - operand_ty, + cond_ty, src_node_offset, .{ .scalar = scalar_i }, )); @@ -12068,13 +12257,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r seen_enum_fields, &range_set, item_ref, - operand_ty, + cond_ty, src_node_offset, .{ .multi = .{ .prong = multi_i, .item = @intCast(item_i) } }, )); } - try sema.validateSwitchNoRange(block, ranges_len, operand_ty, src_node_offset); + try sema.validateSwitchNoRange(block, ranges_len, cond_ty, src_node_offset); } } const all_tags_handled = for (seen_enum_fields) |seen_src| { @@ -12082,7 +12271,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r } else true; if (special_prong == .@"else") { - if (all_tags_handled and !operand_ty.isNonexhaustiveEnum(mod)) return sema.fail( + if (all_tags_handled and !cond_ty.isNonexhaustiveEnum(mod)) return sema.fail( block, special_prong_src, "unreachable else prong; all cases already handled", @@ -12100,9 +12289,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r for (seen_enum_fields, 0..) |seen_src, i| { if (seen_src != null) continue; - const field_name = operand_ty.enumFieldName(i, mod); + const field_name = cond_ty.enumFieldName(i, mod); try sema.addFieldErrNote( - operand_ty, + cond_ty, i, msg, "unhandled enumeration value: '{}'", @@ -12110,15 +12299,15 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r ); } try mod.errNoteNonLazy( - operand_ty.declSrcLoc(mod), + cond_ty.declSrcLoc(mod), msg, "enum '{}' declared here", - .{operand_ty.fmt(mod)}, + .{cond_ty.fmt(mod)}, ); break :msg msg; }; return sema.failWithOwnedErrorMsg(block, msg); - } else if (special_prong == .none and operand_ty.isNonexhaustiveEnum(mod) and !union_originally) { + } else if (special_prong == .none and cond_ty.isNonexhaustiveEnum(mod) and !union_originally) { return sema.fail( block, src, @@ -12132,7 +12321,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r block, &seen_errors, &case_vals, - operand_ty, + cond_ty, inst_data, scalar_cases_len, multi_cases_len, @@ -12153,7 +12342,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r block, &range_set, item_ref, - operand_ty, + cond_ty, src_node_offset, .{ .scalar = scalar_i }, )); @@ -12177,7 +12366,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r block, &range_set, item_ref, - operand_ty, + cond_ty, src_node_offset, .{ .multi = .{ .prong = multi_i, .item = @intCast(item_i) } }, )); @@ -12196,7 +12385,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r &range_set, item_first, item_last, - operand_ty, + cond_ty, src_node_offset, .{ .range = .{ .prong = multi_i, .item = range_i } }, ); @@ -12209,9 +12398,9 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r } check_range: { - if (operand_ty.zigTypeTag(mod) == .Int) { - const min_int = try operand_ty.minInt(mod, operand_ty); - const max_int = try operand_ty.maxInt(mod, operand_ty); + if (cond_ty.zigTypeTag(mod) == .Int) { + const min_int = try cond_ty.minInt(mod, cond_ty); + const max_int = try cond_ty.maxInt(mod, cond_ty); if (try range_set.spans(min_int.toIntern(), max_int.toIntern())) { if (special_prong == .@"else") { return sema.fail( @@ -12278,7 +12467,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r )); } - try sema.validateSwitchNoRange(block, ranges_len, operand_ty, src_node_offset); + try sema.validateSwitchNoRange(block, ranges_len, cond_ty, src_node_offset); } } switch (special_prong) { @@ -12310,7 +12499,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r block, src, "else prong required when switching on type '{}'", - .{operand_ty.fmt(mod)}, + .{cond_ty.fmt(mod)}, ); } @@ -12331,7 +12520,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r block, &seen_values, item_ref, - operand_ty, + cond_ty, src_node_offset, .{ .scalar = scalar_i }, )); @@ -12355,13 +12544,13 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r block, &seen_values, item_ref, - operand_ty, + cond_ty, src_node_offset, .{ .multi = .{ .prong = multi_i, .item = @intCast(item_i) } }, )); } - try sema.validateSwitchNoRange(block, ranges_len, operand_ty, src_node_offset); + try sema.validateSwitchNoRange(block, ranges_len, cond_ty, src_node_offset); } } }, @@ -12380,16 +12569,14 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r .ComptimeFloat, .Float, => return sema.fail(block, operand_src, "invalid switch operand type '{}'", .{ - operand_ty.fmt(mod), + raw_operand_ty.fmt(mod), }), } const spa: SwitchProngAnalysis = .{ .sema = sema, .parent_block = block, - .operand = raw_operand_val, - .operand_ptr = raw_operand_ptr, - .cond = operand, + .operand = operand, .else_error_ty = else_error_ty, .switch_block_inst = inst, .tag_capture_inst = tag_capture_inst, @@ -12432,23 +12619,6 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r defer child_block.instructions.deinit(gpa); defer merges.deinit(gpa); - if (try sema.resolveDefinedValue(&child_block, src, operand)) |operand_val| { - return resolveSwitchComptime( - sema, - spa, - &child_block, - operand, - operand_val, - operand_ty, - special, - case_vals, - scalar_cases_len, - multi_cases_len, - err_set, - empty_enum, - ); - } - if (scalar_cases_len + multi_cases_len == 0 and !special.is_inline) { if (empty_enum) { return .void_value; @@ -12456,51 +12626,86 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r if (special_prong == .none) { return sema.fail(block, src, "switch must handle all possibilities", .{}); } - if (err_set and try sema.maybeErrorUnwrap(block, special.body, operand, operand_src, false)) { - return .unreachable_value; - } - if (mod.backendSupportsFeature(.is_named_enum_value) and block.wantSafety() and operand_ty.zigTypeTag(mod) == .Enum and - (!operand_ty.isNonexhaustiveEnum(mod) or union_originally)) + const init_cond = switch (operand) { + .simple => |s| s.cond, + .loop => |l| l.init_cond, + }; + if (mod.backendSupportsFeature(.is_named_enum_value) and block.wantSafety() and + raw_operand_ty.zigTypeTag(mod) == .Enum and !raw_operand_ty.isNonexhaustiveEnum(mod)) { try sema.zirDbgStmt(block, cond_dbg_node_index); - const ok = try block.addUnOp(.is_named_enum_value, operand); + const ok = try block.addUnOp(.is_named_enum_value, init_cond); try sema.addSafetyCheck(block, src, ok, .corrupt_switch); } + if (err_set and try sema.maybeErrorUnwrap(block, special.body, init_cond, operand_src, false)) { + return .unreachable_value; + } + } - return spa.resolveProngComptime( - &child_block, - .special, - special.body, - special.capture, - .special_capture, - undefined, // case_vals may be undefined for special prongs - .none, - false, - merges, - ); + switch (operand) { + .loop => {}, // always runtime; evaluation in comptime scope uses `simple` + .simple => |s| { + if (try sema.resolveDefinedValue(&child_block, src, s.cond)) |cond_val| { + return resolveSwitchComptimeLoop( + sema, + spa, + &child_block, + if (operand_is_ref) + sema.typeOf(s.by_ref) + else + raw_operand_ty, + cond_ty, + cond_val, + special, + case_vals, + scalar_cases_len, + multi_cases_len, + err_set, + empty_enum, + operand_is_ref, + ); + } + + if (scalar_cases_len + multi_cases_len == 0 and !special.is_inline and !extra.data.bits.has_continue) { + return spa.resolveProngComptime( + &child_block, + .special, + special.body, + special.capture, + .special_capture, + undefined, // case_vals may be undefined for special prongs + .none, + false, + merges, + ); + } + }, } if (child_block.is_comptime) { - _ = try sema.resolveConstDefinedValue(&child_block, operand_src, operand, .{ + _ = try sema.resolveConstDefinedValue(&child_block, operand_src, operand.simple.cond, .{ .needed_comptime_reason = "condition in comptime switch must be comptime-known", .block_comptime_reason = child_block.comptime_reason, }); unreachable; } - _ = try sema.analyzeSwitchRuntimeBlock( + const air_switch_ref = try sema.analyzeSwitchRuntimeBlock( spa, &child_block, src, - operand, - operand_ty, + switch (operand) { + .simple => |s| s.cond, + .loop => |l| l.init_cond, + }, + cond_ty, operand_src, case_vals, special, scalar_cases_len, multi_cases_len, union_originally, - maybe_union_ty, + raw_operand_ty, err_set, src_node_offset, special_prong_src, @@ -12513,6 +12718,67 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r false, ); + for (merges.extra_insts.items, merges.extra_src_locs.items) |placeholder_inst, dispatch_src| { + var replacement_block = block.makeSubBlock(); + defer replacement_block.instructions.deinit(gpa); + + assert(sema.air_instructions.items(.tag)[@intFromEnum(placeholder_inst)] == .br); + const new_operand_maybe_ref = sema.air_instructions.items(.data)[@intFromEnum(placeholder_inst)].br.operand; + + if (extra.data.bits.any_non_inline_capture) { + _ = try replacement_block.addBinOp(.store, operand.loop.operand_alloc, new_operand_maybe_ref); + } + + const new_operand_val = if (operand_is_ref) + try sema.analyzeLoad(&replacement_block, dispatch_src, new_operand_maybe_ref, dispatch_src) + else + new_operand_maybe_ref; + + const new_cond = try sema.switchCond(&replacement_block, dispatch_src, new_operand_val); + + if (mod.backendSupportsFeature(.is_named_enum_value) and block.wantSafety() and + cond_ty.zigTypeTag(mod) == .Enum and !cond_ty.isNonexhaustiveEnum(mod) and + !try sema.isComptimeKnown(new_cond)) + { + const ok = try replacement_block.addUnOp(.is_named_enum_value, new_cond); + try sema.addSafetyCheck(&replacement_block, src, ok, .corrupt_switch); + } + + _ = try replacement_block.addInst(.{ + .tag = .switch_dispatch, + .data = .{ .br = .{ + .block_inst = air_switch_ref.toIndex().?, + .operand = new_cond, + } }, + }); + + if (replacement_block.instructions.items.len == 1) { + // Optimization: we don't need a block! + sema.air_instructions.set( + @intFromEnum(placeholder_inst), + sema.air_instructions.get(@intFromEnum(replacement_block.instructions.items[0])), + ); + continue; + } + + // Replace placeholder with a block. + // No `br` is needed as the block is a switch dispatch so necessarily `noreturn`. + try sema.air_extra.ensureUnusedCapacity( + gpa, + @typeInfo(Air.Block).Struct.fields.len + replacement_block.instructions.items.len, + ); + sema.air_instructions.set(@intFromEnum(placeholder_inst), .{ + .tag = .block, + .data = .{ .ty_pl = .{ + .ty = .noreturn_type, + .payload = sema.addExtraAssumeCapacity(Air.Block{ + .body_len = @intCast(replacement_block.instructions.items.len), + }), + } }, + }); + sema.air_extra.appendSliceAssumeCapacity(@ptrCast(replacement_block.instructions.items)); + } + return sema.resolveAnalyzedBlock(block, src, &child_block, merges, false); } @@ -13087,7 +13353,7 @@ fn analyzeSwitchRuntimeBlock( sema.air_extra.appendSliceAssumeCapacity(@ptrCast(else_body)); return try child_block.addInst(.{ - .tag = .switch_br, + .tag = if (spa.operand == .loop) .loop_switch_br else .switch_br, .data = .{ .pl_op = .{ .operand = operand, .payload = payload_index, @@ -13095,6 +13361,75 @@ fn analyzeSwitchRuntimeBlock( }); } +fn resolveSwitchComptimeLoop( + sema: *Sema, + init_spa: SwitchProngAnalysis, + child_block: *Block, + maybe_ptr_operand_ty: Type, + cond_ty: Type, + init_cond_val: Value, + special: SpecialProng, + case_vals: std.ArrayListUnmanaged(Air.Inst.Ref), + scalar_cases_len: u32, + multi_cases_len: u32, + err_set: bool, + empty_enum: bool, + operand_is_ref: bool, +) CompileError!Air.Inst.Ref { + var spa = init_spa; + var cond_val = init_cond_val; + + while (true) { + if (resolveSwitchComptime( + sema, + spa, + child_block, + spa.operand.simple.cond, + cond_val, + cond_ty, + special, + case_vals, + scalar_cases_len, + multi_cases_len, + err_set, + empty_enum, + )) |result| { + return result; + } else |err| switch (err) { + error.ComptimeBreak => { + const break_inst = sema.code.instructions.get(@intFromEnum(sema.comptime_break_inst)); + if (break_inst.tag != .switch_continue) return error.ComptimeBreak; + const extra = sema.code.extraData(Zir.Inst.Break, break_inst.data.@"break".payload_index).data; + if (extra.block_inst != spa.switch_block_inst) return error.ComptimeBreak; + // This is a `switch_continue` targeting this block. Change the operand and start over. + const src = LazySrcLoc.nodeOffset(extra.operand_src_node); + const new_operand_uncoerced = try sema.resolveInst(break_inst.data.@"break".operand); + const new_operand = try sema.coerce(child_block, maybe_ptr_operand_ty, new_operand_uncoerced, src); + + try sema.emitBackwardBranch(child_block, src); + + const val, const ref = if (operand_is_ref) + .{ try sema.analyzeLoad(child_block, src, new_operand, src), new_operand } + else + .{ new_operand, undefined }; + + const cond_ref = try sema.switchCond(child_block, src, val); + + cond_val = try sema.resolveConstDefinedValue(child_block, src, cond_ref, .{ + .needed_comptime_reason = "condition in comptime switch must be comptime-known", + .block_comptime_reason = child_block.comptime_reason, + }); + spa.operand = .{ .simple = .{ + .by_val = val, + .by_ref = ref, + .cond = cond_ref, + } }; + }, + else => |e| return e, + } + } +} + fn resolveSwitchComptime( sema: *Sema, spa: SwitchProngAnalysis, @@ -13111,6 +13446,7 @@ fn resolveSwitchComptime( ) CompileError!Air.Inst.Ref { const merges = &child_block.label.?.merges; const resolved_operand_val = try sema.resolveLazyValue(operand_val); + var extra_index: usize = special.end; { var scalar_i: usize = 0; diff --git a/src/Value.zig b/src/Value.zig index 99817d79a971..c37376e97aa7 100644 --- a/src/Value.zig +++ b/src/Value.zig @@ -266,6 +266,7 @@ pub fn getUnsignedIntAdvanced(val: Value, mod: *Module, opt_sema: ?*Sema) !?u64 .none => 0, else => |payload| Value.fromInterned(payload).getUnsignedIntAdvanced(mod, opt_sema), }, + .enum_tag => |enum_tag| return Value.fromInterned(enum_tag.int).getUnsignedIntAdvanced(mod, opt_sema), else => null, }, }; diff --git a/src/arch/aarch64/CodeGen.zig b/src/arch/aarch64/CodeGen.zig index 08e8500b6597..f60ab908fa76 100644 --- a/src/arch/aarch64/CodeGen.zig +++ b/src/arch/aarch64/CodeGen.zig @@ -741,6 +741,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .block => try self.airBlock(inst), .br => try self.airBr(inst), .repeat => return self.fail("TODO implement `repeat`", .{}), + .switch_dispatch => return self.fail("TODO implement `switch_dispatch`", .{}), .trap => try self.airTrap(), .breakpoint => try self.airBreakpoint(), .ret_addr => try self.airRetAddr(inst), @@ -828,6 +829,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .field_parent_ptr => try self.airFieldParentPtr(inst), .switch_br => try self.airSwitch(inst), + .loop_switch_br => return self.fail("TODO implement `loop_switch_br`", .{}), .slice_ptr => try self.airSlicePtr(inst), .slice_len => try self.airSliceLen(inst), diff --git a/src/arch/arm/CodeGen.zig b/src/arch/arm/CodeGen.zig index cb9a580832cd..5c4d2a9b94e0 100644 --- a/src/arch/arm/CodeGen.zig +++ b/src/arch/arm/CodeGen.zig @@ -727,6 +727,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .block => try self.airBlock(inst), .br => try self.airBr(inst), .repeat => return self.fail("TODO implement `repeat`", .{}), + .switch_dispatch => return self.fail("TODO implement `switch_dispatch`", .{}), .trap => try self.airTrap(), .breakpoint => try self.airBreakpoint(), .ret_addr => try self.airRetAddr(inst), @@ -814,6 +815,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .field_parent_ptr => try self.airFieldParentPtr(inst), .switch_br => try self.airSwitch(inst), + .loop_switch_br => return self.fail("TODO implement `loop_switch_br`", .{}), .slice_ptr => try self.airSlicePtr(inst), .slice_len => try self.airSliceLen(inst), diff --git a/src/arch/riscv64/CodeGen.zig b/src/arch/riscv64/CodeGen.zig index 26083f9544c9..0769f6cd5ec8 100644 --- a/src/arch/riscv64/CodeGen.zig +++ b/src/arch/riscv64/CodeGen.zig @@ -560,6 +560,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .block => try self.airBlock(inst), .br => try self.airBr(inst), .repeat => return self.fail("TODO implement `repeat`", .{}), + .switch_dispatch => return self.fail("TODO implement `switch_dispatch`", .{}), .trap => try self.airTrap(), .breakpoint => try self.airBreakpoint(), .ret_addr => try self.airRetAddr(inst), @@ -647,6 +648,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .field_parent_ptr => try self.airFieldParentPtr(inst), .switch_br => try self.airSwitch(inst), + .loop_switch_br => return self.fail("TODO implement `loop_switch_br`", .{}), .slice_ptr => try self.airSlicePtr(inst), .slice_len => try self.airSliceLen(inst), diff --git a/src/arch/sparc64/CodeGen.zig b/src/arch/sparc64/CodeGen.zig index bd3a5b0f5e92..6cdc718dabe3 100644 --- a/src/arch/sparc64/CodeGen.zig +++ b/src/arch/sparc64/CodeGen.zig @@ -574,6 +574,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .block => try self.airBlock(inst), .br => try self.airBr(inst), .repeat => return self.fail("TODO implement `repeat`", .{}), + .switch_dispatch => return self.fail("TODO implement `switch_dispatch`", .{}), .trap => try self.airTrap(), .breakpoint => try self.airBreakpoint(), .ret_addr => @panic("TODO try self.airRetAddr(inst)"), @@ -661,6 +662,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .field_parent_ptr => @panic("TODO try self.airFieldParentPtr(inst)"), .switch_br => try self.airSwitch(inst), + .loop_switch_br => return self.fail("TODO implement `loop_switch_br`", .{}), .slice_ptr => try self.airSlicePtr(inst), .slice_len => try self.airSliceLen(inst), diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index 5a0478398b3c..7fe42c3d7497 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -1898,6 +1898,7 @@ fn genInst(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { .breakpoint => func.airBreakpoint(inst), .br => func.airBr(inst), .repeat => return func.fail("TODO implement `repeat`", .{}), + .switch_dispatch => return func.fail("TODO implement `switch_dispatch`", .{}), .int_from_bool => func.airIntFromBool(inst), .cond_br => func.airCondBr(inst), .intcast => func.airIntcast(inst), @@ -1975,6 +1976,7 @@ fn genInst(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { .field_parent_ptr => func.airFieldParentPtr(inst), .switch_br => func.airSwitchBr(inst), + .loop_switch_br => return func.fail("TODO implement `loop_switch_br`", .{}), .trunc => func.airTrunc(inst), .unreach => func.airUnreachable(inst), diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index 2ac50dfa3a07..64bf40af75c9 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -2039,6 +2039,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .block => try self.airBlock(inst), .br => try self.airBr(inst), .repeat => return self.fail("TODO implement `repeat`", .{}), + .switch_dispatch => return self.fail("TODO implement `switch_dispatch`", .{}), .trap => try self.airTrap(), .breakpoint => try self.airBreakpoint(), .ret_addr => try self.airRetAddr(inst), @@ -2124,6 +2125,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .field_parent_ptr => try self.airFieldParentPtr(inst), .switch_br => try self.airSwitchBr(inst), + .loop_switch_br => return self.fail("TODO implement `loop_switch_br`", .{}), .slice_ptr => try self.airSlicePtr(inst), .slice_len => try self.airSliceLen(inst), diff --git a/src/codegen/c.zig b/src/codegen/c.zig index caf9e2097128..9e44026bdfdd 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -329,6 +329,9 @@ pub const Function = struct { /// by type alignment. /// The value is whether the alloc needs to be emitted in the header. allocs: std.AutoArrayHashMapUnmanaged(LocalIndex, bool) = .{}, + /// Maps from `loop_switch_br` instructions to the allocated local used + /// for the switch cond. Dispatches should set this local to the new cond. + loop_switch_conds: std.AutoHashMapUnmanaged(Air.Inst.Index, LocalIndex) = .{}, fn resolveInst(f: *Function, ref: Air.Inst.Ref) !CValue { const gop = try f.value_map.getOrPut(ref); @@ -537,6 +540,7 @@ pub const Function = struct { f.blocks.deinit(gpa); f.value_map.deinit(); f.lazy_fns.deinit(gpa); + f.loop_switch_conds.deinit(gpa); } fn typeOf(f: *Function, inst: Air.Inst.Ref) Type { @@ -3425,16 +3429,18 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, => unreachable, // Instructions that are known to always be `noreturn` based on their tag. - .br => return airBr(f, inst), - .repeat => return airRepeat(f, inst), - .cond_br => return airCondBr(f, inst), - .switch_br => return airSwitchBr(f, inst), - .loop => return airLoop(f, inst), - .ret => return airRet(f, inst, false), - .ret_safe => return airRet(f, inst, false), // TODO - .ret_load => return airRet(f, inst, true), - .trap => return airTrap(f, f.object.writer()), - .unreach => return airUnreach(f), + .br => return airBr(f, inst), + .repeat => return airRepeat(f, inst), + .switch_dispatch => return airSwitchDispatch(f, inst), + .cond_br => return airCondBr(f, inst), + .switch_br => return airSwitchBr(f, inst, false), + .loop_switch_br => return airSwitchBr(f, inst, true), + .loop => return airLoop(f, inst), + .ret => return airRet(f, inst, false), + .ret_safe => return airRet(f, inst, false), // TODO + .ret_load => return airRet(f, inst, true), + .trap => return airTrap(f, f.object.writer()), + .unreach => return airUnreach(f), // Instructions which may be `noreturn`. .block => res: { @@ -3782,7 +3788,7 @@ fn airLoad(f: *Function, inst: Air.Inst.Index) !CValue { return local; } -fn airRet(f: *Function, inst: Air.Inst.Index, is_ptr: bool) !CValue { +fn airRet(f: *Function, inst: Air.Inst.Index, is_ptr: bool) !void { const zcu = f.object.dg.zcu; const un_op = f.air.instructions.items(.data)[@intFromEnum(inst)].un_op; const writer = f.object.writer(); @@ -3832,7 +3838,6 @@ fn airRet(f: *Function, inst: Air.Inst.Index, is_ptr: bool) !CValue { // Not even allowed to return void in a naked function. if (!f.object.dg.is_naked_fn) try writer.writeAll("return;\n"); } - return .none; } fn airIntCast(f: *Function, inst: Air.Inst.Index) !CValue { @@ -4792,7 +4797,7 @@ fn lowerTry( return local; } -fn airBr(f: *Function, inst: Air.Inst.Index) !CValue { +fn airBr(f: *Function, inst: Air.Inst.Index) !void { const branch = f.air.instructions.items(.data)[@intFromEnum(inst)].br; const block = f.blocks.get(branch.block_inst).?; const result = block.result; @@ -4812,14 +4817,53 @@ fn airBr(f: *Function, inst: Air.Inst.Index) !CValue { } try writer.print("goto zig_block_{d};\n", .{block.block_id}); - return .none; } -fn airRepeat(f: *Function, inst: Air.Inst.Index) !CValue { +fn airRepeat(f: *Function, inst: Air.Inst.Index) !void { const repeat = f.air.instructions.items(.data)[@intFromEnum(inst)].repeat; const writer = f.object.writer(); try writer.print("goto zig_loop_{d};\n", .{@intFromEnum(repeat.loop_inst)}); - return .none; +} + +fn airSwitchDispatch(f: *Function, inst: Air.Inst.Index) !void { + const zcu = f.object.dg.zcu; + const br = f.air.instructions.items(.data)[@intFromEnum(inst)].br; + const writer = f.object.writer(); + + if (try f.air.value(br.operand, zcu)) |cond_val| { + // Comptime-known dispatch. Iterate the cases to find the correct + // one, and branch directly to the corresponding case. + var it = f.air.switchIterator(br.block_inst); + var next_case_idx: u32 = 0; + const target_case_idx: u32 = target: while (it.nextCase()) |case| { + const case_idx = next_case_idx; + next_case_idx += 1; + for (case.items) |item| { + const val = Value.fromInterned(item.toInterned().?); + if (cond_val.compareHetero(.eq, val, zcu)) break :target case_idx; + } + for (case.ranges) |range| { + const low = Value.fromInterned(range[0].toInterned().?); + const high = Value.fromInterned(range[1].toInterned().?); + if (cond_val.compareHetero(.gte, low, zcu) and + cond_val.compareHetero(.lte, high, zcu)) + { + break :target case_idx; + } + } + } else it.total_cases; + try writer.print("goto zig_switch_{d}_dispatch_{d};\n", .{ @intFromEnum(br.block_inst), target_case_idx }); + return; + } + + // Runtime-known dispatch. Set the switch condition, and branch back. + const cond = try f.resolveInst(br.operand); + const cond_local = f.loop_switch_conds.get(br.block_inst).?; + try f.writeCValue(writer, .{ .local = cond_local }, .Other); + try writer.writeAll(" = "); + try f.writeCValue(writer, cond, .Initializer); + try writer.writeAll(";\n"); + try writer.print("goto zig_switch_{d}_loop;", .{@intFromEnum(br.block_inst)}); } fn airBitcast(f: *Function, inst: Air.Inst.Index) !CValue { @@ -4946,12 +4990,11 @@ fn bitcast(f: *Function, dest_ty: Type, operand: CValue, operand_ty: Type) !CVal return local; } -fn airTrap(f: *Function, writer: anytype) !CValue { +fn airTrap(f: *Function, writer: anytype) !void { // Not even allowed to call trap in a naked function. - if (f.object.dg.is_naked_fn) return .none; + if (f.object.dg.is_naked_fn) return; try writer.writeAll("zig_trap();\n"); - return .none; } fn airBreakpoint(writer: anytype) !CValue { @@ -4990,15 +5033,14 @@ fn airFence(f: *Function, inst: Air.Inst.Index) !CValue { return .none; } -fn airUnreach(f: *Function) !CValue { +fn airUnreach(f: *Function) !void { // Not even allowed to call unreachable in a naked function. - if (f.object.dg.is_naked_fn) return .none; + if (f.object.dg.is_naked_fn) return; try f.object.writer().writeAll("zig_unreachable();\n"); - return .none; } -fn airLoop(f: *Function, inst: Air.Inst.Index) !CValue { +fn airLoop(f: *Function, inst: Air.Inst.Index) !void { const ty_pl = f.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; const loop = f.air.extraData(Air.Block, ty_pl.payload); const body: []const Air.Inst.Index = @ptrCast(f.air.extra[loop.end..][0..loop.data.body_len]); @@ -5010,11 +5052,9 @@ fn airLoop(f: *Function, inst: Air.Inst.Index) !CValue { // construct at all! try writer.print("zig_loop_{d}:\n", .{@intFromEnum(inst)}); try genBodyInner(f, body); // no need to restore state, we're noreturn - - return .none; } -fn airCondBr(f: *Function, inst: Air.Inst.Index) !CValue { +fn airCondBr(f: *Function, inst: Air.Inst.Index) !void { const pl_op = f.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; const cond = try f.resolveInst(pl_op.operand); try reap(f, inst, &.{pl_op.operand}); @@ -5043,19 +5083,36 @@ fn airCondBr(f: *Function, inst: Air.Inst.Index) !CValue { // instance) `br` to a block (label). try genBodyInner(f, else_body); - - return .none; } -fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !CValue { +fn airSwitchBr(f: *Function, inst: Air.Inst.Index, is_dispatch_loop: bool) !void { const zcu = f.object.dg.zcu; + const gpa = f.object.dg.gpa; const pl_op = f.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; - const condition = try f.resolveInst(pl_op.operand); + const init_condition = try f.resolveInst(pl_op.operand); try reap(f, inst, &.{pl_op.operand}); const condition_ty = f.typeOf(pl_op.operand); const switch_br = f.air.extraData(Air.SwitchBr, pl_op.payload); const writer = f.object.writer(); + // For dispatches, we will create a local alloc to contain the condition value. + // This may not result in optimal codegen for switch loops, but it minimizes the + // amount of C code we generate, which is probably more desirable here (and is simpler). + const condition = if (is_dispatch_loop) cond: { + const new_local = try f.allocLocal(inst, condition_ty); + try f.writeCValue(writer, new_local, .Other); + try writer.writeAll(" = "); + try f.writeCValue(writer, init_condition, .Initializer); + try writer.writeAll(";\n"); + try writer.print("zig_switch_{d}_loop:", .{@intFromEnum(inst)}); + try f.loop_switch_conds.put(gpa, inst, new_local.new_local); + break :cond new_local; + } else init_condition; + + defer if (is_dispatch_loop) { + assert(f.loop_switch_conds.remove(inst)); + }; + try writer.writeAll("switch ("); const lowered_condition_ty = if (condition_ty.toIntern() == .bool_type) @@ -5073,7 +5130,6 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !CValue { try writer.writeAll(") {"); f.object.indent_writer.pushIndent(); - const gpa = f.object.dg.gpa; const liveness = try f.liveness.getSwitchBr(gpa, inst, switch_br.data.cases_len + 1); defer gpa.free(liveness.deaths); @@ -5095,9 +5151,15 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !CValue { try f.object.indent_writer.insertNewline(); try writer.writeAll("case "); const item_value = try f.air.value(item, zcu); - if (item_value.?.getUnsignedInt(zcu)) |item_int| try writer.print("{}\n", .{ - try f.fmtIntLiteral(try zcu.intValue(lowered_condition_ty, item_int)), - }) else { + // If `item_value` is a pointer with a known integer address, print the address + // with no cast to avoid a warning. + write_val: { + if (condition_ty.isPtrAtRuntime(zcu)) { + if (item_value.?.getUnsignedInt(zcu)) |item_int| { + try writer.print("{}", .{try f.fmtIntLiteral(try zcu.intValue(lowered_condition_ty, item_int))}); + break :write_val; + } + } if (condition_ty.isPtrAtRuntime(zcu)) { try writer.writeByte('('); try f.renderType(writer, Type.usize); @@ -5107,9 +5169,14 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !CValue { } try writer.writeByte(':'); } - try writer.writeByte(' '); - - try genBodyResolveState(f, inst, liveness.deaths[case_i], case_body, false); + try writer.writeAll(" {\n"); + f.object.indent_writer.pushIndent(); + if (is_dispatch_loop) { + try writer.print("zig_switch_{d}_dispatch_{d}: ", .{ @intFromEnum(inst), case_i }); + } + try genBodyResolveState(f, inst, liveness.deaths[case_i], case_body, true); + f.object.indent_writer.popIndent(); + try writer.writeByte('}'); // The case body must be noreturn so we don't need to insert a break. } @@ -5157,10 +5224,19 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !CValue { try f.object.dg.renderValue(writer, (try f.air.value(range[1], zcu)).?, .Other); try writer.writeByte(')'); } - try writer.writeAll(") "); - try genBodyResolveState(f, inst, liveness.deaths[case_i], case_body, false); + try writer.writeAll(") {\n"); + f.object.indent_writer.pushIndent(); + if (is_dispatch_loop) { + try writer.print("zig_switch_{d}_dispatch_{d}: ", .{ @intFromEnum(inst), case_i }); + } + try genBodyResolveState(f, inst, liveness.deaths[case_i], case_body, true); + f.object.indent_writer.popIndent(); + try writer.writeByte('}'); } } + if (is_dispatch_loop) { + try writer.print("zig_switch_{d}_dispatch_{d}: ", .{ @intFromEnum(inst), switch_br.data.cases_len }); + } if (else_body.len > 0) { // Note that this must be the last case, so we do not need to use `caseBodyResolveState` since // the parent block will do it (because the case body is noreturn). @@ -5176,7 +5252,6 @@ fn airSwitchBr(f: *Function, inst: Air.Inst.Index) !CValue { f.object.indent_writer.popIndent(); try writer.writeAll("}\n"); - return .none; } fn asmInputNeedsLocal(f: *Function, constraint: []const u8, value: CValue) bool { diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 0a6adb11db4c..6563d8c0b3b9 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -1696,6 +1696,7 @@ pub const Object = struct { .func_inst_table = .{}, .blocks = .{}, .loops = .{}, + .switch_dispatch_info = .{}, .sync_scope = if (owner_mod.single_threaded) .singlethread else .system, .file = file, .scope = subprogram, @@ -4799,8 +4800,34 @@ pub const FuncGen = struct { /// Maps `loop` instructions to the bb to branch to to repeat the loop. loops: std.AutoHashMapUnmanaged(Air.Inst.Index, Builder.Function.Block.Index), + /// Maps `loop_switch_br` instructions to the information required to lower + /// dispatches (`switch_dispatch` instructions). + switch_dispatch_info: std.AutoHashMapUnmanaged(Air.Inst.Index, SwitchDispatchInfo), + sync_scope: Builder.SyncScope, + const SwitchDispatchInfo = struct { + /// These are the blocks corresponding to each switch case. + /// The final element corresponds to the `else` case. + /// Slices allocated into `gpa`. + case_blocks: []Builder.Function.Block.Index, + /// If not `null`, we have manually constructed a jump table to reach the desired block. + /// `table` can be used if the value is between `min` and `max` inclusive. + /// We perform this lowering manually to avoid some questionable behavior from LLVM. + /// See `airSwitchBr` for details. + jmp_table: ?struct { + min: Builder.Constant, + max: Builder.Constant, + /// Pointer to the jump table itself, to be used with `indirectbr`. + /// The index into the jump table is the dispatch condition minus `min`. + /// The table values are `blockaddress` constants corresponding to blocks in `case_blocks`. + table: Builder.Constant, + /// `true` if `table` conatins a reference to the `else` block. + /// In this case, the `indirectbr` must include the `else` block in its target list. + table_includes_else: bool, + }, + }; + const BreakList = union { list: std.MultiArrayList(struct { bb: Builder.Function.Block.Index, @@ -4814,6 +4841,11 @@ pub const FuncGen = struct { self.func_inst_table.deinit(self.gpa); self.blocks.deinit(self.gpa); self.loops.deinit(self.gpa); + var it = self.switch_dispatch_info.valueIterator(); + while (it.next()) |info| { + self.gpa.free(info.case_blocks); + } + self.switch_dispatch_info.deinit(self.gpa); } fn todo(self: *FuncGen, comptime format: []const u8, args: anytype) Error { @@ -5101,16 +5133,18 @@ pub const FuncGen = struct { .work_group_id => try self.airWorkGroupId(inst), // Instructions that are known to always be `noreturn` based on their tag. - .br => return self.airBr(inst), - .repeat => return self.airRepeat(inst), - .cond_br => return self.airCondBr(inst), - .switch_br => return self.airSwitchBr(inst), - .loop => return self.airLoop(inst), - .ret => return self.airRet(inst, false), - .ret_safe => return self.airRet(inst, true), - .ret_load => return self.airRetLoad(inst), - .trap => return self.airTrap(inst), - .unreach => return self.airUnreach(inst), + .br => return self.airBr(inst), + .repeat => return self.airRepeat(inst), + .switch_dispatch => return self.airSwitchDispatch(inst), + .cond_br => return self.airCondBr(inst), + .switch_br => return self.airSwitchBr(inst, false), + .loop_switch_br => return self.airSwitchBr(inst, true), + .loop => return self.airLoop(inst), + .ret => return self.airRet(inst, false), + .ret_safe => return self.airRet(inst, true), + .ret_load => return self.airRetLoad(inst), + .trap => return self.airTrap(inst), + .unreach => return self.airUnreach(inst), // Instructions which may be `noreturn`. .block => res: { @@ -5998,6 +6032,182 @@ pub const FuncGen = struct { _ = try self.wip.br(loop_bb); } + fn lowerSwitchDispatch( + self: *FuncGen, + switch_inst: Air.Inst.Index, + cond_ref: Air.Inst.Ref, + dispatch_info: SwitchDispatchInfo, + ) !void { + const o = self.dg.object; + const zcu = o.module; + const cond_ty = self.typeOf(cond_ref); + + if (try self.air.value(cond_ref, zcu)) |cond_val| { + // Comptime-known dispatch. Iterate the cases to find the correct + // one, and branch to the corresponding element of `case_blocks`. + var it = self.air.switchIterator(switch_inst); + const target_case_idx = target: while (it.nextCase()) |case| { + for (case.items) |item| { + const val = Value.fromInterned(item.toInterned().?); + if (cond_val.compareHetero(.eq, val, zcu)) break :target case.index; + } + for (case.ranges) |range| { + const low = Value.fromInterned(range[0].toInterned().?); + const high = Value.fromInterned(range[1].toInterned().?); + if (cond_val.compareHetero(.gte, low, zcu) and + cond_val.compareHetero(.lte, high, zcu)) + { + break :target case.index; + } + } + } else dispatch_info.case_blocks.len - 1; + const target_block = dispatch_info.case_blocks[target_case_idx]; + target_block.ptr(&self.wip).incoming += 1; + _ = try self.wip.br(target_block); + return; + } + + // Runtime-known dispatch. + const cond = try self.resolveInst(cond_ref); + + if (dispatch_info.jmp_table) |jmp_table| { + // We should use the constructed jump table. + // First, check the bounds to branch to the `else` case if needed. + const inbounds = try self.wip.bin( + .@"and", + try self.cmp(.normal, .gte, cond_ty, cond, jmp_table.min.toValue()), + try self.cmp(.normal, .lte, cond_ty, cond, jmp_table.max.toValue()), + "", + ); + const jmp_table_block = try self.wip.block(1, "Then"); + const else_block = dispatch_info.case_blocks[dispatch_info.case_blocks.len - 1]; + else_block.ptr(&self.wip).incoming += 1; + _ = try self.wip.brCond(inbounds, jmp_table_block, else_block); + + self.wip.cursor = .{ .block = jmp_table_block }; + + // Figure out the list of blocks we might branch to. + // This includes all case blocks, but it might not include the `else` block if + // the table is dense. + const target_blocks_len = dispatch_info.case_blocks.len - @intFromBool(!jmp_table.table_includes_else); + const target_blocks = dispatch_info.case_blocks[0..target_blocks_len]; + + // Make sure to cast the index to a usize so it's not treated as negative! + const table_index = try self.wip.cast( + .zext, + try self.wip.bin(.@"sub nuw", cond, jmp_table.min.toValue(), ""), + try o.lowerType(Type.usize), + "", + ); + const target_ptr_ptr = try self.wip.gep( + .inbounds, + .ptr, + jmp_table.table.toValue(), + &.{table_index}, + "", + ); + const target_ptr = try self.wip.load(.normal, .ptr, target_ptr_ptr, .default, ""); + + // Do the branch! + _ = try self.wip.indirectbr(target_ptr, target_blocks); + + // Mark all target blocks as having one more incoming branch. + for (target_blocks) |case_block| { + case_block.ptr(&self.wip).incoming += 1; + } + + return; + } + + // We must lower to an actual LLVM `switch` instruction. + // The switch prongs will correspond to our scalar cases. Ranges will + // be handled by conditional branches in the `else` prong. + + const llvm_usize = try o.lowerType(Type.usize); + const cond_int = if (cond.typeOfWip(&self.wip).isPointer(&o.builder)) + try self.wip.cast(.ptrtoint, cond, llvm_usize, "") + else + cond; + + var llvm_cases_len: u32 = 0; + var last_range_case: ?u32 = null; + var it = self.air.switchIterator(switch_inst); + while (it.nextCase()) |case| { + if (case.ranges.len > 0) last_range_case = case.index; + llvm_cases_len += @intCast(case.items.len); + } + + // The `else` of the LLVM `switch` is the actual `else` prong only + // if there are no ranges. Otherwise, the `else` will have a + // conditional chain before the "true" `else` prong. + const switch_else_block = if (last_range_case == null) + dispatch_info.case_blocks[dispatch_info.case_blocks.len - 1] + else + try self.wip.block(0, "RangeTest"); + + switch_else_block.ptr(&self.wip).incoming += 1; + + var wip_switch = try self.wip.@"switch"(cond_int, switch_else_block, llvm_cases_len); + defer wip_switch.finish(&self.wip); + + // Construct the actual cases. Set the cursor to the `else` block so + // we can construct ranges at the same time as scalar cases. + self.wip.cursor = .{ .block = switch_else_block }; + + it = self.air.switchIterator(switch_inst); + while (it.nextCase()) |case| { + const case_block = dispatch_info.case_blocks[case.index]; + + for (case.items) |item| { + const llvm_item = (try self.resolveInst(item)).toConst().?; + const llvm_int_item = if (llvm_item.typeOf(&o.builder).isPointer(&o.builder)) + try o.builder.castConst(.ptrtoint, llvm_item, llvm_usize) + else + llvm_item; + try wip_switch.addCase(llvm_int_item, case_block, &self.wip); + } + + case_block.ptr(&self.wip).incoming += @intCast(case.items.len); + + if (case.ranges.len == 0) continue; + + var range_cond: ?Builder.Value = null; + for (case.ranges) |range| { + const llvm_min = try self.resolveInst(range[0]); + const llvm_max = try self.resolveInst(range[1]); + const cond_part = try self.wip.bin( + .@"and", + try self.cmp(.normal, .gte, cond_ty, cond, llvm_min), + try self.cmp(.normal, .lte, cond_ty, cond, llvm_max), + "", + ); + if (range_cond) |old| { + range_cond = try self.wip.bin(.@"or", old, cond_part, ""); + } else range_cond = cond_part; + } + + // If the check fails, we either branch to the "true" `else` case, + // or to the next range condition. + const range_else_block = if (case.index == last_range_case.?) + dispatch_info.case_blocks[dispatch_info.case_blocks.len - 1] + else + try self.wip.block(0, "RangeTest"); + + _ = try self.wip.brCond(range_cond.?, case_block, range_else_block); + case_block.ptr(&self.wip).incoming += 1; + range_else_block.ptr(&self.wip).incoming += 1; + + // Construct the next range conditional (if any) in the false branch. + self.wip.cursor = .{ .block = range_else_block }; + } + } + + fn airSwitchDispatch(self: *FuncGen, inst: Air.Inst.Index) !void { + const br = self.air.instructions.items(.data)[@intFromEnum(inst)].br; + const dispatch_info = self.switch_dispatch_info.get(br.block_inst).?; + return self.lowerSwitchDispatch(br.block_inst, br.operand, dispatch_info); + } + fn airCondBr(self: *FuncGen, inst: Air.Inst.Index) !void { const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; const cond = try self.resolveInst(pl_op.operand); @@ -6117,131 +6327,169 @@ pub const FuncGen = struct { return fg.wip.extractValue(err_union, &.{offset}, ""); } - fn airSwitchBr(self: *FuncGen, inst: Air.Inst.Index) !void { + fn airSwitchBr(self: *FuncGen, inst: Air.Inst.Index, is_dispatch_loop: bool) !void { const o = self.dg.object; + const zcu = o.module; const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; - const cond = try self.resolveInst(pl_op.operand); const switch_br = self.air.extraData(Air.SwitchBr, pl_op.payload); - const else_block = try self.wip.block(1, "Default"); - const llvm_usize = try o.lowerType(Type.usize); - const cond_int = if (cond.typeOfWip(&self.wip).isPointer(&o.builder)) - try self.wip.cast(.ptrtoint, cond, llvm_usize, "") - else - cond; - - var extra_index: usize = switch_br.end; - var any_range_cases = false; - var llvm_cases_len: u32 = 0; - for (0..switch_br.data.cases_len) |_| { - const case = self.air.extraData(Air.SwitchBr.Case, extra_index); - if (case.data.ranges_len != 0) { - // TODO: for ranges, we could still define any scalar cases in the same prong within - // the switch, just directing it to the same bb as the range check. - any_range_cases = true; - extra_index = case.end + case.data.items_len + case.data.ranges_len * 2 + case.data.body_len; - continue; - } - const items: []const Air.Inst.Ref = - @ptrCast(self.air.extra[case.end..][0..case.data.items_len]); - const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len]; - extra_index = case.end + case.data.items_len + case_body.len; - - llvm_cases_len += @intCast(items.len); - } - - var wip_switch = try self.wip.@"switch"(cond_int, else_block, llvm_cases_len); - defer wip_switch.finish(&self.wip); - extra_index = switch_br.end; - for (0..switch_br.data.cases_len) |_| { - const case = self.air.extraData(Air.SwitchBr.Case, extra_index); - if (case.data.ranges_len != 0) { - extra_index = case.end + case.data.items_len + case.data.ranges_len * 2 + case.data.body_len; - continue; - } - const items: []const Air.Inst.Ref = - @ptrCast(self.air.extra[case.end..][0..case.data.items_len]); - const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]); - extra_index = case.end + case.data.items_len + case_body.len; + // For `loop_switch_br`, we need these BBs prepared ahead of time to generate dispatches. + // For `switch_br`, they allow us to sometimes generate better IR by sharing a BB between + // scalar and range cases in the same prong. + // +1 for `else` case. This is not the same as the LLVM `else` prong, as that first contains + // conditionals to handle ranges. + const case_blocks = try self.gpa.alloc(Builder.Function.Block.Index, switch_br.data.cases_len + 1); + defer self.gpa.free(case_blocks); + + for (case_blocks[0 .. case_blocks.len - 1]) |*case_block| { + case_block.* = try self.wip.block(0, "Case"); + } + case_blocks[case_blocks.len - 1] = try self.wip.block(0, "Default"); + + // There's a special case here to manually generate a jump table in some cases. + // + // Labeled switch in Zig is intended to follow the "direct threading" pattern. We would ideally use a jump + // table, and each `continue` has its own indirect `jmp`, to allow the branch predictor to more accurately + // use data patterns to predict future dispatches. The problem, however, is that LLVM emits fascinatingly + // bad asm for this. Not only does it not share the jump table -- which we really need it to do to prevent + // destroying the cache -- but it also actually generates slightly different jump tables for each case, + // and *a separate conditional branch beforehand* to handle dispatching back to the case we're currently + // within(!!). + // + // This asm is really, really, not what we want. As such, we will construct the jump table manually where + // appropriate (the values are dense and relatively few), and use it when lowering dispatches. + + const dispatch_info: SwitchDispatchInfo = .{ + .case_blocks = case_blocks, + .jmp_table = jmp_table: { + if (!is_dispatch_loop) break :jmp_table null; + // On a 64-bit target, 1024 pointers in our jump table is about 8K of pointers. This seems just + // about acceptable - it won't fill L1d cache on most CPUs. + const max_table_len = 1024; + + const cond_ty = self.typeOf(pl_op.operand); + switch (cond_ty.zigTypeTag(zcu)) { + .Bool, .Pointer => break :jmp_table null, + .Enum, .Int, .ErrorSet => {}, + else => unreachable, + } - const case_block = try self.wip.block(@intCast(items.len), "Case"); + if (cond_ty.intInfo(zcu).signedness == .signed) break :jmp_table null; - for (items) |item| { - const llvm_item = (try self.resolveInst(item)).toConst().?; - const llvm_int_item = if (llvm_item.typeOf(&o.builder).isPointer(&o.builder)) - try o.builder.castConst(.ptrtoint, llvm_item, llvm_usize) - else - llvm_item; - try wip_switch.addCase(llvm_int_item, case_block, &self.wip); - } + // Don't worry about the size of the type -- it's irrelevant, because the prong values could be fairly dense. + // If they are, then we will construct a jump table. + const min, const max = self.switchCaseItemRange(inst); + const min_int = min.getUnsignedInt(zcu) orelse break :jmp_table null; + const max_int = max.getUnsignedInt(zcu) orelse break :jmp_table null; + const table_len = max_int - min_int + 1; + if (table_len > max_table_len) break :jmp_table null; - self.wip.cursor = .{ .block = case_block }; - try self.genBodyDebugScope(null, case_body); - } + const table_elems = try self.gpa.alloc(Builder.Constant, @intCast(table_len)); + defer self.gpa.free(table_elems); - self.wip.cursor = .{ .block = else_block }; - const else_body: []const Air.Inst.Index = @ptrCast(self.air.extra[extra_index..][0..switch_br.data.else_body_len]); - if (any_range_cases) { - // We will iterate the cases again to handle those with ranges, and generate - // code using conditionals rather than switch cases for such cases. - const cond_ty = self.typeOf(pl_op.operand); - extra_index = switch_br.end; - for (0..switch_br.data.cases_len) |_| { - const case = self.air.extraData(Air.SwitchBr.Case, extra_index); - if (case.data.ranges_len == 0) { - // No ranges, so handled above - skip this case. - extra_index = case.end + case.data.items_len + case.data.body_len; - continue; - } - extra_index = case.end; - const items: []const Air.Inst.Ref = @ptrCast(self.air.extra[extra_index..][0..case.data.items_len]); - extra_index += items.len; - // TODO: this can be written more cleanly once Sema allows @ptrCast on slices where the length changes. - const ranges: []const [2]Air.Inst.Ref = @as([*]const [2]Air.Inst.Ref, @ptrCast(self.air.extra[extra_index..].ptr))[0..case.data.ranges_len]; - extra_index += ranges.len * 2; - const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[extra_index..][0..case.data.body_len]); - extra_index += case_body.len; - - var range_cond: ?Builder.Value = null; - - for (items) |item| { - const llvm_item = try self.resolveInst(item); - const cond_part = try self.cmp(.normal, .eq, cond_ty, cond, llvm_item); - if (range_cond) |old| { - range_cond = try self.wip.bin(.@"or", old, cond_part, ""); - } else range_cond = cond_part; - } - for (ranges) |range| { - const llvm_min = try self.resolveInst(range[0]); - const llvm_max = try self.resolveInst(range[1]); - const cond_part = try self.wip.bin( - .@"and", - try self.cmp(.normal, .gte, cond_ty, cond, llvm_min), - try self.cmp(.normal, .lte, cond_ty, cond, llvm_max), - "", + // Set them all to the `else` branch, then iterate over the AIR switch + // and replace all values which correspond to other prongs. + @memset(table_elems, try o.builder.blockAddrConst( + self.wip.function, + case_blocks[case_blocks.len - 1], + )); + var item_count: u32 = 0; + var it = self.air.switchIterator(inst); + while (it.nextCase()) |case| { + const case_block = case_blocks[case.index]; + const case_block_addr = try o.builder.blockAddrConst( + self.wip.function, + case_block, ); - if (range_cond) |old| { - range_cond = try self.wip.bin(.@"or", old, cond_part, ""); - } else range_cond = cond_part; + for (case.items) |item| { + const val = Value.fromInterned(item.toInterned().?); + const table_idx = val.toUnsignedInt(zcu) - min_int; + table_elems[@intCast(table_idx)] = case_block_addr; + item_count += 1; + } + for (case.ranges) |range| { + const low = Value.fromInterned(range[0].toInterned().?); + const high = Value.fromInterned(range[1].toInterned().?); + const low_idx = low.toUnsignedInt(zcu) - min_int; + const high_idx = high.toUnsignedInt(zcu) - min_int; + @memset(table_elems[@intCast(low_idx)..@intCast(high_idx + 1)], case_block_addr); + item_count += @intCast(high_idx + 1 - low_idx); + } } - const range_case_block = try self.wip.block(1, "RangeCase"); - const range_else_block = try self.wip.block(1, "RangeDefault"); + const table_llvm_ty = try o.builder.arrayType(table_elems.len, .ptr); + const table_val = try o.builder.arrayConst(table_llvm_ty, table_elems); - _ = try self.wip.brCond(range_cond.?, range_case_block, range_else_block); + const table_variable = try o.builder.addVariable( + try o.builder.strtabStringFmt("__jmptab_{d}", .{@intFromEnum(inst)}), + table_llvm_ty, + .default, + ); + try table_variable.setInitializer(table_val, &o.builder); + table_variable.setLinkage(.internal, &o.builder); + table_variable.setUnnamedAddr(.unnamed_addr, &o.builder); + + break :jmp_table .{ + .min = try o.lowerValue(min.toIntern()), + .max = try o.lowerValue(max.toIntern()), + .table = table_variable.toConst(&o.builder), + .table_includes_else = item_count != table_len, + }; + }, + }; - self.wip.cursor = .{ .block = range_case_block }; - try self.genBodyDebugScope(null, case_body); - self.wip.cursor = .{ .block = range_else_block }; - } + if (is_dispatch_loop) { + try self.switch_dispatch_info.putNoClobber(self.gpa, inst, dispatch_info); + } + defer if (is_dispatch_loop) { + assert(self.switch_dispatch_info.remove(inst)); + }; + + // Generate the initial dispatch. + // If this is a simple `switch_br`, this is the only dispatch. + try self.lowerSwitchDispatch(inst, pl_op.operand, dispatch_info); + + // Iterate the cases and generate their bodies. + var it = self.air.switchIterator(inst); + while (it.nextCase()) |case| { + const case_block = case_blocks[case.index]; + self.wip.cursor = .{ .block = case_block }; + try self.genBodyDebugScope(null, case.body); } - if (else_body.len != 0) { - try self.genBodyDebugScope(null, else_body); + self.wip.cursor = .{ .block = case_blocks[case_blocks.len - 1] }; + const else_body = it.elseBody(); + if (else_body.len > 0) { + try self.genBodyDebugScope(null, it.elseBody()); } else { _ = try self.wip.@"unreachable"(); } + } - // No need to reset the insert cursor since this instruction is noreturn. + fn switchCaseItemRange(self: *FuncGen, inst: Air.Inst.Index) [2]Value { + const zcu = self.dg.object.module; + var it = self.air.switchIterator(inst); + var min: ?Value = null; + var max: ?Value = null; + while (it.nextCase()) |case| { + for (case.items) |item| { + const val = Value.fromInterned(item.toInterned().?); + const low = if (min) |m| val.compareHetero(.lt, m, zcu) else true; + const high = if (max) |m| val.compareHetero(.gt, m, zcu) else true; + if (low) min = val; + if (high) max = val; + } + for (case.ranges) |range| { + const vals: [2]Value = .{ + Value.fromInterned(range[0].toInterned().?), + Value.fromInterned(range[1].toInterned().?), + }; + const low = if (min) |m| vals[0].compareHetero(.lt, m, zcu) else true; + const high = if (max) |m| vals[1].compareHetero(.gt, m, zcu) else true; + if (low) min = vals[0]; + if (high) max = vals[1]; + } + } + return .{ min.?, max.? }; } fn airLoop(self: *FuncGen, inst: Air.Inst.Index) !void { diff --git a/src/print_air.zig b/src/print_air.zig index d90adede978a..0674351734b6 100644 --- a/src/print_air.zig +++ b/src/print_air.zig @@ -295,11 +295,12 @@ const Writer = struct { .aggregate_init => try w.writeAggregateInit(s, inst), .union_init => try w.writeUnionInit(s, inst), .br => try w.writeBr(s, inst), + .switch_dispatch => try w.writeBr(s, inst), .repeat => try w.writeRepeat(s, inst), .cond_br => try w.writeCondBr(s, inst), .@"try" => try w.writeTry(s, inst), .try_ptr => try w.writeTryPtr(s, inst), - .switch_br => try w.writeSwitchBr(s, inst), + .loop_switch_br, .switch_br => try w.writeSwitchBr(s, inst), .cmpxchg_weak, .cmpxchg_strong => try w.writeCmpxchg(s, inst), .fence => try w.writeFence(s, inst), .atomic_load => try w.writeAtomicLoad(s, inst), diff --git a/src/print_zir.zig b/src/print_zir.zig index dfe94d397097..cf0c9454d400 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -303,6 +303,7 @@ const Writer = struct { .@"break", .break_inline, + .switch_continue, => try self.writeBreak(stream, inst), .slice_start => try self.writeSliceStart(stream, inst), diff --git a/test/behavior.zig b/test/behavior.zig index 3081f6c9f969..6c16d7df206a 100644 --- a/test/behavior.zig +++ b/test/behavior.zig @@ -87,6 +87,7 @@ test { _ = @import("behavior/struct_contains_null_ptr_itself.zig"); _ = @import("behavior/struct_contains_slice_of_itself.zig"); _ = @import("behavior/switch.zig"); + _ = @import("behavior/switch_loop.zig"); _ = @import("behavior/switch_prong_err_enum.zig"); _ = @import("behavior/switch_prong_implicit_cast.zig"); _ = @import("behavior/this.zig"); diff --git a/test/behavior/switch_loop.zig b/test/behavior/switch_loop.zig new file mode 100644 index 000000000000..b88bdfe74f49 --- /dev/null +++ b/test/behavior/switch_loop.zig @@ -0,0 +1,205 @@ +const builtin = @import("builtin"); +const std = @import("std"); +const expect = std.testing.expect; + +test "simple switch loop" { + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO + + const S = struct { + fn doTheTest() !void { + var start: u32 = undefined; + start = 32; + const result: u32 = s: switch (start) { + 0 => 0, + 1 => 1, + 2 => 2, + 3 => 3, + else => |x| continue :s x / 2, + }; + try expect(result == 2); + } + }; + try S.doTheTest(); + try comptime S.doTheTest(); +} + +test "switch loop with ranges" { + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO + + const S = struct { + fn doTheTest() !void { + var start: u32 = undefined; + start = 32; + const result = s: switch (start) { + 0...3 => |x| x, + else => |x| continue :s x / 2, + }; + try expect(result == 2); + } + }; + try S.doTheTest(); + try comptime S.doTheTest(); +} + +test "switch loop on enum" { + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO + + const S = struct { + const E = enum { a, b, c }; + + fn doTheTest() !void { + var start: E = undefined; + start = .a; + const result: u32 = s: switch (start) { + .a => continue :s .b, + .b => continue :s .c, + .c => 123, + }; + try expect(result == 123); + } + }; + try S.doTheTest(); + try comptime S.doTheTest(); +} + +test "switch loop on tagged union" { + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO + + const S = struct { + const U = union(enum) { + a: u32, + b: f32, + c: f32, + }; + + fn doTheTest() !void { + var start: U = undefined; + start = .{ .a = 80 }; + const result = s: switch (start) { + .a => |x| switch (x) { + 0...49 => continue :s .{ .b = @floatFromInt(x) }, + 50 => continue :s .{ .c = @floatFromInt(x) }, + else => continue :s .{ .a = x / 2 }, + }, + .b => |x| x, + .c => return error.TestFailed, + }; + try expect(result == 40.0); + } + }; + try S.doTheTest(); + try comptime S.doTheTest(); +} + +test "switch loop dispatching instructions" { + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO + + const S = struct { + const Inst = union(enum) { + set: u32, + add: u32, + sub: u32, + end, + }; + + fn doTheTest() !void { + var insts: [5]Inst = undefined; + @memcpy(&insts, &[5]Inst{ + .{ .set = 123 }, + .{ .add = 100 }, + .{ .sub = 50 }, + .{ .sub = 10 }, + .end, + }); + var i: u32 = 0; + var cur: u32 = undefined; + eval: switch (insts[0]) { + .set => |x| { + cur = x; + i += 1; + continue :eval insts[i]; + }, + .add => |x| { + cur += x; + i += 1; + continue :eval insts[i]; + }, + .sub => |x| { + cur -= x; + i += 1; + continue :eval insts[i]; + }, + .end => {}, + } + try expect(cur == 163); + } + }; + try S.doTheTest(); + try comptime S.doTheTest(); +} + +test "switch loop with pointer capture" { + if (builtin.zig_backend == .stage2_wasm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_x86_64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_sparc64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; // TODO + + const S = struct { + const U = union(enum) { + a: u32, + b: u32, + c: u32, + }; + + fn doTheTest() !void { + var a: U = .{ .a = 100 }; + var b: U = .{ .b = 200 }; + var c: U = .{ .c = 300 }; + inc: switch (a) { + .a => |*x| { + x.* += 1; + continue :inc b; + }, + .b => |*x| { + x.* += 10; + continue :inc c; + }, + .c => |*x| { + x.* += 50; + }, + } + try expect(a.a == 101); + try expect(b.b == 210); + try expect(c.c == 350); + } + }; + try S.doTheTest(); + try comptime S.doTheTest(); +} From aaa83b3e17dd32cec36da437e273373defd23a8e Mon Sep 17 00:00:00 2001 From: mlugg Date: Mon, 29 Apr 2024 02:25:54 +0100 Subject: [PATCH 5/9] std.zig: resolve syntactic ambiguity The parse of `fn foo(a: switch (...) { ... })` was previously handled incorrectly; `a` was treated as both the parameter name and a label. The same issue exists for `for` and `while` expressions -- they should be fixed too, and the grammar amended appropriately. This commit does not do this: it only aims to avoid introducing regressions from labeled switch syntax. --- lib/std/zig/Ast.zig | 20 +++++++++++++++----- lib/std/zig/Parse.zig | 18 +++++++++--------- 2 files changed, 24 insertions(+), 14 deletions(-) diff --git a/lib/std/zig/Ast.zig b/lib/std/zig/Ast.zig index a12f49178232..daa036fee6b0 100644 --- a/lib/std/zig/Ast.zig +++ b/lib/std/zig/Ast.zig @@ -1903,11 +1903,20 @@ pub fn taggedUnionEnumTag(tree: Ast, node: Node.Index) full.ContainerDecl { pub fn switchFull(tree: Ast, node: Node.Index) full.Switch { const data = &tree.nodes.items(.data)[node]; - return tree.fullSwitchComponents(.{ - .switch_token = tree.nodes.items(.main_token)[node], - .condition = data.lhs, - .sub_range = data.rhs, - }); + const main_token = tree.nodes.items(.main_token)[node]; + const switch_token: TokenIndex, const label_token: ?TokenIndex = switch (tree.tokens.items(.tag)[main_token]) { + .identifier => .{ main_token + 2, main_token }, + .keyword_switch => .{ main_token, null }, + else => unreachable, + }; + return .{ + .ast = .{ + .switch_token = switch_token, + .condition = data.lhs, + .sub_range = data.rhs, + }, + .label_token = label_token, + }; } pub fn switchCaseOne(tree: Ast, node: Node.Index) full.SwitchCase { @@ -3285,6 +3294,7 @@ pub const Node = struct { /// main_token is the `(`. async_call_comma, /// `switch(lhs) {}`. `SubRange[rhs]`. + /// `main_token` is the identifier of a preceding label, if any; otherwise `switch`. @"switch", /// Same as switch except there is known to be a trailing comma /// before the final rbrace diff --git a/lib/std/zig/Parse.zig b/lib/std/zig/Parse.zig index 55fbd95c4e1e..483b397c1db1 100644 --- a/lib/std/zig/Parse.zig +++ b/lib/std/zig/Parse.zig @@ -1245,7 +1245,7 @@ fn parseLabeledStatement(p: *Parse) !Node.Index { const loop_stmt = try p.parseLoopStatement(); if (loop_stmt != 0) return loop_stmt; - const switch_expr = try p.parseSwitchExpr(); + const switch_expr = try p.parseSwitchExpr(label_token != 0); if (switch_expr != 0) return switch_expr; if (label_token != 0) { @@ -2699,7 +2699,7 @@ fn parsePrimaryTypeExpr(p: *Parse) !Node.Index { .builtin => return p.parseBuiltinCall(), .keyword_fn => return p.parseFnProto(), .keyword_if => return p.parseIf(expectTypeExpr), - .keyword_switch => return p.expectSwitchExpr(), + .keyword_switch => return p.expectSwitchExpr(false), .keyword_extern, .keyword_packed, @@ -2756,7 +2756,7 @@ fn parsePrimaryTypeExpr(p: *Parse) !Node.Index { }, .keyword_switch => { p.tok_i += 2; - return p.expectSwitchExpr(); + return p.expectSwitchExpr(true); }, .l_brace => { p.tok_i += 2; @@ -3034,17 +3034,17 @@ fn parseWhileTypeExpr(p: *Parse) !Node.Index { } /// SwitchExpr <- KEYWORD_switch LPAREN Expr RPAREN LBRACE SwitchProngList RBRACE -fn parseSwitchExpr(p: *Parse) !Node.Index { +fn parseSwitchExpr(p: *Parse, is_labeled: bool) !Node.Index { const switch_token = p.eatToken(.keyword_switch) orelse return null_node; - return p.expectSwitchSuffix(switch_token); + return p.expectSwitchSuffix(if (is_labeled) switch_token - 2 else switch_token); } -fn expectSwitchExpr(p: *Parse) !Node.Index { +fn expectSwitchExpr(p: *Parse, is_labeled: bool) !Node.Index { const switch_token = p.assertToken(.keyword_switch); - return p.expectSwitchSuffix(switch_token); + return p.expectSwitchSuffix(if (is_labeled) switch_token - 2 else switch_token); } -fn expectSwitchSuffix(p: *Parse, switch_token: TokenIndex) !Node.Index { +fn expectSwitchSuffix(p: *Parse, main_token: TokenIndex) !Node.Index { _ = try p.expectToken(.l_paren); const expr_node = try p.expectExpr(); _ = try p.expectToken(.r_paren); @@ -3055,7 +3055,7 @@ fn expectSwitchSuffix(p: *Parse, switch_token: TokenIndex) !Node.Index { return p.addNode(.{ .tag = if (trailing_comma) .switch_comma else .@"switch", - .main_token = switch_token, + .main_token = main_token, .data = .{ .lhs = expr_node, .rhs = try p.addExtra(Node.SubRange{ From 750b804210be0eec963f46578e18574ba8953858 Mon Sep 17 00:00:00 2001 From: mlugg Date: Mon, 29 Apr 2024 23:31:41 +0100 Subject: [PATCH 6/9] x86_64: un-regress `loop` and `switch_br` This commit fixes codegen for the `loop` and `switch_br` instructions in the self-hosted x86_64 backend, which was previously regressed by this branch. It does *not* yet implement the new `loop_switch_br` instruction. --- src/arch/x86_64/CodeGen.zig | 134 +++++++++++++++++++++++++++--------- 1 file changed, 100 insertions(+), 34 deletions(-) diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index 64bf40af75c9..90edf3706e82 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -106,6 +106,13 @@ frame_allocs: std.MultiArrayList(FrameAlloc) = .{}, free_frame_indices: std.AutoArrayHashMapUnmanaged(FrameIndex, void) = .{}, frame_locs: std.MultiArrayList(Mir.FrameLoc) = .{}, +loop_repeat_info: std.AutoHashMapUnmanaged(Air.Inst.Index, struct { + /// The state to restore before branching. + state: State, + /// The branch target. + jmp_target: Mir.Inst.Index, +}) = .{}, + /// Debug field, used to find bugs in the compiler. air_bookkeeping: @TypeOf(air_bookkeeping_init) = air_bookkeeping_init, @@ -811,7 +818,7 @@ pub fn generate( const namespace = zcu.namespacePtr(fn_owner_decl.src_namespace); const mod = namespace.file_scope.mod; - var function = Self{ + var function: Self = .{ .gpa = gpa, .air = air, .liveness = liveness, @@ -835,6 +842,7 @@ pub fn generate( function.frame_allocs.deinit(gpa); function.free_frame_indices.deinit(gpa); function.frame_locs.deinit(gpa); + function.loop_repeat_info.deinit(gpa); var block_it = function.blocks.valueIterator(); while (block_it.next()) |block| block.deinit(gpa); function.blocks.deinit(gpa); @@ -2038,7 +2046,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .bitcast => try self.airBitCast(inst), .block => try self.airBlock(inst), .br => try self.airBr(inst), - .repeat => return self.fail("TODO implement `repeat`", .{}), + .repeat => try self.airRepeat(inst), .switch_dispatch => return self.fail("TODO implement `switch_dispatch`", .{}), .trap => try self.airTrap(), .breakpoint => try self.airBreakpoint(), @@ -13395,16 +13403,13 @@ fn airLoop(self: *Self, inst: Air.Inst.Index) !void { self.scope_generation += 1; const state = try self.saveState(); - const jmp_target: Mir.Inst.Index = @intCast(self.mir_instructions.len); - try self.genBody(body); - try self.restoreState(state, &.{}, .{ - .emit_instructions = true, - .update_tracking = false, - .resurrect = false, - .close_scope = true, + try self.loop_repeat_info.putNoClobber(self.gpa, inst, .{ + .state = state, + .jmp_target = @intCast(self.mir_instructions.len), }); - _ = try self.asmJmpReloc(jmp_target); + defer assert(self.loop_repeat_info.remove(inst)); + try self.genBody(body); self.finishAirBookkeeping(); } @@ -13446,13 +13451,20 @@ fn lowerBlock(self: *Self, inst: Air.Inst.Index, body: []const Air.Inst.Index) ! } fn airSwitchBr(self: *Self, inst: Air.Inst.Index) !void { + const zcu = self.bin_file.comp.module.?; const pl_op = self.air.instructions.items(.data)[@intFromEnum(inst)].pl_op; const condition = try self.resolveInst(pl_op.operand); const condition_ty = self.typeOf(pl_op.operand); - const switch_br = self.air.extraData(Air.SwitchBr, pl_op.payload); - var extra_index: usize = switch_br.end; - var case_i: u32 = 0; - const liveness = try self.liveness.getSwitchBr(self.gpa, inst, switch_br.data.cases_len + 1); + + const signedness = switch (condition_ty.zigTypeTag(zcu)) { + .Bool, .Pointer => .unsigned, + .Int, .Enum, .ErrorSet => condition_ty.intInfo(zcu).signedness, + else => unreachable, + }; + + var switch_it = self.air.switchIterator(inst); + + const liveness = try self.liveness.getSwitchBr(self.gpa, inst, switch_it.total_cases + 1); defer self.gpa.free(liveness.deaths); // If the condition dies here in this switch instruction, process @@ -13465,20 +13477,12 @@ fn airSwitchBr(self: *Self, inst: Air.Inst.Index) !void { self.scope_generation += 1; const state = try self.saveState(); - while (case_i < switch_br.data.cases_len) : (case_i += 1) { - const case = self.air.extraData(Air.SwitchBr.Case, extra_index); - if (case.data.ranges_len > 0) return self.fail("TODO: switch with ranges", .{}); - const items: []const Air.Inst.Ref = - @ptrCast(self.air.extra[case.end..][0..case.data.items_len]); - const case_body: []const Air.Inst.Index = - @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]); - extra_index = case.end + items.len + case_body.len; - - var relocs = try self.gpa.alloc(Mir.Inst.Index, items.len); + while (switch_it.nextCase()) |case| { + const relocs = try self.gpa.alloc(Mir.Inst.Index, case.items.len + case.ranges.len); defer self.gpa.free(relocs); try self.spillEflagsIfOccupied(); - for (items, relocs, 0..) |item, *reloc, i| { + for (case.items, relocs[0..case.items.len]) |item, *reloc| { const item_mcv = try self.resolveInst(item); const cc: Condition = switch (condition) { .eflags => |cc| switch (item_mcv.immediate) { @@ -13491,13 +13495,63 @@ fn airSwitchBr(self: *Self, inst: Air.Inst.Index) !void { break :cc .e; }, }; - reloc.* = try self.asmJccReloc(if (i < relocs.len - 1) cc else cc.negate(), undefined); + reloc.* = try self.asmJccReloc(cc, undefined); } - for (liveness.deaths[case_i]) |operand| try self.processDeath(operand); + for (case.ranges, relocs[case.items.len..]) |range, *reloc| { + const min_mcv = try self.resolveInst(range[0]); + const max_mcv = try self.resolveInst(range[1]); + // `null` means always false. + const lt_min: ?Condition = switch (condition) { + .eflags => |cc| switch (min_mcv.immediate) { + 0 => null, // condition never <0 + 1 => cc.negate(), + else => unreachable, + }, + else => cc: { + try self.genBinOpMir(.{ ._, .cmp }, condition_ty, condition, min_mcv); + break :cc switch (signedness) { + .unsigned => .b, + .signed => .l, + }; + }, + }; + const lt_min_reloc = if (lt_min) |cc| r: { + break :r try self.asmJccReloc(cc, undefined); + } else null; + // `null` means always true. + const lte_max: ?Condition = switch (condition) { + .eflags => |cc| switch (max_mcv.immediate) { + 0 => cc.negate(), + 1 => null, // condition always <=1 + else => unreachable, + }, + else => cc: { + try self.genBinOpMir(.{ ._, .cmp }, condition_ty, condition, max_mcv); + break :cc switch (signedness) { + .unsigned => .be, + .signed => .le, + }; + }, + }; + // "Success" case is in `reloc`... + if (lte_max) |cc| { + reloc.* = try self.asmJccReloc(cc, undefined); + } else { + reloc.* = try self.asmJmpReloc(undefined); + } + // ...and "fail" case falls through to next checks. + if (lt_min_reloc) |r| self.performReloc(r); + } + + // The jump to skip this case if the conditions all failed. + const skip_case_reloc = try self.asmJmpReloc(undefined); - for (relocs[0 .. relocs.len - 1]) |reloc| self.performReloc(reloc); - try self.genBody(case_body); + for (liveness.deaths[case.index]) |operand| try self.processDeath(operand); + + // Relocate all success cases to the body we're about to generate. + for (relocs) |reloc| self.performReloc(reloc); + try self.genBody(case.body); try self.restoreState(state, &.{}, .{ .emit_instructions = false, .update_tracking = true, @@ -13505,13 +13559,12 @@ fn airSwitchBr(self: *Self, inst: Air.Inst.Index) !void { .close_scope = true, }); - self.performReloc(relocs[relocs.len - 1]); + // Relocate the "skip" branch to fall through to the next case. + self.performReloc(skip_case_reloc); } - if (switch_br.data.else_body_len > 0) { - const else_body: []const Air.Inst.Index = - @ptrCast(self.air.extra[extra_index..][0..switch_br.data.else_body_len]); - + const else_body = switch_it.elseBody(); + if (else_body.len > 0) { const else_deaths = liveness.deaths.len - 1; for (liveness.deaths[else_deaths]) |operand| try self.processDeath(operand); @@ -13602,6 +13655,19 @@ fn airBr(self: *Self, inst: Air.Inst.Index) !void { self.finishAirBookkeeping(); } +fn airRepeat(self: *Self, inst: Air.Inst.Index) !void { + const loop_inst = self.air.instructions.items(.data)[@intFromEnum(inst)].repeat.loop_inst; + const repeat_info = self.loop_repeat_info.get(loop_inst).?; + try self.restoreState(repeat_info.state, &.{}, .{ + .emit_instructions = true, + .update_tracking = false, + .resurrect = false, + .close_scope = true, + }); + _ = try self.asmJmpReloc(repeat_info.jmp_target); + self.finishAirBookkeeping(); +} + fn airAsm(self: *Self, inst: Air.Inst.Index) !void { const mod = self.bin_file.comp.module.?; const ty_pl = self.air.instructions.items(.data)[@intFromEnum(inst)].ty_pl; From 2ed552d693c417d1a1fe3c03e5439dad14246fdb Mon Sep 17 00:00:00 2001 From: mlugg Date: Mon, 29 Apr 2024 23:41:08 +0100 Subject: [PATCH 7/9] wasm: un-regress `loop` --- src/arch/wasm/CodeGen.zig | 21 +++++++++++++++++---- 1 file changed, 17 insertions(+), 4 deletions(-) diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index 7fe42c3d7497..c3fd323b00f4 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -663,6 +663,8 @@ blocks: std.AutoArrayHashMapUnmanaged(Air.Inst.Index, struct { label: u32, value: WValue, }) = .{}, +/// Maps `loop` instructions to their label. `br` to here repeats the loop. +loops: std.AutoHashMapUnmanaged(Air.Inst.Index, u32) = .{}, /// `bytes` contains the wasm bytecode belonging to the 'code' section. code: *ArrayList(u8), /// The index the next local generated will have @@ -751,6 +753,7 @@ pub fn deinit(func: *CodeGen) void { } func.branches.deinit(func.gpa); func.blocks.deinit(func.gpa); + func.loops.deinit(func.gpa); func.locals.deinit(func.gpa); func.simd_immediates.deinit(func.gpa); func.mir_instructions.deinit(func.gpa); @@ -1897,7 +1900,7 @@ fn genInst(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { .trap => func.airTrap(inst), .breakpoint => func.airBreakpoint(inst), .br => func.airBr(inst), - .repeat => return func.fail("TODO implement `repeat`", .{}), + .repeat => func.airRepeat(inst), .switch_dispatch => return func.fail("TODO implement `switch_dispatch`", .{}), .int_from_bool => func.airIntFromBool(inst), .cond_br => func.airCondBr(inst), @@ -3500,13 +3503,13 @@ fn airLoop(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { const loop = func.air.extraData(Air.Block, ty_pl.payload); const body: []const Air.Inst.Index = @ptrCast(func.air.extra[loop.end..][0..loop.data.body_len]); + try func.loops.putNoClobber(func.gpa, inst, func.block_depth); + defer assert(func.loops.remove(inst)); + // result type of loop is always 'noreturn', meaning we can always // emit the wasm type 'block_empty'. try func.startBlock(.loop, wasm.block_empty); try func.genBody(body); - - // breaking to the index of a loop block will continue the loop instead - try func.addLabel(.br, 0); try func.endBlock(); func.finishAir(inst, .none, &.{}); @@ -3720,6 +3723,16 @@ fn airBr(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { func.finishAir(inst, .none, &.{br.operand}); } +fn airRepeat(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { + const repeat = func.air.instructions.items(.data)[@intFromEnum(inst)].repeat; + const loop_label = func.loops.get(repeat.loop_inst).?; + + const idx: u32 = func.block_depth - loop_label; + try func.addLabel(.br, idx); + + func.finishAir(inst, .none, &.{}); +} + fn airNot(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { const ty_op = func.air.instructions.items(.data)[@intFromEnum(inst)].ty_op; From e962ba962b688648dc1517ede33206c07ff4e0c9 Mon Sep 17 00:00:00 2001 From: Luuk de Gram Date: Wed, 1 May 2024 20:35:09 +0200 Subject: [PATCH 8/9] wasm: un-regress `switch_br` and fix `loop` `.loop` is also a block, so the block_depth being stored for the loop was incorrect. By storing the value *after* block creation, we ensure a correct block_depth to jump back to when receiving `.repeat`. This also un-regresses `switch_br` which now correctly handles ranges within cases. It supports it for both jump tables as well as regular conditional branches. --- src/arch/wasm/CodeGen.zig | 114 ++++++++++++++++++++++---------------- 1 file changed, 65 insertions(+), 49 deletions(-) diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index c3fd323b00f4..5f9e7f55cc94 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -3503,12 +3503,13 @@ fn airLoop(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { const loop = func.air.extraData(Air.Block, ty_pl.payload); const body: []const Air.Inst.Index = @ptrCast(func.air.extra[loop.end..][0..loop.data.body_len]); - try func.loops.putNoClobber(func.gpa, inst, func.block_depth); - defer assert(func.loops.remove(inst)); - // result type of loop is always 'noreturn', meaning we can always // emit the wasm type 'block_empty'. try func.startBlock(.loop, wasm.block_empty); + + try func.loops.putNoClobber(func.gpa, inst, func.block_depth); + defer assert(func.loops.remove(inst)); + try func.genBody(body); try func.endBlock(); @@ -4014,11 +4015,11 @@ fn airSwitchBr(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { const liveness = try func.liveness.getSwitchBr(func.gpa, inst, switch_br.data.cases_len + 1); defer func.gpa.free(liveness.deaths); - var extra_index: usize = switch_br.end; - var case_i: u32 = 0; - // a list that maps each value with its value and body based on the order inside the list. - const CaseValue = struct { integer: i32, value: Value }; + const CaseValue = union(enum) { + singular: struct { integer: i32, value: Value }, + range: struct { min: i32, min_value: Value, max: i32, max_value: Value }, + }; var case_list = try std.ArrayList(struct { values: []const CaseValue, body: []const Air.Inst.Index, @@ -4029,16 +4030,12 @@ fn airSwitchBr(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { var lowest_maybe: ?i32 = null; var highest_maybe: ?i32 = null; - while (case_i < switch_br.data.cases_len) : (case_i += 1) { - const case = func.air.extraData(Air.SwitchBr.Case, extra_index); - if (case.data.ranges_len != 0) return func.fail("TODO: switch with ranges", .{}); - const items: []const Air.Inst.Ref = @ptrCast(func.air.extra[case.end..][0..case.data.items_len]); - const case_body: []const Air.Inst.Index = @ptrCast(func.air.extra[case.end + items.len ..][0..case.data.body_len]); - extra_index = case.end + items.len + case_body.len; - const values = try func.gpa.alloc(CaseValue, items.len); + var case_it = func.air.switchIterator(inst); + while (case_it.nextCase()) |case| { + const values = try func.gpa.alloc(CaseValue, case.items.len + case.ranges.len); errdefer func.gpa.free(values); - for (items, 0..) |ref, i| { + for (case.items, 0..) |ref, i| { const item_val = (try func.air.value(ref, mod)).?; const int_val = func.valueAsI32(item_val, target_ty); if (lowest_maybe == null or int_val < lowest_maybe.?) { @@ -4047,10 +4044,33 @@ fn airSwitchBr(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { if (highest_maybe == null or int_val > highest_maybe.?) { highest_maybe = int_val; } - values[i] = .{ .integer = int_val, .value = item_val }; + values[i] = .{ .singular = .{ .integer = int_val, .value = item_val } }; } - case_list.appendAssumeCapacity(.{ .values = values, .body = case_body }); + for (case.ranges, 0..) |range, i| { + const min_val = (try func.air.value(range[0], mod)).?; + const int_min_val = func.valueAsI32(min_val, target_ty); + + if (lowest_maybe == null or int_min_val < lowest_maybe.?) { + lowest_maybe = int_min_val; + } + + const max_val = (try func.air.value(range[1], mod)).?; + const int_max_val = func.valueAsI32(max_val, target_ty); + + if (highest_maybe == null or int_max_val > highest_maybe.?) { + highest_maybe = int_max_val; + } + + values[i + case.items.len] = .{ .range = .{ + .min = int_min_val, + .min_value = min_val, + .max = int_max_val, + .max_value = max_val, + } }; + } + + case_list.appendAssumeCapacity(.{ .values = values, .body = case.body }); try func.startBlock(.block, blocktype); } @@ -4064,7 +4084,7 @@ fn airSwitchBr(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { // TODO: Benchmark this to find a proper value, LLVM seems to draw the line at '40~45'. const is_sparse = highest - lowest > 50 or target_ty.bitSize(mod) > 32; - const else_body: []const Air.Inst.Index = @ptrCast(func.air.extra[extra_index..][0..switch_br.data.else_body_len]); + const else_body = case_it.elseBody(); const has_else_body = else_body.len != 0; if (has_else_body) { try func.startBlock(.block, blocktype); @@ -4100,59 +4120,55 @@ fn airSwitchBr(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { const idx = blk: { for (case_list.items, 0..) |case, idx| { for (case.values) |case_value| { - if (case_value.integer == value) break :blk @as(u32, @intCast(idx)); + switch (case_value) { + .singular => |val| if (val.integer == value) break :blk @as(u32, @intCast(idx)), + .range => |range_val| if (value >= range_val.min and value <= range_val.max) { + break :blk @as(u32, @intCast(idx)); + }, + } } } // error sets are almost always sparse so we use the default case // for errors that are not present in any branch. This is fine as this default // case will never be hit for those cases but we do save runtime cost and size // by using a jump table for this instead of if-else chains. - break :blk if (has_else_body or target_ty.zigTypeTag(mod) == .ErrorSet) case_i else unreachable; + break :blk if (has_else_body or target_ty.zigTypeTag(mod) == .ErrorSet) case_it.total_cases else unreachable; }; func.mir_extra.appendAssumeCapacity(idx); } else if (has_else_body) { - func.mir_extra.appendAssumeCapacity(case_i); // default branch + func.mir_extra.appendAssumeCapacity(case_it.total_cases); // default branch } try func.endBlock(); } - const signedness: std.builtin.Signedness = blk: { - // by default we tell the operand type is unsigned (i.e. bools and enum values) - if (target_ty.zigTypeTag(mod) != .Int) break :blk .unsigned; - - // incase of an actual integer, we emit the correct signedness - break :blk target_ty.intInfo(mod).signedness; - }; - try func.branches.ensureUnusedCapacity(func.gpa, case_list.items.len + @intFromBool(has_else_body)); for (case_list.items, 0..) |case, index| { // when sparse, we use if/else-chain, so emit conditional checks if (is_sparse) { - // for single value prong we can emit a simple if - if (case.values.len == 1) { - try func.emitWValue(target); - const val = try func.lowerConstant(case.values[0].value, target_ty); - try func.emitWValue(val); - const opcode = buildOpcode(.{ - .valtype1 = typeToValtype(target_ty, mod), - .op = .ne, // not equal, because we want to jump out of this block if it does not match the condition. - .signedness = signedness, - }); - try func.addTag(Mir.Inst.Tag.fromOpcode(opcode)); + // for single value prong we can emit a simple condition + if (case.values.len == 1 and case.values[0] == .singular) { + const val = try func.lowerConstant(case.values[0].singular.value, target_ty); + // not equal, because we want to jump out of this block if it does not match the condition. + _ = try func.cmp(target, val, target_ty, .neq); try func.addLabel(.br_if, 0); } else { // in multi-value prongs we must check if any prongs match the target value. try func.startBlock(.block, blocktype); for (case.values) |value| { - try func.emitWValue(target); - const val = try func.lowerConstant(value.value, target_ty); - try func.emitWValue(val); - const opcode = buildOpcode(.{ - .valtype1 = typeToValtype(target_ty, mod), - .op = .eq, - .signedness = signedness, - }); - try func.addTag(Mir.Inst.Tag.fromOpcode(opcode)); + switch (value) { + .singular => |single_val| { + const val = try func.lowerConstant(single_val.value, target_ty); + _ = try func.cmp(target, val, target_ty, .eq); + }, + .range => |range| { + const min_val = try func.lowerConstant(range.min_value, target_ty); + const max_val = try func.lowerConstant(range.max_value, target_ty); + + const gte = try func.cmp(target, min_val, target_ty, .gte); + const lte = try func.cmp(target, max_val, target_ty, .lte); + _ = try func.binOp(gte, lte, Type.bool, .@"and"); + }, + } try func.addLabel(.br_if, 0); } // value did not match any of the prong values From ac6e2e762da13b5729ff81b8cc86222bc2237dfb Mon Sep 17 00:00:00 2001 From: mlugg Date: Tue, 30 Apr 2024 19:20:03 +0100 Subject: [PATCH 9/9] std.zig.render: fix switch rendering --- lib/std/zig/Ast.zig | 5 ++-- lib/std/zig/AstGen.zig | 5 ++-- lib/std/zig/render.zig | 57 ++++++++++++++++++------------------------ 3 files changed, 29 insertions(+), 38 deletions(-) diff --git a/lib/std/zig/Ast.zig b/lib/std/zig/Ast.zig index daa036fee6b0..5562dafc3114 100644 --- a/lib/std/zig/Ast.zig +++ b/lib/std/zig/Ast.zig @@ -1909,11 +1909,12 @@ pub fn switchFull(tree: Ast, node: Node.Index) full.Switch { .keyword_switch => .{ main_token, null }, else => unreachable, }; + const extra = tree.extraData(data.rhs, Ast.Node.SubRange); return .{ .ast = .{ .switch_token = switch_token, .condition = data.lhs, - .sub_range = data.rhs, + .cases = tree.extra_data[extra.start..extra.end], }, .label_token = label_token, }; @@ -2880,7 +2881,7 @@ pub const full = struct { pub const Components = struct { switch_token: TokenIndex, condition: Node.Index, - sub_range: Node.Index, + cases: []const Node.Index, }; }; diff --git a/lib/std/zig/AstGen.zig b/lib/std/zig/AstGen.zig index e4987b05e360..bb2298c70787 100644 --- a/lib/std/zig/AstGen.zig +++ b/lib/std/zig/AstGen.zig @@ -7621,9 +7621,8 @@ fn switchExpr( const node_tags = tree.nodes.items(.tag); const main_tokens = tree.nodes.items(.main_token); const token_tags = tree.tokens.items(.tag); - const operand_node = node_datas[node].lhs; - const extra = tree.extraData(node_datas[node].rhs, Ast.Node.SubRange); - const case_nodes = tree.extra_data[extra.start..extra.end]; + const operand_node = switch_full.ast.condition; + const case_nodes = switch_full.ast.cases; const need_rl = astgen.nodes_need_rl.contains(node); const block_ri: ResultInfo = if (need_rl) ri else .{ diff --git a/lib/std/zig/render.zig b/lib/std/zig/render.zig index 336f39421137..3151de130827 100644 --- a/lib/std/zig/render.zig +++ b/lib/std/zig/render.zig @@ -693,39 +693,27 @@ fn renderExpression(r: *Render, node: Ast.Node.Index, space: Space) Error!void { return renderToken(r, datas[node].rhs, space); }, - .@"break" => { + .@"break", .@"continue" => { const main_token = main_tokens[node]; const label_token = datas[node].lhs; const target = datas[node].rhs; if (label_token == 0 and target == 0) { - try renderToken(r, main_token, space); // break keyword + try renderToken(r, main_token, space); // break/continue } else if (label_token == 0 and target != 0) { - try renderToken(r, main_token, .space); // break keyword + try renderToken(r, main_token, .space); // break/continue try renderExpression(r, target, space); } else if (label_token != 0 and target == 0) { - try renderToken(r, main_token, .space); // break keyword - try renderToken(r, label_token - 1, .none); // colon + try renderToken(r, main_token, .space); // break/continue + try renderToken(r, label_token - 1, .none); // : try renderIdentifier(r, label_token, space, .eagerly_unquote); // identifier } else if (label_token != 0 and target != 0) { - try renderToken(r, main_token, .space); // break keyword - try renderToken(r, label_token - 1, .none); // colon + try renderToken(r, main_token, .space); // break/continue + try renderToken(r, label_token - 1, .none); // : try renderIdentifier(r, label_token, .space, .eagerly_unquote); // identifier try renderExpression(r, target, space); } }, - .@"continue" => { - const main_token = main_tokens[node]; - const label = datas[node].lhs; - if (label != 0) { - try renderToken(r, main_token, .space); // continue - try renderToken(r, label - 1, .none); // : - return renderIdentifier(r, label, space, .eagerly_unquote); // label - } else { - return renderToken(r, main_token, space); // continue - } - }, - .@"return" => { if (datas[node].lhs != 0) { try renderToken(r, main_tokens[node], .space); @@ -845,26 +833,29 @@ fn renderExpression(r: *Render, node: Ast.Node.Index, space: Space) Error!void { .@"switch", .switch_comma, => { - const switch_token = main_tokens[node]; - const condition = datas[node].lhs; - const extra = tree.extraData(datas[node].rhs, Ast.Node.SubRange); - const cases = tree.extra_data[extra.start..extra.end]; - const rparen = tree.lastToken(condition) + 1; + const full = tree.switchFull(node); - try renderToken(r, switch_token, .space); // switch keyword - try renderToken(r, switch_token + 1, .none); // lparen - try renderExpression(r, condition, .none); // condition expression - try renderToken(r, rparen, .space); // rparen + if (full.label_token) |label_token| { + try renderIdentifier(r, label_token, .none, .eagerly_unquote); // label + try renderToken(r, label_token + 1, .space); // : + } + + const rparen = tree.lastToken(full.ast.condition) + 1; + + try renderToken(r, full.ast.switch_token, .space); // switch + try renderToken(r, full.ast.switch_token + 1, .none); // ( + try renderExpression(r, full.ast.condition, .none); // condition expression + try renderToken(r, rparen, .space); // ) ais.pushIndentNextLine(); - if (cases.len == 0) { - try renderToken(r, rparen + 1, .none); // lbrace + if (full.ast.cases.len == 0) { + try renderToken(r, rparen + 1, .none); // { } else { - try renderToken(r, rparen + 1, .newline); // lbrace - try renderExpressions(r, cases, .comma); + try renderToken(r, rparen + 1, .newline); // { + try renderExpressions(r, full.ast.cases, .comma); } ais.popIndent(); - return renderToken(r, tree.lastToken(node), space); // rbrace + return renderToken(r, tree.lastToken(node), space); // } }, .switch_case_one,