Skip to content

[CIR] Add get_element operation for computing pointer to array element #1748

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 9 commits into from
Jul 25, 2025
Merged
52 changes: 52 additions & 0 deletions clang/include/clang/CIR/Dialect/IR/CIROps.td
Original file line number Diff line number Diff line change
Expand Up @@ -3132,6 +3132,58 @@ def CIR_GetMethodOp : CIR_Op<"get_method"> {
let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// GetElementOp
//===----------------------------------------------------------------------===//

def CIR_GetElementOp : CIR_Op<"get_element"> {
let summary = "Get the address of an array element";
let description = [{
The `cir.get_element` operation gets the address of a particular element
from the `base` array.

It expects a pointer to the `base` array and the `index` of the element.

Example:
```mlir
// Suppose we have a array.
!s32i = !cir.int<s, 32>
!arr_ty = !cir.array<!s32i x 4>

// Get the address of the element at index 1.
%elem_1 = cir.get_element %0[1] : (!cir.ptr<!array_ty>, !s32i) -> !cir.ptr<!s32i>

// Get the address of the element at index %i.
%i = ...
%elem_i = cir.get_element %0[%i] : (!cir.ptr<!array_ty>, !s32i) -> !cir.ptr<!s32i>
```
}];

let arguments = (ins
Arg<CIR_PtrToArray, "the base address of the array ">:$base,
Arg<CIR_AnyFundamentalIntType, "the index of the element">:$index
);

let results = (outs CIR_PointerType:$result);

let assemblyFormat = [{
$base `[` $index `]` `:` `(` qualified(type($base)) `,` qualified(type($index)) `)`
`->` qualified(type($result)) attr-dict
}];

let extraClassDeclaration = [{
// Get the type of the element.
mlir::Type getElementType() {
return getType().getPointee();
}
cir::PointerType getBaseType() {
return mlir::cast<cir::PointerType>(getBase().getType());
}
}];

let hasVerifier = 1;
}

//===----------------------------------------------------------------------===//
// VecInsertOp
//===----------------------------------------------------------------------===//
Expand Down
39 changes: 38 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenBuilder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,11 +28,48 @@ mlir::Value CIRGenBuilderTy::maybeBuildArrayDecay(mlir::Location loc,
return arrayPtr;
}

mlir::Value CIRGenBuilderTy::getArrayElement(mlir::Location arrayLocBegin,
mlir::Value CIRGenBuilderTy::promoteArrayIndex(const clang::TargetInfo &ti,
mlir::Location loc,
mlir::Value index) {
// Get the array index type.
auto arrayIndexWidth = ti.getTypeWidth(clang::TargetInfo::IntType::SignedInt);
mlir::Type arrayIndexType = getSIntNTy(arrayIndexWidth);

// If this is a boolean, zero-extend it to the array index type.
if (auto boolTy = mlir::dyn_cast<cir::BoolType>(index.getType()))
return create<cir::CastOp>(loc, arrayIndexType, cir::CastKind::bool_to_int,
index);

// If this an integer, ensure that it is at least as width as the array index
// type.
if (auto intTy = mlir::dyn_cast<cir::IntType>(index.getType())) {
if (intTy.getWidth() < arrayIndexWidth)
return create<cir::CastOp>(loc, arrayIndexType, cir::CastKind::integral,
index);
}

return index;
}

mlir::Value CIRGenBuilderTy::getArrayElement(const clang::TargetInfo &ti,
mlir::Location arrayLocBegin,
mlir::Location arrayLocEnd,
mlir::Value arrayPtr,
mlir::Type eltTy, mlir::Value idx,
bool shouldDecay) {
auto arrayPtrTy = mlir::dyn_cast<cir::PointerType>(arrayPtr.getType());
assert(arrayPtrTy && "expected pointer type");

// If the array pointer is not decayed, emit a GetElementOp.
auto arrayTy = mlir::dyn_cast<cir::ArrayType>(arrayPtrTy.getPointee());
if (shouldDecay && arrayTy && arrayTy == eltTy) {
auto eltPtrTy =
getPointerTo(arrayTy.getElementType(), arrayPtrTy.getAddrSpace());
return create<cir::GetElementOp>(arrayLocEnd, eltPtrTy, arrayPtr,
promoteArrayIndex(ti, arrayLocBegin, idx));
}

// If we don't have sufficient type information, emit a PtrStrideOp.
mlir::Value basePtr = arrayPtr;
if (shouldDecay)
basePtr = maybeBuildArrayDecay(arrayLocBegin, arrayPtr, eltTy);
Expand Down
8 changes: 7 additions & 1 deletion clang/lib/CIR/CodeGen/CIRGenBuilder.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include "clang/AST/Decl.h"
#include "clang/AST/Type.h"
#include "clang/Basic/TargetInfo.h"
#include "clang/CIR/Dialect/Builder/CIRBaseBuilder.h"
#include "clang/CIR/Dialect/IR/CIRAttrs.h"
#include "clang/CIR/Dialect/IR/CIRDataLayout.h"
Expand Down Expand Up @@ -1030,10 +1031,15 @@ class CIRGenBuilderTy : public cir::CIRBaseBuilderTy {
return create<cir::GetRuntimeMemberOp>(loc, resultTy, objectPtr, memberPtr);
}

/// Promote a value for use as an array index.
mlir::Value promoteArrayIndex(const clang::TargetInfo &TargetInfo,
mlir::Location loc, mlir::Value index);

/// Create a cir.ptr_stride operation to get access to an array element.
/// idx is the index of the element to access, shouldDecay is true if the
/// result should decay to a pointer to the element type.
mlir::Value getArrayElement(mlir::Location arrayLocBegin,
mlir::Value getArrayElement(const clang::TargetInfo &targetInfo,
mlir::Location arrayLocBegin,
mlir::Location arrayLocEnd, mlir::Value arrayPtr,
mlir::Type eltTy, mlir::Value idx,
bool shouldDecay);
Expand Down
4 changes: 2 additions & 2 deletions clang/lib/CIR/CodeGen/CIRGenExpr.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1710,8 +1710,8 @@ emitArraySubscriptPtr(CIRGenFunction &CGF, mlir::Location beginLoc,
// that would enhance tracking this later in CIR?
if (inbounds)
assert(!cir::MissingFeatures::emitCheckedInBoundsGEP() && "NYI");
return CGM.getBuilder().getArrayElement(beginLoc, endLoc, ptr, eltTy, idx,
shouldDecay);
return CGM.getBuilder().getArrayElement(CGF.getTarget(), beginLoc, endLoc,
ptr, eltTy, idx, shouldDecay);
}

static QualType getFixedSizeElementType(const ASTContext &astContext,
Expand Down
6 changes: 3 additions & 3 deletions clang/lib/CIR/CodeGen/CIRGenExprAgg.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -954,9 +954,9 @@ void AggExprEmitter::VisitCXXStdInitializerListExpr(
ArrayType->getElementType()) &&
"Expected std::initializer_list second field to be const E *");

auto ArrayEnd =
Builder.getArrayElement(loc, loc, ArrayPtr.getPointer(),
ArrayPtr.getElementType(), Size, false);
auto ArrayEnd = Builder.getArrayElement(
CGF.getTarget(), loc, loc, ArrayPtr.getPointer(),
ArrayPtr.getElementType(), Size, false);
CGF.emitStoreThroughLValue(RValue::get(ArrayEnd), EndOrLength);
}
}
Expand Down
12 changes: 12 additions & 0 deletions clang/lib/CIR/Dialect/IR/CIRDialect.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -3837,6 +3837,18 @@ LogicalResult cir::GetMethodOp::verify() {
return mlir::success();
}

//===----------------------------------------------------------------------===//
// GetMemberOp Definitions
//===----------------------------------------------------------------------===//

LogicalResult cir::GetElementOp::verify() {
auto arrayTy = mlir::cast<cir::ArrayType>(getBaseType().getPointee());
if (getElementType() != arrayTy.getElementType())
return emitError() << "element type mismatch";

return mlir::success();
}

//===----------------------------------------------------------------------===//
// InlineAsmOp Definitions
//===----------------------------------------------------------------------===//
Expand Down
9 changes: 7 additions & 2 deletions clang/lib/CIR/Dialect/Transforms/LifetimeCheck.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1254,6 +1254,12 @@ void LifetimeCheckPass::updatePointsTo(mlir::Value addr, mlir::Value data,
return;
}

if (auto getElemOp = mlir::dyn_cast<cir::GetElementOp>(dataSrcOp)) {
getPmap()[addr].clear();
getPmap()[addr].insert(State::getLocalValue(getElemOp.getBase()));
return;
}

// Initializes ptr types out of known lib calls marked with pointer
// attributes. TODO: find a better way to tag this.
if (auto callOp = dyn_cast<CallOp>(dataSrcOp)) {
Expand Down Expand Up @@ -1945,8 +1951,7 @@ void LifetimeCheckPass::dumpPmap(PMapType &pmap) {
int entry = 0;
for (auto &mapEntry : pmap) {
llvm::errs() << " " << entry << ": " << getVarNameFromValue(mapEntry.first)
<< " "
<< "=> ";
<< " => ";
printPset(mapEntry.second);
llvm::errs() << "\n";
entry++;
Expand Down
129 changes: 96 additions & 33 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -951,6 +951,51 @@ static mlir::Value getLLVMIntCast(mlir::ConversionPatternRewriter &rewriter,
return rewriter.create<mlir::LLVM::TruncOp>(loc, llvmDstIntTy, llvmSrc);
}

static mlir::Value promoteIndex(mlir::ConversionPatternRewriter &rewriter,
mlir::Value index, uint64_t layoutWidth,
bool isUnsigned) {
auto indexOp = index.getDefiningOp();
if (!indexOp)
return index;

auto indexType = mlir::cast<mlir::IntegerType>(index.getType());
auto width = indexType.getWidth();
if (layoutWidth == width)
return index;

// If the index definition is a unary minus (index = sub 0, x), then we need
// to
bool rewriteSub = false;
auto sub = mlir::dyn_cast<mlir::LLVM::SubOp>(indexOp);
if (sub) {
if (auto lhsConst = dyn_cast<mlir::LLVM::ConstantOp>(
sub.getOperand(0).getDefiningOp())) {
auto lhsConstInt = mlir::dyn_cast<mlir::IntegerAttr>(lhsConst.getValue());
if (lhsConstInt && lhsConstInt.getValue() == 0) {
rewriteSub = true;
index = sub.getOperand(1);
}
}
}

// Handle the cast
auto llvmDstType = mlir::IntegerType::get(rewriter.getContext(), layoutWidth);
index = getLLVMIntCast(rewriter, index, llvmDstType, isUnsigned, width,
layoutWidth);

if (rewriteSub) {
index = rewriter.create<mlir::LLVM::SubOp>(
index.getLoc(),
rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(), index.getType(),
0),
index);
// TODO: check if the sub is trivially dead now.
rewriter.eraseOp(sub);
}

return index;
}

mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
cir::PtrStrideOp ptrStrideOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand All @@ -964,50 +1009,67 @@ mlir::LogicalResult CIRToLLVMPtrStrideOpLowering::matchAndRewrite(
// make it i8 instead.
if (mlir::isa<mlir::LLVM::LLVMVoidType>(elementTy) ||
mlir::isa<mlir::LLVM::LLVMFunctionType>(elementTy))
elementTy = mlir::IntegerType::get(elementTy.getContext(), 8,
mlir::IntegerType::Signless);
elementTy = mlir::IntegerType::get(ctx, 8, mlir::IntegerType::Signless);

// Zero-extend, sign-extend or trunc the pointer value.
auto index = adaptor.getStride();
auto width = mlir::cast<mlir::IntegerType>(index.getType()).getWidth();
mlir::DataLayout LLVMLayout(ptrStrideOp->getParentOfType<mlir::ModuleOp>());
auto layoutWidth =
LLVMLayout.getTypeIndexBitwidth(adaptor.getBase().getType());
auto indexOp = index.getDefiningOp();
if (indexOp && layoutWidth && width != *layoutWidth) {
// If the index comes from a subtraction, make sure the extension happens
// before it. To achieve that, look at unary minus, which already got
// lowered to "sub 0, x".
auto sub = dyn_cast<mlir::LLVM::SubOp>(indexOp);
auto unary = dyn_cast_if_present<cir::UnaryOp>(
ptrStrideOp.getStride().getDefiningOp());
bool rewriteSub =
unary && unary.getKind() == cir::UnaryOpKind::Minus && sub;
if (rewriteSub)
index = indexOp->getOperand(1);

// Handle the cast
auto llvmDstType = mlir::IntegerType::get(ctx, *layoutWidth);
index = getLLVMIntCast(rewriter, index, llvmDstType,
ptrStrideOp.getStride().getType().isUnsigned(),
width, *layoutWidth);

// Rewrite the sub in front of extensions/trunc
if (rewriteSub) {
index = rewriter.create<mlir::LLVM::SubOp>(
index.getLoc(),
rewriter.create<mlir::LLVM::ConstantOp>(index.getLoc(),
index.getType(), 0),
index);
rewriter.eraseOp(sub);
}
if (auto layoutWidth =
LLVMLayout.getTypeIndexBitwidth(adaptor.getBase().getType())) {
bool isUnsigned = false;
if (auto strideTy =
mlir::dyn_cast<cir::IntType>(ptrStrideOp.getOperand(1).getType()))
isUnsigned = strideTy.isUnsigned();
index = promoteIndex(rewriter, index, *layoutWidth, isUnsigned);
}

rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(
ptrStrideOp, resultTy, elementTy, adaptor.getBase(), index);
return mlir::success();
}

mlir::LogicalResult CIRToLLVMGetElementOpLowering::matchAndRewrite(
cir::GetElementOp op, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {

if (auto arrayTy =
mlir::dyn_cast<cir::ArrayType>(op.getBaseType().getPointee())) {
auto *tc = getTypeConverter();
const auto llResultTy = tc->convertType(op.getType());
auto elementTy = convertTypeForMemory(*tc, dataLayout, op.getElementType());
auto *ctx = elementTy.getContext();

// void and function types doesn't really have a layout to use in GEPs,
// make it i8 instead.
if (mlir::isa<mlir::LLVM::LLVMVoidType>(elementTy) ||
mlir::isa<mlir::LLVM::LLVMFunctionType>(elementTy))
elementTy = mlir::IntegerType::get(ctx, 8, mlir::IntegerType::Signless);

// Zero-extend, sign-extend or trunc the index value.
auto index = adaptor.getIndex();
mlir::DataLayout LLVMLayout(op->getParentOfType<mlir::ModuleOp>());
if (auto layoutWidth =
LLVMLayout.getTypeIndexBitwidth(adaptor.getBase().getType())) {
bool isUnsigned = false;
if (auto strideTy = dyn_cast<cir::IntType>(op.getOperand(1).getType()))
isUnsigned = strideTy.isUnsigned();
index = promoteIndex(rewriter, index, *layoutWidth, isUnsigned);
}

// Since the base address is a pointer to an aggregate, the first
// offset is always zero. The second offset tell us which member it
// will access.
const auto llArrayTy = getTypeConverter()->convertType(arrayTy);
llvm::SmallVector<mlir::LLVM::GEPArg, 2> offset{0, index};
rewriter.replaceOpWithNewOp<mlir::LLVM::GEPOp>(op, llResultTy, llArrayTy,
adaptor.getBase(), offset);

return mlir::success();
}

llvm_unreachable("NYI, GetElementOp lowering to LLVM for non-Array");
}

mlir::LogicalResult CIRToLLVMBaseClassAddrOpLowering::matchAndRewrite(
cir::BaseClassAddrOp baseClassOp, OpAdaptor adaptor,
mlir::ConversionPatternRewriter &rewriter) const {
Expand Down Expand Up @@ -4388,6 +4450,7 @@ void populateCIRToLLVMConversionPatterns(
patterns.add<
// clang-format off
CIRToLLVMPtrStrideOpLowering,
CIRToLLVMGetElementOpLowering,
CIRToLLVMInlineAsmOpLowering
// clang-format on
>(converter, patterns.getContext(), dataLayout);
Expand Down
16 changes: 16 additions & 0 deletions clang/lib/CIR/Lowering/DirectToLLVM/LowerToLLVM.h
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,22 @@ class CIRToLLVMPtrStrideOpLowering
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMGetElementOpLowering
: public mlir::OpConversionPattern<cir::GetElementOp> {
mlir::DataLayout const &dataLayout;

public:
CIRToLLVMGetElementOpLowering(const mlir::TypeConverter &typeConverter,
mlir::MLIRContext *context,
mlir::DataLayout const &dataLayout)
: OpConversionPattern(typeConverter, context), dataLayout(dataLayout) {}
using mlir::OpConversionPattern<cir::GetElementOp>::OpConversionPattern;

mlir::LogicalResult
matchAndRewrite(cir::GetElementOp op, OpAdaptor,
mlir::ConversionPatternRewriter &) const override;
};

class CIRToLLVMBaseClassAddrOpLowering
: public mlir::OpConversionPattern<cir::BaseClassAddrOp> {
public:
Expand Down
Loading
Loading