From 07b926619b3535001c4d785d593073e6d1a7ea91 Mon Sep 17 00:00:00 2001 From: Naren Dasan Date: Fri, 28 Apr 2023 15:14:02 -0700 Subject: [PATCH 1/3] refactor: Moving elementwise and unary core to impl Signed-off-by: Naren Dasan new file: ../converters/impl/unary/base.py --- .../fx/converters/acc_ops_converters.py | 511 +++++++++++++----- .../fx/converters/aten_ops_converters.py | 7 +- .../fx/converters/converter_utils.py | 342 +----------- .../converters/impl/elementwise/__init__.py | 1 + .../fx/converters/impl/elementwise/base.py | 147 +++++ .../fx/converters/impl/elementwise/ops.py | 111 ++++ py/torch_tensorrt/fx/converters/impl/shape.py | 81 +++ .../fx/converters/impl/unary/__init__.py | 1 + .../fx/converters/impl/unary/base.py | 53 ++ .../fx/converters/impl/unary/ops.py | 104 ++++ 10 files changed, 885 insertions(+), 473 deletions(-) create mode 100644 py/torch_tensorrt/fx/converters/impl/elementwise/__init__.py create mode 100644 py/torch_tensorrt/fx/converters/impl/elementwise/base.py create mode 100644 py/torch_tensorrt/fx/converters/impl/elementwise/ops.py create mode 100644 py/torch_tensorrt/fx/converters/impl/shape.py create mode 100644 py/torch_tensorrt/fx/converters/impl/unary/__init__.py create mode 100644 py/torch_tensorrt/fx/converters/impl/unary/base.py create mode 100644 py/torch_tensorrt/fx/converters/impl/unary/ops.py diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index 9532c7072c..fcb9ce9aad 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -26,7 +26,14 @@ trt_transposed_matmul, ) from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous -from torch_tensorrt.fx.converters.impl import activation, convolution +from torch_tensorrt.fx.converters.impl import activation +from torch_tensorrt.fx.converters.impl.elementwise import trunc_div +from torch_tensorrt.fx.converters.impl.unary import sign +from torch_tensorrt.fx.converters.impl.elementwise.base import ( + convert_binary_elementwise, +) +from torch_tensorrt.fx.converters.impl.unary.base import convert_unary +from torch_tensorrt.fx.converters.impl.shape import get_shape_with_dynamic_shape _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -78,13 +85,14 @@ def trt_transposed_linear_converter(network, target, args, kwargs, name): trt.MatrixOperation.NONE, ) set_layer_name(layer, target, f"{name}_mm") - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - layer.get_output(0), - bias, - trt.ElementWiseOperation.SUM, target, + SourceIR.TORCHTRT_LOWERED, f"{name}_add", + trt.ElementWiseOperation.SUM, + layer.get_output(0), + bias, ) @@ -646,13 +654,14 @@ def layer_norm( set_layer_name(mean_expected_layer, target, f"{name}_mean_expected") # X-E[x] - sub_trt = add_binary_elementwise_layer( + sub_trt = convert_binary_elementwise( network, - input_val, - mean_expected_layer.get_output(0), - trt.ElementWiseOperation.SUB, target, + SourceIR.ACC, f"{name}_sub", + trt.ElementWiseOperation.SUB, + input_val, + mean_expected_layer.get_output(0), ) # Variance = mean(pow(x_sub_mean,2)) pow_tensor = network.add_constant( @@ -660,13 +669,14 @@ def layer_norm( trt.Weights(np.ascontiguousarray([2.0], dtype=np.float32)), ) pow_tensor.name = f"{name}_power" - pow_var = add_binary_elementwise_layer( + pow_var = convert_binary_elementwise( network, - sub_trt, - pow_tensor.get_output(0), - trt.ElementWiseOperation.POW, target, + SourceIR.ACC, f"{name}_pow_var", + trt.ElementWiseOperation.POW, + sub_trt, + pow_tensor.get_output(0), ) mean_trt_layer = network.add_reduce( pow_var, trt.ReduceOperation.AVG, axes, keep_dims=True @@ -678,26 +688,33 @@ def layer_norm( trt.Weights(np.ascontiguousarray([eps], dtype=np.float32)), ) eps_tensor.name = f"{name}_eps" - add_trt = add_binary_elementwise_layer( + add_trt = convert_binary_elementwise( network, - mean_trt_layer.get_output(0), - eps_tensor.get_output(0), - trt.ElementWiseOperation.SUM, target, + SourceIR.ACC, f"{name}_add", + trt.ElementWiseOperation.SUM, + mean_trt_layer.get_output(0), + eps_tensor.get_output(0), ) # SQRT((Var + eps)) - sqrt_trt = add_unary_layer( - network, add_trt, trt.UnaryOperation.SQRT, target, f"{name}_sqrt" + sqrt_trt = convert_unary( + network, + target, + SourceIR.ACC, + f"{name}_sqrt", + trt.UnaryOperation.SQRT, + add_trt, ) # (x - E[x]) / sqrt((var + eps)) - div_trt = add_binary_elementwise_layer( + div_trt = convert_binary_elementwise( network, - sub_trt, - sqrt_trt, - trt.ElementWiseOperation.DIV, target, + SourceIR.ACC, f"{name}_div_trt", + trt.ElementWiseOperation.DIV, + sub_trt, + sqrt_trt, ) assert gamma is not None @@ -707,21 +724,23 @@ def layer_norm( beta_tensor = network.add_constant(gamma.shape, trt.Weights(np.ascontiguousarray(beta))) # type: ignore[attr-defined] beta_tensor.name = f"{name}_beta" # y * gamma + beta - scale_layer = add_binary_elementwise_layer( + scale_layer = convert_binary_elementwise( network, - div_trt, - gamma_tensor.get_output(0), - trt.ElementWiseOperation.PROD, target, + SourceIR.ACC, f"{name}_scale", + trt.ElementWiseOperation.PROD, + div_trt, + gamma_tensor.get_output(0), ) - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - scale_layer, - beta_tensor.get_output(0), - trt.ElementWiseOperation.SUM, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.SUM, + scale_layer, + beta_tensor.get_output(0), ) @@ -824,13 +843,14 @@ def acc_ops_tile( else: d = get_trt_tensor(network, d, f"{name}_{i}") shape.append(d) - mul = add_binary_elementwise_layer( + mul = convert_binary_elementwise( network, - s, - d, - trt.ElementWiseOperation.PROD, target, + SourceIR.ACC, f"{name}_mul_{i}", + trt.ElementWiseOperation.PROD, + s, + d, ) shapes.append(mul) dims = shape @@ -859,13 +879,14 @@ def acc_ops_tile( dims_tensor = concat_dims_layer.get_output(0) input_shape_layer = network.add_shape(input_val) input_shape_layer.name = f"{name}_slice_input_shape" - slice_shapes_tensor = add_binary_elementwise_layer( + slice_shapes_tensor = convert_binary_elementwise( network, - input_shape_layer.get_output(0), - dims_tensor, - trt.ElementWiseOperation.PROD, target, + SourceIR.ACC, f"{name}_slice_shapes", + trt.ElementWiseOperation.PROD, + input_shape_layer.get_output(0), + dims_tensor, ) layer.set_input(1, starts_tensor) layer.set_input(2, slice_shapes_tensor) @@ -886,9 +907,22 @@ def acc_ops_sign( if trt.__version__ >= "8.2" and not network.has_implicit_batch_dimension: input_val = kwargs["input"] operation_type = trt.UnaryOperation.SIGN - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) - return sign(network, input_val, target, name) + return sign( + network, + target, + SourceIR.ACC, + name, + input_val, + ) @tensorrt_converter(acc_ops.relu) @@ -990,7 +1024,14 @@ def acc_ops_sin( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.SIN - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.cos) @@ -1003,7 +1044,14 @@ def acc_ops_cos( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.COS - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.tan) @@ -1016,7 +1064,14 @@ def acc_ops_tan( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.TAN - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.sinh) @@ -1029,7 +1084,14 @@ def acc_ops_sinh( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.SINH - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.cosh) @@ -1042,7 +1104,14 @@ def acc_ops_cosh( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.COSH - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.tanh) @@ -1072,7 +1141,14 @@ def acc_ops_asin( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.ASIN - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.acos) @@ -1085,7 +1161,14 @@ def acc_ops_acos( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.ACOS - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.atan) @@ -1098,7 +1181,14 @@ def acc_ops_atan( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.ATAN - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.exp) @@ -1111,7 +1201,14 @@ def acc_ops_exp( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.EXP - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.log) @@ -1124,7 +1221,14 @@ def acc_ops_log( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.LOG - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.sqrt) @@ -1137,7 +1241,14 @@ def acc_ops_sqrt( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.SQRT - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.reciprocal) @@ -1150,7 +1261,14 @@ def acc_ops_reciprocal( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.RECIP - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.abs) @@ -1163,7 +1281,14 @@ def acc_ops_abs( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.ABS - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.neg) @@ -1176,7 +1301,14 @@ def acc_ops_neg( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.NEG - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.floor) @@ -1189,7 +1321,14 @@ def acc_ops_floor( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.FLOOR - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.ceil) @@ -1202,7 +1341,14 @@ def acc_ops_ceil( ) -> Union[TRTTensor, Sequence[TRTTensor]]: input_val = kwargs["input"] operation_type = trt.UnaryOperation.CEIL - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.sum) @@ -1377,13 +1523,14 @@ def acc_ops_maximum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.MAX, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.MAX, + kwargs["input"], + kwargs["other"], ) @@ -1395,13 +1542,14 @@ def acc_ops_minimum( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.MIN, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.MIN, + kwargs["input"], + kwargs["other"], ) @@ -1460,7 +1608,14 @@ def acc_ops_logical_not( # cast to bool type if input_val.dtype in (trt.float32, trt.float16, trt.int32): input_val = type_cast(network, target, f"{name}_input", input_val, trt.bool) - return add_unary_layer(network, input_val, operation_type, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + operation_type, + input_val, + ) @tensorrt_converter(acc_ops.logical_and, no_implicit_batch_dim=True) @@ -1506,8 +1661,14 @@ def check_is_bool(input_t): input_t = type_cast(network, target, f"{name}_input", input_t, trt.bool) if other_t.dtype != trt.bool: other_t = type_cast(network, target, f"{name}_other", other_t, trt.bool) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.AND, target, name + return convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.AND, + input_t, + other_t, ) @@ -1531,11 +1692,24 @@ def acc_ops_ne( other_t = get_trt_tensor(network, other_t, f"{name}_other_t") input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - eq_t = add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name + eq_t = convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.EQUAL, + input_t, + other_t, ) - return add_unary_layer(network, eq_t, trt.UnaryOperation.NOT, target, name) + return convert_unary( + network, + target, + SourceIR.ACC, + name, + trt.UnaryOperation.NOT, + eq_t, + ) @tensorrt_converter(acc_ops.eq, no_implicit_batch_dim=True) @@ -1558,8 +1732,14 @@ def acc_ops_eq( other_t = get_trt_tensor(network, other_t, f"{name}_other_t") input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.EQUAL, target, name + return convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.EQUAL, + input_t, + other_t, ) @@ -1583,8 +1763,14 @@ def acc_ops_gt( other_t = get_trt_tensor(network, other_t, f"{name}_other_t") input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.GREATER, target, name + return convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.GREATER, + input_t, + other_t, ) @@ -1608,8 +1794,14 @@ def acc_ops_lt( other_t = get_trt_tensor(network, other_t, f"{name}_other_t") input_t, other_t = dtype_uniform(network, target, name, input_t, other_t) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.LESS, target, name + return convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.LESS, + input_t, + other_t, ) @@ -1645,8 +1837,14 @@ def acc_ops_logical_or( set_layer_name(layer_o, target, f"{name}_other_dtype_change") other_t = layer_o.get_output(0) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.OR, target, name + return convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.OR, + input_t, + other_t, ) @@ -1682,8 +1880,14 @@ def acc_ops_logical_xor( set_layer_name(layer_o, target, f"{name}_other_dtype_change") other_t = layer_o.get_output(0) - return add_binary_elementwise_layer( - network, input_t, other_t, trt.ElementWiseOperation.XOR, target, name + return convert_binary_elementwise( + network, + target, + SourceIR.ACC, + name, + trt.ElementWiseOperation.XOR, + input_t, + other_t, ) @@ -1780,23 +1984,30 @@ def acc_ops_fmod( ) -> Union[TRTTensor, Sequence[TRTTensor]]: # NOTE: TRT doesnt currently implement fmod so we need multiple operations to perform it trunc_div_value = trunc_div( - kwargs["input"], kwargs["other"], network, target, name + "_trunc_div" - ) - prod_value = add_binary_elementwise_layer( network, - trunc_div_value, + target, + SourceIR.ACC, + name + "_trunc_div", + kwargs["input"], kwargs["other"], - trt.ElementWiseOperation.PROD, + ) + prod_value = convert_binary_elementwise( + network, target, + SourceIR.ACC, name + "_prod", + trt.ElementWiseOperation.PROD, + trunc_div_value, + kwargs["other"], ) - sub_value = add_binary_elementwise_layer( + sub_value = convert_binary_elementwise( network, - kwargs["input"], - prod_value, - trt.ElementWiseOperation.SUB, target, + SourceIR.ACC, name + "_sub", + trt.ElementWiseOperation.SUB, + kwargs["input"], + prod_value, ) return sub_value @@ -2027,13 +2238,14 @@ def acc_ops_add( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.SUM, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.SUM, + kwargs["input"], + kwargs["other"], ) @@ -2045,13 +2257,14 @@ def acc_ops_sub( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.SUB, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.SUB, + kwargs["input"], + kwargs["other"], ) @@ -2063,13 +2276,14 @@ def acc_ops_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.DIV, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.DIV, + kwargs["input"], + kwargs["other"], ) @@ -2081,13 +2295,14 @@ def acc_ops_floor_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.FLOOR_DIV, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.FLOOR_DIV, + kwargs["input"], + kwargs["other"], ) @@ -2099,7 +2314,14 @@ def acc_ops_trunc_div( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return trunc_div(kwargs["input"], kwargs["other"], network, target, name) + return trunc_div( + network, + target, + SourceIR.ACC, + name, + kwargs["input"], + kwargs["other"], + ) @tensorrt_converter(acc_ops.mul) @@ -2110,13 +2332,14 @@ def acc_ops_mul( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["other"], - trt.ElementWiseOperation.PROD, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.PROD, + kwargs["input"], + kwargs["other"], ) @@ -2128,13 +2351,14 @@ def acc_ops_pow( kwargs: Dict[str, Argument], name: str, ) -> Union[TRTTensor, Sequence[TRTTensor]]: - return add_binary_elementwise_layer( + return convert_binary_elementwise( network, - kwargs["input"], - kwargs["exponent"], - trt.ElementWiseOperation.POW, target, + SourceIR.ACC, name, + trt.ElementWiseOperation.POW, + kwargs["input"], + kwargs["exponent"], ) @@ -2440,7 +2664,12 @@ def acc_ops_slice_tensor( if dynamic_shape > 0: output_shape = get_shape_with_dynamic_shape( - network, output_shape, input_val, target, name + network, + target, + SourceIR.ACC, + name, + output_shape, + input_val, ) layer = network.add_slice( input_val, @@ -2684,7 +2913,12 @@ def acc_ops_split( start[dim] = offset if dynamic_shape: shape = get_shape_with_dynamic_shape( - network, shape, input_val, target, f"{name}_shape_{i}" + network, + target, + SourceIR.ACC, + f"{name}_shape_{i}", + shape, + input_val, ) layer = network.add_slice( input_val, start=start, shape=[] if dynamic_shape else shape, stride=stride @@ -2748,13 +2982,14 @@ def acc_ops_linear( if kwargs["bias"] is not None: bias = get_trt_tensor(network, kwargs["bias"], f"{name}_bias") # type: ignore[arg-type] - res = add_binary_elementwise_layer( + res = convert_binary_elementwise( network, - matmul_layer.get_output(0), - bias, - trt.ElementWiseOperation.SUM, target, + SourceIR.ACC, f"{name}_add", + trt.ElementWiseOperation.SUM, + matmul_layer.get_output(0), + bias, ) return res @@ -2945,7 +3180,14 @@ def slice_to_trt_params(py_slice, dim_size): i += 1 if dynamic_shape: - size = get_shape_with_dynamic_shape(network, size, input_val, target, name) + size = get_shape_with_dynamic_shape( + network, + target, + SourceIR.ACC, + name, + size, + input_val, + ) layer = network.add_slice( input=input_val, @@ -3370,7 +3612,12 @@ def acc_ops_chunk( shape[dim] = min(split_size, max_offset - offset) if dynamic_shape: shape = get_shape_with_dynamic_shape( - network, shape, input_val, target, f"{name}_{i}" + network, + target, + SourceIR.ACC, + f"{name}_{i}", + shape, + input_val, ) start[dim] = offset layer = network.add_slice( @@ -3435,13 +3682,14 @@ def acc_ops_cumsum( set_layer_name(running_sum, target, f"{name}_running_sum_1") running_sum_tensor = running_sum.get_output(0) - current_sum = add_binary_elementwise_layer( + current_sum = convert_binary_elementwise( network, - data, - running_sum_tensor, - trt.ElementWiseOperation.SUM, target, + SourceIR.ACC, f"{name}_sum_1", + trt.ElementWiseOperation.SUM, + data, + running_sum_tensor, ) running_sum.set_input(1, current_sum) @@ -3449,13 +3697,14 @@ def acc_ops_cumsum( set_layer_name(running_sum, target, f"{name}_running_sum_2") running_sum_tensor = running_sum.get_output(0) - current_sum = add_binary_elementwise_layer( + current_sum = convert_binary_elementwise( network, - data, - running_sum_tensor, - trt.ElementWiseOperation.SUM, target, + SourceIR.ACC, f"{name}_sum_2", + trt.ElementWiseOperation.SUM, + data, + running_sum_tensor, ) running_sum.set_input(1, current_sum) diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index cf2101ef1a..6b953d43b7 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -20,7 +20,8 @@ from .converter_utils import * # noqa: F403 import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils -from torch_tensorrt.fx.converters.impl import activation, convolution +from torch_tensorrt.fx.converters.impl import activation +from torch_tensorrt.fx.converters.impl.elementwise import trunc_div _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -182,9 +183,7 @@ def aten_ops_div( network, target, None, kwargs_new, name ) elif rounding_mode == "trunc": - return acc_ops_converters.acc_ops_trunc_div( - network, target, None, kwargs_new, name - ) + return trunc_div(network, target, SourceIR.ATEN, name, args[0], args[1]) else: raise RuntimeError( f"Target {target} does not support rounding mode {rounding_mode}" diff --git a/py/torch_tensorrt/fx/converters/converter_utils.py b/py/torch_tensorrt/fx/converters/converter_utils.py index 49bf401f58..e955c7278b 100644 --- a/py/torch_tensorrt/fx/converters/converter_utils.py +++ b/py/torch_tensorrt/fx/converters/converter_utils.py @@ -28,6 +28,7 @@ class SourceIR(Enum): ACC = auto() ATEN = auto() PRIM = auto() + TORCHTRT_LOWERED = auto() UNKNOWN = auto() def __str__(self): @@ -39,6 +40,8 @@ def __str__(self): return "aten" elif self == SourceIR.PRIM: return "prim" + elif self == SourceIR.TORCHTRT_LOWERED: + return "torchtrt_lowered" else: return "unknown_ir" @@ -409,176 +412,7 @@ def broadcast( return a, b -def get_shape_with_dynamic_shape( - network: TRTNetwork, - shape: Union[list, tuple, torch.Tensor], - input_val: TRTTensor, - target: Target, - name: str, -) -> TRTTensor: - """ - Prepare the real output tensor shape for dynamic shape mode tensor input. - How this functions works: - Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation - output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual - reduce operation output shape. Steps of calculations are: - 1. get the actual tensor shape of input_val via add_shape layer; - 2. create a all 0 tensor [0, 0, 0]; - 3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False]; - 4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace - all -1 dynamic shape dimensions with actual batch_size value; - 5. output shape with actual batch_size as [2048, 128, 256] - - Args: - network (TRTNetwork): TensorRT network object. - shape: calculated shape of the expected output tensor - input_val (TRTTensor): A TensorRT ITensor. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - Returns: - TensorRT ITensors that represents the actual shape of the input_val - """ - # Ger real shape info for input_val - input_shape = network.add_shape(input_val).get_output(0) - - scale_layer = network.add_constant( - input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) - ) - set_layer_name(scale_layer, target, f"{name}_scale") - scale_res = scale_layer.get_output(0) - - length = input_shape.shape[0] - zero_layer = network.add_constant( - input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32)) - ) - set_layer_name(zero_layer, target, f"{name}_zeros") - - condition_val = add_binary_elementwise_layer( - network, - scale_res, - zero_layer.get_output(0), - trt.ElementWiseOperation.LESS, - target, - f"{name}_shape", - ) - select_layer = network.add_select(condition_val, input_shape, scale_res) - set_layer_name(select_layer, target, f"{name}_select") - return select_layer.get_output(0) - - -def add_binary_elementwise_layer( - network: TRTNetwork, - lhs_val: Union[int, float, TRTTensor, torch.Tensor], - rhs_val: Union[int, float, TRTTensor, torch.Tensor], - op_type: trt.ElementWiseOperation, - target: Target, - name: str, -) -> TRTTensor: - """ - This function adds a TensorRT elementwise layer. We allow both operands to be - constant (not a trt tensor) because in implicit batch dimension mode, we could - introduce constant via .size() op. Other scenario should be const folded first. - If any operand is not a trt tensor, we make it a trt constant layer while preserve - its dtype. Then we broadcast these two inputs to have the same number of dimensions. - - Limitation: - If we are using implicit batch dim mode, the operand that is not a trt - tensor are not allowed to have larger ranks than the trt tensor operand. - - Args: - network (TRTNetwork): TensorRT network object. - lhs_val (TRTTensor): Left operand of the binary operation. Could - be a TensorRT tensor, a PyTorch tensor or a simple value. - rhs_val (TRTTensor): Right operand of the binary operation. Similar - to lhs_val. - op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - - Returns: - The output of TensorRT Elementwise layer. - """ - lhs_dtype = None - rhs_dtype = None - is_lhs_trt_tensor = False - is_rhs_trt_tensor = False - - if isinstance(lhs_val, TRTTensor): - lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH) - is_lhs_trt_tensor = True - if isinstance(rhs_val, TRTTensor): - rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH) - is_rhs_trt_tensor = True - - if not is_lhs_trt_tensor and not is_rhs_trt_tensor: - warnings.warn( - f"Both operands of the binary elementwise op {name} " - "are constant. In this case, please consider constant fold the model first." - ) - return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val) - - # If the following conditions are true: - # 1. the network has implicit batch dimension, - # 2. one operand has shape [] (real shape is [batch_size]), - # 3. another operand is a scalar, - # then the result should also have shape [] (real shape is [batch_size]). - # - # In such case, we need to convert the scalar operand to tensor, because - # this way the shape will become [1], and then will be properly squeezed - # into [], meaning that the result will have shape [], which is what we - # expect. - # - # Note that the dtype here is supposed to be the same as the scalar - # dtype but we don't have a way to detect whether it makes sense for the - # scalar to be float or half. Hence we go with the lhs dtype. - if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): - rhs_val = np.array( - [rhs_val], dtype=unified_dtype_converter(lhs_val.dtype, Frameworks.NUMPY) - ) - if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): - lhs_val = np.array( - [lhs_val], dtype=unified_dtype_converter(rhs_val.dtype, Frameworks.NUMPY) - ) - - # When lhs is scalar, and rhs has shape [1,], then currently the assert - # will fail because lhs shape has fewer dimensions than rhs shape. This - # happens when using implicit batch dimension, when we removed the 1st - # dimension from input tensor, causing it to have shape [] - a scalar. We - # fix it by reducing the rhs constant with a squeeze_left, so it becomes a - # scalar too. More generally, we squeeze_left on input if it's a constant - # tensor. This is safe because broadcast will pad dimensions on the left - # (prepend) to make lhs and rhs shape compatible. - if network.has_implicit_batch_dimension: - if isinstance(lhs_val, (torch.Tensor, np.ndarray)): - lhs_val = squeeze_left(lhs_val) - if isinstance(rhs_val, (torch.Tensor, np.ndarray)): - rhs_val = squeeze_left(rhs_val) - - lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) - rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) - - # Check the limitation in the doc string. - if network.has_implicit_batch_dimension: - if is_lhs_trt_tensor and not is_rhs_trt_tensor: - assert len(lhs_val.shape) >= len( - rhs_val.shape - ), f"{lhs_val.shape} >= {rhs_val.shape}" - elif not is_lhs_trt_tensor and is_rhs_trt_tensor: - assert len(rhs_val.shape) >= len( - lhs_val.shape - ), f"{rhs_val.shape} >= {lhs_val.shape}" - - lhs_val, rhs_val = broadcast( - network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" - ) - layer = network.add_elementwise(lhs_val, rhs_val, op_type) - set_layer_name(layer, target, name) - output = layer.get_output(0) - output.name = output.name + "_" + target.__name__ - return output - - -def squeeze_left(const: Union[torch.Tensor, np.ndarray]): +def squeeze_left(const: torch.Tensor): """ Squeeze the size-1 dimensions on the left side of the shape tuple. PyTorch's `squeeze()` doesn't support passing multiple `dim`s at once, so @@ -594,38 +428,6 @@ def squeeze_left(const: Union[torch.Tensor, np.ndarray]): return const -def add_unary_layer( - network: TRTNetwork, - input_val: TRTTensor, - operation_type: trt.UnaryOperation, - target: Target, - name: str, -) -> TRTTensor: - """ - Add a TensorRT Unary layer to `network`. - - Args: - network (TRTNetwork): TensorRT network object. - input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor. - op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation. - target (Target): Target of fx node. - name (str): The name we want to assign to the created TensorRT layer. - - Returns: - The output of TensorRT Unary layer. - """ - if not isinstance(input_val, TRTTensor): - raise RuntimeError( - f"{operation_type} received input {input_val} that is not part " - "of the TensorRT region!" - ) - layer = network.add_unary(input_val, operation_type) - set_layer_name(layer, target, name) - output = layer.get_output(0) - output.name = output.name + "_" + target.__name__ - return layer.get_output(0) - - def add_reduce_layer( network: TRTNetwork, target: Target, @@ -730,142 +532,6 @@ def get_inputs_from_args_and_kwargs(args, kwargs, input_names): return inputs -def sign( - network: TRTNetwork, input_val: TRTTensor, target: Target, name: str -) -> TRTTensor: - """ - Sign is calculated as below: - x = input - sign = (exp(x) // exp(abs(x))) * 2 - 1 - For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0. - With multiply 2, the value become 2(for pos and 0) and 0(for neg). - Finally minus 1, the value become 1(for pos and 0) and -1(for neg). - - Args: - network (TRTNetwork): TensorRT network object. - input_val (TRTTensor): The input tensor. - target (Target): fx node target. - name (str): Name of the fx node with optional suffix. - - Returns: - A TensorRT tensor represent the result of sign operator. - """ - input_exp_output = add_unary_layer( - network, input_val, trt.UnaryOperation.EXP, target, f"{name}_prod_exp" - ) - input_abs_output = add_unary_layer( - network, input_val, trt.UnaryOperation.ABS, target, f"{name}_prod_abs" - ) - input_abs_exp_output = add_unary_layer( - network, - input_abs_output, - trt.UnaryOperation.EXP, - target, - f"{name}_prod_abs_exp", - ) - floor_div_output = add_binary_elementwise_layer( - network, - input_exp_output, - input_abs_exp_output, - trt.ElementWiseOperation.FLOOR_DIV, - target, - f"{name}_exp_floor_div", - ) - double_floor_div_output = add_binary_elementwise_layer( - network, - floor_div_output, - 2, - trt.ElementWiseOperation.PROD, - target, - f"{name}_floor_div*2", - ) - return add_binary_elementwise_layer( - network, - double_floor_div_output, - 1, - trt.ElementWiseOperation.SUB, - target, - f"{name}_sign", - ) - - -def trunc_div( - input: TRTTensor, other: TRTTensor, network: TRTNetwork, target: Target, name: str -) -> TRTTensor: - """ - Perform trunc divide on Tensor, result of divide will be round toward zero. - This means for positive number, it will be floor round; for negative number, - it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3]. - - Args: - input: divisor. - other: dividend. - network: INetworkDefinition. - target: node target. - name: namespace for the op - - Returns: - A TensorRT tensor represent the result of trunc divide. - """ - prod_output = add_binary_elementwise_layer( - network, input, other, trt.ElementWiseOperation.PROD, target, f"{name}_prod" - ) - sign_output = sign(network, prod_output, target, name) - - # Convert constant input into ITensor for UnaryOperation - if not isinstance(input, trt.tensorrt.ITensor): - input = get_trt_tensor(network, input, f"{name}_input") - if not isinstance(other, trt.tensorrt.ITensor): - other = get_trt_tensor( - network, - other, - f"{name}_other", - dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), - ) - - abs_input_output = add_unary_layer( - network, input, trt.UnaryOperation.ABS, target, f"{name}_abs_input" - ) - abs_other_output = add_unary_layer( - network, other, trt.UnaryOperation.ABS, target, f"{name}_abs_other" - ) - abs_floor_output = add_binary_elementwise_layer( - network, - abs_input_output, - abs_other_output, - trt.ElementWiseOperation.FLOOR_DIV, - target, - f"{name}_floor_div", - ) - output = add_binary_elementwise_layer( - network, - abs_floor_output, - sign_output, - trt.ElementWiseOperation.PROD, - target, - f"{name}_output", - ) - - return output - - -def get_python_op_from_trt_elementwise_op( - trt_op: TRTElementWiseOp, -) -> Callable[[Any, Any], Any]: - if trt_op == trt.ElementWiseOperation.SUM: - return operator.add - elif trt_op == trt.ElementWiseOperation.PROD: - return operator.mul - elif trt_op == trt.ElementWiseOperation.SUB: - return operator.sub - elif trt_op == trt.ElementWiseOperation.DIV: - return operator.truediv - elif trt_op == trt.ElementWiseOperation.FLOOR_DIV: - return operator.floordiv - else: - raise RuntimeError(f"{trt_op} is not supported yet!") - - def dtype_uniform( network: TRTNetwork, target: Target, name: str, input: TRTTensor, other: TRTTensor ): diff --git a/py/torch_tensorrt/fx/converters/impl/elementwise/__init__.py b/py/torch_tensorrt/fx/converters/impl/elementwise/__init__.py new file mode 100644 index 0000000000..6965f89636 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/elementwise/__init__.py @@ -0,0 +1 @@ +from .ops import * diff --git a/py/torch_tensorrt/fx/converters/impl/elementwise/base.py b/py/torch_tensorrt/fx/converters/impl/elementwise/base.py new file mode 100644 index 0000000000..261e45728f --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/elementwise/base.py @@ -0,0 +1,147 @@ +import operator +import warnings +from typing import Union, Callable, Any, Optional + +import numpy as np + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp +from torch_tensorrt.fx.utils import torch_dtype_from_trt +from torch_tensorrt.fx.converters.converter_utils import ( + SourceIR, + set_layer_name, + broadcast, + squeeze_left, + get_trt_tensor, +) + + +def get_python_op_from_trt_elementwise_op( + trt_op: TRTElementWiseOp, +) -> Callable[[Any, Any], Any]: + if trt_op == trt.ElementWiseOperation.SUM: + return operator.add + elif trt_op == trt.ElementWiseOperation.PROD: + return operator.mul + elif trt_op == trt.ElementWiseOperation.SUB: + return operator.sub + elif trt_op == trt.ElementWiseOperation.DIV: + return operator.truediv + elif trt_op == trt.ElementWiseOperation.FLOOR_DIV: + return operator.floordiv + else: + raise RuntimeError(f"{trt_op} is not supported yet!") + + +def convert_binary_elementwise( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + op_type: trt.ElementWiseOperation, + lhs_val: Union[int, float, TRTTensor, torch.Tensor], + rhs_val: Union[int, float, TRTTensor, torch.Tensor], +) -> TRTTensor: + """ + This function adds a TensorRT elementwise layer. We allow both operands to be + constant (not a trt tensor) because in implicit batch dimension mode, we could + introduce constant via .size() op. Other scenario should be const folded first. + If any operand is not a trt tensor, we make it a trt constant layer while preserve + its dtype. Then we broadcast these two inputs to have the same number of dimensions. + + Limitation: + If we are using implicit batch dim mode, the operand that is not a trt + tensor are not allowed to have larger ranks than the trt tensor operand. + + Args: + network (TRTNetwork): TensorRT network object. + target (Target): Target of fx node. + source_ir (SourceIR): The IR that is calling the function. + name (str): The name we want to assign to the created TensorRT layer. + lhs_val (TRTTensor): Left operand of the binary operation. Could + be a TensorRT tensor, a PyTorch tensor or a simple value. + rhs_val (TRTTensor): Right operand of the binary operation. Similar + to lhs_val. + op_type (trt.ElementWiseOperation): Type of the TensorRT elementwise binary operation. + + Returns: + The output of TensorRT Elementwise layer. + """ + lhs_dtype = None + rhs_dtype = None + is_lhs_trt_tensor = False + is_rhs_trt_tensor = False + + if isinstance(lhs_val, TRTTensor): + lhs_dtype = torch_dtype_from_trt(lhs_val.dtype) + is_lhs_trt_tensor = True + if isinstance(rhs_val, TRTTensor): + rhs_dtype = torch_dtype_from_trt(rhs_val.dtype) + is_rhs_trt_tensor = True + + if not is_lhs_trt_tensor and not is_rhs_trt_tensor: + warnings.warn( + f"Both operands of the binary elementwise op {name} " + "are constant. In this case, please consider constant fold the model first." + ) + return get_python_op_from_trt_elementwise_op(op_type)(lhs_val, rhs_val) + + # If the following conditions are true: + # 1. the network has implicit batch dimension, + # 2. one operand has shape [] (real shape is [batch_size]), + # 3. another operand is a scalar, + # then the result should also have shape [] (real shape is [batch_size]). + # + # In such case, we need to convert the scalar operand to tensor, because + # this way the shape will become [1], and then will be properly squeezed + # into [], meaning that the result will have shape [], which is what we + # expect. + # + # Note that the dtype here is supposed to be the same as the scalar + # dtype but we don't have a way to detect whether it makes sense for the + # scalar to be float or half. Hence we go with the lhs dtype. + if is_lhs_trt_tensor and isinstance(rhs_val, (float, int)): + rhs_val = torch.tensor([rhs_val], dtype=lhs_dtype) + if is_rhs_trt_tensor and isinstance(lhs_val, (float, int)): + lhs_val = torch.tensor([lhs_val], dtype=rhs_dtype) + + # When lhs is scalar, and rhs has shape [1,], then currently the assert + # will fail because lhs shape has fewer dimensions than rhs shape. This + # happens when using implicit batch dimension, when we removed the 1st + # dimension from input tensor, causing it to have shape [] - a scalar. We + # fix it by reducing the rhs constant with a squeeze_left, so it becomes a + # scalar too. More generally, we squeeze_left on input if it's a constant + # tensor. This is safe because broadcast will pad dimensions on the left + # (prepend) to make lhs and rhs shape compatible. + if network.has_implicit_batch_dimension: + if isinstance(lhs_val, torch.Tensor): + lhs_val = squeeze_left(lhs_val) + if isinstance(rhs_val, torch.Tensor): + rhs_val = squeeze_left(rhs_val) + + lhs_val = get_trt_tensor(network, lhs_val, f"{name}_lhs", lhs_dtype) + rhs_val = get_trt_tensor(network, rhs_val, f"{name}_rhs", rhs_dtype) + + # Check the limitation in the doc string. + if network.has_implicit_batch_dimension: + if is_lhs_trt_tensor and not is_rhs_trt_tensor: + assert len(lhs_val.shape) >= len( + rhs_val.shape + ), f"{lhs_val.shape} >= {rhs_val.shape}" + elif not is_lhs_trt_tensor and is_rhs_trt_tensor: + assert len(rhs_val.shape) >= len( + lhs_val.shape + ), f"{rhs_val.shape} >= {lhs_val.shape}" + + lhs_val, rhs_val = broadcast( + network, lhs_val, rhs_val, f"{name}_lhs", f"{name}_rhs" + ) + layer = network.add_elementwise(lhs_val, rhs_val, op_type) + set_layer_name(layer, target, name, source_ir) + output = layer.get_output(0) + output.name = output.name + "_" + target.__name__ + return output diff --git a/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py new file mode 100644 index 0000000000..ae44ce838c --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py @@ -0,0 +1,111 @@ +import operator +import warnings +from typing import Union, Callable, Any, Optional + +import numpy as np + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp +from torch_tensorrt.fx.utils import torch_dtype_from_trt +from torch_tensorrt.fx.converters.converter_utils import ( + SourceIR, + get_trt_tensor, +) + +from torch_tensorrt.fx.converters.impl.elementwise.base import ( + convert_binary_elementwise, +) +from torch_tensorrt.fx.converters.impl.unary.base import convert_unary +from torch_tensorrt.fx.converters.impl.unary import sign + + +def trunc_div( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, + other: TRTTensor, +) -> TRTTensor: + """ + Perform trunc divide on Tensor, result of divide will be round toward zero. + This means for positive number, it will be floor round; for negative number, + it will be ceil round. Example: [2.1, 0.8, -3.2] -> [2, 0, -3]. + + Args: + network: INetworkDefinition. + target: node target + source_ir (SourceIR): Source IR calling the function. + name: namespace for the op + input: divisor. + other: dividend. + + Returns: + A TensorRT tensor represent the result of trunc divide. + """ + prod_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_prod", + trt.ElementWiseOperation.PROD, + input, + other, + ) + + sign_output = sign( + network, + target, + source_ir, + name, + prod_output, + ) + + # Convert constant input into ITensor for UnaryOperation + if not isinstance(input, trt.tensorrt.ITensor): + input = get_trt_tensor(network, input, f"{name}_input") + if not isinstance(other, trt.tensorrt.ITensor): + other = get_trt_tensor( + network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype) + ) + + abs_input_output = convert_unary( + network, + target, + source_ir, + f"{name}_abs_input", + trt.UnaryOperation.ABS, + input, + ) + abs_other_output = convert_unary( + network, + target, + source_ir, + f"{name}_abs_other", + trt.UnaryOperation.ABS, + other, + ) + abs_floor_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_floor_div", + trt.ElementWiseOperation.FLOOR_DIV, + abs_input_output, + abs_other_output, + ) + output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_output", + trt.ElementWiseOperation.PROD, + abs_floor_output, + sign_output, + ) + + return output diff --git a/py/torch_tensorrt/fx/converters/impl/shape.py b/py/torch_tensorrt/fx/converters/impl/shape.py new file mode 100644 index 0000000000..8667c712b8 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/shape.py @@ -0,0 +1,81 @@ +import operator +import warnings +from typing import Union, Callable, Any, Optional + +import numpy as np + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +import torch +from torch.fx.node import Target + +from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp +from torch_tensorrt.fx.utils import torch_dtype_from_trt +from torch_tensorrt.fx.converters.converter_utils import ( + SourceIR, + set_layer_name, + to_numpy, +) + +from torch_tensorrt.fx.converters.impl.elementwise.base import ( + convert_binary_elementwise, +) + + +def get_shape_with_dynamic_shape( + network: TRTNetwork, + target: Target, + source_ir: SourceIR, + name: str, + shape: Union[list, tuple, torch.Tensor], + input_val: TRTTensor, +) -> TRTTensor: + """ + Prepare the real output tensor shape for dynamic shape mode tensor input. + How this functions works: + Assuming the input_val has actual shape [2048, 256, 512], expected reduce operation + output shape is [-1, 128, 256], this function should return [2048, 128, 256] as the actual + reduce operation output shape. Steps of calculations are: + 1. get the actual tensor shape of input_val via add_shape layer; + 2. create a all 0 tensor [0, 0, 0]; + 3. run elementwise comparision the [0, 0, 0] and [-1, 128, 256] tensor, get a condition tensor [True, False, False]; + 4. use the condition tensor [True, False, False] to do selection between [2048, 256, 512] and [-1, 128, 256], replace + all -1 dynamic shape dimensions with actual batch_size value; + 5. output shape with actual batch_size as [2048, 128, 256] + + Args: + network (TRTNetwork): TensorRT network object. + shape: calculated shape of the expected output tensor + input_val (TRTTensor): A TensorRT ITensor. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + Returns: + TensorRT ITensors that represents the actual shape of the input_val + """ + # Ger real shape info for input_val + input_shape = network.add_shape(input_val).get_output(0) + + scale_layer = network.add_constant( + input_shape.shape, np.ascontiguousarray(shape, dtype=np.int32) + ) + set_layer_name(scale_layer, target, f"{name}_scale") + scale_res = scale_layer.get_output(0) + + length = input_shape.shape[0] + zero_layer = network.add_constant( + input_shape.shape, to_numpy(torch.zeros((length), dtype=torch.int32)) + ) + set_layer_name(zero_layer, target, f"{name}_zeros") + + condition_val = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_shape", + trt.ElementWiseOperation.LESS, + scale_res, + zero_layer.get_output(0), + ) + select_layer = network.add_select(condition_val, input_shape, scale_res) + set_layer_name(select_layer, target, f"{name}_select") + return select_layer.get_output(0) diff --git a/py/torch_tensorrt/fx/converters/impl/unary/__init__.py b/py/torch_tensorrt/fx/converters/impl/unary/__init__.py new file mode 100644 index 0000000000..6965f89636 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/unary/__init__.py @@ -0,0 +1 @@ +from .ops import * diff --git a/py/torch_tensorrt/fx/converters/impl/unary/base.py b/py/torch_tensorrt/fx/converters/impl/unary/base.py new file mode 100644 index 0000000000..fea6334170 --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/unary/base.py @@ -0,0 +1,53 @@ +import operator +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from enum import Enum, auto + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) + +from torch_tensorrt.fx.converters.converter_utils import SourceIR, set_layer_name + +from torch_tensorrt.fx.converters.impl.elementwise.base import ( + convert_binary_elementwise, +) + + +def convert_unary( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + operation_type: trt.UnaryOperation, + input_val: TRTTensor, +) -> TRTTensor: + """ + Add a TensorRT Unary layer to `network`. + + Args: + network (TRTNetwork): TensorRT network object. + input_val (TRTTensor): Input to the unary op. Must be a TensorRT tensor. + op_type (trt.ElementWiseOperation): Type of the TensorRT unary operation. + target (Target): Target of fx node. + name (str): The name we want to assign to the created TensorRT layer. + + Returns: + The output of TensorRT Unary layer. + """ + if not isinstance(input_val, TRTTensor): + raise RuntimeError( + f"{operation_type} received input {input_val} that is not part " + "of the TensorRT region!" + ) + layer = network.add_unary(input_val, operation_type) + set_layer_name(layer, target, name, source_ir) + output = layer.get_output(0) + output.name = output.name + "_" + target.__name__ + return layer.get_output(0) diff --git a/py/torch_tensorrt/fx/converters/impl/unary/ops.py b/py/torch_tensorrt/fx/converters/impl/unary/ops.py new file mode 100644 index 0000000000..cba760736a --- /dev/null +++ b/py/torch_tensorrt/fx/converters/impl/unary/ops.py @@ -0,0 +1,104 @@ +import operator +import warnings +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union + +from enum import Enum, auto + +# @manual=//deeplearning/trt/python:py_tensorrt +import tensorrt as trt +from torch.fx.node import Target + +from torch_tensorrt.fx.types import ( + TRTNetwork, + TRTTensor, +) + +from torch_tensorrt.fx.converters.converter_utils import ( + SourceIR, +) + +from torch_tensorrt.fx.converters.impl.elementwise.base import ( + convert_binary_elementwise, +) +from torch_tensorrt.fx.converters.impl.unary.base import convert_unary + + +def sign( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input_val: TRTTensor, +) -> TRTTensor: + """ + Sign is calculated as below: + x = input + sign = (exp(x) // exp(abs(x))) * 2 - 1 + For positive number and 0, (exp(x) // exp(abs(x))) yield 1; for negative number, (exp(x) // exp(abs(x))) yield 0. + With multiply 2, the value become 2(for pos and 0) and 0(for neg). + Finally minus 1, the value become 1(for pos and 0) and -1(for neg). + + Args: + network (TRTNetwork): TensorRT network object. + target (Target): fx node target. + source_ir (SourceIR): Source IR calling the function + name (str): Name of the fx node with optional suffix. + input_val (TRTTensor): The input tensor. + + Returns: + A TensorRT tensor represent the result of sign operator. + """ + input_exp_output = convert_unary( + network, + target, + source_ir, + f"{name}_prod_exp", + trt.UnaryOperation.EXP, + input_val, + ) + input_abs_output = convert_unary( + network, + target, + source_ir, + f"{name}_prod_abs", + trt.UnaryOperation.ABS, + input_val, + ) + input_abs_exp_output = convert_unary( + network, + target, + source_ir, + f"{name}_prod_abs_exp", + trt.UnaryOperation.EXP, + input_abs_output, + ) + + floor_div_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_exp_floor_div", + trt.ElementWiseOperation.FLOOR_DIV, + input_exp_output, + input_abs_exp_output, + ) + + double_floor_div_output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_floor_div*2", + trt.ElementWiseOperation.PROD, + floor_div_output, + 2, + ) + + return convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_sign", + trt.ElementWiseOperation.SUB, + double_floor_div_output, + 1, + ) From 12e8aa9c7eb5d5b61d0c47ec97060c424d0b655b Mon Sep 17 00:00:00 2001 From: Apurba Bose <44209735+apbose@users.noreply.github.com> Date: Tue, 30 May 2023 15:15:23 -0700 Subject: [PATCH 2/3] Moving elementwise core to impl - rsqrt (FX Converter Refactor [9/N]) (#1905) --- .../fx/converters/aten_ops_converters.py | 37 +++++++++++++++++++ .../fx/converters/impl/elementwise/ops.py | 30 +++++++++++++++ .../converters/aten_op/test_rsqrt_aten.py | 29 +++++++++++++++ 3 files changed, 96 insertions(+) create mode 100644 py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index 6b953d43b7..ef7b7bf5b7 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -22,6 +22,7 @@ import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils from torch_tensorrt.fx.converters.impl import activation from torch_tensorrt.fx.converters.impl.elementwise import trunc_div +from torch_tensorrt.fx.converters.impl.elementwise import rsqrt _LOGGER: logging.Logger = logging.getLogger(__name__) @@ -389,6 +390,42 @@ def aten_ops_relu( ) +@tensorrt_converter(torch.ops.aten.relu.default) +def aten_ops_relu( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return activation.relu( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + +@tensorrt_converter(torch.ops.aten.rsqrt.default) +def aten_ops_rsqrt( + network: TRTNetwork, + target: Target, + args: Tuple[Argument, ...], + kwargs: Dict[str, Argument], + name: str, +) -> Union[TRTTensor, Sequence[TRTTensor]]: + + return rsqrt( + network, + target, + SourceIR.ATEN, + name, + args[0], + ) + + @tensorrt_converter(torch.ops.aten.sub.Tensor) def aten_ops_sub( network: TRTNetwork, diff --git a/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py index ae44ce838c..8fddb426a6 100644 --- a/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py +++ b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py @@ -109,3 +109,33 @@ def trunc_div( ) return output + + +def rsqrt( + network: TRTNetwork, + target: Target, + source_ir: Optional[SourceIR], + name: str, + input: TRTTensor, +) -> TRTTensor: + + sqrt_trt_output = convert_unary( + network, + target, + source_ir, + f"{name}_sqrt", + trt.UnaryOperation.SQRT, + input, + ) + + output = convert_binary_elementwise( + network, + target, + source_ir, + f"{name}_output", + trt.ElementWiseOperation.DIV, + 1, + sqrt_trt_output, + ) + + return output diff --git a/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py new file mode 100644 index 0000000000..3fa27af1a0 --- /dev/null +++ b/py/torch_tensorrt/fx/test/converters/aten_op/test_rsqrt_aten.py @@ -0,0 +1,29 @@ +import torch +import torch.nn as nn +from parameterized import parameterized +from torch.testing._internal.common_utils import run_tests +from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec + + +class TestRSqrtConverter(DispatchTestCase): + @parameterized.expand( + [ + ("2d_dim_alpha", (2, 1), 2), + ("3d_dim_alpha", (2, 1, 2), 2), + ] + ) + def test_rsqrt(self, _, x, alpha): + class rsqrt(nn.Module): + def forward(self, input): + return torch.rsqrt(input) + + inputs = [torch.randn(x) + 1] + self.run_test( + rsqrt(), + inputs, + expected_ops={torch.ops.aten.rsqrt.default}, + ) + + +if __name__ == "__main__": + run_tests() From 2caac760843ec56d34c7aed99c18242cc4c8e602 Mon Sep 17 00:00:00 2001 From: apbose Date: Thu, 6 Jul 2023 13:00:27 -0700 Subject: [PATCH 3/3] Changing torch_dtype_to_trt and torch_dtype_from_trt to unified_dtype_converter and adding convolution in acc_ops_converters.py and aten_ops_converters.py --- py/torch_tensorrt/fx/converters/acc_ops_converters.py | 2 +- py/torch_tensorrt/fx/converters/aten_ops_converters.py | 2 +- py/torch_tensorrt/fx/converters/impl/elementwise/base.py | 6 +++--- py/torch_tensorrt/fx/converters/impl/elementwise/ops.py | 7 +++++-- py/torch_tensorrt/fx/converters/impl/shape.py | 2 +- 5 files changed, 11 insertions(+), 8 deletions(-) diff --git a/py/torch_tensorrt/fx/converters/acc_ops_converters.py b/py/torch_tensorrt/fx/converters/acc_ops_converters.py index fcb9ce9aad..222f1d2a5c 100644 --- a/py/torch_tensorrt/fx/converters/acc_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/acc_ops_converters.py @@ -26,7 +26,7 @@ trt_transposed_matmul, ) from torch_tensorrt.fx.tracer.acc_tracer.acc_ops import contiguous -from torch_tensorrt.fx.converters.impl import activation +from torch_tensorrt.fx.converters.impl import activation, convolution from torch_tensorrt.fx.converters.impl.elementwise import trunc_div from torch_tensorrt.fx.converters.impl.unary import sign from torch_tensorrt.fx.converters.impl.elementwise.base import ( diff --git a/py/torch_tensorrt/fx/converters/aten_ops_converters.py b/py/torch_tensorrt/fx/converters/aten_ops_converters.py index ef7b7bf5b7..bee36207da 100644 --- a/py/torch_tensorrt/fx/converters/aten_ops_converters.py +++ b/py/torch_tensorrt/fx/converters/aten_ops_converters.py @@ -20,7 +20,7 @@ from .converter_utils import * # noqa: F403 import torch_tensorrt.fx.tracer.acc_tracer.acc_utils as acc_utils -from torch_tensorrt.fx.converters.impl import activation +from torch_tensorrt.fx.converters.impl import activation, convolution from torch_tensorrt.fx.converters.impl.elementwise import trunc_div from torch_tensorrt.fx.converters.impl.elementwise import rsqrt diff --git a/py/torch_tensorrt/fx/converters/impl/elementwise/base.py b/py/torch_tensorrt/fx/converters/impl/elementwise/base.py index 261e45728f..e79d5048cb 100644 --- a/py/torch_tensorrt/fx/converters/impl/elementwise/base.py +++ b/py/torch_tensorrt/fx/converters/impl/elementwise/base.py @@ -10,7 +10,7 @@ from torch.fx.node import Target from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp -from torch_tensorrt.fx.utils import torch_dtype_from_trt +from torch_tensorrt.fx.utils import unified_dtype_converter, Frameworks from torch_tensorrt.fx.converters.converter_utils import ( SourceIR, set_layer_name, @@ -77,10 +77,10 @@ def convert_binary_elementwise( is_rhs_trt_tensor = False if isinstance(lhs_val, TRTTensor): - lhs_dtype = torch_dtype_from_trt(lhs_val.dtype) + lhs_dtype = unified_dtype_converter(lhs_val.dtype, Frameworks.TORCH) is_lhs_trt_tensor = True if isinstance(rhs_val, TRTTensor): - rhs_dtype = torch_dtype_from_trt(rhs_val.dtype) + rhs_dtype = unified_dtype_converter(rhs_val.dtype, Frameworks.TORCH) is_rhs_trt_tensor = True if not is_lhs_trt_tensor and not is_rhs_trt_tensor: diff --git a/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py index 8fddb426a6..79f3de90dd 100644 --- a/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py +++ b/py/torch_tensorrt/fx/converters/impl/elementwise/ops.py @@ -10,7 +10,7 @@ from torch.fx.node import Target from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp -from torch_tensorrt.fx.utils import torch_dtype_from_trt +from torch_tensorrt.fx.utils import unified_dtype_converter, Frameworks from torch_tensorrt.fx.converters.converter_utils import ( SourceIR, get_trt_tensor, @@ -70,7 +70,10 @@ def trunc_div( input = get_trt_tensor(network, input, f"{name}_input") if not isinstance(other, trt.tensorrt.ITensor): other = get_trt_tensor( - network, other, f"{name}_other", dtype=torch_dtype_from_trt(input.dtype) + network, + other, + f"{name}_other", + dtype=unified_dtype_converter(input.dtype, Frameworks.TORCH), ) abs_input_output = convert_unary( diff --git a/py/torch_tensorrt/fx/converters/impl/shape.py b/py/torch_tensorrt/fx/converters/impl/shape.py index 8667c712b8..6767dbb4a2 100644 --- a/py/torch_tensorrt/fx/converters/impl/shape.py +++ b/py/torch_tensorrt/fx/converters/impl/shape.py @@ -10,7 +10,7 @@ from torch.fx.node import Target from torch_tensorrt.fx.types import TRTNetwork, TRTTensor, TRTElementWiseOp -from torch_tensorrt.fx.utils import torch_dtype_from_trt +from torch_tensorrt.fx.utils import unified_dtype_converter, Frameworks from torch_tensorrt.fx.converters.converter_utils import ( SourceIR, set_layer_name,