Skip to content

[mlir] Introduce trailingNDimsContiguous for MemRefs #78247

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Feb 17, 2024

Conversation

banach-space
Copy link
Contributor

@banach-space banach-space commented Jan 16, 2024

Extracts logic from vector::isContiguousSlice to check whether
the trailing dim of a memref are contiguous into a dedicated hook
in BuiitinTypes.{h|cpp}.

Follow-up for #76848.

@llvmbot
Copy link
Member

llvmbot commented Jan 16, 2024

@llvm/pr-subscribers-mlir-vector
@llvm/pr-subscribers-mlir-core

@llvm/pr-subscribers-mlir

Author: Andrzej Warzyński (banach-space)

Changes

Extracts logic to check whether the trailing dim of a memref are
contiguous into a dedicated hook in BuiitinTypes.{h|cpp}.

Follow-up for #76848.


Full diff: https://github.com/llvm/llvm-project/pull/78247.diff

3 Files Affected:

  • (modified) mlir/include/mlir/IR/BuiltinTypes.h (+10)
  • (modified) mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp (+4-29)
  • (modified) mlir/lib/IR/BuiltinTypes.cpp (+28)
diff --git a/mlir/include/mlir/IR/BuiltinTypes.h b/mlir/include/mlir/IR/BuiltinTypes.h
index 92ce053ad5c829..2361cf1371237b 100644
--- a/mlir/include/mlir/IR/BuiltinTypes.h
+++ b/mlir/include/mlir/IR/BuiltinTypes.h
@@ -518,6 +518,16 @@ bool isStrided(MemRefType t);
 /// stride. Also return "true" for types with no strides.
 bool isLastMemrefDimUnitStride(MemRefType type);
 
+/// Return "true" if the last N dimensions of the given type are contiguous.
+///
+/// Examples:
+///   - memref<5x4x3x2xi8, strided<[24, 6, 2, 1]> is contiguous when
+///   considering both _all_ and _only_ the trailing 3 dims,
+///   - memref<5x4x3x2xi8, strided<[48, 6, 2, 1]> is _only_ contiguous when
+///   considering the trailing 3 dims.
+///
+bool trailingNDimsContiguous(MemRefType type, int64_t n);
+
 } // namespace mlir
 
 #endif // MLIR_IR_BUILTINTYPES_H
diff --git a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
index 377f3d8c557474..7e62f0bfe88647 100644
--- a/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
+++ b/mlir/lib/Dialect/Vector/Utils/VectorUtils.cpp
@@ -257,38 +257,13 @@ bool vector::isContiguousSlice(MemRefType memrefType, VectorType vectorType) {
   ArrayRef<int64_t> vectorShape = vectorType.getShape();
   auto vecRank = vectorType.getRank();
 
-  // Extract the trailing dims and strides of the input memref
-  auto memrefShape = memrefType.getShape().take_back(vecRank);
-  int64_t offset;
-  SmallVector<int64_t> stridesFull;
-  if (!succeeded(getStridesAndOffset(memrefType, stridesFull, offset)))
-    return false;
-  auto strides = ArrayRef<int64_t>(stridesFull).take_back(vecRank);
-  memrefType.getLayout().isIdentity();
-
-  // TODO: Add support for memref with trailing dynamic shapes. Memrefs
-  // with leading dynamic dimensions are already supported.
-  if (ShapedType::isDynamicShape(memrefShape))
-    return false;
-
-  // Cond 1: Check whether `memrefType` is contiguous.
-  if (!strides.empty()) {
-    // Cond 1.1: A contiguous memref will always have a unit trailing stride.
-    if (strides.back() != 1)
+  if (!trailingNDimsContiguous(memrefType, vecRank))
       return false;
 
-    // Cond 1.2: Strides of a contiguous memref have to match the flattened
-    // dims.
-    strides = strides.drop_back(1);
-    SmallVector<int64_t> flattenedDims;
-    for (size_t i = 1; i < memrefShape.size(); i++)
-      flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
-
-    if (!llvm::equal(strides, llvm::reverse(flattenedDims)))
-      return false;
-  }
+  // Extract the trailing dims and strides of the input memref
+  auto memrefShape = memrefType.getShape().take_back(vecRank);
 
-  // Cond 2: Compare the dims of `vectorType` against `memrefType` (in reverse).
+  // Compare the dims of `vectorType` against `memrefType` (in reverse).
   // In the most basic case, all dims will match.
   auto firstNonMatchingDim =
       std::mismatch(vectorShape.rbegin(), vectorShape.rend(),
diff --git a/mlir/lib/IR/BuiltinTypes.cpp b/mlir/lib/IR/BuiltinTypes.cpp
index 9b8ee3d4528035..795fa3e626d42b 100644
--- a/mlir/lib/IR/BuiltinTypes.cpp
+++ b/mlir/lib/IR/BuiltinTypes.cpp
@@ -8,6 +8,7 @@
 
 #include "mlir/IR/BuiltinTypes.h"
 #include "TypeDetail.h"
+#include "mlir/Dialect/Utils/IndexingUtils.h"
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/BuiltinAttributes.h"
@@ -967,3 +968,30 @@ bool mlir::isLastMemrefDimUnitStride(MemRefType type) {
   auto successStrides = getStridesAndOffset(type, strides, offset);
   return succeeded(successStrides) && (strides.empty() || strides.back() == 1);
 }
+
+bool mlir::trailingNDimsContiguous(MemRefType type, int64_t n) {
+  if (!isLastMemrefDimUnitStride(type))
+      return false;
+
+  auto memrefShape = type.getShape().take_back(n);
+  int64_t offset;
+  SmallVector<int64_t> stridesFull;
+  if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
+    return false;
+  auto strides = ArrayRef<int64_t>(stridesFull).take_back(n);
+
+  if (ShapedType::isDynamicShape(memrefShape))
+    return false;
+
+  if (strides.empty())
+    return true;
+
+  // Strides of a contiguous memref have to match the flattened
+  // dims.
+  strides = strides.drop_back(1);
+  SmallVector<int64_t> flattenedDims;
+  for (size_t i = 1; i < memrefShape.size(); i++)
+    flattenedDims.push_back(mlir::computeProduct(memrefShape.take_back(i)));
+
+  return llvm::equal(strides, llvm::reverse(flattenedDims));
+}

Copy link

github-actions bot commented Jan 16, 2024

✅ With the latest revision this PR passed the C/C++ code formatter.

Extracts logic to check whether the trailing dim of a memref are
contiguous into a dedicated hook in BuiitinTypes.{h|cpp}.

Follow-up for llvm#76848.
@banach-space banach-space force-pushed the andrzej/refactor_is_contiguous branch from 7eaa746 to 86e8587 Compare January 16, 2024 11:11
auto memrefShape = type.getShape().take_back(n);
int64_t offset;
SmallVector<int64_t> stridesFull;
if (!succeeded(getStridesAndOffset(type, stridesFull, offset)))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the memref layout being taken into account anywhere? That's one of the differences between vector and memref.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just added a check that leverages isIdentity.

Is the memref layout being taken into account anywhere?

I thought that that's what getStridesAndOffset does (i.e. returns the layout as offsets and strides)? I am not 100% sure though, this is mostly just moving things around.

@banach-space
Copy link
Contributor Author

Kind ping :)

Copy link
Contributor

@dcaballe dcaballe left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Just a minor comment

@banach-space banach-space merged commit 9478bf0 into llvm:main Feb 17, 2024
@banach-space banach-space deleted the andrzej/refactor_is_contiguous branch March 3, 2024 16:15
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants