diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index dbc7d0dd74a00..7eb729f349638 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -805,10 +805,18 @@ struct ToBufferOfCast : public OpRewritePattern { tensorCastOperand.getOperand().getType()); if (!srcTensorType) return failure(); + auto currentOutputMemRefType = + dyn_cast(toBuffer.getResult().getType()); + if (!currentOutputMemRefType) + return failure(); + auto memrefType = MemRefType::get(srcTensorType.getShape(), - srcTensorType.getElementType()); + srcTensorType.getElementType(), + currentOutputMemRefType.getLayout(), + currentOutputMemRefType.getMemorySpace()); Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType, - tensorCastOperand.getOperand()); + tensorCastOperand.getOperand(), + toBuffer.getReadOnly()); rewriter.replaceOpWithNewOp(toBuffer, toBuffer.getType(), memref); return success(); diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir index f44e29071796d..2acd19453a04d 100644 --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -255,16 +255,32 @@ func.func @clone_and_preceding_dealloc(%arg0: memref) -> memref<32xf32> { func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) -> memref { %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor - %1 = bufferization.to_buffer %0 : tensor to memref + %1 = bufferization.to_buffer %0 read_only : tensor to memref return %1 : memref } -// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8> +// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8> // CHECK: %[[M1:.+]] = memref.cast %[[M]] // CHECK-SAME: memref<4x6x16x32xi8> to memref // CHECK: return %[[M1]] : memref // ----- +// CHECK-LABEL: func @tensor_cast_to_buffer +// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8> +func.func @tensor_cast_to_buffer_layout_and_memspace(%arg0 : tensor<4x6x16x32xi8>) -> + memref, 1> { + %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor + %1 = bufferization.to_buffer %0 : tensor to memref, 1> + return %1 : memref, 1> +} +// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] : tensor<4x6x16x32xi8> +// CHECK: %[[M1:.+]] = memref.cast %[[M]] +// CHECK-SAME: memref<4x6x16x32xi8, strided<[?, ?, ?, 1], offset: ?>, 1> +// CHECK-SAME: to memref, 1> +// CHECK: return %[[M1]] : memref, 1> + +// ----- + // Folding of memref.load(to_buffer(%v, %idxs)) -> tensor.extract(%v, %idx) // CHECK-LABEL: func @load_from_buffer_cast( func.func @load_from_buffer_cast(%arg0: index, %arg1: index,