1
+ import collections
1
2
import functools
2
3
import logging
3
4
import re
@@ -50,20 +51,28 @@ def get_node_io(
50
51
) -> str :
51
52
"""Gets a string representing the node inputs and outputs including tensor shapes and dtypes"""
52
53
53
- def format_tensor_metadata (
54
- metadata : Union [TensorMetadata , Sequence [TensorMetadata ]]
55
- ) -> str :
54
+ def format_tensor_metadata (metadata : Union [Any , Sequence [Any ]]) -> str :
56
55
"""Formats the metadata for a single node"""
57
56
# If the provided data is a simple TensorMetadata object, parse it
58
- if isinstance (metadata , TensorMetadata ):
59
- return f"{ tuple (metadata .shape )} @{ metadata .dtype } "
57
+ if isinstance (metadata , TensorMetadata ) or issubclass (
58
+ type (metadata ), torch .Tensor
59
+ ):
60
+ return f"{ tuple (metadata .shape )} @{ metadata .dtype } " # type: ignore
61
+ # If the provided data is a scalar, return it as is
62
+ elif isinstance (metadata , (int , float , bool )):
63
+ return f"{ metadata } @Python-{ type (metadata )} "
60
64
# If the provided data is a sequence, recursively parse it
61
- else :
65
+ elif isinstance ( metadata , collections . abc . Sequence ) :
62
66
formatted_str = "("
63
67
for meta in metadata :
64
68
formatted_str += format_tensor_metadata (meta ) + ", "
65
69
66
70
return formatted_str [:- 2 ] + ")"
71
+ else :
72
+ _LOGGER .warning (
73
+ f"Detected unparseable type in node formatting: { type (metadata )} "
74
+ )
75
+ return ""
67
76
68
77
# Format input tensors
69
78
metadata_string = "Inputs: ("
@@ -74,8 +83,10 @@ def format_tensor_metadata(
74
83
if arg .op == "get_attr" :
75
84
shape , dtype = constant_mapping [str (arg )]
76
85
arg_repr = f"{ shape } @{ dtype } "
77
- elif arg .meta .get ("tensor_meta" , False ) :
86
+ elif arg .meta .get ("tensor_meta" ) is not None :
78
87
arg_repr = format_tensor_metadata (arg .meta ["tensor_meta" ])
88
+ elif arg .meta .get ("val" ) is not None :
89
+ arg_repr = format_tensor_metadata (arg .meta ["val" ])
79
90
else :
80
91
arg_repr = ""
81
92
@@ -92,8 +103,10 @@ def format_tensor_metadata(
92
103
if node .op == "get_attr" :
93
104
shape , dtype = constant_mapping [str (node )]
94
105
node_repr = f"{ shape } @{ dtype } "
95
- elif node .meta .get ("tensor_meta" , False ) :
106
+ elif node .meta .get ("tensor_meta" ) is not None :
96
107
node_repr = format_tensor_metadata (node .meta ["tensor_meta" ])
108
+ elif node .meta .get ("val" ) is not None :
109
+ node_repr = format_tensor_metadata (node .meta ["val" ])
97
110
else :
98
111
node_repr = ""
99
112
metadata_string += f"{ node } : { node_repr } , "
0 commit comments