Skip to content

[SPIR-V] Fix some GEP legalization #150943

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Aug 1, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 177 additions & 2 deletions llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,6 +194,42 @@ class SPIRVEmitIntrinsics

void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B);

// Tries to walk the type accessed by the given GEP instruction.
// For each nested type access, one of the 2 callbacks is called:
// - OnLiteralIndexing when the index is a known constant value.
// Parameters:
// PointedType: the pointed type resulting of this indexing.
// If the parent type is an array, this is the index in the array.
// If the parent type is a struct, this is the field index.
// Index: index of the element in the parent type.
// - OnDynamnicIndexing when the index is a non-constant value.
// This callback is only called when indexing into an array.
// Parameters:
// ElementType: the type of the elements stored in the parent array.
// Offset: the Value* containing the byte offset into the array.
// Return true if an error occured during the walk, false otherwise.
bool walkLogicalAccessChain(
GetElementPtrInst &GEP,
const std::function<void(Type *PointedType, uint64_t Index)>
&OnLiteralIndexing,
const std::function<void(Type *ElementType, Value *Offset)>
&OnDynamicIndexing);

// Returns the type accessed using the given GEP instruction by relying
// on the GEP type.
// FIXME: GEP types are not supposed to be used to retrieve the pointed
// type. This must be fixed.
Type *getGEPType(GetElementPtrInst *GEP);

// Returns the type accessed using the given GEP instruction by walking
// the source type using the GEP indices.
// FIXME: without help from the frontend, this method cannot reliably retrieve
// the stored type, nor can robustly determine the depth of the type
// we are accessing.
Type *getGEPTypeLogical(GetElementPtrInst *GEP);

Instruction *buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP);

public:
static char ID;
SPIRVEmitIntrinsics(SPIRVTargetMachine *TM = nullptr)
Expand Down Expand Up @@ -246,6 +282,17 @@ bool expectIgnoredInIRTranslation(const Instruction *I) {
}
}

// Returns the source pointer from `I` ignoring intermediate ptrcast.
Value *getPointerRoot(Value *I) {
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
if (II->getIntrinsicID() == Intrinsic::spv_ptrcast) {
Value *V = II->getArgOperand(0);
return getPointerRoot(V);
}
}
return I;
}

} // namespace

char SPIRVEmitIntrinsics::ID = 0;
Expand Down Expand Up @@ -555,7 +602,111 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
Ty = RefTy;
}

Type *getGEPType(GetElementPtrInst *Ref) {
bool SPIRVEmitIntrinsics::walkLogicalAccessChain(
GetElementPtrInst &GEP,
const std::function<void(Type *, uint64_t)> &OnLiteralIndexing,
const std::function<void(Type *, Value *)> &OnDynamicIndexing) {
// We only rewrite i8* GEP. Other should be left as-is.
// Valid i8* GEP must always have a single index.
assert(GEP.getSourceElementType() ==
IntegerType::getInt8Ty(CurrF->getContext()));
assert(GEP.getNumIndices() == 1);

auto &DL = CurrF->getDataLayout();
Value *Src = getPointerRoot(GEP.getPointerOperand());
Type *CurType = deduceElementType(Src, true);

Value *Operand = *GEP.idx_begin();
ConstantInt *CI = dyn_cast<ConstantInt>(Operand);
if (!CI) {
ArrayType *AT = dyn_cast<ArrayType>(CurType);
// Operand is not constant. Either we have an array and accept it, or we
// give up.
if (AT)
OnDynamicIndexing(AT->getElementType(), Operand);
return AT == nullptr;
}

assert(CI);
uint64_t Offset = CI->getZExtValue();

do {
if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) {
uint32_t EltTypeSize = DL.getTypeSizeInBits(AT->getElementType()) / 8;
assert(Offset < AT->getNumElements() * EltTypeSize);
uint64_t Index = Offset / EltTypeSize;
Offset = Offset - (Index * EltTypeSize);
CurType = AT->getElementType();
OnLiteralIndexing(CurType, Index);
} else if (StructType *ST = dyn_cast<StructType>(CurType)) {
uint32_t StructSize = DL.getTypeSizeInBits(ST) / 8;
assert(Offset < StructSize);
const auto &STL = DL.getStructLayout(ST);
unsigned Element = STL->getElementContainingOffset(Offset);
Offset -= STL->getElementOffset(Element);
CurType = ST->getElementType(Element);
OnLiteralIndexing(CurType, Element);
} else {
// Vector type indexing should not use GEP.
// So if we have an index left, something is wrong. Giving up.
return true;
}
} while (Offset > 0);

return false;
}

