Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -258,20 +258,23 @@ struct CallOpInterface
return failure();
Value buffer = *maybeBuffer;

// Caller / callee type mismatch is handled with a CastOp.
// Caller / callee type mismatch is handled with castOrReallocMemRefValue.
auto memRefType = funcType.getInput(opOperand.getOperandNumber());
// Since we don't yet have a clear layout story, to_memref may
// conservatively turn tensors into more dynamic memref than necessary.
// If the memref type of the callee fails, introduce an extra memref.cast
// that will either canonicalize away or fail compilation until we can do
// something better.
// something better. Insert a reallocation + copy if it cannot be
// statically guaranteed that a direct cast would be valid.
if (buffer.getType() != memRefType) {
assert(
memref::CastOp::areCastCompatible(buffer.getType(), memRefType) &&
"CallOp::bufferize: cast incompatible");
Value castBuffer = rewriter.create<memref::CastOp>(callOp.getLoc(),
memRefType, buffer);
buffer = castBuffer;
auto memrefDstType = dyn_cast<MemRefType>(memRefType);
assert(memrefDstType &&
"buffer layout not supported on unranked tensors");
FailureOr<Value> replacement = bufferization::castOrReallocMemRefValue(
rewriter, buffer, memrefDstType, options);
if (failed(replacement))
return failure();
buffer = *replacement;
}
newOperands.push_back(buffer);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,30 @@ func.func @return_extract_slice(%idx: index, %sz: index) -> (tensor<2x?xf32>)

// -----

// CHECK-NO-LAYOUT-MAP-LABEL: func.func @foo(
// CHECK-NO-LAYOUT-MAP-SAME: %[[VAL_0:.*]]: memref<3x8xf16>) -> memref<3x8xf16> {
// CHECK-NO-LAYOUT-MAP: return %[[VAL_0]] : memref<3x8xf16>
// CHECK-NO-LAYOUT-MAP: }
func.func @foo(%arg0: tensor<3x8xf16>) -> tensor<3x8xf16> {
return %arg0 : tensor<3x8xf16>
}

// CHECK-NO-LAYOUT-MAP-LABEL: func.func @call_extract_slice(
// CHECK-NO-LAYOUT-MAP-SAME: %[[VAL_0:.*]]: memref<4x8xf16>) -> memref<3x8xf16> {
// CHECK-NO-LAYOUT-MAP: %[[VAL_1:.*]] = memref.subview %[[VAL_0]][1, 0] [3, 8] [1, 1] : memref<4x8xf16> to memref<3x8xf16, strided<[8, 1], offset: 8>>
// CHECK-NO-LAYOUT-MAP: %[[VAL_2:.*]] = memref.alloc() {alignment = 64 : i64} : memref<3x8xf16>
// CHECK-NO-LAYOUT-MAP: memref.copy %[[VAL_1]], %[[VAL_2]] : memref<3x8xf16, strided<[8, 1], offset: 8>> to memref<3x8xf16>
// CHECK-NO-LAYOUT-MAP: %[[VAL_3:.*]] = call @foo(%[[VAL_2]]) : (memref<3x8xf16>) -> memref<3x8xf16>
// CHECK-NO-LAYOUT-MAP: return %[[VAL_3]] : memref<3x8xf16>
// CHECK-NO-LAYOUT-MAP: }
func.func @call_extract_slice(%arg0: tensor<4x8xf16>) -> (tensor<3x8xf16>) {
%0 = tensor.extract_slice %arg0[1, 0] [3, 8] [1, 1] : tensor<4x8xf16> to tensor<3x8xf16>
%1 = call @foo(%0) : (tensor<3x8xf16>) -> tensor<3x8xf16>
return %1 : tensor<3x8xf16>
}

// -----

// CHECK-LABEL: func private @private_func
// CHECK-NO-LAYOUT-MAP-LABEL: func private @private_func(memref<?xf32>) -> f32
func.func private @private_func(tensor<?xf32>) -> (f32)
Expand Down