@@ -678,7 +678,13 @@ def acc_ops_batch_norm(
678678
679679
680680@tensorrt_converter (acc_ops .layer_norm )
681- def acc_ops_layer_norm (network , target , args , kwargs , name ):
681+ def acc_ops_layer_norm (
682+ network : TRTNetwork ,
683+ target : Target ,
684+ args : Tuple [Argument , ...],
685+ kwargs : Dict [str , Argument ],
686+ name : str ,
687+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
682688 return add_layer_norm (network , target , kwargs , name )
683689
684690
@@ -690,37 +696,7 @@ def acc_ops_softmax(
690696 kwargs : Dict [str , Argument ],
691697 name : str ,
692698) -> Union [TRTTensor , Sequence [TRTTensor ]]:
693- input_val = kwargs ["input" ]
694- input_ranks = len (input_val .shape ) + (1 if network .has_implicit_batch_dimension else 0 ) # type: ignore[union-attr]
695-
696- if not isinstance (input_val , TRTTensor ):
697- raise RuntimeError (
698- f"softmax received input { input_val } that is not part "
699- "of the TensorRT region!"
700- )
701-
702- # Used to get dim when dim is None. Copied from PyTorch softmax implementation.
703- def get_softmax_dim (ndim : int ) -> int :
704- if ndim == 0 or ndim == 1 or ndim == 3 :
705- ret = 0
706- else :
707- ret = 1
708- return ret
709-
710- if kwargs ["dim" ] is None :
711- dim = get_softmax_dim (input_ranks )
712- else :
713- dim = cast (int , kwargs ["dim" ])
714-
715- dim = get_positive_dim (dim , input_ranks )
716- if network .has_implicit_batch_dimension :
717- assert dim != 0 , "Can't apply softmax on batch dimension when it's implicit."
718- dim -= 1
719-
720- layer = network .add_softmax (input_val )
721- layer .axes = 1 << dim
722- set_layer_name (layer , target , name )
723- return layer .get_output (0 )
699+ return add_softmax (network , target , kwargs , name )
724700
725701
726702@tensorrt_converter (acc_ops .tile )
@@ -956,9 +932,7 @@ def acc_ops_sqrt(
956932 kwargs : Dict [str , Argument ],
957933 name : str ,
958934) -> Union [TRTTensor , Sequence [TRTTensor ]]:
959- input_val = kwargs ["input" ]
960- operation_type = trt .UnaryOperation .SQRT
961- return add_unary_layer (network , input_val , operation_type , target , name )
935+ return add_sqrt (network , target , kwargs , name )
962936
963937
964938@tensorrt_converter (acc_ops .reciprocal )
@@ -1619,40 +1593,7 @@ def acc_ops_squeeze(
16191593 kwargs : Dict [str , Argument ],
16201594 name : str ,
16211595) -> Union [TRTTensor , Sequence [TRTTensor ]]:
1622- input_val = kwargs ["input" ]
1623-
1624- if not isinstance (input_val , TRTTensor ):
1625- raise RuntimeError (
1626- f"squeeze received input { input_val } that is not part "
1627- "of the TensorRT region!"
1628- )
1629-
1630- dim = cast (Optional [int ], kwargs ["dim" ] if "dim" in kwargs else None )
1631- # Squeeze with dim=None would only work in explicit batch dim mode without any dynamic
1632- # dim, which is a very rare case. For now we just claim not supporting dim=None.
1633- assert dim is not None , "We don't support dim=None right now for squeeze."
1634-
1635- dim = get_positive_dim (
1636- dim , len (input_val .shape ) + (1 if network .has_implicit_batch_dimension else 0 )
1637- )
1638- if network .has_implicit_batch_dimension :
1639- assert dim != 0 , "We don't support squeeze batch dim when it's implicit."
1640- dim -= 1
1641-
1642- assert input_val .shape [dim ] != - 1 , "We don't support squeeze dynamic dim."
1643- assert (
1644- len (get_dynamic_dims (input_val .shape )) <= 1
1645- ), "Currently more than one dynamic dim for input to squeeze is not supported."
1646-
1647- output_shape = []
1648- for i , s in enumerate (input_val .shape ):
1649- if i == dim and s == 1 :
1650- continue
1651- output_shape .append (s )
1652- layer = network .add_shuffle (input_val )
1653- layer .reshape_dims = tuple (output_shape )
1654- set_layer_name (layer , target , name )
1655- return layer .get_output (0 )
1596+ return add_squeeze (network , target , kwargs , name )
16561597
16571598
16581599@tensorrt_converter (acc_ops .add )
@@ -2022,89 +1963,7 @@ def acc_ops_where(
20221963 kwargs : Dict [str , Argument ],
20231964 name : str ,
20241965) -> Union [TRTTensor , Sequence [TRTTensor ]]:
2025-
2026- condition_t = kwargs ["condition" ]
2027- x_t = kwargs ["x" ]
2028- y_t = kwargs ["y" ]
2029-
2030- if type (x_t ) != TRTTensor :
2031- assert type (x_t ) is torch .Tensor , f"value { x_t } is not torch.Tensor!"
2032-
2033- if type (y_t ) != TRTTensor :
2034- assert type (y_t ) is torch .Tensor , f"value { y_t } is not torch.Tensor!"
2035-
2036- # get output shape
2037-
2038- x_shape = list (x_t .shape )
2039- y_shape = list (y_t .shape )
2040- condition_shape = list (condition_t .shape )
2041- output_shape = list (torch .broadcast_shapes (condition_shape , x_shape , y_shape ))
2042-
2043- # expand shape
2044- if type (condition_t ) != TRTTensor :
2045- assert condition_t .dtype == torch .bool , "condition dtype is not bool"
2046- if condition_shape != output_shape :
2047- condition_t .expand (output_shape )
2048- condition_t = condition_t .to (torch .int32 )
2049- condition_const = get_trt_tensor (network , condition_t , f"{ name } _condition" )
2050- condition_layer = network .add_identity (condition_const )
2051- condition_layer .set_output_type (0 , trt .bool )
2052- set_layer_name (condition_layer , target , f"{ name } _condition" )
2053- condition_val = condition_layer .get_output (0 )
2054- else :
2055- assert condition_t .dtype == trt .bool , "mask dtype is not bool!"
2056- if condition_shape != output_shape :
2057- condition_val = acc_ops_expand_tensor (
2058- network ,
2059- target ,
2060- None ,
2061- {"input" : condition_t , "sizes" : output_shape },
2062- name = f"{ name } _expand" ,
2063- )
2064- else :
2065- condition_val = condition_t
2066-
2067- if type (x_t ) != TRTTensor :
2068- if x_shape != output_shape :
2069- # special case where 1 element in x_t
2070- if len (x_t .shape ) == 0 :
2071- x_t = x_t .unsqueeze (0 )
2072- x_t = x_t .expand (output_shape )
2073- x_val = get_trt_tensor (network , x_t , f"{ name } _x" )
2074- else :
2075- x_val = x_t
2076- if x_shape != output_shape :
2077- x_val = acc_ops_expand_tensor (
2078- network ,
2079- target ,
2080- None ,
2081- {"input" : x_val , "sizes" : output_shape },
2082- name = f"{ name } _x_expand" ,
2083- )
2084-
2085- if type (y_t ) != TRTTensor :
2086- if y_shape != output_shape :
2087- # special case where 1 element in y_t
2088- if len (y_t .shape ) == 0 :
2089- y_t = y_t .unsqueeze (0 )
2090- y_t = y_t .expand (output_shape )
2091- y_val = get_trt_tensor (network , y_t , f"{ name } _y" )
2092- else :
2093- y_val = y_t
2094- if y_shape != output_shape :
2095- y_val = acc_ops_expand_tensor (
2096- network ,
2097- target ,
2098- None ,
2099- {"input" : y_val , "sizes" : output_shape },
2100- name = f"{ name } _y_expand" ,
2101- )
2102-
2103- select_layer = network .add_select (condition_val , x_val , y_val )
2104-
2105- set_layer_name (select_layer , target , f"{ name } _select" )
2106-
2107- return select_layer .get_output (0 )
1966+ return add_where (network , target , kwargs , name )
21081967
21091968
21101969@tensorrt_converter (acc_ops .masked_fill , no_implicit_batch_dim = True )
@@ -2721,62 +2580,7 @@ def acc_ops_chunk(
27212580 kwargs : Dict [str , Argument ],
27222581 name : str ,
27232582) -> Union [TRTTensor , Sequence [TRTTensor ]]:
2724- input_val = kwargs ["input" ]
2725- chunks = cast (int , kwargs ["chunks" ])
2726- dim = cast (int , kwargs ["dim" ])
2727- input_dim_size = len (input_val .shape ) # type: ignore[union-attr]
2728-
2729- if not isinstance (input_val , TRTTensor ):
2730- raise RuntimeError (
2731- f"chunk received input { input_val } that is not part "
2732- "of the TensorRT region!"
2733- )
2734-
2735- dynamic_shape = has_dynamic_shape (input_val .shape )
2736- if network .has_implicit_batch_dimension :
2737- input_dim_size += 1
2738- dim = get_positive_dim (dim , input_dim_size )
2739- assert dim != 0 , "Can't chunk on batch dim when it's implicit!"
2740- dim -= 1
2741- else :
2742- if dynamic_shape :
2743- assert input_val .shape [dim ] != - 1 , "Can't chunk on dynamic shape dimension!"
2744- dim = get_positive_dim (dim , input_dim_size )
2745-
2746- if chunks > input_val .shape [dim ]:
2747- warnings .warn (
2748- f"Asked for { chunks } chunks along dimention "
2749- f"{ dim } on tensor with size { input_val .shape } , chunks "
2750- f"will default to { input_val .shape [dim ]} " ,
2751- RuntimeWarning ,
2752- )
2753- chunks = input_val .shape [dim ]
2754-
2755- start = [0 ] * len (input_val .shape )
2756- stride = [1 ] * len (start )
2757- offset = 0
2758- split_size = (input_val .shape [dim ] + chunks - 1 ) // chunks
2759-
2760- max_offset = input_val .shape [dim ]
2761- # add slice layers
2762- output = []
2763- for i in range (chunks ):
2764- shape = list (input_val .shape )
2765- shape [dim ] = min (split_size , max_offset - offset )
2766- if dynamic_shape :
2767- shape = get_shape_with_dynamic_shape (
2768- network , shape , input_val , target , f"{ name } _{ i } "
2769- )
2770- start [dim ] = offset
2771- layer = network .add_slice (
2772- input_val , start = start , shape = [] if dynamic_shape else shape , stride = stride
2773- )
2774- if dynamic_shape :
2775- layer .set_input (2 , shape )
2776- offset += split_size
2777- set_layer_name (layer , target , f"{ name } _{ i } " )
2778- output .append (layer .get_output (0 ))
2779- return output
2583+ return add_chunk (network , target , kwargs , name )
27802584
27812585
27822586@tensorrt_converter (acc_ops .cumsum , no_implicit_batch_dim = True )
0 commit comments