-
Notifications
You must be signed in to change notification settings - Fork 368
Moving normalization core to impl - softmax (FX Converter Refactor [12/N]) <Target: converter_reorg_elementwise> #1909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
There are some changes that do not conform to Python style guidelines:
--- py/torch_tensorrt/fx/converters/aten_ops_converters.py 2023-05-10 06:06:19.166184 +0000
+++ py/torch_tensorrt/fx/converters/aten_ops_converters.py 2023-05-10 06:06:48.425711 +0000
@@ -359,17 +359,11 @@
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return softmax(
- network,
- target,
- SourceIR.ATEN,
- name,
- kwargs["input"],
- kwargs["dim"])
+ return softmax(network, target, SourceIR.ATEN, name, kwargs["input"], kwargs["dim"])
@tensorrt_converter(torch.ops.aten.cat.default)
def aten_ops_cat(
network: TRTNetwork,
--- py/torch_tensorrt/fx/converters/impl/normalization/__init__.py 2023-05-10 06:06:19.166184 +0000
+++ py/torch_tensorrt/fx/converters/impl/normalization/__init__.py 2023-05-10 06:06:48.631521 +0000
@@ -1 +1 @@
-from .ops import *
\ No newline at end of file
+from .ops import *
--- py/torch_tensorrt/fx/converters/impl/normalization/ops.py 2023-05-10 06:06:19.166184 +0000
+++ py/torch_tensorrt/fx/converters/impl/normalization/ops.py 2023-05-10 06:06:48.739746 +0000
@@ -12,20 +12,21 @@
from torch_tensorrt.fx.types import TRTNetwork, TRTTensor
from torch_tensorrt.fx.converters.converter_utils import (
SourceIR,
set_layer_name,
- get_positive_dim
+ get_positive_dim,
)
+
def softmax(
network: TRTNetwork,
target: Target,
source_ir: Optional[SourceIR],
name: str,
input: TRTTensor,
- dim: Optional[Any] = None
+ dim: Optional[Any] = None,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
input_ranks = len(input.shape) + (1 if network.has_implicit_batch_dimension else 0) # type: ignore[union-attr]
if not isinstance(input, TRTTensor):
raise RuntimeError(
@@ -53,6 +54,5 @@
layer = network.add_softmax(input)
layer.axes = 1 << dim
set_layer_name(layer, target, name)
return layer.get_output(0)
-
--- py/torch_tensorrt/fx/converters/acc_ops_converters.py 2023-05-10 06:06:19.166184 +0000
+++ py/torch_tensorrt/fx/converters/acc_ops_converters.py 2023-05-10 06:06:51.503546 +0000
@@ -857,18 +857,11 @@
target: Target,
args: Tuple[Argument, ...],
kwargs: Dict[str, Argument],
name: str,
) -> Union[TRTTensor, Sequence[TRTTensor]]:
- return softmax(
- network,
- target,
- SourceIR.ACC,
- name,
- kwargs["input"],
- kwargs["dim"]
- )
+ return softmax(network, target, SourceIR.ACC, name, kwargs["input"], kwargs["dim"])
@tensorrt_converter(acc_ops.tile)
def acc_ops_tile(
network: TRTNetwork,
--- py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py 2023-05-10 06:06:19.174185 +0000
+++ py/torch_tensorrt/fx/test/converters/aten_op/test_softmax_aten.py 2023-05-10 06:06:51.919309 +0000
@@ -39,6 +39,6 @@
TestModule(), input_specs, expected_ops={torch.ops.aten._softmax.default}
)
if __name__ == "__main__":
- run_tests()
\ No newline at end of file
+ run_tests()There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
0d4e6d7 to
9611d67
Compare
2769db2 to
2f23743
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
9611d67 to
93846ed
Compare
Signed-off-by: Naren Dasan <[email protected]> new file: ../converters/impl/unary/base.py
93846ed to
83986cd
Compare
… <Target: converter_reorg_elementwise> (#1905)
2f23743 to
f423be1
Compare
softmax linting error fix
f423be1 to
75b1a2a
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to C++ style guidelines
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code conforms to Python style guidelines
546f975 to
c8a9559
Compare
18e503b to
2caac76
Compare
|
Closed with #2070 merged with main |
No description provided.