-
Notifications
You must be signed in to change notification settings - Fork 368
Closed
Labels
No Activitycomponent: dynamoIssues relating to the `torch.compile` or `torch._dynamo.export` pathsIssues relating to the `torch.compile` or `torch._dynamo.export` pathsfeature requestNew feature or requestNew feature or request
Description
Context
Currently, elementwise operators are not automatically type-promoted in FX in the same way they are in TS. This leads to bugs such as #1995, where the types are mismatched and TensorRT throws an error.
Feature Proposal
Using the TS type-promotion code as a starting point:
TensorRT/core/conversion/converters/converter_util.cpp
Lines 72 to 78 in 82631fa
| nvinfer1::DataType promote_types(nvinfer1::DataType type_a, nvinfer1::DataType type_b) { | |
| auto torch_type_a = util::TRTDataTypeToScalarType(type_a); | |
| auto torch_type_b = util::TRTDataTypeToScalarType(type_b); | |
| auto promo_type = at::promote_types(torch_type_a, torch_type_b); | |
| auto trt_promo_type = util::ScalarTypeToTRTDataType(promo_type); | |
| return trt_promo_type; | |
| } |
Implement a similar type-promotion scheme pointing to the
converter_reorg_elementwise branch, here:TensorRT/py/torch_tensorrt/fx/converters/impl/elementwise/base.py
Lines 40 to 48 in 546f975
| def convert_binary_elementwise( | |
| network: TRTNetwork, | |
| target: Target, | |
| source_ir: Optional[SourceIR], | |
| name: str, | |
| op_type: trt.ElementWiseOperation, | |
| lhs_val: Union[int, float, TRTTensor, torch.Tensor], | |
| rhs_val: Union[int, float, TRTTensor, torch.Tensor], | |
| ) -> TRTTensor: |
Metadata
Metadata
Labels
No Activitycomponent: dynamoIssues relating to the `torch.compile` or `torch._dynamo.export` pathsIssues relating to the `torch.compile` or `torch._dynamo.export` pathsfeature requestNew feature or requestNew feature or request