Skip to content

Commit 0debfb1

Browse files
author
Erich Keane
authored
[SYCL] Fix pathological case of visiting callees of a function. (#4065)
The markdevice rewrite improved the way we were checking recursive functions, however as an oversight didn't 'uniqify' each callee-check. This patch ensures we only visit each callee 1x, even if it is called multiple times. Note that this isn't a 'perfect' fix, we could skip any function we've ever 'seen' before in this kernel, however it results in some reduced diagnostic quality for recursive and attribute-collection issues. This at least reduces the 'pathological' cases that remain to just those that are also mostly pathological for templates in general (though we are still worse-off than template instantiations).
1 parent 7fc8aa0 commit 0debfb1

File tree

2 files changed

+60
-5
lines changed

2 files changed

+60
-5
lines changed

clang/lib/Sema/SemaSYCL.cpp

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -764,10 +764,8 @@ class SingleDeviceFunctionTracker {
764764
return FD->getMostRecentDecl();
765765
}
766766

767-
void VisitCallNode(CallGraphNode *Node,
767+
void VisitCallNode(CallGraphNode *Node, FunctionDecl *CurrentDecl,
768768
llvm::SmallVectorImpl<FunctionDecl *> &CallStack) {
769-
FunctionDecl *CurrentDecl = GetFDFromNode(Node);
770-
771769
// If this isn't a function, I don't think there is anything we can do here.
772770
if (!CurrentDecl)
773771
return;
@@ -842,8 +840,16 @@ class SingleDeviceFunctionTracker {
842840

843841
// Recurse.
844842
CallStack.push_back(CurrentDecl);
843+
llvm::SmallPtrSet<FunctionDecl *, 16> SeenCallees;
845844
for (CallGraphNode *CI : Node->callees()) {
846-
VisitCallNode(CI, CallStack);
845+
FunctionDecl *CurFD = GetFDFromNode(CI);
846+
847+
// Make sure we only visit each callee 1x from this function to avoid very
848+
// time consuming template recursion cases.
849+
if (!llvm::is_contained(SeenCallees, CurFD)) {
850+
VisitCallNode(CI, CurFD, CallStack);
851+
SeenCallees.insert(CurFD);
852+
}
847853
}
848854
CallStack.pop_back();
849855
}
@@ -852,7 +858,7 @@ class SingleDeviceFunctionTracker {
852858
void Init() {
853859
CallGraphNode *KernelNode = Parent.getNodeForKernel(SYCLKernel);
854860
llvm::SmallVector<FunctionDecl *> CallStack;
855-
VisitCallNode(KernelNode, CallStack);
861+
VisitCallNode(KernelNode, GetFDFromNode(KernelNode), CallStack);
856862
}
857863

858864
public:
Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
// RUN: %clang_cc1 -fsycl-is-device -internal-isystem %S/Inputs -sycl-std=2020 -verify -fsyntax-only -std=c++20 %s
2+
3+
// This test validates that this actually makes it through 'MarkDevice'. This
4+
// is a bit of a pathological case where we ended up visiting each call
5+
// individually. There is likely a similar test case that can cause us to hit
6+
// a pathological case in a very similar situation (where the callees aren't
7+
// exactly the same), but that likely causes problems with template
8+
// instantiations first.
9+
10+
// expected-no-diagnostics
11+
12+
#include "sycl.hpp"
13+
14+
template<bool B, typename V = void>
15+
struct enable_if { };
16+
template<typename V>
17+
struct enable_if<true, V> {
18+
using type = V;
19+
};
20+
template<bool B, typename V = void>
21+
using enable_if_t = typename enable_if<B, V>::type;
22+
23+
24+
template<int N, enable_if_t<N == 24, int> = 0>
25+
void mark_device_pathological_case() {
26+
// Do nothing.
27+
}
28+
29+
template<int N, enable_if_t<N < 24, int> = 0>
30+
void mark_device_pathological_case() {
31+
// We were visiting each of these, which caused 9^24 visits.
32+
mark_device_pathological_case<N + 1>();
33+
mark_device_pathological_case<N + 1>();
34+
mark_device_pathological_case<N + 1>();
35+
mark_device_pathological_case<N + 1>();
36+
mark_device_pathological_case<N + 1>();
37+
mark_device_pathological_case<N + 1>();
38+
mark_device_pathological_case<N + 1>();
39+
mark_device_pathological_case<N + 1>();
40+
mark_device_pathological_case<N + 1>();
41+
mark_device_pathological_case<N + 1>();
42+
}
43+
44+
int main() {
45+
sycl::queue q;
46+
q.submit([](sycl::handler &h) {
47+
h.single_task<class kernel>([]() { mark_device_pathological_case<0>(); });
48+
});
49+
}

0 commit comments

Comments
 (0)