Skip to content

Commit c2eaf0f

Browse files
[NFCI][SYCL][Graph] Refactor graph_impl::add (#19351)
Part of the refactoring to eliminate `std::weak_ptr<node_impl>` and reduce usage of `std::shared_ptr<node_impl>` by preferring raw ptr/ref. Previous PRs in the series: #19295 #19332 #19334 #19350 * Accept `Deps` as `nodes_range` in `graph_impl::add` * Return `node_impl &` from `graph_impl::add` * Add `node` support in `nodes_range` and use that together with modified `graph_impl::add` when created new `node_impl`s based on `std::vector<node> Deps` to avoid creation of temporary `DepImpls` storage. * Also updated `registerSuccessor/registerPredecessor` and `addEventForNode/addDepsToNode` to accept raw `node_impl &` as the changes above resulted in having raw reference at the call sites.
1 parent 00ae3dd commit c2eaf0f

File tree

4 files changed

+75
-86
lines changed

4 files changed

+75
-86
lines changed

sycl/source/detail/graph/graph_impl.cpp

Lines changed: 33 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -409,22 +409,19 @@ void graph_impl::markCGMemObjs(
409409
}
410410
}
411411

412-
std::shared_ptr<node_impl> graph_impl::add(nodes_range Deps) {
413-
const std::shared_ptr<node_impl> &NodeImpl = std::make_shared<node_impl>();
414-
415-
MNodeStorage.push_back(NodeImpl);
412+
node_impl &graph_impl::add(nodes_range Deps) {
413+
node_impl &NodeImpl = createNode();
416414

417415
addDepsToNode(NodeImpl, Deps);
418416
// Add an event associated with this explicit node for mixed usage
419417
addEventForNode(sycl::detail::event_impl::create_completed_host_event(),
420-
*NodeImpl);
418+
NodeImpl);
421419
return NodeImpl;
422420
}
423421

