Skip to content

Conversation

@qedawkins
Copy link
Contributor

Previously this folder would ignore the layout and memory space on the to_buffer op and set it as default. This changes the pattern to retain both fields from the existing memref type but incorporate the static shape information from the tensor cast.
The read_only attribute was also dropped by the pattern and is retained now as well.

@llvmbot llvmbot added mlir mlir:bufferization Bufferization infrastructure labels Jul 24, 2025
@llvmbot
Copy link
Member

llvmbot commented Jul 24, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-bufferization

Author: Quinn Dawkins (qedawkins)

Changes

Previously this folder would ignore the layout and memory space on the to_buffer op and set it as default. This changes the pattern to retain both fields from the existing memref type but incorporate the static shape information from the tensor cast.
The read_only attribute was also dropped by the pattern and is retained now as well.


Full diff: https://github.com/llvm/llvm-project/pull/150511.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp (+10-2)
  • (modified) mlir/test/Dialect/Bufferization/canonicalize.mlir (+18-2)
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 875a06546c9f0..c3b5476029ee5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -804,10 +804,18 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
         tensorCastOperand.getOperand().getType());
     if (!srcTensorType)
       return failure();
+    auto currentOutputMemRefType =
+        dyn_cast<MemRefType>(toBuffer.getResult().getType());
+    if (!currentOutputMemRefType)
+      return failure();
+
     auto memrefType = MemRefType::get(srcTensorType.getShape(),
-                                      srcTensorType.getElementType());
+                                      srcTensorType.getElementType(),
+                                      currentOutputMemRefType.getLayout(),
+                                      currentOutputMemRefType.getMemorySpace());
     Value memref = rewriter.create<ToBufferOp>(toBuffer.getLoc(), memrefType,
-                                               tensorCastOperand.getOperand());
+                                               tensorCastOperand.getOperand(),
+                                               toBuffer.getReadOnly());
     rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
                                                 memref);
     return success();
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index f44e29071796d..2acd19453a04d 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -255,16 +255,32 @@ func.func @clone_and_preceding_dealloc(%arg0: memref<?xf32>) -> memref<32xf32> {
 func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) ->
   memref<?x?x16x32xi8> {
   %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
-  %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
+  %1 = bufferization.to_buffer %0 read_only : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
   return %1 : memref<?x?x16x32xi8>
 }
-// CHECK:   %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
+// CHECK:   %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8>
 // CHECK:   %[[M1:.+]] = memref.cast %[[M]]
 // CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
 // CHECK:   return %[[M1]] : memref<?x?x16x32xi8>
 
 // -----
 
+// CHECK-LABEL: func @tensor_cast_to_buffer
+//  CHECK-SAME:   %[[ARG0:.+]]: tensor<4x6x16x32xi8>
+func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8>) ->
+  memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> {
+  %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+  %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+  return %1 : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+}
+// CHECK:   %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
+// CHECK:   %[[M1:.+]] = memref.cast %[[M]]
+// CHECK-SAME: memref<4x6x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+// CHECK-SAME: to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+// CHECK:   return %[[M1]] : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+
+// -----
+
 // Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
 // CHECK-LABEL: func @load_from_buffer_cast(
 func.func @load_from_buffer_cast(%arg0: index, %arg1: index,

Previously this folder would ignore the layout and memory space on the
to_buffer op and set it as default. This changes the pattern to retain
both fields from the existing memref type but incorporate the static
shape information from the tensor cast.
@qedawkins qedawkins force-pushed the fix_bufferization_canon branch from dd967db to 3a20e38 Compare July 24, 2025 21:15
@github-actions
Copy link

github-actions bot commented Jul 24, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

@qedawkins qedawkins merged commit fd8f69d into llvm:main Jul 24, 2025
9 checks passed
@qedawkins qedawkins deleted the fix_bufferization_canon branch July 24, 2025 21:50
mahesh-attarde pushed a commit to mahesh-attarde/llvm-project that referenced this pull request Jul 28, 2025
Previously this folder would ignore the layout and memory space on the
to_buffer op and set it as default. This changes the pattern to retain
both fields from the existing memref type but incorporate the static
shape information from the tensor cast.
The `read_only` attribute was also dropped by the pattern and is
retained now as well.
return failure();
auto currentOutputMemRefType =
dyn_cast<MemRefType>(toBuffer.getResult().getType());
if (!currentOutputMemRefType)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't know what to make out of this, but it looks like some Tensorflow tests started failing due to this early return. If I change it to this code and remove the return failure(), the test passes again:

Type memrefType;
if (currentOutputMemRefType) {
  memrefType = MemRefType::get(srcTensorType.getShape(),
                               srcTensorType.getElementType(),
                               currentOutputMemRefType.getLayout(),
                               currentOutputMemRefType.getMemorySpace());
} else {
  memrefType = MemRefType::get(srcTensorType.getShape(),
                               srcTensorType.getElementType());
}

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure if that's what's happening here, but it looks like this PR changes the pattern so that it no longer applies to IR such as:

%0 = tensor.cast %src : ranked -> unranked
%1 = bufferization.to_buffer %0

Should we be looking for BaseMemRefType here instead of MemRefType?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@matthias-springer Thanks for the hint, I tried it out and it works. We can use BaseMemRefType::cloneWith to generate the new MemRefType that will preserve layout and memory space if the output type is MemRefType. I will prepare a PR.

akuegel added a commit that referenced this pull request Aug 6, 2025
#152257)

#150511 changed the
canonicalization pattern to not allow casts from ranked to unranked
anymore. This patch restores this functionality, while still keeping the
fix to preserve memory space and layout.
llvm-sync bot pushed a commit to arm/arm-toolchain that referenced this pull request Aug 6, 2025
… in canonic… (#152257)

llvm/llvm-project#150511 changed the
canonicalization pattern to not allow casts from ranked to unranked
anymore. This patch restores this functionality, while still keeping the
fix to preserve memory space and layout.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:bufferization Bufferization infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants