Skip to content

[mlir][bufferization] Assertion `memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && "CallOp::bufferize: cast incompatible"' failed #105916

@CoTinker

Description

@CoTinker

test.mlir

func.func @generic(%arg0: tensor<16515072xf16>, %arg1: tensor<16515072xf16>) -> tensor<16515072xf16> {
    return %arg1 : tensor<16515072xf16>
}

func.func @forward(%arg0: tensor<2880x7168xf16>) -> (tensor<7168x2x24x2x24xf16>) {
  %expanded_2120 = tensor.expand_shape %arg0 [[0, 1], [2]] output_shape [5,576,7168] : tensor<2880x7168xf16> into tensor<5x576x7168xf16>
  %extracted_slice_2122 = tensor.extract_slice %expanded_2120[1, 0, 0] [4, 576, 7168] [1, 1, 1] : tensor<5x576x7168xf16> to tensor<4x576x7168xf16>
  %expanded_2123 = tensor.expand_shape %extracted_slice_2122 [[0, 1], [2, 3], [4]] output_shape [2,2,24,24,7168] : tensor<4x576x7168xf16> into tensor<2x2x24x24x7168xf16>
  %2248 = tensor.empty() : tensor<7168x2x24x2x24xf16>
  %collapsed_2124 = tensor.collapse_shape %expanded_2123 [[0, 1, 2, 3, 4]] : tensor<2x2x24x24x7168xf16> into tensor<16515072xf16>
  %collapsed_2125 = tensor.collapse_shape %2248 [[0, 1, 2, 3, 4]] : tensor<7168x2x24x2x24xf16> into tensor<16515072xf16>
  %2249 = call @generic(%collapsed_2124, %collapsed_2125) : (tensor<16515072xf16>, tensor<16515072xf16>) -> tensor<16515072xf16>
  %expanded_2126 = tensor.expand_shape %2249 [[0, 1, 2, 3, 4]] output_shape [7168,2,24,2,24] : tensor<16515072xf16> into tensor<7168x2x24x2x24xf16>
  return %expanded_2126 : tensor<7168x2x24x2x24xf16>
}

%mlir-opt -one-shot-bufferize="bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map" test.mlir

mlir-opt: /home/mls/community/llvm-project/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp:269: llvm::LogicalResult mlir::bufferization::func_ext::CallOpInterface::bufferize(mlir::Operation*, mlir::RewriterBase&, const mlir::bufferization::BufferizationOptions&) const: Assertion `memref::CastOp::areCastCompatible(buffer.getType(), memRefType) && "CallOp::bufferize: cast incompatible"' failed.
PLEASE submit a bug report to https://github.com/llvm/llvm-project/issues/ and include the crash backtrace.
Stack dump:
0.      Program arguments: ../build-llvm/bin/mlir-opt "-one-shot-bufferize=bufferize-function-boundaries function-boundary-type-conversion=identity-layout-map" buffer.mlir
 #0 0x0000558d58056d3a llvm::sys::PrintStackTrace(llvm::raw_ostream&, int) /home/mls/community/llvm-project/llvm/lib/Support/Unix/Signals.inc:723:22
 #1 0x0000558d5805715b PrintStackTraceSignalHandler(void*) /home/mls/community/llvm-project/llvm/lib/Support/Unix/Signals.inc:798:1
 #2 0x0000558d580545ab llvm::sys::RunSignalHandlers() /home/mls/community/llvm-project/llvm/lib/Support/Signals.cpp:105:20
 #3 0x0000558d580565d2 SignalHandler(int) /home/mls/community/llvm-project/llvm/lib/Support/Unix/Signals.inc:413:1
 #4 0x00007fd340ba0520 (/lib/x86_64-linux-gnu/libc.so.6+0x42520)
 #5 0x00007fd340bf49fc __pthread_kill_implementation ./nptl/pthread_kill.c:44:76
 #6 0x00007fd340bf49fc __pthread_kill_internal ./nptl/pthread_kill.c:78:10
 #7 0x00007fd340bf49fc pthread_kill ./nptl/pthread_kill.c:89:10
 #8 0x00007fd340ba0476 gsignal ./signal/../sysdeps/posix/raise.c:27:6
 #9 0x00007fd340b867f3 abort ./stdlib/abort.c:81:7
#10 0x00007fd340b8671b _nl_load_domain ./intl/loadmsgcat.c:1177:9
#11 0x00007fd340b97e96 (/lib/x86_64-linux-gnu/libc.so.6+0x39e96)
#12 0x0000558d58ad1fe3 mlir::bufferization::func_ext::CallOpInterface::bufferize(mlir::Operation*, mlir::RewriterBase&, mlir::bufferization::BufferizationOptions const&) const /home/mls/community/llvm-project/mlir/lib/Dialect/Bufferization/Transforms/FuncBufferizableOpInterfaceImpl.cpp:272:59
#13 0x0000558d58ade8b0 mlir::bufferization::detail::BufferizableOpInterfaceInterfaceTraits::FallbackModel<mlir::bufferization::func_ext::CallOpInterface>::bufferize(mlir::bufferization::detail::BufferizableOpInterfaceInterfaceTraits::Concept const*, mlir::Operation*, mlir::RewriterBase&, mlir::bufferization::BufferizationOptions const&) /home/mls/community/build-llvm/tools/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.h.inc:1053:1
#14 0x0000558d589ec990 mlir::bufferization::BufferizableOpInterface::bufferize(mlir::RewriterBase&, mlir::bufferization::BufferizationOptions const&) /home/mls/community/build-llvm/tools/mlir/include/mlir/Dialect/Bufferization/IR/BufferizableOpInterface.cpp.inc:310:3
#15 0x0000558d58a86517 mlir::bufferization::bufferizeOp(mlir::Operation*, mlir::bufferization::BufferizationOptions const&, mlir::bufferization::BufferizationStatistics*) /home/mls/community/llvm-project/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp:477:15
#16 0x0000558d58afd5cc mlir::bufferization::bufferizeModuleOp(mlir::ModuleOp, mlir::bufferization::OneShotBufferizationOptions const&, mlir::bufferization::BufferizationStatistics*) /home/mls/community/llvm-project/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp:454:17
#17 0x0000558d58afdb13 mlir::bufferization::runOneShotModuleBufferize(mlir::ModuleOp, mlir::bufferization::OneShotBufferizationOptions const&, mlir::bufferization::BufferizationStatistics*) /home/mls/community/llvm-project/mlir/lib/Dialect/Bufferization/Transforms/OneShotModuleBufferize.cpp:509:13
#18 0x0000558d58a859a8 (anonymous namespace)::OneShotBufferizePass::runOnOperation() /home/mls/community/llvm-project/mlir/lib/Dialect/Bufferization/Transforms/Bufferize.cpp:296:17
#19 0x0000558d5d8960ad mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::'lambda0'()::operator()() const /home/mls/community/llvm-project/mlir/lib/Pass/Pass.cpp:524:57
#20 0x0000558d5d899d76 void llvm::function_ref<void ()>::callback_fn<mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int)::'lambda0'()>(long) /home/mls/community/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:46:40
#21 0x0000558d58079f34 llvm::function_ref<void ()>::operator()() const /home/mls/community/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:62
#22 0x0000558d5d8a125d void mlir::MLIRContext::executeAction<mlir::PassExecutionAction, mlir::Pass&>(llvm::function_ref<void ()>, llvm::ArrayRef<mlir::IRUnit>, mlir::Pass&) /home/mls/community/llvm-project/mlir/include/mlir/IR/MLIRContext.h:276:3
#23 0x0000558d5d8964cf mlir::detail::OpToOpPassAdaptor::run(mlir::Pass*, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int) /home/mls/community/llvm-project/mlir/lib/Pass/Pass.cpp:533:23
#24 0x0000558d5d8967aa mlir::detail::OpToOpPassAdaptor::runPipeline(mlir::OpPassManager&, mlir::Operation*, mlir::AnalysisManager, bool, unsigned int, mlir::PassInstrumentor*, mlir::PassInstrumentation::PipelineParentInfo const*) /home/mls/community/llvm-project/mlir/lib/Pass/Pass.cpp:593:15
#25 0x0000558d5d8986bc mlir::PassManager::runPasses(mlir::Operation*, mlir::AnalysisManager) /home/mls/community/llvm-project/mlir/lib/Pass/Pass.cpp:904:40
#26 0x0000558d5d898514 mlir::PassManager::run(mlir::Operation*) /home/mls/community/llvm-project/mlir/lib/Pass/Pass.cpp:884:69
#27 0x0000558d5d888cfe performActions(llvm::raw_ostream&, std::shared_ptr<llvm::SourceMgr> const&, mlir::MLIRContext*, mlir::MlirOptMainConfig const&) /home/mls/community/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:413:13
#28 0x0000558d5d8893e4 processBuffer(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::MlirOptMainConfig const&, mlir::DialectRegistry&, llvm::ThreadPoolInterface*) /home/mls/community/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:478:26
#29 0x0000558d5d88999b mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::'lambda'(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)::operator()(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) const /home/mls/community/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:561:25
#30 0x0000558d5d88ab1d llvm::LogicalResult llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::callback_fn<mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&)::'lambda'(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>(long, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) /home/mls/community/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:45:52
#31 0x0000558d5da26e65 llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>::operator()(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&) const /home/mls/community/llvm-project/llvm/include/llvm/ADT/STLFunctionalExtras.h:68:12
#32 0x0000558d5da266f6 mlir::splitAndProcessBuffer(std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::function_ref<llvm::LogicalResult (std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, llvm::raw_ostream&)>, llvm::raw_ostream&, llvm::StringRef, llvm::StringRef) /home/mls/community/llvm-project/mlir/lib/Support/ToolUtilities.cpp:27:30
#33 0x0000558d5d889b34 mlir::MlirOptMain(llvm::raw_ostream&, std::unique_ptr<llvm::MemoryBuffer, std::default_delete<llvm::MemoryBuffer>>, mlir::DialectRegistry&, mlir::MlirOptMainConfig const&) /home/mls/community/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:564:31
#34 0x0000558d5d889e2d mlir::MlirOptMain(int, char**, llvm::StringRef, llvm::StringRef, mlir::DialectRegistry&) /home/mls/community/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:605:13
#35 0x0000558d5d88a024 mlir::MlirOptMain(int, char**, llvm::StringRef, mlir::DialectRegistry&) /home/mls/community/llvm-project/mlir/lib/Tools/mlir-opt/MlirOptMain.cpp:621:21
#36 0x0000558d57fb1a4e main /home/mls/community/llvm-project/mlir/tools/mlir-opt/mlir-opt.cpp:317:0
#37 0x00007fd340b87d90 __libc_start_call_main ./csu/../sysdeps/nptl/libc_start_call_main.h:58:16
#38 0x00007fd340b87e40 call_init ./csu/../csu/libc-start.c:128:20
#39 0x00007fd340b87e40 __libc_start_main ./csu/../csu/libc-start.c:379:5
#40 0x0000558d57fb14e5 _start (../build-llvm/bin/mlir-opt+0x1e174e5)
Aborted (core dumped)

// Caller / callee type mismatch is handled with a CastOp.
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
// Since we don't yet have a clear layout story, to_memref may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
// that will either canonicalize away or fail compilation until we can do
// something better.
if (buffer.getType() != memRefType) {
assert(
memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
"CallOp::bufferize: cast incompatible");
Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
memRefType, buffer);

the buffer.getType() is memref<16515072xf16, strided<[1], offset: 4128768>> while the memrefType is memref<16515072xf16>. CastOp can't convert between the two type.

Maybe we can use:

%alloc = memref.alloc() : memref<16515072xf16>
memref.copy %collapse_shape, %alloc : memref<16515072xf16, strided<[1], offset: 4128768>> to memref<16515072xf16>

Is it reasonable to do this? @matthias-springer

Metadata

Metadata

Assignees

Labels

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions