Skip to content

Commit 9a220c0

Browse files
committed
[SPIR-V] Fix some GEP legalization
Pointers and GEP are untyped. SPIR-V required structured OpAccessChain. This means the backend will have to determine a good way to retrieve the structured access from an untyped GEP. This is not a trivial problem, and needs to be addressed to have a robust compiler. The issue is other workstreams relies on the access chain deduction to work. So we have 2 options: - pause all dependent work until we have a good chain deduction. - submit this limited fix to we can work on both this and other features in parallel. Choice we want to make is #2: submitting this **knowing** this is not a **good** fix. It only increase the number of patterns we can work with, thus allowing others to continue working on other parts of the backend. This patch as-is has many limitations: - If cannot robustly determine the depth of the structured access from a GEP. Fixing this would require looking ahead at the full GEP chain. - It cannot always figure out the correct access indices, especially with dynamic indices. This will require frontend collaboration. Because we know this is a temporary hack, this patch only impacts the logical SPIR-V target. Physical SPIR-V, which can rely on pointer cast remains on the old method. Related to #145002
1 parent 50f3a6b commit 9a220c0

File tree

5 files changed

+230
-6
lines changed

5 files changed

+230
-6
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 150 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,6 +194,31 @@ class SPIRVEmitIntrinsics
194194

195195
void useRoundingMode(ConstrainedFPIntrinsic *FPI, IRBuilder<> &B);
196196

197+
// Tries to walk the type accessed by the given GEP instruction.
198+
// For each nested type access, one of the 2 callbacks is called:
199+
// - OnStaticIndex when the index is a known constant value.
200+
// - OnDynamnicIndexing when the index is a non-constant value.
201+
// Return true if an error occured during the walk, false otherwise.
202+
bool walkLogicalAccessChain(
203+
GetElementPtrInst &GEP,
204+
const std::function<void(Type *, uint64_t)> &OnStaticIndexing,
205+
const std::function<void(Type *, Value *)> &OnDynamicIndexing);
206+
207+
// Returns the type accessed using the given GEP instruction by relying
208+
// on the GEP type.
209+
// FIXME: GEP types are not supposed to be used to retrieve the pointed
210+
// type. This must be fixed.
211+
Type *getGEPType(GetElementPtrInst *GEP);
212+
213+
// Returns the type accessed using the given GEP instruction by walking
214+
// the source type using the GEP indices.
215+
// FIXME: without help from the frontend, this method cannot reliably retrieve
216+
// the stored type, nor can robustly determine the depth of the type
217+
// we are accessing.
218+
Type *getGEPTypeLogical(GetElementPtrInst *GEP);
219+
220+
Instruction *buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP);
221+
197222
public:
198223
static char ID;
199224
SPIRVEmitIntrinsics(SPIRVTargetMachine *TM = nullptr)
@@ -246,6 +271,17 @@ bool expectIgnoredInIRTranslation(const Instruction *I) {
246271
}
247272
}
248273

274+
// Returns the source pointer from `I` ignoring intermediate ptrcast.
275+
Value *getPointerRoot(Value *I) {
276+
if (auto *II = dyn_cast<IntrinsicInst>(I)) {
277+
if (II->getIntrinsicID() == Intrinsic::spv_ptrcast) {
278+
Value *V = II->getArgOperand(0);
279+
return getPointerRoot(V);
280+
}
281+
}
282+
return I;
283+
}
284+
249285
} // namespace
250286

251287
char SPIRVEmitIntrinsics::ID = 0;
@@ -555,7 +591,97 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
555591
Ty = RefTy;
556592
}
557593

