-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Fix mypy errors attributed to pytorch_lightning/utilities/meta.py
#13763
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
3391c63
0e9cd31
5c64311
494af9d
0c196e7
b25dd53
666d8ec
4cbff1f
97dfd78
cfcadde
c48fc83
3262f22
9e080f9
f72cbe6
13b633e
c451038
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,7 +19,7 @@ | |
| from functools import partial | ||
| from itertools import chain | ||
| from types import ModuleType | ||
| from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type | ||
| from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union | ||
|
|
||
| import torch | ||
| from torch import nn, Tensor | ||
|
|
@@ -38,43 +38,9 @@ | |
|
|
||
| #################################################################### | ||
| # BELOW: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. # | ||
| # TODO: Removed once merged and released on PyTorch side # | ||
| # TODO: Remove once merged and released on PyTorch side # | ||
| #################################################################### | ||
|
|
||
| @contextmanager | ||
| def enable_python_mode(cls) -> Iterator[None]: | ||
| if not hasattr(cls, "__torch_dispatch__"): | ||
| raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod") | ||
| if not isinstance(cls, type) or not issubclass(cls, (Tensor,)): | ||
| raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass") | ||
| torch._C._enter_python_mode(cls) | ||
| try: | ||
| yield | ||
| finally: | ||
| torch._C._exit_python_mode() | ||
|
|
||
| _tls = threading.local() | ||
| _tls.in_call = False | ||
|
|
||
| @contextmanager | ||
| def _no_dispatch() -> Iterator[None]: | ||
| """Temporarily disables the Python dispatch mode.""" | ||
| guard = _DisableTorchDispatch() # noqa F841 | ||
| try: | ||
| yield | ||
| finally: | ||
| del guard | ||
|
|
||
| def _handle_arange(func, args, kwargs): | ||
| kwargs["device"] = torch.device("cpu") | ||
| return torch.empty_like(func(*args, **kwargs), device="meta") | ||
|
|
||
| def _handle_tril(func, args, kwargs): | ||
| if args and isinstance(args[0], Tensor): | ||
| return torch.empty_like(args[0], device="meta") | ||
|
|
||
| return NotImplemented | ||
|
|
||
| class _MetaContext(Tensor): | ||
| _op_handlers: Dict[Callable, Callable] = {} | ||
|
|
||
|
|
@@ -91,7 +57,7 @@ def _ensure_handlers_initialized(cls) -> None: | |
| ) | ||
|
|
||
| @classmethod | ||
| def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | ||
| def __torch_dispatch__(cls, func: Callable, types: Any, args: Any = (), kwargs: Optional[Any] = None) -> Any: | ||
| cls._ensure_handlers_initialized() | ||
|
|
||
| op_handler: Optional[Callable] | ||
|
|
@@ -107,13 +73,49 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None): | |
| if result is not NotImplemented: | ||
| return result | ||
|
|
||
| if "device" in kwargs: | ||
| if kwargs is not None and "device" in kwargs.keys(): | ||
| kwargs["device"] = torch.device("meta") | ||
|
|
||
| return func(*args, **kwargs) | ||
| return func(*args, **(kwargs if kwargs is not None else {})) | ||
|
|
||
| def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module: | ||
| def create_instance(module=None) -> Module: | ||
| @contextmanager | ||
| def enable_python_mode(cls: Type[_MetaContext]) -> Iterator[None]: | ||
| if not hasattr(cls, "__torch_dispatch__"): | ||
| raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod") | ||
| if not isinstance(cls, type) or not issubclass(cls, (Tensor,)): | ||
| raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass") | ||
| torch._C._enter_python_mode(cls) | ||
| try: | ||
| yield | ||
| finally: | ||
| torch._C._exit_python_mode() | ||
|
|
||
| _tls = threading.local() | ||
| _tls.in_call = False | ||
|
|
||
| @contextmanager | ||
| def _no_dispatch() -> Iterator[None]: | ||
| """Temporarily disables the Python dispatch mode.""" | ||
| guard = _DisableTorchDispatch() # noqa F841 | ||
| try: | ||
| yield | ||
| finally: | ||
| del guard | ||
|
|
||
| def _handle_arange(func: Callable, args: Any, kwargs: Any) -> Tensor: | ||
| kwargs["device"] = torch.device("cpu") | ||
| return torch.empty_like(func(*args, **kwargs), device="meta") | ||
|
|
||
| def _handle_tril(func: Callable, args: Any, kwargs: Any) -> Union[Tensor, Any]: | ||
| if args and isinstance(args[0], Tensor): | ||
| return torch.empty_like(args[0], device="meta") | ||
|
|
||
| return NotImplemented | ||
|
|
||
| def init_meta( | ||
| module_fn: Callable[..., Module], *args: Any, **kwargs: Any | ||
| ) -> Union[Module, MisconfigurationException]: | ||
| def create_instance(module: Optional[Any] = None) -> Module: | ||
| if module: | ||
| module.__init__(*args, **kwargs) | ||
| return module | ||
|
|
@@ -144,7 +146,9 @@ def is_meta_init() -> bool: | |
|
|
||
| else: | ||
|
|
||
| def init_meta(*_, **__): | ||
| def init_meta( | ||
| module_fn: Callable[..., Module], *args: Any, **kwargs: Any | ||
| ) -> Union[Module, MisconfigurationException]: | ||
| if not _TORCH_GREATER_EQUAL_1_10: | ||
| return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") | ||
|
|
||
|
|
@@ -153,7 +157,7 @@ def init_meta(*_, **__): | |
| def get_all_subclasses(cls: Type) -> Set[Type]: | ||
| subclass_list = [] | ||
|
|
||
| def recurse(cl): | ||
| def recurse(cl: Type) -> None: | ||
| for subclass in cl.__subclasses__(): | ||
| subclass_list.append(subclass) | ||
| recurse(subclass) | ||
|
|
@@ -170,7 +174,7 @@ def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module | |
|
|
||
| try: | ||
| index = int(name) | ||
| root_module[index] = materialized_module | ||
| root_module[index] = materialized_module # type: ignore[operator] | ||
| except ValueError: | ||
| setattr(root_module, name, materialized_module) | ||
|
|
||
|
|
@@ -194,8 +198,8 @@ def materialize_module(root_module: nn.Module) -> nn.Module: | |
|
|
||
|
|
||
| # cache subclasses to optimize the search when resetting the meta device later on. | ||
| __STORAGE_META__ = {} | ||
| __CREATED_MODULES__ = set() | ||
| __STORAGE_META__: Dict[Type, Tuple] = {} | ||
| __CREATED_MODULES__: Set[Type] = set() | ||
|
|
||
|
|
||
| def _unset_meta_device(from_created: bool = False) -> None: | ||
|
|
@@ -206,7 +210,7 @@ def _unset_meta_device(from_created: bool = False) -> None: | |
| if from_created: | ||
| values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] | ||
| else: | ||
| values = __STORAGE_META__.values() | ||
| values = list(__STORAGE_META__.values()) | ||
|
|
||
| for mods, subclass, _ in values: | ||
| for mod in mods: | ||
|
|
@@ -221,7 +225,7 @@ def _set_meta_device_populated(from_created: bool = False) -> None: | |
| if from_created: | ||
| values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] | ||
| else: | ||
| values = __STORAGE_META__.values() | ||
| values = list(__STORAGE_META__.values()) | ||
|
|
||
| for mods, subclass, meta_class in values: | ||
| for mod in mods: | ||
|
|
@@ -251,45 +255,47 @@ def _set_meta_device() -> None: | |
| setattr(mod, subclass.__name__, meta_class) | ||
| continue | ||
|
|
||
| class _IsinstanceMetaclass(type(subclass)): | ||
| class _IsinstanceMetaclass(type(subclass)): # type: ignore[misc] | ||
| def __instancecheck__(self, instance: Any) -> bool: | ||
| """Overrides the ``isinstance`` check on ``_MaterializerModule`` objects.""" | ||
| return isinstance(instance, self.__bases__[0]) | ||
|
|
||
| # Create a class subclassing current `subclass` overriding its new method. | ||
| # this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta` | ||
| # version of the current subclass module | ||
| class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass): | ||
| class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass): # type: ignore[valid-type, misc] | ||
| @classmethod | ||
| @contextmanager | ||
| def instantiation_context(cls): | ||
| def instantiation_context(cls) -> Generator[None, None, None]: | ||
| _unset_meta_device(from_created=True) | ||
| yield | ||
| _set_meta_device_populated(from_created=True) | ||
|
|
||
| @classmethod | ||
| def materialize(cls, materialize_fn: Callable): | ||
| def materialize(cls, materialize_fn: Callable) -> Type: | ||
| with cls.instantiation_context(): | ||
| obj = materialize_fn() | ||
| return obj | ||
|
|
||
| @staticmethod | ||
| def add_subclasses(subclass): | ||
| def add_subclasses(subclass: Type) -> None: | ||
| """This is used to unroll the instantiation tree while creating the modules.""" | ||
| # Don't store the LightningModule as skipped from the Meta process. | ||
| if subclass != pl.LightningModule: | ||
| __CREATED_MODULES__.add(subclass) | ||
| if subclass.__bases__[0] != torch.nn.modules.module.Module: | ||
| _MaterializerModule.add_subclasses(subclass.__bases__[0]) | ||
|
|
||
| def __new__(cls, *args, **kwargs): | ||
| def __new__(cls, *args: Any, **kwargs: Any) -> Type: | ||
| subclass = cls.__bases__[0] | ||
| cls.add_subclasses(subclass) | ||
| with cls.instantiation_context(): | ||
| obj = init_meta(subclass, *args, **kwargs) | ||
| if isinstance(obj, Exception): | ||
| raise obj | ||
|
Comment on lines
+294
to
+295
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you taking these changes from upstream?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I wrote this line because |
||
|
|
||
| obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) | ||
| return obj | ||
| obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) # type: ignore[assignment] | ||
| return obj # type: ignore | ||
|
|
||
| def search(mod: ModuleType) -> List[ModuleType]: | ||
| out = [] | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wondering, why it returns an Exception instead of raising it?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks like a bug
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I modified definition line of
init_metafunction becausemypyshows error when conditional function do not have same identical signature.And original code returns
MisconfigurationException, so I followed to return error without additional modify