44
55#pragma once
66
7- #include " runtime.hpp"
87#include < optional>
98
9+ #include " runtime.hpp"
10+
1011namespace CUDA {
1112
1213class GraphCapture ;
1314class CaptureInfo ;
1415
15- class Graph : public Handle <cudaGraph_t> {
16+ class Graph : public Handle <cudaGraph_t> {
1617public:
1718 Graph (unsigned int flags);
1819
@@ -23,14 +24,14 @@ class Graph: public Handle<cudaGraph_t> {
2324private:
2425 Graph (cudaGraph_t graph);
2526
26- static cudaError_t createFromNative (cudaGraph_t * pGraph, cudaGraph_t anotherGraph);
27+ static cudaError_t createFromNative (cudaGraph_t* pGraph, cudaGraph_t anotherGraph);
2728
2829 static cudaGraph_t createNativeWithFlags (unsigned int flags);
2930};
3031
3132bool operator ==(const Graph& rhs, const Graph& lhs);
3233
33- class GraphExec : public Handle <cudaGraphExec_t> {
34+ class GraphExec : public Handle <cudaGraphExec_t> {
3435public:
3536 GraphExec (const Graph& g);
3637
@@ -73,16 +74,18 @@ class GraphCapture {
7374
7475private:
7576 Stream stream_;
76- cudaGraph_t cudaGraph_ {};
77- cudaError_t capturedError_ {cudaSuccess};
78- std::optional<Graph> graph_ {};
77+ cudaGraph_t cudaGraph_{};
78+ cudaError_t capturedError_{cudaSuccess};
79+ std::optional<Graph> graph_{};
7980};
8081
8182class UploadNode {
8283 friend CaptureInfo;
84+
8385public:
8486 void update_src (const GraphExec& exec, const void * src);
8587 bool operator ==(const UploadNode& rhs) const ;
88+
8689private:
8790 UploadNode (cudaGraphNode_t node, CUDA::DevicePointer<void *> dst, const void * src, std::size_t size);
8891 cudaGraphNode_t node_;
@@ -93,9 +96,11 @@ class UploadNode {
9396
9497class DownloadNode {
9598 friend CaptureInfo;
99+
96100public:
97101 void update_dst (const GraphExec& exec, void * dst);
98102 bool operator ==(const DownloadNode& rhs) const ;
103+
99104private:
100105 DownloadNode (cudaGraphNode_t node, void * dst, CUDA::DevicePointer<const void *> src, std::size_t size);
101106 cudaGraphNode_t node_;
@@ -118,4 +123,4 @@ class CaptureInfo {
118123 size_t depCount_;
119124};
120125
121- }// namespace CUDA
126+ } // namespace CUDA
0 commit comments