From c15f9d4ad6cd7626e8cba071c6cad1c935c5c875 Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 24 Apr 2024 16:59:52 +0000 Subject: [PATCH 1/2] [mlir] Add sub-byte type emulation support for `memref.collapse_shape` This PR add support for `memref.collapse_shape` to sub-byte type emulation. The `memref.collapse_shape` becomes a no-opt given that we are flattening the memref as part of the emulation (i.e., we are collapsing all the dimensions). --- .../MemRef/Transforms/EmulateNarrowType.cpp | 32 +++++++++++++++++-- .../Dialect/MemRef/emulate-narrow-type.mlir | 20 ++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp index 4449733f0daf0..77c108aab4807 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/EmulateNarrowType.cpp @@ -13,7 +13,6 @@ #include "mlir/Dialect/Arith/Transforms/Passes.h" #include "mlir/Dialect/Arith/Utils/Utils.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/MemRef/Transforms/Passes.h" #include "mlir/Dialect/MemRef/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/Utils/MemRefUtils.h" #include "mlir/Dialect/Vector/IR/VectorOps.h" @@ -24,7 +23,6 @@ #include "mlir/Support/MathExtras.h" #include "mlir/Transforms/DialectConversion.h" #include "llvm/Support/FormatVariadic.h" -#include "llvm/Support/MathExtras.h" #include #include @@ -430,6 +428,33 @@ struct ConvertMemRefSubview final : OpConversionPattern { } }; +//===----------------------------------------------------------------------===// +// ConvertMemRefCollapseShape +//===----------------------------------------------------------------------===// + +/// Emulating a `memref.collapse_shape` becomes a no-op after emulation given +/// that we flatten memrefs to a single dimension as part of the emulation and +/// there is no dimension to collapse any further. +struct ConvertMemRefCollapseShape final + : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + LogicalResult + matchAndRewrite(memref::CollapseShapeOp collapseShapeOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Value srcVal = adaptor.getSrc(); + auto newTy = dyn_cast(srcVal.getType()); + if (!newTy) + return failure(); + + if (newTy.getRank() != 1) + return failure(); + + rewriter.replaceOp(collapseShapeOp, srcVal); + return success(); + } +}; + } // end anonymous namespace //===----------------------------------------------------------------------===// @@ -442,7 +467,8 @@ void memref::populateMemRefNarrowTypeEmulationPatterns( // Populate `memref.*` conversion patterns. patterns.add, - ConvertMemRefAllocation, ConvertMemRefLoad, + ConvertMemRefAllocation, + ConvertMemRefCollapseShape, ConvertMemRefLoad, ConvertMemrefStore, ConvertMemRefAssumeAlignment, ConvertMemRefSubview, ConvertMemRefReinterpretCast>( typeConverter, patterns.getContext()); diff --git a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir index fd37b7ff0a271..435dcc944778d 100644 --- a/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir +++ b/mlir/test/Dialect/MemRef/emulate-narrow-type.mlir @@ -430,3 +430,23 @@ func.func @rank_zero_memref_store(%arg0: i4) -> () { // CHECK32: %[[EXTUI:.+]] = arith.extui %[[ARG0]] : i4 to i32 // CHECK32: %[[WRITE_RMW:.+]] = memref.atomic_rmw assign %[[EXTUI]], %[[ALLOC]][] : (i32, memref) -> i32 // CHECK32: return + +// ----- + +func.func @memref_collapse_shape_i4(%idx0 : index, %idx1 : index) -> i4 { + %arr = memref.alloc() : memref<32x8x128xi4> + %collapse = memref.collapse_shape %arr[[0, 1], [2]] : memref<32x8x128xi4> into memref<256x128xi4> + %1 = memref.load %collapse[%idx0, %idx1] : memref<256x128xi4> + return %1 : i4 +} + +// CHECK-LABEL: func.func @memref_collapse_shape_i4( +// CHECK: %[[ALLOC:.*]] = memref.alloc() : memref<16384xi8> +// CHECK-NOT: memref.collapse_shape +// CHECK: memref.load %[[ALLOC]][%{{.*}}] : memref<16384xi8> + +// CHECK32-LABEL: func.func @memref_collapse_shape_i4( +// CHECK32: %[[ALLOC:.*]] = memref.alloc() : memref<4096xi32> +// CHECK32-NOT: memref.collapse_shape +// CHECK32: memref.load %[[ALLOC]][%{{.*}}] : memref<4096xi32> + From f7ecf40b663a3cefa760c1ca904f8aa3937e985a Mon Sep 17 00:00:00 2001 From: Diego Caballero Date: Wed, 24 Apr 2024 17:07:46 +0000 Subject: [PATCH 2/2] [mlir][MemRef] Add ExtractStridedMetadataOpCollapseShapeFolder This PR adds a new pattern to the set of patterns used to resolve the offset, sizes and stride of a memref. Similar to `ExtractStridedMetadataOpSubviewFolder`, the new pattern resolves strided_metadata(collapse_shape) directly, without introduce a reshape_cast op. --- .../Transforms/ExpandStridedMetadata.cpp | 189 ++++++++++++------ 1 file changed, 130 insertions(+), 59 deletions(-) diff --git a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp index 96eb7cfd2db69..b5578a58468e9 100644 --- a/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp +++ b/mlir/lib/Dialect/MemRef/Transforms/ExpandStridedMetadata.cpp @@ -550,6 +550,78 @@ getCollapsedStride(memref::CollapseShapeOp collapseShape, OpBuilder &builder, return {makeComposedFoldedAffineMin(builder, collapseShape.getLoc(), minMap, groupStrides)}; } + +template (*getReshapedSizes)( + ReassociativeReshapeLikeOp, OpBuilder &, + ArrayRef /*origSizes*/, unsigned /*groupId*/), + SmallVector (*getReshapedStrides)( + ReassociativeReshapeLikeOp, OpBuilder &, + ArrayRef /*origSizes*/, + ArrayRef /*origStrides*/, unsigned /*groupId*/)> +static FailureOr +resolveReshapeStridedMetadata(RewriterBase &rewriter, + ReassociativeReshapeLikeOp reshape) { + // Build a plain extract_strided_metadata(memref) from + // extract_strided_metadata(reassociative_reshape_like(memref)). + Location origLoc = reshape.getLoc(); + Value source = reshape.getSrc(); + auto sourceType = cast(source.getType()); + unsigned sourceRank = sourceType.getRank(); + + auto newExtractStridedMetadata = + rewriter.create(origLoc, source); + + // Collect statically known information. + auto [strides, offset] = getStridesAndOffset(sourceType); + MemRefType reshapeType = reshape.getResultType(); + unsigned reshapeRank = reshapeType.getRank(); + + OpFoldResult offsetOfr = + ShapedType::isDynamic(offset) + ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) + : rewriter.getIndexAttr(offset); + + // Get the special case of 0-D out of the way. + if (sourceRank == 0) { + SmallVector ones(reshapeRank, rewriter.getIndexAttr(1)); + return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr, + /*sizes=*/ones, /*strides=*/ones}; + } + + SmallVector finalSizes; + finalSizes.reserve(reshapeRank); + SmallVector finalStrides; + finalStrides.reserve(reshapeRank); + + // Compute the reshaped strides and sizes from the base strides and sizes. + SmallVector origSizes = + getAsOpFoldResult(newExtractStridedMetadata.getSizes()); + SmallVector origStrides = + getAsOpFoldResult(newExtractStridedMetadata.getStrides()); + unsigned idx = 0, endIdx = reshape.getReassociationIndices().size(); + for (; idx != endIdx; ++idx) { + SmallVector reshapedSizes = + getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx); + SmallVector reshapedStrides = getReshapedStrides( + reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); + + unsigned groupSize = reshapedSizes.size(); + for (unsigned i = 0; i < groupSize; ++i) { + finalSizes.push_back(reshapedSizes[i]); + finalStrides.push_back(reshapedStrides[i]); + } + } + assert(((isa(reshape) && idx == sourceRank) || + (isa(reshape) && idx == reshapeRank)) && + "We should have visited all the input dimensions"); + assert(finalSizes.size() == reshapeRank && + "We should have populated all the values"); + + return StridedMetadata{newExtractStridedMetadata.getBaseBuffer(), offsetOfr, + finalSizes, finalStrides}; +} + /// Replace `baseBuffer, offset, sizes, strides = /// extract_strided_metadata(reshapeLike(memref))` /// With @@ -580,68 +652,66 @@ struct ReshapeFolder : public OpRewritePattern { LogicalResult matchAndRewrite(ReassociativeReshapeLikeOp reshape, PatternRewriter &rewriter) const override { - // Build a plain extract_strided_metadata(memref) from - // extract_strided_metadata(reassociative_reshape_like(memref)). - Location origLoc = reshape.getLoc(); - Value source = reshape.getSrc(); - auto sourceType = cast(source.getType()); - unsigned sourceRank = sourceType.getRank(); - - auto newExtractStridedMetadata = - rewriter.create(origLoc, source); - - // Collect statically known information. - auto [strides, offset] = getStridesAndOffset(sourceType); - MemRefType reshapeType = reshape.getResultType(); - unsigned reshapeRank = reshapeType.getRank(); - - OpFoldResult offsetOfr = - ShapedType::isDynamic(offset) - ? getAsOpFoldResult(newExtractStridedMetadata.getOffset()) - : rewriter.getIndexAttr(offset); - - // Get the special case of 0-D out of the way. - if (sourceRank == 0) { - SmallVector ones(reshapeRank, rewriter.getIndexAttr(1)); - auto memrefDesc = rewriter.create( - origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(), - offsetOfr, /*sizes=*/ones, /*strides=*/ones); - rewriter.replaceOp(reshape, memrefDesc.getResult()); - return success(); + FailureOr stridedMetadata = + resolveReshapeStridedMetadata( + rewriter, reshape); + if (failed(stridedMetadata)) { + return rewriter.notifyMatchFailure(reshape, + "failed to resolve reshape metadata"); } - SmallVector finalSizes; - finalSizes.reserve(reshapeRank); - SmallVector finalStrides; - finalStrides.reserve(reshapeRank); - - // Compute the reshaped strides and sizes from the base strides and sizes. - SmallVector origSizes = - getAsOpFoldResult(newExtractStridedMetadata.getSizes()); - SmallVector origStrides = - getAsOpFoldResult(newExtractStridedMetadata.getStrides()); - unsigned idx = 0, endIdx = reshape.getReassociationIndices().size(); - for (; idx != endIdx; ++idx) { - SmallVector reshapedSizes = - getReshapedSizes(reshape, rewriter, origSizes, /*groupId=*/idx); - SmallVector reshapedStrides = getReshapedStrides( - reshape, rewriter, origSizes, origStrides, /*groupId=*/idx); - - unsigned groupSize = reshapedSizes.size(); - for (unsigned i = 0; i < groupSize; ++i) { - finalSizes.push_back(reshapedSizes[i]); - finalStrides.push_back(reshapedStrides[i]); - } + rewriter.replaceOpWithNewOp( + reshape, reshape.getType(), stridedMetadata->basePtr, + stridedMetadata->offset, stridedMetadata->sizes, + stridedMetadata->strides); + return success(); + } +}; + +/// Pattern to replace `extract_strided_metadata(collapse_shape)` +/// With +/// +/// \verbatim +/// baseBuffer, baseOffset, baseSizes, baseStrides = +/// extract_strided_metadata(memref) +/// strides#i = baseStrides#i * subSizes#i +/// offset = baseOffset + sum(subOffset#i * baseStrides#i) +/// sizes = subSizes +/// \verbatim +/// +/// with `baseBuffer`, `offset`, `sizes` and `strides` being +/// the replacements for the original `extract_strided_metadata`. +struct ExtractStridedMetadataOpCollapseShapeFolder + : OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(memref::ExtractStridedMetadataOp op, + PatternRewriter &rewriter) const override { + auto collapseShapeOp = + op.getSource().getDefiningOp(); + if (!collapseShapeOp) + return failure(); + + FailureOr stridedMetadata = + resolveReshapeStridedMetadata(rewriter, + collapseShapeOp); + if (failed(stridedMetadata)) { + return rewriter.notifyMatchFailure( + op, "failed to resolve metadata in terms of source collapse_shape op"); } - assert(((isa(reshape) && idx == sourceRank) || - (isa(reshape) && idx == reshapeRank)) && - "We should have visited all the input dimensions"); - assert(finalSizes.size() == reshapeRank && - "We should have populated all the values"); - auto memrefDesc = rewriter.create( - origLoc, reshapeType, newExtractStridedMetadata.getBaseBuffer(), - offsetOfr, finalSizes, finalStrides); - rewriter.replaceOp(reshape, memrefDesc.getResult()); + + Location loc = collapseShapeOp.getLoc(); + SmallVector results; + results.push_back(stridedMetadata->basePtr); + results.push_back(getValueOrCreateConstantIndexOp(rewriter, loc, + stridedMetadata->offset)); + results.append( + getValueOrCreateConstantIndexOp(rewriter, loc, stridedMetadata->sizes)); + results.append(getValueOrCreateConstantIndexOp(rewriter, loc, + stridedMetadata->strides)); + rewriter.replaceOp(op, results); return success(); } }; @@ -1030,6 +1100,7 @@ void memref::populateResolveExtractStridedMetadataPatterns( RewritePatternSet &patterns) { patterns.add, ExtractStridedMetadataOpAllocFolder, + ExtractStridedMetadataOpCollapseShapeFolder, ExtractStridedMetadataOpGetGlobalFolder, ExtractStridedMetadataOpSubviewFolder, RewriteExtractAlignedPointerAsIndexOfViewLikeOp,