@@ -162,6 +162,25 @@ def copy_(tensor: torch.Tensor, other: torch.Tensor) -> torch.Tensor:
162162 return tensor
163163
164164
165+ # Here, we need to check several modules imported, and hot patch all of them, as sometimes torch does
166+ # something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules.activations,
167+ # where MultiHeadAttention lives), so the function name is binded at import time and just doing
168+ # `setattr(torch.nn.init, name, globals()[name])` is thus not enough
169+ # The following list should be enough for all torch versions we work with
170+ TORCH_MODULES_TO_PATCH = (
171+ "torch.nn.init" ,
172+ "torch.nn.modules.activation" ,
173+ "torch.nn.modules.transformer" ,
174+ "torch.nn.modules.linear" ,
175+ "torch.nn.modules.loss" ,
176+ "torch.nn.modules.batchnorm" ,
177+ "torch.nn.modules.conv" ,
178+ "torch.nn.modules.normalization" ,
179+ "torch.nn.modules.rnn" ,
180+ "torch.nn.modules.sparse" ,
181+ )
182+
183+
165184@contextmanager
166185def guard_torch_init_functions ():
167186 """
@@ -174,18 +193,16 @@ def guard_torch_init_functions():
174193 originals = defaultdict (dict )
175194 try :
176195 # Replace all torch funcs by the ones in this file
177- for name in TORCH_INIT_FUNCTIONS .keys ():
178- # Here, we need to check all modules imported, and hot patch all of them, as usually torch does
179- # something like `from torch.nn.init import xavier_uniform_` in their internals (e.g in torch.nn.modules,
180- # where MultiHeadAttention lives), so the function name is binded at import time and just doing
181- # `setattr(torch.nn.init, name, gloabls()[name])` is thus not enough
182- for module in sys .modules .copy ().values ():
183- if module and hasattr (module , name ):
184- originals [module ][name ] = getattr (module , name )
185- setattr (module , name , globals ()[name ])
196+ for module_name in TORCH_MODULES_TO_PATCH :
197+ if module_name in sys .modules :
198+ module = sys .modules [module_name ]
199+ for func_name in TORCH_INIT_FUNCTIONS .keys ():
200+ if hasattr (module , func_name ):
201+ originals [module ][func_name ] = getattr (module , func_name )
202+ setattr (module , func_name , globals ()[func_name ])
186203 yield
187204 finally :
188205 # Set back the original functions on all modules
189206 for module , functions in originals .items ():
190- for name , func in functions .items ():
191- setattr (module , name , func )
207+ for func_name , func in functions .items ():
208+ setattr (module , func_name , func )
0 commit comments