-
Notifications
You must be signed in to change notification settings - Fork 370
Closed
Labels
questionFurther information is requestedFurther information is requested
Description
TensorRT/py/torch_tensorrt/dynamo/conversion/aten_ops_converters.py
Lines 1030 to 1037 in a662411
| allowed_casts = { | |
| torch.float, | |
| torch.int32, | |
| torch.int64, | |
| torch.bool, | |
| torch.int8, | |
| torch.float16, | |
| } |
Is there a specific reason why torch.bfloat16 is not included in the allowed_casts set within the to_copy_dtype_validator function?
Plus, this causes graph partitioning when performing a aten.ops._to_copy operation to torch.bfloat16. I'm wondering if this could potentially impact performance.
Metadata
Metadata
Assignees
Labels
questionFurther information is requestedFurther information is requested