Skip to content

Commit 815751b

Browse files
gs-oliveperi044
authored andcommitted
fix: FakeTensors appearing in get_attr calls (#2669)
1 parent d859859 commit 815751b

File tree

1 file changed

+39
-6
lines changed

1 file changed

+39
-6
lines changed

py/torch_tensorrt/dynamo/lowering/passes/constant_folding.py

Lines changed: 39 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -11,12 +11,9 @@
1111

1212
# Modify import location of utilities based on Torch version
1313
if version.parse(sanitized_torch_version()) < version.parse("2.1.1"):
14-
from torch._inductor.freezing import ConstantFolder, replace_node_with_constant
14+
from torch._inductor.freezing import ConstantFolder
1515
else:
16-
from torch._inductor.constant_folding import (
17-
ConstantFolder,
18-
replace_node_with_constant,
19-
)
16+
from torch._inductor.constant_folding import ConstantFolder
2017

2118
logger = logging.getLogger(__name__)
2219

@@ -36,7 +33,9 @@ def constant_fold(
3633
cf.run()
3734

3835
for node, constant in cf.node_replacements.items():
39-
replace_node_with_constant(gm, node, constant)
36+
replace_node_with_constant(
37+
gm, node, torch.nn.Parameter(constant.cuda(), requires_grad=False)
38+
)
4039

4140
erased_params = []
4241
for node in gm.graph.nodes:
@@ -55,6 +54,40 @@ def constant_fold(
5554
return gm
5655

5756

57+
def replace_node_with_constant(
58+
gm: torch.fx.GraphModule, node: torch.fx.Node, constant: torch.Tensor
59+
) -> None:
60+
"""Adapted from:
61+
https://github.com/pytorch/pytorch/blob/bcf35c6ae62bb6560befa3550e37a8283944e5f4/torch/_inductor/constant_folding.py#L17-L43
62+
63+
Modified to register parameters, instead of buffers for frozen constants
64+
"""
65+
g = gm.graph
66+
67+
if not hasattr(gm, "_frozen_param_count"):
68+
gm._frozen_param_count = 0
69+
70+
i = gm._frozen_param_count
71+
72+
while True:
73+
qualname = f"_frozen_param{i}"
74+
if not hasattr(gm, qualname):
75+
break
76+
i += 1
77+
78+
gm._frozen_param_count = i + 1
79+
80+
with g.inserting_before(node):
81+
new_input_node = g.create_node("get_attr", qualname, (), {})
82+
node.replace_all_uses_with(new_input_node)
83+
new_input_node.meta.update(node.meta)
84+
g.erase_node(node)
85+
86+
# Needed to suppress `does not reference an nn.Module, nn.Parameter, or buffer` warning
87+
gm.register_parameter(qualname, constant)
88+
setattr(gm, qualname, constant)
89+
90+
5891
# TODO: Delete this class when the following code is fixed in nightly:
5992
# https://github.com/pytorch/pytorch/blob/4b881b0da390c1290bb12850ef9daad6f6eb2cb6/torch/_inductor/constant_folding.py#L53-L63
6093
class _TorchTensorRTConstantFolder(ConstantFolder): # type: ignore[misc]

0 commit comments

Comments
 (0)