From 3391c634f26cd0b00f1f1ed13c3059c4b59dcd15 Mon Sep 17 00:00:00 2001 From: nninept Date: Wed, 20 Jul 2022 13:34:09 +0900 Subject: [PATCH 01/15] remove corresponding line from pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6d973aa0dde51..77538a22d8b09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,5 @@ module = [ "pytorch_lightning.tuner.batch_size_scaling", "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.data", - "pytorch_lightning.utilities.meta", ] ignore_errors = "True" From 0e9cd3186d75fdbd639fb162e197813cf444a84b Mon Sep 17 00:00:00 2001 From: nninept Date: Wed, 20 Jul 2022 15:46:37 +0900 Subject: [PATCH 02/15] update enable_python_mode annotation --- src/pytorch_lightning/utilities/meta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index 77da02f7231d4..23ff554a57649 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -42,7 +42,7 @@ #################################################################### @contextmanager - def enable_python_mode(cls) -> Iterator[None]: + def enable_python_mode(cls: Type[_MetaContext]) -> Iterator[None]: if not hasattr(cls, "__torch_dispatch__"): raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod") if not isinstance(cls, type) or not issubclass(cls, (Tensor,)): From 5c64311e6abb72cb2d5c0190e6ed20182a54f9f1 Mon Sep 17 00:00:00 2001 From: nninept Date: Wed, 20 Jul 2022 15:52:11 +0900 Subject: [PATCH 03/15] update _handle_arange annotation --- src/pytorch_lightning/utilities/meta.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index 23ff554a57649..b5608016f1356 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -64,8 +64,8 @@ def _no_dispatch() -> Iterator[None]: yield finally: del guard - - def _handle_arange(func, args, kwargs): + + def _handle_arange(func: Callable, args: Any, kwargs: Any) -> Tensor: kwargs["device"] = torch.device("cpu") return torch.empty_like(func(*args, **kwargs), device="meta") From 494af9d1c47c4a40742282cd80521c08967f1f45 Mon Sep 17 00:00:00 2001 From: nninept Date: Wed, 20 Jul 2022 16:37:28 +0900 Subject: [PATCH 04/15] update _handle_tril annotation --- src/pytorch_lightning/utilities/meta.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index b5608016f1356..ed06ec49577d9 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -19,7 +19,7 @@ from functools import partial from itertools import chain from types import ModuleType -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type, Union import torch from torch import nn, Tensor @@ -69,12 +69,12 @@ def _handle_arange(func: Callable, args: Any, kwargs: Any) -> Tensor: kwargs["device"] = torch.device("cpu") return torch.empty_like(func(*args, **kwargs), device="meta") - def _handle_tril(func, args, kwargs): + def _handle_tril(func: Callable, args: Any, kwargs: Any) -> Union[Tensor, Any]: if args and isinstance(args[0], Tensor): return torch.empty_like(args[0], device="meta") return NotImplemented - + class _MetaContext(Tensor): _op_handlers: Dict[Callable, Callable] = {} From 0c196e746c63f6a71e774e9d451456e32d14f500 Mon Sep 17 00:00:00 2001 From: nninept Date: Wed, 20 Jul 2022 16:47:49 +0900 Subject: [PATCH 05/15] update __torch_dispatch__ annotation --- src/pytorch_lightning/utilities/meta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index ed06ec49577d9..6c887d2a1c251 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -91,7 +91,7 @@ def _ensure_handlers_initialized(cls) -> None: ) @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): + def __torch_dispatch__(cls, func: Callable, types: Any, args: Any=(), kwargs: Optional[Any]=None) -> Any: cls._ensure_handlers_initialized() op_handler: Optional[Callable] From b25dd53dc41cb1c5aaef2510169a03a539df3a81 Mon Sep 17 00:00:00 2001 From: nninept Date: Wed, 20 Jul 2022 17:32:18 +0900 Subject: [PATCH 06/15] add condition that kwargs is None --- src/pytorch_lightning/utilities/meta.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index 6c887d2a1c251..964585a84d774 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -107,10 +107,10 @@ def __torch_dispatch__(cls, func: Callable, types: Any, args: Any=(), kwargs: Op if result is not NotImplemented: return result - if "device" in kwargs: + if kwargs is not None and "device" in kwargs.keys(): kwargs["device"] = torch.device("meta") - return func(*args, **kwargs) + return func(*args, **(kwargs if kwargs is not None else {})) def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module: def create_instance(module=None) -> Module: From 666d8ecc80bc232b01f8069bf70c64f89393b41d Mon Sep 17 00:00:00 2001 From: nninept Date: Wed, 20 Jul 2022 17:53:13 +0900 Subject: [PATCH 07/15] update init_meta arguments annotation --- src/pytorch_lightning/utilities/meta.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index 964585a84d774..9e847c4334802 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -112,8 +112,8 @@ def __torch_dispatch__(cls, func: Callable, types: Any, args: Any=(), kwargs: Op return func(*args, **(kwargs if kwargs is not None else {})) - def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module: - def create_instance(module=None) -> Module: + def init_meta(module_fn: Callable[..., Module], *args: Any, **kwargs: Any) -> Union[Module, MisconfigurationException]: + def create_instance(module: Optional[Any]=None) -> Module: if module: module.__init__(*args, **kwargs) return module @@ -144,7 +144,7 @@ def is_meta_init() -> bool: else: - def init_meta(*_, **__): + def init_meta(module_fn: Callable[..., Module], *args: Any, **kwargs: Any) -> Union[Module, MisconfigurationException]: if not _TORCH_GREATER_EQUAL_1_10: return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") From 4cbff1ff9fb2ba8bf4f12a587de32aed8f3409d9 Mon Sep 17 00:00:00 2001 From: nninept Date: Wed, 20 Jul 2022 21:13:54 +0900 Subject: [PATCH 08/15] update get_all_subclasses annotation --- src/pytorch_lightning/utilities/meta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index 9e847c4334802..df843de27d64f 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -153,7 +153,7 @@ def init_meta(module_fn: Callable[..., Module], *args: Any, **kwargs: Any) -> Un def get_all_subclasses(cls: Type) -> Set[Type]: subclass_list = [] - def recurse(cl): + def recurse(cl: Type) -> None: for subclass in cl.__subclasses__(): subclass_list.append(subclass) recurse(subclass) From 97dfd7832c4f07d3706769ab45c477c4461ad217 Mon Sep 17 00:00:00 2001 From: nninept Date: Wed, 20 Jul 2022 21:18:33 +0900 Subject: [PATCH 09/15] add ignore comment for operator error --- src/pytorch_lightning/utilities/meta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index df843de27d64f..a494f534c0a45 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -170,7 +170,7 @@ def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module try: index = int(name) - root_module[index] = materialized_module + root_module[index] = materialized_module # type: ignore[operator] except ValueError: setattr(root_module, name, materialized_module) From cfcadde8880270f8aabaae7ebc7c5e4f682ad2e4 Mon Sep 17 00:00:00 2001 From: nninept Date: Wed, 20 Jul 2022 23:22:42 +0900 Subject: [PATCH 10/15] update _MaterializerModule class annotation --- src/pytorch_lightning/utilities/meta.py | 16 +++++++++------- 1 file changed, 9 insertions(+), 7 deletions(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index a494f534c0a45..3063b0f6e02c9 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -259,22 +259,22 @@ def __instancecheck__(self, instance: Any) -> bool: # Create a class subclassing current `subclass` overriding its new method. # this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta` # version of the current subclass module - class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass): + class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass): # type: ignore[valid-type, misc] @classmethod @contextmanager - def instantiation_context(cls): + def instantiation_context(cls) -> Generator[None, None, None]: _unset_meta_device(from_created=True) yield _set_meta_device_populated(from_created=True) @classmethod - def materialize(cls, materialize_fn: Callable): + def materialize(cls, materialize_fn: Callable) -> Type: with cls.instantiation_context(): obj = materialize_fn() return obj @staticmethod - def add_subclasses(subclass): + def add_subclasses(subclass: Type) -> None: """This is used to unroll the instantiation tree while creating the modules.""" # Don't store the LightningModule as skipped from the Meta process. if subclass != pl.LightningModule: @@ -282,14 +282,16 @@ def add_subclasses(subclass): if subclass.__bases__[0] != torch.nn.modules.module.Module: _MaterializerModule.add_subclasses(subclass.__bases__[0]) - def __new__(cls, *args, **kwargs): + def __new__(cls, *args: Any, **kwargs: Any) -> Type: subclass = cls.__bases__[0] cls.add_subclasses(subclass) with cls.instantiation_context(): obj = init_meta(subclass, *args, **kwargs) + if(isinstance(obj, Exception)): + raise obj - obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) - return obj + obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) # type: ignore[assignment] + return obj # type: ignore def search(mod: ModuleType) -> List[ModuleType]: out = [] From c48fc83e8bc4546a16e955ac31def83b2fe35b6e Mon Sep 17 00:00:00 2001 From: nninept Date: Wed, 20 Jul 2022 23:43:19 +0900 Subject: [PATCH 11/15] update __STORAGE_META__ & __CREATED_MODULES__ annotation --- src/pytorch_lightning/utilities/meta.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index 3063b0f6e02c9..d5a344e886fa8 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -19,7 +19,7 @@ from functools import partial from itertools import chain from types import ModuleType -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type, Union +from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union import torch from torch import nn, Tensor @@ -194,8 +194,8 @@ def materialize_module(root_module: nn.Module) -> nn.Module: # cache subclasses to optimize the search when resetting the meta device later on. -__STORAGE_META__ = {} -__CREATED_MODULES__ = set() +__STORAGE_META__: Dict[Type, Tuple]= {} +__CREATED_MODULES__: Set[Type] = set() def _unset_meta_device(from_created: bool = False) -> None: @@ -206,7 +206,7 @@ def _unset_meta_device(from_created: bool = False) -> None: if from_created: values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] else: - values = __STORAGE_META__.values() + values = list(__STORAGE_META__.values()) for mods, subclass, _ in values: for mod in mods: @@ -221,7 +221,7 @@ def _set_meta_device_populated(from_created: bool = False) -> None: if from_created: values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] else: - values = __STORAGE_META__.values() + values = list(__STORAGE_META__.values()) for mods, subclass, meta_class in values: for mod in mods: From 3262f228bbc2685673ee78a8606ccd288c3b30ce Mon Sep 17 00:00:00 2001 From: nninept Date: Wed, 20 Jul 2022 23:44:45 +0900 Subject: [PATCH 12/15] update _IsinstanceMetaclass class annotation --- src/pytorch_lightning/utilities/meta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index d5a344e886fa8..e26141d89931e 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -251,7 +251,7 @@ def _set_meta_device() -> None: setattr(mod, subclass.__name__, meta_class) continue - class _IsinstanceMetaclass(type(subclass)): + class _IsinstanceMetaclass(type(subclass)): # type: ignore[misc] def __instancecheck__(self, instance: Any) -> bool: """Overrides the ``isinstance`` check on ``_MaterializerModule`` objects.""" return isinstance(instance, self.__bases__[0]) From 9e080f98993f6c19acfbd6e2e703c7346694b497 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 20 Jul 2022 15:19:11 +0000 Subject: [PATCH 13/15] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/utilities/meta.py | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index e26141d89931e..30027cbfa8802 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -64,7 +64,7 @@ def _no_dispatch() -> Iterator[None]: yield finally: del guard - + def _handle_arange(func: Callable, args: Any, kwargs: Any) -> Tensor: kwargs["device"] = torch.device("cpu") return torch.empty_like(func(*args, **kwargs), device="meta") @@ -74,7 +74,7 @@ def _handle_tril(func: Callable, args: Any, kwargs: Any) -> Union[Tensor, Any]: return torch.empty_like(args[0], device="meta") return NotImplemented - + class _MetaContext(Tensor): _op_handlers: Dict[Callable, Callable] = {} @@ -91,7 +91,7 @@ def _ensure_handlers_initialized(cls) -> None: ) @classmethod - def __torch_dispatch__(cls, func: Callable, types: Any, args: Any=(), kwargs: Optional[Any]=None) -> Any: + def __torch_dispatch__(cls, func: Callable, types: Any, args: Any = (), kwargs: Optional[Any] = None) -> Any: cls._ensure_handlers_initialized() op_handler: Optional[Callable] @@ -112,8 +112,10 @@ def __torch_dispatch__(cls, func: Callable, types: Any, args: Any=(), kwargs: Op return func(*args, **(kwargs if kwargs is not None else {})) - def init_meta(module_fn: Callable[..., Module], *args: Any, **kwargs: Any) -> Union[Module, MisconfigurationException]: - def create_instance(module: Optional[Any]=None) -> Module: + def init_meta( + module_fn: Callable[..., Module], *args: Any, **kwargs: Any + ) -> Union[Module, MisconfigurationException]: + def create_instance(module: Optional[Any] = None) -> Module: if module: module.__init__(*args, **kwargs) return module @@ -144,7 +146,9 @@ def is_meta_init() -> bool: else: - def init_meta(module_fn: Callable[..., Module], *args: Any, **kwargs: Any) -> Union[Module, MisconfigurationException]: + def init_meta( + module_fn: Callable[..., Module], *args: Any, **kwargs: Any + ) -> Union[Module, MisconfigurationException]: if not _TORCH_GREATER_EQUAL_1_10: return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") @@ -194,7 +198,7 @@ def materialize_module(root_module: nn.Module) -> nn.Module: # cache subclasses to optimize the search when resetting the meta device later on. -__STORAGE_META__: Dict[Type, Tuple]= {} +__STORAGE_META__: Dict[Type, Tuple] = {} __CREATED_MODULES__: Set[Type] = set() @@ -287,7 +291,7 @@ def __new__(cls, *args: Any, **kwargs: Any) -> Type: cls.add_subclasses(subclass) with cls.instantiation_context(): obj = init_meta(subclass, *args, **kwargs) - if(isinstance(obj, Exception)): + if isinstance(obj, Exception): raise obj obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) # type: ignore[assignment] From f72cbe6b7f354e0ea44461125c24e15a11a94733 Mon Sep 17 00:00:00 2001 From: nninept Date: Fri, 22 Jul 2022 01:17:38 +0900 Subject: [PATCH 14/15] modify _MetaContext class definition line --- src/pytorch_lightning/utilities/meta.py | 68 ++++++++++++------------- 1 file changed, 34 insertions(+), 34 deletions(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index 30027cbfa8802..7dc655f2205b6 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -41,40 +41,6 @@ # TODO: Removed once merged and released on PyTorch side # #################################################################### - @contextmanager - def enable_python_mode(cls: Type[_MetaContext]) -> Iterator[None]: - if not hasattr(cls, "__torch_dispatch__"): - raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod") - if not isinstance(cls, type) or not issubclass(cls, (Tensor,)): - raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass") - torch._C._enter_python_mode(cls) - try: - yield - finally: - torch._C._exit_python_mode() - - _tls = threading.local() - _tls.in_call = False - - @contextmanager - def _no_dispatch() -> Iterator[None]: - """Temporarily disables the Python dispatch mode.""" - guard = _DisableTorchDispatch() # noqa F841 - try: - yield - finally: - del guard - - def _handle_arange(func: Callable, args: Any, kwargs: Any) -> Tensor: - kwargs["device"] = torch.device("cpu") - return torch.empty_like(func(*args, **kwargs), device="meta") - - def _handle_tril(func: Callable, args: Any, kwargs: Any) -> Union[Tensor, Any]: - if args and isinstance(args[0], Tensor): - return torch.empty_like(args[0], device="meta") - - return NotImplemented - class _MetaContext(Tensor): _op_handlers: Dict[Callable, Callable] = {} @@ -112,6 +78,40 @@ def __torch_dispatch__(cls, func: Callable, types: Any, args: Any = (), kwargs: return func(*args, **(kwargs if kwargs is not None else {})) + @contextmanager + def enable_python_mode(cls: Type[_MetaContext]) -> Iterator[None]: + if not hasattr(cls, "__torch_dispatch__"): + raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod") + if not isinstance(cls, type) or not issubclass(cls, (Tensor,)): + raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass") + torch._C._enter_python_mode(cls) + try: + yield + finally: + torch._C._exit_python_mode() + + _tls = threading.local() + _tls.in_call = False + + @contextmanager + def _no_dispatch() -> Iterator[None]: + """Temporarily disables the Python dispatch mode.""" + guard = _DisableTorchDispatch() # noqa F841 + try: + yield + finally: + del guard + + def _handle_arange(func: Callable, args: Any, kwargs: Any) -> Tensor: + kwargs["device"] = torch.device("cpu") + return torch.empty_like(func(*args, **kwargs), device="meta") + + def _handle_tril(func: Callable, args: Any, kwargs: Any) -> Union[Tensor, Any]: + if args and isinstance(args[0], Tensor): + return torch.empty_like(args[0], device="meta") + + return NotImplemented + def init_meta( module_fn: Callable[..., Module], *args: Any, **kwargs: Any ) -> Union[Module, MisconfigurationException]: From 13b633e7ba5b9d3efa6024d2f2b86dab1dc80130 Mon Sep 17 00:00:00 2001 From: nninept Date: Wed, 27 Jul 2022 17:10:45 +0900 Subject: [PATCH 15/15] modify todo --- src/pytorch_lightning/utilities/meta.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index 7dc655f2205b6..6708159da86b5 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -38,7 +38,7 @@ #################################################################### # BELOW: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. # - # TODO: Removed once merged and released on PyTorch side # + # TODO: Remove once merged and released on PyTorch side # #################################################################### class _MetaContext(Tensor):