Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,5 @@ module = [
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.utilities.auto_restart",
"pytorch_lightning.utilities.data",
"pytorch_lightning.utilities.meta",
]
ignore_errors = "True"
118 changes: 62 additions & 56 deletions src/pytorch_lightning/utilities/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from functools import partial
from itertools import chain
from types import ModuleType
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type
from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Tuple, Type, Union

import torch
from torch import nn, Tensor
Expand All @@ -38,43 +38,9 @@

####################################################################
# BELOW: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. #
# TODO: Removed once merged and released on PyTorch side #
# TODO: Remove once merged and released on PyTorch side #
####################################################################

@contextmanager
def enable_python_mode(cls) -> Iterator[None]:
if not hasattr(cls, "__torch_dispatch__"):
raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod")
if not isinstance(cls, type) or not issubclass(cls, (Tensor,)):
raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass")
torch._C._enter_python_mode(cls)
try:
yield
finally:
torch._C._exit_python_mode()

_tls = threading.local()
_tls.in_call = False

@contextmanager
def _no_dispatch() -> Iterator[None]:
"""Temporarily disables the Python dispatch mode."""
guard = _DisableTorchDispatch() # noqa F841
try:
yield
finally:
del guard

def _handle_arange(func, args, kwargs):
kwargs["device"] = torch.device("cpu")
return torch.empty_like(func(*args, **kwargs), device="meta")

def _handle_tril(func, args, kwargs):
if args and isinstance(args[0], Tensor):
return torch.empty_like(args[0], device="meta")

return NotImplemented

class _MetaContext(Tensor):
_op_handlers: Dict[Callable, Callable] = {}

Expand All @@ -91,7 +57,7 @@ def _ensure_handlers_initialized(cls) -> None:
)

@classmethod
def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
def __torch_dispatch__(cls, func: Callable, types: Any, args: Any = (), kwargs: Optional[Any] = None) -> Any:
cls._ensure_handlers_initialized()

op_handler: Optional[Callable]
Expand All @@ -107,13 +73,49 @@ def __torch_dispatch__(cls, func, types, args=(), kwargs=None):
if result is not NotImplemented:
return result

if "device" in kwargs:
if kwargs is not None and "device" in kwargs.keys():
kwargs["device"] = torch.device("meta")

return func(*args, **kwargs)
return func(*args, **(kwargs if kwargs is not None else {}))

def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module:
def create_instance(module=None) -> Module:
@contextmanager
def enable_python_mode(cls: Type[_MetaContext]) -> Iterator[None]:
if not hasattr(cls, "__torch_dispatch__"):
raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod")
if not isinstance(cls, type) or not issubclass(cls, (Tensor,)):
raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass")
torch._C._enter_python_mode(cls)
try:
yield
finally:
torch._C._exit_python_mode()

_tls = threading.local()
_tls.in_call = False

@contextmanager
def _no_dispatch() -> Iterator[None]:
"""Temporarily disables the Python dispatch mode."""
guard = _DisableTorchDispatch() # noqa F841
try:
yield
finally:
del guard

def _handle_arange(func: Callable, args: Any, kwargs: Any) -> Tensor:
kwargs["device"] = torch.device("cpu")
return torch.empty_like(func(*args, **kwargs), device="meta")

def _handle_tril(func: Callable, args: Any, kwargs: Any) -> Union[Tensor, Any]:
if args and isinstance(args[0], Tensor):
return torch.empty_like(args[0], device="meta")

return NotImplemented

def init_meta(
module_fn: Callable[..., Module], *args: Any, **kwargs: Any
) -> Union[Module, MisconfigurationException]:
def create_instance(module: Optional[Any] = None) -> Module:
if module:
module.__init__(*args, **kwargs)
return module
Expand Down Expand Up @@ -144,7 +146,9 @@ def is_meta_init() -> bool:

else:

def init_meta(*_, **__):
def init_meta(
module_fn: Callable[..., Module], *args: Any, **kwargs: Any
) -> Union[Module, MisconfigurationException]:
if not _TORCH_GREATER_EQUAL_1_10:
return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wondering, why it returns an Exception instead of raising it?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a bug

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I modified definition line of init_meta function because mypy shows error when conditional function do not have same identical signature.
And original code returns MisconfigurationException, so I followed to return error without additional modify


