@@ -37,7 +37,12 @@ struct MyStackClass : torch::CustomClassHolder {
3737};
3838// END class
3939
40- #ifdef NO_PICKLE
40+ // BEGIN free_function
41+ c10::intrusive_ptr<MyStackClass<std::string>> manipulate_instance (const c10::intrusive_ptr<MyStackClass<std::string>>& instance) {
42+ instance->pop ();
43+ return instance;
44+ }
45+ // END free_function
4146
4247// BEGIN binding
4348// Notice a few things:
@@ -52,94 +57,76 @@ struct MyStackClass : torch::CustomClassHolder {
5257// Python and C++ as `torch.classes.my_classes.MyStackClass`. We call
5358// the first argument the "namespace" and the second argument the
5459// actual class name.
55- static auto testStack =
56- torch:: class_<MyStackClass<std::string>>(" my_classes " , " MyStackClass" )
57- // The following line registers the contructor of our MyStackClass
58- // class that takes a single `std::vector<std::string>` argument,
59- // i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
60- // Currently, we do not support registering overloaded
61- // constructors, so for now you can only `def()` one instance of
62- // `torch::init`.
63- .def(torch::init<std::vector<std::string>>())
64- // The next line registers a stateless (i.e. no captures) C++ lambda
65- // function as a method. Note that a lambda function must take a
66- // `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
67- // as the first argument. Other arguments can be whatever you want.
68- .def(" top" , [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
69- return self->stack_ .back ();
70- })
71- // The following four lines expose methods of the MyStackClass<std::string>
72- // class as-is. `torch::class_` will automatically examine the
73- // argument and return types of the passed-in method pointers and
74- // expose these to Python and TorchScript accordingly. Finally, notice
75- // that we must take the *address* of the fully-qualified method name,
76- // i.e. use the unary `&` operator, due to C++ typing rules.
77- .def(" push" , &MyStackClass<std::string>::push)
78- .def(" pop" , &MyStackClass<std::string>::pop)
79- .def(" clone" , &MyStackClass<std::string>::clone)
80- .def(" merge" , &MyStackClass<std::string>::merge);
60+ TORCH_LIBRARY (my_classes, m) {
61+ m. class_ <MyStackClass<std::string>>(" MyStackClass" )
62+ // The following line registers the contructor of our MyStackClass
63+ // class that takes a single `std::vector<std::string>` argument,
64+ // i.e. it exposes the C++ method `MyStackClass(std::vector<T> init)`.
65+ // Currently, we do not support registering overloaded
66+ // constructors, so for now you can only `def()` one instance of
67+ // `torch::init`.
68+ .def (torch::init<std::vector<std::string>>())
69+ // The next line registers a stateless (i.e. no captures) C++ lambda
70+ // function as a method. Note that a lambda function must take a
71+ // `c10::intrusive_ptr<YourClass>` (or some const/ref version of that)
72+ // as the first argument. Other arguments can be whatever you want.
73+ .def (" top" , [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
74+ return self->stack_ .back ();
75+ })
76+ // The following four lines expose methods of the MyStackClass<std::string>
77+ // class as-is. `torch::class_` will automatically examine the
78+ // argument and return types of the passed-in method pointers and
79+ // expose these to Python and TorchScript accordingly. Finally, notice
80+ // that we must take the *address* of the fully-qualified method name,
81+ // i.e. use the unary `&` operator, due to C++ typing rules.
82+ .def (" push" , &MyStackClass<std::string>::push)
83+ .def (" pop" , &MyStackClass<std::string>::pop)
84+ .def (" clone" , &MyStackClass<std::string>::clone)
85+ .def (" merge" , &MyStackClass<std::string>::merge)
8186// END binding
87+ #ifndef NO_PICKLE
88+ // BEGIN def_pickle
89+ // class_<>::def_pickle allows you to define the serialization
90+ // and deserialization methods for your C++ class.
91+ // Currently, we only support passing stateless lambda functions
92+ // as arguments to def_pickle
93+ .def_pickle (
94+ // __getstate__
95+ // This function defines what data structure should be produced
96+ // when we serialize an instance of this class. The function
97+ // must take a single `self` argument, which is an intrusive_ptr
98+ // to the instance of the object. The function can return
99+ // any type that is supported as a return value of the TorchScript
100+ // custom operator API. In this instance, we've chosen to return
101+ // a std::vector<std::string> as the salient data to preserve
102+ // from the class.
103+ [](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
104+ -> std::vector<std::string> {
105+ return self->stack_ ;
106+ },
107+ // __setstate__
108+ // This function defines how to create a new instance of the C++
109+ // class when we are deserializing. The function must take a
110+ // single argument of the same type as the return value of
111+ // `__getstate__`. The function must return an intrusive_ptr
112+ // to a new instance of the C++ class, initialized however
113+ // you would like given the serialized state.
114+ [](std::vector<std::string> state)
115+ -> c10::intrusive_ptr<MyStackClass<std::string>> {
116+ // A convenient way to instantiate an object and get an
117+ // intrusive_ptr to it is via `make_intrusive`. We use
118+ // that here to allocate an instance of MyStackClass<std::string>
119+ // and call the single-argument std::vector<std::string>
120+ // constructor with the serialized state.
121+ return c10::make_intrusive<MyStackClass<std::string>>(std::move (state));
122+ });
123+ // END def_pickle
124+ #endif // NO_PICKLE
82125
83- #else
84-
85- // BEGIN pickle_binding
86- static auto testStack =
87- torch::class_<MyStackClass<std::string>>(" my_classes" , " MyStackClass" )
88- .def(torch::init<std::vector<std::string>>())
89- .def(" top" , [](const c10::intrusive_ptr<MyStackClass<std::string>>& self) {
90- return self->stack_ .back ();
91- })
92- .def(" push" , &MyStackClass<std::string>::push)
93- .def(" pop" , &MyStackClass<std::string>::pop)
94- .def(" clone" , &MyStackClass<std::string>::clone)
95- .def(" merge" , &MyStackClass<std::string>::merge)
96- // class_<>::def_pickle allows you to define the serialization
97- // and deserialization methods for your C++ class.
98- // Currently, we only support passing stateless lambda functions
99- // as arguments to def_pickle
100- .def_pickle(
101- // __getstate__
102- // This function defines what data structure should be produced
103- // when we serialize an instance of this class. The function
104- // must take a single `self` argument, which is an intrusive_ptr
105- // to the instance of the object. The function can return
106- // any type that is supported as a return value of the TorchScript
107- // custom operator API. In this instance, we've chosen to return
108- // a std::vector<std::string> as the salient data to preserve
109- // from the class.
110- [](const c10::intrusive_ptr<MyStackClass<std::string>>& self)
111- -> std::vector<std::string> {
112- return self->stack_ ;
113- },
114- // __setstate__
115- // This function defines how to create a new instance of the C++
116- // class when we are deserializing. The function must take a
117- // single argument of the same type as the return value of
118- // `__getstate__`. The function must return an intrusive_ptr
119- // to a new instance of the C++ class, initialized however
120- // you would like given the serialized state.
121- [](std::vector<std::string> state)
122- -> c10::intrusive_ptr<MyStackClass<std::string>> {
123- // A convenient way to instantiate an object and get an
124- // intrusive_ptr to it is via `make_intrusive`. We use
125- // that here to allocate an instance of MyStackClass<std::string>
126- // and call the single-argument std::vector<std::string>
127- // constructor with the serialized state.
128- return c10::make_intrusive<MyStackClass<std::string>>(std::move (state));
129- });
130- // END pickle_binding
131-
132- // BEGIN free_function
133- c10::intrusive_ptr<MyStackClass<std::string>> manipulate_instance (const c10::intrusive_ptr<MyStackClass<std::string>>& instance) {
134- instance->pop ();
135- return instance;
126+ // BEGIN def_free
127+ m.def (
128+ " foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y" ,
129+ manipulate_instance
130+ );
131+ // END def_free
136132}
137-
138- static auto instance_registry = torch::RegisterOperators().op(
139- torch::RegisterOperators::options ()
140- .schema(
141- " foo::manipulate_instance(__torch__.torch.classes.my_classes.MyStackClass x) -> __torch__.torch.classes.my_classes.MyStackClass Y" )
142- .catchAllKernel<decltype(manipulate_instance), &manipulate_instance>());
143- // END free_function
144-
145- #endif
0 commit comments