Skip to content

Commit e809c83

Browse files
committed
resolve comments
1 parent 212978a commit e809c83

File tree

5 files changed

+58
-22
lines changed

5 files changed

+58
-22
lines changed

examples/dynamo/auto_generate_plugin.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
"""
22
.. _auto_generate_converters:
33
4-
Automatically Generate a Converter for a Custom Kernel
4+
Automatically Generate a Plugin for a Custom Kernel
55
===================================================================
66
7-
We are going to demonstrate how to automatically generate a converter for a custom kernel using Torch-TensorRT using
7+
We are going to demonstrate how to automatically generate a plugin for a custom kernel using Torch-TensorRT using
88
the new Python based plugin system in TensorRT 10.7.
99
1010
Torch-TensorRT supports falling back to PyTorch implementations of operations in the case that Torch-TensorRT
@@ -102,16 +102,21 @@ def _(x: torch.Tensor, y: torch.Tensor, b: float = 0.2, a: int = 2) -> torch.Ten
102102
torch_tensorrt.dynamo.conversion.plugins.generate_plugin("torchtrt_ex::elementwise_mul")
103103

104104

105-
# %%
106-
# Generating the Converter
107-
# -------------------------------------------------------------------
108-
# Given that we have defined the custom operator in PyTorch and TensorRT, we can now generate the converter for the operation.
109-
# As long as the namespace and names match, the following function will automatically generate the converter for the operation.
105+
# # %%
106+
# # Generating the Converter
107+
# # -------------------------------------------------------------------
108+
# # Given that we have defined the custom operator in PyTorch and TensorRT, we can now generate the converter for the operation.
109+
# # As long as the namespace and names match, the following function will automatically generate the converter for the operation.
110110
torch_tensorrt.dynamo.conversion.plugins.generate_plugin_converter(
111111
"torchtrt_ex::elementwise_mul", supports_dynamic_shapes=True
112112
)
113113

114114

115+
# # %%
116+
# # Above two commands can be replaced with the following single one line:
117+
# torch_tensorrt.dynamo.conversion.plugins.custom_op("torchtrt_ex::elementwise_mul", supports_dynamic_shapes=True)
118+
119+
115120
# %%
116121
# Using our converter with a model
117122
# -------------------------------------------------------------------

py/torch_tensorrt/dynamo/conversion/plugins/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from torch_tensorrt.dynamo.conversion.plugins._custom_op import custom_op
12
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import generate_plugin
23
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin_converter import (
34
generate_plugin_converter,
Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
from typing import Callable, Optional
2+
3+
from torch.fx.node import Node
4+
from torch_tensorrt.dynamo._settings import CompilationSettings
5+
from torch_tensorrt.dynamo.conversion._ConverterRegistry import ConverterPriority
6+
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin import generate_plugin
7+
from torch_tensorrt.dynamo.conversion.plugins._generate_plugin_converter import (
8+
generate_plugin_converter,
9+
)
10+
11+
12+
def custom_op(
13+
op_name: str,
14+
capability_validator: Optional[Callable[[Node, CompilationSettings], bool]] = None,
15+
priority: ConverterPriority = ConverterPriority.STANDARD,
16+
supports_dynamic_shapes: bool = False,
17+
):
18+
generate_plugin(op_name)
19+
generate_plugin_converter(
20+
op_name, capability_validator, priority, supports_dynamic_shapes
21+
)

py/torch_tensorrt/dynamo/conversion/plugins/_generate_plugin.py

Lines changed: 23 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import logging
12
from types import FunctionType
23
from typing import Tuple
34

@@ -8,6 +9,8 @@
89
from torch._subclasses.fake_tensor import FakeTensorMode
910
from torch.fx.experimental.symbolic_shapes import DimDynamic, ShapeEnv
1011

12+
_LOGGER: logging.Logger = logging.getLogger(__name__)
13+
1114

1215
def mksym(shape_env, value, source, dynamic_dim):
1316
return shape_env.create_symintnode(
@@ -42,6 +45,16 @@ def generate_signature(torch_op):
4245
for arg in schema.arguments:
4346
arg_list.append(arg.name)
4447

48+
# TODO: Torch types need to be converted to python primitive types here
49+
# Some other types are not handled:
50+
# - torch._C.ListType.ofT(<type>)
51+
# - torch._C.TupleType.get()
52+
# - torch._C.DictType.get(<key_type>, <value_type>)
53+
# - torch._C.OptionalType.ofT(<type>)
54+
# - torch._C.DeviceObjType.get()
55+
# - torch._C.FunctionType.get()
56+
# - torch._C.ClassType
57+
4558
if arg.type.isSubtypeOf(torch._C.TensorType.get()):
4659
tensor_args.append(arg)
4760
register_func_annotation[arg.name] = trtp.TensorDesc
@@ -52,6 +65,12 @@ def generate_signature(torch_op):
5265
elif arg.type.isSubtypeOf(torch._C.IntType.get()):
5366
register_func_annotation[arg.name] = int
5467
impl_func_annotation[arg.name] = int
68+
elif arg.type.isSubtypeOf(torch._C.Booltype.get()):
69+
register_func_annotation[arg.name] = bool
70+
impl_func_annotation[arg.name] = bool
71+
elif arg.type.isSubtypeOf(torch._C.Stringtype.get()):
72+
register_func_annotation[arg.name] = str
73+
impl_func_annotation[arg.name] = str
5574
else:
5675
raise ValueError("arg type is not handled")
5776

@@ -94,12 +113,6 @@ def generate_signature(torch_op):
94113
register_func_annotation,
95114
impl_func_annotation,
96115
) = generate_signature(torch_op)
97-
print(args_input)
98-
print(kwargs_input)
99-
print(plugin_signature)
100-
print(plugin_impl_signature)
101-
print(register_func_annotation)
102-
print(impl_func_annotation)
103116

104117
def _generic_plugin_desc(*args, **kwargs) -> Tuple[trtp.TensorDesc]:
105118
shape_env = ShapeEnv()
@@ -141,6 +154,8 @@ def _generic_plugin_desc(*args, **kwargs) -> Tuple[trtp.TensorDesc]:
141154
return _generic_plugin_desc({args_input}, {kwargs_input})
142155
"""
143156

