From 940c0dd50e73b7c93ffec0544ea598d9d1cbbc99 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 6 Aug 2025 06:41:59 +0000 Subject: [PATCH] [mlir][Bufferization] Support cast from ranked to unranked in canonicalization. --- .../Dialect/Bufferization/IR/BufferizationOps.cpp | 8 +++----- mlir/test/Dialect/Bufferization/canonicalize.mlir | 13 +++++++++++++ 2 files changed, 16 insertions(+), 5 deletions(-) diff --git a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp index 7eb729f349638..f1f12f4bca70e 100644 --- a/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp +++ b/mlir/lib/Dialect/Bufferization/IR/BufferizationOps.cpp @@ -806,14 +806,12 @@ struct ToBufferOfCast : public OpRewritePattern { if (!srcTensorType) return failure(); auto currentOutputMemRefType = - dyn_cast(toBuffer.getResult().getType()); + dyn_cast(toBuffer.getResult().getType()); if (!currentOutputMemRefType) return failure(); - auto memrefType = MemRefType::get(srcTensorType.getShape(), - srcTensorType.getElementType(), - currentOutputMemRefType.getLayout(), - currentOutputMemRefType.getMemorySpace()); + auto memrefType = currentOutputMemRefType.cloneWith( + srcTensorType.getShape(), srcTensorType.getElementType()); Value memref = ToBufferOp::create(rewriter, toBuffer.getLoc(), memrefType, tensorCastOperand.getOperand(), toBuffer.getReadOnly()); diff --git a/mlir/test/Dialect/Bufferization/canonicalize.mlir b/mlir/test/Dialect/Bufferization/canonicalize.mlir index 2acd19453a04d..ae1d1fcfc19dc 100644 --- a/mlir/test/Dialect/Bufferization/canonicalize.mlir +++ b/mlir/test/Dialect/Bufferization/canonicalize.mlir @@ -263,6 +263,19 @@ func.func @tensor_cast_to_buffer(%arg0 : tensor<4x6x16x32xi8>) -> // CHECK-SAME: memref<4x6x16x32xi8> to memref // CHECK: return %[[M1]] : memref +// CHECK-LABEL: func @tensor_cast_to_unranked_buffer +// CHECK-SAME: %[[ARG0:.+]]: tensor<4x6x16x32xi8> +func.func @tensor_cast_to_unranked_buffer(%arg0 : tensor<4x6x16x32xi8>) -> + memref<*xi8> { + %0 = tensor.cast %arg0 : tensor<4x6x16x32xi8> to tensor<*xi8> + %1 = bufferization.to_buffer %0 read_only : tensor<*xi8> to memref<*xi8> + return %1 : memref<*xi8> +} +// CHECK: %[[M:.+]] = bufferization.to_buffer %[[ARG0]] read_only : tensor<4x6x16x32xi8> +// CHECK: %[[M1:.+]] = memref.cast %[[M]] +// CHECK-SAME: memref<4x6x16x32xi8> to memref<*xi8> +// CHECK: return %[[M1]] : memref<*xi8> + // ----- // CHECK-LABEL: func @tensor_cast_to_buffer