Expand All @@ -153,7 +157,7 @@ def init_meta(*_, **__):
def get_all_subclasses(cls: Type) -> Set[Type]:
subclass_list = []

def recurse(cl):
def recurse(cl: Type) -> None:
for subclass in cl.__subclasses__():
subclass_list.append(subclass)
recurse(subclass)
Expand All @@ -170,7 +174,7 @@ def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module

try:
index = int(name)
root_module[index] = materialized_module
root_module[index] = materialized_module # type: ignore[operator]
except ValueError:
setattr(root_module, name, materialized_module)

Expand All @@ -194,8 +198,8 @@ def materialize_module(root_module: nn.Module) -> nn.Module:


# cache subclasses to optimize the search when resetting the meta device later on.
__STORAGE_META__ = {}
__CREATED_MODULES__ = set()
__STORAGE_META__: Dict[Type, Tuple] = {}
__CREATED_MODULES__: Set[Type] = set()


def _unset_meta_device(from_created: bool = False) -> None:
Expand All @@ -206,7 +210,7 @@ def _unset_meta_device(from_created: bool = False) -> None:
if from_created:
values = [__STORAGE_META__[key] for key in __CREATED_MODULES__]
else:
values = __STORAGE_META__.values()
values = list(__STORAGE_META__.values())

for mods, subclass, _ in values:
for mod in mods:
Expand All @@ -221,7 +225,7 @@ def _set_meta_device_populated(from_created: bool = False) -> None:
if from_created:
values = [__STORAGE_META__[key] for key in __CREATED_MODULES__]
else:
values = __STORAGE_META__.values()
values = list(__STORAGE_META__.values())

for mods, subclass, meta_class in values:
for mod in mods:
Expand Down Expand Up @@ -251,45 +255,47 @@ def _set_meta_device() -> None:
setattr(mod, subclass.__name__, meta_class)
continue

class _IsinstanceMetaclass(type(subclass)):
class _IsinstanceMetaclass(type(subclass)): # type: ignore[misc]
def __instancecheck__(self, instance: Any) -> bool:
"""Overrides the ``isinstance`` check on ``_MaterializerModule`` objects."""
return isinstance(instance, self.__bases__[0])

# Create a class subclassing current `subclass` overriding its new method.
# this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta`
# version of the current subclass module
class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass):
class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass): # type: ignore[valid-type, misc]
@classmethod
@contextmanager
def instantiation_context(cls):
def instantiation_context(cls) -> Generator[None, None, None]:
_unset_meta_device(from_created=True)
yield
_set_meta_device_populated(from_created=True)

@classmethod
def materialize(cls, materialize_fn: Callable):
def materialize(cls, materialize_fn: Callable) -> Type:
with cls.instantiation_context():
obj = materialize_fn()
return obj

@staticmethod
def add_subclasses(subclass):
def add_subclasses(subclass: Type) -> None:
"""This is used to unroll the instantiation tree while creating the modules."""
# Don't store the LightningModule as skipped from the Meta process.
if subclass != pl.LightningModule:
__CREATED_MODULES__.add(subclass)
if subclass.__bases__[0] != torch.nn.modules.module.Module:
_MaterializerModule.add_subclasses(subclass.__bases__[0])

def __new__(cls, *args, **kwargs):
def __new__(cls, *args: Any, **kwargs: Any) -> Type:
subclass = cls.__bases__[0]
cls.add_subclasses(subclass)
with cls.instantiation_context():
obj = init_meta(subclass, *args, **kwargs)
if isinstance(obj, Exception):
raise obj
Comment on lines +294 to +295
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Are you taking these changes from upstream?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I wrote this line because init_meta function returns MisconfigurationException when _TORCH_GREATER_EQUAL_1_10 is false.
If the error is returned, obj.materialize would be none. So I thought that it would be better to raise error explicitly


obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize)
return obj
obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) # type: ignore[assignment]
return obj # type: ignore

def search(mod: ModuleType) -> List[ModuleType]:
out = []
Expand Down