diff --git a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td index 7bbf18fe0106f..152715f281088 100644 --- a/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td +++ b/mlir/include/mlir/Dialect/LLVMIR/NVVMOps.td @@ -123,52 +123,67 @@ class NVVM_SpecialRegisterOp traits = []> : let assemblyFormat = "attr-dict `:` type($res)"; } +class NVVM_SpecialRangeableRegisterOp traits = []> : + NVVM_SpecialRegisterOp { + let arguments = (ins OptionalAttr:$range); + let assemblyFormat = "(`range` $range^)? attr-dict `:` type($res)"; + let llvmBuilder = baseLlvmBuilder # setRangeRetAttrCode # baseLlvmBuilderCoda; + let mlirBuilder = baseMlirBuilder # importRangeRetAttrCode # baseMlirBuilderCoda; + + // Backwards-compatibility builder for an unspecified range. + let builders = [ + OpBuilder<(ins "Type":$resultType), [{ + build($_builder, $_state, resultType, ::mlir::LLVM::ConstantRangeAttr{}); + }]> + ]; +} + //===----------------------------------------------------------------------===// // Lane index and range -def NVVM_LaneIdOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.laneid">; -def NVVM_WarpSizeOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.warpsize">; +def NVVM_LaneIdOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.laneid">; +def NVVM_WarpSizeOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.warpsize">; //===----------------------------------------------------------------------===// // Thread index and range -def NVVM_ThreadIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.x">; -def NVVM_ThreadIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.y">; -def NVVM_ThreadIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.tid.z">; -def NVVM_BlockDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.x">; -def NVVM_BlockDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.y">; -def NVVM_BlockDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ntid.z">; +def NVVM_ThreadIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.x">; +def NVVM_ThreadIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.y">; +def NVVM_ThreadIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.tid.z">; +def NVVM_BlockDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.x">; +def NVVM_BlockDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.y">; +def NVVM_BlockDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ntid.z">; //===----------------------------------------------------------------------===// // Block index and range -def NVVM_BlockIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.x">; -def NVVM_BlockIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.y">; -def NVVM_BlockIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.ctaid.z">; -def NVVM_GridDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.x">; -def NVVM_GridDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.y">; -def NVVM_GridDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nctaid.z">; +def NVVM_BlockIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.x">; +def NVVM_BlockIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.y">; +def NVVM_BlockIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.ctaid.z">; +def NVVM_GridDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.x">; +def NVVM_GridDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.y">; +def NVVM_GridDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nctaid.z">; //===----------------------------------------------------------------------===// // CTA Cluster index and range -def NVVM_ClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.x">; -def NVVM_ClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.y">; -def NVVM_ClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.clusterid.z">; -def NVVM_ClusterDimXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.x">; -def NVVM_ClusterDimYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.y">; -def NVVM_ClusterDimZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.nclusterid.z">; +def NVVM_ClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.x">; +def NVVM_ClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.y">; +def NVVM_ClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.clusterid.z">; +def NVVM_ClusterDimXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.x">; +def NVVM_ClusterDimYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.y">; +def NVVM_ClusterDimZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.nclusterid.z">; //===----------------------------------------------------------------------===// // CTA index and range within Cluster -def NVVM_BlockInClusterIdXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.x">; -def NVVM_BlockInClusterIdYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.y">; -def NVVM_BlockInClusterIdZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctaid.z">; -def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.x">; -def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.y">; -def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctaid.z">; +def NVVM_BlockInClusterIdXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.x">; +def NVVM_BlockInClusterIdYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.y">; +def NVVM_BlockInClusterIdZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctaid.z">; +def NVVM_ClusterDimBlocksXOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.x">; +def NVVM_ClusterDimBlocksYOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.y">; +def NVVM_ClusterDimBlocksZOp : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctaid.z">; //===----------------------------------------------------------------------===// // CTA index and across Cluster dimensions -def NVVM_ClusterId : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.ctarank">; -def NVVM_ClusterDim : NVVM_SpecialRegisterOp<"read.ptx.sreg.cluster.nctarank">; +def NVVM_ClusterId : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.ctarank">; +def NVVM_ClusterDim : NVVM_SpecialRangeableRegisterOp<"read.ptx.sreg.cluster.nctarank">; //===----------------------------------------------------------------------===// // Clock registers diff --git a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp index 9b1be198f77a8..164622d77e6b6 100644 --- a/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp +++ b/mlir/lib/Conversion/GPUToNVVM/LowerGpuOpsToNVVMOps.cpp @@ -29,6 +29,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Dialect/Math/IR/Math.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/GreedyPatternRewriteDriver.h" @@ -209,7 +210,15 @@ struct GPULaneIdOpToNVVM : ConvertOpToLLVMPattern { ConversionPatternRewriter &rewriter) const override { auto loc = op->getLoc(); MLIRContext *context = rewriter.getContext(); - Value newOp = rewriter.create(loc, rewriter.getI32Type()); + LLVM::ConstantRangeAttr bounds = nullptr; + if (std::optional upperBound = op.getUpperBound()) + bounds = rewriter.getAttr( + /*bitWidth=*/32, /*lower=*/0, upperBound->getZExtValue()); + else + bounds = rewriter.getAttr( + /*bitWidth=*/32, /*lower=*/0, /*upper=*/kWarpSize); + Value newOp = + rewriter.create(loc, rewriter.getI32Type(), bounds); // Truncate or extend the result depending on the index bitwidth specified // by the LLVMTypeConverter options. const unsigned indexBitwidth = getTypeConverter()->getIndexTypeBitwidth(); @@ -340,27 +349,40 @@ void mlir::populateGpuSubgroupReduceOpLoweringPattern( void mlir::populateGpuToNVVMConversionPatterns(LLVMTypeConverter &converter, RewritePatternSet &patterns) { + using gpu::index_lowering::IndexKind; + using gpu::index_lowering::IntrType; populateWithGenerated(patterns); patterns.add(converter); patterns.add< gpu::index_lowering::OpLowering, + NVVM::ThreadIdYOp, NVVM::ThreadIdZOp>>( + converter, IndexKind::Block, IntrType::Id); + patterns.add< gpu::index_lowering::OpLowering, + NVVM::BlockDimYOp, NVVM::BlockDimZOp>>( + converter, IndexKind::Block, IntrType::Dim); + patterns.add< gpu::index_lowering::OpLowering, - gpu::index_lowering::OpLowering, - gpu::index_lowering::OpLowering< - gpu::ClusterBlockIdOp, NVVM::BlockInClusterIdXOp, - NVVM::BlockInClusterIdYOp, NVVM::BlockInClusterIdZOp>, - gpu::index_lowering::OpLowering, - gpu::index_lowering::OpLowering, - gpu::index_lowering::OpLowering, - GPULaneIdOpToNVVM, GPUShuffleOpLowering, GPUReturnOpLowering>(converter); + NVVM::ClusterIdYOp, NVVM::ClusterIdZOp>>( + converter, IndexKind::Other, IntrType::Id); + patterns.add>(converter, IndexKind::Other, IntrType::Dim); + patterns.add>( + converter, IndexKind::Other, IntrType::Id); + patterns.add>(converter, IndexKind::Other, IntrType::Dim); + patterns.add>( + converter, IndexKind::Block, IntrType::Id); + patterns.add>( + converter, IndexKind::Grid, IntrType::Dim); + patterns.add( + converter); patterns.add( converter, NVVM::kSharedMemoryAlignmentBit); diff --git a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp index 855abc12a909e..bc830a77f3c58 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/NVVM/LLVMIRToNVVMTranslation.cpp @@ -14,6 +14,7 @@ #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/Target/LLVMIR/ModuleImport.h" +#include "llvm/IR/ConstantRange.h" #include "llvm/IR/IntrinsicsNVPTX.h" using namespace mlir; diff --git a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir index 8f2ec289c9252..66ad1e307fc3a 100644 --- a/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir +++ b/mlir/test/Conversion/GPUToNVVM/gpu-to-nvvm.mlir @@ -50,7 +50,7 @@ gpu.module @test_module_0 { %gDimZ = gpu.grid_dim z - // CHECK: = nvvm.read.ptx.sreg.laneid : i32 + // CHECK: = nvvm.read.ptx.sreg.laneid range : i32 // CHECK: = llvm.sext %{{.*}} : i32 to i64 %laneId = gpu.lane_id @@ -699,9 +699,21 @@ gpu.module @test_module_32 { } gpu.module @test_module_33 { -// CHECK-LABEL: func @kernel_with_block_size() -// CHECK: attributes {gpu.kernel, gpu.known_block_size = array, nvvm.kernel, nvvm.maxntid = array} - gpu.func @kernel_with_block_size() kernel attributes {known_block_size = array} { +// CHECK-LABEL: func @kernel_with_block_size( +// CHECK: attributes {gpu.kernel, gpu.known_block_size = array, nvvm.kernel, nvvm.maxntid = array} + gpu.func @kernel_with_block_size(%arg0: !llvm.ptr) kernel attributes {known_block_size = array} { + // CHECK: = nvvm.read.ptx.sreg.tid.x range : i32 + %0 = gpu.thread_id x + // CHECK: = nvvm.read.ptx.sreg.tid.y range : i32 + %1 = gpu.thread_id y + // CHECK: = nvvm.read.ptx.sreg.tid.z range : i32 + %2 = gpu.thread_id z + + // Fake usage to prevent dead code elimination + %3 = arith.addi %0, %1 : index + %4 = arith.addi %3, %2 : index + %5 = arith.index_cast %4 : index to i64 + llvm.store %5, %arg0 : i64, !llvm.ptr gpu.return } } @@ -917,6 +929,20 @@ gpu.module @test_module_48 { } } +gpu.module @test_module_49 { +// CHECK-LABEL: func @explicit_id_bounds() + func.func @explicit_id_bounds() -> (index, index, index) { + // CHECK: = nvvm.read.ptx.sreg.tid.x range : i32 + %0 = gpu.thread_id x upper_bound 32 + // CHECK: = nvvm.read.ptx.sreg.ntid.x range : i32 + %1 = gpu.block_dim x upper_bound 32 + // CHECK: = nvvm.read.ptx.sreg.laneid range : i32 + %2 = gpu.lane_id upper_bound 16 + + return %0, %1, %2 : index, index, index + } +} + module attributes {transform.with_named_sequence} { transform.named_sequence @__transform_main(%toplevel_module: !transform.any_op {transform.readonly}) { %gpu_module = transform.structured.match ops{["gpu.module"]} in %toplevel_module diff --git a/mlir/test/Target/LLVMIR/Import/nvvmir.ll b/mlir/test/Target/LLVMIR/Import/nvvmir.ll index e4a8773e2dd80..131e9065b2d88 100644 --- a/mlir/test/Target/LLVMIR/Import/nvvmir.ll +++ b/mlir/test/Target/LLVMIR/Import/nvvmir.ll @@ -58,6 +58,9 @@ define i32 @nvvm_special_regs() { %27 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.ctarank() ; CHECK: = nvvm.read.ptx.sreg.cluster.nctarank : i32 %28 = call i32 @llvm.nvvm.read.ptx.sreg.cluster.nctarank() + + ; CHECK = nvvm.read.ptx.sreg.tid.x range <0 : i32, 64 : i32> : i32 + %29 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x() ret i32 %1 } diff --git a/mlir/test/Target/LLVMIR/nvvmir.mlir b/mlir/test/Target/LLVMIR/nvvmir.mlir index 88ffb1c7bfdf7..7fd082a5eb3c7 100644 --- a/mlir/test/Target/LLVMIR/nvvmir.mlir +++ b/mlir/test/Target/LLVMIR/nvvmir.mlir @@ -62,7 +62,10 @@ llvm.func @nvvm_special_regs() -> i32 { %29 = nvvm.read.ptx.sreg.clock : i32 // CHECK: call i64 @llvm.nvvm.read.ptx.sreg.clock64 %30 = nvvm.read.ptx.sreg.clock64 : i64 - + + // CHECK: %31 = call range(i32 0, 64) i32 @llvm.nvvm.read.ptx.sreg.tid.x() + %31 = nvvm.read.ptx.sreg.tid.x range : i32 + llvm.return %1 : i32 }