Skip to content

Commit 09a52b9

Browse files
committed
matmul changes, bmm changes and adding broadcastable
1 parent dba4988 commit 09a52b9

File tree

4 files changed

+167
-18
lines changed

4 files changed

+167
-18
lines changed

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ def get_positive_dim(dim: int, dim_size: int) -> int:
7777
return dim
7878

7979

80-
def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
80+
def set_layer_name(layer: TRTLayer, target: Target, name: str, is_acc=True) -> None:
8181
"""
8282
Set the TensorRT layer name to "[TensorRT Layer Type]_[Original Op Name]_[FX Node Name with Suffix]"
8383
@@ -87,7 +87,7 @@ def set_layer_name(layer: TRTLayer, target: Target, name: str) -> None:
8787
the node represents.
8888
name (str): Consists of fx node.name with optional suffix.
8989
"""
90-
target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}"
90+
target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}" if is_acc else f"aten_ops.{target.__name__}"
9191
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"
9292

9393

@@ -288,6 +288,39 @@ def prepend_ones(
288288
return layer.get_output(0)
289289

290290

291+
def broadcastable(
292+
a: TRTTensor,
293+
b: TRTTensor,
294+
) -> bool:
295+
"Check if two tensors are broadcastable according to torch rules"
296+
a_shape = tuple(a.shape)
297+
b_shape = tuple(b.shape)
298+
print("a shape is", a_shape)
299+
print("b shape is", b_shape)
300+
# check from the trailing
301+
diff = len(a_shape) - len(b_shape)
302+
if diff == 0:
303+
return True
304+
if diff > 0:
305+
max = len(a_shape)
306+
min = len(b_shape)
307+
greater_tensor = a_shape
308+
lesser_tensor = b_shape
309+
elif diff < 0:
310+
max = len(b_shape)
311+
min = len(a_shape)
312+
greater_tensor = b_shape
313+
lesser_tensor = a_shape
314+
j = min - 1
315+
for i in range(max - 1, diff - 1, -1):
316+
if not (
317+
greater_tensor[i] != lesser_tensor[j]
318+
and (greater_tensor[i] == 1 or lesser_tensor[i] == 1)
319+
):
320+
return False
321+
return True
322+
323+
291324
def broadcast(
292325
network: TRTNetwork,
293326
a: TRTTensor,

py/torch_tensorrt/fx/converters/operator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from .converter_utils import set_layer_name
1717
from .converter_utils import get_trt_tensor
1818
from .converter_utils import broadcast
19+
from .converter_utils import broadcastable
1920
from .converter_utils import squeeze_left
2021
from .converter_utils import dtype_uniform
2122
from .converter_utils import get_trt_plugin
@@ -1117,7 +1118,17 @@ def add_expand(network, target, kwargs, name):
11171118

11181119
ranks = len(input_val.shape)
11191120
# TRT does not support different dimension size
1120-
assert len(shape) == ranks
1121+
#though this condition is not seen in the case of bmm
1122+
# where input_t and shape dimensions are not equal
1123+
assert len(shape) >= ranks
1124+
if(len(shape) != ranks):
1125+
shape_tuple = tuple([0] * len(shape))
1126+
shape_tensor = get_trt_tensor(network, input_t, f"{name}_shape")
1127+
input_val, shape_tensor = broadcast(network, input_val, shape_tensor,
1128+
f"{name}_input_val",
1129+
f"{name}_shape_val")
1130+
ranks = len(shape)
1131+
11211132
shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)]
11221133

11231134
inshape = tuple(input_val.shape)

py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py