558-
Type *getGEPType(GetElementPtrInst *Ref) {
594+
bool SPIRVEmitIntrinsics::walkLogicalAccessChain(
595+
GetElementPtrInst &GEP,
596+
const std::function<void(Type *, uint64_t)> &OnStaticIndexing,
597+
const std::function<void(Type *, Value *)> &OnDynamicIndexing) {
598+
auto &DL = CurrF->getDataLayout();
599+
Value *Src = getPointerRoot(GEP.getPointerOperand());
600+
Type *CurType = deduceElementType(Src, true);
601+
602+
for (Value *V : GEP.indices()) {
603+
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
604+
uint64_t Offset = CI->getZExtValue();
605+
606+
do {
607+
if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) {
608+
uint32_t EltTypeSize = DL.getTypeSizeInBits(AT->getElementType()) / 8;
609+
assert(Offset < AT->getNumElements() * EltTypeSize);
610+
uint64_t Index = Offset / EltTypeSize;
611+
Offset = Offset - (Index * EltTypeSize);
612+
CurType = AT->getElementType();
613+
OnStaticIndexing(CurType, Index);
614+
} else if (StructType *ST = dyn_cast<StructType>(CurType)) {
615+
uint32_t StructSize = DL.getTypeSizeInBits(ST) / 8;
616+
assert(Offset < StructSize);
617+
const auto &STL = DL.getStructLayout(ST);
618+
unsigned Element = STL->getElementContainingOffset(Offset);
619+
Offset -= STL->getElementOffset(Element);
620+
CurType = ST->getElementType(Element);
621+
OnStaticIndexing(CurType, Element);
622+
} else {
623+
// Vector type indexing should not use GEP.
624+
// So if we have an index left, something is wrong. Giving up.
625+
return true;
626+
}
627+
} while (Offset > 0);
628+
629+
} else if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) {
630+
// Index is not constant. Either we have an array and accept it, or we
631+
// give up.
632+
CurType = AT->getElementType();
633+
OnDynamicIndexing(CurType, V);
634+
} else
635+
return true;
636+
}
637+
638+
return false;
639+
}
640+
641+
Instruction *
642+
SPIRVEmitIntrinsics::buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP) {
643+
IRBuilder<> B(GEP.getParent());
644+
645+
std::vector<Value *> Indices;
646+
Indices.push_back(ConstantInt::get(
647+
IntegerType::getInt32Ty(CurrF->getContext()), 0, /* Signed= */ false));
648+
walkLogicalAccessChain(
649+
GEP,
650+
[&Indices, &B](Type *EltType, uint64_t Index) {
651+
Indices.push_back(
652+
ConstantInt::get(B.getInt64Ty(), Index, /* Signed= */ false));
653+
},
654+
[&Indices](Type *EltType, Value *Index) { Indices.push_back(Index); });
655+
656+
B.SetInsertPoint(&GEP);
657+
SmallVector<Type *, 2> Types = {GEP.getType(), GEP.getOperand(0)->getType()};
658+
SmallVector<Value *, 4> Args;
659+
Args.push_back(B.getInt1(GEP.isInBounds()));
660+
Args.push_back(GEP.getOperand(0));
661+
llvm::append_range(Args, Indices);
662+
auto *NewI = B.CreateIntrinsic(Intrinsic::spv_gep, {Types}, {Args});
663+
replaceAllUsesWithAndErase(B, &GEP, NewI);
664+
return NewI;
665+
}
666+
667+
Type *SPIRVEmitIntrinsics::getGEPTypeLogical(GetElementPtrInst *GEP) {
668+
669+
Type *CurType = GEP->getResultElementType();
670+
671+
bool Interrupted = walkLogicalAccessChain(
672+
*GEP, [&CurType](Type *EltType, uint64_t Index) { CurType = EltType; },
673+
[&CurType](Type *EltType, Value *Index) { CurType = EltType; });
674+
675+
return Interrupted ? GEP->getResultElementType() : CurType;
676+
}
677+
678+
Type *SPIRVEmitIntrinsics::getGEPType(GetElementPtrInst *Ref) {
679+
if (Ref->getSourceElementType() ==
680+
IntegerType::getInt8Ty(CurrF->getContext()) &&
681+
TM->getSubtargetImpl()->isLogicalSPIRV()) {
682+
return getGEPTypeLogical(Ref);
683+
}
684+
559685
Type *Ty = nullptr;
560686
// TODO: not sure if GetElementPtrInst::getTypeAtIndex() does anything
561687
// useful here
@@ -1395,6 +1521,13 @@ Instruction *SPIRVEmitIntrinsics::visitSwitchInst(SwitchInst &I) {
13951521
}
13961522

13971523
Instruction *SPIRVEmitIntrinsics::visitGetElementPtrInst(GetElementPtrInst &I) {
1524+
if (I.getSourceElementType() == IntegerType::getInt8Ty(CurrF->getContext()) &&
1525+
TM->getSubtargetImpl()->isLogicalSPIRV()) {
1526+
Instruction *Result = buildLogicalAccessChainFromGEP(I);
1527+
if (Result)
1528+
return Result;
1529+
}
1530+
13981531
IRBuilder<> B(I.getParent());
13991532
B.SetInsertPoint(&I);
14001533
SmallVector<Type *, 2> Types = {I.getType(), I.getOperand(0)->getType()};
@@ -1588,7 +1721,22 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
15881721
}
15891722
if (GetElementPtrInst *GEPI = dyn_cast<GetElementPtrInst>(I)) {
15901723
Value *Pointer = GEPI->getPointerOperand();
1591-
Type *OpTy = GEPI->getSourceElementType();
1724+
Type *OpTy = nullptr;
1725+
1726+
// Knowing the accessed type is mandatory for logical SPIR-V. Sadly,
1727+
// the GEP source element type should not be used for this purpose, and
1728+
// the alternative type-scavenging method is not working.
1729+
// Physical SPIR-V can work around this, but not logical, hence still
1730+
// try to rely on the broken type scavenging for logical.
1731+
if (TM->getSubtargetImpl()->isLogicalSPIRV()) {
1732+
Value *Src = getPointerRoot(Pointer);
1733+
OpTy = GR->findDeducedElementType(Src);
1734+
}
1735+
1736+
// In all cases, fall back to the GEP type if type scavenging failed.
1737+
if (!OpTy)
1738+
OpTy = GEPI->getSourceElementType();
1739+
15921740
replacePointerOperandWithPtrCast(I, Pointer, OpTy, 0, B);
15931741
if (isNestedPointer(OpTy))
15941742
insertTodoType(Pointer);

llvm/test/CodeGen/SPIRV/hlsl-resources/StructuredBuffer.ll

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv1.6-vulkan1.3-library %s -o - | FileCheck %s
1+
; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv1.6-vulkan1.3-library %s -o - -print-after-all | FileCheck %s
22
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv1.6-vulkan1.3-library %s -o - -filetype=obj | spirv-val %}
33

44
target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024-n8:16:32:64-G1"

llvm/test/CodeGen/SPIRV/llvm-intrinsics/lifetime.ll

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ define spir_func void @foo(ptr noundef byval(%tprange) align 8 %_arg_UserRange)
3333
%RoundedRangeKernel = alloca %tprange, align 8
3434
call void @llvm.lifetime.start.p0(i64 72, ptr nonnull %RoundedRangeKernel)
3535
call void @llvm.memcpy.p0.p0.i64(ptr align 8 %RoundedRangeKernel, ptr align 8 %_arg_UserRange, i64 16, i1 false)
36-
%KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 16
36+
%KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 8
3737
call void @llvm.lifetime.end.p0(i64 72, ptr nonnull %RoundedRangeKernel)
3838
ret void
3939
}
@@ -55,7 +55,7 @@ define spir_func void @bar(ptr noundef byval(%tprange) align 8 %_arg_UserRange)
5555
%RoundedRangeKernel = alloca %tprange, align 8
5656
call void @llvm.lifetime.start.p0(i64 -1, ptr nonnull %RoundedRangeKernel)
5757
call void @llvm.memcpy.p0.p0.i64(ptr align 8 %RoundedRangeKernel, ptr align 8 %_arg_UserRange, i64 16, i1 false)
58-
%KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 16
58+
%KernelFunc = getelementptr inbounds i8, ptr %RoundedRangeKernel, i64 8
5959
call void @llvm.lifetime.end.p0(i64 -1, ptr nonnull %RoundedRangeKernel)
6060
ret void
6161
}

