diff --git a/pyproject.toml b/pyproject.toml index 226b109459f24..5473e73c52e19 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -71,6 +71,5 @@ module = [ "pytorch_lightning.tuner.batch_size_scaling", "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.data", - "pytorch_lightning.utilities.meta", ] ignore_errors = "True" diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index b12e1cf042a1f..62e81e4839da6 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -18,8 +18,6 @@ from torch.nn import Module from typing_extensions import Self -import pytorch_lightning as pl - class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ["device", "dtype"] @@ -180,10 +178,8 @@ def half(self) -> Self: # type: ignore[valid-type] def __update_properties( self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None ) -> None: - def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None: - # TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't - # work when using `init_meta_context`. - if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)): + def apply_fn(module: Union[DeviceDtypeModuleMixin, Module]) -> None: + if not isinstance(module, DeviceDtypeModuleMixin): return if device is not None: module._device = device diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index 91fa92b555ae0..39bba092e9c60 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -87,6 +87,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] a selected set of attributes get restored in the main process after processes join. **kwargs: Optional keyword arguments to be passed to the given function. """ + self._check_torchdistx_support() # The default cluster environment in Lightning chooses a random free port number # This needs to be done in the main process here before starting processes to ensure each rank will connect # through the same port @@ -178,6 +179,16 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra) + def _check_torchdistx_support(self) -> None: + if self._start_method == "spawn": + from pytorch_lightning.utilities.meta import _is_deferred + + if _is_deferred(self._strategy.lightning_module): + raise NotImplementedError( + f"The `{type(self._strategy).__name__}` strategy does not support `torchdistx`'s deferred" + f" initialization." + ) + def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 2a1c5082a1ac8..6853c4328af46 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -70,7 +70,6 @@ XLAProfiler, ) from pytorch_lightning.strategies import ParallelStrategy, Strategy -from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector @@ -106,8 +105,7 @@ from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module +from pytorch_lightning.utilities.imports import _fault_tolerant_training, _module_available from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.seed import isolate_rng @@ -1469,20 +1467,14 @@ def _call_setup_hook(self) -> None: def _call_configure_sharded_model(self) -> None: with self.strategy.model_sharded_context(): - self._handle_meta_model() - self._call_lightning_module_hook("configure_sharded_model") - self._call_callback_hooks("on_configure_sharded_model") - - def _handle_meta_model(self) -> None: - if not is_on_meta_device(self.lightning_module): - return + # experimental support for torchdistx + if _module_available("torchdistx.deferred_init"): + from torchdistx.deferred_init import materialize_module - if isinstance(self.strategy, DDPSpawnStrategy): - raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.") + materialize_module(self.lightning_module) - materialize_module(self.lightning_module) - # the trainer reference is lost during materialization - self.lightning_module.trainer = proxy(self) + self._call_lightning_module_hook("configure_sharded_model") + self._call_callback_hooks("on_configure_sharded_model") def _call_teardown_hook(self) -> None: fn = self.state.fn._setup_fn diff --git a/src/pytorch_lightning/utilities/cli.py b/src/pytorch_lightning/utilities/cli.py index 285b5361f9cd2..8af919b78ce93 100644 --- a/src/pytorch_lightning/utilities/cli.py +++ b/src/pytorch_lightning/utilities/cli.py @@ -22,7 +22,7 @@ import pytorch_lightning as pl import pytorch_lightning.cli as new_cli -from pytorch_lightning.utilities.meta import get_all_subclasses +from pytorch_lightning.utilities.meta import _get_all_subclasses from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation _deprecate_registry_message = ( @@ -108,17 +108,17 @@ def _populate_registries(subclasses: bool) -> None: # Remove in v1.9 if subclasses: rank_zero_deprecation(_deprecate_auto_registry_message) # this will register any subclasses from all loaded modules including userland - for cls in get_all_subclasses(torch.optim.Optimizer): + for cls in _get_all_subclasses(torch.optim.Optimizer): OPTIMIZER_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler): + for cls in _get_all_subclasses(torch.optim.lr_scheduler._LRScheduler): LR_SCHEDULER_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(pl.Callback): + for cls in _get_all_subclasses(pl.Callback): CALLBACK_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(pl.LightningModule): + for cls in _get_all_subclasses(pl.LightningModule): MODEL_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(pl.LightningDataModule): + for cls in _get_all_subclasses(pl.LightningDataModule): DATAMODULE_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(pl.loggers.Logger): + for cls in _get_all_subclasses(pl.loggers.Logger): LOGGER_REGISTRY(cls, show_deprecation=False) else: # manually register torch's subclasses and our subclasses diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 862c7f2de905b..00a7cb8486709 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -18,7 +18,7 @@ from contextlib import contextmanager from dataclasses import fields from functools import partial -from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Tuple, Type, Union import torch from torch import Tensor @@ -39,6 +39,7 @@ from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.meta import _get_all_subclasses from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.seed import pl_worker_init_function from pytorch_lightning.utilities.warnings import WarningCache @@ -493,20 +494,6 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: return wrapper -# https://stackoverflow.com/a/63851681/9201239 -def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]: - """Returns a list of all classes that inherit directly or indirectly from the given class.""" - subclasses = set() - - def recurse(cl: Type[Any]) -> None: - for subclass in cl.__subclasses__(): - subclasses.add(subclass) - recurse(subclass) - - recurse(cls) - return subclasses - - @contextmanager def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index 77da02f7231d4..9f4cd72bfe65d 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -11,149 +11,46 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib -import inspect -import operator -import threading from contextlib import contextmanager -from functools import partial -from itertools import chain -from types import ModuleType -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type +from typing import Any, Callable, Generator, Mapping, Optional, Set, Type, Union -import torch -from torch import nn, Tensor -from torch.nn import Module -from torch.nn.modules.container import ModuleDict, ModuleList, Sequential +from torch import Tensor +from torch.nn import Module, Parameter -import pytorch_lightning as pl -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _compare_version -from pytorch_lightning.utilities.rank_zero import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.imports import _module_available -_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0") -if _TORCH_GREATER_EQUAL_1_10: - from torch._C import _DisableTorchDispatch # type: ignore[attr-defined] - - #################################################################### - # BELOW: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. # - # TODO: Removed once merged and released on PyTorch side # - #################################################################### - - @contextmanager - def enable_python_mode(cls) -> Iterator[None]: - if not hasattr(cls, "__torch_dispatch__"): - raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod") - if not isinstance(cls, type) or not issubclass(cls, (Tensor,)): - raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass") - torch._C._enter_python_mode(cls) - try: - yield - finally: - torch._C._exit_python_mode() - - _tls = threading.local() - _tls.in_call = False - - @contextmanager - def _no_dispatch() -> Iterator[None]: - """Temporarily disables the Python dispatch mode.""" - guard = _DisableTorchDispatch() # noqa F841 - try: - yield - finally: - del guard - - def _handle_arange(func, args, kwargs): - kwargs["device"] = torch.device("cpu") - return torch.empty_like(func(*args, **kwargs), device="meta") - - def _handle_tril(func, args, kwargs): - if args and isinstance(args[0], Tensor): - return torch.empty_like(args[0], device="meta") - - return NotImplemented - - class _MetaContext(Tensor): - _op_handlers: Dict[Callable, Callable] = {} - - @classmethod - def _ensure_handlers_initialized(cls) -> None: - if cls._op_handlers: - return - - cls._op_handlers.update( - { - torch.ops.aten.arange: _handle_arange, - torch.ops.aten.tril: _handle_tril, - } - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - cls._ensure_handlers_initialized() - - op_handler: Optional[Callable] - - try: - op_handler = cls._op_handlers[func] - except KeyError: - op_handler = None - - with _no_dispatch(): - if op_handler: - result = op_handler(func, args, kwargs) - if result is not NotImplemented: - return result - - if "device" in kwargs: - kwargs["device"] = torch.device("meta") - - return func(*args, **kwargs) - - def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module: - def create_instance(module=None) -> Module: - if module: - module.__init__(*args, **kwargs) - return module - return module_fn(*args, **kwargs) - - if _tls.in_call: - module = create_instance() - else: - _tls.in_call = True - try: - with enable_python_mode(_MetaContext): - module = create_instance() - finally: - _tls.in_call = False - - module.materialize = partial(create_instance, module=module) # type: ignore[assignment] - - return module +def is_meta_init() -> bool: + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.is_meta_init` is deprecated in v1.8 and will be removed in v1.9." + " The function has become a no-op." + " Please check out the `torchdistx` project instead: https://github.com/pytorch/torchdistx" + ) + return False - def is_meta_init() -> bool: - """Indicates whether the module is being instantiated by ``init_meta()``.""" - return _tls.in_call - #################################################################### - # ABOVE: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. # - # TODO: Removed once merged and released on PyTorch side # - #################################################################### +def init_meta(module_fn: Callable[..., Module], *args: Any, **kwargs: Any) -> None: + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.init_meta` is deprecated in v1.8 and will be removed in v1.9." + " The function has become a no-op." + " Please check out the `torchdistx` project instead: https://github.com/pytorch/torchdistx" + ) -else: - def init_meta(*_, **__): - if not _TORCH_GREATER_EQUAL_1_10: - return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") +def get_all_subclasses(cls: Type) -> Set[Type]: + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.get_all_subclasses` is deprecated in v1.8 and will be removed in v1.9." + " Please copy its implementation if you have a use for it." + ) + return _get_all_subclasses(cls) # https://stackoverflow.com/a/63851681/9201239 -def get_all_subclasses(cls: Type) -> Set[Type]: +def _get_all_subclasses(cls: Type) -> Set[Type]: subclass_list = [] - def recurse(cl): + def recurse(cl: Type) -> None: for subclass in cl.__subclasses__(): subclass_list.append(subclass) recurse(subclass) @@ -163,7 +60,11 @@ def recurse(cl): return set(subclass_list) -def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module: nn.Module) -> None: +def recursively_setattr(root_module: Any, prefix: str, materialized_module: Module) -> None: + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.recursively_setattr` is deprecated in v1.8 and will be removed in v1.9." + " Please copy its implementation if you have a use for it." + ) *path, name = prefix.split(".") for p in path: root_module = getattr(root_module, p) @@ -175,166 +76,43 @@ def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module setattr(root_module, name, materialized_module) -def materialize_module(root_module: nn.Module) -> nn.Module: - """This utility performs an in-place operation by materialize a module and its children.""" - if not _TORCH_GREATER_EQUAL_1_10: - return root_module - - materialize_fn = getattr(root_module, "materialize", None) - if materialize_fn and not isinstance(root_module, (Sequential, ModuleList, ModuleDict)): - return materialize_fn() - - for name, child in root_module.named_children(): - materialize_fn = getattr(child, "materialize", None) - if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)): - materialize_module(child) - else: - setattr(root_module, name, materialize_fn()) - return root_module - - -# cache subclasses to optimize the search when resetting the meta device later on. -__STORAGE_META__ = {} -__CREATED_MODULES__ = set() - - -def _unset_meta_device(from_created: bool = False) -> None: - """Replace all meta module by their original version.""" - if not _TORCH_GREATER_EQUAL_1_10: - raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") - - if from_created: - values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] - else: - values = __STORAGE_META__.values() - - for mods, subclass, _ in values: - for mod in mods: - setattr(mod, subclass.__name__, subclass) - - -def _set_meta_device_populated(from_created: bool = False) -> None: - """Replace all meta module by their original version.""" - if not _TORCH_GREATER_EQUAL_1_10: - raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") - - if from_created: - values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] - else: - values = __STORAGE_META__.values() - - for mods, subclass, meta_class in values: - for mod in mods: - setattr(mod, subclass.__name__, meta_class) - - -def _set_meta_device() -> None: - """Replace all torch.nn.Module by their meta replacement.""" - - if not _TORCH_GREATER_EQUAL_1_10: - raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") - - # Author note: This can be optimized further by searching all subclasses at once. - # Its time complexity is O(n*m) where n is the number of all subclasses if there's no multiple inheritance - # and m the number of all subclasses belonging to its subclass module. - - for subclass in get_all_subclasses(torch.nn.modules.module.Module): - - if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule): - continue - - # if a subclass has already been stored, we should use the cache - if str(subclass) in __STORAGE_META__: - # reset the class import package to its rightful state. - mods, subclass, meta_class = __STORAGE_META__[subclass] - for mod in mods: - setattr(mod, subclass.__name__, meta_class) - continue - - class _IsinstanceMetaclass(type(subclass)): - def __instancecheck__(self, instance: Any) -> bool: - """Overrides the ``isinstance`` check on ``_MaterializerModule`` objects.""" - return isinstance(instance, self.__bases__[0]) - - # Create a class subclassing current `subclass` overriding its new method. - # this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta` - # version of the current subclass module - class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass): - @classmethod - @contextmanager - def instantiation_context(cls): - _unset_meta_device(from_created=True) - yield - _set_meta_device_populated(from_created=True) - - @classmethod - def materialize(cls, materialize_fn: Callable): - with cls.instantiation_context(): - obj = materialize_fn() - return obj - - @staticmethod - def add_subclasses(subclass): - """This is used to unroll the instantiation tree while creating the modules.""" - # Don't store the LightningModule as skipped from the Meta process. - if subclass != pl.LightningModule: - __CREATED_MODULES__.add(subclass) - if subclass.__bases__[0] != torch.nn.modules.module.Module: - _MaterializerModule.add_subclasses(subclass.__bases__[0]) - - def __new__(cls, *args, **kwargs): - subclass = cls.__bases__[0] - cls.add_subclasses(subclass) - with cls.instantiation_context(): - obj = init_meta(subclass, *args, **kwargs) - - obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) - return obj - - def search(mod: ModuleType) -> List[ModuleType]: - out = [] - for _, obj in inspect.getmembers(mod): - if obj == subclass: - out.append(mod) - return out - - submodules = subclass.__module__.split(".") - mod = importlib.import_module(submodules[0]) - - # nn.Module class can be imported at different level and they all need to be mocked. - # Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear - # Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear - # needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule - out = [search(mod)] - for name in submodules[1:]: - mod = getattr(mod, name) - out.append(search(mod)) - - # drop empty module - mods = [mod for mod in chain(*out) if mod] - - # store the modules search so it doesn't have to be performed again for this class - __STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule) - - # replace all subclass by its meta form - for mod in mods: - setattr(mod, subclass.__name__, _MaterializerModule) +def materialize_module(root_module: Module) -> None: + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.materialize_module` is deprecated in v1.8 and will be removed in v1.9." + " The function has become a no-op." + " Please check out the `torchdistx` project instead: https://github.com/pytorch/torchdistx" + ) @contextmanager def init_meta_context() -> Generator: - rank_zero_warn( - "Be aware this feature is highly experimental and there are a number of weird edge cases " - "where it can internal assert and/or crash. A more stable version is to be expected from PyTorch 1.11." + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.init_meta_context` is deprecated in v1.8 and will be removed in v1.9." + " The function has become a no-op." + " Please check out the `torchdistx` project instead: https://github.com/pytorch/torchdistx" ) - _set_meta_device() yield - _unset_meta_device() -def is_on_meta_device(module: nn.Module) -> bool: +def is_on_meta_device(module: Module) -> bool: + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.is_on_meta_device` is deprecated in v1.8 and will be removed in v1.9." + " Please copy its implementation if you have a use for it." + ) try: param = next(module.parameters()) - return param.device.type == "meta" + return param.is_meta except StopIteration: return False + + +def _is_deferred(module: Optional[Module]) -> bool: + if module is None or not _module_available("torchdistx.fake"): + return False + from torchdistx.fake import is_fake + + def any_fake(tensors: Mapping[str, Optional[Union[Tensor, Parameter]]]) -> bool: + return any(is_fake(t) for t in tensors.values() if t is not None) + + is_deferred = any(_is_deferred(m) for m in module.children()) + return is_deferred or any_fake(module._parameters) or any_fake(module._buffers) diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py index 54c59bec62b5d..dcd8ecfd0169c 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py @@ -217,3 +217,17 @@ def test_gpu_accelerator_deprecation_warning(): ) ): GPUAccelerator() + + +def test_meta_utility_deprecations(): + import pytorch_lightning.utilities.meta as meta + + pytest.deprecated_call(meta.is_meta_init, match="is_meta_init.*removed in v1.9") + pytest.deprecated_call(meta.init_meta, Mock(), match="init_meta.*removed in v1.9") + pytest.deprecated_call(meta.get_all_subclasses, Mock, match="get_all_subclasses.*removed in v1.9") + pytest.deprecated_call(meta.recursively_setattr, Mock(), "foo", 1, match="recursively_setattr.*removed in v1.9") + pytest.deprecated_call(meta.materialize_module, Mock(), match="materialize_module.*removed in v1.9") + with pytest.deprecated_call(match="init_meta_context.*removed in v1.9"): + with meta.init_meta_context(): + pass + pytest.deprecated_call(meta.is_on_meta_device, LightningModule(), match="is_on_meta_device.*removed in v1.9") diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index 1f955a2520faa..14f7ab1e79b08 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -33,7 +33,6 @@ from pytorch_lightning.strategies import DeepSpeedStrategy from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE, LightningDeepSpeedModule from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.meta import init_meta_context from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.datasets import RandomIterableDataset from tests_pytorch.helpers.runif import RunIf @@ -1232,25 +1231,6 @@ def on_test_batch_start( trainer.test(model) -@RunIf(min_cuda_gpus=2, min_torch="1.10.0", max_torch="1.12.0", standalone=True, deepspeed=True) -def test_deepspeed_with_meta_device(tmpdir): - with init_meta_context(): - model = BoringModel() - assert model.layer.weight.device.type == "meta" - trainer = Trainer( - default_root_dir=tmpdir, - strategy=DeepSpeedStrategy(stage=3), - accelerator="gpu", - devices=2, - fast_dev_run=True, - precision=16, - enable_progress_bar=False, - enable_model_summary=False, - ) - trainer.fit(model) - assert model.layer.weight.device.type == "cpu" - - @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) def test_deepspeed_multi_save_same_filepath(tmpdir): """Test that verifies that deepspeed saves only latest checkpoint in the specified path and deletes the old diff --git a/tests/tests_pytorch/utilities/test_meta.py b/tests/tests_pytorch/utilities/test_meta.py deleted file mode 100644 index f7fcce4cb835e..0000000000000 --- a/tests/tests_pytorch/utilities/test_meta.py +++ /dev/null @@ -1,84 +0,0 @@ -# Copyright The PyTorch Lightning team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import pytest -from torch import nn - -from pytorch_lightning.core.module import LightningModule -from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module -from tests_pytorch.helpers.runif import RunIf - - -class MLP(nn.Module): - def __init__(self, num_layers: int): - super().__init__() - self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(num_layers)] + [nn.Dropout(), nn.LayerNorm(1)]) - - -class SimpleBoringModel(LightningModule): - def __init__(self, num_layers: int): - super().__init__() - self.save_hyperparameters() - self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)]) - - -@RunIf(min_torch="1.10.0", max_torch="1.12.0", standalone=True) -def test_init_meta_context(): - - with init_meta_context(): - m = nn.Linear(in_features=1, out_features=1) - assert isinstance(m, nn.Linear) - assert m.weight.device.type == "meta" - assert is_on_meta_device(m) - mlp = MLP(4) - assert mlp.layer[0].weight.device.type == "meta" - - mlp = materialize_module(mlp) - assert mlp.layer[0].weight.device.type == "cpu" - - assert not is_on_meta_device(mlp) - assert not is_on_meta_device(nn.Module()) - - model = SimpleBoringModel(4) - assert model.layer[0].weight.device.type == "meta" - materialize_module(model) - assert model.layer[0].weight.device.type == "cpu" - - mlp = MLP(4) - assert mlp.layer[0].weight.device.type == "cpu" - # no-op as already materialized. - materialize_module(mlp) - assert mlp.layer[0].weight.device.type == "cpu" - - m = nn.Linear(in_features=1, out_features=1) - assert m.weight.device.type == "cpu" - - with init_meta_context(): - m = nn.Linear(in_features=1, out_features=1) - assert m.weight.device.type == "meta" - - m = nn.Linear(in_features=1, out_features=1) - assert m.weight.device.type == "cpu" - - -@RunIf(min_torch="1.10.0", max_torch="1.12.0", standalone=True) -def test_materialize_module_recursive_child(): - """Test materialize_module doesn't set a child recursively to a model instantiated within init_meta_context.""" - with init_meta_context(): - model = BoringModel() - - materialize_module(model) - - with pytest.raises(AttributeError, match="'Linear' object has no attribute 'layer'"): - model.layer.layer diff --git a/tests/tests_pytorch/utilities/test_torchdistx.py b/tests/tests_pytorch/utilities/test_torchdistx.py new file mode 100644 index 0000000000000..aa3f8e34bfaac --- /dev/null +++ b/tests/tests_pytorch/utilities/test_torchdistx.py @@ -0,0 +1,92 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import pytest +from torch import nn + +from pytorch_lightning import Trainer +from pytorch_lightning.core.module import LightningModule +from pytorch_lightning.demos.boring_classes import BoringModel +from pytorch_lightning.utilities.imports import _RequirementAvailable +from pytorch_lightning.utilities.meta import _is_deferred +from tests_pytorch.helpers.runif import RunIf + +_TORCHDISTX_AVAILABLE = _RequirementAvailable("torchdistx") + + +class SimpleBoringModel(LightningModule): + def __init__(self, num_layers): + super().__init__() + self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(num_layers)]) + + +@pytest.mark.skipif(not _TORCHDISTX_AVAILABLE, reason=_TORCHDISTX_AVAILABLE.message) +def test_deferred_init_with_lightning_module(): + from torchdistx.deferred_init import deferred_init, materialize_module + from torchdistx.fake import is_fake + + model = deferred_init(SimpleBoringModel, 4) + weight = model.layer[0].weight + assert weight.device.type == "cpu" + assert is_fake(weight) + assert _is_deferred(model) + + materialize_module(model) + materialize_module(model) # make sure it's idempotent + assert not _is_deferred(model) + weight = model.layer[0].weight + assert weight.device.type == "cpu" + assert not is_fake(weight) + + +@pytest.mark.skipif(not _TORCHDISTX_AVAILABLE, reason=_TORCHDISTX_AVAILABLE.message) +@pytest.mark.parametrize( + "trainer_kwargs", + ( + {"accelerator": "auto", "devices": 1}, + pytest.param( + {"strategy": "deepspeed_stage_3", "accelerator": "gpu", "devices": 2, "precision": 16}, + marks=RunIf(min_cuda_gpus=2, deepspeed=True), + ), + ), +) +def test_deferred_init_with_trainer(tmpdir, trainer_kwargs): + from torchdistx.deferred_init import deferred_init + + model = deferred_init(BoringModel) + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + enable_progress_bar=False, + enable_model_summary=False, + **trainer_kwargs + ) + trainer.fit(model) + + +@pytest.mark.skipif(not _TORCHDISTX_AVAILABLE, reason=_TORCHDISTX_AVAILABLE.message) +def test_deferred_init_ddp_spawn(tmpdir): + from torchdistx.deferred_init import deferred_init + + model = deferred_init(BoringModel) + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + enable_progress_bar=False, + enable_model_summary=False, + accelerator="auto", + devices="1", + strategy="ddp_spawn", + ) + with pytest.raises(NotImplementedError, match="DDPSpawnStrategy` strategy does not support.*torchdistx"): + trainer.fit(model)