Instruction *
SPIRVEmitIntrinsics::buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP) {
auto &DL = CurrF->getDataLayout();
IRBuilder<> B(GEP.getParent());
B.SetInsertPoint(&GEP);

std::vector<Value *> Indices;
Indices.push_back(ConstantInt::get(
IntegerType::getInt32Ty(CurrF->getContext()), 0, /* Signed= */ false));
walkLogicalAccessChain(
GEP,
[&Indices, &B](Type *EltType, uint64_t Index) {
Indices.push_back(
ConstantInt::get(B.getInt64Ty(), Index, /* Signed= */ false));
},
[&Indices, &B, &DL](Type *EltType, Value *Offset) {
uint32_t EltTypeSize = DL.getTypeSizeInBits(EltType) / 8;
Value *Index = B.CreateUDiv(
Offset, ConstantInt::get(Offset->getType(), EltTypeSize,
/* Signed= */ false));
Indices.push_back(Index);
});

SmallVector<Type *, 2> Types = {GEP.getType(), GEP.getOperand(0)->getType()};
SmallVector<Value *, 4> Args;
Args.push_back(B.getInt1(GEP.isInBounds()));
Args.push_back(GEP.getOperand(0));
llvm::append_range(Args, Indices);
auto *NewI = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
replaceAllUsesWithAndErase(B, &GEP, NewI);
return NewI;
}

Type *SPIRVEmitIntrinsics::getGEPTypeLogical(GetElementPtrInst *GEP) {

Type *CurType = GEP->getResultElementType();

bool Interrupted = walkLogicalAccessChain(
*GEP, [&CurType](Type *EltType, uint64_t Index) { CurType = EltType; },
[&CurType](Type *EltType, Value *Index) { CurType = EltType; });

return Interrupted ? GEP->getResultElementType() : CurType;
}

Type *SPIRVEmitIntrinsics::getGEPType(GetElementPtrInst *Ref) {
if (Ref->getSourceElementType() ==
IntegerType::getInt8Ty(CurrF->getContext()) &&
TM->getSubtargetImpl()->isLogicalSPIRV()) {
return getGEPTypeLogical(Ref);
}

Type *Ty = nullptr;
// TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything
// useful here
Expand Down Expand Up @@ -1395,6 +1546,13 @@ Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
}

Instruction *SPIRVEmitIntrinsics::visitGetElementPtrInst(GetElementPtrInst &I) {
if (I.getSourceElementType() == IntegerType::getInt8Ty(CurrF->getContext()) &&
TM->getSubtargetImpl()->isLogicalSPIRV()) {
Instruction *Result = buildLogicalAccessChainFromGEP(I);
if (Result)
return Result;
}

IRBuilder<> B(I.getParent());
B.SetInsertPoint(&I);
SmallVector<Type *, 2> Types = {I.getType(), I.getOperand(0)->getType()};
Expand Down Expand Up @@ -1588,7 +1746,24 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
}
if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
Value *Pointer = GEPI->getPointerOperand();
Type *OpTy = GEPI->getSourceElementType();
Type *OpTy = nullptr;