157+
_LOGGER.warning(f"Plugin registration function: \n{codegen_plugin}")
158+
144159
plugin_code = compile(codegen_plugin, "<string>", "exec")
145160

146161
globals()["_generic_plugin_desc"] = _generic_plugin_desc
@@ -167,6 +182,8 @@ def _generic_plugin_impl(outputs, stream, *args, **kwargs):
167182
_generic_plugin_impl(outputs, stream, {args_input}, {kwargs_input})
168183
"""
169184

185+
_LOGGER.warning(f"Plugin implementation function: \n{plugin_impl_func}")
186+
170187
plugin_impl_code = compile(plugin_impl_func, "<string>", "exec")
171188

172189
globals()["_generic_plugin_impl"] = _generic_plugin_impl
@@ -175,15 +192,6 @@ def _generic_plugin_impl(outputs, stream, *args, **kwargs):
175192

176193
plugin_impl.__annotations__ = impl_func_annotation
177194

178-
import inspect
179-
180-
sig = inspect.signature(plugin_impl)
181-
182-
# input arg annotations are optional, but we will validate if provided
183-
for name, param in sig.parameters.items():
184-
print(name)
185-
print(param.annotation)
186-
187195
trtp.impl(plugin_name)(plugin_impl)
188196

189197
return plugin

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ requires = [
1212
"torch>=2.7.0.dev,<2.8.0",
1313
"pybind11==2.6.2",
1414
"numpy",
15+
"sympy",
1516
]
1617
build-backend = "setuptools.build_meta"
1718

0 commit comments

Comments
 (0)