1818from functools import partial
1919from itertools import chain
2020from types import ModuleType
21- from typing import Callable , Dict , Generator , Iterator , List , Optional , Set , Type
21+ from typing import Any , Callable , Dict , Generator , Iterator , List , Optional , Set , Type
2222
2323import torch
2424from torch import nn , Tensor
2525from torch .nn import Module
2626from torch .nn .modules .container import ModuleDict , ModuleList , Sequential
2727
28+ import pytorch_lightning as pl
2829from pytorch_lightning .utilities import rank_zero_warn
2930from pytorch_lightning .utilities .exceptions import MisconfigurationException
3031from pytorch_lightning .utilities .imports import _TORCH_GREATER_EQUAL_1_10
@@ -191,7 +192,6 @@ def materialize_module(root_module: nn.Module) -> nn.Module:
191192
192193# cache subclasses to optimize the search when resetting the meta device later on.
193194__STORAGE_META__ = {}
194-
195195__CREATED_MODULES__ = set ()
196196
197197
@@ -237,45 +237,52 @@ def _set_meta_device() -> None:
237237
238238 for subclass in get_all_subclasses (torch .nn .modules .module .Module ):
239239
240- if isinstance ( subclass , (Sequential , ModuleList , ModuleDict ) ):
240+ if subclass in (Sequential , ModuleList , ModuleDict , pl . LightningModule ):
241241 continue
242242
243243 # if a subclass has already been stored, we should use the cache
244244 if str (subclass ) in __STORAGE_META__ :
245- # reset the class import package to its rightfull state.
245+ # reset the class import package to its rightful state.
246246 mods , subclass , meta_class = __STORAGE_META__ [subclass ]
247247 for mod in mods :
248248 setattr (mod , subclass .__name__ , meta_class )
249249 continue
250250
251+ class _IsinstanceMetaclass (type (subclass )):
252+ def __instancecheck__ (self , instance : Any ) -> bool :
253+ """Overrides the ``isinstance`` check on ``_MaterializerModule`` objects."""
254+ return isinstance (instance , self .__bases__ [0 ])
255+
251256 # Create a class subclassing current `subclass` overriding its new method.
252257 # this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
253258 # version of the current subclass module
254- class _MetaClass (subclass ):
259+ class _MaterializerModule (subclass , metaclass = _IsinstanceMetaclass ):
255260 @classmethod
256261 @contextmanager
257- def instantiation_context (cls , materialize : bool ):
262+ def instantiation_context (cls ):
258263 _unset_meta_device (from_created = True )
259264 yield
260265 _set_meta_device_populated (from_created = True )
261266
262267 @classmethod
263268 def materialize (cls , materialize_fn : Callable ):
264- with cls .instantiation_context (materialize = True ):
269+ with cls .instantiation_context ():
265270 obj = materialize_fn ()
266271 return obj
267272
268273 @staticmethod
269274 def add_subclasses (subclass ):
270- """This is used to unrol the instantion tree while creating the modules."""
271- __CREATED_MODULES__ .add (subclass )
275+ """This is used to unroll the instantiation tree while creating the modules."""
276+ # Don't store the LightningModule as skipped from the Meta process.
277+ if subclass != pl .LightningModule :
278+ __CREATED_MODULES__ .add (subclass )
272279 if subclass .__bases__ [0 ] != torch .nn .modules .module .Module :
273- _MetaClass .add_subclasses (subclass .__bases__ [0 ])
280+ _MaterializerModule .add_subclasses (subclass .__bases__ [0 ])
274281
275282 def __new__ (cls , * args , ** kwargs ):
276283 subclass = cls .__bases__ [0 ]
277284 cls .add_subclasses (subclass )
278- with cls .instantiation_context (materialize = False ):
285+ with cls .instantiation_context ():
279286 obj = init_meta (subclass , * args , ** kwargs )
280287
281288 obj .materialize = partial (cls .materialize , materialize_fn = obj .materialize )
@@ -294,9 +301,8 @@ def search(mod: ModuleType) -> List[ModuleType]:
294301 # nn.Module class can be imported at different level and they all need to be mocked.
295302 # Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear
296303 # Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear
297- # needs to be replaced by the torch.nn.linear.modules.Linear _MetaClass
298- out = []
299- out .append (search (mod ))
304+ # needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule
305+ out = [search (mod )]
300306 for name in submodules [1 :]:
301307 mod = getattr (mod , name )
302308 out .append (search (mod ))
@@ -305,11 +311,11 @@ def search(mod: ModuleType) -> List[ModuleType]:
305311 mods = [mod for mod in chain (* out ) if mod ]
306312
307313 # store the modules search so it doesn't have to be performed again for this class
308- __STORAGE_META__ [subclass ] = (mods , subclass , _MetaClass )
314+ __STORAGE_META__ [subclass ] = (mods , subclass , _MaterializerModule )
309315
310316 # replace all subclass by its meta form
311317 for mod in mods :
312- setattr (mod , subclass .__name__ , _MetaClass )
318+ setattr (mod , subclass .__name__ , _MaterializerModule )
313319
314320
315321@contextmanager
@@ -321,3 +327,11 @@ def init_meta_context() -> Generator:
321327 _set_meta_device ()
322328 yield
323329 _unset_meta_device ()
330+
331+
332+ def is_on_meta_device (module : nn .Module ) -> bool :
333+ try :
334+ param = next (module .parameters ())
335+ return param .device .type == "meta"
336+ except StopIteration :
337+ return False
0 commit comments