Skip to content

Commit d1fd1d7

Browse files
committed
Correcting pre-commit hooks
1 parent 09a52b9 commit d1fd1d7

File tree

4 files changed

+39
-33
lines changed

4 files changed

+39
-33
lines changed

py/torch_tensorrt/fx/converters/converter_utils.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,13 @@ def set_layer_name(layer: TRTLayer, target: Target, name: str, is_acc=True) -> N
8787
the node represents.
8888
name (str): Consists of fx node.name with optional suffix.
8989
"""
90-
target_name = target if isinstance(target, str) else f"acc_ops.{target.__name__}" if is_acc else f"aten_ops.{target.__name__}"
90+
target_name = (
91+
target
92+
if isinstance(target, str)
93+
else f"acc_ops.{target.__name__}"
94+
if is_acc
95+
else f"aten_ops.{target.__name__}"
96+
)
9197
layer.name = f"[{layer.type.name}]-[{target_name}]-[{name}]"
9298

9399

py/torch_tensorrt/fx/converters/operator.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1118,17 +1118,17 @@ def add_expand(network, target, kwargs, name):
11181118

11191119
ranks = len(input_val.shape)
11201120
# TRT does not support different dimension size
1121-
#though this condition is not seen in the case of bmm
1121+
# though this condition is not seen in the case of bmm
11221122
# where input_t and shape dimensions are not equal
11231123
assert len(shape) >= ranks
1124-
if(len(shape) != ranks):
1125-
shape_tuple = tuple([0] * len(shape))
1126-
shape_tensor = get_trt_tensor(network, input_t, f"{name}_shape")
1127-
input_val, shape_tensor = broadcast(network, input_val, shape_tensor,
1128-
f"{name}_input_val",
1129-
f"{name}_shape_val")
1130-
ranks = len(shape)
1131-
1124+
if len(shape) != ranks:
1125+
shape_tuple = tuple([0] * len(shape))
1126+
shape_tensor = get_trt_tensor(network, input_t, f"{name}_shape")
1127+
input_val, shape_tensor = broadcast(
1128+
network, input_val, shape_tensor, f"{name}_input_val", f"{name}_shape_val"
1129+
)
1130+
ranks = len(shape)
1131+
11321132
shape = [input_val.shape[i] if shape[i] == -1 else shape[i] for i in range(ranks)]
11331133

11341134
inshape = tuple(input_val.shape)

py/torch_tensorrt/fx/passes/lower_basic_pass_aten.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -417,11 +417,11 @@ def compose_bmm(
417417
input_n = node.all_input_nodes[0]
418418
other_n = node.all_input_nodes[1]
419419

420-
# If no input nodes are available, the bmm argument itself could be an input
420+
# If no input nodes are available, the bmm argument itself could be an input
421421
# Alternatively, if the node has no users, it can be eliminated
422422
if len(input_n.all_input_nodes) == 0 or len(node.users) == 0:
423423
return PassResult(module, modified)
424-
424+
425425
output = next(iter(node.users))
426426
input_input_n = input_n.all_input_nodes[0]
427427
if (
@@ -434,7 +434,7 @@ def compose_bmm(
434434
+ "Skipping bmm lowering on this operation"
435435
)
436436
return PassResult(module, modified)
437-
437+
438438
real_input = input_input_n.all_input_nodes[0]
439439
input_other_n = other_n.all_input_nodes[0]
440440
if (
@@ -447,7 +447,7 @@ def compose_bmm(
447447
+ "Skipping bmm lowering on this operation"
448448
)
449449
return PassResult(module, modified)
450-
450+
451451
real_other = input_other_n.all_input_nodes[0]
452452
if len(real_other.meta["val"].size()) == 2:
453453
new_func = aten_compose_bmm_2d

py/torch_tensorrt/fx/test/converters/aten_op/test_matmul_aten.py

Lines changed: 19 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,13 +19,13 @@ class TestMatMulConverter(DispatchTestCase):
1919
[
2020
("2_2", (2, 3), (3, 2)),
2121
("2_2", (2, 3), (3, 1)),
22-
#FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
22+
# FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
2323
# (2,3), (3,) torch.ops.aten.mv.default
24-
# Following cases use torch.ops.aten.bmm.defauly
25-
#("4_3", (3,1,3,2), (2,2,3)),
26-
#("3_4", (3,1,3,2), (2,2,3)),
27-
#("3_4", (2, 2, 3), (3, 1, 3, 3)),
28-
#("4_2", (1, 2, 2, 3), (3, 2)),
24+
# Following cases use torch.ops.aten.bmm.defauly
25+
# ("4_3", (3,1,3,2), (2,2,3)),
26+
# ("3_4", (3,1,3,2), (2,2,3)),
27+
# ("3_4", (2, 2, 3), (3, 1, 3, 3)),
28+
# ("4_2", (1, 2, 2, 3), (3, 2)),
2929
]
3030
)
3131
def test_matmul_other_constant(self, _, input_shape, other_shape):
@@ -38,7 +38,7 @@ def forward(self, input):
3838
return torch.matmul(input, self.other)
3939

4040
inputs = [torch.randn(*input_shape)]
41-
41+
4242
self.run_test(
4343
MatMul(),
4444
inputs,
@@ -50,14 +50,13 @@ def forward(self, input):
5050
[
5151
("2_2", (2, 3), (3, 2)),
5252
("1_2", (1, 3), (3, 2)),
53-
#FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
53+
# FIXME torch.ops.aten.mv.default for (2,3), (3,1) - should mv be lowered to mm?
5454
# (2,3), (3,) torch.ops.aten.mv.default
55-
# Following cases use torch.ops.aten.bmm.defauly
56-
#("4_3", (3,1,3,2), (2,2,3)),
57-
#("3_4", (3,1,3,2), (2,2,3)),
58-
#("3_4", (2, 2, 3), (3, 1, 3, 3)),
59-
#("4_2", (1, 2, 2, 3), (3, 2)),
60-
55+
# Following cases use torch.ops.aten.bmm.defauly
56+
# ("4_3", (3,1,3,2), (2,2,3)),
57+
# ("3_4", (3,1,3,2), (2,2,3)),
58+
# ("3_4", (2, 2, 3), (3, 1, 3, 3)),
59+
# ("4_2", (1, 2, 2, 3), (3, 2)),
6160
]
6261
)
6362
def test_matmul_input_constant(self, _, input_shape, other_shape):
@@ -75,8 +74,8 @@ def forward(self, other):
7574
MatMul(),
7675
inputs,
7776
expected_ops={torch.ops.aten.mm.default},
78-
test_explicit_batch_dim=True
79-
#test_explicit_batch_dim=(len(other_shape) <= 2),
77+
test_explicit_batch_dim=True
78+
# test_explicit_batch_dim=(len(other_shape) <= 2),
8079
)
8180

8281
@parameterized.expand(
@@ -96,7 +95,7 @@ def forward(self, input, other):
9695
return torch.matmul(input, other)
9796

9897
inputs = [torch.randn(*input_shape), torch.randn(*other_shape)]
99-
test_explicit_batch_dim = not(
98+
test_explicit_batch_dim = not (
10099
input_shape[0] == other_shape[0]
101100
and len(input_shape) > 2
102101
and len(other_shape) > 2
@@ -108,7 +107,8 @@ def forward(self, input, other):
108107
test_explicit_batch_dim=test_explicit_batch_dim,
109108
)
110109

111-
#FIXME: dynamic shape is giving bmm
110+
# FIXME: dynamic shape is giving bmm
111+
112112

113113
if __name__ == "__main__":
114-
run_tests()
114+
run_tests()

0 commit comments

Comments
 (0)