-
Notifications
You must be signed in to change notification settings - Fork 14.7k
[mlir] Introduce trailingNDimsContiguous
for MemRefs
#78247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[mlir] Introduce trailingNDimsContiguous
for MemRefs
#78247
Conversation
@llvm/pr-subscribers-mlir-vector @llvm/pr-subscribers-mlir Author: Andrzej Warzyński (banach-space) ChangesExtracts logic to check whether the trailing dim of a memref are Follow-up for #76848. Full diff: https://github.com/llvm/llvm-project/pull/78247.diff 3 Files Affected:
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c829..2361cf1371237b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -518,6 +518,16 @@ bool isStrided(MemRefType t);
/// stride. Also return "true" for types with no strides.
bool isLastMemrefDimUnitStride(MemRefType type);
+/// Return "true" if the last N dimensions of the given type are contiguous.
+///
+/// Examples:
+/// - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
+/// considering both _all_ and _only_ the trailing 3 dims,
+/// - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
+/// considering the trailing 3 dims.
+///
+bool trailingNDimsContiguous(MemRefType type, int64_t n);
+
} // namespace mlir
#endif // MLIR_IR_BUILTINTYPES_H
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 377f3d8c557474..7e62f0bfe88647 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -257,38 +257,13 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
ArrayRef<int64_t> vectorShape = vectorType.getShape();
auto vecRank = vectorType.getRank();
- // Extract the trailing dims and strides of the input memref
- auto memrefShape = memrefType.getShape().take_back(vecRank);
- int64_t offset;
- SmallVector<int64_t> stridesFull;
- if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset)))
- return false;
- auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
- memrefType.getLayout().isIdentity();
-
- // TODO: Add support for memref with trailing dynamic shapes. Memrefs
- // with leading dynamic dimensions are already supported.
- if (ShapedType::isDynamicShape(memrefShape))
- return false;
-
- // Cond 1: Check whether `memrefType` is contiguous.
- if (!strides.empty()) {
- // Cond 1.1: A contiguous memref will always have a unit trailing stride.
- if (strides.back() != 1)
+ if (!trailingNDimsContiguous(memrefType, vecRank))
return false;
- // Cond 1.2: Strides of a contiguous memref have to match the flattened
- // dims.
- strides = strides.drop_back(1);
- SmallVector<int64_t> flattenedDims;
- for (size_t i = 1; i < memrefShape.size(); i++)
- flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
-
- if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
- return false;
- }
+ // Extract the trailing dims and strides of the input memref
+ auto memrefShape = memrefType.getShape().take_back(vecRank);
- // Cond 2: Compare the dims of `vectorType` against `memrefType` (in reverse).
+ // Compare the dims of `vectorType` against `memrefType` (in reverse).
// In the most basic case, all dims will match.
auto firstNonMatchingDim =
std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 9b8ee3d4528035..795fa3e626d42b 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -8,6 +8,7 @@
#include "mlir/IR/BuiltinTypes.h"
#include "TypeDetail.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/AffineExpr.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/BuiltinAttributes.h"
@@ -967,3 +968,30 @@ bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
auto successStrides = getStridesAndOffset(type, strides, offset);
return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
}
+
+bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
+ if (!isLastMemrefDimUnitStride(type))
+ return false;
+
+ auto memrefShape = type.getShape().take_back(n);
+ int64_t offset;
+ SmallVector<int64_t> stridesFull;
+ if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
+ return false;
+ auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
+
+ if (ShapedType::isDynamicShape(memrefShape))
+ return false;
+
+ if (strides.empty())
+ return true;
+
+ // Strides of a contiguous memref have to match the flattened
+ // dims.
+ strides = strides.drop_back(1);
+ SmallVector<int64_t> flattenedDims;
+ for (size_t i = 1; i < memrefShape.size(); i++)
+ flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
+
+ return llvm::equal(strides, llvm::reverse(flattenedDims));
+}
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Extracts logic to check whether the trailing dim of a memref are contiguous into a dedicated hook in BuiitinTypes.{h|cpp}. Follow-up for llvm#76848.
7eaa746
to
86e8587
Compare
Fix Windows build
auto memrefShape = type.getShape().take_back(n); | ||
int64_t offset; | ||
SmallVector<int64_t> stridesFull; | ||
if (!succeeded(getStridesAndOffset(type, stridesFull, offset))) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is the memref layout being taken into account anywhere? That's one of the differences between vector and memref.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just added a check that leverages isIdentity.
Is the memref layout being taken into account anywhere?
I thought that that's what getStridesAndOffset
does (i.e. returns the layout as offsets and strides)? I am not 100% sure though, this is mostly just moving things around.
Use memref's getLayout().isIdentity()
Kind ping :) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM! Just a minor comment
Extracts logic from
vector::isContiguousSlice
to check whetherthe trailing dim of a memref are contiguous into a dedicated hook
in BuiitinTypes.{h|cpp}.
Follow-up for #76848.