From 97ade0568b16254ee1f212ac588d3a3d6c1efe54 Mon Sep 17 00:00:00 2001 From: Prathamesh Tagore Date: Wed, 8 May 2024 10:10:03 +0530 Subject: [PATCH] [mlir][fold-memref-alias-ops] Add support for folding memref.expand_shape involving dynamic dims fold-memref-alias-ops pass bails out in presence of dynamic shapes which leads to unwanted propagation of alias types during other transformations. This can percolate down further and can lead to errors which should not have been created in the first place. --- .../mlir/Dialect/MemRef/Utils/MemRefUtils.h | 29 +++++++ .../MemRef/Transforms/FoldMemRefAliasOps.cpp | 85 ++++++++++++++----- mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp | 23 +++++ .../Dialect/MemRef/fold-memref-alias-ops.mlir | 81 +++++++++++++----- 4 files changed, 180 insertions(+), 38 deletions(-) diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h index 7d9a5e6ca7596..46003ed846869 100644 --- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h +++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h @@ -64,6 +64,35 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, // it means both the allocations and associated stores can be removed. void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp); +/// Given a set of sizes, return the suffix product. +/// +/// When applied to slicing, this is the calculation needed to derive the +/// strides (i.e. the number of linear indices to skip along the (k-1) most +/// minor dimensions to get the next k-slice). +/// +/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`. +/// +/// Assuming `sizes` is `[s0, .. sn]`, return the vector +/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`. +/// +/// It is the caller's responsibility to provide valid OpFoldResult type values +/// and construct valid IR in the end. +/// +/// `sizes` elements are asserted to be non-negative. +/// +/// Return an empty vector if `sizes` is empty. +/// +/// The function emits an IR block which computes suffix product for provided +/// sizes. +SmallVector +computeSuffixProductIRBlock(Location loc, OpBuilder &builder, + ArrayRef sizes); +inline SmallVector +computeStridesIRBlock(Location loc, OpBuilder &builder, + ArrayRef sizes) { + return computeSuffixProductIRBlock(loc, builder, sizes); +} + } // namespace memref } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index aa44455ada7f9..29a5bc9a7ae5c 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -63,39 +64,85 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl &sourceIndices) { - // The below implementation uses computeSuffixProduct method, which only - // allows int64_t values (i.e., static shape). Bail out if it has dynamic - // shapes. - if (!expandShapeOp.getResultType().hasStaticShape()) + // Record the rewriter context for constructing ops later. + MLIRContext *ctx = rewriter.getContext(); + + // Capture expand_shape's input dimensions as `SmallVector`. + // This is done for the purpose of inferring the output shape via + // `inferExpandOutputShape` which will in turn be used for suffix product + // calculation later. + SmallVector srcShape; + MemRefType srcType = expandShapeOp.getSrcType(); + + for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) { + if (srcType.isDynamicDim(i)) { + srcShape.push_back( + rewriter.create(loc, expandShapeOp.getSrc(), i) + .getResult()); + } else { + srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i])); + } + } + + auto outputShape = inferExpandShapeOutputShape( + rewriter, loc, expandShapeOp.getResultType(), + expandShapeOp.getReassociationIndices(), srcShape); + if (!outputShape.has_value()) return failure(); - MLIRContext *ctx = rewriter.getContext(); + // Traverse all reassociation groups to determine the appropriate indices + // corresponding to each one of them post op folding. for (ArrayRef groups : expandShapeOp.getReassociationIndices()) { assert(!groups.empty() && "association indices groups cannot be empty"); + // Flag to indicate the presence of dynamic dimensions in current + // reassociation group. int64_t groupSize = groups.size(); - // Construct the expression for the index value w.r.t to expand shape op - // source corresponding the indices wrt to expand shape op result. - SmallVector sizes(groupSize); - for (int64_t i = 0; i < groupSize; ++i) - sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]); - SmallVector suffixProduct = computeSuffixProduct(sizes); - SmallVector dims(groupSize); - bindDimsList(ctx, MutableArrayRef{dims}); - AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct); + // Group output dimensions utilized in this reassociation group for suffix + // product calculation. + SmallVector sizesVal(groupSize); + for (int64_t i = 0; i < groupSize; ++i) { + sizesVal[i] = (*outputShape)[groups[i]]; + } + + // Calculate suffix product of relevant output dimension sizes. + SmallVector suffixProduct = + memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal); + + // Create affine expression variables for dimensions and symbols in the + // newly constructed affine map. + SmallVector dims(groupSize), symbols(groupSize); + bindDimsList(ctx, dims); + bindSymbolsList(ctx, symbols); - /// Apply permutation and create AffineApplyOp. + // Linearize binded dimensions and symbols to construct the resultant + // affine expression for this indice. + AffineExpr srcIndexExpr = linearize(ctx, dims, symbols); + + // Record the load index corresponding to each dimension in the + // reassociation group. These are later supplied as operands to the affine + // map used for calulating relevant index post op folding. SmallVector dynamicIndices(groupSize); for (int64_t i = 0; i < groupSize; i++) dynamicIndices[i] = indices[groups[i]]; - // Creating maximally folded and composd affine.apply composes better with - // other transformations without interleaving canonicalization passes. + // Supply suffix product results followed by load op indices as operands + // to the map. + SmallVector mapOperands; + llvm::append_range(mapOperands, suffixProduct); + llvm::append_range(mapOperands, dynamicIndices); + + // Creating maximally folded and composed affine.apply composes better + // with other transformations without interleaving canonicalization + // passes. OpFoldResult ofr = affine::makeComposedFoldedAffineApply( rewriter, loc, AffineMap::get(/*numDims=*/groupSize, - /*numSymbols=*/0, srcIndexExpr), - dynamicIndices); + /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr), + mapOperands); + + // Push index value in the op post folding corresponding to this + // reassociation group. sourceIndices.push_back( getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); } diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index 556a82de2166f..c93e5a9dcd39f 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { namespace memref { @@ -155,5 +156,27 @@ void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) { rewriter.eraseOp(op); } +static SmallVector +computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder, + ArrayRef sizes, + OpFoldResult unit) { + SmallVector strides(sizes.size(), unit); + AffineExpr s0, s1; + bindSymbols(builder.getContext(), s0, s1); + + for (int64_t r = strides.size() - 1; r > 0; --r) { + strides[r - 1] = affine::makeComposedFoldedAffineApply( + builder, loc, s0 * s1, {strides[r], sizes[r]}); + } + return strides; +} + +SmallVector +computeSuffixProductIRBlock(Location loc, OpBuilder &builder, + ArrayRef sizes) { + OpFoldResult unit = builder.getIndexAttr(1); + return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit); +} + } // namespace memref } // namespace mlir diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 254cd4015eed9..99b5f78b03fba 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -468,23 +468,66 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar // ----- -// CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape -// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[SZ0:.*]]: index) -func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 { +// CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape +// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 +func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 { %c0 = arith.constant 0 : index %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> return %0 : f32 } -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[EXPAND_SHAPE:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 16, %[[SZ0]], 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> -// CHECK: %[[VAL_0:.*]] = memref.load %[[EXPAND_SHAPE]][%[[C0]], %[[ARG1]], %[[ARG2]], %[[C0]]] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> -// CHECK: return %[[VAL_0]] : f32 +// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>> +// CHECK-NEXT: return %[[VAL1]] : f32 // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape +// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) +func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0 : index) { + %c0 = arith.constant 0 : index + %c1f32 = arith.constant 1.0 : f32 + %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> + memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> + return +} +// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>> +// CHECK-NEXT: return + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim +// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) +func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) { + %subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref> + %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [1, 16, %sz0, 1] : memref> into memref> + %dim = memref.dim %expand_shape, %c0 : memref> + + affine.for %arg6 = 0 to %dim step 64 { + affine.for %arg7 = 0 to 16 step 16 { + %dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref> + affine.store %dummy_load, %subview[%arg6, %arg7] : memref> + } + } + return +} +// CHECK-NEXT: memref.subview +// CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape +// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref> +// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 { +// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 { +// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]] +// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]]] +// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32> +// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]] +// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32> + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 * 1024 + s1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { @@ -506,14 +549,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 { // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 { // CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 { -// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]]) -// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]]) +// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]], %[[ARG4]]] +// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]] // CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32> // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1025 + d1)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 1024)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { @@ -535,14 +578,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 { // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 { // CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 { -// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]]) -// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]]) +// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])[%[[ARG3]]] +// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]] // CHECK-NEXT: affine.load %[[ARG0]][%[[TMP1]], %[[TMP3]]] : memref<1024x1024xf32> // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 1024)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 1024)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { @@ -565,8 +608,8 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 { // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 { // CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 { -// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]]) -// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]]) +// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]]] +// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]] // CHECK-NEXT: memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32> // -----