424-
std::shared_ptr<node_impl>
425-
graph_impl::add(std::function<void(handler &)> CGF,
426-
const std::vector<sycl::detail::ArgDesc> &Args,
427-
std::vector<std::shared_ptr<node_impl>> &Deps) {
422+
node_impl &graph_impl::add(std::function<void(handler &)> CGF,
423+
const std::vector<sycl::detail::ArgDesc> &Args,
424+
nodes_range Deps) {
428425
(void)Args;
429426
#ifdef __INTEL_PREVIEW_BREAKING_CHANGES
430427
detail::handler_impl HandlerImpl{*this};
@@ -435,7 +432,8 @@ graph_impl::add(std::function<void(handler &)> CGF,
435432

436433
// Pass the node deps to the handler so they are available when processing the
437434
// CGF, need for async_malloc nodes.
438-
Handler.impl->MNodeDeps = Deps;
435+
for (node_impl &N : Deps)
436+
Handler.impl->MNodeDeps.push_back(N.shared_from_this());
439437

440438
#if XPTI_ENABLE_INSTRUMENTATION
441439
// Save code location if one was set in TLS.
@@ -471,12 +469,12 @@ graph_impl::add(std::function<void(handler &)> CGF,
471469
: ext::oneapi::experimental::detail::getNodeTypeFromCG(
472470
Handler.getType());
473471

474-
auto NodeImpl =
472+
node_impl &NodeImpl =
475473
this->add(NodeType, std::move(Handler.impl->MGraphNodeCG), Deps);
476474

477475
// Add an event associated with this explicit node for mixed usage
478476
addEventForNode(sycl::detail::event_impl::create_completed_host_event(),
479-
*NodeImpl);
477+
NodeImpl);
480478

481479
// Retrieve any dynamic parameters which have been registered in the CGF and
482480
// register the actual nodes with them.
@@ -489,44 +487,40 @@ graph_impl::add(std::function<void(handler &)> CGF,
489487
}
490488

491489
for (auto &[DynamicParam, ArgIndex] : DynamicParams) {
492-
DynamicParam->registerNode(NodeImpl, ArgIndex);
490+
DynamicParam->registerNode(NodeImpl.shared_from_this(), ArgIndex);
493491
}
494492

495493
return NodeImpl;
496494
}
497495

498-
std::shared_ptr<node_impl>
499-
graph_impl::add(node_type NodeType,
500-
std::shared_ptr<sycl::detail::CG> CommandGroup,
501-
nodes_range Deps) {
496+
node_impl &graph_impl::add(node_type NodeType,
497+
std::shared_ptr<sycl::detail::CG> CommandGroup,
498+
nodes_range Deps) {
502499

503500
// A unique set of dependencies obtained by checking requirements and events
504501
std::set<node_impl *> UniqueDeps = getCGEdges(CommandGroup);
505502

506503
// Track and mark the memory objects being used by the graph.
507504
markCGMemObjs(CommandGroup);
508505

509-
const std::shared_ptr<node_impl> &NodeImpl =
510-
std::make_shared<node_impl>(NodeType, std::move(CommandGroup));
511-
MNodeStorage.push_back(NodeImpl);
506+
node_impl &NodeImpl = createNode(NodeType, std::move(CommandGroup));
512507

513508
// Add any deps determined from requirements and events into the dependency
514509
// list
515510
addDepsToNode(NodeImpl, Deps);
516511
addDepsToNode(NodeImpl, UniqueDeps);
517512

518513
if (NodeType == node_type::async_free) {
519-
auto AsyncFreeCG =
520-
static_cast<CGAsyncFree *>(NodeImpl->MCommandGroup.get());
514+
auto AsyncFreeCG = static_cast<CGAsyncFree *>(NodeImpl.MCommandGroup.get());
521515
// If this is an async free node mark that it is now available for reuse,
522516
// and pass the async free node for tracking.
523-
MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(), *NodeImpl);
517+
MGraphMemPool.markAllocationAsAvailable(AsyncFreeCG->getPtr(), NodeImpl);
524518
}
525519

526520
return NodeImpl;
527521
}
528522

529-
std::shared_ptr<node_impl>
523+
node_impl &
530524
graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
531525
nodes_range Deps) {
532526
// Set of Dependent nodes based on CG event and accessor dependencies.
@@ -550,15 +544,14 @@ graph_impl::add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
550544
const auto &ActiveKernel = DynCGImpl->getActiveCG();
551545
node_type NodeType =
552546
ext::oneapi::experimental::detail::getNodeTypeFromCG(DynCGImpl->MCGType);
553-
std::shared_ptr<detail::node_impl> NodeImpl =
554-
add(NodeType, ActiveKernel, Deps);
547+
detail::node_impl &NodeImpl = add(NodeType, ActiveKernel, Deps);
555548

556549
// Add an event associated with this explicit node for mixed usage
557550
addEventForNode(sycl::detail::event_impl::create_completed_host_event(),
558-
*NodeImpl);
551+
NodeImpl);
559552

560553
// Track the dynamic command-group used inside the node object
561-
DynCGImpl->MNodes.push_back(NodeImpl);
554+
DynCGImpl->MNodes.push_back(NodeImpl.shared_from_this());
562555

563556
return NodeImpl;
564557
}
@@ -651,7 +644,7 @@ void graph_impl::makeEdge(std::shared_ptr<node_impl> Src,
651644
bool DestWasGraphRoot = Dest->MPredecessors.size() == 0;
652645

653646
// We need to add the edges first before checking for cycles
654-
Src->registerSuccessor(Dest);
647+
Src->registerSuccessor(*Dest);
655648

656649
bool DestLostRootStatus = DestWasGraphRoot && Dest->MPredecessors.size() == 1;
657650
if (DestLostRootStatus) {
@@ -1264,7 +1257,7 @@ void exec_graph_impl::duplicateNodes() {
12641257
// Look through all the original node successors, find their copies and
12651258
// register those as successors with the current copied node
12661259
for (node_impl &NextNode : OriginalNode->successors()) {
1267-
auto Successor = NodesMap.at(NextNode.shared_from_this());
1260+
node_impl &Successor = *NodesMap.at(NextNode.shared_from_this());
12681261
NodeCopy->registerSuccessor(Successor);
12691262
}
12701263
}
@@ -1306,7 +1299,8 @@ void exec_graph_impl::duplicateNodes() {
13061299
auto NodeCopy = NewSubgraphNodes[i];
13071300

13081301
for (node_impl &NextNode : SubgraphNode->successors()) {
1309-
auto Successor = SubgraphNodesMap.at(NextNode.shared_from_this());
1302+
node_impl &Successor =
1303+
*SubgraphNodesMap.at(NextNode.shared_from_this());
13101304
NodeCopy->registerSuccessor(Successor);
13111305
}
13121306
}
@@ -1340,7 +1334,7 @@ void exec_graph_impl::duplicateNodes() {
13401334
// Add all input nodes from the subgraph as successors for this node
13411335
// instead
13421336
for (auto &Input : Inputs) {
1343-
PredNode.registerSuccessor(Input);
1337+
PredNode.registerSuccessor(*Input);
13441338
}
13451339
}
13461340

@@ -1359,7 +1353,7 @@ void exec_graph_impl::duplicateNodes() {
13591353
// Add all Output nodes from the subgraph as predecessors for this node
13601354
// instead
13611355
for (auto &Output : Outputs) {
1362-
Output->registerSuccessor(SuccNode.shared_from_this());
1356+
Output->registerSuccessor(SuccNode);
13631357
}
13641358
}
13651359

@@ -1840,38 +1834,25 @@ node modifiable_command_graph::addImpl(dynamic_command_group &DynCGF,
18401834
"dynamic command-group.");
18411835
}
18421836

1843-
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
1844-
for (auto &D : Deps) {
1845-
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
1846-
}
1847-
18481837
graph_impl::WriteLock Lock(impl->MMutex);
1849-
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DynCGFImpl, DepImpls);
1850-
return sycl::detail::createSyclObjFromImpl<node>(std::move(NodeImpl));
1838+
detail::node_impl &NodeImpl = impl->add(DynCGFImpl, Deps);
1839+
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
18511840
}
18521841

