Skip to content

Commit 58e2467

Browse files
committed
pr-feedback: simplify code and add more testing.
Fixes the index inconsistencies
1 parent 4f47d9c commit 58e2467

File tree

3 files changed

+168
-40
lines changed

3 files changed

+168
-40
lines changed

llvm/lib/Target/SPIRV/SPIRVEmitIntrinsics.cpp

Lines changed: 68 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -196,13 +196,24 @@ class SPIRVEmitIntrinsics
196196

197197
// Tries to walk the type accessed by the given GEP instruction.
198198
// For each nested type access, one of the 2 callbacks is called:
199-
// - OnStaticIndex when the index is a known constant value.
199+
// - OnLiteralIndexing when the index is a known constant value.
200+
// Parameters:
201+
// PointedType: the pointed type resulting of this indexing.
202+
// If the parent type is an array, this is the index in the array.
203+
// If the parent type is a struct, this is the field index.
204+
// Index: index of the element in the parent type.
200205
// - OnDynamnicIndexing when the index is a non-constant value.
206+
// This callback is only called when indexing into an array.
207+
// Parameters:
208+
// ElementType: the type of the elements stored in the parent array.
209+
// Offset: the Value* containing the byte offset into the array.
201210
// Return true if an error occured during the walk, false otherwise.
202211
bool walkLogicalAccessChain(
203212
GetElementPtrInst &GEP,
204-
const std::function<void(Type *, uint64_t)> &OnStaticIndexing,
205-
const std::function<void(Type *, Value *)> &OnDynamicIndexing);
213+
const std::function<void(Type *PointedType, uint64_t Index)>
214+
&OnLiteralIndexing,
215+
const std::function<void(Type *ElementType, Value *Offset)>
216+
&OnDynamicIndexing);
206217

207218
// Returns the type accessed using the given GEP instruction by relying
208219
// on the GEP type.
@@ -593,54 +604,64 @@ void SPIRVEmitIntrinsics::maybeAssignPtrType(Type *&Ty, Value *Op, Type *RefTy,
593604

594605
bool SPIRVEmitIntrinsics::walkLogicalAccessChain(
595606
GetElementPtrInst &GEP,
596-
const std::function<void(Type *, uint64_t)> &OnStaticIndexing,
607+
const std::function<void(Type *, uint64_t)> &OnLiteralIndexing,
597608
const std::function<void(Type *, Value *)> &OnDynamicIndexing) {
609+
// We only rewrite i8* GEP. Other should be left as-is.
610+
// Observation so-far is i8* GEP always have a single index. Making sure
611+
// that's the case.
612+
assert(GEP.getSourceElementType() ==
613+
IntegerType::getInt8Ty(CurrF->getContext()));
614+
assert(GEP.getNumIndices() == 1);
615+
598616
auto &DL = CurrF->getDataLayout();
599617
Value *Src = getPointerRoot(GEP.getPointerOperand());
600618
Type *CurType = deduceElementType(Src, true);
601619

602-
for (Value *V : GEP.indices()) {
603-
if (ConstantInt *CI = dyn_cast<ConstantInt>(V)) {
604-
uint64_t Offset = CI->getZExtValue();
605-
606-
do {
607-
if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) {
608-
uint32_t EltTypeSize = DL.getTypeSizeInBits(AT->getElementType()) / 8;
609-
assert(Offset < AT->getNumElements() * EltTypeSize);
610-
uint64_t Index = Offset / EltTypeSize;
611-
Offset = Offset - (Index * EltTypeSize);
612-
CurType = AT->getElementType();
613-
OnStaticIndexing(CurType, Index);
614-
} else if (StructType *ST = dyn_cast<StructType>(CurType)) {
615-
uint32_t StructSize = DL.getTypeSizeInBits(ST) / 8;
616-
assert(Offset < StructSize);
617-
const auto &STL = DL.getStructLayout(ST);
618-
unsigned Element = STL->getElementContainingOffset(Offset);
619-
Offset -= STL->getElementOffset(Element);
620-
CurType = ST->getElementType(Element);
621-
OnStaticIndexing(CurType, Element);
622-
} else {
623-
// Vector type indexing should not use GEP.
624-
// So if we have an index left, something is wrong. Giving up.
625-
return true;
626-
}
627-
} while (Offset > 0);
628-
629-
} else if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) {
630-
// Index is not constant. Either we have an array and accept it, or we
631-
// give up.
620+
Value *Operand = *GEP.idx_begin();
621+
ConstantInt *CI = dyn_cast<ConstantInt>(Operand);
622+
if (!CI) {
623+
ArrayType *AT = dyn_cast<ArrayType>(CurType);
624+
// Operand is not constant. Either we have an array and accept it, or we
625+
// give up.
626+
if (AT)
627+
OnDynamicIndexing(AT->getElementType(), Operand);
628+
return AT == nullptr;
629+
}
630+
631+
assert(CI);
632+
uint64_t Offset = CI->getZExtValue();
633+
634+
do {
635+
if (ArrayType *AT = dyn_cast<ArrayType>(CurType)) {
636+
uint32_t EltTypeSize = DL.getTypeSizeInBits(AT->getElementType()) / 8;
637+
assert(Offset < AT->getNumElements() * EltTypeSize);
638+
uint64_t Index = Offset / EltTypeSize;
639+
Offset = Offset - (Index * EltTypeSize);
632640
CurType = AT->getElementType();
633-
OnDynamicIndexing(CurType, V);
634-
} else
641+
OnLiteralIndexing(CurType, Index);
642+
} else if (StructType *ST = dyn_cast<StructType>(CurType)) {
643+
uint32_t StructSize = DL.getTypeSizeInBits(ST) / 8;
644+
assert(Offset < StructSize);
645+
const auto &STL = DL.getStructLayout(ST);
646+
unsigned Element = STL->getElementContainingOffset(Offset);
647+
Offset -= STL->getElementOffset(Element);
648+
CurType = ST->getElementType(Element);
649+
OnLiteralIndexing(CurType, Element);
650+
} else {
651+
// Vector type indexing should not use GEP.
652+
// So if we have an index left, something is wrong. Giving up.
635653
return true;
636-
}
654+
}
655+
} while (Offset > 0);
637656

