Skip to content

Commit 35e29df

Browse files
[NFC][SYCL][Graph] Add successors/predecessors views + cleanup
Part of refactoring to get rid of most (all?) `std::weak_ptr<node_impl>` and some of `std::shared_ptr<node_impl>` started in intel#19295. Use `nodes_range` from that PR to implement `successors`/`predecessors` views and update read-only accesses to the successors/predecessors to go through them. I'm not changing the data members `MSuccessors`/`MPredecessors` yet because it would affect unittests. I'd prefer to refactor most of the code in future PRs before making that change and updating unittests in one go. I'm updating some APIs to accept `node_impl &` instead of `std::shared_ptr` where the change is mostly localized to the callers iterating over successors/predecessors and doesn't spoil into other code too much. For those that weren't updated here we (temporarily) use `shared_from_this()` but I expect to eliminate those unnecessary copies when those interfaces will be updated in the subsequent PRs.
1 parent 322cd13 commit 35e29df

File tree

5 files changed

+79
-80
lines changed

5 files changed

+79
-80
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 50 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -100,17 +100,16 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
100100
Source.pop();
101101
SortedNodes.push_back(Node);
102102

103-
for (auto &SuccWP : Node->MSuccessors) {
104-
auto Succ = SuccWP.lock();
103+
for (node_impl &Succ : Node->successors()) {
105104

106-
if (PartitionBounded && (Succ->MPartitionNum != Node->MPartitionNum)) {
105+
if (PartitionBounded && (Succ.MPartitionNum != Node->MPartitionNum)) {
107106
continue;
108107
}
109108

110-
auto &TotalVisitedEdges = Succ->MTotalVisitedEdges;
109+
auto &TotalVisitedEdges = Succ.MTotalVisitedEdges;
111110
++TotalVisitedEdges;
112-
if (TotalVisitedEdges == Succ->MPredecessors.size()) {
113-
Source.push(Succ);
111+
if (TotalVisitedEdges == Succ.MPredecessors.size()) {
112+
Source.push(Succ.weak_from_this());
114113
}
115114
}
116115
}
@@ -127,14 +126,14 @@ void sortTopological(std::set<std::weak_ptr<node_impl>,
127126
/// a node with a smaller partition number.
128127
/// @param Node Node to assign to the partition.
129128
/// @param PartitionNum Number to propagate.
130-
void propagatePartitionUp(std::shared_ptr<node_impl> Node, int PartitionNum) {
131-
if (((Node->MPartitionNum != -1) && (Node->MPartitionNum <= PartitionNum)) ||
132-
(Node->MCGType == sycl::detail::CGType::CodeplayHostTask)) {
129+
void propagatePartitionUp(node_impl &Node, int PartitionNum) {
130+
if (((Node.MPartitionNum != -1) && (Node.MPartitionNum <= PartitionNum)) ||
131+
(Node.MCGType == sycl::detail::CGType::CodeplayHostTask)) {
133132
return;
134133
}
135-
Node->MPartitionNum = PartitionNum;
136-
for (auto &Predecessor : Node->MPredecessors) {
137-
propagatePartitionUp(Predecessor.lock(), PartitionNum);
134+
Node.MPartitionNum = PartitionNum;
135+
for (node_impl &Predecessor : Node.predecessors()) {
136+
propagatePartitionUp(Predecessor, PartitionNum);
138137
}
139138
}
140139

@@ -146,17 +145,18 @@ void propagatePartitionUp(std::shared_ptr<node_impl> Node, int PartitionNum) {
146145
/// @param HostTaskList List of host tasks that have already been processed and
147146
/// are encountered as successors to the node Node.
148147
void propagatePartitionDown(
149-
const std::shared_ptr<node_impl> &Node, int PartitionNum,
148+
node_impl &Node, int PartitionNum,
150149
std::list<std::shared_ptr<node_impl>> &HostTaskList) {
151-
if (Node->MCGType == sycl::detail::CGType::CodeplayHostTask) {
152-
if (Node->MPartitionNum != -1) {
153-
HostTaskList.push_front(Node);
150+
if (Node.MCGType == sycl::detail::CGType::CodeplayHostTask) {
151+
if (Node.MPartitionNum != -1) {
152+
HostTaskList.push_front(Node.shared_from_this());
154153
}
155154
return;
156155
}
157-
Node->MPartitionNum = PartitionNum;
158-
for (auto &Successor : Node->MSuccessors) {
159-
propagatePartitionDown(Successor.lock(), PartitionNum, HostTaskList);
156+
Node.MPartitionNum = PartitionNum;
157+
for (node_impl &Successor : Node.successors()) {
158+
propagatePartitionDown(Successor, PartitionNum,
159+
HostTaskList);
160160
}
161161
}
162162

@@ -165,8 +165,8 @@ void propagatePartitionDown(
165165
/// @param Node node to test
166166
/// @return True is `Node` is a root of its partition
167167
bool isPartitionRoot(std::shared_ptr<node_impl> Node) {
168-
for (auto &Predecessor : Node->MPredecessors) {
169-
if (Predecessor.lock()->MPartitionNum == Node->MPartitionNum) {
168+
for (node_impl &Predecessor : Node->predecessors()) {
169+
if (Predecessor.MPartitionNum == Node->MPartitionNum) {
170170
return false;
171171
}
172172
}
@@ -221,15 +221,15 @@ void exec_graph_impl::makePartitions() {
221221
auto Node = HostTaskList.front();
222222
HostTaskList.pop_front();
223223
CurrentPartition++;
224-
for (auto &Predecessor : Node->MPredecessors) {
225-
propagatePartitionUp(Predecessor.lock(), CurrentPartition);
224+
for (node_impl &Predecessor : Node->predecessors()) {
225+
propagatePartitionUp(Predecessor, CurrentPartition);
226226
}
227227
CurrentPartition++;
228228
Node->MPartitionNum = CurrentPartition;
229229
CurrentPartition++;
230230
auto TmpSize = HostTaskList.size();
231-
for (auto &Successor : Node->MSuccessors) {
232-
propagatePartitionDown(Successor.lock(), CurrentPartition, HostTaskList);
231+
for (node_impl &Successor : Node->successors()) {
232+
propagatePartitionDown(Successor, CurrentPartition, HostTaskList);
233233
}
234234
if (HostTaskList.size() > TmpSize) {
235235
// At least one HostTask has been re-numbered so group merge opportunities
@@ -290,9 +290,9 @@ void exec_graph_impl::makePartitions() {
290290
for (const auto &Partition : MPartitions) {
291291
for (auto const &Root : Partition->MRoots) {
292292
auto RootNode = Root.lock();
293-
for (const auto &Dep : RootNode->MPredecessors) {
294-
auto NodeDep = Dep.lock();
295-
auto &Predecessor = MPartitions[MPartitionNodes[NodeDep]];
293+
for (node_impl &NodeDep : RootNode->predecessors()) {
294+
auto &Predecessor =
295+
MPartitions[MPartitionNodes[NodeDep.shared_from_this()]];
296296
Partition->MPredecessors.push_back(Predecessor.get());
297297
Predecessor->MSuccessors.push_back(Partition.get());
298298
}
@@ -424,8 +424,8 @@ std::set<std::shared_ptr<node_impl>> graph_impl::getCGEdges(
424424
bool ShouldAddDep = true;
425425
// If any of this node's successors have this requirement then we skip
426426
// adding the current node as a dependency.
427-
for (auto &Succ : Node->MSuccessors) {
428-
if (Succ.lock()->hasRequirementDependency(Req)) {
427+
for (node_impl &Succ : Node->successors()) {
428+
if (Succ.hasRequirementDependency(Req)) {
429429
ShouldAddDep = false;
430430
break;
431431
}
@@ -774,17 +774,17 @@ void graph_impl::beginRecording(sycl::detail::queue_impl &Queue) {
774774
// predecessors until we find the real dependency.
775775
void exec_graph_impl::findRealDeps(
776776
std::vector<ur_exp_command_buffer_sync_point_t> &Deps,
777-
std::shared_ptr<node_impl> CurrentNode, int ReferencePartitionNum) {
778-
if (!CurrentNode->requiresEnqueue()) {
779-
for (auto &N : CurrentNode->MPredecessors) {
780-
auto NodeImpl = N.lock();
777+
node_impl &CurrentNode, int ReferencePartitionNum) {
778+
if (!CurrentNode.requiresEnqueue()) {
779+
for (node_impl &NodeImpl : CurrentNode.predecessors()) {
781780
findRealDeps(Deps, NodeImpl, ReferencePartitionNum);
782781
}
783782
} else {
783+
auto CurrentNodePtr = CurrentNode.shared_from_this();
784784
// Verify if CurrentNode belong the the same partition
785-
if (MPartitionNodes[CurrentNode] == ReferencePartitionNum) {
785+
if (MPartitionNodes[CurrentNodePtr] == ReferencePartitionNum) {
786786
// Verify that the sync point has actually been set for this node.
787-
auto SyncPoint = MSyncPoints.find(CurrentNode);
787+
auto SyncPoint = MSyncPoints.find(CurrentNodePtr);
788788
assert(SyncPoint != MSyncPoints.end() &&
789789
"No sync point has been set for node dependency.");
790790
// Check if the dependency has already been added.
@@ -802,8 +802,8 @@ exec_graph_impl::enqueueNodeDirect(const sycl::context &Ctx,
802802
ur_exp_command_buffer_handle_t CommandBuffer,
803803
std::shared_ptr<node_impl> Node) {
804804
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
805-
for (auto &N : Node->MPredecessors) {
806-
findRealDeps(Deps, N.lock(), MPartitionNodes[Node]);
805+
for (node_impl &N : Node->predecessors()) {
806+
findRealDeps(Deps, N, MPartitionNodes[Node]);
807807
}
808808
ur_exp_command_buffer_sync_point_t NewSyncPoint;
809809
ur_exp_command_buffer_command_handle_t NewCommand = 0;
@@ -858,8 +858,8 @@ exec_graph_impl::enqueueNode(ur_exp_command_buffer_handle_t CommandBuffer,
858858
std::shared_ptr<node_impl> Node) {
859859

860860
std::vector<ur_exp_command_buffer_sync_point_t> Deps;
861-
for (auto &N : Node->MPredecessors) {
862-
findRealDeps(Deps, N.lock(), MPartitionNodes[Node]);
861+
for (node_impl &N : Node->predecessors()) {
862+
findRealDeps(Deps, N, MPartitionNodes[Node]);
863863
}
864864

865865
sycl::detail::EventImplPtr Event =
@@ -1328,8 +1328,8 @@ void exec_graph_impl::duplicateNodes() {
13281328
auto NodeCopy = NewNodes[i];
13291329
// Look through all the original node successors, find their copies and
13301330
// register those as successors with the current copied node
1331-
for (auto &NextNode : OriginalNode->MSuccessors) {
1332-
auto Successor = NodesMap.at(NextNode.lock());
1331+
for (node_impl &NextNode : OriginalNode->successors()) {
1332+
auto Successor = NodesMap.at(NextNode.shared_from_this());
13331333
NodeCopy->registerSuccessor(Successor);
13341334
}
13351335
}
@@ -1370,8 +1370,8 @@ void exec_graph_impl::duplicateNodes() {
13701370
auto SubgraphNode = SubgraphNodes[i];
13711371
auto NodeCopy = NewSubgraphNodes[i];
13721372

1373-
for (auto &NextNode : SubgraphNode->MSuccessors) {
1374-
auto Successor = SubgraphNodesMap.at(NextNode.lock());
1373+
for (node_impl &NextNode : SubgraphNode->successors()) {
1374+
auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this());
13751375
NodeCopy->registerSuccessor(Successor);
13761376
}
13771377
}
@@ -1392,9 +1392,8 @@ void exec_graph_impl::duplicateNodes() {
13921392
// original subgraph node
13931393

13941394
// Predecessors
1395-
for (auto &PredNodeWeak : NewNode->MPredecessors) {
1396-
auto PredNode = PredNodeWeak.lock();
1397-
auto &Successors = PredNode->MSuccessors;
1395+
for (node_impl &PredNode : NewNode->predecessors()) {
1396+
auto &Successors = PredNode.MSuccessors;
13981397

13991398
// Remove the subgraph node from this nodes successors
14001399
Successors.erase(std::remove_if(Successors.begin(), Successors.end(),
@@ -1406,14 +1405,13 @@ void exec_graph_impl::duplicateNodes() {
14061405
// Add all input nodes from the subgraph as successors for this node
14071406
// instead
14081407
for (auto &Input : Inputs) {
1409-
PredNode->registerSuccessor(Input);
1408+
PredNode.registerSuccessor(Input);
14101409
}
14111410
}
14121411

14131412
// Successors
1414-
for (auto &SuccNodeWeak : NewNode->MSuccessors) {
1415-
auto SuccNode = SuccNodeWeak.lock();
1416-
auto &Predecessors = SuccNode->MPredecessors;
1413+
for (node_impl &SuccNode : NewNode->successors()) {
1414+
auto &Predecessors = SuccNode.MPredecessors;
14171415

14181416
// Remove the subgraph node from this nodes successors
14191417
Predecessors.erase(std::remove_if(Predecessors.begin(),
@@ -1426,7 +1424,7 @@ void exec_graph_impl::duplicateNodes() {
14261424
// Add all Output nodes from the subgraph as predecessors for this node
14271425
// instead
14281426
for (auto &Output : Outputs) {
1429-
Output->registerSuccessor(SuccNode);
1427+
Output->registerSuccessor(SuccNode.shared_from_this());
14301428
}
14311429
}
14321430

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 7 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -352,19 +352,17 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
352352
/// @param NodeA pointer to the first node for comparison
353353
/// @param NodeB pointer to the second node for comparison
354354
/// @return true is same structure found, false otherwise
355-
static bool checkNodeRecursive(const std::shared_ptr<node_impl> &NodeA,
356-
const std::shared_ptr<node_impl> &NodeB) {
355+
static bool checkNodeRecursive(node_impl &NodeA, node_impl &NodeB) {
357356
size_t FoundCnt = 0;
358-
for (std::weak_ptr<node_impl> &SuccA : NodeA->MSuccessors) {
359-
for (std::weak_ptr<node_impl> &SuccB : NodeB->MSuccessors) {
360-
if (NodeA->isSimilar(*NodeB) &&
361-
checkNodeRecursive(SuccA.lock(), SuccB.lock())) {
357+
for (node_impl &SuccA : NodeA.successors()) {
358+
for (node_impl &SuccB : NodeB.successors()) {
359+
if (NodeA.isSimilar(NodeB) && checkNodeRecursive(SuccA, SuccB)) {
362360
FoundCnt++;
363361
break;
364362
}
365363
}
366364
}
367-
if (FoundCnt != NodeA->MSuccessors.size()) {
365+
if (FoundCnt != NodeA.MSuccessors.size()) {
368366
return false;
369367
}
370368

@@ -434,7 +432,7 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
434432
auto NodeBLocked = NodeB.lock();
435433

436434
if (NodeALocked->isSimilar(*NodeBLocked)) {
437-
if (checkNodeRecursive(NodeALocked, NodeBLocked)) {
435+
if (checkNodeRecursive(*NodeALocked, *NodeBLocked)) {
438436
RootsFound++;
439437
break;
440438
}
@@ -829,8 +827,7 @@ class exec_graph_impl {
829827
/// SyncPoint for CurrentNode, otherwise we need to
830828
/// synchronize on the host with the completion of previous partitions.
831829
void findRealDeps(std::vector<ur_exp_command_buffer_sync_point_t> &Deps,
832-
std::shared_ptr<node_impl> CurrentNode,
833-
int ReferencePartitionNum);
830+
node_impl &CurrentNode, int ReferencePartitionNum);
834831

835832
/// Duplicate nodes from the modifiable graph associated with this executable
836833
/// graph and store them locally. Any subgraph nodes in the modifiable graph

sycl/source/detail/graph/memory_pool.cpp

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -116,49 +116,44 @@ graph_mem_pool::tryReuseExistingAllocation(
116116
// free nodes. We do this in a breadth-first approach because we want to find
117117
// the shortest path to a reusable allocation.
118118

119-
std::queue<std::weak_ptr<node_impl>> NodesToCheck;
119+
std::queue<node_impl *> NodesToCheck;
120120

121121
// Add all the dependent nodes to the queue, they will be popped first
122122
for (auto &Dep : DepNodes) {
123-
NodesToCheck.push(Dep);
123+
NodesToCheck.push(&*Dep);
124124
}
125125

126126
// Called when traversing over nodes to check if the current node is a free
127127
// node for one of the available allocations. If it is we populate AllocInfo
128128
// with the allocation to be reused.
129129
auto CheckNodeEqual =
130-
[&CompatibleAllocs](const std::shared_ptr<node_impl> &CurrentNode)
131-
-> std::optional<alloc_info> {
130+
[&CompatibleAllocs](node_impl &CurrentNode) -> std::optional<alloc_info> {
132131
for (auto &Alloc : CompatibleAllocs) {
133-
const auto &AllocFreeNode = Alloc.LastFreeNode;
134-
// Compare control blocks without having to lock AllocFreeNode to check
135-
// for node equality
136-
if (!CurrentNode.owner_before(AllocFreeNode) &&
137-
!AllocFreeNode.owner_before(CurrentNode)) {
132+
if (&CurrentNode == Alloc.LastFreeNode) {
138133
return Alloc;
139134
}
140135
}
141136
return std::nullopt;
142137
};
143138

144139
while (!NodesToCheck.empty()) {
145-
auto CurrentNode = NodesToCheck.front().lock();
140+
node_impl &CurrentNode = *NodesToCheck.front();
146141

147-
if (CurrentNode->MTotalVisitedEdges > 0) {
142+
if (CurrentNode.MTotalVisitedEdges > 0) {
148143
continue;
149144
}
150145

151146
// Check if the node is a free node and, if so, check if it is a free node
152147
// for any of the allocations which are free for reuse. We should not bother
153148
// checking nodes that are not free nodes, so we continue and check their
154149
// predecessors.
155-
if (CurrentNode->MNodeType == node_type::async_free) {
150+
if (CurrentNode.MNodeType == node_type::async_free) {
156151
std::optional<alloc_info> AllocFound = CheckNodeEqual(CurrentNode);
157152
if (AllocFound) {
158153
// Reset visited nodes tracking
159154
MGraph.resetNodeVisitedEdges();
160155
// Reset last free node for allocation
161-
MAllocations.at(AllocFound.value().Ptr).LastFreeNode.reset();
156+
MAllocations.at(AllocFound.value().Ptr).LastFreeNode = nullptr;
162157
// Remove found allocation from the free list
163158
MFreeAllocations.erase(std::find(MFreeAllocations.begin(),
164159
MFreeAllocations.end(),
@@ -168,12 +163,12 @@ graph_mem_pool::tryReuseExistingAllocation(
168163
}
169164

170165
// Add CurrentNode predecessors to queue
171-
for (auto &Pred : CurrentNode->MPredecessors) {
172-
NodesToCheck.push(Pred);
166+
for (node_impl &Pred : CurrentNode.predecessors()) {
167+
NodesToCheck.push(&Pred);
173168
}
174169

175170
// Mark node as visited
176-
CurrentNode->MTotalVisitedEdges = 1;
171+
CurrentNode.MTotalVisitedEdges = 1;
177172
NodesToCheck.pop();
178173
}
179174

@@ -183,7 +178,7 @@ graph_mem_pool::tryReuseExistingAllocation(
183178
void graph_mem_pool::markAllocationAsAvailable(
184179
void *Ptr, const std::shared_ptr<node_impl> &FreeNode) {
185180
MFreeAllocations.push_back(Ptr);
186-
MAllocations.at(Ptr).LastFreeNode = FreeNode;
181+
MAllocations.at(Ptr).LastFreeNode = FreeNode.get();
187182
}
188183

189184
} // namespace detail

sycl/source/detail/graph/memory_pool.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class graph_mem_pool {
4444
// Should the allocation be zero initialized during initial allocation
4545
bool ZeroInit = false;
4646
// Last free node for this allocation in the graph
47-
std::weak_ptr<node_impl> LastFreeNode = {};
47+
node_impl *LastFreeNode = nullptr;
4848
};
4949

5050
public:

sycl/source/detail/graph/node_impl.hpp

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ class node;
3131
namespace detail {
3232
// Forward declarations
3333
class node_impl;
34+
class nodes_range;
3435
class exec_graph_impl;
3536

3637
/// Takes a vector of weak_ptrs to node_impls and returns a vector of node
@@ -116,6 +117,10 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
116117
/// cannot be used to find out the partion of a node outside of this process.
117118
int MPartitionNum = -1;
118119

120+
// Out-of-class as need "complete" `nodes_range`:
121+
inline nodes_range successors() const;
122+
inline nodes_range predecessors() const;
123+
119124
/// Add successor to the node.
120125
/// @param Node Node to add as a successor.
121126
void registerSuccessor(const std::shared_ptr<node_impl> &Node) {
@@ -830,6 +835,10 @@ class nodes_range {
830835
size_t size() const { return Size; }
831836
bool empty() const { return Size == 0; }
832837
};
838+
839+
inline nodes_range node_impl::successors() const { return MSuccessors; }
840+
inline nodes_range node_impl::predecessors() const { return MPredecessors; }
841+
833842
} // namespace detail
834843
} // namespace experimental
835844
} // namespace oneapi

0 commit comments

Comments
 (0)