diff --git a/pyproject.toml b/pyproject.toml index 32cc6e8452d25..488ce7a51094e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,6 +76,5 @@ module = [ "pytorch_lightning.tuner.batch_size_scaling", "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.data", - "pytorch_lightning.utilities.meta", ] ignore_errors = "True" diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index 77da02f7231d4..6708159da86b5 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -19,7 +19,7 @@ from functools import partial from itertools import chain from types import ModuleType -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union import torch from torch import nn, Tensor @@ -38,43 +38,9 @@ #################################################################### # BELOW: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. # - # TODO: Removed once merged and released on PyTorch side # + # TODO: Remove once merged and released on PyTorch side # #################################################################### - @contextmanager - def enable_python_mode(cls) -> Iterator[None]: - if not hasattr(cls, "__torch_dispatch__"): - raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod") - if not isinstance(cls, type) or not issubclass(cls, (Tensor,)): - raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass") - torch._C._enter_python_mode(cls) - try: - yield - finally: - torch._C._exit_python_mode() - - _tls = threading.local() - _tls.in_call = False - - @contextmanager - def _no_dispatch() -> Iterator[None]: - """Temporarily disables the Python dispatch mode.""" - guard = _DisableTorchDispatch() # noqa F841 - try: - yield - finally: - del guard - - def _handle_arange(func, args, kwargs): - kwargs["device"] = torch.device("cpu") - return torch.empty_like(func(*args, **kwargs), device="meta") - - def _handle_tril(func, args, kwargs): - if args and isinstance(args[0], Tensor): - return torch.empty_like(args[0], device="meta") - - return NotImplemented - class _MetaContext(Tensor): _op_handlers: Dict[Callable, Callable] = {} @@ -91,7 +57,7 @@ def _ensure_handlers_initialized(cls) -> None: ) @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__(cls, func: Callable, types: Any, args: Any = (), kwargs: Optional[Any] = None) -> Any: cls._ensure_handlers_initialized() op_handler: Optional[Callable] @@ -107,13 +73,49 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): if result is not NotImplemented: return result - if "device" in kwargs: + if kwargs is not None and "device" in kwargs.keys(): kwargs["device"] = torch.device("meta") - return func(*args, **kwargs) + return func(*args, **(kwargs if kwargs is not None else {})) - def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module: - def create_instance(module=None) -> Module: + @contextmanager + def enable_python_mode(cls: Type[_MetaContext]) -> Iterator[None]: + if not hasattr(cls, "__torch_dispatch__"): + raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod") + if not isinstance(cls, type) or not issubclass(cls, (Tensor,)): + raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass") + torch._C._enter_python_mode(cls) + try: + yield + finally: + torch._C._exit_python_mode() + + _tls = threading.local() + _tls.in_call = False + + @contextmanager + def _no_dispatch() -> Iterator[None]: + """Temporarily disables the Python dispatch mode.""" + guard = _DisableTorchDispatch() # noqa F841 + try: + yield + finally: + del guard + + def _handle_arange(func: Callable, args: Any, kwargs: Any) -> Tensor: + kwargs["device"] = torch.device("cpu") + return torch.empty_like(func(*args, **kwargs), device="meta") + + def _handle_tril(func: Callable, args: Any, kwargs: Any) -> Union[Tensor, Any]: + if args and isinstance(args[0], Tensor): + return torch.empty_like(args[0], device="meta") + + return NotImplemented + + def init_meta( + module_fn: Callable[..., Module], *args: Any, **kwargs: Any + ) -> Union[Module, MisconfigurationException]: + def create_instance(module: Optional[Any] = None) -> Module: if module: module.__init__(*args, **kwargs) return module @@ -144,7 +146,9 @@ def is_meta_init() -> bool: else: - def init_meta(*_, **__): + def init_meta( + module_fn: Callable[..., Module], *args: Any, **kwargs: Any + ) -> Union[Module, MisconfigurationException]: if not _TORCH_GREATER_EQUAL_1_10: return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") @@ -153,7 +157,7 @@ def init_meta(*_, **__): def get_all_subclasses(cls: Type) -> Set[Type]: subclass_list = [] - def recurse(cl): + def recurse(cl: Type) -> None: for subclass in cl.__subclasses__(): subclass_list.append(subclass) recurse(subclass) @@ -170,7 +174,7 @@ def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module try: index = int(name) - root_module[index] = materialized_module + root_module[index] = materialized_module # type: ignore[operator] except ValueError: setattr(root_module, name, materialized_module) @@ -194,8 +198,8 @@ def materialize_module(root_module: nn.Module) -> nn.Module: # cache subclasses to optimize the search when resetting the meta device later on. -__STORAGE_META__ = {} -__CREATED_MODULES__ = set() +__STORAGE_META__: Dict[Type, Tuple] = {} +__CREATED_MODULES__: Set[Type] = set() def _unset_meta_device(from_created: bool = False) -> None: @@ -206,7 +210,7 @@ def _unset_meta_device(from_created: bool = False) -> None: if from_created: values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] else: - values = __STORAGE_META__.values() + values = list(__STORAGE_META__.values()) for mods, subclass, _ in values: for mod in mods: @@ -221,7 +225,7 @@ def _set_meta_device_populated(from_created: bool = False) -> None: if from_created: values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] else: - values = __STORAGE_META__.values() + values = list(__STORAGE_META__.values()) for mods, subclass, meta_class in values: for mod in mods: @@ -251,7 +255,7 @@ def _set_meta_device() -> None: setattr(mod, subclass.__name__, meta_class) continue - class _IsinstanceMetaclass(type(subclass)): + class _IsinstanceMetaclass(type(subclass)): # type: ignore[misc] def __instancecheck__(self, instance: Any) -> bool: """Overrides the ``isinstance`` check on ``_MaterializerModule`` objects.""" return isinstance(instance, self.__bases__[0]) @@ -259,22 +263,22 @@ def __instancecheck__(self, instance: Any) -> bool: # Create a class subclassing current `subclass` overriding its new method. # this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta` # version of the current subclass module - class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass): + class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass): # type: ignore[valid-type, misc] @classmethod @contextmanager - def instantiation_context(cls): + def instantiation_context(cls) -> Generator[None, None, None]: _unset_meta_device(from_created=True) yield _set_meta_device_populated(from_created=True) @classmethod - def materialize(cls, materialize_fn: Callable): + def materialize(cls, materialize_fn: Callable) -> Type: with cls.instantiation_context(): obj = materialize_fn() return obj @staticmethod - def add_subclasses(subclass): + def add_subclasses(subclass: Type) -> None: """This is used to unroll the instantiation tree while creating the modules.""" # Don't store the LightningModule as skipped from the Meta process. if subclass != pl.LightningModule: @@ -282,14 +286,16 @@ def add_subclasses(subclass): if subclass.__bases__[0] != torch.nn.modules.module.Module: _MaterializerModule.add_subclasses(subclass.__bases__[0]) - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> Type: subclass = cls.__bases__[0] cls.add_subclasses(subclass) with cls.instantiation_context(): obj = init_meta(subclass, *args, **kwargs) + if isinstance(obj, Exception): + raise obj - obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) - return obj + obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) # type: ignore[assignment] + return obj # type: ignore def search(mod: ModuleType) -> List[ModuleType]: out = []