18531842
node modifiable_command_graph::addImpl(const std::vector<node> &Deps) {
18541843
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
1855-
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
1856-
for (auto &D : Deps) {
1857-
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
1858-
}
18591844

18601845
graph_impl::WriteLock Lock(impl->MMutex);
1861-
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(DepImpls);
1862-
return sycl::detail::createSyclObjFromImpl<node>(std::move(NodeImpl));
1846+
detail::node_impl &NodeImpl = impl->add(Deps);
1847+
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
18631848
}
18641849

18651850
node modifiable_command_graph::addImpl(std::function<void(handler &)> CGF,
18661851
const std::vector<node> &Deps) {
18671852
impl->throwIfGraphRecordingQueue("Explicit API \"Add()\" function");
1868-
std::vector<std::shared_ptr<detail::node_impl>> DepImpls;
1869-
for (auto &D : Deps) {
1870-
DepImpls.push_back(sycl::detail::getSyclObjImpl(D));
1871-
}
18721853

1873-
std::shared_ptr<detail::node_impl> NodeImpl = impl->add(CGF, {}, DepImpls);
1874-
return sycl::detail::createSyclObjFromImpl<node>(std::move(NodeImpl));
1854+
detail::node_impl &NodeImpl = impl->add(CGF, {}, Deps);
1855+
return sycl::detail::createSyclObjFromImpl<node>(NodeImpl);
18751856
}
18761857

18771858
void modifiable_command_graph::addGraphLeafDependencies(node Node) {

sycl/source/detail/graph/graph_impl.hpp

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -147,30 +147,30 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
147147
/// @param CommandGroup The CG which stores all information for this node.
148148
/// @param Deps Dependencies of the created node.
149149
/// @return Created node in the graph.
150-
std::shared_ptr<node_impl> add(node_type NodeType,
151-
std::shared_ptr<sycl::detail::CG> CommandGroup,
152-
nodes_range Deps);
150+
node_impl &add(node_type NodeType,
151+
std::shared_ptr<sycl::detail::CG> CommandGroup,
152+
nodes_range Deps);
153153

154154
/// Create a CGF node in the graph.
155155
/// @param CGF Command-group function to create node with.
156156
/// @param Args Node arguments.
157157
/// @param Deps Dependencies of the created node.
158158
/// @return Created node in the graph.
159-
std::shared_ptr<node_impl> add(std::function<void(handler &)> CGF,
160-
const std::vector<sycl::detail::ArgDesc> &Args,
161-
std::vector<std::shared_ptr<node_impl>> &Deps);
159+
node_impl &add(std::function<void(handler &)> CGF,
160+
const std::vector<sycl::detail::ArgDesc> &Args,
161+
nodes_range Deps);
162162

163163
/// Create an empty node in the graph.
164164
/// @param Deps List of predecessor nodes.
165165
/// @return Created node in the graph.
166-
std::shared_ptr<node_impl> add(nodes_range Deps);
166+
node_impl &add(nodes_range Deps);
167167

168168
/// Create a dynamic command-group node in the graph.
169169
/// @param DynCGImpl Dynamic command-group used to create node.
170170
/// @param Deps List of predecessor nodes.
171171
/// @return Created node in the graph.
172-
std::shared_ptr<node_impl>
173-
add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl, nodes_range Deps);
172+
node_impl &add(std::shared_ptr<dynamic_command_group_impl> &DynCGImpl,
173+
nodes_range Deps);
174174

175175
/// Add a queue to the set of queues which are currently recording to this
176176
/// graph.
@@ -511,6 +511,12 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
511511
}
512512

