From 9d72432887387f3f99440c9969871146e391eb5b Mon Sep 17 00:00:00 2001 From: YazZz1k Date: Tue, 20 Feb 2024 13:28:27 +0300 Subject: [PATCH] [CIR][CIRGen] Support for __builtin_expect --- clang/include/clang/CIR/Dialect/IR/CIROps.td | 25 +++++++++ clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp | 34 +++++++++++- .../CodeGen/UnimplementedFeatureGuarding.h | 1 - .../CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp | 52 +++++++++++++++--- clang/test/CIR/CodeGen/pred-info-builtins.c | 33 ++++++++---- clang/test/CIR/Lowering/expect.cir | 54 +++++++++++++++++++ clang/test/CIR/Lowering/if.cir | 30 +++++------ clang/test/CIR/Lowering/switch.cir | 2 +- 8 files changed, 192 insertions(+), 39 deletions(-) create mode 100644 clang/test/CIR/Lowering/expect.cir diff --git a/clang/include/clang/CIR/Dialect/IR/CIROps.td b/clang/include/clang/CIR/Dialect/IR/CIROps.td index 0ea10462c265..f46ed3329645 100644 --- a/clang/include/clang/CIR/Dialect/IR/CIROps.td +++ b/clang/include/clang/CIR/Dialect/IR/CIROps.td @@ -2691,6 +2691,31 @@ def SinOp : UnaryFPToFPBuiltinOp<"sin">; def SqrtOp : UnaryFPToFPBuiltinOp<"sqrt">; def TruncOp : UnaryFPToFPBuiltinOp<"trunc">; +//===----------------------------------------------------------------------===// +// Branch Probability Operations +//===----------------------------------------------------------------------===// + +def ExpectOp : CIR_Op<"expect", + [Pure, AllTypesMatch<["result", "val", "expected"]>]> { + let summary = + "Compute whether expression is likely to evaluate to a specified value"; + let description = [{ + Provides __builtin_expect functionality in Clang IR. + + If $prob is not specified, then behaviour is same as __builtin_expect. + If specified, then behaviour is same as __builtin_expect_with_probability, + where probability = $prob. + }]; + + let arguments = (ins CIR_IntType:$val, + CIR_IntType:$expected, + OptionalAttr:$prob); + let results = (outs CIR_IntType:$result); + let assemblyFormat = [{ + `(` $val`,` $expected (`,` $prob^)? `)` `:` type($val) attr-dict + }]; +} + //===----------------------------------------------------------------------===// // Variadic Operations //===----------------------------------------------------------------------===// diff --git a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp index 0c351c7ea1b7..6eb18c52ca59 100644 --- a/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp +++ b/clang/lib/CIR/CodeGen/CIRGenBuiltin.cpp @@ -386,10 +386,40 @@ RValue CIRGenFunction::buildBuiltinExpr(const GlobalDecl GD, unsigned BuiltinID, } case Builtin::BI__builtin_expect: - case Builtin::BI__builtin_expect_with_probability: + case Builtin::BI__builtin_expect_with_probability: { + auto ArgValue = buildScalarExpr(E->getArg(0)); + auto ExpectedValue = buildScalarExpr(E->getArg(1)); + + // Don't generate cir.expect on -O0 as the backend won't use it for + // anything. Note, we still IRGen ExpectedValue because it could have + // side-effects. + if (CGM.getCodeGenOpts().OptimizationLevel == 0) + return RValue::get(ArgValue); + + mlir::FloatAttr ProbAttr = {}; + if (BuiltinIDIfNoAsmLabel == Builtin::BI__builtin_expect_with_probability) { + llvm::APFloat Probability(0.0); + const Expr *ProbArg = E->getArg(2); + bool EvalSucceed = + ProbArg->EvaluateAsFloat(Probability, CGM.getASTContext()); + assert(EvalSucceed && "probability should be able to evaluate as float"); + (void)EvalSucceed; + bool LoseInfo = false; + Probability.convert(llvm::APFloat::IEEEdouble(), + llvm::RoundingMode::Dynamic, &LoseInfo); + ProbAttr = mlir::FloatAttr::get( + mlir::FloatType::getF64(builder.getContext()), Probability); + } + + auto result = builder.create( + getLoc(E->getSourceRange()), ArgValue.getType(), ArgValue, + ExpectedValue, ProbAttr); + + return RValue::get(result); + } case Builtin::BI__builtin_unpredictable: { if (CGM.getCodeGenOpts().OptimizationLevel != 0) - assert(!UnimplementedFeature::branchPredictionInfoBuiltin()); + assert(!UnimplementedFeature::insertBuiltinUnpredictable()); return RValue::get(buildScalarExpr(E->getArg(0))); } diff --git a/clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h b/clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h index d6a7e1d89433..68f50950390d 100644 --- a/clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h +++ b/clang/lib/CIR/CodeGen/UnimplementedFeatureGuarding.h @@ -141,7 +141,6 @@ struct UnimplementedFeature { static bool armComputeVolatileBitfields() { return false; } static bool setCommonAttributes() { return false; } static bool insertBuiltinUnpredictable() { return false; } - static bool branchPredictionInfoBuiltin() { return false; } static bool createInvariantGroup() { return false; } static bool addAutoInitAnnotation() { return false; } static bool addHeapAllocSiteMetadata() { return false; } diff --git a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp index 1cdfc3e19f5b..925b8154c57e 100644 --- a/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp +++ b/clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp @@ -795,9 +795,27 @@ class CIRIfLowering : public mlir::OpConversionPattern { } rewriter.setInsertionPointToEnd(currentBlock); - auto trunc = rewriter.create(loc, rewriter.getI1Type(), - adaptor.getCondition()); - rewriter.create(loc, trunc.getRes(), thenBeforeBody, + + // FIXME: CIR always lowers !cir.bool to i8 type. + // In this reason CIR CodeGen often emits the redundant zext + trunc + // sequence that prevents lowering of llvm.expect in + // LowerExpectIntrinsicPass. + // We should fix that in a more appropriate way. But as a temporary solution + // just avoid the redundant casts here. + mlir::Value condition; + auto zext = + dyn_cast(adaptor.getCondition().getDefiningOp()); + if (zext && zext->getOperand(0).getType() == rewriter.getI1Type()) { + condition = zext->getOperand(0); + if (zext->use_empty()) + rewriter.eraseOp(zext); + } else { + auto trunc = rewriter.create( + loc, rewriter.getI1Type(), adaptor.getCondition()); + condition = trunc.getRes(); + } + + rewriter.create(loc, condition, thenBeforeBody, elseBeforeBody); if (!emptyElse) { @@ -2155,6 +2173,25 @@ class CIRFAbsOpLowering : public mlir::OpConversionPattern { } }; +class CIRExpectOpLowering + : public mlir::OpConversionPattern { +public: + using OpConversionPattern::OpConversionPattern; + + mlir::LogicalResult + matchAndRewrite(mlir::cir::ExpectOp op, OpAdaptor adaptor, + mlir::ConversionPatternRewriter &rewriter) const override { + std::optional prob = op.getProb(); + if (!prob) + rewriter.replaceOpWithNewOp(op, adaptor.getVal(), + adaptor.getExpected()); + else + rewriter.replaceOpWithNewOp( + op, adaptor.getVal(), adaptor.getExpected(), prob.value()); + return mlir::success(); + } +}; + class CIRVTableAddrPointOpLowering : public mlir::OpConversionPattern { public: @@ -2275,10 +2312,11 @@ void populateCIRToLLVMConversionPatterns(mlir::RewritePatternSet &patterns, CIRVACopyLowering, CIRVAArgLowering, CIRBrOpLowering, CIRTernaryOpLowering, CIRGetMemberOpLowering, CIRSwitchOpLowering, CIRPtrDiffOpLowering, CIRCopyOpLowering, CIRMemCpyOpLowering, - CIRFAbsOpLowering, CIRVTableAddrPointOpLowering, CIRVectorCreateLowering, - CIRVectorInsertLowering, CIRVectorExtractLowering, CIRVectorCmpOpLowering, - CIRStackSaveLowering, CIRStackRestoreLowering, CIRUnreachableLowering, - CIRInlineAsmOpLowering>(converter, patterns.getContext()); + CIRFAbsOpLowering, CIRExpectOpLowering, CIRVTableAddrPointOpLowering, + CIRVectorCreateLowering, CIRVectorInsertLowering, + CIRVectorExtractLowering, CIRVectorCmpOpLowering, CIRStackSaveLowering, + CIRStackRestoreLowering, CIRUnreachableLowering, CIRInlineAsmOpLowering>( + converter, patterns.getContext()); } namespace { diff --git a/clang/test/CIR/CodeGen/pred-info-builtins.c b/clang/test/CIR/CodeGen/pred-info-builtins.c index 585ff299a2a5..c9c44b73742f 100644 --- a/clang/test/CIR/CodeGen/pred-info-builtins.c +++ b/clang/test/CIR/CodeGen/pred-info-builtins.c @@ -1,4 +1,5 @@ -// RUN: %clang_cc1 -O0 -triple x86_64-unknown-linux-gnu -fclangir-enable -emit-cir %s -o - | FileCheck %s +// RUN: %clang_cc1 -O0 -triple x86_64-unknown-linux-gnu -fclangir-enable -emit-cir %s -o - | FileCheck %s --check-prefix=CIR-O0 +// RUN: %clang_cc1 -O2 -triple x86_64-unknown-linux-gnu -fclangir-enable -emit-cir %s -o - | FileCheck %s --check-prefix=CIR-O2 extern void __attribute__((noinline)) bar(void); @@ -6,22 +7,34 @@ void expect(int x) { if (__builtin_expect(x, 0)) bar(); } -// CHECK: cir.func @expect -// CHECK: cir.if {{%.*}} { -// CHECK: cir.call @bar() : () -> () +// CIR-O0: cir.func @expect +// CIR-O0: cir.if {{%.*}} { +// CIR-O0: cir.call @bar() : () -> () + +// CIR-O2: cir.func @expect +// CIR-O2: [[EXPECT:%.*]] = cir.expect({{.*}}, {{.*}}) : !s64i +// CIR-O2: [[EXPECT_BOOL:%.*]] = cir.cast(int_to_bool, [[EXPECT]] : !s64i), !cir.bool +// CIR-O2: cir.if [[EXPECT_BOOL]] +// CIR-O2: cir.call @bar() : () -> () void expect_with_probability(int x) { if (__builtin_expect_with_probability(x, 1, 0.8)) bar(); } -// CHECK: cir.func @expect_with_probability -// CHECK: cir.if {{%.*}} { -// CHECK: cir.call @bar() : () -> () +// CIR-O0: cir.func @expect_with_probability +// CIR-O0: cir.if {{%.*}} { +// CIR-O0: cir.call @bar() : () -> () + +// CIR-O2: cir.func @expect_with_probability +// CIR-O2: [[EXPECT:%.*]] = cir.expect({{.*}}, {{.*}}, 8.000000e-01) : !s64i +// CIR-O2: [[EXPECT_BOOL:%.*]] = cir.cast(int_to_bool, [[EXPECT]] : !s64i), !cir.bool +// CIR-O2: cir.if [[EXPECT_BOOL]] +// CIR-O2: cir.call @bar() : () -> () void unpredictable(int x) { if (__builtin_unpredictable(x > 1)) bar(); -// CHECK: cir.func @unpredictable -// CHECK: cir.if {{%.*}} { -// CHECK: cir.call @bar() : () -> () +// CIR-O0: cir.func @unpredictable +// CIR-O0: cir.if {{%.*}} { +// CIR-O0: cir.call @bar() : () -> () } diff --git a/clang/test/CIR/Lowering/expect.cir b/clang/test/CIR/Lowering/expect.cir new file mode 100644 index 000000000000..a221cca5f3dd --- /dev/null +++ b/clang/test/CIR/Lowering/expect.cir @@ -0,0 +1,54 @@ +// RUN: cir-opt %s -cir-to-llvm | FileCheck %s -check-prefix=MLIR +// RUN: cir-translate %s -cir-to-llvmir | FileCheck %s -check-prefix=LLVM + +!s64i = !cir.int +module { + cir.func @foo(%arg0: !s64i) { + %0 = cir.const(#cir.int<1> : !s64i) : !s64i + %1 = cir.expect(%arg0, %0) : !s64i + %2 = cir.cast(int_to_bool, %1 : !s64i), !cir.bool + cir.if %2 { + cir.yield + } + %3 = cir.expect(%arg0, %0, 1.000000e-01) : !s64i + %4 = cir.cast(int_to_bool, %3 : !s64i), !cir.bool + cir.if %4 { + cir.yield + } + cir.return + } +} + +// MLIR: llvm.func @foo(%arg0: i64) +// MLIR: [[ONE:%.*]] = llvm.mlir.constant(1 : i64) : i64 +// MLIR: [[EXPECT:%.*]] = llvm.intr.expect %arg0, [[ONE]] : i64 +// MLIR: [[ZERO:%.*]] = llvm.mlir.constant(0 : i64) : i64 +// MLIR: [[CMP_NE:%.*]] = llvm.icmp "ne" [[EXPECT]], [[ZERO]] : i64 +// MLIR: llvm.cond_br [[CMP_NE]], ^bb1, ^bb2 +// MLIR: ^bb1: // pred: ^bb0 +// MLIR: llvm.br ^bb2 +// MLIR: ^bb2: // 2 preds: ^bb0, ^bb1 +// MLIR: [[EXPECT_WITH_PROB:%.*]] = llvm.intr.expect.with.probability %arg0, [[ONE]], 1.000000e-01 : i64 +// MLIR: [[ZERO:%.*]] = llvm.mlir.constant(0 : i64) : i64 +// MLIR: [[CMP_NE:%.*]] = llvm.icmp "ne" [[EXPECT_WITH_PROB]], [[ZERO]] : i64 +// MLIR: llvm.cond_br [[CMP_NE]], ^bb3, ^bb4 +// MLIR: ^bb3: // pred: ^bb2 +// MLIR: llvm.br ^bb4 +// MLIR: ^bb4: // 2 preds: ^bb2, ^bb3 +// MLIR: llvm.return + +// LLVM: define void @foo(i64 %0) +// LLVM: [[EXPECT:%.*]] = call i64 @llvm.expect.i64(i64 %0, i64 1) +// LLVM: [[CMP_NE:%.*]] = icmp ne i64 [[EXPECT]], 0 +// LLVM: br i1 [[CMP_NE]], label %4, label %5 +// LLVM: 4: +// LLVM: br label %5 +// LLVM: 5: +// LLVM: [[EXPECT_WITH_PROB:%.*]] = call i64 @llvm.expect.with.probability.i64(i64 %0, i64 1, double 1.000000e-01) +// LLVM: [[CMP_NE:%.*]] = icmp ne i64 [[EXPECT_WITH_PROB]], 0 +// LLVM: br i1 [[CMP_NE]], label %8, label %9 +// LLVM: 8: +// LLVM: br label %9 +// LLVM: 9: +// LLVM: ret void + diff --git a/clang/test/CIR/Lowering/if.cir b/clang/test/CIR/Lowering/if.cir index a6dfd8e65900..eac0b5e4467e 100644 --- a/clang/test/CIR/Lowering/if.cir +++ b/clang/test/CIR/Lowering/if.cir @@ -18,32 +18,28 @@ module { // MLIR: llvm.func @foo(%arg0: i32) -> i32 // MLIR-NEXT: %0 = llvm.mlir.constant(0 : i32) : i32 // MLIR-NEXT: %1 = llvm.icmp "ne" %arg0, %0 : i32 -// MLIR-NEXT: %2 = llvm.zext %1 : i1 to i8 -// MLIR-NEXT: %3 = llvm.trunc %2 : i8 to i1 -// MLIR-NEXT: llvm.cond_br %3, ^bb2, ^bb1 +// MLIR-NEXT: llvm.cond_br %1, ^bb2, ^bb1 // MLIR-NEXT: ^bb1: // pred: ^bb0 -// MLIR-NEXT: %4 = llvm.mlir.constant(0 : i32) : i32 -// MLIR-NEXT: llvm.return %4 : i32 +// MLIR-NEXT: %2 = llvm.mlir.constant(0 : i32) : i32 +// MLIR-NEXT: llvm.return %2 : i32 // MLIR-NEXT: ^bb2: // pred: ^bb0 -// MLIR-NEXT: %5 = llvm.mlir.constant(1 : i32) : i32 -// MLIR-NEXT: llvm.return %5 : i32 +// MLIR-NEXT: %3 = llvm.mlir.constant(1 : i32) : i32 +// MLIR-NEXT: llvm.return %3 : i32 // MLIR-NEXT: ^bb3: // no predecessors // MLIR-NEXT: llvm.return %arg0 : i32 // MLIR-NEXT: } // LLVM: define i32 @foo(i32 %0) // LLVM-NEXT: %2 = icmp ne i32 %0, 0 -// LLVM-NEXT: %3 = zext i1 %2 to i8 -// LLVM-NEXT: %4 = trunc i8 %3 to i1 -// LLVM-NEXT: br i1 %4, label %6, label %5 +// LLVM-NEXT: br i1 %2, label %4, label %3 // LLVM-EMPTY: -// LLVM-NEXT: 5: +// LLVM-NEXT: 3: // LLVM-NEXT: ret i32 0 // LLVM-EMPTY: -// LLVM-NEXT: 6: +// LLVM-NEXT: 4: // LLVM-NEXT: ret i32 1 // LLVM-EMPTY: -// LLVM-NEXT: 7: +// LLVM-NEXT: 5: // LLVM-NEXT: ret i32 %0 // LLVM-NEXT: } @@ -59,12 +55,10 @@ module { // MLIR: llvm.func @onlyIf(%arg0: i32) -> i32 // MLIR-NEXT: %0 = llvm.mlir.constant(0 : i32) : i32 // MLIR-NEXT: %1 = llvm.icmp "ne" %arg0, %0 : i32 - // MLIR-NEXT: %2 = llvm.zext %1 : i1 to i8 - // MLIR-NEXT: %3 = llvm.trunc %2 : i8 to i1 - // MLIR-NEXT: llvm.cond_br %3, ^bb1, ^bb2 + // MLIR-NEXT: llvm.cond_br %1, ^bb1, ^bb2 // MLIR-NEXT: ^bb1: // pred: ^bb0 - // MLIR-NEXT: %4 = llvm.mlir.constant(1 : i32) : i32 - // MLIR-NEXT: llvm.return %4 : i32 + // MLIR-NEXT: %2 = llvm.mlir.constant(1 : i32) : i32 + // MLIR-NEXT: llvm.return %2 : i32 // MLIR-NEXT: ^bb2: // pred: ^bb0 // MLIR-NEXT: llvm.return %arg0 : i32 // MLIR-NEXT: } diff --git a/clang/test/CIR/Lowering/switch.cir b/clang/test/CIR/Lowering/switch.cir index 92f8e4654a40..5931d49de3a4 100644 --- a/clang/test/CIR/Lowering/switch.cir +++ b/clang/test/CIR/Lowering/switch.cir @@ -171,7 +171,7 @@ module { // CHECK: ^bb2: // pred: ^bb1 // CHECK: llvm.br ^bb3 // CHECK: ^bb3: // pred: ^bb2 - // CHECK: llvm.cond_br %14, ^bb4, ^bb5 + // CHECK: llvm.cond_br {{%.*}}, ^bb4, ^bb5 // CHECK: ^bb4: // pred: ^bb3 // CHECK: llvm.br ^bb7 // CHECK: ^bb5: // pred: ^bb3