2121from ..utils import get_dynamic_dims , torch_dtype_from_trt , torch_dtype_to_trt
2222
2323from .converter_utils import * # noqa: F403
24+ from .activation import *
25+ from .operator import *
26+
2427import torch_tensorrt .fx .tracer .acc_tracer .acc_utils as acc_utils
25- import activation
26- import operator
28+
2729
2830_LOGGER : logging .Logger = logging .getLogger (__name__ )
2931
@@ -40,7 +42,7 @@ def aten_ops_add(
4042 "input" : args [0 ],
4143 "other" : args [1 ],
4244 }
43- return operator . add_add (network , target , None , kwargs_new , name )
45+ return add_add (network , target , None , kwargs_new , name )
4446
4547
4648@tensorrt_converter (torch .ops .aten .mean .dim )
@@ -143,13 +145,13 @@ def aten_ops_div(
143145 }
144146 rounding_mode = kwargs .get ("rounding_mode" )
145147 if rounding_mode is None :
146- return operator . add_div (network , target , None , kwargs_new , name )
148+ return add_div (network , target , None , kwargs_new , name )
147149 elif rounding_mode == "floor" :
148- return operator . add_floor_div (
150+ return add_floor_div (
149151 network , target , None , kwargs_new , name
150152 )
151153 elif rounding_mode == "trunc" :
152- return operator . add_trunc_div (
154+ return add_trunc_div (
153155 network , target , None , kwargs_new , name
154156 )
155157 else :
@@ -170,7 +172,7 @@ def aten_ops_floor_div(
170172 "input" : args [0 ],
171173 "other" : args [1 ],
172174 }
173- return operator . add_floor_div (network , target , None , kwargs_new , name )
175+ return add_floor_div (network , target , None , kwargs_new , name )
174176
175177
176178@tensorrt_converter (torch .ops .aten .fmod .Scalar )
@@ -186,7 +188,7 @@ def aten_ops_fmod(
186188 "input" : args [0 ],
187189 "other" : args [1 ],
188190 }
189- return operator . add_fmod (network , target , None , kwargs_new , name )
191+ return add_fmod (network , target , None , kwargs_new , name )
190192
191193
192194@tensorrt_converter (torch .ops .aten .linear )
@@ -203,7 +205,7 @@ def aten_ops_linear(
203205 "bias" : args [2 ],
204206 }
205207
206- return operator . add_linear (network , target , None , kwargs_new , name )
208+ return add_linear (network , target , None , kwargs_new , name )
207209
208210
209211@tensorrt_converter (torch .ops .aten .max_pool3d )
@@ -252,10 +254,11 @@ def aten_ops_mul(
252254 "input" : args [0 ],
253255 "other" : args [1 ],
254256 }
255- return operator . add_mul (network , target , None , kwargs_new , name )
257+ return add_mul (network , target , None , kwargs_new , name )
256258
257259
258- @tensorrt_converter (torch .ops .aten .matmul .Tensor )
260+ @tensorrt_converter (torch .ops .aten .matmul )
261+ @tensorrt_converter (torch .ops .aten .mm .default )
259262def aten_ops_matmul (
260263 network : TRTNetwork ,
261264 target : Target ,
@@ -267,7 +270,7 @@ def aten_ops_matmul(
267270 "input" : args [0 ],
268271 "other" : args [1 ],
269272 }
270- return operator . add_matmul (network , target , None , kwargs_new , name )
273+ return add_matmul (network , target , None , kwargs_new , name )
271274
272275
273276@tensorrt_converter (torch .ops .aten .pow .Tensor_Scalar )
@@ -283,7 +286,7 @@ def aten_ops_pow(
283286 "input" : args [0 ],
284287 "exponent" : args [1 ],
285288 }
286- return operator . add_pow (network , target , None , kwargs_new , name )
289+ return add_pow (network , target , kwargs_new , name )
287290
288291
289292@tensorrt_converter (torch .ops .aten .relu .default )
@@ -297,7 +300,7 @@ def aten_ops_relu(
297300 kwargs_new = {
298301 "input" : args [0 ],
299302 }
300- return activation . add_relu (network , target , kwargs_new , name )
303+ return add_relu (network , target , kwargs_new , name )
301304
302305@tensorrt_converter (torch .ops .aten .sub .Tensor )
303306def aten_ops_sub (
@@ -311,7 +314,7 @@ def aten_ops_sub(
311314 "input" : args [0 ],
312315 "other" : args [1 ],
313316 }
314- return operator . add_sub (network , target , None , kwargs_new , name )
317+ return add_sub (network , target , None , kwargs_new , name )
315318
316319
317320@tensorrt_converter (torch .ops .aten .view .default )
@@ -378,7 +381,7 @@ def aten_ops_expand(
378381 "input" : args [0 ],
379382 "sizes" : args [1 ],
380383 }
381- return operator . add_expand (network , target , kwargs_new , name )
384+ return add_expand (network , target , kwargs_new , name )
382385
383386
384387@tensorrt_converter (operator .floordiv )
@@ -393,7 +396,7 @@ def aten_ops_operator_floordiv(
393396 "input" : args [0 ],
394397 "other" : args [1 ],
395398 }
396- return operator . add_floor_div (network , target , None , kwargs_new , name )
399+ return add_floor_div (network , target , None , kwargs_new , name )
397400
398401
399402@tensorrt_converter (operator .mul )
@@ -408,7 +411,7 @@ def aten_ops_operator_mul(
408411 "input" : args [0 ],
409412 "other" : args [1 ],
410413 }
411- return operator . add_mul (network , target , None , kwargs_new , name )
414+ return add_mul (network , target , None , kwargs_new , name )
412415
413416
414417@tensorrt_converter (operator .add )
@@ -423,7 +426,7 @@ def aten_ops_operator_add(
423426 "input" : args [0 ],
424427 "other" : args [1 ],
425428 }
426- return operator . add_add (network , target , None , kwargs_new , name )
429+ return add_add (network , target , None , kwargs_new , name )
427430
428431
429432@tensorrt_converter (operator .sub )
@@ -438,7 +441,7 @@ def aten_ops_operator_sub(
438441 "input" : args [0 ],
439442 "other" : args [1 ],
440443 }
441- return operator . add_sub (network , target , None , kwargs_new , name )
444+ return add_sub (network , target , None , kwargs_new , name )
442445
443446
444447@tensorrt_converter (torch .ops .aten .sym_numel )
@@ -497,9 +500,10 @@ def aten_ops_slice(
497500 "stop" : args [3 ],
498501 "step" : args [4 ],
499502 }
500- return operator . add_slice (network , target . kwargs_new , name )
503+ return add_slice (network , target . kwargs_new , name )
501504
502- @tensorrt_converter (torch .ops .aten .select .Tensor )
505+
506+ @tensorrt_converter (torch .ops .aten .select )
503507def aten_ops_select (
504508 network : TRTNetwork ,
505509 target : Target ,
@@ -512,7 +516,7 @@ def aten_ops_select(
512516 "dim" : args [1 ],
513517 "index" : args [2 ],
514518 }
515- return operator . add_select (network , target . kwargs_new , name )
519+ return add_select (network , target . kwargs_new , name )
516520
517521
518522@tensorrt_converter (torch .ops .aten .leaky_relu .default )
@@ -526,7 +530,7 @@ def aten_ops_leaky_relu(
526530 kwargs_new = {
527531 "input" : args [0 ],
528532 }
529- return activation . add_leaky_relu (network , target , kwargs_new , name )
533+ return add_leaky_relu (network , target , kwargs_new , name )
530534
531535
532536@tensorrt_converter (torch .ops .aten .elu .default )
@@ -540,7 +544,7 @@ def aten_ops_elu(
540544 kwargs_new = {
541545 "input" : args [0 ],
542546 }
543- return activation . add_elu (network , target , kwargs_new , name )
547+ return add_elu (network , target , kwargs_new , name )
544548
545549
546550@tensorrt_converter (torch .ops .aten .selu .default )
@@ -554,7 +558,7 @@ def aten_ops_selu(
554558 kwargs_new = {
555559 "input" : args [0 ],
556560 }
557- return activation . selu (network , target , kwargs_new , name )
561+ return add_selu (network , target , kwargs_new , name )
558562
559563
560564@tensorrt_converter (torch .ops .aten .gelu .default )
@@ -568,22 +572,7 @@ def aten_ops_gelu(
568572 kwargs_new = {
569573 "input" : args [0 ],
570574 }
571- return activation .add_gelu (network , target , kwargs_new , name )
572-
573-
574- @tensorrt_converter (torch .ops .aten .softsign .default )
575- def aten_ops_softsign (
576- network : TRTNetwork ,
577- target : Target ,
578- args : Tuple [Argument , ...],
579- kwargs : Dict [str , Argument ],
580- name : str ,
581- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
582- kwargs_new = {
583- "input" : args [0 ],
584- }
585- return activation .add_softsign (network , target , kwargs_new , name )
586-
575+ return add_gelu (network , target , kwargs_new , name )
587576
588577@tensorrt_converter (torch .ops .aten .tanh .default )
589578def aten_ops_tanh (
@@ -596,34 +585,7 @@ def aten_ops_tanh(
596585 kwargs_new = {
597586 "input" : args [0 ],
598587 }
599- return activation .add_tanh (network , target , kwargs_new , name )
600-
601- @tensorrt_converter (torch .ops .aten .softsign .default )
602- def aten_ops_softsign (
603- network : TRTNetwork ,
604- target : Target ,
605- args : Tuple [Argument , ...],
606- kwargs : Dict [str , Argument ],
607- name : str ,
608- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
609- kwargs_new = {
610- "input" : args [0 ],
611- }
612- return activation .add_softsign (network , target , kwargs_new , name )
613-
614-
615- @tensorrt_converter (torch .ops .aten .softsign .default )
616- def aten_ops_hard_sigmoid (
617- network : TRTNetwork ,
618- target : Target ,
619- args : Tuple [Argument , ...],
620- kwargs : Dict [str , Argument ],
621- name : str ,
622- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
623- kwargs_new = {
624- "input" : args [0 ],
625- }
626- return activation .add_hard_sigmoid (network , target , kwargs_new , name )
588+ return add_tanh (network , target , kwargs_new , name )
627589
628590
629591@tensorrt_converter (torch .ops .aten .sigmoid .default )
@@ -637,7 +599,7 @@ def aten_ops_hard_tanh(
637599 kwargs_new = {
638600 "input" : args [0 ],
639601 }
640- return activation . add_hard_tanh (network , target , kwargs_new , name )
602+ return add_hard_tanh (network , target , kwargs_new , name )
641603
642604
643605@tensorrt_converter (torch .ops .aten .sigmoid .default )
@@ -651,7 +613,7 @@ def aten_ops_sigmoid(
651613 kwargs_new = {
652614 "input" : args [0 ],
653615 }
654- return activation . add_sigmoid (network , target , kwargs_new , name )
616+ return add_sigmoid (network , target , kwargs_new , name )
655617
656618
657619
0 commit comments