513513
private:
514+
template <typename... Ts> node_impl &createNode(Ts &&...Args) {
515+
MNodeStorage.push_back(
516+
std::make_shared<node_impl>(std::forward<Ts>(Args)...));
517+
return *MNodeStorage.back();
518+
}
519+
514520
/// Check the graph for cycles by performing a depth-first search of the
515521
/// graph. If a node is visited more than once in a given path through the
516522
/// graph, a cycle is present and the search ends immediately.
@@ -525,13 +531,13 @@ class graph_impl : public std::enable_shared_from_this<graph_impl> {
525531
/// added as a root node.
526532
/// @param Node The node to add deps for
527533
/// @param Deps List of dependent nodes
528-
void addDepsToNode(const std::shared_ptr<node_impl> &Node, nodes_range Deps) {
534+
void addDepsToNode(node_impl &Node, nodes_range Deps) {
529535
for (node_impl &N : Deps) {
530536
N.registerSuccessor(Node);
531-
this->removeRoot(*Node);
537+
this->removeRoot(Node);
532538
}
533-
if (Node->MPredecessors.empty()) {
534-
this->addRoot(*Node);
539+
if (Node.MPredecessors.empty()) {
540+
this->addRoot(Node);
535541
}
536542
}
537543

sycl/source/detail/graph/node_impl.hpp

Lines changed: 15 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
#include <sycl/detail/cg_types.hpp> // for CGType
1616
#include <sycl/detail/kernel_desc.hpp> // for kernel_param_kind_t
1717

18+
#include <sycl/ext/oneapi/experimental/graph/node.hpp> // for node
19+
1820
#include <cstring>
1921
#include <fstream>
2022
#include <iomanip>
@@ -27,8 +29,6 @@ inline namespace _V1 {
2729
namespace ext {
2830
namespace oneapi {
2931
namespace experimental {
30-
// Forward declarations
31-
class node;
3232

3333
namespace detail {
3434
// Forward declarations
@@ -122,27 +122,27 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
122122

123123
/// Add successor to the node.
124124
/// @param Node Node to add as a successor.
125-
void registerSuccessor(const std::shared_ptr<node_impl> &Node) {
125+
void registerSuccessor(node_impl &Node) {
126126
if (std::find_if(MSuccessors.begin(), MSuccessors.end(),
127-
[Node](const std::weak_ptr<node_impl> &Ptr) {
128-
return Ptr.lock() == Node;
127+
[&Node](const std::weak_ptr<node_impl> &Ptr) {
128+
return Ptr.lock().get() == &Node;
129129
}) != MSuccessors.end()) {
130130
return;
131131
}
132-
MSuccessors.push_back(Node);
133-
Node->registerPredecessor(shared_from_this());
132+
MSuccessors.push_back(Node.weak_from_this());
133+
Node.registerPredecessor(*this);
134134
}
135135

136136
/// Add predecessor to the node.
137137
/// @param Node Node to add as a predecessor.
138-
void registerPredecessor(const std::shared_ptr<node_impl> &Node) {
138+
void registerPredecessor(node_impl &Node) {
139139
if (std::find_if(MPredecessors.begin(), MPredecessors.end(),
140140
[&Node](const std::weak_ptr<node_impl> &Ptr) {
141-
return Ptr.lock() == Node;
141+
return Ptr.lock().get() == &Node;
142142
}) != MPredecessors.end()) {
143143
return;
144144
}
145-
MPredecessors.push_back(Node);
145+
MPredecessors.push_back(Node.weak_from_this());
146146
}
147147

148148
/// Construct an empty node.
@@ -764,12 +764,14 @@ class node_impl : public std::enable_shared_from_this<node_impl> {
764764

765765
struct nodes_deref_impl {
766766
template <typename T> static node_impl &dereference(T &Elem) {
767-
if constexpr (std::is_same_v<std::decay_t<decltype(Elem)>,
768-
std::weak_ptr<node_impl>>) {
767+
using Ty = std::decay_t<decltype(Elem)>;
768+
if constexpr (std::is_same_v<Ty, std::weak_ptr<node_impl>>) {
769769
// This assumes that weak_ptr doesn't actually manage lifetime and
770770
// the object is guaranteed to be alive (which seems to be the
771771
// assumption across all graph code).
772772
return *Elem.lock();
773+
} else if constexpr (std::is_same_v<Ty, node>) {
774+
return *getSyclObjImpl(Elem);
773775
} else {
774776
return *Elem;
775777
}
@@ -791,7 +793,7 @@ using nodes_iterator = nodes_iterator_impl<
791793
//
792794
std::set<std::shared_ptr<node_impl>>, std::set<node_impl *>,
793795
//
794-
std::list<node_impl *>>;
796+
std::list<node_impl *>, std::vector<node>>;
795797

796798
class nodes_range : public iterator_range<nodes_iterator> {
797799
private:

0 commit comments

Comments
 (0)