1
+ import logging
1
2
from types import FunctionType
2
3
from typing import Tuple
3
4
8
9
from torch ._subclasses .fake_tensor import FakeTensorMode
9
10
from torch .fx .experimental .symbolic_shapes import DimDynamic , ShapeEnv
10
11
12
+ _LOGGER : logging .Logger = logging .getLogger (__name__ )
13
+
11
14
12
15
def mksym (shape_env , value , source , dynamic_dim ):
13
16
return shape_env .create_symintnode (
@@ -42,6 +45,16 @@ def generate_signature(torch_op):
42
45
for arg in schema .arguments :
43
46
arg_list .append (arg .name )
44
47
48
+ # TODO: Torch types need to be converted to python primitive types here
49
+ # Some other types are not handled:
50
+ # - torch._C.ListType.ofT(<type>)
51
+ # - torch._C.TupleType.get()
52
+ # - torch._C.DictType.get(<key_type>, <value_type>)
53
+ # - torch._C.OptionalType.ofT(<type>)
54
+ # - torch._C.DeviceObjType.get()
55
+ # - torch._C.FunctionType.get()
56
+ # - torch._C.ClassType
57
+
45
58
if arg .type .isSubtypeOf (torch ._C .TensorType .get ()):
46
59
tensor_args .append (arg )
47
60
register_func_annotation [arg .name ] = trtp .TensorDesc
@@ -52,6 +65,12 @@ def generate_signature(torch_op):
52
65
elif arg .type .isSubtypeOf (torch ._C .IntType .get ()):
53
66
register_func_annotation [arg .name ] = int
54
67
impl_func_annotation [arg .name ] = int
68
+ elif arg .type .isSubtypeOf (torch ._C .Booltype .get ()):
69
+ register_func_annotation [arg .name ] = bool
70
+ impl_func_annotation [arg .name ] = bool
71
+ elif arg .type .isSubtypeOf (torch ._C .Stringtype .get ()):
72
+ register_func_annotation [arg .name ] = str
73
+ impl_func_annotation [arg .name ] = str
55
74
else :
56
75
raise ValueError ("arg type is not handled" )
57
76
@@ -94,12 +113,6 @@ def generate_signature(torch_op):
94
113
register_func_annotation ,
95
114
impl_func_annotation ,
96
115
) = generate_signature (torch_op )
97
- print (args_input )
98
- print (kwargs_input )
99
- print (plugin_signature )
100
- print (plugin_impl_signature )
101
- print (register_func_annotation )
102
- print (impl_func_annotation )
103
116
104
117
def _generic_plugin_desc (* args , ** kwargs ) -> Tuple [trtp .TensorDesc ]:
105
118
shape_env = ShapeEnv ()
@@ -141,6 +154,8 @@ def _generic_plugin_desc(*args, **kwargs) -> Tuple[trtp.TensorDesc]:
141
154
return _generic_plugin_desc({ args_input } , { kwargs_input } )
142
155
"""
143
156
157
+ _LOGGER .warning (f"Plugin registration function: \n { codegen_plugin } " )
158
+
144
159
plugin_code = compile (codegen_plugin , "<string>" , "exec" )
145
160
146
161
globals ()["_generic_plugin_desc" ] = _generic_plugin_desc
@@ -167,6 +182,8 @@ def _generic_plugin_impl(outputs, stream, *args, **kwargs):
167
182
_generic_plugin_impl(outputs, stream, { args_input } , { kwargs_input } )
168
183
"""
169
184
185
+ _LOGGER .warning (f"Plugin implementation function: \n { plugin_impl_func } " )
186
+
170
187
plugin_impl_code = compile (plugin_impl_func , "<string>" , "exec" )
171
188
172
189
globals ()["_generic_plugin_impl" ] = _generic_plugin_impl
@@ -175,15 +192,6 @@ def _generic_plugin_impl(outputs, stream, *args, **kwargs):
175
192
176
193
plugin_impl .__annotations__ = impl_func_annotation
177
194
178
- import inspect
179
-
180
- sig = inspect .signature (plugin_impl )
181
-
182
- # input arg annotations are optional, but we will validate if provided
183
- for name , param in sig .parameters .items ():
184
- print (name )
185
- print (param .annotation )
186
-
187
195
trtp .impl (plugin_name )(plugin_impl )
188
196
189
197
return plugin
0 commit comments