@@ -800,6 +800,40 @@ def aten_ops_tile(
800
800
)
801
801
802
802
803
+ def zero_output_validator (node : Node ) -> bool :
804
+ if 0 in node .args [1 ]:
805
+ _LOGGER .debug (
806
+ f"We do not support output tensor { node .args [1 ]} tensors with zero-sized dimensions for this operation."
807
+ )
808
+ return False
809
+ else :
810
+ return True
811
+
812
+
813
+ @dynamo_tensorrt_converter (
814
+ torch .ops .aten .as_strided .default ,
815
+ capability_validator = zero_output_validator ,
816
+ )
817
+ @dynamo_tensorrt_converter (torch .ops .aten .as_strided .default )
818
+ def aten_ops_as_strided (
819
+ ctx : ConversionContext ,
820
+ target : Target ,
821
+ args : Tuple [Argument , ...],
822
+ kwargs : Dict [str , Argument ],
823
+ name : str ,
824
+ ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
825
+ return impl .slice .as_strided (
826
+ ctx ,
827
+ target ,
828
+ source_ir = SourceIR .ATEN ,
829
+ name = name ,
830
+ input = args [0 ],
831
+ size = args [1 ],
832
+ stride = args [2 ],
833
+ storage_offset = args_bounds_check (args , 3 , None ),
834
+ )
835
+
836
+
803
837
@dynamo_tensorrt_converter (torch .ops .aten .permute .default )
804
838
@enforce_tensor_types (
805
839
{
@@ -2186,40 +2220,6 @@ def aten_ops_linear(
2186
2220
)
2187
2221
2188
2222
2189
- def zero_output_validator (node : Node ) -> bool :
2190
- if 0 in node .args [1 ]:
2191
- _LOGGER .debug (
2192
- f"We do not support output tensor { node .args [1 ]} tensors with zero-sized dimensions for this operation."
2193
- )
2194
- return False
2195
- else :
2196
- return True
2197
-
2198
-
2199
- @dynamo_tensorrt_converter (
2200
- torch .ops .aten .as_strided .default ,
2201
- capability_validator = zero_output_validator ,
2202
- )
2203
- @dynamo_tensorrt_converter (torch .ops .aten .as_strided .default )
2204
- def aten_ops_as_strided (
2205
- ctx : ConversionContext ,
2206
- target : Target ,
2207
- args : Tuple [Argument , ...],
2208
- kwargs : Dict [str , Argument ],
2209
- name : str ,
2210
- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
2211
- return impl .slice .as_strided (
2212
- ctx ,
2213
- target ,
2214
- source_ir = SourceIR .ATEN ,
2215
- name = name ,
2216
- input = args [0 ],
2217
- size = args [1 ],
2218
- stride = args [2 ],
2219
- storage_offset = args_bounds_check (args , 3 , None ),
2220
- )
2221
-
2222
-
2223
2223
def avg_pool_param_validator (pool_node : Node ) -> bool :
2224
2224
ceil_mode = args_bounds_check (pool_node .args , 4 , False )
2225
2225
divisor_override = args_bounds_check (pool_node .args , 6 )
0 commit comments