Skip to content

Commit e9b099a

Browse files
author
Chen Fu
committed
skipped unnecessary broadcast
1 parent ac17aba commit e9b099a

File tree

1 file changed

+8
-2
lines changed
  • py/torch_tensorrt/dynamo/conversion/impl/elementwise

1 file changed

+8
-2
lines changed

py/torch_tensorrt/dynamo/conversion/impl/elementwise/base.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1+
import logging
12
import operator
23
import warnings
34
from typing import Any, Callable, Optional, Union
45

5-
import numpy as np
66
import tensorrt as trt
77
import torch
88
from torch.fx.node import Target
@@ -20,6 +20,8 @@
2020
)
2121
from torch_tensorrt.dynamo.types import TRTElementWiseOp, TRTTensor
2222

23+
logger = logging.getLogger(__name__)
24+
2325

2426
def get_python_op_from_trt_elementwise_op(
2527
trt_op: TRTElementWiseOp,
@@ -148,7 +150,11 @@ def convert_binary_elementwise(
148150
ctx, rhs_val, trt_promoted_type, f"{name}_cast_rhs_val", target, source_ir
149151
)
150152

151-
if has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
153+
if len(lhs_val.shape) == len(rhs_val.shape) and all(
154+
a == b or a == 1 or b == 1 for a, b in zip(lhs_val.shape, rhs_val.shape)
155+
):
156+
logger.info(f"skip broadcast for {name}")
157+
elif has_dynamic_shape(lhs_val.shape) or has_dynamic_shape(rhs_val.shape):
152158
lhs_val, rhs_val = broadcast(
153159
ctx, lhs_val, rhs_val, f"{name}_broadcast_lhs", f"{name}_broadcast_rhs"
154160
)

0 commit comments

Comments
 (0)