-
Notifications
You must be signed in to change notification settings - Fork 15.1k
[mlir][Bufferization] Fix to_buffer(tensor.cast) folder #150511
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-bufferization Author: Quinn Dawkins (qedawkins) ChangesPreviously this folder would ignore the layout and memory space on the to_buffer op and set it as default. This changes the pattern to retain both fields from the existing memref type but incorporate the static shape information from the tensor cast. Full diff: https://github.com/llvm/llvm-project/pull/150511.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
index 875a06546c9f0..c3b5476029ee5 100644
--- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
+++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp
@@ -804,10 +804,18 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
tensorCastOperand.getOperand().getType());
if (!srcTensorType)
return failure();
+ auto currentOutputMemRefType =
+ dyn_cast<MemRefType>(toBuffer.getResult().getType());
+ if (!currentOutputMemRefType)
+ return failure();
+
auto memrefType = MemRefType::get(srcTensorType.getShape(),
- srcTensorType.getElementType());
+ srcTensorType.getElementType(),
+ currentOutputMemRefType.getLayout(),
+ currentOutputMemRefType.getMemorySpace());
Value memref = rewriter.create<ToBufferOp>(toBuffer.getLoc(), memrefType,
- tensorCastOperand.getOperand());
+ tensorCastOperand.getOperand(),
+ toBuffer.getReadOnly());
rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
memref);
return success();
diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir
index f44e29071796d..2acd19453a04d 100644
--- a/mlir/test/Dialect/Bufferization/canonicalize.mlir
+++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir
@@ -255,16 +255,32 @@ func.func @clone_and_preceding_dealloc(%arg0: memref<?xf32>) -> memref<32xf32> {
func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) ->
memref<?x?x16x32xi8> {
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
- %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
+ %1 = bufferization.to_buffer %0 read_only : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
return %1 : memref<?x?x16x32xi8>
}
-// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
+// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8>
// CHECK: %[[M1:.+]] = memref.cast %[[M]]
// CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
// -----
+// CHECK-LABEL: func @tensor_cast_to_buffer
+// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
+func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8>) ->
+ memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> {
+ %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
+ %1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+ return %1 : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+}
+// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
+// CHECK: %[[M1:.+]] = memref.cast %[[M]]
+// CHECK-SAME: memref<4x6x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+// CHECK-SAME: to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+// CHECK: return %[[M1]] : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
+
+// -----
+
// Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
// CHECK-LABEL: func @load_from_buffer_cast(
func.func @load_from_buffer_cast(%arg0: index, %arg1: index,
|
Previously this folder would ignore the layout and memory space on the to_buffer op and set it as default. This changes the pattern to retain both fields from the existing memref type but incorporate the static shape information from the tensor cast.
dd967db to
3a20e38
Compare
|
✅ With the latest revision this PR passed the C/C++ code formatter. |
Previously this folder would ignore the layout and memory space on the to_buffer op and set it as default. This changes the pattern to retain both fields from the existing memref type but incorporate the static shape information from the tensor cast. The `read_only` attribute was also dropped by the pattern and is retained now as well.
| return failure(); | ||
| auto currentOutputMemRefType = | ||
| dyn_cast<MemRefType>(toBuffer.getResult().getType()); | ||
| if (!currentOutputMemRefType) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't know what to make out of this, but it looks like some Tensorflow tests started failing due to this early return. If I change it to this code and remove the return failure(), the test passes again:
Type memrefType;
if (currentOutputMemRefType) {
memrefType = MemRefType::get(srcTensorType.getShape(),
srcTensorType.getElementType(),
currentOutputMemRefType.getLayout(),
currentOutputMemRefType.getMemorySpace());
} else {
memrefType = MemRefType::get(srcTensorType.getShape(),
srcTensorType.getElementType());
}
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm not sure if that's what's happening here, but it looks like this PR changes the pattern so that it no longer applies to IR such as:
%0 = tensor.cast %src : ranked -> unranked
%1 = bufferization.to_buffer %0
Should we be looking for BaseMemRefType here instead of MemRefType?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@matthias-springer Thanks for the hint, I tried it out and it works. We can use BaseMemRefType::cloneWith to generate the new MemRefType that will preserve layout and memory space if the output type is MemRefType. I will prepare a PR.
… in canonic… (#152257) llvm/llvm-project#150511 changed the canonicalization pattern to not allow casts from ranked to unranked anymore. This patch restores this functionality, while still keeping the fix to preserve memory space and layout.
Previously this folder would ignore the layout and memory space on the to_buffer op and set it as default. This changes the pattern to retain both fields from the existing memref type but incorporate the static shape information from the tensor cast.
The
read_onlyattribute was also dropped by the pattern and is retained now as well.