Skip to content

Commit 3a20e38

Browse files
committed
[mlir][Bufferization] Fix to_buffer(tensor.cast) folder
Previously this folder would ignore the layout and memory space on the to_buffer op and set it as default. This changes the pattern to retain both fields from the existing memref type but incorporate the static shape information from the tensor cast.
1 parent ea7d813 commit 3a20e38

File tree

2 files changed

+28
-4
lines changed

2 files changed

+28
-4
lines changed

mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -805,10 +805,18 @@ struct ToBufferOfCast : public OpRewritePattern<ToBufferOp> {
805805
tensorCastOperand.getOperand().getType());
806806
if (!srcTensorType)
807807
return failure();
808+
auto currentOutputMemRefType =
809+
dyn_cast<MemRefType>(toBuffer.getResult().getType());
810+
if (!currentOutputMemRefType)
811+
return failure();
812+
808813
auto memrefType = MemRefType::get(srcTensorType.getShape(),
809-
srcTensorType.getElementType());
814+
srcTensorType.getElementType(),
815+
currentOutputMemRefType.getLayout(),
816+
currentOutputMemRefType.getMemorySpace());
810817
Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType,
811-
tensorCastOperand.getOperand());
818+
tensorCastOperand.getOperand(),
819+
toBuffer.getReadOnly());
812820
rewriter.replaceOpWithNewOp<memref::CastOp>(toBuffer, toBuffer.getType(),
813821
memref);
814822
return success();

mlir/test/Dialect/Bufferization/canonicalize.mlir

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -255,16 +255,32 @@ func.func @clone_and_preceding_dealloc(%arg0: memref<?xf32>) -> memref<32xf32> {
255255
func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) ->
256256
memref<?x?x16x32xi8> {
257257
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
258-
%1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
258+
%1 = bufferization.to_buffer %0 read_only : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8>
259259
return %1 : memref<?x?x16x32xi8>
260260
}
261-
// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
261+
// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8>
262262
// CHECK: %[[M1:.+]] = memref.cast %[[M]]
263263
// CHECK-SAME: memref<4x6x16x32xi8> to memref<?x?x16x32xi8>
264264
// CHECK: return %[[M1]] : memref<?x?x16x32xi8>
265265

266266
// -----
267267

268+
// CHECK-LABEL: func @tensor_cast_to_buffer
269+
// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8>
270+
func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8>) ->
271+
memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> {
272+
%0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<?x?x16x32xi8>
273+
%1 = bufferization.to_buffer %0 : tensor<?x?x16x32xi8> to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
274+
return %1 : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
275+
}
276+
// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8>
277+
// CHECK: %[[M1:.+]] = memref.cast %[[M]]
278+
// CHECK-SAME: memref<4x6x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
279+
// CHECK-SAME: to memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
280+
// CHECK: return %[[M1]] : memref<?x?x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1>
281+
282+
// -----
283+
268284
// Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx)
269285
// CHECK-LABEL: func @load_from_buffer_cast(
270286
func.func @load_from_buffer_cast(%arg0: index, %arg1: index,

0 commit comments

Comments
 (0)