Lines changed: 23 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -416,29 +416,46 @@ def compose_bmm(
416416
node = n
417417
input_n = node.all_input_nodes[0]
418418
other_n = node.all_input_nodes[1]
419+
420+
# If no input nodes are available, the bmm argument itself could be an input
421+
# Alternatively, if the node has no users, it can be eliminated
422+
if len(input_n.all_input_nodes) == 0 or len(node.users) == 0:
423+
return PassResult(module, modified)
424+
419425
output = next(iter(node.users))
420426
input_input_n = input_n.all_input_nodes[0]
421427
if (
422428
input_input_n.target != torch.ops.aten.expand.default
423429
and input_n.target != torch.ops.aten.view.default
424430
):
425-
raise RuntimeError(
426-
"Bmm is addressed in fixed pattern. A new pattern is met!"
431+
_LOGGER.warn(
432+
"Bmm is addressed in fixed pattern. "
433+
+ f"A new pattern {input_input_n.target}, {input_n.target} is met! "
434+
+ "Skipping bmm lowering on this operation"
427435
)
436+
return PassResult(module, modified)
437+
428438
real_input = input_input_n.all_input_nodes[0]
429439
input_other_n = other_n.all_input_nodes[0]
430440
if (
431441
input_other_n.target != torch.ops.aten.expand.default
432442
and other_n.target != torch.ops.aten.view.default
433443
):
434-
raise RuntimeError(
435-
"Bmm is addressed in fixed pattern. A new pattern is met!"
444+
_LOGGER.warn(
445+
"Bmm is addressed in fixed pattern. "
446+
+ f"A new pattern {input_other_n.target}, {other_n.target} is met! "
447+
+ "Skipping bmm lowering on this operation"
436448
)
449+
return PassResult(module, modified)
450+
437451
real_other = input_other_n.all_input_nodes[0]
438452
if len(real_other.meta["val"].size()) == 2:
439453
new_func = aten_compose_bmm_2d
440-
if len(real_other.meta["val"].size()) == 3:
454+
elif len(real_other.meta["val"].size()) == 3:
441455
new_func = aten_compose_bmm_3d
456+
else:
457+
# No valid bmm replacement exists for the specified dimensions
458+
return PassResult(module, modified)
442459

443460
with module.graph.inserting_after(node):
444461
new_args = (real_input, real_other)
@@ -449,6 +466,7 @@ def compose_bmm(
449466
kwargs=None,
450467
)
451468
output.replace_all_uses_with(new_node)
469+
modified = True
452470

453471
module.graph.eliminate_dead_code()
454472
module.recompile()

py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py

Lines changed: 97 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,22 +6,109 @@
66
from torch.testing._internal.common_utils import run_tests
77
from torch_tensorrt.fx.tools.common_fx2trt import DispatchTestCase, InputTensorSpec
88

9+
import torch
10+
import torch.nn as nn
11+
import torch_tensorrt.fx.tracer.acc_tracer.acc_ops as acc_ops
12+
from parameterized import parameterized
13+
from torch.testing._internal.common_utils import run_tests
14+
from torch_tensorrt.fx.tools.common_fx2trt import AccTestCase, InputTensorSpec
15+
916

1017
class TestMatMulConverter(DispatchTestCase):
11-
def test_matmul(self):
12-
class TestModule(torch.nn.Module):
13-
def forward(self, x, y):
14-
return torch.matmul(x, y)
15-
16-
inputOne = torch.randn(2, 32)
17-
inputTwo = torch.randn(32, 2)
18-
inputs = [inputOne, inputTwo]
18+
@parameterized.expand(
19+
[
20+
("2_2", (2, 3), (3, 2)),
21+
("2_2", (2, 3), (3, 1)),
22+
#FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
23+
# (2,3), (3,) torch.ops.aten.mv.default
24+
# Following cases use torch.ops.aten.bmm.defauly
25+
#("4_3", (3,1,3,2), (2,2,3)),
26+
#("3_4", (3,1,3,2), (2,2,3)),
27+
#("3_4", (2, 2, 3), (3, 1, 3, 3)),
28+
#("4_2", (1, 2, 2, 3), (3, 2)),
29+
]
30+
)
31+
def test_matmul_other_constant(self, _, input_shape, other_shape):
32+
class MatMul(nn.Module):
33+
def __init__(self):
34+
super().__init__()
35+
self.other = nn.Parameter(torch.randn(*other_shape))
36+
37+
def forward(self, input):
38+
return torch.matmul(input, self.other)
39+
40+
inputs = [torch.randn(*input_shape)]
41+
42+
self.run_test(
43+
MatMul(),
44+
inputs,
45+
expected_ops={torch.ops.aten.mm.default},
46+
test_explicit_batch_dim=(len(input_shape) >= 1),
47+
)
48+
49+
@parameterized.expand(
50+
[
51+
("2_2", (2, 3), (3, 2)),
52+
("1_2", (1, 3), (3, 2)),
53+
#FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
54+
# (2,3), (3,) torch.ops.aten.mv.default
55+
# Following cases use torch.ops.aten.bmm.defauly
56+
#("4_3", (3,1,3,2), (2,2,3)),
57+
#("3_4", (3,1,3,2), (2,2,3)),
58+
#("3_4", (2, 2, 3), (3, 1, 3, 3)),
59+
#("4_2", (1, 2, 2, 3), (3, 2)),
60+
61+
]
62+
)
63+
def test_matmul_input_constant(self, _, input_shape, other_shape):
64+
class MatMul(nn.Module):
65+
def __init__(self):
66+
super().__init__()
67+
self.input = nn.Parameter(torch.randn(*input_shape))
68+
69+
def forward(self, other):
70+
return torch.matmul(self.input, other)
71+
72+
inputs = [torch.randn(*other_shape)]
73+
74+
self.run_test(
75+
MatMul(),
76+
inputs,
77+
expected_ops={torch.ops.aten.mm.default},
78+
test_explicit_batch_dim=True
79+
#test_explicit_batch_dim=(len(other_shape) <= 2),
80+
)
81+
82+
@parameterized.expand(
83+
[
84+
("2_2", (2, 3), (3, 2)),
85+
# ("2_3", (2, 3), (2, 3, 4)),
86+
# ("4_4", (2, 2, 2, 3), (2, 1, 3, 2)),
87+
# ("4_2", (2, 1, 2, 3), (3, 2)),
88+
# ("2_1", (2, 3), (3,)),
89+
# ("1_2", (3,), (3, 2)),
90+
# ("1_1", (3,), (3,)),
91+
]
92+
)
93+
def test_matmul(self, _, input_shape, other_shape):
94+
class MatMul(nn.Module):
95+
def forward(self, input, other):
96+
return torch.matmul(input, other)
97+
98+
inputs = [torch.randn(*input_shape), torch.randn(*other_shape)]
99+
test_explicit_batch_dim = not(
100+
input_shape[0] == other_shape[0]
101+
and len(input_shape) > 2
102+
and len(other_shape) > 2
103+
)
19104
self.run_test(
20-
TestModule(),
105+
MatMul(),
21106
inputs,
22107
expected_ops={torch.ops.aten.mm.default},
108+
test_explicit_batch_dim=test_explicit_batch_dim,
23109
)
24110

111+
#FIXME: dynamic shape is giving bmm
25112

26113
if __name__ == "__main__":
27-
run_tests()
114+
run_tests()

0 commit comments

Comments
 (0)