diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td index 49e54df3436ff..2da45eba77655 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMAttrDefs.td @@ -1034,6 +1034,40 @@ def LLVM_TBAATagArrayAttr let constBuilderCall = ?; } +//===----------------------------------------------------------------------===// +// ConstantRangeAttr +//===----------------------------------------------------------------------===// +def LLVM_ConstantRangeAttr : LLVM_Attr<"ConstantRange", "constant_range"> { + let parameters = (ins + "::llvm::APInt":$lower, + "::llvm::APInt":$upper + ); + let summary = "A range of two integers, corresponding to LLVM's ConstantRange"; + let description = [{ + A pair of two integers, mapping to the ConstantRange structure in LLVM IR, + which is allowed to wrap or be empty. + + The range represented is [Lower, Upper), and is either signed or unsigned + depending on context. + + `lower` and `upper` must have the same width. + + Syntax: + ``` + `<` `i`(width($lower)) $lower `,` $upper `>` + }]; + + let builders = [ + AttrBuilder<(ins "uint32_t":$bitWidth, "int64_t":$lower, "int64_t":$upper), [{ + return $_get($_ctxt, ::llvm::APInt(bitWidth, lower), ::llvm::APInt(bitWidth, upper)); + }]> + ]; + + let hasCustomAssemblyFormat = 1; + let genVerifyDecl = 1; +} + + //===----------------------------------------------------------------------===// // VScaleRangeAttr //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td index 7b9a9cf017c53..c3d352d8d0dd4 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td +++ b/mlir/include/mlir/Dialect/LLVMIR/LLVMOpBase.td @@ -319,17 +319,19 @@ class LLVM_IntrOpBase mlirOperands; SmallVector mlirAttrs; if (failed(moduleImport.convertIntrinsicArguments( @@ -345,9 +347,32 @@ class LLVM_IntrOpBase( $_location, resultTypes, mlirOperands, mlirAttrs); - }] # !if(!gt(requiresFastmath, 0), + }]; + string baseMlirBuilderCoda = !if(!gt(numResults, 0), "$res = op;", "$_op = op;"); + let mlirBuilder = baseMlirBuilder # !if(!gt(requiresFastmath, 0), "moduleImport.setFastmathFlagsAttr(inst, op);", "") - # !if(!gt(numResults, 0), "$res = op;", "$_op = op;"); + # baseMlirBuilderCoda; + + // Code for handling a `range` attribute that holds the constant range of the + // intrinsic's result (if one is specified at the call site). This is intended + // for GPU IDs and other calls where range() is meaningful. It expects + // an optional LLVM_ConstantRangeAttr named `range` to be present on the + // operation. These are included to abstract out common code in several + // dialects. + string setRangeRetAttrCode = [{ + if ($range) { + inst->addRangeRetAttr(::llvm::ConstantRange( + $range->getLower(), $range->getUpper())); + } + }]; + string importRangeRetAttrCode = [{ + // Note: we don't want to look in to the declaration here. + auto rangeAttr = inst->getAttributes().getRetAttr(::llvm::Attribute::Range); + if (rangeAttr.isValid()) { + const ::llvm::ConstantRange& value = rangeAttr.getValueAsConstantRange(); + op.setRangeAttr(::mlir::LLVM::ConstantRangeAttr::get($_builder.getContext(), value.getLower(), value.getUpper())); + } + }]; } // Base class for LLVM intrinsic operations, should not be used directly. Places diff --git a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td index 35fd8270ca693..de23246255650 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/ROCDLOps.td @@ -98,23 +98,36 @@ class ROCDL_IntrOp overloadedResults, // ROCDL special register op definitions //===----------------------------------------------------------------------===// -class ROCDL_SpecialRegisterOp traits = []> : - ROCDL_Op, - Results<(outs LLVM_Type:$res)>, Arguments<(ins)> { - string llvmBuilder = "$res = createIntrinsicCallWithRange(builder," - # "llvm::Intrinsic::amdgcn_" # !subst(".","_", mnemonic) - # ", op->getAttrOfType<::mlir::DenseI32ArrayAttr>(\"range\"));"; - let assemblyFormat = "attr-dict `:` type($res)"; +class ROCDL_SpecialIdRegisterOp : + ROCDL_IntrPure1Op, + Arguments<(ins OptionalAttr:$range)> { + string llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda; + string mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda; + + let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)"; + + // Temporaly builder until Nvidia ops also support range attributes. + let builders = [ + OpBuilder<(ins "Type":$resultType), [{ + build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{}); + }]> + ]; } -class ROCDL_DeviceFunctionOp traits = []> : ROCDL_Op, - Results<(outs LLVM_Type:$res)>, Arguments<(ins)> { - string llvmBuilder = "$res = createDeviceFunctionCall(builder, \"" + Results<(outs LLVM_Type:$res)>, Arguments<(ins OptionalAttr:$range)> { + string llvmBuilder = "$res = createDimGetterFunctionCall(builder, op, \"" # device_function # "\", " # parameter # ");"; - let assemblyFormat = "attr-dict `:` type($res)"; + let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)"; + + // Temporaly builder until Nvidia ops also support range attributes. + let builders = [ + OpBuilder<(ins "Type":$resultType), [{ + build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{}); + }]> + ]; } //===----------------------------------------------------------------------===// @@ -181,33 +194,33 @@ def ROCDL_BallotOp : //===----------------------------------------------------------------------===// // Thread index and Block index -def ROCDL_ThreadIdXOp : ROCDL_SpecialRegisterOp<"workitem.id.x">; -def ROCDL_ThreadIdYOp : ROCDL_SpecialRegisterOp<"workitem.id.y">; -def ROCDL_ThreadIdZOp : ROCDL_SpecialRegisterOp<"workitem.id.z">; +def ROCDL_ThreadIdXOp : ROCDL_SpecialIdRegisterOp<"workitem.id.x">; +def ROCDL_ThreadIdYOp : ROCDL_SpecialIdRegisterOp<"workitem.id.y">; +def ROCDL_ThreadIdZOp : ROCDL_SpecialIdRegisterOp<"workitem.id.z">; -def ROCDL_BlockIdXOp : ROCDL_SpecialRegisterOp<"workgroup.id.x">; -def ROCDL_BlockIdYOp : ROCDL_SpecialRegisterOp<"workgroup.id.y">; -def ROCDL_BlockIdZOp : ROCDL_SpecialRegisterOp<"workgroup.id.z">; +def ROCDL_BlockIdXOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.x">; +def ROCDL_BlockIdYOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.y">; +def ROCDL_BlockIdZOp : ROCDL_SpecialIdRegisterOp<"workgroup.id.z">; //===----------------------------------------------------------------------===// // Thread range and Block range -def ROCDL_BlockDimXOp : ROCDL_DeviceFunctionOp<"workgroup.dim.x", +def ROCDL_BlockDimXOp : ROCDL_DimGetterFunctionOp<"workgroup.dim.x", "__ockl_get_local_size", 0>; -def ROCDL_BlockDimYOp : ROCDL_DeviceFunctionOp<"workgroup.dim.y", +def ROCDL_BlockDimYOp : ROCDL_DimGetterFunctionOp<"workgroup.dim.y", "__ockl_get_local_size", 1>; -def ROCDL_BlockDimZOp : ROCDL_DeviceFunctionOp<"workgroup.dim.z", +def ROCDL_BlockDimZOp : ROCDL_DimGetterFunctionOp<"workgroup.dim.z", "__ockl_get_local_size", 2>; -def ROCDL_GridDimXOp : ROCDL_DeviceFunctionOp<"grid.dim.x", +def ROCDL_GridDimXOp : ROCDL_DimGetterFunctionOp<"grid.dim.x", "__ockl_get_num_groups", 0>; -def ROCDL_GridDimYOp : ROCDL_DeviceFunctionOp<"grid.dim.y", +def ROCDL_GridDimYOp : ROCDL_DimGetterFunctionOp<"grid.dim.y", "__ockl_get_num_groups", 1>; -def ROCDL_GridDimZOp : ROCDL_DeviceFunctionOp<"grid.dim.z", +def ROCDL_GridDimZOp : ROCDL_DimGetterFunctionOp<"grid.dim.z", "__ockl_get_num_groups", 2>; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h index e4cd24e0380e7..eaf1554a83f89 100644 --- a/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h +++ b/mlir/lib/Conversion/GPUCommon/IndexIntrinsicsOpLowering.h @@ -114,9 +114,9 @@ struct OpLowering : public ConvertOpToLLVMPattern { if (upperBound && intrType != IntrType::None) { int32_t min = (intrType == IntrType::Dim ? 1 : 0); - int32_t max = *upperBound - (intrType == IntrType::Id ? 0 : 1); - newOp->setAttr( - "range", DenseI32ArrayAttr::get(op.getContext(), ArrayRef{min, max})); + int32_t max = *upperBound + (intrType == IntrType::Id ? 0 : 1); + newOp->setAttr("range", LLVM::ConstantRangeAttr::get( + rewriter.getContext(), 32, min, max)); } if (indexBitwidth > 32) { newOp = rewriter.create( diff --git a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp index 491dcc7f01e73..6047c4a7ef515 100644 --- a/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp +++ b/mlir/lib/Dialect/LLVMIR/IR/LLVMAttrs.cpp @@ -232,6 +232,47 @@ DIRecursiveTypeAttrInterface DISubprogramAttr::getRecSelf(DistinctAttr recId) { {}, {}, {}, {}, {}, 0, 0, {}, {}, {}, {}); } +//===----------------------------------------------------------------------===// +// ConstantRangeAttr +//===----------------------------------------------------------------------===// + +Attribute ConstantRangeAttr::parse(AsmParser &parser, Type odsType) { + llvm::SMLoc loc = parser.getCurrentLocation(); + IntegerType widthType; + if (parser.parseLess() || parser.parseType(widthType) || + parser.parseComma()) { + return Attribute{}; + } + unsigned bitWidth = widthType.getWidth(); + APInt lower(bitWidth, 0); + APInt upper(bitWidth, 0); + if (parser.parseInteger(lower) || parser.parseComma() || + parser.parseInteger(upper) || parser.parseGreater()) + return Attribute{}; + // For some reason, 0 is always parsed as 64-bits, fix that if needed. + if (lower.isZero()) + lower = lower.sextOrTrunc(bitWidth); + if (upper.isZero()) + upper = upper.sextOrTrunc(bitWidth); + return parser.getChecked(loc, parser.getContext(), lower, + upper); +} + +void ConstantRangeAttr::print(AsmPrinter &printer) const { + printer << ""; +} + +LogicalResult +ConstantRangeAttr::verify(llvm::function_ref emitError, + APInt lower, APInt upper) { + if (lower.getBitWidth() != upper.getBitWidth()) + return emitError() + << "expected lower and upper to have matching bitwidths but got " + << lower.getBitWidth() << " vs. " << upper.getBitWidth(); + return success(); +} + //===----------------------------------------------------------------------===// // TargetFeaturesAttr //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp index c1ee650776356..ec21fbf714c24 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/ROCDL/ROCDLToLLVMIRTranslation.cpp @@ -26,25 +26,13 @@ using namespace mlir; using namespace mlir::LLVM; using mlir::LLVM::detail::createIntrinsicCall; -static llvm::Value *createIntrinsicCallWithRange(llvm::IRBuilderBase &builder, - llvm::Intrinsic::ID intrinsic, - DenseI32ArrayAttr maybeRange) { - auto *inst = llvm::cast( - createIntrinsicCall(builder, intrinsic, {}, {})); - if (maybeRange) { - llvm::ConstantRange Range(APInt(32, maybeRange[0]), - APInt(32, maybeRange[1])); - inst->addRangeRetAttr(Range); - } - return inst; -} - -// Create a call to ROCm-Device-Library function -// Currently this routine will work only for calling ROCDL functions that -// take a single int32 argument. It is likely that the interface of this -// function will change to make it more generic. -static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder, - StringRef fnName, int parameter) { +// Create a call to ROCm-Device-Library function that returns an ID. +// This is intended to specifically call device functions that fetch things like +// block or grid dimensions, and so is limited to functions that take one +// integer parameter. +static llvm::Value *createDimGetterFunctionCall(llvm::IRBuilderBase &builder, + Operation *op, StringRef fnName, + int parameter) { llvm::Module *module = builder.GetInsertBlock()->getModule(); llvm::FunctionType *functionType = llvm::FunctionType::get( llvm::Type::getInt64Ty(module->getContext()), // return type. @@ -54,7 +42,14 @@ static llvm::Value *createDeviceFunctionCall(llvm::IRBuilderBase &builder, module->getOrInsertFunction(fnName, functionType).getCallee()); llvm::Value *fnOp0 = llvm::ConstantInt::get( llvm::Type::getInt32Ty(module->getContext()), parameter); - return builder.CreateCall(fn, ArrayRef(fnOp0)); + auto *call = builder.CreateCall(fn, ArrayRef(fnOp0)); + if (auto rangeAttr = op->getAttrOfType("range")) { + // Zero-extend to 64 bits because the GPU dialect uses 32-bit bounds but + // these ockl functions are defined to be 64-bits + call->addRangeRetAttr(llvm::ConstantRange(rangeAttr.getLower().zext(64), + rangeAttr.getUpper().zext(64))); + } + return call; } namespace { diff --git a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir index b6fb08522ae1f..56b65beb03695 100644 --- a/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir +++ b/mlir/test/Conversion/GPUToROCDL/gpu-to-rocdl.mlir @@ -77,18 +77,18 @@ gpu.module @test_module { {known_block_size = array, known_grid_size = array} { - // CHECK: rocdl.workitem.id.x {range = array} : i32 + // CHECK: rocdl.workitem.id.x range : i32 %tIdX = gpu.thread_id x - // CHECK: rocdl.workitem.id.y {range = array} : i32 + // CHECK: rocdl.workitem.id.y range : i32 %tIdY = gpu.thread_id y - // CHECK: rocdl.workitem.id.z {range = array} : i32 + // CHECK: rocdl.workitem.id.z range : i32 %tIdZ = gpu.thread_id z - // CHECK: rocdl.workgroup.id.x {range = array} : i32 + // CHECK: rocdl.workgroup.id.x range : i32 %bIdX = gpu.block_id x - // CHECK: rocdl.workgroup.id.y {range = array} : i32 + // CHECK: rocdl.workgroup.id.y range : i32 %bIdY = gpu.block_id y - // CHECK: rocdl.workgroup.id.z {range = array} : i32 + // CHECK: rocdl.workgroup.id.z range : i32 %bIdZ = gpu.block_id z // "Usage" to make the ID calls not die diff --git a/mlir/test/Target/LLVMIR/rocdl.mlir b/mlir/test/Target/LLVMIR/rocdl.mlir index 97b505746fc75..0f0c2412e5ec2 100644 --- a/mlir/test/Target/LLVMIR/rocdl.mlir +++ b/mlir/test/Target/LLVMIR/rocdl.mlir @@ -28,8 +28,10 @@ llvm.func @rocdl_special_regs() -> i32 { %12 = rocdl.grid.dim.z : i64 // CHECK: call range(i32 0, 64) i32 @llvm.amdgcn.workitem.id.x() - %13 = rocdl.workitem.id.x {range = array} : i32 + %13 = rocdl.workitem.id.x range : i32 + // CHECK: call range(i64 1, 65) i64 @__ockl_get_local_size(i32 0) + %14 = rocdl.workgroup.dim.x range : i64 llvm.return %1 : i32 }