diff --git a/lib/std/zig/Ast.zig b/lib/std/zig/Ast.zig index 20bdba8cf796..a12f49178232 100644 --- a/lib/std/zig/Ast.zig +++ b/lib/std/zig/Ast.zig @@ -1197,14 +1197,7 @@ pub fn lastToken(tree: Ast, node: Node.Index) TokenIndex { n = extra.sentinel; }, - .@"continue" => { - if (datas[n].lhs != 0) { - return datas[n].lhs + end_offset; - } else { - return main_tokens[n] + end_offset; - } - }, - .@"break" => { + .@"continue", .@"break" => { if (datas[n].rhs != 0) { n = datas[n].rhs; } else if (datas[n].lhs != 0) { @@ -1908,6 +1901,15 @@ pub fn taggedUnionEnumTag(tree: Ast, node: Node.Index) full.ContainerDecl { }); } +pub fn switchFull(tree: Ast, node: Node.Index) full.Switch { + const data = &tree.nodes.items(.data)[node]; + return tree.fullSwitchComponents(.{ + .switch_token = tree.nodes.items(.main_token)[node], + .condition = data.lhs, + .sub_range = data.rhs, + }); +} + pub fn switchCaseOne(tree: Ast, node: Node.Index) full.SwitchCase { const data = &tree.nodes.items(.data)[node]; const values: *[1]Node.Index = &data.lhs; @@ -2217,6 +2219,21 @@ fn fullContainerDeclComponents(tree: Ast, info: full.ContainerDecl.Components) f return result; } +fn fullSwitchComponents(tree: Ast, info: full.Switch.Components) full.Switch { + const token_tags = tree.tokens.items(.tag); + const tok_i = info.switch_token -| 1; + var result: full.Switch = .{ + .ast = info, + .label_token = null, + }; + if (token_tags[tok_i] == .colon and + token_tags[tok_i -| 1] == .identifier) + { + result.label_token = tok_i - 1; + } + return result; +} + fn fullSwitchCaseComponents(tree: Ast, info: full.SwitchCase.Components, node: Node.Index) full.SwitchCase { const token_tags = tree.tokens.items(.tag); const node_tags = tree.nodes.items(.tag); @@ -2488,6 +2505,13 @@ pub fn fullContainerDecl(tree: Ast, buffer: *[2]Ast.Node.Index, node: Node.Index }; } +pub fn fullSwitch(tree: Ast, node: Node.Index) ?full.Switch { + return switch (tree.nodes.items(.tag)[node]) { + .@"switch", .switch_comma => tree.switchFull(node), + else => null, + }; +} + pub fn fullSwitchCase(tree: Ast, node: Node.Index) ?full.SwitchCase { return switch (tree.nodes.items(.tag)[node]) { .switch_case_one, .switch_case_inline_one => tree.switchCaseOne(node), @@ -2840,6 +2864,17 @@ pub const full = struct { }; }; + pub const Switch = struct { + ast: Components, + label_token: ?TokenIndex, + + pub const Components = struct { + switch_token: TokenIndex, + condition: Node.Index, + sub_range: Node.Index, + }; + }; + pub const SwitchCase = struct { inline_token: ?TokenIndex, /// Points to the first token after the `|`. Will either be an identifier or @@ -3294,7 +3329,8 @@ pub const Node = struct { @"suspend", /// `resume lhs`. rhs is unused. @"resume", - /// `continue`. lhs is token index of label if any. rhs is unused. + /// `continue :lhs rhs` + /// both lhs and rhs may be omitted. @"continue", /// `break :lhs rhs` /// both lhs and rhs may be omitted. diff --git a/lib/std/zig/AstGen.zig b/lib/std/zig/AstGen.zig index 2cab0fe7ca91..3da00e9b5800 100644 --- a/lib/std/zig/AstGen.zig +++ b/lib/std/zig/AstGen.zig @@ -1138,7 +1138,7 @@ fn expr(gz: *GenZir, scope: *Scope, ri: ResultInfo, node: Ast.Node.Index) InnerE .error_set_decl => return errorSetDecl(gz, ri, node), .array_access => return arrayAccess(gz, scope, ri, node), .@"comptime" => return comptimeExprAst(gz, scope, ri, node), - .@"switch", .switch_comma => return switchExpr(gz, scope, ri.br(), node), + .@"switch", .switch_comma => return switchExpr(gz, scope, ri.br(), node, tree.fullSwitch(node).?), .@"nosuspend" => return nosuspendExpr(gz, scope, ri, node), .@"suspend" => return suspendExpr(gz, scope, node), @@ -2226,6 +2226,7 @@ fn continueExpr(parent_gz: *GenZir, parent_scope: *Scope, node: Ast.Node.Index) const tree = astgen.tree; const node_datas = tree.nodes.items(.data); const break_label = node_datas[node].lhs; + const rhs = node_datas[node].rhs; // Look for the label in the scope. var scope = parent_scope; @@ -2250,6 +2251,17 @@ fn continueExpr(parent_gz: *GenZir, parent_scope: *Scope, node: Ast.Node.Index) if (break_label != 0) blk: { if (gen_zir.label) |*label| { if (try astgen.tokenIdentEql(label.token, break_label)) { + const maybe_switch_tag = astgen.instructions.items(.tag)[@intFromEnum(label.block_inst)]; + if (rhs != 0) switch (maybe_switch_tag) { + .switch_block, .switch_block_ref => { + gen_zir.any_dispatch = true; + }, + else => return astgen.failNode(node, "cannot continue with operand", .{}), + } else switch (maybe_switch_tag) { + .switch_block, .switch_block_ref => return astgen.failNode(node, "cannot continue switch without operand", .{}), + else => {}, + } + label.used = true; break :blk; } @@ -2259,6 +2271,17 @@ fn continueExpr(parent_gz: *GenZir, parent_scope: *Scope, node: Ast.Node.Index) continue; } + if (rhs != 0) { + const operand = try reachableExpr(parent_gz, parent_scope, gen_zir.break_result_info, rhs, node); + + // As our last action before the continue, "pop" the error trace if needed + if (!gen_zir.is_comptime) + _ = try parent_gz.addRestoreErrRetIndex(.{ .block = continue_block }, .always, node); + + _ = try parent_gz.addBreakWithSrcNode(.switch_continue, continue_block, operand, rhs); + return Zir.Inst.Ref.unreachable_value; + } + const break_tag: Zir.Inst.Tag = if (gen_zir.is_inline) .break_inline else @@ -2842,6 +2865,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As .panic, .trap, .check_comptime_control_flow, + .switch_continue, => { noreturn_src_node = statement; break :b true; @@ -7568,7 +7592,8 @@ fn switchExpr( parent_gz: *GenZir, scope: *Scope, ri: ResultInfo, - switch_node: Ast.Node.Index, + node: Ast.Node.Index, + switch_full: Ast.full.Switch, ) InnerError!Zir.Inst.Ref { const astgen = parent_gz.astgen; const gpa = astgen.gpa; @@ -7577,14 +7602,14 @@ fn switchExpr( const node_tags = tree.nodes.items(.tag); const main_tokens = tree.nodes.items(.main_token); const token_tags = tree.tokens.items(.tag); - const operand_node = node_datas[switch_node].lhs; - const extra = tree.extraData(node_datas[switch_node].rhs, Ast.Node.SubRange); + const operand_node = node_datas[node].lhs; + const extra = tree.extraData(node_datas[node].rhs, Ast.Node.SubRange); const case_nodes = tree.extra_data[extra.start..extra.end]; - const need_rl = astgen.nodes_need_rl.contains(switch_node); + const need_rl = astgen.nodes_need_rl.contains(node); const block_ri: ResultInfo = if (need_rl) ri else .{ .rl = switch (ri.rl) { - .ptr => .{ .ty = (try ri.rl.resultType(parent_gz, switch_node)).? }, + .ptr => .{ .ty = (try ri.rl.resultType(parent_gz, node)).? }, .inferred_ptr => .none, else => ri.rl, }, @@ -7595,6 +7620,10 @@ fn switchExpr( const LocTag = @typeInfo(ResultInfo.Loc).Union.tag_type.?; const need_result_rvalue = @as(LocTag, block_ri.rl) != @as(LocTag, ri.rl); + if (switch_full.label_token) |label_token| { + try astgen.checkLabelRedefinition(scope, label_token); + } + // We perform two passes over the AST. This first pass is to collect information // for the following variables, make note of the special prong AST node index, // and bail out with a compile error if there are multiple special prongs present. @@ -7636,7 +7665,7 @@ fn switchExpr( ); } else if (underscore_src) |some_underscore| { return astgen.failNodeNotes( - switch_node, + node, "else and '_' prong in switch expression", .{}, &[_]u32{ @@ -7677,7 +7706,7 @@ fn switchExpr( ); } else if (else_src) |some_else| { return astgen.failNodeNotes( - switch_node, + node, "else and '_' prong in switch expression", .{}, &[_]u32{ @@ -7747,7 +7776,15 @@ fn switchExpr( try emitDbgStmtForceCurrentIndex(parent_gz, operand_lc); // This gets added to the parent block later, after the item expressions. const switch_tag: Zir.Inst.Tag = if (any_payload_is_ref) .switch_block_ref else .switch_block; - const switch_block = try parent_gz.makeBlockInst(switch_tag, switch_node); + const switch_block = try parent_gz.makeBlockInst(switch_tag, node); + + block_scope.continue_block = switch_block.toOptional(); + if (switch_full.label_token) |label_token| { + block_scope.label = .{ + .token = label_token, + .block_inst = switch_block, + }; + } // We re-use this same scope for all cases, including the special prong, if any. var case_scope = parent_gz.makeSubBlock(&block_scope.base); @@ -7969,6 +8006,7 @@ fn switchExpr( .has_under = special_prong == .under, .any_has_tag_capture = any_has_tag_capture, .scalar_cases_len = @intCast(scalar_cases_len), + .any_dispatch = block_scope.any_dispatch, }, }); @@ -8004,7 +8042,7 @@ fn switchExpr( } if (need_result_rvalue) { - return rvalue(parent_gz, ri, switch_block.toRef(), switch_node); + return rvalue(parent_gz, ri, switch_block.toRef(), node); } else { return switch_block.toRef(); } @@ -11872,6 +11910,7 @@ const GenZir = struct { cur_defer_node: Ast.Node.Index = 0, // Set if this GenZir is a defer or it is inside a defer. any_defer_node: Ast.Node.Index = 0, + any_dispatch: bool = false, const unstacked_top = std.math.maxInt(usize); /// Call unstack before adding any new instructions to containing GenZir. diff --git a/lib/std/zig/Parse.zig b/lib/std/zig/Parse.zig index 369e2ef125b4..55fbd95c4e1e 100644 --- a/lib/std/zig/Parse.zig +++ b/lib/std/zig/Parse.zig @@ -924,7 +924,6 @@ fn expectContainerField(p: *Parse) !Node.Index { /// / KEYWORD_errdefer Payload? BlockExprStatement /// / IfStatement /// / LabeledStatement -/// / SwitchExpr /// / VarDeclExprStatement fn expectStatement(p: *Parse, allow_defer_var: bool) Error!Node.Index { if (p.eatToken(.keyword_comptime)) |comptime_token| { @@ -995,7 +994,6 @@ fn expectStatement(p: *Parse, allow_defer_var: bool) Error!Node.Index { .rhs = try p.expectBlockExprStatement(), }, }), - .keyword_switch => return p.expectSwitchExpr(), .keyword_if => return p.expectIfStatement(), .keyword_enum, .keyword_struct, .keyword_union => { const identifier = p.tok_i + 1; @@ -1238,7 +1236,7 @@ fn expectIfStatement(p: *Parse) !Node.Index { }); } -/// LabeledStatement <- BlockLabel? (Block / LoopStatement) +/// LabeledStatement <- BlockLabel? (Block / LoopStatement / SwitchExpr) fn parseLabeledStatement(p: *Parse) !Node.Index { const label_token = p.parseBlockLabel(); const block = try p.parseBlock(); @@ -1247,6 +1245,9 @@ fn parseLabeledStatement(p: *Parse) !Node.Index { const loop_stmt = try p.parseLoopStatement(); if (loop_stmt != 0) return loop_stmt; + const switch_expr = try p.parseSwitchExpr(); + if (switch_expr != 0) return switch_expr; + if (label_token != 0) { const after_colon = p.tok_i; const node = try p.parseTypeExpr(); @@ -2072,7 +2073,7 @@ fn expectTypeExpr(p: *Parse) Error!Node.Index { /// / KEYWORD_break BreakLabel? Expr? /// / KEYWORD_comptime Expr /// / KEYWORD_nosuspend Expr -/// / KEYWORD_continue BreakLabel? +/// / KEYWORD_continue BreakLabel? Expr? /// / KEYWORD_resume Expr /// / KEYWORD_return Expr? /// / BlockLabel? LoopExpr @@ -2098,7 +2099,7 @@ fn parsePrimaryExpr(p: *Parse) !Node.Index { .main_token = p.nextToken(), .data = .{ .lhs = try p.parseBreakLabel(), - .rhs = undefined, + .rhs = try p.parseExpr(), }, }); }, @@ -2627,7 +2628,6 @@ fn parseSuffixExpr(p: *Parse) !Node.Index { /// / KEYWORD_anyframe /// / KEYWORD_unreachable /// / STRINGLITERAL -/// / SwitchExpr /// /// ContainerDecl <- (KEYWORD_extern / KEYWORD_packed)? ContainerDeclAuto /// @@ -2647,6 +2647,7 @@ fn parseSuffixExpr(p: *Parse) !Node.Index { /// LabeledTypeExpr /// <- BlockLabel Block /// / BlockLabel? LoopTypeExpr +/// / BlockLabel? SwitchExpr /// /// LoopTypeExpr <- KEYWORD_inline? (ForTypeExpr / WhileTypeExpr) fn parsePrimaryTypeExpr(p: *Parse) !Node.Index { @@ -2753,6 +2754,10 @@ fn parsePrimaryTypeExpr(p: *Parse) !Node.Index { p.tok_i += 2; return p.parseWhileTypeExpr(); }, + .keyword_switch => { + p.tok_i += 2; + return p.expectSwitchExpr(); + }, .l_brace => { p.tok_i += 2; return p.parseBlock(); @@ -3029,8 +3034,17 @@ fn parseWhileTypeExpr(p: *Parse) !Node.Index { } /// SwitchExpr <- KEYWORD_switch LPAREN Expr RPAREN LBRACE SwitchProngList RBRACE +fn parseSwitchExpr(p: *Parse) !Node.Index { + const switch_token = p.eatToken(.keyword_switch) orelse return null_node; + return p.expectSwitchSuffix(switch_token); +} + fn expectSwitchExpr(p: *Parse) !Node.Index { const switch_token = p.assertToken(.keyword_switch); + return p.expectSwitchSuffix(switch_token); +} + +fn expectSwitchSuffix(p: *Parse, switch_token: TokenIndex) !Node.Index { _ = try p.expectToken(.l_paren); const expr_node = try p.expectExpr(); _ = try p.expectToken(.r_paren); diff --git a/lib/std/zig/Zir.zig b/lib/std/zig/Zir.zig index 8aa4c0c8c53e..4f048863f628 100644 --- a/lib/std/zig/Zir.zig +++ b/lib/std/zig/Zir.zig @@ -317,6 +317,9 @@ pub const Inst = struct { /// break instruction in a block, and the target block is the parent. /// Uses the `break` union field. break_inline, + /// Branch from within a switch case to the case specified by the operand. + /// Uses the `break` union field. `block_inst` refers to a `switch_block` or `switch_block_ref`. + switch_continue, /// Checks that comptime control flow does not happen inside a runtime block. /// Uses the `un_node` union field. check_comptime_control_flow, @@ -1290,6 +1293,7 @@ pub const Inst = struct { .panic, .trap, .check_comptime_control_flow, + .switch_continue, => true, }; } @@ -1533,6 +1537,7 @@ pub const Inst = struct { .break_inline, .condbr, .condbr_inline, + .switch_continue, .compile_error, .ret_node, .ret_load, @@ -1618,6 +1623,7 @@ pub const Inst = struct { .bool_br_or = .pl_node, .@"break" = .@"break", .break_inline = .@"break", + .switch_continue = .@"break", .check_comptime_control_flow = .un_node, .for_len = .pl_node, .call = .pl_node, @@ -2955,9 +2961,10 @@ pub const Inst = struct { has_under: bool, /// If true, at least one prong has an inline tag capture. any_has_tag_capture: bool, + any_dispatch: bool, scalar_cases_len: ScalarCasesLen, - pub const ScalarCasesLen = u28; + pub const ScalarCasesLen = u27; pub fn specialProng(bits: Bits) SpecialProng { const has_else: u2 = @intFromBool(bits.has_else); diff --git a/lib/std/zig/render.zig b/lib/std/zig/render.zig index c6a6f3ce710d..faa89fb1faee 100644 --- a/lib/std/zig/render.zig +++ b/lib/std/zig/render.zig @@ -693,7 +693,7 @@ fn renderExpression(r: *Render, node: Ast.Node.Index, space: Space) Error!void { return renderToken(r, datas[node].rhs, space); }, - .@"break" => { + .@"break", .@"continue" => { const main_token = main_tokens[node]; const label_token = datas[node].lhs; const target = datas[node].rhs; @@ -714,18 +714,6 @@ fn renderExpression(r: *Render, node: Ast.Node.Index, space: Space) Error!void { } }, - .@"continue" => { - const main_token = main_tokens[node]; - const label = datas[node].lhs; - if (label != 0) { - try renderToken(r, main_token, .space); // continue - try renderToken(r, label - 1, .none); // : - return renderIdentifier(r, label, space, .eagerly_unquote); // label - } else { - return renderToken(r, main_token, space); // continue - } - }, - .@"return" => { if (datas[node].lhs != 0) { try renderToken(r, main_tokens[node], .space); @@ -844,28 +832,7 @@ fn renderExpression(r: *Render, node: Ast.Node.Index, space: Space) Error!void { .@"switch", .switch_comma, - => { - const switch_token = main_tokens[node]; - const condition = datas[node].lhs; - const extra = tree.extraData(datas[node].rhs, Ast.Node.SubRange); - const cases = tree.extra_data[extra.start..extra.end]; - const rparen = tree.lastToken(condition) + 1; - - try renderToken(r, switch_token, .space); // switch keyword - try renderToken(r, switch_token + 1, .none); // lparen - try renderExpression(r, condition, .none); // condition expression - try renderToken(r, rparen, .space); // rparen - - ais.pushIndentNextLine(); - if (cases.len == 0) { - try renderToken(r, rparen + 1, .none); // lbrace - } else { - try renderToken(r, rparen + 1, .newline); // lbrace - try renderExpressions(r, cases, .comma); - } - ais.popIndent(); - return renderToken(r, tree.lastToken(node), space); // rbrace - }, + => return renderSwitch(r, node, tree.fullSwitch(node).?, space), .switch_case_one, .switch_case_inline_one, @@ -1866,6 +1833,39 @@ fn renderFnProto(r: *Render, fn_proto: Ast.full.FnProto, space: Space) Error!voi return renderExpression(r, fn_proto.ast.return_type, space); } +fn renderSwitch( + r: *Render, + switch_node: Ast.Node.Index, + switch_full: Ast.full.Switch, + space: Space, +) Error!void { + const switch_token = switch_full.ast.switch_token; + const condition = switch_full.ast.condition; + const extra = r.tree.extraData(switch_full.ast.sub_range, Ast.Node.SubRange); + const cases = r.tree.extra_data[extra.start..extra.end]; + const rparen = r.tree.lastToken(condition) + 1; + + if (switch_full.label_token) |label| { + try renderIdentifier(r, label, .none, .eagerly_unquote); // label + try renderToken(r, label + 1, .space); // : + } + + try renderToken(r, switch_token, .space); // switch keyword + try renderToken(r, switch_token + 1, .none); // lparen + try renderExpression(r, condition, .none); // condition expression + try renderToken(r, rparen, .space); // rparen + + r.ais.pushIndentNextLine(); + if (cases.len == 0) { + try renderToken(r, rparen + 1, .none); // lbrace + } else { + try renderToken(r, rparen + 1, .newline); // lbrace + try renderExpressions(r, cases, .comma); + } + r.ais.popIndent(); + return renderToken(r, r.tree.lastToken(switch_node), space); // rbrace +} + fn renderSwitchCase( r: *Render, switch_case: Ast.full.SwitchCase, diff --git a/src/Air.zig b/src/Air.zig index 9554c55561a5..b382e6d4ad12 100644 --- a/src/Air.zig +++ b/src/Air.zig @@ -283,6 +283,10 @@ pub const Inst = struct { /// Result type is always noreturn; no instructions in a block follow this one. /// Uses the `br` field. br, + /// Switch dispatch. + /// TODO: description + /// Uses the `br` field. + switch_dispatch, /// Lowers to a trap/jam instruction causing program abortion. /// This may lower to an instruction known to be invalid. /// Sometimes, for the lack of a better instruction, `trap` and `breakpoint` may compile down to the same code. @@ -1134,9 +1138,29 @@ pub const CondBr = struct { /// Trailing: /// * 0. `Case` for each `cases_len` /// * 1. the else body, according to `else_body_len`. +/// * 2. dispatch information. pub const SwitchBr = struct { cases_len: u32, else_body_len: u32, + flags: Flags, + block_inst: u32, + + const Flags = packed struct(u32) { + /// The number of values the `switch`'s operand type can represent and, consequently, the size of the dispatch table. + dispatch_table_len: u30, + dispatch_mode: DispatchMode, + }; + + const DispatchMode = packed struct(u2) { + /// Set if any dispatch instruction has a comptime known operand. + direct: bool, + /// Set if any dispatch instruction has a runtime known operand. + indirect: bool, + + pub fn hasAnyDispatch(self: DispatchMode) bool { + return self.direct | self.indirect; + } + }; /// Trailing: /// * item: Inst.Ref // for each `items_len`. @@ -1441,6 +1465,7 @@ pub fn typeOfIndex(air: *const Air, inst: Air.Inst.Index, ip: *const InternPool) .br, .cond_br, .switch_br, + .switch_dispatch, .ret, .ret_safe, .ret_load, @@ -1603,6 +1628,7 @@ pub fn mustLower(air: Air, inst: Air.Inst.Index, ip: *const InternPool) bool { .call_never_inline, .cond_br, .switch_br, + .switch_dispatch, .@"try", .try_ptr, .dbg_stmt, diff --git a/src/Liveness.zig b/src/Liveness.zig index 4ca28758e222..03e71f104229 100644 --- a/src/Liveness.zig +++ b/src/Liveness.zig @@ -594,11 +594,12 @@ pub fn categorizeOperand( return .write; }, - .br => { + .br, .switch_dispatch => { const br = air_datas[@intFromEnum(inst)].br; if (br.operand == operand_ref) return matchOperandSmallIndex(l, operand, 0, .noret); return .noret; }, + .assembly => { return .complex; }, @@ -1198,7 +1199,7 @@ fn analyzeInst( return analyzeOperands(a, pass, data, inst, .{ pl_op.operand, extra.operand, .none }); }, - .br => return analyzeInstBr(a, pass, data, inst), + .br, .switch_dispatch => return analyzeInstBr(a, pass, data, inst), .assembly => { const extra = a.air.extraData(Air.Asm, inst_datas[@intFromEnum(inst)].ty_pl.payload); diff --git a/src/Liveness/Verify.zig b/src/Liveness/Verify.zig index 4392f25e101d..b64c6ab48b59 100644 --- a/src/Liveness/Verify.zig +++ b/src/Liveness/Verify.zig @@ -417,7 +417,7 @@ fn verifyBody(self: *Verify, body: []const Air.Inst.Index) Error!void { try self.verifyInst(inst); }, - .br => { + .br, .switch_dispatch => { const br = data[@intFromEnum(inst)].br; const gop = try self.blocks.getOrPut(self.gpa, br.block_inst); diff --git a/src/Sema.zig b/src/Sema.zig index d21fed6910ed..4134b3095efa 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -377,6 +377,7 @@ pub const Block = struct { // TODO is_comptime and comptime_reason should probably be merged together. is_comptime: bool, is_typeof: bool = false, + dispatch_mode: Air.SwitchBr.DispatchMode = .{ .direct = false, .indirect = false }, /// Keep track of the active error return trace index around blocks so that we can correctly /// pop the error trace upon block exit. @@ -598,6 +599,20 @@ pub const Block = struct { }); } + fn addSwitchDispatch( + block: *Block, + target_block: Air.Inst.Index, + operand: Air.Inst.Ref, + ) error{OutOfMemory}!Air.Inst.Ref { + return block.addInst(.{ + .tag = .switch_dispatch, + .data = .{ .br = .{ + .block_inst = target_block, + .operand = operand, + } }, + }); + } + fn addBinOp( block: *Block, tag: Air.Inst.Tag, @@ -1510,6 +1525,15 @@ fn analyzeBodyInner( sema.comptime_break_inst = inst; return error.ComptimeBreak; }, + .switch_continue => { + if (block.is_comptime) { + sema.comptime_break_inst = inst; + return error.ComptimeBreak; + } else { + try sema.zirSwitchContinue(block, inst); + break; + } + }, .repeat => { if (block.is_comptime) { // Send comptime control flow back to the beginning of this block. @@ -6622,6 +6646,36 @@ fn zirBreak(sema: *Sema, start_block: *Block, inst: Zir.Inst.Index) CompileError } } +fn zirSwitchContinue(sema: *Sema, start_block: *Block, inst: Zir.Inst.Index) CompileError!void { + const tracy = trace(@src()); + defer tracy.end(); + + const inst_data = sema.code.instructions.items(.data)[@intFromEnum(inst)].@"break"; + const extra = sema.code.extraData(Zir.Inst.Break, inst_data.payload_index).data; + const operand = try sema.resolveInst(inst_data.operand); + const is_comptime_known = try sema.isComptimeKnown(operand); + const zir_block = extra.block_inst; + + var block = start_block; + while (true) { + if (block.label) |label| { + if (label.zir_block == zir_block) { + if (is_comptime_known) + block.dispatch_mode.direct = true + else + block.dispatch_mode.indirect = true; + const br_ref = try start_block.addSwitchDispatch(label.merges.block_inst, operand); + const src_loc = LazySrcLoc.nodeOffset(extra.operand_src_node); + _ = br_ref; + _ = src_loc; + // TODO: check continue operand type with the switch operand one + return; + } + } + block = block.parent.?; + } +} + fn zirDbgStmt(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!void { if (block.is_comptime or block.ownerModule().strip) return; @@ -10864,6 +10918,7 @@ const SwitchProngAnalysis = struct { operand: Air.Inst.Ref, /// May be `undefined` if no prong has a by-ref capture. operand_ptr: Air.Inst.Ref, + operand_is_ref: bool, /// The switch condition value. For unions, `operand` is the union and `cond` is its tag. cond: Air.Inst.Ref, /// If this switch is on an error set, this is the type to assign to the @@ -10875,11 +10930,17 @@ const SwitchProngAnalysis = struct { /// undefined if no prong has a tag capture. tag_capture_inst: Zir.Inst.Index, + const ResolvedProng = union(enum) { + is_comptime: Air.Inst.Ref, + is_runtime: Air.Inst.Ref, + }; /// Resolve a switch prong which is determined at comptime to have no peers. /// Uses `resolveBlockBody`. Sets up captures as needed. + /// TODO: replace `resolveBlockBody` with `analyzeBodyInner` and handle `error.ComptimeBreak` fn resolveProngComptime( spa: SwitchProngAnalysis, child_block: *Block, + operand_ty: Type, prong_type: enum { normal, special }, prong_body: []const Zir.Inst.Index, capture: Zir.Inst.SwitchBlock.ProngInfo.Capture, @@ -10895,7 +10956,7 @@ const SwitchProngAnalysis = struct { /// `inline_case_capture` cannot be `.none`. has_tag_capture: bool, merges: *Block.Merges, - ) CompileError!Air.Inst.Ref { + ) CompileError!ResolvedProng { const sema = spa.sema; const src = sema.code.instructions.items(.data)[@intFromEnum(spa.switch_block_inst)].pl_node.src(); @@ -11326,6 +11387,18 @@ const SwitchProngAnalysis = struct { @typeInfo(Air.Block).Struct.fields.len + 1); + const dispatch_table_len: u30 = + if (block.dispatch_mode.hasAnyDispatch()) + switch (operand_ty.zigTypeTag(sema.mod)) { + .Bool => 2, + .Int => if (std.math.cast(u30, operand_ty.bitSize(sema.mod))) |b| std.math.powi(u30, 2, b) catch null else null, + .Enum => std.math.cast(u30, operand_ty.enumFieldCount(sema.mod)), + .Union => std.math.cast(u30, operand_ty.unionTagType(sema.mod).?.enumFieldCount(sema.mod)), + else => null, + } orelse return sema.fail(block, operand_src, "switch dispatch on type '{}'", .{operand_ty.fmt(mod)}) + else + 0; + const switch_br_inst: u32 = @intCast(sema.air_instructions.len); try sema.air_instructions.append(sema.gpa, .{ .tag = .switch_br, @@ -11334,6 +11407,11 @@ const SwitchProngAnalysis = struct { .payload = sema.addExtraAssumeCapacity(Air.SwitchBr{ .cases_len = @intCast(prong_count), .else_body_len = @intCast(else_body_len), + .flags = .{ + .dispatch_mode = block.dispatch_mode, + .dispatch_table_len = dispatch_table_len, + }, + .block_inst = if (block.label) |label| @intFromEnum(label.merges.block_inst) else 0, }), } }, }); @@ -11611,6 +11689,7 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp .parent_block = block, .operand = undefined, // must be set to the unwrapped error code before use .operand_ptr = .none, + .operand_is_ref = false, .cond = raw_operand_val, .else_error_ty = else_error_ty, .switch_block_inst = inst, @@ -11644,7 +11723,7 @@ fn zirSwitchBlockErrUnion(sema: *Sema, block: *Block, inst: Zir.Inst.Index) Comp return resolveSwitchComptime( sema, - spa, + &spa, &child_block, try sema.switchCond(block, switch_operand_src, spa.operand), err_val, @@ -11770,6 +11849,7 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r const operand_src: LazySrcLoc = .{ .node_offset_switch_operand = src_node_offset }; const special_prong_src: LazySrcLoc = .{ .node_offset_switch_special_prong = src_node_offset }; const extra = sema.code.extraData(Zir.Inst.SwitchBlock, inst_data.payload_index); + const any_dispatch = extra.data.bits.any_dispatch; const raw_operand_val: Air.Inst.Ref, const raw_operand_ptr: Air.Inst.Ref = blk: { const maybe_ptr = try sema.resolveInst(extra.data.operand); @@ -12247,11 +12327,12 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r }), } - const spa: SwitchProngAnalysis = .{ + var spa: SwitchProngAnalysis = .{ .sema = sema, .parent_block = block, .operand = raw_operand_val, .operand_ptr = raw_operand_ptr, + .operand_is_ref = operand_is_ref, .cond = operand, .else_error_ty = else_error_ty, .switch_block_inst = inst, @@ -12295,52 +12376,64 @@ fn zirSwitchBlock(sema: *Sema, block: *Block, inst: Zir.Inst.Index, operand_is_r defer child_block.instructions.deinit(gpa); defer merges.deinit(gpa); - if (try sema.resolveDefinedValue(&child_block, src, operand)) |operand_val| { - return resolveSwitchComptime( - sema, - spa, - &child_block, - operand, - operand_val, - operand_ty, - special, - case_vals, - scalar_cases_len, - multi_cases_len, - err_set, - empty_enum, - ); - } - - if (scalar_cases_len + multi_cases_len == 0 and !special.is_inline) { - if (empty_enum) { - return .void_value; - } - if (special_prong == .none) { - return sema.fail(block, src, "switch must handle all possibilities", .{}); - } - if (err_set and try sema.maybeErrorUnwrap(block, special.body, operand, operand_src, false)) { - return .unreachable_value; - } - if (mod.backendSupportsFeature(.is_named_enum_value) and block.wantSafety() and operand_ty.zigTypeTag(mod) == .Enum and - (!operand_ty.isNonexhaustiveEnum(mod) or union_originally)) - { - try sema.zirDbgStmt(block, cond_dbg_node_index); - const ok = try block.addUnOp(.is_named_enum_value, operand); - try sema.addSafetyCheck(block, src, ok, .corrupt_switch); + if (!any_dispatch or child_block.is_comptime) { + if (try sema.resolveDefinedValue(&child_block, src, operand)) |operand_val| { + return resolveSwitchComptime( + sema, + &spa, + &child_block, + operand, + operand_val, + operand_ty, + special, + case_vals, + scalar_cases_len, + multi_cases_len, + err_set, + empty_enum, + ); } - return spa.resolveProngComptime( - &child_block, - .special, - special.body, - special.capture, - .special_capture, - undefined, // case_vals may be undefined for special prongs - .none, - false, - merges, - ); + if (scalar_cases_len + multi_cases_len == 0 and !special.is_inline) { + if (empty_enum) { + return .void_value; + } + if (special_prong == .none) { + return sema.fail(block, src, "switch must handle all possibilities", .{}); + } + if (err_set and try sema.maybeErrorUnwrap(block, special.body, operand, operand_src, false)) { + return .unreachable_value; + } + if (mod.backendSupportsFeature(.is_named_enum_value) and block.wantSafety() and operand_ty.zigTypeTag(mod) == .Enum and + (!operand_ty.isNonexhaustiveEnum(mod) or union_originally)) + { + try sema.zirDbgStmt(block, cond_dbg_node_index); + const ok = try block.addUnOp(.is_named_enum_value, operand); + try sema.addSafetyCheck(block, src, ok, .corrupt_switch); + } + + dispatch: while (true) { + const resolved = try spa.resolveProngComptime( + &child_block, + .special, + special.body, + special.capture, + .special_capture, + undefined, // case_vals may be undefined for special prongs + .none, + false, + merges, + ); + if (child_block.is_comptime) { + const break_inst = sema.comptime_break_inst; + const break_tag = sema.code.instructions.items(.tag)[@intFromEnum(break_inst)]; + if (break_tag == .switch_continue) { + try sema.emitBackwardBranch(&child_block, src); + continue :dispatch; + } else return resolved; + } else return resolved; + } + } } if (child_block.is_comptime) { @@ -13010,9 +13103,26 @@ fn analyzeSwitchRuntimeBlock( try sema.air_extra.ensureUnusedCapacity(gpa, @typeInfo(Air.SwitchBr).Struct.fields.len + cases_extra.items.len + final_else_body.len); + const dispatch_table_len: u30 = + if (child_block.dispatch_mode.hasAnyDispatch()) + switch (operand_ty.zigTypeTag(sema.mod)) { + .Bool => 2, + .Int => if (std.math.cast(u30, operand_ty.bitSize(sema.mod))) |b| std.math.powi(u30, 2, b) catch null else null, + .Enum => std.math.cast(u30, operand_ty.enumFieldCount(sema.mod)), + .Union => std.math.cast(u30, operand_ty.unionTagType(sema.mod).?.enumFieldCount(sema.mod)), + else => null, + } orelse return sema.fail(child_block, operand_src, "switch dispatch on type '{}'", .{operand_ty.fmt(mod)}) + else + 0; + const payload_index = sema.addExtraAssumeCapacity(Air.SwitchBr{ .cases_len = @intCast(cases_len), .else_body_len = @intCast(final_else_body.len), + .flags = .{ + .dispatch_mode = child_block.dispatch_mode, + .dispatch_table_len = dispatch_table_len, + }, + .block_inst = if (child_block.label) |label| @intFromEnum(label.merges.block_inst) else 0, }); sema.air_extra.appendSliceAssumeCapacity(@ptrCast(cases_extra.items)); @@ -13029,7 +13139,7 @@ fn analyzeSwitchRuntimeBlock( fn resolveSwitchComptime( sema: *Sema, - spa: SwitchProngAnalysis, + spa: *SwitchProngAnalysis, child_block: *Block, cond_operand: Air.Inst.Ref, operand_val: Value, @@ -13042,115 +13152,181 @@ fn resolveSwitchComptime( empty_enum: bool, ) CompileError!Air.Inst.Ref { const merges = &child_block.label.?.merges; - const resolved_operand_val = try sema.resolveLazyValue(operand_val); - var extra_index: usize = special.end; - { - var scalar_i: usize = 0; - while (scalar_i < scalar_cases_len) : (scalar_i += 1) { - extra_index += 1; - const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]); - extra_index += 1; - const body = sema.code.bodySlice(extra_index, info.body_len); - extra_index += info.body_len; - - const item = case_vals.items[scalar_i]; - const item_val = sema.resolveConstDefinedValue(child_block, .unneeded, item, undefined) catch unreachable; - if (operand_val.eql(item_val, operand_ty, sema.mod)) { - if (err_set) try sema.maybeErrorUnwrapComptime(child_block, body, cond_operand); - return spa.resolveProngComptime( - child_block, - .normal, - body, - info.capture, - .{ .scalar_capture = @intCast(scalar_i) }, - &.{item}, - if (info.is_inline) cond_operand else .none, - info.has_tag_capture, - merges, - ); - } - } - } - { - var multi_i: usize = 0; - var case_val_idx: usize = scalar_cases_len; - while (multi_i < multi_cases_len) : (multi_i += 1) { - const items_len = sema.code.extra[extra_index]; - extra_index += 1; - const ranges_len = sema.code.extra[extra_index]; - extra_index += 1; - const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]); - extra_index += 1 + items_len; - const body = sema.code.bodySlice(extra_index + 2 * ranges_len, info.body_len); - const items = case_vals.items[case_val_idx..][0..items_len]; - case_val_idx += items_len; + var next_dispatch_val = operand_val; + dispatch: while (true) { + var extra_index: usize = special.end; + { + var scalar_i: usize = 0; + while (scalar_i < scalar_cases_len) : (scalar_i += 1) { + extra_index += 1; + const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]); + extra_index += 1; + const body = sema.code.bodySlice(extra_index, info.body_len); + extra_index += info.body_len; - for (items) |item| { - // Validation above ensured these will succeed. + const item = case_vals.items[scalar_i]; const item_val = sema.resolveConstDefinedValue(child_block, .unneeded, item, undefined) catch unreachable; - if (operand_val.eql(item_val, operand_ty, sema.mod)) { + if (next_dispatch_val.eql(item_val, operand_ty, sema.mod)) { if (err_set) try sema.maybeErrorUnwrapComptime(child_block, body, cond_operand); - return spa.resolveProngComptime( + switch (try spa.resolveProngComptime( child_block, + operand_ty, .normal, body, info.capture, - .{ .multi_capture = @intCast(multi_i) }, - items, + .{ .scalar_capture = @intCast(scalar_i) }, + &.{item}, if (info.is_inline) cond_operand else .none, info.has_tag_capture, merges, - ); + )) { + .is_comptime => |val| { + next_dispatch_val = val; + continue :dispatch; + }, + else => |val| return val, + } } } + } + { + var multi_i: usize = 0; + var case_val_idx: usize = scalar_cases_len; + while (multi_i < multi_cases_len) : (multi_i += 1) { + const items_len = sema.code.extra[extra_index]; + extra_index += 1; + const ranges_len = sema.code.extra[extra_index]; + extra_index += 1; + const info: Zir.Inst.SwitchBlock.ProngInfo = @bitCast(sema.code.extra[extra_index]); + extra_index += 1 + items_len; + const body = sema.code.bodySlice(extra_index + 2 * ranges_len, info.body_len); - var range_i: usize = 0; - while (range_i < ranges_len) : (range_i += 1) { - const range_items = case_vals.items[case_val_idx..][0..2]; - extra_index += 2; - case_val_idx += 2; + const items = case_vals.items[case_val_idx..][0..items_len]; + case_val_idx += items_len; - // Validation above ensured these will succeed. - const first_val = sema.resolveConstDefinedValue(child_block, .unneeded, range_items[0], undefined) catch unreachable; - const last_val = sema.resolveConstDefinedValue(child_block, .unneeded, range_items[1], undefined) catch unreachable; - if ((try sema.compareAll(resolved_operand_val, .gte, first_val, operand_ty)) and - (try sema.compareAll(resolved_operand_val, .lte, last_val, operand_ty))) - { - if (err_set) try sema.maybeErrorUnwrapComptime(child_block, body, cond_operand); - return spa.resolveProngComptime( - child_block, - .normal, - body, - info.capture, - .{ .multi_capture = @intCast(multi_i) }, - undefined, // case_vals may be undefined for ranges - if (info.is_inline) cond_operand else .none, - info.has_tag_capture, - merges, - ); + for (items) |item| { + // Validation above ensured these will succeed. + const item_val = sema.resolveConstDefinedValue(child_block, .unneeded, item, undefined) catch unreachable; + if (next_dispatch_val.eql(item_val, operand_ty, sema.mod)) { + if (err_set) try sema.maybeErrorUnwrapComptime(child_block, body, cond_operand); + switch (try spa.resolveProngComptime( + child_block, + operand_ty, + .normal, + body, + info.capture, + .{ .multi_capture = @intCast(multi_i) }, + items, + if (info.is_inline) cond_operand else .none, + info.has_tag_capture, + merges, + )) { + .is_comptime => |val| { + next_dispatch_val = val; + continue :dispatch; + }, + else => |val| return val, + } + } + } + + var range_i: usize = 0; + while (range_i < ranges_len) : (range_i += 1) { + const range_items = case_vals.items[case_val_idx..][0..2]; + extra_index += 2; + case_val_idx += 2; + + // Validation above ensured these will succeed. + const first_val = sema.resolveConstDefinedValue(child_block, .unneeded, range_items[0], undefined) catch unreachable; + const last_val = sema.resolveConstDefinedValue(child_block, .unneeded, range_items[1], undefined) catch unreachable; + + const resolved_operand_val = try sema.resolveLazyValue(next_dispatch_val); + if ((try sema.compareAll(resolved_operand_val, .gte, first_val, operand_ty)) and + (try sema.compareAll(resolved_operand_val, .lte, last_val, operand_ty))) + { + if (err_set) try sema.maybeErrorUnwrapComptime(child_block, body, cond_operand); + switch (try spa.resolveProngComptime( + child_block, + operand_ty, + .normal, + body, + info.capture, + .{ .multi_capture = @intCast(multi_i) }, + undefined, // case_vals may be undefined for ranges + if (info.is_inline) cond_operand else .none, + info.has_tag_capture, + merges, + )) { + .is_comptime => |val| { + next_dispatch_val = val; + continue :dispatch; + }, + else => |val| return val, + } + } } + + extra_index += info.body_len; } + } + if (err_set) try sema.maybeErrorUnwrapComptime(child_block, special.body, cond_operand); + if (empty_enum) { + return .void_value; + } - extra_index += info.body_len; + switch (try spa.resolveProngComptime( + child_block, + operand_ty, + .special, + special.body, + special.capture, + .special_capture, + undefined, // case_vals may be undefined for special prongs + if (special.is_inline) cond_operand else .none, + special.has_tag_capture, + merges, + )) { + .is_comptime => |val| { + next_dispatch_val = val; + continue :dispatch; + }, + else => |val| return val, } } - if (err_set) try sema.maybeErrorUnwrapComptime(child_block, special.body, cond_operand); - if (empty_enum) { - return .void_value; - } +} - return spa.resolveProngComptime( - child_block, - .special, - special.body, - special.capture, - .special_capture, - undefined, // case_vals may be undefined for special prongs - if (special.is_inline) cond_operand else .none, - special.has_tag_capture, - merges, - ); +fn resolveSwitchDispatchComptime( + sema: *Sema, + spa: *SwitchProngAnalysis, + child_block: *Block, + operand_ty: Type, + resolved_prong: Air.Inst.Ref, +) CompileError!?Value { + if (child_block.is_comptime) { + const break_inst = sema.comptime_break_inst; + const break_tag = sema.code.instructions.items(.tag)[@intFromEnum(break_inst)]; + const break_data = sema.code.instructions.items(.data)[@intFromEnum(break_inst)]; + const break_extra = sema.code.extraData(Zir.Inst.Break, break_data.@"break".payload_index); + const operand_src = LazySrcLoc.nodeOffset(break_extra.data.operand_src_node); + if (break_tag == .switch_continue) { + try sema.emitBackwardBranch(child_block, operand_src); + const coerced_resolved = try sema.coerce(child_block, operand_ty, resolved_prong, operand_src); + const operand_val: Air.Inst.Ref, const operand_ptr: Air.Inst.Ref = blk: { + const maybe_ptr = coerced_resolved; + if (spa.operand_is_ref) { + const src = sema.code.instructions.items(.data)[@intFromEnum(spa.switch_block_inst)].pl_node.src(); + const val = try sema.analyzeLoad(child_block, src, maybe_ptr, operand_src); + break :blk .{ val, maybe_ptr }; + } else { + break :blk .{ maybe_ptr, undefined }; + } + }; + spa.operand = operand_val; + spa.operand_ptr = operand_ptr; + return try sema.resolveConstDefinedValue(child_block, operand_src, coerced_resolved, undefined); + } else return null; + } else return null; } const RangeSetUnhandledIterator = struct { diff --git a/src/arch/aarch64/CodeGen.zig b/src/arch/aarch64/CodeGen.zig index b9f8259c05dc..7446315e2b93 100644 --- a/src/arch/aarch64/CodeGen.zig +++ b/src/arch/aarch64/CodeGen.zig @@ -827,6 +827,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .field_parent_ptr => try self.airFieldParentPtr(inst), .switch_br => try self.airSwitch(inst), + .switch_dispatch => return self.fail("TODO: implement switch_dispatch", .{}), .slice_ptr => try self.airSlicePtr(inst), .slice_len => try self.airSliceLen(inst), diff --git a/src/arch/arm/CodeGen.zig b/src/arch/arm/CodeGen.zig index 86d4e8f7fdd6..923f3dd146b1 100644 --- a/src/arch/arm/CodeGen.zig +++ b/src/arch/arm/CodeGen.zig @@ -813,6 +813,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .field_parent_ptr => try self.airFieldParentPtr(inst), .switch_br => try self.airSwitch(inst), + .switch_dispatch => return self.fail("TODO: implement switch_dispatch", .{}), .slice_ptr => try self.airSlicePtr(inst), .slice_len => try self.airSliceLen(inst), diff --git a/src/arch/riscv64/CodeGen.zig b/src/arch/riscv64/CodeGen.zig index 5abe3afcfd2a..9190da7019e8 100644 --- a/src/arch/riscv64/CodeGen.zig +++ b/src/arch/riscv64/CodeGen.zig @@ -646,6 +646,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .field_parent_ptr => try self.airFieldParentPtr(inst), .switch_br => try self.airSwitch(inst), + .switch_dispatch => return self.fail("TODO: implement switch_dispatch", .{}), .slice_ptr => try self.airSlicePtr(inst), .slice_len => try self.airSliceLen(inst), diff --git a/src/arch/sparc64/CodeGen.zig b/src/arch/sparc64/CodeGen.zig index 19c18ec4a6b0..61f1606efda3 100644 --- a/src/arch/sparc64/CodeGen.zig +++ b/src/arch/sparc64/CodeGen.zig @@ -660,6 +660,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .field_parent_ptr => @panic("TODO try self.airFieldParentPtr(inst)"), .switch_br => try self.airSwitch(inst), + .switch_dispatch => return self.fail("TODO: implement switch_dispatch", .{}), .slice_ptr => try self.airSlicePtr(inst), .slice_len => try self.airSliceLen(inst), diff --git a/src/arch/wasm/CodeGen.zig b/src/arch/wasm/CodeGen.zig index 6dc231672406..f73270512c59 100644 --- a/src/arch/wasm/CodeGen.zig +++ b/src/arch/wasm/CodeGen.zig @@ -1974,6 +1974,7 @@ fn genInst(func: *CodeGen, inst: Air.Inst.Index) InnerError!void { .field_parent_ptr => func.airFieldParentPtr(inst), .switch_br => func.airSwitchBr(inst), + .switch_dispatch => return func.fail("TODO: implement switch_dispatch", .{}), .trunc => func.airTrunc(inst), .unreach => func.airUnreachable(inst), diff --git a/src/arch/x86_64/CodeGen.zig b/src/arch/x86_64/CodeGen.zig index 1b584bfe53c1..0b3d4f7af6b9 100644 --- a/src/arch/x86_64/CodeGen.zig +++ b/src/arch/x86_64/CodeGen.zig @@ -2123,6 +2123,7 @@ fn genBody(self: *Self, body: []const Air.Inst.Index) InnerError!void { .field_parent_ptr => try self.airFieldParentPtr(inst), .switch_br => try self.airSwitchBr(inst), + .switch_dispatch => return self.fail("TODO: implement switch_dispatch", .{}), .slice_ptr => try self.airSlicePtr(inst), .slice_len => try self.airSliceLen(inst), diff --git a/src/codegen/c.zig b/src/codegen/c.zig index 7ae5c87ee540..85151c0501b2 100644 --- a/src/codegen/c.zig +++ b/src/codegen/c.zig @@ -3200,6 +3200,7 @@ fn genBodyInner(f: *Function, body: []const Air.Inst.Index) error{ AnalysisFail, .cond_br => try airCondBr(f, inst), .br => try airBr(f, inst), .switch_br => try airSwitchBr(f, inst), + .switch_dispatch => return f.fail("TODO: C backend: implement switch_dispatch", .{}), .struct_field_ptr => try airStructFieldPtr(f, inst), .array_to_slice => try airArrayToSlice(f, inst), .cmpxchg_weak => try airCmpxchg(f, inst, "weak"), diff --git a/src/codegen/llvm.zig b/src/codegen/llvm.zig index 8ddacbe11ca1..0760b2208371 100644 --- a/src/codegen/llvm.zig +++ b/src/codegen/llvm.zig @@ -1702,6 +1702,7 @@ pub const Object = struct { .prev_dbg_line = 0, .prev_dbg_column = 0, .err_ret_trace = err_ret_trace, + .switch_dispatch_tables = .{}, }; defer fg.deinit(); deinit_wip = false; @@ -4796,6 +4797,19 @@ pub const FuncGen = struct { sync_scope: Builder.SyncScope, + switch_dispatch_tables: std.AutoHashMapUnmanaged( + Air.Inst.Index, + DispatchTable, + ), + + const DispatchTable = struct { + // Written if there is any indirectbr. Each position corresponds to some case. There can be duplicates. + aggrt: ?Builder.Constant = null, + addrs: []const Builder.WipFunction.Block.Index, + // Used as argument in indirectbr. All values are unique. + dests: []const Builder.WipFunction.Block.Index, + }; + const BreakList = union { list: std.MultiArrayList(struct { bb: Builder.Function.Block.Index, @@ -4808,6 +4822,7 @@ pub const FuncGen = struct { self.wip.deinit(); self.func_inst_table.deinit(self.gpa); self.blocks.deinit(self.gpa); + self.switch_dispatch_tables.deinit(self.gpa); } fn todo(self: *FuncGen, comptime format: []const u8, args: anytype) Error { @@ -4983,6 +4998,7 @@ pub const FuncGen = struct { .block => try self.airBlock(inst), .br => try self.airBr(inst), .switch_br => try self.airSwitchBr(inst), + .switch_dispatch => try self.airSwitchDispatch(inst), .trap => try self.airTrap(inst), .breakpoint => try self.airBreakpoint(inst), .ret_addr => try self.airRetAddr(inst), @@ -5984,6 +6000,54 @@ pub const FuncGen = struct { return .none; } + fn airSwitchDispatch(self: *FuncGen, inst: Air.Inst.Index) !Builder.Value { + const o = self.dg.object; + var b = o.builder; + const data = self.air.instructions.items(.data)[@intFromEnum(inst)].br; + const switch_inst = data.block_inst; + const switch_table = self.switch_dispatch_tables.get(switch_inst).?; + if (try self.air.value(data.operand, o.module)) |value| { + const operand = value.toUnsignedInt(o.module); + const case_block = switch_table.addrs[operand]; + _ = try self.wip.br(case_block); + self.wip.blocks.items[@intFromEnum(case_block)].incoming += 1; + } else { + // TODO: generate IR like https://godbolt.org/z/zqx6sY5fh with these steps: + // - load the operand and cast it as an index + // - get the blockaddress of the case block by the index + // - jump using the indirectbr instruction and the blockaddress + // NOTE: that link shows a generated IR block called %indirectgoto but it is not + // necessary to create it. + const llvm_item = try self.resolveInst(data.operand); + const aggrt = switch_table.aggrt.?; + const base = aggrt.getBase(&b).toConst().toValue(); + + const aggrt_ty = aggrt.typeOf(&b); + const item_ty = llvm_item.typeOf(self.wip.function, &b); + const loaded_item = try self.wip.load( + .normal, + item_ty, + llvm_item, + Builder.Alignment.fromByteUnits(1), + "", + ); + + const sext_item = try self.wip.cast(.sext, loaded_item, .i64, ""); + const addr = try self.wip.gep(.inbounds, aggrt_ty, base, &.{sext_item}, ""); + const loaded_addr = try self.wip.load( + .normal, + try b.intType(std.math.log2_int(usize, switch_table.addrs.len)), + addr, + Builder.Alignment.fromByteUnits(8), + "", + ); + + _ = try self.wip.indirectBr(loaded_addr, switch_table.dests); + } + + return .none; + } + fn airTry(self: *FuncGen, body_tail: []const Air.Inst.Index) !Builder.Value { const o = self.dg.object; const mod = o.module; @@ -6095,6 +6159,14 @@ pub const FuncGen = struct { else cond; + var wip_addrs = std.ArrayList(Builder.WipFunction.Block.Index).init(self.gpa); + var wip_dests = std.ArrayList(Builder.WipFunction.Block.Index).init(self.gpa); + try wip_dests.ensureTotalCapacity(switch_br.data.cases_len + 1); + if (switch_br.data.flags.dispatch_mode.hasAnyDispatch()) { + try wip_addrs.ensureTotalCapacity(switch_br.data.flags.dispatch_table_len); + wip_addrs.appendNTimesAssumeCapacity(else_block, switch_br.data.flags.dispatch_table_len); + } + var extra_index: usize = switch_br.end; var case_i: u32 = 0; var llvm_cases_len: u32 = 0; @@ -6105,8 +6177,59 @@ pub const FuncGen = struct { const case_body = self.air.extra[case.end + items.len ..][0..case.data.body_len]; extra_index = case.end + case.data.items_len + case_body.len; + const case_block = try self.wip.block(@intCast(items.len), "Case"); + wip_dests.appendAssumeCapacity(case_block); + llvm_cases_len += @intCast(items.len); } + wip_dests.appendAssumeCapacity(else_block); + + if (switch_br.data.flags.dispatch_mode.hasAnyDispatch()) { + extra_index = switch_br.end; + case_i = 0; + while (case_i < switch_br.data.cases_len) : (case_i += 1) { + const case = self.air.extraData(Air.SwitchBr.Case, extra_index); + const items: []const Air.Inst.Ref = + @ptrCast(self.air.extra[case.end..][0..case.data.items_len]); + const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]); + extra_index = case.end + case.data.items_len + case_body.len; + + const case_block = wip_dests.items[case_i]; + + for (items) |item| { + const value = Value.fromInterned(item.toInterned().?).toUnsignedInt(o.module); + wip_addrs.items[value] = case_block; + } + } + } + + var aggregate: ?Builder.Constant = null; + if (switch_br.data.dispatchMode().hasIndirectDispatches()) { + const indices = wip_addrs.items; + var blockaddresses = std.ArrayList(Builder.Value).init(self.gpa); + defer blockaddresses.deinit(); + try blockaddresses.ensureTotalCapacity(indices.len); + for (indices) |i| { + const addr = try o.builder.blockAddrValue(self.wip.function, i); + blockaddresses.appendAssumeCapacity(addr); + } + aggregate = (try self.wip.buildAggregate( + try o.builder.arrayType(indices.len, .ptr), + blockaddresses.items, + "dispatch_table", + )).toConst().?; + const base = aggregate.?.getBase(&o.builder); + base.setUnnamedAddr(.unnamed_addr, &o.builder); + base.setLinkage(.internal, &o.builder); + } + const dispatch_table: DispatchTable = .{ + .aggrt = aggregate, + .addrs = try wip_addrs.toOwnedSlice(), + .dests = try wip_dests.toOwnedSlice(), + }; + if (switch_br.data.flags.dispatch_mode.hasAnyDispatch()) { + try self.switch_dispatch_tables.put(self.gpa, @enumFromInt(switch_br.data.block_inst), dispatch_table); + } var wip_switch = try self.wip.@"switch"(cond_int, else_block, llvm_cases_len); defer wip_switch.finish(&self.wip); @@ -6120,7 +6243,7 @@ pub const FuncGen = struct { const case_body: []const Air.Inst.Index = @ptrCast(self.air.extra[case.end + items.len ..][0..case.data.body_len]); extra_index = case.end + case.data.items_len + case_body.len; - const case_block = try self.wip.block(@intCast(items.len), "Case"); + const case_block = dispatch_table.dests[case_i]; for (items) |item| { const llvm_item = (try self.resolveInst(item)).toConst().?; diff --git a/src/codegen/llvm/Builder.zig b/src/codegen/llvm/Builder.zig index 12ecd2ec38e8..a920884f097d 100644 --- a/src/codegen/llvm/Builder.zig +++ b/src/codegen/llvm/Builder.zig @@ -4120,6 +4120,7 @@ pub const Function = struct { @"sub nuw", @"sub nuw nsw", @"switch", + indirectbr, @"tail call", @"tail call fast", trunc, @@ -4293,6 +4294,7 @@ pub const Function = struct { .ret, .@"ret void", .@"switch", + .indirectbr, .@"unreachable", => true, else => false, @@ -4309,6 +4311,7 @@ pub const Function = struct { .store, .@"store atomic", .@"switch", + .indirectbr, .@"unreachable", .block, => false, @@ -4399,6 +4402,7 @@ pub const Function = struct { .store, .@"store atomic", .@"switch", + .indirectbr, .@"unreachable", => .none, .call, @@ -4585,6 +4589,7 @@ pub const Function = struct { .store, .@"store atomic", .@"switch", + .indirectbr, .@"unreachable", => .none, .call, @@ -4750,6 +4755,11 @@ pub const Function = struct { //case_blocks: [cases_len]Block.Index, }; + pub const IndirectBr = struct { + ptr: Value, + blocks_len: u32, + }; + pub const Binary = struct { lhs: Value, rhs: Value, @@ -5023,7 +5033,7 @@ pub const WipFunction = struct { branches: u32 = 0, instructions: std.ArrayListUnmanaged(Instruction.Index), - const Index = enum(u32) { + pub const Index = enum(u32) { entry, _, @@ -5186,6 +5196,20 @@ pub const WipFunction = struct { return .{ .index = 0, .instruction = instruction }; } + pub fn indirectBr(self: *WipFunction, ptr: Builder.Value, blocks: []const WipFunction.Block.Index) Allocator.Error!Instruction.Index { + assert(blocks.len > 0); + try self.ensureUnusedExtraCapacity(1, Instruction.IndirectBr, blocks.len); + const instruction = try self.addInst(null, .{ + .tag = .indirectbr, + .data = self.addExtraAssumeCapacity(Instruction.IndirectBr{ + .ptr = ptr, + .blocks_len = @intCast(blocks.len), + }), + }); + self.extra.appendSliceAssumeCapacity(@ptrCast(blocks)); + return instruction; + } + pub fn @"unreachable"(self: *WipFunction) Allocator.Error!Instruction.Index { try self.ensureUnusedExtraCapacity(1, NoExtra, 0); const instruction = try self.addInst(null, .{ .tag = .@"unreachable", .data = undefined }); @@ -6449,6 +6473,13 @@ pub const WipFunction = struct { wip_extra.appendSlice(case_vals); wip_extra.appendSlice(case_blocks); }, + .indirectbr => { + const extra = self.extraDataTrail(Instruction.IndirectBr, instruction.data); + instruction.data = wip_extra.addExtra(Instruction.IndirectBr{ + .ptr = instructions.map(extra.data.ptr), + .blocks_len = extra.data.blocks_len, + }); + }, .va_arg => { const extra = self.extraData(Instruction.VaArg, instruction.data); instruction.data = wip_extra.addExtra(Instruction.VaArg{ @@ -10058,6 +10089,14 @@ pub fn printUnbuffered( ); try writer.writeAll(" ]"); }, + .indirectbr => |tag| { + var extra = function.extraDataTrail(Function.Instruction.IndirectBr, instruction.data); + const blocks = extra.trail.next(extra.data.blocks_len, Function.Block.Index, &function); + try writer.print(" {s} {%}, [ ", .{ @tagName(tag), extra.data.ptr.fmt(function_index, self) }); + for (blocks[0 .. blocks.len - 1]) |case_block| + try writer.print("{%}, ", .{case_block.toInst(&function).fmt(function_index, self)}); + try writer.print("{%} ]", .{blocks[blocks.len - 1].toInst(&function).fmt(function_index, self)}); + }, .va_arg => |tag| { const extra = function.extraData(Function.Instruction.VaArg, instruction.data); try writer.print(" %{} = {s} {%}, {%}", .{ @@ -15042,6 +15081,21 @@ pub fn toBitcode(self: *Builder, allocator: Allocator) bitcode_writer.Error![]co try function_block.writeUnabbrev(12, record.items); }, + .indirectbr => { + var extra = func.extraDataTrail(Function.Instruction.IndirectBr, datas[instr_index]); + + try record.ensureUnusedCapacity(self.gpa, 1 + extra.data.blocks_len); + + // Address + record.appendAssumeCapacity(adapter.getOffsetValueIndex(extra.data.ptr)); + + const blocks = extra.trail.next(extra.data.blocks_len, Function.Block.Index, &func); + for (blocks) |block| { + record.appendAssumeCapacity(@intFromEnum(block)); + } + + try function_block.writeUnabbrev(12, record.items); + }, .va_arg => { const extra = func.extraData(Function.Instruction.VaArg, datas[instr_index]); try function_block.writeAbbrev(FunctionBlock.VaArg{ diff --git a/src/codegen/llvm/ir.zig b/src/codegen/llvm/ir.zig index 8e3d20a63a92..768c15c7a77b 100644 --- a/src/codegen/llvm/ir.zig +++ b/src/codegen/llvm/ir.zig @@ -1124,6 +1124,7 @@ pub const FunctionBlock = struct { StoreAtomic, BrUnconditional, BrConditional, + IndirectBr, VaArg, AtomicRmw, CmpXchg, @@ -1522,6 +1523,14 @@ pub const FunctionBlock = struct { condition: u32, }; + pub const IndirectBr = struct { + pub const ops = [_]AbbrevOp{ + .{ .literal = 12 }, + }; + address: u32, + destinations: []const u32, + }; + pub const VaArg = struct { pub const ops = [_]AbbrevOp{ .{ .literal = 23 }, diff --git a/src/print_air.zig b/src/print_air.zig index 12e2825d4ef0..7c99de233a00 100644 --- a/src/print_air.zig +++ b/src/print_air.zig @@ -294,7 +294,7 @@ const Writer = struct { .aggregate_init => try w.writeAggregateInit(s, inst), .union_init => try w.writeUnionInit(s, inst), - .br => try w.writeBr(s, inst), + .br, .switch_dispatch => try w.writeBr(s, inst), .cond_br => try w.writeCondBr(s, inst), .@"try" => try w.writeTry(s, inst), .try_ptr => try w.writeTryPtr(s, inst), diff --git a/src/print_zir.zig b/src/print_zir.zig index e20eff63281e..32b2270fef26 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -303,6 +303,7 @@ const Writer = struct { .@"break", .break_inline, + .switch_continue, => try self.writeBreak(stream, inst), .slice_start => try self.writeSliceStart(stream, inst), diff --git a/test/behavior/switch.zig b/test/behavior/switch.zig index ae33e8e9abf6..0feaab22d844 100644 --- a/test/behavior/switch.zig +++ b/test/behavior/switch.zig @@ -958,3 +958,198 @@ test "block error return trace index is reset between prongs" { }; try result; } + +test "comptime direct switch dispatch with numbers" { + if (builtin.zig_backend != .stage2_llvm) return error.SkipZigTest; // TODO + + const val: u7 = 2; + const a = brk: switch (val) { + 0, 1 => continue :brk 3, + 2 => continue :brk 0, + 3 => 4, + else => 5, + }; + + comptime { + try expectEqual(4, a); + } +} + +test "runtime direct switch dispatch with numbers" { + if (builtin.zig_backend != .stage2_llvm) return error.SkipZigTest; // TODO + + var val: u7 = 2; + _ = &val; + const a = brk: switch (val) { + 0, 1 => continue :brk 3, + 2 => continue :brk 0, + 3 => 4, + else => 5, + }; + try expectEqual(4, a); +} + +test "direct switch dispatch with enum" { + if (builtin.zig_backend != .stage2_llvm) return error.SkipZigTest; // TODO + + const E = enum { a, b, c, d, e, f }; + const e = E.f; + + var val: u8 = 3; + val = brk: switch (e) { + .a, .b => continue :brk .d, + .c => break 2, + .d => continue :brk .c, + else => continue :brk .b, + }; + + try expectEqual(2, val); +} + +test "comptime direct switch dispatch with tagged union" { + if (builtin.zig_backend != .stage2_llvm) return error.SkipZigTest; // TODO + + const Val = union(enum) { + a: u8, + b: bool, + c: i32, + }; + + const val = Val{ .a = 3 }; + const b = brk: switch (val) { + .a => |v| { + const c: i32 = @intCast(v); + continue :brk Val{ .c = c }; + }, + .b => |v| break :brk !v, + .c => |v| continue :brk Val{ .b = v < 4 }, + }; + + comptime { + try expectEqual(false, b); + } +} + +test "runtime direct switch dispatch with tagged union" { + if (builtin.zig_backend != .stage2_llvm) return error.SkipZigTest; // TODO + + const Val = union(enum) { + a: u8, + b: bool, + c: i32, + }; + + var val = Val{ .a = 3 }; + _ = &val; + + const b = brk: switch (val) { + .a => |v| { + const c: i32 = @intCast(v); + continue :brk Val{ .c = c }; + }, + .b => |v| break :brk !v, + .c => |v| continue :brk Val{ .b = v < 4 }, + }; + + try expectEqual(false, b); +} + +test "runtime indirect switch dispatch with enum" { + if (builtin.zig_backend != .stage2_llvm) return error.SkipZigTest; // TODO + + const E = enum { a, b, c, d, e, f, g }; + + var val: u8 = 3; + var list = [_]E{ .b, .c, .f, .a, .e, .d, .b, .a, .b, .g }; + _ = &list; + + // 3 -> 4 -> 6 -> 2 -> 4 -> 2 -> 1 -> 2 -> 4 -> 5 + + var i: usize = 0; + brk: switch (list[i]) { + .a => { + val *= 2; + i += 1; + continue :brk list[i]; + }, + .b, .c => |op| { + val += if (op == .b) 1 else 2; + i += 1; + continue :brk list[i]; + }, + .d => { + val -= 1; + i += 1; + continue :brk list[i]; + }, + else => |op| { + if (op == .g) break :brk; + + val /= if (op == .e) 2 else 3; + i += 1; + continue :brk list[i]; + }, + } + + try expectEqual(5, val); +} + +test "runtime direct switch dispatch with referenced number" { + if (builtin.zig_backend != .stage2_llvm) return error.SkipZigTest; // TODO + + var val: u7 = 2; + _ = &val; + const a = brk: switch (val) { + 0, 1 => |*v| { + try expectEqual(2, val); + v.* = 3; + try expectEqual(3, val); + continue :brk v.*; + }, + 2 => continue :brk 0, + 3 => 4, + else => 5, + }; + try expectEqual(4, a); +} + +test "runtime direct switch dispatch with referenced tagged union" { + if (builtin.zig_backend != .stage2_llvm) return error.SkipZigTest; // TODO + + const Val = union(enum) { + a: u8, + b: bool, + c: i32, + }; + + var val = Val{ .a = 3 }; + _ = &val; + + const b = brk: switch (val) { + .a => |v| { + const c: i32 = @intCast(v); + continue :brk Val{ .c = c }; + }, + .b => |*v| { + try expectEqual(true, v.*); + break :brk !v.*; + }, + .c => |v| continue :brk Val{ .b = v < 4 }, + }; + + try expectEqual(false, b); +} + +test "direct switch dispatch with types" { + if (builtin.zig_backend != .stage2_llvm) return error.SkipZigTest; // TODO + + const val: type = u8; + const a = brk: switch (val) { + u8 => continue :brk u7, + bool => []i4, + u7 => continue :brk bool, + else => unreachable, + }; + + try expectEqual([]i4, a); +}