From 87689f16e2c06dcebfef2e02aa97946476a1cc0b Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Fri, 15 Dec 2023 09:02:41 -0800 Subject: [PATCH 01/12] [mlir][amdgpu] Shared memory access optimization pass --- mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 27 ++ .../mlir/Dialect/AMDGPU/Transforms/Passes.h | 4 + .../mlir/Dialect/AMDGPU/Transforms/Passes.td | 8 + .../Dialect/AMDGPU/Transforms/Transforms.h | 54 ++++ .../mlir/Dialect/AMDGPU/Transforms/Utils.h | 21 ++ mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp | 15 ++ .../Dialect/AMDGPU/Transforms/CMakeLists.txt | 2 + .../Transforms/OptimizeSharedMemory.cpp | 252 ++++++++++++++++++ mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp | 48 ++++ .../AMDGPU/optimize_shmem_reads_writes.mlir | 57 ++++ 10 files changed, 488 insertions(+) create mode 100644 mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h create mode 100644 mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h create mode 100644 mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp create mode 100644 mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp create mode 100644 mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index ffb302fcedd73..324c656f47599 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -29,6 +29,33 @@ def AMDGPU_Dialect : Dialect { "gpu::GPUDialect" ]; let useDefaultAttributePrinterParser = 1; + + let extraClassDeclaration = [{ + /// Return true if the given MemRefType has an integer address + /// space that matches the ROCDL shared memory address space or + /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`. + static bool hasSharedMemoryAddressSpace(MemRefType type); + + /// Return true if the given Attribute has an integer address + /// space that matches the ROCDL shared memory address space or + /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`. + static bool isSharedMemoryAddressSpace(Attribute type); + + /// Defines the MemRef memory space attribute numeric value that indicates + /// a memref is located in global memory. This should correspond to the + /// value used in ROCDL. + static constexpr unsigned kGlobalMemoryAddressSpace = 1; + + /// Defines the MemRef memory space attribute numeric value that indicates + /// a memref is located in private memory. This should correspond to the + /// value used in ROCDL. + static constexpr unsigned kPrivateMemoryAddressSpace = 2; + + /// Defines the MemRef memory space attribute numeric value that indicates + /// a memref is located in shared memory. This should correspond to the + /// value used in ROCDL. + static constexpr unsigned kSharedMemoryAddressSpace = 3; + }]; } //===----------------------------------------------------------------------===// diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h index 8dd5ff1a4b198..752078cd6930e 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h @@ -21,6 +21,10 @@ class ConversionTarget; namespace amdgpu { #define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS + +/// Create a pass to optimize shared memory reads and writes. +std::unique_ptr createOptimizeSharedMemoryPass(); + #define GEN_PASS_REGISTRATION #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td index e6b27aa842dfc..1c12ca9827112 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td @@ -30,4 +30,12 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> { "Chipset that these operations will run on">]; } +def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> { + let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts."; + let constructor = "mlir::amdgpu::createOptimizeSharedMemoryPass()"; + let dependentDialects = [ + "memref::MemRefDialect", "vector::VectorDialect" + ]; +} + #endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_PASSES_TD_ diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h new file mode 100644 index 0000000000000..140bc12deed69 --- /dev/null +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Transforms.h @@ -0,0 +1,54 @@ +//===- Transforms.h - AMDGPU Dialect transformations --------------*- +// C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file declares functions that assist transformations for the amdgpu +// dialect. +// +//===----------------------------------------------------------------------===// +#ifndef MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_ +#define MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_ + +#include "mlir/IR/Operation.h" +#include "mlir/Support/LogicalResult.h" + +namespace mlir { +class RewriterBase; + +namespace amdgpu { + +/// +/// Passes +/// + +/// Optimizes vectorized accesses to a shared memory buffer specified by +/// memrefValue. This transformation assumes the following: +/// 1) All relevant accesses to `memrefValue` are contained with `parentOp`. +/// 2) The function will fail precondition checks if any subviews are +/// taken of `memrefValue`. All reads/writes to `memrefValue` should occur +/// through `memrefValue` directly. +/// +/// Shared memory bank conflicts occur when multiple threads attempt to read or +/// write locations assigned to the same shared memory bank. For `2^N` byte +/// vectorized accesses, we need to be concerned with conflicts among threads +/// identified as `(tid) -> tid.floordiv(2^{7-N})`. As such, this transformation +/// changes any indexed memory access (vector.load, memref.load, etc) +/// such that the final dimension's index value is permuted such that +/// `newColIndex = oldColIndex % vectorSize + +/// perm[rowIndex](oldColIndex/vectorSize, rowIndex)` where `rowIndex` is the +/// index for the second-to last dimension and `perm[rowIndex]` is a permutation +/// function that depends on the row Index. The permutation function is chosen +/// to ensure that sequential distributed+vectorized reads/writes down a single +/// dimension of the memref have minimal conflicts. +mlir::LogicalResult optimizeSharedMemoryReadsAndWrites(Operation *parentOp, + Value memrefValue); + +} // namespace amdgpu +} // namespace mlir + +#endif // MLIR_DIALECT_AMDGPU_TRANSFORMS_TRANSFORMS_H_ diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h new file mode 100644 index 0000000000000..bee3af1914fee --- /dev/null +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h @@ -0,0 +1,21 @@ +//===- Utils.h - Transform utilities -----------------------------*- C++-*-===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/IR/Operation.h" + +namespace mlir { +namespace amdgpu { + +/// Get the indices that the given load/store operation is operating on. +Operation::operand_range getIndices(Operation *op); + +/// Set the indices that the given load/store operation is operating on. +void setIndices(Operation *op, ArrayRef indices); + +} // namespace amdgpu +} // namespace mlir diff --git a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp index 2575ad4984814..4e72fbf56b80a 100644 --- a/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp +++ b/mlir/lib/Dialect/AMDGPU/IR/AMDGPUDialect.cpp @@ -43,6 +43,21 @@ void AMDGPUDialect::initialize() { >(); } +bool amdgpu::AMDGPUDialect::isSharedMemoryAddressSpace(Attribute memorySpace) { + if (!memorySpace) + return false; + if (auto intAttr = llvm::dyn_cast(memorySpace)) + return intAttr.getInt() == AMDGPUDialect::kSharedMemoryAddressSpace; + if (auto gpuAttr = llvm::dyn_cast(memorySpace)) + return gpuAttr.getValue() == gpu::AddressSpace::Workgroup; + return false; +} + +bool amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(MemRefType type) { + Attribute memorySpace = type.getMemorySpace(); + return isSharedMemoryAddressSpace(memorySpace); +} + //===----------------------------------------------------------------------===// // 8-bit float ops //===----------------------------------------------------------------------===// diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt index e11b6cc88bf22..a1a91270bc55c 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt +++ b/mlir/lib/Dialect/AMDGPU/Transforms/CMakeLists.txt @@ -1,5 +1,7 @@ add_mlir_dialect_library(MLIRAMDGPUTransforms EmulateAtomics.cpp + OptimizeSharedMemory.cpp + Utils.cpp ADDITIONAL_HEADER_DIRS {$MLIR_MAIN_INCLUDE_DIR}/mlir/Dialect/AMDGPU/Transforms diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp new file mode 100644 index 0000000000000..0a2f04f4e6487 --- /dev/null +++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp @@ -0,0 +1,252 @@ +//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation +//----------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +// +// This file implements transforms to optimize accesses to shared memory. +// It is inspired by +// https://github.com/llvm/llvm-project/blob/main/mlir/lib/Dialect/NVGPU/Transforms/OptimizeSharedMemory.cpp +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/AMDGPU/Transforms/Transforms.h" +#include "mlir/Dialect/AMDGPU/Transforms/Utils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/Interfaces/SideEffectInterfaces.h" +#include "mlir/Support/LogicalResult.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/Support/MathExtras.h" + +namespace mlir { +namespace amdgpu { +#define GEN_PASS_DEF_OPTIMIZESHAREDMEMORY +#include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" +} // namespace amdgpu +} // namespace mlir + +using namespace mlir; +using namespace mlir::amdgpu; + +/// The size of a shared memory line according to AMD documentation. +/// https://www.amd.com/content/dam/amd/en/documents/instinct-tech-docs/instruction-set-architectures/instinct-mi200-cdna2-instruction-set-architecture.pdf +constexpr int64_t kSharedMemoryLineSizeBytes = 64; +/// We optimize for 64bit accesses, but this can be made an argument in the +/// future. +constexpr int64_t kDefaultVectorSizeBits = 64; + +/// Uses `srcIndexValue` to permute `tgtIndexValue` via +/// `result = xor(floordiv(srcIdxVal,permuteEveryN), +/// floordiv(tgtIdxVal,vectorSize))) +/// + tgtIdxVal % vectorSize` +/// This is done using an optimized sequence of `arith` operations. +static Value permuteVectorOffset(OpBuilder &b, Location loc, + ArrayRef indices, MemRefType memrefTy, + int64_t srcDim, int64_t tgtDim) { + // Adjust the src index to change how often the permutation changes + // if necessary. + Value src = indices[srcDim]; + + // We only want to permute every N iterations of the target dim where N is + // ceil(sharedMemoryLineSizeBytes / dimSizeBytes(tgtDim)). + const int64_t permuteEveryN = std::max( + 1, kSharedMemoryLineSizeBytes / ((memrefTy.getDimSize(tgtDim) * + memrefTy.getElementTypeBitWidth()) / + 8)); + + // clang-format off + // Index bit representation (b0 = least significant bit) for dim(1) + // of a `memref` is as follows: + // N := log2(128/elementSizeBits) + // M := log2(dimSize(1)) + // then + // bits[0:N] = sub-vector element offset + // bits[N:M] = vector index + // clang-format on + int64_t n = + llvm::Log2_64(kDefaultVectorSizeBits / memrefTy.getElementTypeBitWidth()); + int64_t m = llvm::Log2_64(memrefTy.getDimSize(tgtDim)); + + // Capture bits[0:(M-N)] of src by first creating a (M-N) mask. + int64_t mask = (1LL << (m - n)) - 1; + if (permuteEveryN > 1) + mask = mask << llvm::Log2_64(permuteEveryN); + Value srcBits = b.create(loc, mask); + srcBits = b.create(loc, src, srcBits); + + // Use the src bits to permute the target bits b[N:M] containing the + // vector offset. + if (permuteEveryN > 1) { + int64_t shlBits = n - llvm::Log2_64(permuteEveryN); + if (shlBits > 0) { + Value finalShiftVal = b.create(loc, shlBits); + srcBits = b.createOrFold(loc, srcBits, finalShiftVal); + } else if (shlBits < 0) { + Value finalShiftVal = b.create(loc, -1 * shlBits); + srcBits = b.createOrFold(loc, srcBits, finalShiftVal); + } + } else { + Value finalShiftVal = b.create(loc, n); + srcBits = b.createOrFold(loc, srcBits, finalShiftVal); + } + + Value permutedVectorIdx = + b.create(loc, indices[tgtDim], srcBits); + return permutedVectorIdx; +} + +static void transformIndices(OpBuilder &builder, Location loc, + SmallVector &indices, + MemRefType memrefTy, int64_t srcDim, + int64_t tgtDim) { + indices[tgtDim] = + permuteVectorOffset(builder, loc, indices, memrefTy, srcDim, tgtDim); +} + +/// Return all operations within `parentOp` that read from or write to +/// `shmMemRef`. +static LogicalResult +getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, + SmallVector &readOps, + SmallVector &writeOps) { + parentOp->walk([&](Operation *op) { + MemoryEffectOpInterface iface = dyn_cast(op); + if (!iface) + return; + std::optional effect = + iface.getEffectOnValue(shmMemRef); + if (effect) { + readOps.push_back(op); + return; + } + effect = iface.getEffectOnValue(shmMemRef); + if (effect) + writeOps.push_back(op); + }); + + // Restrict to a supported set of ops. We also require at least 2D access, + // although this could be relaxed. + if (llvm::any_of(readOps, [](Operation *op) { + return !isa( + op) || + amdgpu::getIndices(op).size() < 2; + })) + return failure(); + if (llvm::any_of(writeOps, [](Operation *op) { + return !isa( + op) || + amdgpu::getIndices(op).size() < 2; + })) + return failure(); + + return success(); +} + +mlir::LogicalResult +mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, + Value memrefValue) { + auto memRefType = dyn_cast(memrefValue.getType()); + if (!memRefType || + !amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace(memRefType)) + return failure(); + + // Abort if the given value has any sub-views; we do not do any alias + // analysis. + bool hasSubView = false; + parentOp->walk([&](memref::SubViewOp subView) { hasSubView = true; }); + if (hasSubView) + return failure(); + + // Check if this is necessary given the assumption of 128b accesses: + // If dim[rank-1] is small enough to fit 8 rows in a 128B line. + const int64_t rowSize = memRefType.getDimSize(memRefType.getRank() - 1); + const int64_t rowsPerLine = + (8 * kSharedMemoryLineSizeBytes / memRefType.getElementTypeBitWidth()) / + rowSize; + const int64_t threadGroupSize = + 1LL << (7 - llvm::Log2_64(kDefaultVectorSizeBits / 8)); + if (rowsPerLine >= threadGroupSize) + return failure(); + + // Get sets of operations within the function that read/write to shared + // memory. + SmallVector shmReadOps; + SmallVector shmWriteOps; + if (failed(getShmReadAndWriteOps(parentOp, memrefValue, shmReadOps, + shmWriteOps))) + return failure(); + + if (shmReadOps.empty() || shmWriteOps.empty()) + return failure(); + + OpBuilder builder(parentOp->getContext()); + + int64_t tgtDim = memRefType.getRank() - 1; + int64_t srcDim = memRefType.getRank() - 2; + + // Transform indices for the ops writing to shared memory. + while (!shmWriteOps.empty()) { + Operation *shmWriteOp = shmWriteOps.back(); + shmWriteOps.pop_back(); + builder.setInsertionPoint(shmWriteOp); + + auto indices = amdgpu::getIndices(shmWriteOp); + SmallVector transformedIndices(indices.begin(), indices.end()); + transformIndices(builder, shmWriteOp->getLoc(), transformedIndices, + memRefType, srcDim, tgtDim); + amdgpu::setIndices(shmWriteOp, transformedIndices); + } + + // Transform indices for the ops reading from shared memory. + while (!shmReadOps.empty()) { + Operation *shmReadOp = shmReadOps.back(); + shmReadOps.pop_back(); + builder.setInsertionPoint(shmReadOp); + + auto indices = amdgpu::getIndices(shmReadOp); + SmallVector transformedIndices(indices.begin(), indices.end()); + transformIndices(builder, shmReadOp->getLoc(), transformedIndices, + memRefType, srcDim, tgtDim); + amdgpu::setIndices(shmReadOp, transformedIndices); + } + + return success(); +} + +namespace { +class OptimizeSharedMemoryPass + : public amdgpu::impl::OptimizeSharedMemoryBase { +public: + OptimizeSharedMemoryPass() = default; + + void runOnOperation() override { + Operation *op = getOperation(); + SmallVector shmAllocOps; + op->walk([&](memref::AllocOp allocOp) { + if (!amdgpu::AMDGPUDialect::hasSharedMemoryAddressSpace( + allocOp.getType())) + return; + shmAllocOps.push_back(allocOp); + }); + for (auto allocOp : shmAllocOps) { + if (failed(optimizeSharedMemoryReadsAndWrites(getOperation(), + allocOp.getMemref()))) + return; + } + } +}; +} // namespace + +std::unique_ptr mlir::amdgpu::createOptimizeSharedMemoryPass() { + return std::make_unique(); +} diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp new file mode 100644 index 0000000000000..a1dc6cf70e7bf --- /dev/null +++ b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp @@ -0,0 +1,48 @@ +//===- Utils.cpp - Transform utilities ------------------------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// + +#include "mlir/Dialect/AMDGPU/Transforms/Utils.h" + +#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" + +using namespace mlir; +using namespace mlir::amdgpu; + +Operation::operand_range amdgpu::getIndices(Operation *op) { + if (auto loadOp = dyn_cast(op)) + return loadOp.getIndices(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getIndices(); + if (auto vectorReadOp = dyn_cast(op)) + return vectorReadOp.getIndices(); + if (auto vectorStoreOp = dyn_cast(op)) + return vectorStoreOp.getIndices(); + if (auto transferReadOp = dyn_cast(op)) + return transferReadOp.getIndices(); + if (auto transferWriteOp = dyn_cast(op)) + return transferWriteOp.getIndices(); + llvm_unreachable("unsupported op type"); +} + +void amdgpu::setIndices(Operation *op, ArrayRef indices) { + if (auto loadOp = dyn_cast(op)) + return loadOp.getIndicesMutable().assign(indices); + if (auto storeOp = dyn_cast(op)) + return storeOp.getIndicesMutable().assign(indices); + if (auto vectorReadOp = dyn_cast(op)) + return vectorReadOp.getIndicesMutable().assign(indices); + if (auto vectorStoreOp = dyn_cast(op)) + return vectorStoreOp.getIndicesMutable().assign(indices); + if (auto transferReadOp = dyn_cast(op)) + return transferReadOp.getIndicesMutable().assign(indices); + if (auto transferWriteOp = dyn_cast(op)) + return transferWriteOp.getIndicesMutable().assign(indices); + llvm_unreachable("unsupported op type"); +} diff --git a/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir new file mode 100644 index 0000000000000..41111dddda520 --- /dev/null +++ b/mlir/test/Dialect/AMDGPU/optimize_shmem_reads_writes.mlir @@ -0,0 +1,57 @@ +// RUN: mlir-opt %s --pass-pipeline='builtin.module(func.func(amdgpu-optimize-shared-memory))' | FileCheck %s + + // CHECK: @optimize_shmem([[arg0:%.+]]: memref<{{.*}}>, [[readRow:%.+]]: index, [[readCol:%.+]]: index, [[writeRow:%.+]]: index, [[writeCol:%.+]]: index, [[fragRow:%.+]]: index, [[fragCol:%.+]]: index, [[fragColPerm:%.+]]: index, [[stRow:%.+]]: index, [[stCol:%.+]]: index) + func.func @optimize_shmem(%arg0: memref<4096x4096xf16>, + %readRow: index, %readCol: index, + %writeRow: index, %writeCol: index, + %fragRow: index, %fragCol: index, + %fragColPerm: index, + %stRow: index, %stCol: index) { + // CHECK: %[[cst:.+]] = arith.constant 0.000000e+00 : f16 + %cst = arith.constant 0.000000e+00 : f16 + + // CHECK: [[shmA:%.+]] = memref.alloc + // CHECK: [[shmB:%.+]] = memref.alloc + %shmA = memref.alloc() {alignment = 64 : i64} : memref<128x32xf16, 3> + %shmB = memref.alloc() {alignment = 64 : i64} : memref<256x32xf16, 3> + + // CHECK: %[[D0:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16> + %0 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16> + // CHECK: [[c7:%.+]] = arith.constant 7 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]] + // CHECK: [[c2:%.+]] = arith.constant 2 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] + // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]] + // CHECK: vector.transfer_write %[[D0:.+]], [[shmB]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3> + vector.transfer_write %0, %shmB[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<256x32xf16, 3> + gpu.barrier + gpu.barrier + // CHECK: [[c7:%.+]] = arith.constant 7 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]] + // CHECK: [[c2:%.+]] = arith.constant 2 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] + // CHECK: vector.load [[shmB:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<256x32xf16, 3>, vector<8xf16> + %1 = vector.load %shmB[%fragRow, %fragColPerm] : memref<256x32xf16, 3>, vector<8xf16> + + // CHECK: %[[D2:.+]] = vector.transfer_read [[arg0:%.+]][[[readRow:%.+]], [[readCol:%.+]]], [[cst:.+]] {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16> + %2 = vector.transfer_read %arg0[%readRow, %readCol], %cst {in_bounds = [true, true]} : memref<4096x4096xf16>, vector<1x8xf16> + // CHECK: [[c7:%.+]] = arith.constant 7 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[stRow:%.+]], [[c7]] + // CHECK: [[c2:%.+]] = arith.constant 2 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] + // CHECK: [[stColPerm:%.+]] = arith.xori [[stCol:%.+]], [[xorBits]] + // CHECK: vector.transfer_write %[[D2:.+]], [[shmA:%.+]][[[writeRow:%.+]], [[writeCol:%.+]]] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3> + vector.transfer_write %2, %shmA[%writeRow, %writeCol] {in_bounds = [true, true]} : vector<1x8xf16>, memref<128x32xf16, 3> + gpu.barrier + gpu.barrier + // CHECK: [[c7:%.+]] = arith.constant 7 : index + // CHECK: [[srcBits:%.+]] = arith.andi [[fragRow]], [[c7]] + // CHECK: [[c2:%.+]] = arith.constant 2 : index + // CHECK: [[xorBits:%.+]] = arith.shli [[srcBits]], [[c2]] + // CHECK: [[fragColPerm:%.+]] = arith.xori [[fragCol:%.+]], [[xorBits]] + // CHECK: vector.load [[shmA:%.+]][[[fragRow:%.+]], [[fragColPerm]]] : memref<128x32xf16, 3>, vector<8xf16> + %3 = vector.load %shmA[%fragRow, %fragColPerm] : memref<128x32xf16, 3>, vector<8xf16> + return + } + \ No newline at end of file From c072744c97c328360afbb96eaabd6033665dc965 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Sun, 7 Jan 2024 22:21:29 -0800 Subject: [PATCH 02/12] Add a fix for bad line break --- mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp index 0a2f04f4e6487..c80beed0ed1d9 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp @@ -1,5 +1,4 @@ -//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation -//----------===// +//===- OptimizeSharedMemory.cpp - MLIR AMDGPU pass implementation ---------===// // // Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. From 5384ebaed5b156ce5f7446da6f054a59a459b668 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Mon, 8 Jan 2024 15:35:15 -0500 Subject: [PATCH 03/12] Remove constructor to enable autogeneration --- mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h | 5 +---- mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td | 1 - mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 4 ---- 3 files changed, 1 insertion(+), 9 deletions(-) diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h index 752078cd6930e..11d182ba5823e 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.h @@ -20,10 +20,7 @@ namespace mlir { class ConversionTarget; namespace amdgpu { -#define GEN_PASS_DECL_AMDGPUEMULATEATOMICSPASS - -/// Create a pass to optimize shared memory reads and writes. -std::unique_ptr createOptimizeSharedMemoryPass(); +#define GEN_PASS_DECL #define GEN_PASS_REGISTRATION #include "mlir/Dialect/AMDGPU/Transforms/Passes.h.inc" diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td index 1c12ca9827112..1b1543c2d3897 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td @@ -32,7 +32,6 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> { def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> { let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts."; - let constructor = "mlir::amdgpu::createOptimizeSharedMemoryPass()"; let dependentDialects = [ "memref::MemRefDialect", "vector::VectorDialect" ]; diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp index c80beed0ed1d9..81d98c9225de4 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp @@ -245,7 +245,3 @@ class OptimizeSharedMemoryPass } }; } // namespace - -std::unique_ptr mlir::amdgpu::createOptimizeSharedMemoryPass() { - return std::make_unique(); -} From d78fd01e7f856caf0ab6fee41692f5c959befe7c Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Mon, 8 Jan 2024 19:38:02 -0800 Subject: [PATCH 04/12] Remove unused constant expressions --- mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td | 10 ---------- 1 file changed, 10 deletions(-) diff --git a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td index 324c656f47599..b4bf1b5191232 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td +++ b/mlir/include/mlir/Dialect/AMDGPU/IR/AMDGPU.td @@ -41,16 +41,6 @@ def AMDGPU_Dialect : Dialect { /// is a gpu::AddressSpaceAttr attribute with value 'workgroup`. static bool isSharedMemoryAddressSpace(Attribute type); - /// Defines the MemRef memory space attribute numeric value that indicates - /// a memref is located in global memory. This should correspond to the - /// value used in ROCDL. - static constexpr unsigned kGlobalMemoryAddressSpace = 1; - - /// Defines the MemRef memory space attribute numeric value that indicates - /// a memref is located in private memory. This should correspond to the - /// value used in ROCDL. - static constexpr unsigned kPrivateMemoryAddressSpace = 2; - /// Defines the MemRef memory space attribute numeric value that indicates /// a memref is located in shared memory. This should correspond to the /// value used in ROCDL. From 811fea3dd804bcf5f87b16f2c2588479a9a55fcb Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Tue, 9 Jan 2024 08:17:04 -0800 Subject: [PATCH 05/12] Add simplification for read/write ops initialization --- mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp index 81d98c9225de4..d004d258bebe6 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp @@ -195,8 +195,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, // Transform indices for the ops writing to shared memory. while (!shmWriteOps.empty()) { - Operation *shmWriteOp = shmWriteOps.back(); - shmWriteOps.pop_back(); + Operation *shmWriteOp = shmWriteOps.pop_back_val(); builder.setInsertionPoint(shmWriteOp); auto indices = amdgpu::getIndices(shmWriteOp); @@ -208,8 +207,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, // Transform indices for the ops reading from shared memory. while (!shmReadOps.empty()) { - Operation *shmReadOp = shmReadOps.back(); - shmReadOps.pop_back(); + Operation *shmReadOp = shmReadOps.pop_back_val(); builder.setInsertionPoint(shmReadOp); auto indices = amdgpu::getIndices(shmReadOp); From f8b4c06dfcef6af4a84a36af0a15fc0eb0ba6849 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Wed, 10 Jan 2024 08:31:30 -0800 Subject: [PATCH 06/12] Add description for utils --- mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h index bee3af1914fee..b39e25d1a8826 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h @@ -11,10 +11,13 @@ namespace mlir { namespace amdgpu { -/// Get the indices that the given load/store operation is operating on. +/// Get and set the indices that the given load/store operation is operating on. +/// Preconditions: +/// - The Op must have memory affects +/// - Considers memref::LoadOp, vector::LoadOp, vector::TransferReadOp +/// - Considers memref::StoreOp, vector::StoreOp, vector::TransferWriteOp +/// - Excludes subview op Operation::operand_range getIndices(Operation *op); - -/// Set the indices that the given load/store operation is operating on. void setIndices(Operation *op, ArrayRef indices); } // namespace amdgpu From 653c1ae031d9a05b6606c63b525b2ad6f7559e44 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Thu, 11 Jan 2024 07:06:24 -0800 Subject: [PATCH 07/12] Add description --- mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td index 1b1543c2d3897..c8059e6d316e8 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Passes.td @@ -32,6 +32,12 @@ def AmdgpuEmulateAtomicsPass : Pass<"amdgpu-emulate-atomics"> { def OptimizeSharedMemory : Pass<"amdgpu-optimize-shared-memory"> { let summary = "Optimizes accesses to shared memory memrefs in order to reduce bank conflicts."; + let description = [{ + This pass adds a transformation and pass to the AMDGPU dialect that + attempts to optimize reads/writes from a memref representing GPU shared + memory in order to avoid bank conflicts. + }]; + let dependentDialects = [ "memref::MemRefDialect", "vector::VectorDialect" ]; From fe9f5a957c4e588acff3b2564a5e24fa3ce5e1fe Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Thu, 11 Jan 2024 07:13:28 -0800 Subject: [PATCH 08/12] Add change the pass data type as struct --- mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp index d004d258bebe6..4a7dc2f20afd9 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp @@ -221,7 +221,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, } namespace { -class OptimizeSharedMemoryPass +struct OptimizeSharedMemoryPass : public amdgpu::impl::OptimizeSharedMemoryBase { public: OptimizeSharedMemoryPass() = default; From 983e956e175bbb9c689a53e11555ab3ca17e1612 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Thu, 11 Jan 2024 07:19:40 -0800 Subject: [PATCH 09/12] Remove anonymous namespace --- mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 2 -- 1 file changed, 2 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp index 4a7dc2f20afd9..2ce1ed72856dc 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp @@ -220,7 +220,6 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, return success(); } -namespace { struct OptimizeSharedMemoryPass : public amdgpu::impl::OptimizeSharedMemoryBase { public: @@ -242,4 +241,3 @@ struct OptimizeSharedMemoryPass } } }; -} // namespace From b722dd7bd16bd96aac7f3516b1b85cca44c24ef4 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Thu, 11 Jan 2024 11:58:44 -0800 Subject: [PATCH 10/12] Add optional for the util function --- mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h | 2 +- .../Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp | 8 ++++---- mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp | 5 ++--- 3 files changed, 7 insertions(+), 8 deletions(-) diff --git a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h index b39e25d1a8826..6be57ca54b15f 100644 --- a/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h +++ b/mlir/include/mlir/Dialect/AMDGPU/Transforms/Utils.h @@ -17,7 +17,7 @@ namespace amdgpu { /// - Considers memref::LoadOp, vector::LoadOp, vector::TransferReadOp /// - Considers memref::StoreOp, vector::StoreOp, vector::TransferWriteOp /// - Excludes subview op -Operation::operand_range getIndices(Operation *op); +std::optional getIndices(Operation *op); void setIndices(Operation *op, ArrayRef indices); } // namespace amdgpu diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp index 2ce1ed72856dc..c7001fc6d57d5 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/OptimizeSharedMemory.cpp @@ -138,13 +138,13 @@ getShmReadAndWriteOps(Operation *parentOp, Value shmMemRef, if (llvm::any_of(readOps, [](Operation *op) { return !isa( op) || - amdgpu::getIndices(op).size() < 2; + amdgpu::getIndices(op)->size() < 2; })) return failure(); if (llvm::any_of(writeOps, [](Operation *op) { return !isa( op) || - amdgpu::getIndices(op).size() < 2; + amdgpu::getIndices(op)->size() < 2; })) return failure(); @@ -199,7 +199,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, builder.setInsertionPoint(shmWriteOp); auto indices = amdgpu::getIndices(shmWriteOp); - SmallVector transformedIndices(indices.begin(), indices.end()); + SmallVector transformedIndices(indices->begin(), indices->end()); transformIndices(builder, shmWriteOp->getLoc(), transformedIndices, memRefType, srcDim, tgtDim); amdgpu::setIndices(shmWriteOp, transformedIndices); @@ -211,7 +211,7 @@ mlir::amdgpu::optimizeSharedMemoryReadsAndWrites(Operation *parentOp, builder.setInsertionPoint(shmReadOp); auto indices = amdgpu::getIndices(shmReadOp); - SmallVector transformedIndices(indices.begin(), indices.end()); + SmallVector transformedIndices(indices->begin(), indices->end()); transformIndices(builder, shmReadOp->getLoc(), transformedIndices, memRefType, srcDim, tgtDim); amdgpu::setIndices(shmReadOp, transformedIndices); diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp index a1dc6cf70e7bf..05ac29bfcfaec 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp @@ -15,7 +15,7 @@ using namespace mlir; using namespace mlir::amdgpu; -Operation::operand_range amdgpu::getIndices(Operation *op) { +std::optional amdgpu::getIndices(Operation *op) { if (auto loadOp = dyn_cast(op)) return loadOp.getIndices(); if (auto storeOp = dyn_cast(op)) @@ -28,7 +28,7 @@ Operation::operand_range amdgpu::getIndices(Operation *op) { return transferReadOp.getIndices(); if (auto transferWriteOp = dyn_cast(op)) return transferWriteOp.getIndices(); - llvm_unreachable("unsupported op type"); + return std::nullopt; } void amdgpu::setIndices(Operation *op, ArrayRef indices) { @@ -44,5 +44,4 @@ void amdgpu::setIndices(Operation *op, ArrayRef indices) { return transferReadOp.getIndicesMutable().assign(indices); if (auto transferWriteOp = dyn_cast(op)) return transferWriteOp.getIndicesMutable().assign(indices); - llvm_unreachable("unsupported op type"); } From 085d629ce18f3ee66a0c74f797565af84a61f71f Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Tue, 16 Jan 2024 16:19:26 -0800 Subject: [PATCH 11/12] Add interface for get and set indices functions --- mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp | 93 +++++++++++++------- 1 file changed, 59 insertions(+), 34 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp index 05ac29bfcfaec..0361a4b983d00 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp @@ -1,11 +1,3 @@ -//===- Utils.cpp - Transform utilities ------------------------------------===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - #include "mlir/Dialect/AMDGPU/Transforms/Utils.h" #include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h" @@ -15,33 +7,66 @@ using namespace mlir; using namespace mlir::amdgpu; +// Define an interface for operations with indices +class IndicesInterface { +public: + virtual std::optional getIndices() = 0; + virtual void setIndices(ArrayRef indices) = 0; + virtual ~IndicesInterface() = default; +}; + +// Implement a generic class that uses IndicesInterface +class OperationWithIndices : public IndicesInterface { +private: + Operation *op; + template + static std::optional getIndicesImpl(Operation *op) { + if (auto specificOp = dyn_cast(op)) + return specificOp.getIndices(); + return std::nullopt; + } + + template + static void setIndicesImpl(Operation *op, ArrayRef indices) { + if (auto specificOp = dyn_cast(op)) + specificOp.getIndicesMutable().assign(indices); + } + +public: + OperationWithIndices(Operation *op) : op(op) {} + + std::optional getIndices() override { + auto result = getIndicesImpl(op); + if (!result) + result = getIndicesImpl(op); + if (!result) + result = getIndicesImpl(op); + if (!result) + result = getIndicesImpl(op); + if (!result) + result = getIndicesImpl(op); + if (!result) + result = getIndicesImpl(op); + + return result; + } + + void setIndices(ArrayRef indices) override { + setIndicesImpl(op, indices); + setIndicesImpl(op, indices); + setIndicesImpl(op, indices); + setIndicesImpl(op, indices); + setIndicesImpl(op, indices); + setIndicesImpl(op, indices); + } +}; + std::optional amdgpu::getIndices(Operation *op) { - if (auto loadOp = dyn_cast(op)) - return loadOp.getIndices(); - if (auto storeOp = dyn_cast(op)) - return storeOp.getIndices(); - if (auto vectorReadOp = dyn_cast(op)) - return vectorReadOp.getIndices(); - if (auto vectorStoreOp = dyn_cast(op)) - return vectorStoreOp.getIndices(); - if (auto transferReadOp = dyn_cast(op)) - return transferReadOp.getIndices(); - if (auto transferWriteOp = dyn_cast(op)) - return transferWriteOp.getIndices(); - return std::nullopt; + OperationWithIndices operationWithIndices(op); + return operationWithIndices.getIndices(); } void amdgpu::setIndices(Operation *op, ArrayRef indices) { - if (auto loadOp = dyn_cast(op)) - return loadOp.getIndicesMutable().assign(indices); - if (auto storeOp = dyn_cast(op)) - return storeOp.getIndicesMutable().assign(indices); - if (auto vectorReadOp = dyn_cast(op)) - return vectorReadOp.getIndicesMutable().assign(indices); - if (auto vectorStoreOp = dyn_cast(op)) - return vectorStoreOp.getIndicesMutable().assign(indices); - if (auto transferReadOp = dyn_cast(op)) - return transferReadOp.getIndicesMutable().assign(indices); - if (auto transferWriteOp = dyn_cast(op)) - return transferWriteOp.getIndicesMutable().assign(indices); -} + OperationWithIndices operationWithIndices(op); + operationWithIndices.setIndices(indices); +} \ No newline at end of file From 49b043f62dbaa297dc16cb1e0c1e8e4ec733be53 Mon Sep 17 00:00:00 2001 From: erman-gurses Date: Wed, 17 Jan 2024 13:36:55 -0800 Subject: [PATCH 12/12] Revert back to the original get and set indices methods --- mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp | 85 ++++++-------------- 1 file changed, 26 insertions(+), 59 deletions(-) diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp index 0361a4b983d00..8163eeafdf1f0 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/Utils.cpp @@ -7,66 +7,33 @@ using namespace mlir; using namespace mlir::amdgpu; -// Define an interface for operations with indices -class IndicesInterface { -public: - virtual std::optional getIndices() = 0; - virtual void setIndices(ArrayRef indices) = 0; - virtual ~IndicesInterface() = default; -}; - -// Implement a generic class that uses IndicesInterface -class OperationWithIndices : public IndicesInterface { -private: - Operation *op; - template - static std::optional getIndicesImpl(Operation *op) { - if (auto specificOp = dyn_cast(op)) - return specificOp.getIndices(); - return std::nullopt; - } - - template - static void setIndicesImpl(Operation *op, ArrayRef indices) { - if (auto specificOp = dyn_cast(op)) - specificOp.getIndicesMutable().assign(indices); - } - -public: - OperationWithIndices(Operation *op) : op(op) {} - - std::optional getIndices() override { - auto result = getIndicesImpl(op); - if (!result) - result = getIndicesImpl(op); - if (!result) - result = getIndicesImpl(op); - if (!result) - result = getIndicesImpl(op); - if (!result) - result = getIndicesImpl(op); - if (!result) - result = getIndicesImpl(op); - - return result; - } - - void setIndices(ArrayRef indices) override { - setIndicesImpl(op, indices); - setIndicesImpl(op, indices); - setIndicesImpl(op, indices); - setIndicesImpl(op, indices); - setIndicesImpl(op, indices); - setIndicesImpl(op, indices); - } -}; - std::optional amdgpu::getIndices(Operation *op) { - OperationWithIndices operationWithIndices(op); - return operationWithIndices.getIndices(); + if (auto loadOp = dyn_cast(op)) + return loadOp.getIndices(); + if (auto storeOp = dyn_cast(op)) + return storeOp.getIndices(); + if (auto vectorReadOp = dyn_cast(op)) + return vectorReadOp.getIndices(); + if (auto vectorStoreOp = dyn_cast(op)) + return vectorStoreOp.getIndices(); + if (auto transferReadOp = dyn_cast(op)) + return transferReadOp.getIndices(); + if (auto transferWriteOp = dyn_cast(op)) + return transferWriteOp.getIndices(); + return std::nullopt; } void amdgpu::setIndices(Operation *op, ArrayRef indices) { - OperationWithIndices operationWithIndices(op); - operationWithIndices.setIndices(indices); -} \ No newline at end of file + if (auto loadOp = dyn_cast(op)) + return loadOp.getIndicesMutable().assign(indices); + if (auto storeOp = dyn_cast(op)) + return storeOp.getIndicesMutable().assign(indices); + if (auto vectorReadOp = dyn_cast(op)) + return vectorReadOp.getIndicesMutable().assign(indices); + if (auto vectorStoreOp = dyn_cast(op)) + return vectorStoreOp.getIndicesMutable().assign(indices); + if (auto transferReadOp = dyn_cast(op)) + return transferReadOp.getIndicesMutable().assign(indices); + if (auto transferWriteOp = dyn_cast(op)) + return transferWriteOp.getIndicesMutable().assign(indices); +}