Skip to content

Commit 38c2b6a

Browse files
gitoleglanza
authored andcommitted
[CIR][ABI][Lowering] covers return struct case with coercion through memory (#1059)
This PR covers one more case for return values of struct type, where `memcpy` is emitted.
1 parent 97e764c commit 38c2b6a

File tree

2 files changed

+68
-17
lines changed

2 files changed

+68
-17
lines changed

clang/lib/CIR/Dialect/Transforms/TargetLowering/LowerFunction.cpp

Lines changed: 42 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -257,6 +257,23 @@ Value emitAddressAtOffset(LowerFunction &LF, Value addr,
257257
return addr;
258258
}
259259

260+
mlir::cir::AllocaOp findAlloca(Operation *op) {
261+
if (!op)
262+
return {};
263+
264+
if (auto al = dyn_cast<mlir::cir::AllocaOp>(op)) {
265+
return al;
266+
} else if (auto ret = dyn_cast<mlir::cir::ReturnOp>(op)) {
267+
auto vals = ret.getInput();
268+
if (vals.size() == 1)
269+
return findAlloca(vals[0].getDefiningOp());
270+
} else if (auto load = dyn_cast<mlir::cir::LoadOp>(op)) {
271+
return findAlloca(load.getAddr().getDefiningOp());
272+
}
273+
274+
return {};
275+
}
276+
260277
/// After the calling convention is lowered, an ABI-agnostic type might have to
261278
/// be loaded back to its ABI-aware couterpart so it may be returned. If they
262279
/// differ, we have to do a coerced load. A coerced load, which means to load a
@@ -305,6 +322,31 @@ Value castReturnValue(Value Src, Type Ty, LowerFunction &LF) {
305322
return LF.getRewriter().create<LoadOp>(Src.getLoc(), Cast);
306323
}
307324

325+
// Otherwise do coercion through memory.
326+
if (auto addr = findAlloca(Src.getDefiningOp())) {
327+
auto &rewriter = LF.getRewriter();
328+
auto *ctxt = LF.LM.getMLIRContext();
329+
auto ptrTy = PointerType::get(ctxt, Ty);
330+
auto voidPtr = PointerType::get(ctxt, mlir::cir::VoidType::get(ctxt));
331+
332+
// insert alloca near the previuos one
333+
auto point = rewriter.saveInsertionPoint();
334+
rewriter.setInsertionPointAfter(addr);
335+
auto align = LF.LM.getDataLayout().getABITypeAlign(Ty);
336+
auto alignAttr = rewriter.getI64IntegerAttr(align.value());
337+
auto tmp =
338+
rewriter.create<AllocaOp>(Src.getLoc(), ptrTy, Ty, "tmp", alignAttr);
339+
rewriter.restoreInsertionPoint(point);
340+
341+
auto srcVoidPtr = createBitcast(addr, voidPtr, LF);
342+
auto dstVoidPtr = createBitcast(tmp, voidPtr, LF);
343+
auto i64Ty = IntType::get(ctxt, 64, false);
344+
auto len = rewriter.create<ConstantOp>(
345+
Src.getLoc(), IntAttr::get(i64Ty, SrcSize.getFixedValue()));
346+
rewriter.create<MemCpyOp>(Src.getLoc(), dstVoidPtr, srcVoidPtr, len);
347+
return rewriter.create<LoadOp>(Src.getLoc(), tmp.getResult());
348+
}
349+
308350
cir_cconv_unreachable("NYI");
309351
}
310352

@@ -532,23 +574,6 @@ LowerFunction::buildFunctionProlog(const LowerFunctionInfo &FI, FuncOp Fn,
532574
return success();
533575
}
534576

535-
mlir::cir::AllocaOp findAlloca(Operation *op) {
536-
if (!op)
537-
return {};
538-
539-
if (auto al = dyn_cast<mlir::cir::AllocaOp>(op)) {
540-
return al;
541-
} else if (auto ret = dyn_cast<mlir::cir::ReturnOp>(op)) {
542-
auto vals = ret.getInput();
543-
if (vals.size() == 1)
544-
return findAlloca(vals[0].getDefiningOp());
545-
} else if (auto load = dyn_cast<mlir::cir::LoadOp>(op)) {
546-
return findAlloca(load.getAddr().getDefiningOp());
547-
}
548-
549-
return {};
550-
}
551-
552577
LogicalResult LowerFunction::buildFunctionEpilog(const LowerFunctionInfo &FI) {
553578
// NOTE(cir): no-return, naked, and no result functions should be handled in
554579
// CIRGen.

clang/test/CIR/CallConvLowering/AArch64/aarch64-cc-structs.c

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,32 @@ GT_128 ret_gt_128() {
7474
return x;
7575
}
7676

77+
typedef struct {
78+
int a;
79+
int b;
80+
int c;
81+
} S;
82+
83+
// CHECK: cir.func {{.*@retS}}() -> !cir.array<!u64i x 2>
84+
// CHECK: %[[#V0:]] = cir.alloca !ty_S, !cir.ptr<!ty_S>, ["__retval"] {alignment = 4 : i64}
85+
// CHECK: %[[#V1:]] = cir.alloca !cir.array<!u64i x 2>, !cir.ptr<!cir.array<!u64i x 2>>, ["tmp"] {alignment = 8 : i64}
86+
// CHECK: %[[#V2:]] = cir.cast(bitcast, %[[#V0]] : !cir.ptr<!ty_S>), !cir.ptr<!void>
87+
// CHECK: %[[#V3:]] = cir.cast(bitcast, %[[#V1]] : !cir.ptr<!cir.array<!u64i x 2>>), !cir.ptr<!void>
88+
// CHECK: %[[#V4:]] = cir.const #cir.int<12> : !u64i
89+
// CHECK: cir.libc.memcpy %[[#V4]] bytes from %[[#V2]] to %[[#V3]] : !u64i, !cir.ptr<!void> -> !cir.ptr<!void>
90+
// CHECK: %[[#V5:]] = cir.load %[[#V1]] : !cir.ptr<!cir.array<!u64i x 2>>, !cir.array<!u64i x 2>
91+
// CHECK: cir.return %[[#V5]] : !cir.array<!u64i x 2>
92+
93+
// LLVM: [2 x i64] @retS()
94+
// LLVM: %[[#V1:]] = alloca %struct.S, i64 1, align 4
95+
// LLVM: %[[#V2:]] = alloca [2 x i64], i64 1, align 8
96+
// LLVM: call void @llvm.memcpy.p0.p0.i64(ptr %[[#V2]], ptr %[[#V1]], i64 12, i1 false)
97+
// LLVM: %[[#V3:]] = load [2 x i64], ptr %[[#V2]], align 8
98+
// LLVM: ret [2 x i64] %[[#V3]]
99+
S retS() {
100+
S s;
101+
return s;
102+
}
77103
// CHECK: cir.func {{.*@pass_lt_64}}(%arg0: !u64
78104
// CHECK: %[[#V0:]] = cir.alloca !ty_LT_64_, !cir.ptr<!ty_LT_64_>
79105
// CHECK: %[[#V1:]] = cir.cast(integral, %arg0 : !u64i), !u16i

0 commit comments

Comments
 (0)