Skip to content

Commit a5c903f

Browse files
authored
Fix looping in torch guard decorator (#42260)
* fix * add * fix * switch loop order for perfs * typo
1 parent 67302b0 commit a5c903f

File tree

1 file changed

+28
-11
lines changed

1 file changed

+28
-11
lines changed

src/transformers/initialization.py

Lines changed: 28 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -162,6 +162,25 @@ def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
162162
return tensor
163163

164164

165+
# Here, we need to check several modules imported, and hot patch all of them, as sometimes torch does
166+
# something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules.activations,
167+
# where MultiHeadAttention lives), so the function name is binded at import time and just doing
168+
# `setattr(torch.nn.init, name, globals()[name])` is thus not enough
169+
# The following list should be enough for all torch versions we work with
170+
TORCH_MODULES_TO_PATCH = (
171+
"torch.nn.init",
172+
"torch.nn.modules.activation",
173+
"torch.nn.modules.transformer",
174+
"torch.nn.modules.linear",
175+
"torch.nn.modules.loss",
176+
"torch.nn.modules.batchnorm",
177+
"torch.nn.modules.conv",
178+
"torch.nn.modules.normalization",
179+
"torch.nn.modules.rnn",
180+
"torch.nn.modules.sparse",
181+
)
182+
183+
165184
@contextmanager
166185
def guard_torch_init_functions():
167186
"""
@@ -174,18 +193,16 @@ def guard_torch_init_functions():
174193
originals = defaultdict(dict)
175194
try:
176195
# Replace all torch funcs by the ones in this file
177-
for name in TORCH_INIT_FUNCTIONS.keys():
178-
# Here, we need to check all modules imported, and hot patch all of them, as usually torch does
179-
# something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules,
180-
# where MultiHeadAttention lives), so the function name is binded at import time and just doing
181-
# `setattr(torch.nn.init, name, gloabls()[name])` is thus not enough
182-
for module in sys.modules.copy().values():
183-
if module and hasattr(module, name):
184-
originals[module][name] = getattr(module, name)
185-
setattr(module, name, globals()[name])
196+
for module_name in TORCH_MODULES_TO_PATCH:
197+
if module_name in sys.modules:
198+
module = sys.modules[module_name]
199+
for func_name in TORCH_INIT_FUNCTIONS.keys():
200+
if hasattr(module, func_name):
201+
originals[module][func_name] = getattr(module, func_name)
202+
setattr(module, func_name, globals()[func_name])
186203
yield
187204
finally:
188205
# Set back the original functions on all modules
189206
for module, functions in originals.items():
190-
for name, func in functions.items():
191-
setattr(module, name, func)
207+
for func_name, func in functions.items():
208+
setattr(module, func_name, func)

0 commit comments

Comments
 (0)