1+ import torch
2+ from torch_tensorrt .fx .tracer .acc_tracer import acc_ops
3+
4+
5+ def check_permute (node : torch .fx .Node ):
6+ ranks = len (node .meta ["tensor_meta" ].shape )
7+ permutation = list (i % ranks for i in node .kwargs ["permutation" ]) # type: ignore[union-attr]
8+ allowed_permutation = list (i for i in range (ranks ))
9+ allowed_permutation [- 1 ] = ranks - 2
10+ allowed_permutation [- 2 ] = ranks - 1
11+ return permutation == allowed_permutation
12+
13+
14+ def fuse_permute_matmul (gm : torch .fx .GraphModule ):
15+ """
16+ Fuse pattern like permute + matmul if permute is transposing the last two dimension.
17+ """
18+ for node in gm .graph .nodes :
19+ if node .target == acc_ops .matmul :
20+ lhs , rhs = node .kwargs ["input" ], node .kwargs ["other" ]
21+ lhs_transposed = rhs_tranposed = False
22+ skip = False
23+
24+ if lhs .target == acc_ops .permute and check_permute (lhs ):
25+ lhs_transposed = True
26+ lhs = lhs .kwargs ["input" ]
27+
28+ if rhs .target == acc_ops .permute and check_permute (rhs ):
29+ rhs_tranposed = True
30+ rhs = rhs .kwargs ["input" ]
31+
32+ if (not skip ) and (lhs_transposed or rhs_tranposed ):
33+ with gm .graph .inserting_before (node ):
34+ fused_node = gm .graph .call_function (
35+ trt_transposed_matmul ,
36+ args = (lhs , rhs , lhs_transposed , rhs_tranposed ),
37+ )
38+ node .replace_all_uses_with (fused_node )
39+
40+ gm .graph .eliminate_dead_code ()
41+ gm .graph .lint ()
42+ gm .recompile ()
43+ return gm
44+
45+
46+ def trt_transposed_linear (
47+ input : torch .Tensor , weight : torch .Tensor , bias : torch .Tensor
48+ ):
49+ return torch .matmul (input .transpose (- 1 , - 2 ), weight .t ()) + bias
50+
51+
52+ def fuse_permute_linear (gm : torch .fx .GraphModule ):
53+ """
54+ Fuse pattern like permute + linear if permute is transposing the last two dimension.
55+ """
56+ for node in gm .graph .nodes :
57+ if node .target == acc_ops .linear :
58+ inp = node .kwargs ["input" ]
59+ if inp .target == acc_ops .permute and check_permute (inp ):
60+ inp = inp .kwargs ["input" ]
61+ weight = node .kwargs ["weight" ]
62+ bias = node .kwargs ["bias" ]
63+ with gm .graph .inserting_before (node ):
64+ fused_node = gm .graph .call_function (
65+ trt_transposed_linear , args = (inp , weight , bias )
66+ )
67+ node .replace_all_uses_with (fused_node )
68+
69+ gm .graph .eliminate_dead_code ()
70+ gm .graph .lint ()
71+ gm .recompile ()
72+ return gm
0 commit comments