Skip to content

Commit b371e63

Browse files
committed
Address review comments; add support for integer inputs
1 parent fab9832 commit b371e63

File tree

2 files changed

+22
-8
lines changed

2 files changed

+22
-8
lines changed

py/torch_tensorrt/dynamo/conversion/converter_utils.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import collections
12
import functools
23
import logging
34
import re
@@ -50,20 +51,28 @@ def get_node_io(
5051
) -> str:
5152
"""Gets a string representing the node inputs and outputs including tensor shapes and dtypes"""
5253

53-
def format_tensor_metadata(
54-
metadata: Union[TensorMetadata, Sequence[TensorMetadata]]
55-
) -> str:
54+
def format_tensor_metadata(metadata: Union[Any, Sequence[Any]]) -> str:
5655
"""Formats the metadata for a single node"""
5756
# If the provided data is a simple TensorMetadata object, parse it
58-
if isinstance(metadata, TensorMetadata):
59-
return f"{tuple(metadata.shape)}@{metadata.dtype}"
57+
if isinstance(metadata, TensorMetadata) or issubclass(
58+
type(metadata), torch.Tensor
59+
):
60+
return f"{tuple(metadata.shape)}@{metadata.dtype}" # type: ignore
61+
# If the provided data is a scalar, return it as is
62+
elif isinstance(metadata, (int, float, bool)):
63+
return f"{metadata}@Python-{type(metadata)}"
6064
# If the provided data is a sequence, recursively parse it
61-
else:
65+
elif isinstance(metadata, collections.abc.Sequence):
6266
formatted_str = "("
6367
for meta in metadata:
6468
formatted_str += format_tensor_metadata(meta) + ", "
6569

6670
return formatted_str[:-2] + ")"
71+
else:
72+
_LOGGER.warning(
73+
f"Detected unparseable type in node formatting: {type(metadata)}"
74+
)
75+
return ""
6776

6877
# Format input tensors
6978
metadata_string = "Inputs: ("
@@ -74,8 +83,10 @@ def format_tensor_metadata(
7483
if arg.op == "get_attr":
7584
shape, dtype = constant_mapping[str(arg)]
7685
arg_repr = f"{shape}@{dtype}"
77-
elif arg.meta.get("tensor_meta", False):
86+
elif arg.meta.get("tensor_meta") is not None:
7887
arg_repr = format_tensor_metadata(arg.meta["tensor_meta"])
88+
elif arg.meta.get("val") is not None:
89+
arg_repr = format_tensor_metadata(arg.meta["val"])
7990
else:
8091
arg_repr = ""
8192

@@ -92,8 +103,10 @@ def format_tensor_metadata(
92103
if node.op == "get_attr":
93104
shape, dtype = constant_mapping[str(node)]
94105
node_repr = f"{shape}@{dtype}"
95-
elif node.meta.get("tensor_meta", False):
106+
elif node.meta.get("tensor_meta") is not None:
96107
node_repr = format_tensor_metadata(node.meta["tensor_meta"])
108+
elif node.meta.get("val") is not None:
109+
node_repr = format_tensor_metadata(node.meta["val"])
97110
else:
98111
node_repr = ""
99112
metadata_string += f"{node}: {node_repr}, "

py/torch_tensorrt/dynamo/lowering/passes/replace_max_pool_with_indices.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ def replace_max_pool_with_indices(
4343
args=node.args,
4444
kwargs=node.kwargs,
4545
)
46+
maxpool_fused.meta = node.meta
4647

4748
logger.debug(
4849
f"Replacing all uses of nodes {node}, {getitem_node} with fused maxpool node {maxpool_fused} "

0 commit comments

Comments
 (0)