diff --git a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h index 7d9a5e6ca7596..46003ed846869 100644 --- a/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h +++ b/mlir/include/mlir/Dialect/MemRef/Utils/MemRefUtils.h @@ -64,6 +64,35 @@ getLinearizedMemRefOffsetAndSize(OpBuilder &builder, Location loc, int srcBits, // it means both the allocations and associated stores can be removed. void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp); +/// Given a set of sizes, return the suffix product. +/// +/// When applied to slicing, this is the calculation needed to derive the +/// strides (i.e. the number of linear indices to skip along the (k-1) most +/// minor dimensions to get the next k-slice). +/// +/// This is the basis to linearize an n-D offset confined to `[0 ... sizes]`. +/// +/// Assuming `sizes` is `[s0, .. sn]`, return the vector +/// `[s1 * ... * sn, s2 * ... * sn, ..., sn, 1]`. +/// +/// It is the caller's responsibility to provide valid OpFoldResult type values +/// and construct valid IR in the end. +/// +/// `sizes` elements are asserted to be non-negative. +/// +/// Return an empty vector if `sizes` is empty. +/// +/// The function emits an IR block which computes suffix product for provided +/// sizes. +SmallVector +computeSuffixProductIRBlock(Location loc, OpBuilder &builder, + ArrayRef sizes); +inline SmallVector +computeStridesIRBlock(Location loc, OpBuilder &builder, + ArrayRef sizes) { + return computeSuffixProductIRBlock(loc, builder, sizes); +} + } // namespace memref } // namespace mlir diff --git a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp index aa44455ada7f9..29a5bc9a7ae5c 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/FoldMemRefAliasOps.cpp @@ -19,6 +19,7 @@ #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" +#include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -63,39 +64,85 @@ resolveSourceIndicesExpandShape(Location loc, PatternRewriter &rewriter, memref::ExpandShapeOp expandShapeOp, ValueRange indices, SmallVectorImpl &sourceIndices) { - // The below implementation uses computeSuffixProduct method, which only - // allows int64_t values (i.e., static shape). Bail out if it has dynamic - // shapes. - if (!expandShapeOp.getResultType().hasStaticShape()) + // Record the rewriter context for constructing ops later. + MLIRContext *ctx = rewriter.getContext(); + + // Capture expand_shape's input dimensions as `SmallVector`. + // This is done for the purpose of inferring the output shape via + // `inferExpandOutputShape` which will in turn be used for suffix product + // calculation later. + SmallVector srcShape; + MemRefType srcType = expandShapeOp.getSrcType(); + + for (int64_t i = 0, e = srcType.getRank(); i < e; ++i) { + if (srcType.isDynamicDim(i)) { + srcShape.push_back( + rewriter.create(loc, expandShapeOp.getSrc(), i) + .getResult()); + } else { + srcShape.push_back(rewriter.getIndexAttr(srcType.getShape()[i])); + } + } + + auto outputShape = inferExpandShapeOutputShape( + rewriter, loc, expandShapeOp.getResultType(), + expandShapeOp.getReassociationIndices(), srcShape); + if (!outputShape.has_value()) return failure(); - MLIRContext *ctx = rewriter.getContext(); + // Traverse all reassociation groups to determine the appropriate indices + // corresponding to each one of them post op folding. for (ArrayRef groups : expandShapeOp.getReassociationIndices()) { assert(!groups.empty() && "association indices groups cannot be empty"); + // Flag to indicate the presence of dynamic dimensions in current + // reassociation group. int64_t groupSize = groups.size(); - // Construct the expression for the index value w.r.t to expand shape op - // source corresponding the indices wrt to expand shape op result. - SmallVector sizes(groupSize); - for (int64_t i = 0; i < groupSize; ++i) - sizes[i] = expandShapeOp.getResultType().getDimSize(groups[i]); - SmallVector suffixProduct = computeSuffixProduct(sizes); - SmallVector dims(groupSize); - bindDimsList(ctx, MutableArrayRef{dims}); - AffineExpr srcIndexExpr = linearize(ctx, dims, suffixProduct); + // Group output dimensions utilized in this reassociation group for suffix + // product calculation. + SmallVector sizesVal(groupSize); + for (int64_t i = 0; i < groupSize; ++i) { + sizesVal[i] = (*outputShape)[groups[i]]; + } + + // Calculate suffix product of relevant output dimension sizes. + SmallVector suffixProduct = + memref::computeSuffixProductIRBlock(loc, rewriter, sizesVal); + + // Create affine expression variables for dimensions and symbols in the + // newly constructed affine map. + SmallVector dims(groupSize), symbols(groupSize); + bindDimsList(ctx, dims); + bindSymbolsList(ctx, symbols); - /// Apply permutation and create AffineApplyOp. + // Linearize binded dimensions and symbols to construct the resultant + // affine expression for this indice. + AffineExpr srcIndexExpr = linearize(ctx, dims, symbols); + + // Record the load index corresponding to each dimension in the + // reassociation group. These are later supplied as operands to the affine + // map used for calulating relevant index post op folding. SmallVector dynamicIndices(groupSize); for (int64_t i = 0; i < groupSize; i++) dynamicIndices[i] = indices[groups[i]]; - // Creating maximally folded and composd affine.apply composes better with - // other transformations without interleaving canonicalization passes. + // Supply suffix product results followed by load op indices as operands + // to the map. + SmallVector mapOperands; + llvm::append_range(mapOperands, suffixProduct); + llvm::append_range(mapOperands, dynamicIndices); + + // Creating maximally folded and composed affine.apply composes better + // with other transformations without interleaving canonicalization + // passes. OpFoldResult ofr = affine::makeComposedFoldedAffineApply( rewriter, loc, AffineMap::get(/*numDims=*/groupSize, - /*numSymbols=*/0, srcIndexExpr), - dynamicIndices); + /*numSymbols=*/groupSize, /*expression=*/srcIndexExpr), + mapOperands); + + // Push index value in the op post folding corresponding to this + // reassociation group. sourceIndices.push_back( getValueOrCreateConstantIndexOp(rewriter, loc, ofr)); } diff --git a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp index 556a82de2166f..c93e5a9dcd39f 100644 --- a/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp +++ b/mlir/lib/Dialect/MemRef/Utils/MemRefUtils.cpp @@ -15,6 +15,7 @@ #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "llvm/ADT/STLExtras.h" namespace mlir { namespace memref { @@ -155,5 +156,27 @@ void eraseDeadAllocAndStores(RewriterBase &rewriter, Operation *parentOp) { rewriter.eraseOp(op); } +static SmallVector +computeSuffixProductIRBlockImpl(Location loc, OpBuilder &builder, + ArrayRef sizes, + OpFoldResult unit) { + SmallVector strides(sizes.size(), unit); + AffineExpr s0, s1; + bindSymbols(builder.getContext(), s0, s1); + + for (int64_t r = strides.size() - 1; r > 0; --r) { + strides[r - 1] = affine::makeComposedFoldedAffineApply( + builder, loc, s0 * s1, {strides[r], sizes[r]}); + } + return strides; +} + +SmallVector +computeSuffixProductIRBlock(Location loc, OpBuilder &builder, + ArrayRef sizes) { + OpFoldResult unit = builder.getIndexAttr(1); + return computeSuffixProductIRBlockImpl(loc, builder, sizes, unit); +} + } // namespace memref } // namespace mlir diff --git a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir index 254cd4015eed9..99b5f78b03fba 100644 --- a/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir +++ b/mlir/test/Dialect/MemRef/fold-memref-alias-ops.mlir @@ -468,23 +468,66 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_3d(%ar // ----- -// CHECK-LABEL: fold_dynamic_subview_with_memref_load_store_expand_shape -// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[SZ0:.*]]: index) -func.func @fold_dynamic_subview_with_memref_load_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 { +// CHECK-LABEL: fold_dynamic_subview_with_memref_load_expand_shape +// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) -> f32 +func.func @fold_dynamic_subview_with_memref_load_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0: index) -> f32 { %c0 = arith.constant 0 : index %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> %0 = memref.load %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> return %0 : f32 } -// CHECK: %[[C0:.*]] = arith.constant 0 : index -// CHECK: %[[EXPAND_SHAPE:.*]] = memref.expand_shape %[[ARG0]] {{\[\[}}0, 1], [2, 3]] output_shape [1, 16, %[[SZ0]], 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> -// CHECK: %[[VAL_0:.*]] = memref.load %[[EXPAND_SHAPE]][%[[C0]], %[[ARG1]], %[[ARG2]], %[[C0]]] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> -// CHECK: return %[[VAL_0]] : f32 +// CHECK-NEXT: %[[VAL1:.*]] = memref.load %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>> +// CHECK-NEXT: return %[[VAL1]] : f32 // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1024 + d1)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK-LABEL: fold_dynamic_subview_with_memref_store_expand_shape +// CHECK-SAME: (%[[ARG0:.*]]: memref<16x?xf32, strided<[16, 1]>>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index) +func.func @fold_dynamic_subview_with_memref_store_expand_shape(%arg0 : memref<16x?xf32, strided<[16, 1]>>, %arg1 : index, %arg2 : index, %sz0 : index) { + %c0 = arith.constant 0 : index + %c1f32 = arith.constant 1.0 : f32 + %expand_shape = memref.expand_shape %arg0 [[0, 1], [2, 3]] output_shape [1, 16, %sz0, 1] : memref<16x?xf32, strided<[16, 1]>> into memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> + memref.store %c1f32, %expand_shape[%c0, %arg1, %arg2, %c0] : memref<1x16x?x1xf32, strided<[256, 16, 1, 1]>> + return +} +// CHECK: %[[C1F32:.*]] = arith.constant 1.000000e+00 : f32 +// CHECK-NEXT: memref.store %[[C1F32]], %[[ARG0]][%[[ARG1]], %[[ARG2]]] : memref<16x?xf32, strided<[16, 1]>> +// CHECK-NEXT: return + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0] -> (s0 * 3)> +// CHECK-LABEL: fold_memref_alias_expand_shape_subview_load_store_dynamic_dim +// CHECK-SAME: (%[[ARG0:.*]]: memref<2048x16xf32>, %[[ARG1:.*]]: index, %[[ARG2:.*]]: index, %[[ARG3:.*]]: index, %[[ARG4:.*]]: index) +func.func @fold_memref_alias_expand_shape_subview_load_store_dynamic_dim(%alloc: memref<2048x16xf32>, %c10: index, %c5: index, %c0: index, %sz0: index) { + %subview = memref.subview %alloc[%c5, 0] [%c10, 16] [1, 1] : memref<2048x16xf32> to memref> + %expand_shape = memref.expand_shape %subview [[0], [1, 2, 3]] output_shape [1, 16, %sz0, 1] : memref> into memref> + %dim = memref.dim %expand_shape, %c0 : memref> + + affine.for %arg6 = 0 to %dim step 64 { + affine.for %arg7 = 0 to 16 step 16 { + %dummy_load = affine.load %expand_shape[%arg6, 0, %arg7, %arg7] : memref> + affine.store %dummy_load, %subview[%arg6, %arg7] : memref> + } + } + return +} +// CHECK-NEXT: memref.subview +// CHECK-NEXT: %[[EXPAND_SHAPE:.*]] = memref.expand_shape +// CHECK-NEXT: %[[DIM:.*]] = memref.dim %[[EXPAND_SHAPE]], %[[ARG3]] : memref> +// CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to %[[DIM]] step 64 { +// CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 16 step 16 { +// CHECK-NEXT: %[[VAL0:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]] +// CHECK-NEXT: %[[VAL1:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]]] +// CHECK-NEXT: %[[VAL2:.*]] = affine.load %[[ARG0]][%[[VAL0]], %[[VAL1]]] : memref<2048x16xf32> +// CHECK-NEXT: %[[VAL3:.*]] = affine.apply #[[$MAP0]]()[%[[ARG2]], %[[ARG4]]] +// CHECK-NEXT: affine.store %[[VAL2]], %[[ARG0]][%[[VAL3]], %[[ARG5]]] : memref<2048x16xf32> + +// ----- + +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0, s1] -> (s0 * 1024 + s1)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { @@ -506,14 +549,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape(%arg0: // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 { // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 { // CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 { -// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]]) -// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]]) +// CHECK-NEXT: %[[IDX1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]], %[[ARG4]]] +// CHECK-NEXT: %[[IDX2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]] // CHECK-NEXT: affine.load %[[ARG0]][%[[IDX1]], %[[IDX2]]] : memref<1024x1024xf32> // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 1025 + d1)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0, d1)[s0] -> (d0 + d1 + s0 * 1024)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_access_index_is_an_expression(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { @@ -535,14 +578,14 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_when_a // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 { // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 { // CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 { -// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]]) -// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]]) +// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]], %[[ARG4]])[%[[ARG3]]] +// CHECK-NEXT: %[[TMP3:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]] // CHECK-NEXT: affine.load %[[ARG0]][%[[TMP1]], %[[TMP3]]] : memref<1024x1024xf32> // ----- -// CHECK-DAG: #[[$MAP0:.*]] = affine_map<(d0) -> (d0 * 1024)> -// CHECK-DAG: #[[$MAP1:.*]] = affine_map<(d0, d1) -> (d0 + d1)> +// CHECK-DAG: #[[$MAP0:.*]] = affine_map<()[s0] -> (s0 * 1024)> +// CHECK-DAG: #[[$MAP1:.*]] = affine_map<()[s0, s1] -> (s0 + s1)> // CHECK-LABEL: fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index // CHECK-SAME: (%[[ARG0:.*]]: memref<1024x1024xf32>, %[[ARG1:.*]]: memref<1xf32>, %[[ARG2:.*]]: index) func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_constant_access_index(%arg0: memref<1024x1024xf32>, %arg1: memref<1xf32>, %arg2: index) -> f32 { @@ -565,8 +608,8 @@ func.func @fold_static_stride_subview_with_affine_load_store_expand_shape_with_c // CHECK-NEXT: affine.for %[[ARG4:.*]] = 0 to 1024 { // CHECK-NEXT: affine.for %[[ARG5:.*]] = 0 to 1020 { // CHECK-NEXT: affine.for %[[ARG6:.*]] = 0 to 1 { -// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]](%[[ARG3]]) -// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]](%[[ARG5]], %[[ARG6]]) +// CHECK-NEXT: %[[TMP1:.*]] = affine.apply #[[$MAP0]]()[%[[ARG3]]] +// CHECK-NEXT: %[[TMP2:.*]] = affine.apply #[[$MAP1]]()[%[[ARG5]], %[[ARG6]]] // CHECK-NEXT: memref.load %[[ARG0]][%[[TMP1]], %[[TMP2]]] : memref<1024x1024xf32> // -----