diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 9cdd961d96ff5..108d7237ff703 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -767,4 +767,40 @@ def AMDGPU_WMMAOp : let hasVerifier = 1; } +def AMDGPU_GatherToLDSOp : + AMDGPU_Op<"gather_to_lds", [SameVariadicOperandSize]>, + Arguments<(ins + Arg:$src, + Variadic:$srcIndices, + Arg:$dst, + Variadic:$dstIndices, + TypeAttr:$transferType + )>, + Results<(outs)> { + let summary = "MLIR wrapper for CDNA mfma instructions"; + let description = [{ + The `amdgpu.global_load` op is a wrapper around the `global_load_lds` instructions. + + Operands: + * `$src`: global memory memref to read from. + * `$srcIndices`: indices into `$src` to read from for this thread. + * `$dst`: LDS memory memref to write to. + * `$dstIndices`: base indices into `$dst` to write to for the subgroup of this thread. + The elements gathered by the subgroup will be written in order of lane ID will be written + into contiguously starting at `$dst[$dstIndices]`. + * `$transferType`: type of the data to be transferred by each thread. This is used to determine + the size of the data to be transferred and the number of threads in the subgroup. + The transfer type must be a scalar type or a vector type with a single element type. + + The `$dst`, along with its indices, points to the memory location the subgroup of this thread + will write to. + + Note: only enabled for gfx942 and later. + }]; + let assemblyFormat = [{ + $src `[` $srcIndices `]` `,` $dst `[` $dstIndices `]` attr-dict `:` $transferType `,` type($src) `,` type($dst) + }]; + let hasVerifier = 1; +} + #endif // AMDGPU diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 56d40d6d123bf..5f697bdeef566 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -1010,6 +1010,55 @@ struct WMMAOpLowering : public ConvertOpToLLVMPattern { } }; +struct GatherToLDSOpLowering : public ConvertOpToLLVMPattern { + GatherToLDSOpLowering(const LLVMTypeConverter &converter, Chipset chipset) + : ConvertOpToLLVMPattern(converter), chipset(chipset) {} + + Chipset chipset; + + LogicalResult + matchAndRewrite(GatherToLDSOp op, GatherToLDSOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (chipset < kGfx942) + return op.emitOpError("chipset not supported"); + + Location loc = op.getLoc(); + + auto srcMemRefType = cast(op.getSrc().getType()); + auto dstMemRefType = cast(op.getSrc().getType()); + + // TODO: instead of only transfering one element per thread, we could + // augment it to transfer multiple elements per thread by issuing multiple + // `global_load_lds` instructions. + Type transferType = op.getTransferType(); + size_t loadWidth = [&]() -> size_t { + if (auto transferVectorType = dyn_cast(transferType)) { + return transferVectorType.getNumElements() * + (transferVectorType.getElementTypeBitWidth() / 8); + } else { + return transferType.getIntOrFloatBitWidth() / 8; + } + }(); + + // Currently only 1, 2, and 4 byte loads are supported. + if (loadWidth != 1 && loadWidth != 2 && loadWidth != 4) + return op.emitOpError("chipset unsupported element size"); + + Value srcPtr = getStridedElementPtr(loc, srcMemRefType, adaptor.getSrc(), + (adaptor.getSrcIndices()), rewriter); + Value dstPtr = getStridedElementPtr(loc, dstMemRefType, adaptor.getDst(), + (adaptor.getDstIndices()), rewriter); + + rewriter.replaceOpWithNewOp( + op, srcPtr, dstPtr, createI32Constant(rewriter, loc, loadWidth), + createI32Constant(rewriter, loc, 0), + createI32Constant(rewriter, loc, 0), ArrayAttr{}, ArrayAttr{}, + ArrayAttr{}); + + return success(); + } +}; + namespace { struct ExtPackedFp8OpLowering final : public ConvertOpToLLVMPattern { @@ -1393,6 +1442,6 @@ void mlir::populateAMDGPUToROCDLConversionPatterns(LLVMTypeConverter &converter, ROCDL::RawPtrBufferAtomicCmpSwap>, AMDGPUDPPLowering, LDSBarrierOpLowering, SchedBarrierOpLowering, MFMAOpLowering, WMMAOpLowering, ExtPackedFp8OpLowering, - PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering>(converter, - chipset); + PackedTrunc2xFp8OpLowering, PackedStochRoundFp8OpLowering, + GatherToLDSOpLowering>(converter, chipset); } diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 1e482515a4ee0..7f286f938ee60 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/Dialect/LLVMIR/ROCDLDialect.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/IR/Builders.h" #include "mlir/IR/BuiltinTypes.h" #include "mlir/IR/Diagnostics.h" @@ -24,6 +25,7 @@ #include "mlir/IR/PatternMatch.h" #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/TypeSwitch.h" +#include "llvm/IR/DerivedTypes.h" #include #include @@ -112,21 +114,31 @@ LogicalResult FatRawBufferCastOp::verify() { return success(); } +static bool hasGlobalMemorySpace(Attribute memorySpace) { + if (!memorySpace) + return true; + if (auto intMemorySpace = dyn_cast(memorySpace)) + return intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1; + if (auto gpuMemorySpace = dyn_cast(memorySpace)) + return gpuMemorySpace.getValue() == gpu::AddressSpace::Global; + return false; +} + +static bool hasWorkgroupMemorySpace(Attribute memorySpace) { + if (auto intMemorySpace = dyn_cast(memorySpace)) + return intMemorySpace.getInt() == 3; + if (auto gpuMemorySpace = dyn_cast(memorySpace)) + return gpuMemorySpace.getValue() == gpu::AddressSpace::Workgroup; + return false; +} + //===----------------------------------------------------------------------===// // RawBuffer*Op //===----------------------------------------------------------------------===// template static LogicalResult verifyRawBufferOp(T &op) { MemRefType bufferType = llvm::cast(op.getMemref().getType()); - Attribute memorySpace = bufferType.getMemorySpace(); - bool isGlobal = false; - if (!memorySpace) - isGlobal = true; - else if (auto intMemorySpace = llvm::dyn_cast(memorySpace)) - isGlobal = intMemorySpace.getInt() == 0 || intMemorySpace.getInt() == 1; - else if (auto gpuMemorySpace = - llvm::dyn_cast(memorySpace)) - isGlobal = gpuMemorySpace.getValue() == gpu::AddressSpace::Global; + bool isGlobal = hasGlobalMemorySpace(bufferType.getMemorySpace()); if (!isGlobal) return op.emitOpError( @@ -461,6 +473,40 @@ LogicalResult DPPOp::verify() { return success(); } +LogicalResult GatherToLDSOp::verify() { + MemRefType srcType = cast(getSrc().getType()); + MemRefType dstType = cast(getDst().getType()); + + if (!memref::isStaticShapeAndContiguousRowMajor(dstType)) + return emitOpError( + "destination types must have static shape and contiguous"); + + auto elemType = srcType.getElementType(); + // Check $src and $dst element types are the same. + if (elemType != dstType.getElementType()) + return emitOpError("source and destination element types must match"); + + // copy type sizes should be 1, 2, or 4 bytes. + auto transferType = getTransferType(); + size_t transferSize; + if (auto vectorTransfer = dyn_cast(transferType)) { + transferSize = vectorTransfer.getNumElements() * + vectorTransfer.getElementTypeBitWidth(); + } else { + transferSize = transferType.getIntOrFloatBitWidth(); + } + if (transferSize != 8 && transferSize != 16 && transferSize != 32) + return emitOpError("Transfering type size must be 8, 16, or 32 bits"); + + if (!hasGlobalMemorySpace(srcType.getMemorySpace())) + return emitOpError("source memory address space must be Global"); + + if (!hasWorkgroupMemorySpace(dstType.getMemorySpace())) + return emitOpError("destination memory address space must be Workgroup"); + + return success(); +} + #include "mlir/Dialect/AMDGPU/IR/AMDGPUEnums.cpp.inc" #define GET_ATTRDEF_CLASSES diff --git a/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir new file mode 100644 index 0000000000000..b1c16bd5db079 --- /dev/null +++ b/mlir/test/Conversion/AMDGPUToROCDL/load_lds.mlir @@ -0,0 +1,143 @@ +// RUN: mlir-opt %s -convert-amdgpu-to-rocdl=chipset=gfx942 | FileCheck %s + +#gpu_global_addrspace = 1 +#gpu_lds_addrspace = 3 + +// CHECK-LABEL: func @global_load_to_rocdl_f32 +// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xf32, 1>) +func.func @global_load_to_rocdl_f32(%global : memref<128x72xf32, #gpu_global_addrspace>) { + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c32 = arith.constant 32 : index + %alloc = memref.alloc() : memref<64x64xf32, #gpu_lds_addrspace> + // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] + + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64 + // CHECK: %[[C12:.*]] = arith.constant 12 : index + // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]] + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]] + + // CHECK: %[[ALLOC:.*]] = memref.alloc() + // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast + // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] + + // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64 + // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64 + // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64 + + // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]] + // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] + + // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64 + // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64 + // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64 + + // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]] + // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]] + amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] + : f32, memref<128x72xf32, #gpu_global_addrspace>, memref<64x64xf32, #gpu_lds_addrspace> + func.return +} + +// CHECK-LABEL: func @global_load_to_rocdl_i8 +// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi8, 1>) +func.func @global_load_to_rocdl_i8(%global : memref<128x72xi8, #gpu_global_addrspace>) { + // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] + + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64 + // CHECK: %[[C12:.*]] = arith.constant 12 : index + // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]] + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]] + + // CHECK: %[[ALLOC:.*]] = memref.alloc() + // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] + // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] + + // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64 + // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64 + // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64 + + // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]] + // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] + + // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64 + // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64 + // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64 + + // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]] + // CHECK: %[[C1:.*]] = llvm.mlir.constant(1 : i32) : i32 + // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C1]] + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c32 = arith.constant 32 : index + %alloc = memref.alloc() : memref<64x64xi8, #gpu_lds_addrspace> + amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] + : i8, memref<128x72xi8, #gpu_global_addrspace>, memref<64x64xi8, #gpu_lds_addrspace> + func.return +} + +// CHECK-LABEL: func @global_load_to_rocdl_vec +// CHECK-SAME: (%[[ARG0:.*]]: memref<128x72xi16, 1>) +func.func @global_load_to_rocdl_vec(%global : memref<128x72xi16, #gpu_global_addrspace>) { + // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] + + // CHECK: %[[C0:.*]] = arith.constant 0 : index + // CHECK: %[[IC0:.*]] = builtin.unrealized_conversion_cast %c0 : index to i64 + // CHECK: %[[C12:.*]] = arith.constant 12 : index + // CHECK: %[[IC12:.*]] = builtin.unrealized_conversion_cast %[[C12]] + // CHECK: %[[C32:.*]] = arith.constant 32 : index + // CHECK: %[[IC32:.*]] = builtin.unrealized_conversion_cast %[[C32]] + + // CHECK: %[[ALLOC:.*]] = memref.alloc() + // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] + // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] + + // CHECK: %[[C72:.*]] = llvm.mlir.constant(72 : index) : i64 + // CHECK: %[[MUL:.*]] = llvm.mul %[[IC12]], %[[C72]] : i64 + // CHECK: %[[SRC_OFFSET:.*]] = llvm.add %[[MUL]], %[[IC0]] : i64 + + // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRC_OFFSET]]] + // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] + + // CHECK: %[[C72_1:.*]] = llvm.mlir.constant(72 : index) : i64 + // CHECK: %[[MUL_2:.*]] = llvm.mul %[[IC32]], %[[C72_1]] : i64 + // CHECK: %[[DST_OFFSET:.*]] = llvm.add %[[MUL_2]], %[[IC0]] : i64 + + // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DST_OFFSET]]] + // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]] + %c0 = arith.constant 0 : index + %c12 = arith.constant 12 : index + %c32 = arith.constant 32 : index + %alloc = memref.alloc() : memref<64x128xi16, #gpu_lds_addrspace> + amdgpu.gather_to_lds %global[%c12, %c0], %alloc[%c32, %c0] + : vector<2 x i16>, memref<128x72xi16, #gpu_global_addrspace>, memref<64x128xi16, #gpu_lds_addrspace> + func.return +} + + +// CHECK-LABEL: func @global_load_to_rocdl_dynamic_indices +// CHECK-SAME: (%[[ARG0:.*]]: memref<512xi32, 1>, %[[SRC_IDX:.*]]: index, %[[DST_IDX:.*]]: index) +func.func @global_load_to_rocdl_dynamic_indices(%global : memref<512xi32, #gpu_global_addrspace>, %src_idx : index, %dst_idx : index) { + // CHECK: %[[DSTIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[DST_IDX]] + // CHECK: %[[SRCIDX_CAST:.*]] = builtin.unrealized_conversion_cast %[[SRC_IDX]] + // CHECK: %[[GLOBAL_DESC:.*]] = builtin.unrealized_conversion_cast %[[ARG0]] + // CHECK: %[[ALLOC:.*]] = memref.alloc() + // CHECK: %[[LDS_DESC:.*]] = builtin.unrealized_conversion_cast %[[ALLOC]] + // CHECK: %[[GLOBAL_BASE:.*]] = llvm.extractvalue %[[GLOBAL_DESC]][1] + // CHECK: %[[GLOBAL_PTR:.*]] = llvm.getelementptr %[[GLOBAL_BASE]][%[[SRCIDX_CAST]]] + // CHECK: %[[LDS_BASE:.*]] = llvm.extractvalue %[[LDS_DESC]][1] + // CHECK: %[[LDS_PTR:.*]] = llvm.getelementptr %[[LDS_BASE]][%[[DSTIDX_CAST]]] + // CHECK: %[[C4:.*]] = llvm.mlir.constant(4 : i32) : i32 + // CHECK: rocdl.global.load.lds %[[GLOBAL_PTR]], %[[LDS_PTR]], %[[C4]] + %alloc = memref.alloc() : memref<4x64xi32, #gpu_lds_addrspace> + %c0 = arith.constant 0 : index + amdgpu.gather_to_lds %global[%src_idx], %alloc[%dst_idx, %c0] + : i32, memref<512xi32, #gpu_global_addrspace>, memref<4x64xi32, #gpu_lds_addrspace> + func.return +}