// Knowing the accessed type is mandatory for logical SPIR-V. Sadly,
// the GEP source element type should not be used for this purpose, and
// the alternative type-scavenging method is not working.
// Physical SPIR-V can work around this, but not logical, hence still
// try to rely on the broken type scavenging for logical.
bool IsRewrittenGEP =
GEPI->getSourceElementType() == IntegerType::getInt8Ty(I->getContext());
if (IsRewrittenGEP && TM->getSubtargetImpl()->isLogicalSPIRV()) {
Value *Src = getPointerRoot(Pointer);
OpTy = GR->findDeducedElementType(Src);
}

// In all cases, fall back to the GEP type if type scavenging failed.
if (!OpTy)
OpTy = GEPI->getSourceElementType();

replacePointerOperandWithPtrCast(I, Pointer, OpTy, 0, B);
if (isNestedPointer(OpTy))
insertTodoType(Pointer);
Expand Down
4 changes: 2 additions & 2 deletions llvm/test/CodeGen/SPIRV/llvm-intrinsics/lifetime.ll
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ define spir_func void @foo(ptr noundef byval(%tprange) align 8 %_arg_UserRange)
%RoundedRangeKernel = alloca %tprange, align 8
call void @llvm.lifetime.start.p0(i64 72, ptr nonnull %RoundedRangeKernel)
call void @llvm.memcpy.p0.p0.i64(ptr align 8 %RoundedRangeKernel, ptr align 8 %_arg_UserRange, i64 16, i1 false)
%KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 16
%KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 8
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you making this change?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The size of %tprange is 16 bytes. If you do a GEP to +16, you are doing an out-of-bounds access. Because we have a literal index, we'd generate an out-of-bounds OpInBoundsAccessChain
This test is about testing lifetime intrinsics, so seems OK to change the GEP index to something valid.

call void @llvm.lifetime.end.p0(i64 72, ptr nonnull %RoundedRangeKernel)
ret void
}
Expand All @@ -55,7 +55,7 @@ define spir_func void @bar(ptr noundef byval(%tprange) align 8 %_arg_UserRange)
%RoundedRangeKernel = alloca %tprange, align 8
call void @llvm.lifetime.start.p0(i64 -1, ptr nonnull %RoundedRangeKernel)
call void @llvm.memcpy.p0.p0.i64(ptr align 8 %RoundedRangeKernel, ptr align 8 %_arg_UserRange, i64 16, i1 false)
%KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 16
%KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 8
call void @llvm.lifetime.end.p0(i64 -1, ptr nonnull %RoundedRangeKernel)
ret void
}
Expand Down
3 changes: 2 additions & 1 deletion llvm/test/CodeGen/SPIRV/logical-struct-access.ll
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
; RUN: llc -O0 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -print-after-all | FileCheck %s
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}

; CHECK-DAG: [[uint:%[0-9]+]] = OpTypeInt 32 0

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}

%struct.S1 = type { <4 x i32>, [10 x <4 x float>], <4 x float> }
%struct.S2 = type { <4 x float>, <4 x i32> }

@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1

define <4 x float> @main() {
entry:
%0 = tail call target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32 0, i32 1, i32 1, i32 0, i1 false, ptr nonnull @.str)
%3 = tail call noundef align 1 dereferenceable(192) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) %0, i32 0)

