From aebb07a8db6b349a8bbfac0bb147cc281222e0c4 Mon Sep 17 00:00:00 2001 From: Krzysztof Drewniak Date: Thu, 5 Sep 2024 22:41:11 +0000 Subject: [PATCH] Refactor how range() annotations are handled for ROCDL intrinsics This commit introduces a ConstantRange attribute to match the ConstantRange attribute type present in LLVM IR. It then refactors the LLVM_IntrOpBase so that the basic part of the intrinsic builder code can be re-used without needing to copy it or get rid of important context. This, along with adding code for handling an optional `range` attribute to that same base, allows us to make the support for range() annotations generic without adding another bit to IntrOpBase. This commit then updates the lowering of index intrinsic operations to use the new ConstantRange attribute and fixes a bug (where we'd be subtracting 1 from upper bounds instead of adding it on operations like gpu.block_dim) along the way. The point of these changes is to enable these range annotations to be used for the corresponding NVVM operations in a future commit. --- .../mlir/Dialect/LLVMIR/LLVMAttrDefs.td | 31 ++++++++++ .../include/mlir/Dialect/LLVMIR/LLVMOpBase.td | 40 ++++++++++-- mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td | 61 +++++++++++-------- .../GPUCommon/IndexIntrinsicsOpLowering.h | 6 +- mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp | 22 +++++++ .../ROCDL/ROCDLToLLVMIRTranslation.cpp | 36 +++++------ .../Conversion/GPUToROCDL/gpu-to-rocdl.mlir | 12 ++-- mlir/test/Target/LLVMIR/rocdl.mlir | 4 +- 8 files changed, 152 insertions(+), 60 deletions(-) diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index e57be7f760d38..64c69bbe34299 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -1008,6 +1008,37 @@ def LLVM_TBAATagArrayAttr let constBuilderCall = ?; } +//===----------------------------------------------------------------------===// +// ConstantRangeAttr +//===----------------------------------------------------------------------===// +def LLVM_ConstantRangeAttr : LLVM_Attr<"ConstantRange", "constant_range"> { + let parameters = (ins + "IntegerAttr":$lower, + "IntegerAttr":$upper + ); + let summary = "A range of two integers, corresponding to LLVM's ConstantRange"; + let description = [{ + A pair of two integers, mapping to the ConstantRange structure in LLVM IR, + which is allowed to wrap or be empty. + + The range represented is [Lower, Upper), and is either signed or unsigned + depending on context. + + `lower` and `upper` must have the same width. + }]; + + let builders = [ + AttrBuilder<(ins "uint32_t":$bitWidth, "int64_t":$lower, "int64_t":$upper)> + ]; + + let assemblyFormat = [{ + `<` $lower `,` $upper `>` + }]; + + let genVerifyDecl = 1; +} + + //===----------------------------------------------------------------------===// // VScaleRangeAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index 7b9a9cf017c53..12811e2b4848f 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -319,17 +319,19 @@ class LLVM_IntrOpBase mlirOperands; SmallVector mlirAttrs; if (failed(moduleImport.convertIntrinsicArguments( @@ -345,9 +347,35 @@ class LLVM_IntrOpBase( $_location, resultTypes, mlirOperands, mlirAttrs); - }] # !if(!gt(requiresFastmath, 0), + }]; + string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;"); + let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0), "moduleImport.setFastmathFlagsAttr(inst, op);", "") - # !if(!gt(numResults, 0), "$res = op;", "$_op = op;"); + # baseMlirBuilderCoda; + + // Code for handling a `range` attribute that holds the constant range of the + // intrinsic's result (if one is specified at the call site). This is intended + // for GPU IDs and other calls where range() is meaningful. It expects + // an optional LLVM_ConstantRangeAttr named `range` to be present on the + // operation. These are included to abstract out common code in several + // dialects. + string setRangeRetAttrCode = [{ + if ($range) { + inst->addRangeRetAttr(::llvm::ConstantRange( + $range->getLower().getValue(), $range->getUpper().getValue())); + } + }]; + string importRangeRetAttrCode = [{ + // Note: we don't want to look in to the declaration here. + auto rangeAttr = inst->getAttributes().getRetAttr(::llvm::Attribute::Range); + if (rangeAttr.isValid()) { + const ::llvm::ConstantRange& value = rangeAttr.getValueAsConstantRange(); + ::mlir::Type intType = ::mlir::IntegerType::get($_ctxt, value.getBitWidth()); + auto lowerAttr = ::mlir::IntegerAttr::get(intType, value.getLower()); + auto upperAttr = ::mlir::IntegerAttr::get(intType, value.getUpper()); + op.setRangeAttr(::mlir::LLVM::ConstantRangeAttr::get($_ctxt, lowerAttr, upperAttr)); + } + }]; } // Base class for LLVM intrinsic operations, should not be used directly. Places diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index e832dfa9d6b80..c4c278c166696 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -98,23 +98,36 @@ class ROCDL_IntrOp overloadedResults, // ROCDL special register op definitions //===----------------------------------------------------------------------===// -class ROCDL_SpecialRegisterOp traits = []> : - ROCDL_Op, - Results<(outs LLVM_Type:$res)>, Arguments<(ins)> { - string llvmBuilder = "$res = createIntrinsicCallWithRange(builder," - # "llvm::Intrinsic::amdgcn_" # !subst(".","_", mnemonic) - # ", op->getAttrOfType<::mlir::DenseI32ArrayAttr>(\"range\"));"; - let assemblyFormat = "attr-dict `:` type($res)"; +class ROCDL_SpecialIdRegisterOp : + ROCDL_IntrPure1Op, + Arguments<(ins OptionalAttr:$range)> { + string llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda; + string mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda; + + let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)"; + + // Temporaly builder until Nvidia ops also support range attributes. + let builders = [ + OpBuilder<(ins "Type":$resultType), [{ + build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{}); + }]> + ]; } -class ROCDL_DeviceFunctionOp traits = []> : ROCDL_Op, - Results<(outs LLVM_Type:$res)>, Arguments<(ins)> { - string llvmBuilder = "$res = createDeviceFunctionCall(builder, \"" + Results<(outs LLVM_Type:$res)>, Arguments<(ins OptionalAttr:$range)> { + string llvmBuilder = "$res = createDimGetterFunctionCall(builder, op, \"" # device_function # "\", " # parameter # ");"; - let assemblyFormat = "attr-dict `:` type($res)"; + let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)"; + + // Temporaly builder until Nvidia ops also support range attributes. + let builders = [ + OpBuilder<(ins "Type":$resultType), [{ + build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{}); + }]> + ]; } //===----------------------------------------------------------------------===// @@ -181,33 +194,33 @@ def ROCDL_BallotOp : //===----------------------------------------------------------------------===// // Thread index and Block index -def ROCDL_ThreadIdXOp : ROCDL_SpecialRegisterOp<"workitem.id.x">; -def ROCDL_ThreadIdYOp : ROCDL_SpecialRegisterOp<"workitem.id.y">; -def ROCDL_ThreadIdZOp : ROCDL_SpecialRegisterOp<"workitem.id.z">; +def ROCDL_ThreadIdXOp : ROCDL_SpecialIdRegisterOp<"workitem.id.x">; +def ROCDL_ThreadIdYOp : ROCDL_SpecialIdRegisterOp<"workitem.id.y">; +def ROCDL_ThreadIdZOp : ROCDL_SpecialIdRegisterOp<"workitem.id.z">; -def ROCDL_BlockIdXOp : ROCDL_SpecialRegisterOp<"workgroup.id.x">; -def ROCDL_BlockIdYOp : ROCDL_SpecialRegisterOp<"workgroup.id.y">; -def ROCDL_BlockIdZOp : ROCDL_SpecialRegisterOp<"workgroup.id.z">; +def ROCDL_BlockIdXOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.x">; +def ROCDL_BlockIdYOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.y">; +def ROCDL_BlockIdZOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.z">; //===----------------------------------------------------------------------===// // Thread range and Block range -def ROCDL_BlockDimXOp : ROCDL_DeviceFunctionOp<"workgroup.dim.x", +def ROCDL_BlockDimXOp : ROCDL_DimGetterFunctionOp<"workgroup.dim.x", "__ockl_get_local_size", 0>; -def ROCDL_BlockDimYOp : ROCDL_DeviceFunctionOp<"workgroup.dim.y", +def ROCDL_BlockDimYOp : ROCDL_DimGetterFunctionOp<"workgroup.dim.y", "__ockl_get_local_size", 1>; -def ROCDL_BlockDimZOp : ROCDL_DeviceFunctionOp<"workgroup.dim.z", +def ROCDL_BlockDimZOp : ROCDL_DimGetterFunctionOp<"workgroup.dim.z", "__ockl_get_local_size", 2>; -def ROCDL_GridDimXOp : ROCDL_DeviceFunctionOp<"grid.dim.x", +def ROCDL_GridDimXOp : ROCDL_DimGetterFunctionOp<"grid.dim.x", "__ockl_get_num_groups", 0>; -def ROCDL_GridDimYOp : ROCDL_DeviceFunctionOp<"grid.dim.y", +def ROCDL_GridDimYOp : ROCDL_DimGetterFunctionOp<"grid.dim.y", "__ockl_get_num_groups", 1>; -def ROCDL_GridDimZOp : ROCDL_DeviceFunctionOp<"grid.dim.z", +def ROCDL_GridDimZOp : ROCDL_DimGetterFunctionOp<"grid.dim.z", "__ockl_get_num_groups", 2>; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h index e4cd24e0380e7..eaf1554a83f89 100644 --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -114,9 +114,9 @@ struct OpLowering : public ConvertOpToLLVMPattern { if (upperBound && intrType != IntrType::None) { int32_t min = (intrType == IntrType::Dim ? 1 : 0); - int32_t max = *upperBound - (intrType == IntrType::Id ? 0 : 1); - newOp->setAttr( - "range", DenseI32ArrayAttr::get(op.getContext(), ArrayRef{min, max})); + int32_t max = *upperBound + (intrType == IntrType::Id ? 0 : 1); + newOp->setAttr("range", LLVM::ConstantRangeAttr::get( + rewriter.getContext(), 32, min, max)); } if (indexBitwidth > 32) { newOp = rewriter.create( diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp index 98a9659735e7e..dd683f2056cfe 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -215,6 +215,28 @@ DICompositeTypeAttr::getRecSelf(DistinctAttr recId) { {}, DIFlags(), 0, 0, {}, {}, {}, {}, {}); } +//===----------------------------------------------------------------------===// +// ConstantRangeAttr +//===----------------------------------------------------------------------===// +ConstantRangeAttr ConstantRangeAttr::get(MLIRContext *context, + uint32_t bitWidth, int64_t lower, + int64_t upper) { + Type widthType = IntegerType::get(context, bitWidth); + auto lowerAttr = IntegerAttr::get(widthType, lower); + auto upperAttr = IntegerAttr::get(widthType, upper); + return get(context, lowerAttr, upperAttr); +} + +LogicalResult +ConstantRangeAttr::verify(llvm::function_ref emitError, + IntegerAttr lower, IntegerAttr upper) { + if (lower.getType() != upper.getType()) + return emitError() + << "expected lower and upper to have matching types but got " + << lower.getType() << " vs. " << upper.getType(); + return success(); +} + //===----------------------------------------------------------------------===// // TargetFeaturesAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp index c1ee650776356..0ca732b0c4383 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp @@ -26,25 +26,13 @@ using namespace mlir; using namespace mlir::LLVM; using mlir::LLVM::detail::createIntrinsicCall; -static llvm::Value *createIntrinsicCallWithRange(llvm::IRBuilderBase &builder, - llvm::Intrinsic::ID intrinsic, - DenseI32ArrayAttr maybeRange) { - auto *inst = llvm::cast( - createIntrinsicCall(builder, intrinsic, {}, {})); - if (maybeRange) { - llvm::ConstantRange Range(APInt(32, maybeRange[0]), - APInt(32, maybeRange[1])); - inst->addRangeRetAttr(Range); - } - return inst; -} - -// Create a call to ROCm-Device-Library function -// Currently this routine will work only for calling ROCDL functions that -// take a single int32 argument. It is likely that the interface of this -// function will change to make it more generic. -static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder, - StringRef fnName, int parameter) { +// Create a call to ROCm-Device-Library function that returns an ID. +// This is intended to specifically call device functions that fetch things like +// block or grid dimensions, and so is limited to functions that take one +// integer parameter. +static llvm::Value *createDimGetterFunctionCall(llvm::IRBuilderBase &builder, + Operation *op, StringRef fnName, + int parameter) { llvm::Module *module = builder.GetInsertBlock()->getModule(); llvm::FunctionType *functionType = llvm::FunctionType::get( llvm::Type::getInt64Ty(module->getContext()), // return type. @@ -54,7 +42,15 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder, module->getOrInsertFunction(fnName, functionType).getCallee()); llvm::Value *fnOp0 = llvm::ConstantInt::get( llvm::Type::getInt32Ty(module->getContext()), parameter); - return builder.CreateCall(fn, ArrayRef(fnOp0)); + auto *call = builder.CreateCall(fn, ArrayRef(fnOp0)); + if (auto rangeAttr = op->getAttrOfType("range")) { + // Zero-extend to 64 bits because the GPU dialect uses 32-bit bounds but + // these ockl functions are defined to be 64-bits + call->addRangeRetAttr( + llvm::ConstantRange(rangeAttr.getLower().getValue().zext(64), + rangeAttr.getUpper().getValue().zext(64))); + } + return call; } namespace { diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir index bf49a42a11577..74f4be32d28f7 100644 --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -77,18 +77,18 @@ gpu.module @test_module { {known_block_size = array, known_grid_size = array} { - // CHECK: rocdl.workitem.id.x {range = array} : i32 + // CHECK: rocdl.workitem.id.x range <0 : i32, 8 : i32> : i32 %tIdX = gpu.thread_id x - // CHECK: rocdl.workitem.id.y {range = array} : i32 + // CHECK: rocdl.workitem.id.y range <0 : i32, 12 : i32> : i32 %tIdY = gpu.thread_id y - // CHECK: rocdl.workitem.id.z {range = array} : i32 + // CHECK: rocdl.workitem.id.z range <0 : i32, 16 : i32> : i32 %tIdZ = gpu.thread_id z - // CHECK: rocdl.workgroup.id.x {range = array} : i32 + // CHECK: rocdl.workgroup.id.x range <0 : i32, 20 : i32> : i32 %bIdX = gpu.block_id x - // CHECK: rocdl.workgroup.id.y {range = array} : i32 + // CHECK: rocdl.workgroup.id.y range <0 : i32, 24 : i32> : i32 %bIdY = gpu.block_id y - // CHECK: rocdl.workgroup.id.z {range = array} : i32 + // CHECK: rocdl.workgroup.id.z range <0 : i32, 28 : i32> : i32 %bIdZ = gpu.block_id z // "Usage" to make the ID calls not die diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index d902a82eeb9ea..7020168bbd7dd 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -28,8 +28,10 @@ llvm.func @rocdl_special_regs() -> i32 { %12 = rocdl.grid.dim.z : i64 // CHECK: call range(i32 0, 64) i32 @llvm.amdgcn.workitem.id.x() - %13 = rocdl.workitem.id.x {range = array} : i32 + %13 = rocdl.workitem.id.x range <0 : i32, 64 : i32> : i32 + // CHECK: call range(i64 1, 65) i64 @__ockl_get_local_size(i32 0) + %14 = rocdl.workgroup.dim.x range <1 : i32, 65 : i32> : i64 llvm.return %1 : i32 }