638657
return false;
639658
}
640659

641660
Instruction *
642661
SPIRVEmitIntrinsics::buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP) {
662+
auto &DL = CurrF->getDataLayout();
643663
IRBuilder<> B(GEP.getParent());
664+
B.SetInsertPoint(&GEP);
644665

645666
std::vector<Value *> Indices;
646667
Indices.push_back(ConstantInt::get(
@@ -651,9 +672,14 @@ SPIRVEmitIntrinsics::buildLogicalAccessChainFromGEP(GetElementPtrInst &GEP) {
651672
Indices.push_back(
652673
ConstantInt::get(B.getInt64Ty(), Index, /* Signed= */ false));
653674
},
654-
[&Indices](Type *EltType, Value *Index) { Indices.push_back(Index); });
675+
[&Indices, &B, &DL](Type *EltType, Value *Offset) {
676+
uint32_t EltTypeSize = DL.getTypeSizeInBits(EltType) / 8;
677+
Value *Index = B.CreateUDiv(
678+
Offset, ConstantInt::get(Offset->getType(), EltTypeSize,
679+
/* Signed= */ false));
680+
Indices.push_back(Index);
681+
});
655682

656-
B.SetInsertPoint(&GEP);
657683
SmallVector<Type *, 2> Types = {GEP.getType(), GEP.getOperand(0)->getType()};
658684
SmallVector<Value *, 4> Args;
659685
Args.push_back(B.getInt1(GEP.isInBounds()));
@@ -1728,7 +1754,9 @@ void SPIRVEmitIntrinsics::insertPtrCastOrAssignTypeInstr(Instruction *I,
17281754
// the alternative type-scavenging method is not working.
17291755
// Physical SPIR-V can work around this, but not logical, hence still
17301756
// try to rely on the broken type scavenging for logical.
1731-
if (TM->getSubtargetImpl()->isLogicalSPIRV()) {
1757+
bool IsRewrittenGEP =
1758+
GEPI->getSourceElementType() == IntegerType::getInt8Ty(I->getContext());
1759+
if (IsRewrittenGEP && TM->getSubtargetImpl()->isLogicalSPIRV()) {
17321760
Value *Src = getPointerRoot(Pointer);
17331761
OpTy = GR->findDeducedElementType(Src);
17341762
}
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
3+
4+
%struct.S1 = type { <4 x i32>, [10 x <4 x float>], <4 x float> }
5+
%struct.S2 = type { <4 x float>, <4 x i32> }
6+
7+
@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1
8+
9+
define <4 x float> @main() {
10+
entry:
11+
%0 = tail call target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32 0, i32 1, i32 1, i32 0, i1 false, ptr nonnull @.str)
12+
%3 = tail call noundef align 1 dereferenceable(192) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) %0, i32 0)
13+
14+
; CHECK-DAG: %[[#ulong:]] = OpTypeInt 64 0
15+
; CHECK-DAG: %[[#ulong_1:]] = OpConstant %[[#ulong]] 1
16+
; CHECK-DAG: %[[#ulong_3:]] = OpConstant %[[#ulong]] 3
17+
18+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
19+
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
20+
; CHECK-DAG: %[[#uint_10:]] = OpConstant %[[#uint]] 10
21+
22+
; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
23+
; CHECK-DAG: %[[#v4f:]] = OpTypeVector %[[#float]] 4
24+
; CHECK-DAG: %[[#arr_v4f:]] = OpTypeArray %[[#v4f]] %[[#uint_10]]
25+
; CHECK-DAG: %[[#S1:]] = OpTypeStruct %[[#]] %[[#arr_v4f]] %[[#]]
26+
; CHECK-DAG: %[[#sb_S1:]] = OpTypePointer StorageBuffer %[[#S1]]
27+
; CHECK-DAG: %[[#sb_v4f:]] = OpTypePointer StorageBuffer %[[#v4f]]
28+
29+
; CHECK: %[[#tmp:]] = OpAccessChain %[[#sb_S1]] %[[#]] %[[#uint_0]] %[[#uint_0]]
30+
; CHECK: %[[#ptr:]] = OpInBoundsAccessChain %[[#sb_v4f]] %[[#tmp]] %[[#ulong_1]] %[[#ulong_3]]
31+
; This rewritten GEP combined all constant indices into a single value.
32+
; We should make sure the correct indices are retrieved.
33+
%arrayidx.i = getelementptr inbounds nuw i8, ptr addrspace(11) %3, i64 64
34+
35+
; CHECK: OpLoad %[[#v4f]] %[[#ptr]]
36+
%4 = load <4 x float>, ptr addrspace(11) %arrayidx.i, align 1
37+
38+
ret <4 x float> %4
39+
}
40+
41+
declare i32 @llvm.spv.flattened.thread.id.in.group()
42+
declare target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32, i32, i32, i32, i1, ptr)
43+
declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0), i32)
44+
45+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
46+
Lines changed: 54 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,54 @@
1+
; RUN: llc -verify-machineinstrs -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - | FileCheck %s
2+
; RUN: %if spirv-tools %{ llc -O3 -mtriple=spirv1.6-unknown-vulkan1.3-compute %s -o - -filetype=obj | spirv-val %}
3+
4+
%struct.S1 = type { <4 x i32>, [10 x <4 x float>], <4 x float> }
5+
%struct.S2 = type { <4 x float>, <4 x i32> }
6+
7+
@.str = private unnamed_addr constant [3 x i8] c"In\00", align 1
8+
9+
define <4 x float> @main(i32 %index) {
10+
entry:
11+
%0 = tail call target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32 0, i32 1, i32 1, i32 0, i1 false, ptr nonnull @.str)
12+
%3 = tail call noundef align 1 dereferenceable(192) ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) %0, i32 0)
13+
14+
; CHECK-DAG: %[[#ulong:]] = OpTypeInt 64 0
15+
; CHECK-DAG: %[[#ulong_1:]] = OpConstant %[[#ulong]] 1
16+
17+
; CHECK-DAG: %[[#uint:]] = OpTypeInt 32 0
18+
; CHECK-DAG: %[[#uint_0:]] = OpConstant %[[#uint]] 0
19+
; CHECK-DAG: %[[#uint_10:]] = OpConstant %[[#uint]] 10
20+
; CHECK-DAG: %[[#uint_16:]] = OpConstant %[[#uint]] 16
21+
22+
; CHECK-DAG: %[[#float:]] = OpTypeFloat 32
23+
; CHECK-DAG: %[[#v4f:]] = OpTypeVector %[[#float]] 4
24+
; CHECK-DAG: %[[#arr_v4f:]] = OpTypeArray %[[#v4f]] %[[#uint_10]]
25+
; CHECK-DAG: %[[#S1:]] = OpTypeStruct %[[#]] %[[#arr_v4f]] %[[#]]
26+
; CHECK-DAG: %[[#sb_S1:]] = OpTypePointer StorageBuffer %[[#S1]]
27+
; CHECK-DAG: %[[#sb_arr_v4f:]] = OpTypePointer StorageBuffer %[[#arr_v4f]]
28+
; CHECK-DAG: %[[#sb_v4f:]] = OpTypePointer StorageBuffer %[[#v4f]]
29+
30+
; CHECK: %[[#a:]] = OpAccessChain %[[#sb_S1]] %[[#]] %[[#uint_0]] %[[#uint_0]]
31+
; CHECK: %[[#b:]] = OpInBoundsAccessChain %[[#sb_arr_v4f]] %[[#a]] %[[#ulong_1]]
32+
%4 = getelementptr inbounds nuw i8, ptr addrspace(11) %3, i64 16
33+
34+
; CHECK: %[[#offset:]] = OpIMul %[[#]] %[[#]] %[[#uint_16]]
35+
; Offset is computed in bytes. Make sure we reconvert it back to an index.
36+
%offset = mul i32 %index, 16
37+
38+
; CHECK: %[[#index:]] = OpUDiv %[[#]] %[[#offset]] %[[#uint_16]]
39+
; CHECK: %[[#c:]] = OpInBoundsAccessChain %[[#sb_v4f]] %[[#b]] %[[#index]]
40+
%5 = getelementptr inbounds nuw i8, ptr addrspace(11) %4, i32 %offset
41+
42+
; CHECK: OpLoad %[[#v4f]] %[[#c]]
43+
%6 = load <4 x float>, ptr addrspace(11) %5, align 1
44+
45+
ret <4 x float> %6
46+
}
47+
48+
declare i32 @llvm.spv.flattened.thread.id.in.group()
49+
declare target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0) @llvm.spv.resource.handlefrombinding.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(i32, i32, i32, i32, i1, ptr)
50+
declare ptr addrspace(11) @llvm.spv.resource.getpointer.p11.tspirv.VulkanBuffer_a0s_struct.S1s_12_0t(target("spirv.VulkanBuffer", [0 x %struct.S1], 12, 0), i32)
51+
52+
attributes #0 = { "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" }
53+
54+

0 commit comments

Comments
 (0)