diff --git a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp index d2ac850a5f70b..d52ff4d4257c7 100644 --- a/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp +++ b/mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp @@ -298,16 +298,156 @@ struct LegalizeSVEMaskLoadConversion : public OpRewritePattern { } }; +/// Transforms a `transfer_read` operation so it reads vector of a type that +/// can be mapped to an LLVM type ("LLVM-legal" type). This is done by +/// collapsing trailing dimensions so we obtain a vector type with a single +/// scalable dimension in the rightmost position. +/// +/// Example: +/// ``` +/// %v = vector.transfer_read %M[%i, %j, %c0, %c0], %c0_i8 +/// {in_bounds = [false, true, true, true]} +/// : memref, vector<2x[4]x2x8xi8> +/// ``` +/// is rewritten to +/// ``` +/// %collapse_shape = memref.collapse_shape %M [[0], [1, 2, 3]] +/// : memref into memref +/// %0 = vector.transfer_read %collapse_shape[%i, %j], %c0_i8 +/// {in_bounds = [false, true]} +/// : memref, vector<2x[64]xi8> +/// %1 = vector.shape_cast %0 : vector<2x[64]xi8> to vector<2x[4]x2x8xi8> +/// ``` +struct LegalizeTransferRead : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(vector::TransferReadOp readOp, + PatternRewriter &rewriter) const override { + + // Do not try to transform masked reads. For example, if we have a transfer + // to a `vector<[4]x4xi8>` we could have a mask like + // 1 1 1 0 + // 1 1 1 0 + // 1 1 1 0 + // 0 0 0 0 + // Flattening this mask would look like + // 1 1 1 0 1 1 1 0 1 1 1 0 0 0 0 0 + // and we have not yet figured out an efficient way to build such a mask, + // neither from the mask operand, nor from the original `vector.create_mask` + // operation (if visible at all). + if (readOp.isMasked() || readOp.getMask()) + return rewriter.notifyMatchFailure(readOp, + "masked transfers not-supported"); + + // General permutation maps are not supported. The issue is with transpose, + // broadcast, and other forms of non-identify mapping in the minor + // dimensions which is impossible to represent after collapsing (at least + // because the resulting "collapsed" maps would have smaller number of + // dimension indices). + // TODO: We have not had yet the need for it, but some forms of permutation + // maps with identity in the minor dimensions voukld be supported, for + // example `(i, j, k, p) -> (j, i, k, p)` where we need to collapse only `k` + // and `p`. + if (!readOp.getPermutationMap().isMinorIdentity()) + return rewriter.notifyMatchFailure(readOp, "non-identity permutation"); + + // We handle transfers of vectors with rank >= 2 and a single scalable + // dimension. This transformation aims to transform an LLVM-illegal type + // into an LLVM-legal type and one dimensional vectors are already + // LLVM-legal, even if scalable. A value of a vector type with more than one + // scalable dimension is impossible to represent using a vector type with no + // scalable dimensions or a single one. For example a `vector<[4]x[4]xi8>` + // would have `4 * 4 * vscale * vscale` elements and this quantity is + // impossible to represent as `N` or `N * vscale` (where `N` is a constant). + VectorType origVT = readOp.getVectorType(); + ArrayRef origScalableDims = origVT.getScalableDims(); + const int64_t origVRank = origVT.getRank(); + if (origVRank < 2 || origVT.getNumScalableDims() != 1) + return rewriter.notifyMatchFailure(readOp, "wrong dimensions"); + + // Number of trailing dimensions to collapse, including the scalable + // dimension. Nothing to do if the single scalable dimension is already the + // last one. + const int64_t numCollapseDims = std::distance( + llvm::find(origScalableDims, true), origScalableDims.end()); + if (numCollapseDims < 2) + return rewriter.notifyMatchFailure(readOp, + "scalable dimension is trailing"); + + // We want a simple memref (not a tensor) with contiguous elements for at + // least all the trailing dimensions up to and including the scalable one. + auto memTy = dyn_cast(readOp.getBase().getType()); + if (!(memTy && memTy.areTrailingDimsContiguous(numCollapseDims))) + return rewriter.notifyMatchFailure( + readOp, "non-contiguous memref dimensions to collapse"); + + // The dimensions to collapse (excluding the scalable one) of the vector and + // the memref must match. A dynamic memref dimension is considered + // non-matching. The transfers from the dimensions to collapse must be + // in-bounds (it follows the corresponding indices would be zero). This + // guarantees that the operation transfers a contiguous block + // and no padding is necessary. + if (!llvm::equal(memTy.getShape().take_back(numCollapseDims - 1), + origVT.getShape().take_back(numCollapseDims - 1))) + return rewriter.notifyMatchFailure( + readOp, "memref and vector dimensions do not match"); + + SmallVector origInBounds = readOp.getInBoundsValues(); + if (!llvm::all_of( + ArrayRef(origInBounds).take_back(numCollapseDims - 1), + [](bool v) { return v; })) + return rewriter.notifyMatchFailure( + readOp, "out-of-bounds transfer from a dimension to collapse"); + + // Collapse the trailing dimensions of the memref. + SmallVector reassoc; + for (int64_t i = 0; i < memTy.getRank() - numCollapseDims + 1; ++i) + reassoc.push_back({i}); + for (int64_t i = memTy.getRank() - numCollapseDims + 1; i < memTy.getRank(); + ++i) + reassoc.back().push_back(i); + if (!memref::CollapseShapeOp::isGuaranteedCollapsible(memTy, reassoc)) + return failure(); + Value collapsedMem = rewriter.create( + readOp.getLoc(), readOp.getBase(), reassoc); + + // Get a vector type with collapsed trailing dimensions. + SmallVector shape(origVT.getShape()); + for (int64_t i = origVRank - numCollapseDims + 1; i < origVRank; ++i) + shape[origVRank - numCollapseDims] *= shape[i]; + shape.pop_back_n(numCollapseDims - 1); + auto collapsedVT = + VectorType::get(shape, origVT.getElementType(), + origScalableDims.drop_back(numCollapseDims - 1)); + + // Drop the extra (zero) indices. + auto indices = readOp.getIndices().drop_back(numCollapseDims - 1); + + // Create the new `transfer_read`. + auto newReadOp = rewriter.create( + readOp.getLoc(), collapsedVT, collapsedMem, indices, + ArrayRef(origInBounds).drop_back(numCollapseDims - 1)); + + // Cast back to the original vector type. + auto toOrigShape = rewriter.create(readOp.getLoc(), + origVT, newReadOp); + + rewriter.replaceOp(readOp, toOrigShape); + return success(); + } +}; + } // namespace void mlir::arm_sve::populateLegalizeVectorStoragePatterns( RewritePatternSet &patterns) { - patterns.add, - LegalizeSVEMaskAllocation, - LegalizeSVEMaskTypeCastConversion, - LegalizeSVEMaskStoreConversion, LegalizeSVEMaskLoadConversion>( - patterns.getContext()); + patterns + .add, + LegalizeSVEMaskAllocation, + LegalizeSVEMaskTypeCastConversion, LegalizeSVEMaskStoreConversion, + LegalizeSVEMaskLoadConversion, LegalizeTransferRead>( + patterns.getContext()); } namespace { diff --git a/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir new file mode 100644 index 0000000000000..5f923cdafb956 --- /dev/null +++ b/mlir/test/Dialect/ArmSVE/legalize-transfer-read.mlir @@ -0,0 +1,257 @@ +// RUN: mlir-opt --arm-sve-legalize-vector-storage --split-input-file %s | FileCheck %s + + +// Test the `LegalizeTransferRead` pattern +// (mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp) + +// ----- + +// This is the base case, unremarkable in any way, except that it's our main +// motivating example and use case. + +// CHECK-LABEL: @base_case +// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]: +// CHECK: %[[PAD:.+]] = arith.constant 0 : i8 +// CHECK: %[[C0:.+]] = arith.constant 0 : index +// CHECK: %[[COLLAPSE:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1], [2, 3]] +// CHECK-SAME: : memref into memref +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSE]][%[[I]], %[[J]], %[[C0]]], %[[PAD]] {in_bounds = [true]} +// CHECK-SAME: : memref, vector<[32]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[32]xi8> to vector<[4]x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x8xi8> + +func.func @base_case(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %pad = arith.constant 123 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// ----- + +// Test the case where the scalable dimension is not the second-to-last. + +// CHECK-LABEL: @with_3d_vector +// CHECK-SAME: %[[I:.+]]: index, %[[J:.+]]: index, %[[M:.+]]: +// CHECK: %[[PAD:.+]] = arith.constant 0 : i8 +// CHECK: %[[COLLAPSED:.+]] = memref.collapse_shape %[[M]] +// CHECK-SAME{LITERAL}: [[0], [1, 2, 3]] +// CHECK-SAME: : memref into memref +// CHECK-NEXT: %[[T0:.+]] = vector.transfer_read %[[COLLAPSED]][%[[I]], %[[J]]], %[[PAD]] {in_bounds = [true]} +// CHECK-SAME: : memref, vector<[64]xi8> +// CHECK-NEXT: %[[T1:.+]] = vector.shape_cast %[[T0]] : vector<[64]xi8> to vector<[4]x2x8xi8> +// CHECK-NEXT: return %[[T1]] : vector<[4]x2x8xi8> + +func.func @with_3d_vector(%i : index, %j : index, %M : memref) -> vector<[4]x2x8xi8> { + %c0 = arith.constant 0 : index + %pad = arith.constant 123 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %pad {in_bounds = [true, true, true]} : memref, vector<[4]x2x8xi8> + + return %A : vector<[4]x2x8xi8> +} + +// ----- + +// Test the case when the vector is already LLVM-legal (fixed). + +// CHECK-LABEL: @negative_vector_legal_fixed +// CHECK-NOT: memref.collapse + +func.func @negative_vector_legal_fixed(%i : index, %j : index, %M : memref) -> vector<8x8xi8> { + %c0 = arith.constant 0 : index + %pad = arith.constant 123 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<8x8xi8> + + return %A : vector<8x8xi8> +} + +// ----- + +// Test the case when the vector is already LLVM-legal (single-dimension scalable). + +// CHECK-LABEL: @negative_vector_legal_1d_scalable +// CHECK-NOT: memref.collapse + +func.func @negative_vector_legal_1d_scalable(%i : index, %j : index, %M : memref) -> vector<[8]xi8> { + %c0 = arith.constant 0 : index + %pad = arith.constant 123 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %pad {in_bounds = [true]} : memref, vector<[8]xi8> + + return %A : vector<[8]xi8> +} + +// ----- + +// Test the case when the vector is already LLVM-legal (single trailing +// scalable dimension). + +// CHECK-LABEL: @negative_vector_legal_trailing_scalable_dim +// CHECK-NOT: memref.collapse + +func.func @negative_vector_legal_trailing_scalable_dim(%i : index, %j : index, %M : memref) -> vector<8x[8]xi8> { + %c0 = arith.constant 0 : index + %pad = arith.constant 123 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<8x[8]xi8> + + return %A : vector<8x[8]xi8> +} + +// ----- + +// Test the case of unsupported vector type (more than one scalable dimension) + +// CHECK-LABEL: @negative_vector_type_two_scalable_dims +// CHECK-NOT: memref.collapse + +func.func @negative_vector_type_two_scalable_dims(%i : index, %j : index, %M : memref) -> vector<[8]x[8]x8xi8> { + %c0 = arith.constant 0 : index + %pad = arith.constant 123 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %pad {in_bounds = [true, true, true]} : memref, vector<[8]x[8]x8xi8> + + return %A : vector<[8]x[8]x8xi8> +} + +// ----- + +// Test the case of reading from a tensor - not supported, since the +// transform reasons about memory layouts. + +// CHECK-LABEL: @negative_tensor_transfer +// CHECK-NOT: memref.collapse + +func.func @negative_tensor_transfer(%i : index, %j : index, %M : tensor) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %pad = arith.constant 123 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %pad {in_bounds = [true, true]} : tensor, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// ----- + +// Test the case when the transfer is discontiguous because the memref +// is discontiguous. +// There are other ways to make a memref discontiguous. The transformation +// is not concerned with the particular reason a memref is discontiguous, but +// only with the fact. Therefore there are no variations with the memref made +// discontiguous by some other mechanism. + +// CHECK-LABEL: @negative_discontig_mem +// CHECK-NOT: memref.collapse + +#strides = strided<[?, ?, 16, 1]> + +func.func @negative_discontig_mem(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %pad = arith.constant 123 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %pad {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// ----- + +// Test the case when the transformation is not applied because of +// a non-trivial permutation map (broadcast). + +// CHECK-LABEL: @negative_broadcast +// CHECK-NOT: memref.collapse + +#perm = affine_map<(i, j, k, p) -> (k, 0)> + +func.func @negative_broadcast(%i : index, %j : index, %M : memref) -> vector<[4]x8xi8> { + %c0 = arith.constant 0 : index + %pad = arith.constant 123 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %pad {permutation_map = #perm, in_bounds = [true, true] } : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// ----- + +// Test the case of a masked read - not supported right now. +// (see mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp) + +// CHECK-LABEL: @negative_masked +// CHECK-NOT: memref.collapse + +func.func @negative_masked( + %i : index, %j : index, + %M : memref, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> { + + %c0 = arith.constant 0 : index + %pad = arith.constant 123 : i8 + + %A = vector.mask %mask { + vector.transfer_read %M[%i, %j, %c0, %c0], %pad {in_bounds = [true, true] } : memref, vector<[4]x8xi8> + } : vector<[4]x8xi1> -> vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// ----- + +// Test case with a mask operand - not supported right now. +// (see mlir/lib/Dialect/ArmSVE/Transforms/LegalizeVectorStorage.cpp) + +// CHECK-LABEL: @negative_with_mask +// CHECK-NOT: memref.collapse + +func.func @negative_with_mask( + %i : index, %j : index, + %M : memref, %mask : vector<[4]x8xi1>) -> vector<[4]x8xi8> { + + %c0 = arith.constant 0 : index + %pad = arith.constant 123 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %pad, %mask {in_bounds = [true, true] } : memref, vector<[4]x8xi8> + + return %A : vector<[4]x8xi8> +} + +// ----- + +// Test the case when the dimensions to collapse (excluding the scalable one) +// of the vector and the memref do not match (static non matching dimension). + +// CHECK-LABEL: @negative_non_matching_dim_static +// CHECK-NOT: memref.collapse + +func.func @negative_non_matching_dim_static(%i : index, %j : index, %M : memref) -> vector<[4]x4xi8> { + + %c0 = arith.constant 0 : index + %pad = arith.constant 123 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %pad {in_bounds = [true, true] } : memref, vector<[4]x4xi8> + + return %A : vector<[4]x4xi8> +} + +// ----- + +// Test the case when the dimensions to collapse (excluding the scalable one) +// of the vector and the memref do not match (dynamic non matching dimension). + +// CHECK-LABEL: @negative_non_matching_dim_dynamic +// CHECK-NOT: memref.collapse + +func.func @negative_non_matching_dim_dynamic(%i : index, %j : index, %M : memref) -> vector<[4]x4xi8> { + + %c0 = arith.constant 0 : index + %pad = arith.constant 123 : i8 + + %A = vector.transfer_read %M[%i, %j, %c0, %c0], %pad {in_bounds = [true, true] } : memref, vector<[4]x4xi8> + + return %A : vector<[4]x4xi8> +} diff --git a/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-non-trailing.mlir b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-non-trailing.mlir new file mode 100644 index 0000000000000..36fdb60d3e7bf --- /dev/null +++ b/mlir/test/Integration/Dialect/Vector/CPU/ArmSVE/transfer-read-scalable-non-trailing.mlir @@ -0,0 +1,79 @@ +// REQUIRES: arm-emulator + +// DEFINE: %{compile} = mlir-opt %s \ +// DEFINE: --arm-sve-legalize-vector-storage --convert-vector-to-scf --convert-scf-to-cf --convert-vector-to-llvm='enable-arm-sve' \ +// DEFINE: --expand-strided-metadata --lower-affine --convert-to-llvm --finalize-memref-to-llvm --reconcile-unrealized-casts \ +// DEFINE: -o %t + +// DEFINE: %{entry_point} = main + +// DEFINE: %{run} = %mcr_aarch64_cmd %t -e %{entry_point} -entry-point-result=void --march=aarch64 --mattr="+sve" \ +// DEFINE: -shared-libs=%mlir_runner_utils,%mlir_c_runner_utils,%native_mlir_arm_runner_utils + +// RUN: rm -f %t && %{compile} && %{run} | FileCheck %s + +// Test the transfer_read with vector type with a non-trailing scalable +// dimension as transformed by the pattern LegalizeTransferRead. + +func.func @transfer_read_scalable_non_trailing(%vs : i32, %M : memref) { + func.call @setArmVLBits(%vs) : (i32) -> () + + // Read an LLVM-illegal vector + %c0 = arith.constant 0 : index + %c0_i8 = arith.constant 0 : i8 + %A = vector.transfer_read %M[%c0, %c0], %c0_i8 {in_bounds = [true, true]} : memref, vector<[4]x8xi8> + + // Print the vector, for verification. + %B = vector.shape_cast %A : vector<[4]x8xi8> to vector<[32]xi8> + func.call @printVec(%B) : (vector<[32]xi8>) -> () + + return +} + +func.func @main() { + + %c0 = arith.constant 0 : index + +// Prepare an 8x8 buffer with test data. The test performs two reads +// of a [4]x8 vector from the buffer. One read, with vector length 128 bits, +// reads the first half the buffer. The other read, with vector length +// 256 bits, reads the entire buffer. + %T = arith.constant dense<[[11, 12, 13, 14, 15, 16, 17, 18], + [21, 22, 23, 24, 25, 26, 27, 28], + [31, 32, 33, 34, 35, 36, 37, 38], + [41, 42, 43, 44, 45, 46, 47, 48], + [51, 52, 53, 54, 55, 56, 57, 58], + [61, 62, 63, 64, 65, 66, 67, 68], + [71, 72, 73, 74, 75, 76, 77, 78], + [81, 82, 83, 84, 85, 86, 87, 88]]> : vector<8x8xi8> + + %M = memref.alloca() : memref<8x8xi8> + vector.transfer_write %T, %M[%c0, %c0] : vector<8x8xi8>, memref<8x8xi8> + %MM = memref.cast %M : memref<8x8xi8> to memref + +// CHECK-LABEL: Result(VL128): +// CHECK:( 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28 ) +// CHECK:( 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 ) + vector.print str "Result(VL128):\n" + %c128 = arith.constant 128 : i32 + func.call @transfer_read_scalable_non_trailing(%c128, %MM) : (i32, memref) -> () + +// CHECK-LABEL: Result(VL256): +// CHECK: ( 11, 12, 13, 14, 15, 16, 17, 18, 21, 22, 23, 24, 25, 26, 27, 28, 31, 32, 33, 34, 35, 36, 37, 38, 41, 42, 43, 44, 45, 46, 47, 48 ) +// CHECK: ( 51, 52, 53, 54, 55, 56, 57, 58, 61, 62, 63, 64, 65, 66, 67, 68, 71, 72, 73, 74, 75, 76, 77, 78, 81, 82, 83, 84, 85, 86, 87, 88 ) + vector.print str "Result(VL256):\n" + %c256 = arith.constant 256 : i32 + func.call @transfer_read_scalable_non_trailing(%c256, %MM) : (i32, memref) -> () + + return +} + +func.func private @printVec(%v : vector<[32]xi8>) { + %v0 = vector.scalable.extract %v[0] : vector<[16]xi8> from vector<[32]xi8> + %v1 = vector.scalable.extract %v[16] : vector<[16]xi8> from vector<[32]xi8> + vector.print %v0 : vector<[16]xi8> + vector.print %v1 : vector<[16]xi8> + return +} + +func.func private @setArmVLBits(%bits : i32)