; CHECK-DAG: %[[#ulong:]] = OpTypeInt 64 0
; CHECK-DAG: %[[#ulong_1:]] = OpConstant %[[#ulong]] 1
; CHECK-DAG: %[[#ulong_3:]] = OpConstant %[[#ulong]] 3

; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
; CHECK-DAG: %[[#uint_10:]] = OpConstant %[[#uint]] 10

; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
; CHECK-DAG: %[[#v4f:]] = OpTypeVector %[[#float]] 4
; CHECK-DAG: %[[#arr_v4f:]] = OpTypeArray %[[#v4f]] %[[#uint_10]]
; CHECK-DAG: %[[#S1:]] = OpTypeStruct %[[#]] %[[#arr_v4f]] %[[#]]
; CHECK-DAG: %[[#sb_S1:]] = OpTypePointer StorageBuffer %[[#S1]]
; CHECK-DAG: %[[#sb_v4f:]] = OpTypePointer StorageBuffer %[[#v4f]]

; CHECK: %[[#tmp:]] = OpAccessChain %[[#sb_S1]] %[[#]] %[[#uint_0]] %[[#uint_0]]
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#sb_v4f]] %[[#tmp]] %[[#ulong_1]] %[[#ulong_3]]
; This rewritten GEP combined all constant indices into a single value.
; We should make sure the correct indices are retrieved.
%arrayidx.i = getelementptr inbounds nuw i8, ptr addrspace(11) %3, i64 64

; CHECK: OpLoad %[[#v4f]] %[[#ptr]]
%4 = load <4 x float>, ptr addrspace(11) %arrayidx.i, align 1

ret <4 x float> %4
}

declare i32 @llvm.spv.flattened.thread.id.in.group()
declare target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32, i32, i32, i32, i1, ptr)
declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0), i32)

attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - | FileCheck %s
; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}

%struct.S1 = type { <4 x i32>, [10 x <4 x float>], <4 x float> }
%struct.S2 = type { <4 x float>, <4 x i32> }

@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1

define <4 x float> @main(i32 %index) {
entry:
%0 = tail call target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32 0, i32 1, i32 1, i32 0, i1 false, ptr nonnull @.str)
%3 = tail call noundef align 1 dereferenceable(192) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) %0, i32 0)

; CHECK-DAG: %[[#ulong:]] = OpTypeInt 64 0
; CHECK-DAG: %[[#ulong_1:]] = OpConstant %[[#ulong]] 1

; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
; CHECK-DAG: %[[#uint_10:]] = OpConstant %[[#uint]] 10
; CHECK-DAG: %[[#uint_16:]] = OpConstant %[[#uint]] 16

; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
; CHECK-DAG: %[[#v4f:]] = OpTypeVector %[[#float]] 4
; CHECK-DAG: %[[#arr_v4f:]] = OpTypeArray %[[#v4f]] %[[#uint_10]]
; CHECK-DAG: %[[#S1:]] = OpTypeStruct %[[#]] %[[#arr_v4f]] %[[#]]
; CHECK-DAG: %[[#sb_S1:]] = OpTypePointer StorageBuffer %[[#S1]]
; CHECK-DAG: %[[#sb_arr_v4f:]] = OpTypePointer StorageBuffer %[[#arr_v4f]]
; CHECK-DAG: %[[#sb_v4f:]] = OpTypePointer StorageBuffer %[[#v4f]]

; CHECK: %[[#a:]] = OpAccessChain %[[#sb_S1]] %[[#]] %[[#uint_0]] %[[#uint_0]]
; CHECK: %[[#b:]] = OpInBoundsAccessChain %[[#sb_arr_v4f]] %[[#a]] %[[#ulong_1]]
%4 = getelementptr inbounds nuw i8, ptr addrspace(11) %3, i64 16

; CHECK: %[[#offset:]] = OpIMul %[[#]] %[[#]] %[[#uint_16]]
; Offset is computed in bytes. Make sure we reconvert it back to an index.
%offset = mul i32 %index, 16

; CHECK: %[[#index:]] = OpUDiv %[[#]] %[[#offset]] %[[#uint_16]]
; CHECK: %[[#c:]] = OpInBoundsAccessChain %[[#sb_v4f]] %[[#b]] %[[#index]]
%5 = getelementptr inbounds nuw i8, ptr addrspace(11) %4, i32 %offset

; CHECK: OpLoad %[[#v4f]] %[[#c]]
%6 = load <4 x float>, ptr addrspace(11) %5, align 1

ret <4 x float> %6
}

declare i32 @llvm.spv.flattened.thread.id.in.group()
declare target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32, i32, i32, i32, i1, ptr)
declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0), i32)

attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }


Loading
Loading