Skip to content

Commit ab6210b

Browse files
[NFC][SYCL][Graph] Store raw node_impl * in MPredecessors/MSuccessors (#19438)
Continuation of #19295 and #19332.
1 parent 7e1ea22 commit ab6210b

File tree

8 files changed

+195
-211
lines changed

8 files changed

+195
-211
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 6 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1321,11 +1321,9 @@ void exec_graph_impl::duplicateNodes() {
13211321
auto &Successors = PredNode.MSuccessors;
13221322

13231323
// Remove the subgraph node from this nodes successors
1324-
Successors.erase(std::remove_if(Successors.begin(), Successors.end(),
1325-
[NewNode](auto WeakNode) {
1326-
return WeakNode.lock() == NewNode;
1327-
}),
1328-
Successors.end());
1324+
Successors.erase(
1325+
std::remove(Successors.begin(), Successors.end(), NewNode.get()),
1326+
Successors.end());
13291327

13301328
// Add all input nodes from the subgraph as successors for this node
13311329
// instead
@@ -1339,12 +1337,9 @@ void exec_graph_impl::duplicateNodes() {
13391337
auto &Predecessors = SuccNode.MPredecessors;
13401338

13411339
// Remove the subgraph node from this nodes successors
1342-
Predecessors.erase(std::remove_if(Predecessors.begin(),
1343-
Predecessors.end(),
1344-
[NewNode](auto WeakNode) {
1345-
return WeakNode.lock() == NewNode;
1346-
}),
1347-
Predecessors.end());
1340+
Predecessors.erase(
1341+
std::remove(Predecessors.begin(), Predecessors.end(), NewNode.get()),
1342+
Predecessors.end());
13481343

13491344
// Add all Output nodes from the subgraph as predecessors for this node
13501345
// instead

sycl/source/detail/graph/node_impl.hpp

Lines changed: 13 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -86,11 +86,11 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
8686
/// Unique identifier for this node.
8787
id_type MID = getNextNodeID();
8888
/// List of successors to this node.
89-
std::vector<std::weak_ptr<node_impl>> MSuccessors;
89+
std::vector<node_impl *> MSuccessors;
9090
/// List of predecessors to this node.
9191
///
9292
/// Using weak_ptr here to prevent circular references between nodes.
93-
std::vector<std::weak_ptr<node_impl>> MPredecessors;
93+
std::vector<node_impl *> MPredecessors;
9494
/// Type of the command-group for the node.
9595
sycl::detail::CGType MCGType = sycl::detail::CGType::None;
9696
/// User facing type of the node.
@@ -116,26 +116,22 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
116116
/// Add successor to the node.
117117
/// @param Node Node to add as a successor.
118118
void registerSuccessor(node_impl &Node) {
119-
if (std::find_if(MSuccessors.begin(), MSuccessors.end(),
120-
[&Node](const std::weak_ptr<node_impl> &Ptr) {
121-
return Ptr.lock().get() == &Node;
122-
}) != MSuccessors.end()) {
119+
if (std::find(MSuccessors.begin(), MSuccessors.end(), &Node) !=
120+
MSuccessors.end()) {
123121
return;
124122
}
125-
MSuccessors.push_back(Node.weak_from_this());
123+
MSuccessors.push_back(&Node);
126124
Node.registerPredecessor(*this);
127125
}
128126

129127
/// Add predecessor to the node.
130128
/// @param Node Node to add as a predecessor.
131129
void registerPredecessor(node_impl &Node) {
132-
if (std::find_if(MPredecessors.begin(), MPredecessors.end(),
133-
[&Node](const std::weak_ptr<node_impl> &Ptr) {
134-
return Ptr.lock().get() == &Node;
135-
}) != MPredecessors.end()) {
130+
if (std::find(MPredecessors.begin(), MPredecessors.end(), &Node) !=
131+
MPredecessors.end()) {
136132
return;
137133
}
138-
MPredecessors.push_back(Node.weak_from_this());
134+
MPredecessors.push_back(&Node);
139135
}
140136

141137
/// Construct an empty node.
@@ -386,15 +382,13 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
386382
Visited.push_back(this);
387383

388384
printDotCG(Stream, Verbose);
389-
for (const auto &Dep : MPredecessors) {
390-
auto NodeDep = Dep.lock();
391-
Stream << " \"" << NodeDep.get() << "\" -> \"" << this << "\""
392-
<< std::endl;
385+
for (node_impl *Pred : MPredecessors) {
386+
Stream << " \"" << Pred << "\" -> \"" << this << "\"" << std::endl;
393387
}
394388

395-
for (std::weak_ptr<node_impl> Succ : MSuccessors) {
396-
if (MPartitionNum == Succ.lock()->MPartitionNum)
397-
Succ.lock()->printDotRecursive(Stream, Visited, Verbose);
389+
for (node_impl *Succ : MSuccessors) {
390+
if (MPartitionNum == Succ->MPartitionNum)
391+
Succ->printDotRecursive(Stream, Visited, Verbose);
398392
}
399393
}
400394

0 commit comments

Comments
 (0)