2121from .converter_utils import * # noqa: F403
2222import torch_tensorrt .fx .tracer .acc_tracer .acc_utils as acc_utils
2323from torch_tensorrt .fx .converters .impl import activation , convolution
24- from torch_tensorrt .fx .converters .impl .elementwise import trunc_div
25- from torch_tensorrt .fx .converters .impl .elementwise import rsqrt
26- from torch_tensorrt .fx .converters .impl .elementwise import fmod
27- from torch_tensorrt .fx .converters .impl .elementwise import rsub
28- from torch_tensorrt .fx .converters .impl .normalization import batch_norm
29- from torch_tensorrt .fx .converters .impl .normalization import layer_norm
30- from torch_tensorrt .fx .converters .impl .normalization import softmax
31- from torch_tensorrt .fx .converters .impl .squeeze import squeeze
32- from torch_tensorrt .fx .converters .impl .select import select
33- from torch_tensorrt .fx .converters .impl .slice import slice_op
34- from torch_tensorrt .fx .converters .impl .matmul import matrix_multiply
35- from torch_tensorrt .fx .converters .impl .condition import where
36- from torch_tensorrt .fx .converters .impl .unsqueeze import unsqueeze
37- from torch_tensorrt .fx .converters .impl .elementwise import clamp
3824
3925_LOGGER : logging .Logger = logging .getLogger (__name__ )
4026
41-
42- def or_none (args , i ):
43- return args [i ] if len (args ) > i else None
44-
45-
4627## converter list in alphabetic order
4728@tensorrt_converter (torch .ops .aten .add .Tensor )
4829def aten_ops_add (
@@ -108,19 +89,18 @@ def aten_ops_batch_norm(
10889 kwargs : Dict [str , Argument ],
10990 name : str ,
11091) -> Union [TRTTensor , Sequence [TRTTensor ]]:
111- return batch_norm (
112- network ,
113- target ,
114- SourceIR .ATEN ,
115- name ,
116- args [0 ],
117- args [1 ],
118- args [2 ],
119- args [3 ],
120- args [4 ],
121- args [5 ],
122- args [6 ],
123- args [7 ],
92+ kwargs_new = {
93+ "input" : args [0 ],
94+ "weight" : args [1 ],
95+ "bias" : args [2 ],
96+ "running_mean" : args [3 ],
97+ "running_var" : args [4 ],
98+ "training" : args [5 ],
99+ "momentum" : args [6 ],
100+ "eps" : args [7 ],
101+ }
102+ return acc_ops_converters .acc_ops_batch_norm (
103+ network , target , None , kwargs_new , name
124104 )
125105
126106
@@ -202,7 +182,9 @@ def aten_ops_div(
202182 network , target , None , kwargs_new , name
203183 )
204184 elif rounding_mode == "trunc" :
205- return trunc_div (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
185+ return acc_ops_converters .acc_ops_trunc_div (
186+ network , target , None , kwargs_new , name
187+ )
206188 else :
207189 raise RuntimeError (
208190 f"Target { target } does not support rounding mode { rounding_mode } "
@@ -260,7 +242,11 @@ def aten_ops_fmod(
260242 kwargs : Dict [str , Argument ],
261243 name : str ,
262244) -> Union [TRTTensor , Sequence [TRTTensor ]]:
263- return fmod (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
245+ kwargs_new = {
246+ "input" : args [0 ],
247+ "other" : args [1 ],
248+ }
249+ return acc_ops_converters .acc_ops_fmod (network , target , None , kwargs_new , name )
264250
265251
266252@tensorrt_converter (torch .ops .aten .hardtanh .default )
@@ -271,40 +257,12 @@ def aten_ops_hardtanh(
271257 kwargs : Dict [str , Argument ],
272258 name : str ,
273259) -> Union [TRTTensor , Sequence [TRTTensor ]]:
260+
274261 return activation .hardtanh (
275262 network , target , SourceIR .ATEN , name , args [0 ], args [1 ], args [2 ]
276263 )
277264
278265
279- @tensorrt_converter (torch .ops .aten .gelu .default )
280- def aten_ops_gelu (
281- network : TRTNetwork ,
282- target : Target ,
283- args : Tuple [Argument , ...],
284- kwargs : Dict [str , Argument ],
285- name : str ,
286- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
287- return activation .gelu (
288- network ,
289- target ,
290- SourceIR .ATEN ,
291- name ,
292- args [0 ],
293- )
294-
295-
296- @tensorrt_converter (torch .ops .aten .matmul )
297- @tensorrt_converter (torch .ops .aten .mm .default )
298- def aten_ops_matmul (
299- network : TRTNetwork ,
300- target : Target ,
301- args : Tuple [Argument , ...],
302- kwargs : Dict [str , Argument ],
303- name : str ,
304- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
305- return matrix_multiply (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
306-
307-
308266@tensorrt_converter (torch .ops .aten .fmod .Tensor )
309267def aten_ops_fmod (
310268 network : TRTNetwork ,
@@ -328,28 +286,8 @@ def aten_ops_leaky_relu(
328286 kwargs : Dict [str , Argument ],
329287 name : str ,
330288) -> Union [TRTTensor , Sequence [TRTTensor ]]:
331- return activation .leaky_relu (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
332-
333289
334- @tensorrt_converter (torch .ops .aten .layer_norm .default )
335- def aten_ops_layernorm (
336- network : TRTNetwork ,
337- target : Target ,
338- args : Tuple [Argument , ...],
339- kwargs : Dict [str , Argument ],
340- name : str ,
341- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
342- return layer_norm (
343- network ,
344- target ,
345- SourceIR .ATEN ,
346- name ,
347- args [0 ],
348- args [1 ],
349- args [2 ],
350- args [3 ],
351- args [4 ],
352- )
290+ return activation .leaky_relu (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
353291
354292
355293@tensorrt_converter (torch .ops .aten .linear )
@@ -452,42 +390,6 @@ def aten_ops_relu(
452390 )
453391
454392
455- @tensorrt_converter (torch .ops .aten .relu .default )
456- def aten_ops_relu (
457- network : TRTNetwork ,
458- target : Target ,
459- args : Tuple [Argument , ...],
460- kwargs : Dict [str , Argument ],
461- name : str ,
462- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
463-
464- return activation .relu (
465- network ,
466- target ,
467- SourceIR .ATEN ,
468- name ,
469- args [0 ],
470- )
471-
472-
473- @tensorrt_converter (torch .ops .aten .rsqrt .default )
474- def aten_ops_rsqrt (
475- network : TRTNetwork ,
476- target : Target ,
477- args : Tuple [Argument , ...],
478- kwargs : Dict [str , Argument ],
479- name : str ,
480- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
481-
482- return rsqrt (
483- network ,
484- target ,
485- SourceIR .ATEN ,
486- name ,
487- args [0 ],
488- )
489-
490-
491393@tensorrt_converter (torch .ops .aten .sub .Tensor )
492394def aten_ops_sub (
493395 network : TRTNetwork ,
@@ -503,29 +405,6 @@ def aten_ops_sub(
503405 return acc_ops_converters .acc_ops_sub (network , target , None , kwargs_new , name )
504406
505407
506- @tensorrt_converter (torch .ops .aten .squeeze .dim )
507- @tensorrt_converter (torch .ops .aten .squeeze .dims )
508- def aten_ops_squeeze (
509- network : TRTNetwork ,
510- target : Target ,
511- args : Tuple [Argument , ...],
512- kwargs : Dict [str , Argument ],
513- name : str ,
514- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
515- return squeeze (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
516-
517-
518- @tensorrt_converter (torch .ops .aten .unsqueeze .default )
519- def aten_ops_unsqueeze (
520- network : TRTNetwork ,
521- target : Target ,
522- args : Tuple [Argument , ...],
523- kwargs : Dict [str , Argument ],
524- name : str ,
525- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
526- return unsqueeze (network , target , SourceIR .ATEN , name , input_t = args [0 ], dim = args [1 ])
527-
528-
529408@tensorrt_converter (torch .ops .aten .view .default )
530409def aten_ops_reshape (
531410 network : TRTNetwork ,
@@ -563,31 +442,6 @@ def aten_ops_reshape(
563442 return layer .get_output (0 )
564443
565444
566- @tensorrt_converter (torch .ops .aten .rsub .Tensor )
567- def aten_ops_rsub (
568- network : TRTNetwork ,
569- target : Target ,
570- args : Tuple [Argument , ...],
571- kwargs : Dict [str , Argument ],
572- name : str ,
573- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
574- alpha = None
575- if "alpha" in kwargs :
576- alpha = kwargs ["alpha" ]
577- return rsub (network , target , SourceIR .ATEN , name , args [0 ], args [1 ], alpha )
578-
579-
580- @tensorrt_converter (torch .ops .aten ._softmax .default )
581- def aten_ops_softmax (
582- network : TRTNetwork ,
583- target : Target ,
584- args : Tuple [Argument , ...],
585- kwargs : Dict [str , Argument ],
586- name : str ,
587- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
588- return softmax (network , target , SourceIR .ATEN , name , args [0 ], args [1 ])
589-
590-
591445@tensorrt_converter (torch .ops .aten .tanh .default )
592446def aten_ops_tanh (
593447 network : TRTNetwork ,
@@ -596,30 +450,12 @@ def aten_ops_tanh(
596450 kwargs : Dict [str , Argument ],
597451 name : str ,
598452) -> Union [TRTTensor , Sequence [TRTTensor ]]:
599- return activation .tanh (
600- network ,
601- target ,
602- SourceIR .ATEN ,
603- name ,
604- args [0 ],
605- )
606453
607-
608- @tensorrt_converter (torch .ops .aten .where .self )
609- def aten_ops_where (
610- network : TRTNetwork ,
611- target : Target ,
612- args : Tuple [Argument , ...],
613- kwargs : Dict [str , Argument ],
614- name : str ,
615- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
616- return where (
454+ return activation .tanh (
617455 network ,
618456 target ,
619457 SourceIR .ATEN ,
620458 name ,
621- args [1 ],
622- args [2 ],
623459 args [0 ],
624460 )
625461
@@ -639,25 +475,6 @@ def aten_ops_cat(
639475 return acc_ops_converters .acc_ops_cat (network , target , None , kwargs_new , name )
640476
641477
642- @tensorrt_converter (torch .ops .aten .clamp .default )
643- def aten_ops_clamp (
644- network : TRTNetwork ,
645- target : Target ,
646- args : Tuple [Argument , ...],
647- kwargs : Dict [str , Argument ],
648- name : str ,
649- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
650- return clamp .clamp (
651- network ,
652- target ,
653- SourceIR .ACC ,
654- name ,
655- input_val = args [0 ],
656- min_val = or_none (args , 1 ),
657- max_val = or_none (args , 2 ),
658- )
659-
660-
661478@tensorrt_converter (torch .ops .aten .expand .default )
662479def aten_ops_expand (
663480 network : TRTNetwork ,
@@ -720,17 +537,6 @@ def aten_ops_operator_add(
720537 return acc_ops_converters .acc_ops_add (network , target , None , kwargs_new , name )
721538
722539
723- @tensorrt_converter (torch .ops .aten .select .int )
724- def aten_ops_select (
725- network : TRTNetwork ,
726- target : Target ,
727- args : Tuple [Argument , ...],
728- kwargs : Dict [str , Argument ],
729- name : str ,
730- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
731- return select (network , target , SourceIR .ATEN , name , args [0 ], args [1 ], args [2 ])
732-
733-
734540@tensorrt_converter (operator .sub )
735541def aten_ops_operator_sub (
736542 network : TRTNetwork ,
@@ -766,27 +572,6 @@ def aten_ops_sym_numel(
766572 return reduce_layer .get_output (0 )
767573
768574
769- @tensorrt_converter (torch .ops .aten .slice .Tensor )
770- def aten_ops_slice (
771- network : TRTNetwork ,
772- target : Target ,
773- args : Tuple [Argument , ...],
774- kwargs : Dict [str , Argument ],
775- name : str ,
776- ) -> Union [TRTTensor , Sequence [TRTTensor ]]:
777- return slice_op (
778- network ,
779- target ,
780- SourceIR .ATEN ,
781- name ,
782- args [0 ],
783- args [1 ],
784- args [2 ],
785- args [3 ],
786- args [4 ],
787- )
788-
789-
790575@tensorrt_converter (torch .ops .aten .sym_size )
791576def aten_ops_sym_size (
792577 network : TRTNetwork ,
0 commit comments