diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 5806295cedb19..7cb4b5c346ad9 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -783,24 +783,27 @@ def NVVM_SyncWarpOp : let assemblyFormat = "$mask attr-dict `:` type($mask)"; } - -def NVVM_ElectSyncOp : NVVM_Op<"elect.sync", - [DeclareOpInterfaceMethods]> +def NVVM_ElectSyncOp : NVVM_Op<"elect.sync"> { + let summary = "Elect one leader thread"; + let description = [{ + The `elect.sync` instruction elects one predicated active leader + thread from among a set of threads specified in membermask. + The membermask is set to `0xFFFFFFFF` for the current version + of this Op. The predicate result is set to `True` for the + leader thread, and `False` for all other threads. + + [For more information, see PTX ISA] + (https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#parallel-synchronization-and-communication-instructions-elect-sync) + }]; + let results = (outs I1:$pred); let assemblyFormat = "attr-dict `->` type(results)"; - let extraClassDefinition = [{ - std::string $cppClass::getPtx() { - return std::string( - "{ \n" - ".reg .u32 rx; \n" - ".reg .pred px; \n" - " mov.pred %0, 0; \n" - " elect.sync rx | px, 0xFFFFFFFF;\n" - "@px mov.pred %0, 1; \n" - "}\n" - ); - } + string llvmBuilder = [{ + auto *resultTuple = createIntrinsicCall(builder, + llvm::Intrinsic::nvvm_elect_sync, {builder.getInt32(0xFFFFFFFF)}); + // Extract the second value into $pred + $pred = builder.CreateExtractValue(resultTuple, 1); }]; } diff --git a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir index 375e2951a037c..66b736c18718f 100644 --- a/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir +++ b/mlir/test/Conversion/NVVMToLLVM/nvvm-to-llvm.mlir @@ -579,13 +579,7 @@ func.func @wgmma_f32_e5m2_e4m3(%descA : i64, %descB : i64) -> !mat32f32 { // ----- func.func @elect_one_leader_sync() { - // CHECK: llvm.inline_asm has_side_effects asm_dialect = att "{ - // CHECK-SAME: .reg .u32 rx; - // CHECK-SAME: .reg .pred px; - // CHECK-SAME: mov.pred $0, 0; - // CHECK-SAME: elect.sync rx | px, 0xFFFFFFFF; - // CHECK-SAME: @px mov.pred $0, 1; - // CHECK-SAME: "=b" : () -> i1 + // CHECK: %[[RES:.*]] = nvvm.elect.sync -> i1 %cnd = nvvm.elect.sync -> i1 return } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 0471e5faf8457..75ce958b43fd3 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -259,6 +259,15 @@ llvm.func @nvvm_vote(%0 : i32, %1 : i1) -> i32 { llvm.return %3 : i32 } +// CHECK-LABEL: @nvvm_elect_sync +llvm.func @nvvm_elect_sync() -> i1 { + // CHECK: %[[RES:.*]] = call { i32, i1 } @llvm.nvvm.elect.sync(i32 -1) + // CHECK-NEXT: %[[PRED:.*]] = extractvalue { i32, i1 } %[[RES]], 1 + // CHECK-NEXT: ret i1 %[[PRED]] + %0 = nvvm.elect.sync -> i1 + llvm.return %0 : i1 +} + // CHECK-LABEL: @nvvm_mma_mn8n8k4_row_col_f32_f32 llvm.func @nvvm_mma_mn8n8k4_row_col_f32_f32(%a0 : vector<2xf16>, %a1 : vector<2xf16>, %b0 : vector<2xf16>, %b1 : vector<2xf16>,