llvm/test/CodeGen/SPIRV/logical-struct-access.ll

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
1-
; RUN: llc -O0 -mtriple=spirv-unknown-unknown %s -o - | FileCheck %s
1+
; RUN: llc -O0 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -print-after-all | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
23

34
; CHECK-DAG: [[uint:%[0-9]+]] = OpTypeInt 32 0
45

Lines changed: 75 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,75 @@
1+
; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 5
2+
; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - | FileCheck %s
3+
; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
4+
5+
; struct S1 {
6+
; int4 i;
7+
; float4 f;
8+
; };
9+
; struct S2 {
10+
; float4 f;
11+
; int4 i;
12+
; };
13+
;
14+
; StructuredBuffer<S1> In : register(t1);
15+
; RWStructuredBuffer<S2> Out : register(u0);
16+
;
17+
; [numthreads(1,1,1)]
18+
; void main(uint GI : SV_GroupIndex) {
19+
; Out[GI].f = In[GI].f;
20+
; Out[GI].i = In[GI].i;
21+
; }
22+
23+
%struct.S1 = type { <4 x i32>, <4 x float> }
24+
%struct.S2 = type { <4 x float>, <4 x i32> }
25+
26+
@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1
27+
@.str.2 = private unnamed_addr constant [4 x i8] c"Out\00", align 1
28+
29+
define void @main() local_unnamed_addr #0 {
30+
; CHECK-LABEL: main
31+
; CHECK: %43 = OpFunction %2 None %3 ; -- Begin function main
32+
; CHECK-NEXT: %1 = OpLabel
33+
; CHECK-NEXT: %44 = OpVariable %28 Function %38
34+
; CHECK-NEXT: %45 = OpVariable %27 Function %39
35+
; CHECK-NEXT: %46 = OpCopyObject %19 %40
36+
; CHECK-NEXT: %47 = OpCopyObject %16 %41
37+
; CHECK-NEXT: %48 = OpLoad %4 %42
38+
; CHECK-NEXT: %49 = OpAccessChain %13 %46 %29 %48
39+
; CHECK-NEXT: %50 = OpInBoundsAccessChain %9 %49 %31
40+
; CHECK-NEXT: %51 = OpLoad %8 %50 Aligned 1
41+
; CHECK-NEXT: %52 = OpAccessChain %11 %47 %29 %48
42+
; CHECK-NEXT: %53 = OpInBoundsAccessChain %9 %52 %29
43+
; CHECK-NEXT: OpStore %53 %51 Aligned 1
44+
; CHECK-NEXT: %54 = OpAccessChain %6 %49 %29
45+
; CHECK-NEXT: %55 = OpLoad %5 %54 Aligned 1
46+
; CHECK-NEXT: %56 = OpInBoundsAccessChain %6 %52 %31
47+
; CHECK-NEXT: OpStore %56 %55 Aligned 1
48+
; CHECK-NEXT: OpReturn
49+
; CHECK-NEXT: OpFunctionEnd
50+
entry:
51+
%0 = tail call target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32 0, i32 1, i32 1, i32 0, i1 false, ptr nonnull @.str)
52+
%1 = tail call target("spirv.VulkanBuffer", [0 x %struct.S2], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S2s_12_1t(i32 0, i32 0, i32 1, i32 0, i1 false, ptr nonnull @.str.2)
53+
%2 = tail call i32 @llvm.spv.flattened.thread.id.in.group()
54+
%3 = tail call noundef align 1 dereferenceable(32) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) %0, i32 %2)
55+
%f.i = getelementptr inbounds nuw i8, ptr addrspace(11) %3, i64 16
56+
%4 = load <4 x float>, ptr addrspace(11) %f.i, align 1
57+
%5 = tail call noundef align 1 dereferenceable(32) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S2s_12_1t(target("spirv.VulkanBuffer", [0 x %struct.S2], 12, 1) %1, i32 %2)
58+
store <4 x float> %4, ptr addrspace(11) %5, align 1
59+
%6 = load <4 x i32>, ptr addrspace(11) %3, align 1
60+
%i6.i = getelementptr inbounds nuw i8, ptr addrspace(11) %5, i64 16
61+
store <4 x i32> %6, ptr addrspace(11) %i6.i, align 1
62+
ret void
63+
}
64+
65+
declare i32 @llvm.spv.flattened.thread.id.in.group()
66+
67+
declare target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32, i32, i32, i32, i1, ptr)
68+
69+
declare target("spirv.VulkanBuffer", [0 x %struct.S2], 12, 1) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S2s_12_1t(i32, i32, i32, i32, i1, ptr)
70+
71+
declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S2s_12_1t(target("spirv.VulkanBuffer", [0 x %struct.S2], 12, 1), i32)
72+
73+
declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0), i32)
74+
75+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }

0 commit comments

Comments
 (0)