From 88e469304ba02650580aa497bed90fa418800741 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Jul 2022 02:57:45 +0200 Subject: [PATCH 01/14] Remove meta device utilities in favor of torchdistx --- pyproject.toml | 1 - .../core/mixins/device_dtype_mixin.py | 8 +- src/pytorch_lightning/trainer/trainer.py | 22 +- src/pytorch_lightning/utilities/meta.py | 320 +++--------------- .../strategies/test_deepspeed_strategy.py | 20 -- tests/tests_pytorch/utilities/test_meta.py | 93 ++--- 6 files changed, 88 insertions(+), 376 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 6d973aa0dde51..77538a22d8b09 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,6 +75,5 @@ module = [ "pytorch_lightning.tuner.batch_size_scaling", "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.data", - "pytorch_lightning.utilities.meta", ] ignore_errors = "True" diff --git a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py index b12e1cf042a1f..62e81e4839da6 100644 --- a/src/pytorch_lightning/core/mixins/device_dtype_mixin.py +++ b/src/pytorch_lightning/core/mixins/device_dtype_mixin.py @@ -18,8 +18,6 @@ from torch.nn import Module from typing_extensions import Self -import pytorch_lightning as pl - class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ["device", "dtype"] @@ -180,10 +178,8 @@ def half(self) -> Self: # type: ignore[valid-type] def __update_properties( self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None ) -> None: - def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None: - # TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't - # work when using `init_meta_context`. - if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)): + def apply_fn(module: Union[DeviceDtypeModuleMixin, Module]) -> None: + if not isinstance(module, DeviceDtypeModuleMixin): return if device is not None: module._device = device diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 561fe799f1010..6d67fa2c716b9 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -70,7 +70,6 @@ XLAProfiler, ) from pytorch_lightning.strategies import ParallelStrategy, Strategy -from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector @@ -106,8 +105,7 @@ from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks from pytorch_lightning.utilities.distributed import distributed_available from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException -from pytorch_lightning.utilities.imports import _fault_tolerant_training -from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module +from pytorch_lightning.utilities.imports import _fault_tolerant_training, _module_available from pytorch_lightning.utilities.model_helpers import is_overridden from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn from pytorch_lightning.utilities.seed import isolate_rng @@ -1451,20 +1449,14 @@ def _call_setup_hook(self) -> None: def _call_configure_sharded_model(self) -> None: with self.strategy.model_sharded_context(): - self._handle_meta_model() - self._call_lightning_module_hook("configure_sharded_model") - self._call_callback_hooks("on_configure_sharded_model") - - def _handle_meta_model(self) -> None: - if not is_on_meta_device(self.lightning_module): - return + # experimental support for torchdistx + if _module_available("torchdistx.deferred_init"): + from torchdistx.deferred_init import materialize_module - if isinstance(self.strategy, DDPSpawnStrategy): - raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.") + materialize_module(self.lightning_module) - materialize_module(self.lightning_module) - # the trainer reference is lost during materialization - self.lightning_module.trainer = proxy(self) + self._call_lightning_module_hook("configure_sharded_model") + self._call_callback_hooks("on_configure_sharded_model") def _call_teardown_hook(self) -> None: fn = self.state.fn._setup_fn diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index 77da02f7231d4..eb47e074940ed 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -11,146 +11,41 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import importlib -import inspect -import operator -import threading from contextlib import contextmanager -from functools import partial -from itertools import chain -from types import ModuleType -from typing import Any, Callable, Dict, Generator, Iterator, List, Optional, Set, Type +from typing import Any, Callable, Generator, Set, Type -import torch -from torch import nn, Tensor -from torch.nn import Module -from torch.nn.modules.container import ModuleDict, ModuleList, Sequential +from torch import Module, nn -import pytorch_lightning as pl -from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _compare_version -from pytorch_lightning.utilities.rank_zero import rank_zero_warn +from pytorch_lightning.utilities import rank_zero_deprecation -_TORCH_GREATER_EQUAL_1_10 = _compare_version("torch", operator.ge, "1.10.0") -if _TORCH_GREATER_EQUAL_1_10: - from torch._C import _DisableTorchDispatch # type: ignore[attr-defined] - - #################################################################### - # BELOW: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. # - # TODO: Removed once merged and released on PyTorch side # - #################################################################### - - @contextmanager - def enable_python_mode(cls) -> Iterator[None]: - if not hasattr(cls, "__torch_dispatch__"): - raise ValueError("The class passed to enable_python_mode " "must have a __torch_dispatch__ classmethod") - if not isinstance(cls, type) or not issubclass(cls, (Tensor,)): - raise ValueError("The argument passed to enable_python_mode " "must be the type of a Tensor subclass") - torch._C._enter_python_mode(cls) - try: - yield - finally: - torch._C._exit_python_mode() - - _tls = threading.local() - _tls.in_call = False - - @contextmanager - def _no_dispatch() -> Iterator[None]: - """Temporarily disables the Python dispatch mode.""" - guard = _DisableTorchDispatch() # noqa F841 - try: - yield - finally: - del guard - - def _handle_arange(func, args, kwargs): - kwargs["device"] = torch.device("cpu") - return torch.empty_like(func(*args, **kwargs), device="meta") - - def _handle_tril(func, args, kwargs): - if args and isinstance(args[0], Tensor): - return torch.empty_like(args[0], device="meta") - - return NotImplemented - - class _MetaContext(Tensor): - _op_handlers: Dict[Callable, Callable] = {} - - @classmethod - def _ensure_handlers_initialized(cls) -> None: - if cls._op_handlers: - return - - cls._op_handlers.update( - { - torch.ops.aten.arange: _handle_arange, - torch.ops.aten.tril: _handle_tril, - } - ) - - @classmethod - def __torch_dispatch__(cls, func, types, args=(), kwargs=None): - cls._ensure_handlers_initialized() - - op_handler: Optional[Callable] - - try: - op_handler = cls._op_handlers[func] - except KeyError: - op_handler = None - - with _no_dispatch(): - if op_handler: - result = op_handler(func, args, kwargs) - if result is not NotImplemented: - return result - - if "device" in kwargs: - kwargs["device"] = torch.device("meta") - - return func(*args, **kwargs) - - def init_meta(module_fn: Callable[..., Module], *args, **kwargs) -> Module: - def create_instance(module=None) -> Module: - if module: - module.__init__(*args, **kwargs) - return module - return module_fn(*args, **kwargs) - - if _tls.in_call: - module = create_instance() - else: - _tls.in_call = True - try: - with enable_python_mode(_MetaContext): - module = create_instance() - finally: - _tls.in_call = False - - module.materialize = partial(create_instance, module=module) # type: ignore[assignment] - - return module +def is_meta_init() -> bool: + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.is_meta_init` is deprecated in v1.8 and will be removed in v1.9." + " The function has become a no-op." + " Please check out the `torchdistx` project instead: https://github.com/pytorch/torchdistx" + ) + return False - def is_meta_init() -> bool: - """Indicates whether the module is being instantiated by ``init_meta()``.""" - return _tls.in_call - #################################################################### - # ABOVE: TAKEN FROM https://github.com/pytorch/pytorch/pull/66317. # - # TODO: Removed once merged and released on PyTorch side # - #################################################################### +def init_meta(module_fn: Callable[..., Module], *args: Any, **kwargs: Any) -> None: + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.init_meta` is deprecated in v1.8 and will be removed in v1.9." + " The function has become a no-op." + " Please check out the `torchdistx` project instead: https://github.com/pytorch/torchdistx" + ) -else: - def init_meta(*_, **__): - if not _TORCH_GREATER_EQUAL_1_10: - return MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") +def get_all_subclasses(cls: Type) -> Set[Type]: + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.get_all_subclasses` is deprecated in v1.8 and will be removed in v1.9." + " Please copy its implementation if you have an use for it." + ) + return _get_all_subclasses(cls) # https://stackoverflow.com/a/63851681/9201239 -def get_all_subclasses(cls: Type) -> Set[Type]: +def _get_all_subclasses(cls: Type) -> Set[Type]: subclass_list = [] def recurse(cl): @@ -164,6 +59,10 @@ def recurse(cl): def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module: nn.Module) -> None: + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.recursively_setattr` is deprecated in v1.8 and will be removed in v1.9." + " Please copy its implementation if you have an use for it." + ) *path, name = prefix.split(".") for p in path: root_module = getattr(root_module, p) @@ -175,166 +74,31 @@ def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module setattr(root_module, name, materialized_module) -def materialize_module(root_module: nn.Module) -> nn.Module: - """This utility performs an in-place operation by materialize a module and its children.""" - if not _TORCH_GREATER_EQUAL_1_10: - return root_module - - materialize_fn = getattr(root_module, "materialize", None) - if materialize_fn and not isinstance(root_module, (Sequential, ModuleList, ModuleDict)): - return materialize_fn() - - for name, child in root_module.named_children(): - materialize_fn = getattr(child, "materialize", None) - if not materialize_fn or isinstance(child, (Sequential, ModuleList, ModuleDict)): - materialize_module(child) - else: - setattr(root_module, name, materialize_fn()) - return root_module - - -# cache subclasses to optimize the search when resetting the meta device later on. -__STORAGE_META__ = {} -__CREATED_MODULES__ = set() - - -def _unset_meta_device(from_created: bool = False) -> None: - """Replace all meta module by their original version.""" - if not _TORCH_GREATER_EQUAL_1_10: - raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") - - if from_created: - values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] - else: - values = __STORAGE_META__.values() - - for mods, subclass, _ in values: - for mod in mods: - setattr(mod, subclass.__name__, subclass) - - -def _set_meta_device_populated(from_created: bool = False) -> None: - """Replace all meta module by their original version.""" - if not _TORCH_GREATER_EQUAL_1_10: - raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") - - if from_created: - values = [__STORAGE_META__[key] for key in __CREATED_MODULES__] - else: - values = __STORAGE_META__.values() - - for mods, subclass, meta_class in values: - for mod in mods: - setattr(mod, subclass.__name__, meta_class) - - -def _set_meta_device() -> None: - """Replace all torch.nn.Module by their meta replacement.""" - - if not _TORCH_GREATER_EQUAL_1_10: - raise MisconfigurationException("`init_meta` is supported from PyTorch 1.10.0") - - # Author note: This can be optimized further by searching all subclasses at once. - # Its time complexity is O(n*m) where n is the number of all subclasses if there's no multiple inheritance - # and m the number of all subclasses belonging to its subclass module. - - for subclass in get_all_subclasses(torch.nn.modules.module.Module): - - if subclass in (Sequential, ModuleList, ModuleDict, pl.LightningModule): - continue - - # if a subclass has already been stored, we should use the cache - if str(subclass) in __STORAGE_META__: - # reset the class import package to its rightful state. - mods, subclass, meta_class = __STORAGE_META__[subclass] - for mod in mods: - setattr(mod, subclass.__name__, meta_class) - continue - - class _IsinstanceMetaclass(type(subclass)): - def __instancecheck__(self, instance: Any) -> bool: - """Overrides the ``isinstance`` check on ``_MaterializerModule`` objects.""" - return isinstance(instance, self.__bases__[0]) - - # Create a class subclassing current `subclass` overriding its new method. - # this will enable use to use `torch.distributed.nn.utils.init_meta` to create a `meta` - # version of the current subclass module - class _MaterializerModule(subclass, metaclass=_IsinstanceMetaclass): - @classmethod - @contextmanager - def instantiation_context(cls): - _unset_meta_device(from_created=True) - yield - _set_meta_device_populated(from_created=True) - - @classmethod - def materialize(cls, materialize_fn: Callable): - with cls.instantiation_context(): - obj = materialize_fn() - return obj - - @staticmethod - def add_subclasses(subclass): - """This is used to unroll the instantiation tree while creating the modules.""" - # Don't store the LightningModule as skipped from the Meta process. - if subclass != pl.LightningModule: - __CREATED_MODULES__.add(subclass) - if subclass.__bases__[0] != torch.nn.modules.module.Module: - _MaterializerModule.add_subclasses(subclass.__bases__[0]) - - def __new__(cls, *args, **kwargs): - subclass = cls.__bases__[0] - cls.add_subclasses(subclass) - with cls.instantiation_context(): - obj = init_meta(subclass, *args, **kwargs) - - obj.materialize = partial(cls.materialize, materialize_fn=obj.materialize) - return obj - - def search(mod: ModuleType) -> List[ModuleType]: - out = [] - for _, obj in inspect.getmembers(mod): - if obj == subclass: - out.append(mod) - return out - - submodules = subclass.__module__.split(".") - mod = importlib.import_module(submodules[0]) - - # nn.Module class can be imported at different level and they all need to be mocked. - # Example: torch.nn.Linear is actually torch.nn.modules.linear.Linear - # Therefore, torch.nn.Linear, torch.nn.modules.Linear, torch.nn.modules.linear.Linear - # needs to be replaced by the torch.nn.linear.modules.Linear _MaterializerModule - out = [search(mod)] - for name in submodules[1:]: - mod = getattr(mod, name) - out.append(search(mod)) - - # drop empty module - mods = [mod for mod in chain(*out) if mod] - - # store the modules search so it doesn't have to be performed again for this class - __STORAGE_META__[subclass] = (mods, subclass, _MaterializerModule) - - # replace all subclass by its meta form - for mod in mods: - setattr(mod, subclass.__name__, _MaterializerModule) +def materialize_module(root_module: nn.Module) -> None: + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.materialize_module` is deprecated in v1.8 and will be removed in v1.9." + " The function has become a no-op." + " Please check out the `torchdistx` project instead: https://github.com/pytorch/torchdistx" + ) @contextmanager def init_meta_context() -> Generator: - rank_zero_warn( - "Be aware this feature is highly experimental and there are a number of weird edge cases " - "where it can internal assert and/or crash. A more stable version is to be expected from PyTorch 1.11." + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.init_meta_context` is deprecated in v1.8 and will be removed in v1.9." + " The function has become a no-op." + " Please check out the `torchdistx` project instead: https://github.com/pytorch/torchdistx" ) - _set_meta_device() yield - _unset_meta_device() def is_on_meta_device(module: nn.Module) -> bool: + rank_zero_deprecation( + "`pytorch_lightning.utilities.meta.is_on_meta_device` is deprecated in v1.8 and will be removed in v1.9." + " Please copy its implementation if you have an use for it." + ) try: param = next(module.parameters()) - return param.device.type == "meta" + return param.is_meta except StopIteration: return False diff --git a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py index 79562134f9ccb..8cad3a2f89d26 100644 --- a/tests/tests_pytorch/strategies/test_deepspeed_strategy.py +++ b/tests/tests_pytorch/strategies/test_deepspeed_strategy.py @@ -35,7 +35,6 @@ from pytorch_lightning.strategies.deepspeed import _DEEPSPEED_AVAILABLE, LightningDeepSpeedModule from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _RequirementAvailable -from pytorch_lightning.utilities.meta import init_meta_context from tests_pytorch.helpers.datamodules import ClassifDataModule from tests_pytorch.helpers.datasets import RandomIterableDataset from tests_pytorch.helpers.runif import RunIf @@ -1239,25 +1238,6 @@ def on_test_batch_start( trainer.test(model) -@RunIf(min_cuda_gpus=2, min_torch="1.10.0", standalone=True, deepspeed=True) -def test_deepspeed_with_meta_device(tmpdir): - with init_meta_context(): - model = BoringModel() - assert model.layer.weight.device.type == "meta" - trainer = Trainer( - default_root_dir=tmpdir, - strategy=DeepSpeedStrategy(stage=3), - accelerator="gpu", - devices=2, - fast_dev_run=True, - precision=16, - enable_progress_bar=False, - enable_model_summary=False, - ) - trainer.fit(model) - assert model.layer.weight.device.type == "cpu" - - @RunIf(min_cuda_gpus=2, standalone=True, deepspeed=True) def test_deepspeed_multi_save_same_filepath(tmpdir): """Test that verifies that deepspeed saves only latest checkpoint in the specified path and deletes the old diff --git a/tests/tests_pytorch/utilities/test_meta.py b/tests/tests_pytorch/utilities/test_meta.py index b19483e29bbe2..9f844cfe15db4 100644 --- a/tests/tests_pytorch/utilities/test_meta.py +++ b/tests/tests_pytorch/utilities/test_meta.py @@ -14,71 +14,52 @@ import pytest from torch import nn +from pytorch_lightning import Trainer from pytorch_lightning.core.module import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel -from pytorch_lightning.utilities.meta import init_meta_context, is_on_meta_device, materialize_module +from pytorch_lightning.utilities.imports import _RequirementAvailable from tests_pytorch.helpers.runif import RunIf - -class MLP(nn.Module): - def __init__(self, num_layers: int): - super().__init__() - self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(num_layers)] + [nn.Dropout(), nn.LayerNorm(1)]) +_TORCHDISTX_AVAILABLE = _RequirementAvailable("torchdistx") class SimpleBoringModel(LightningModule): - def __init__(self, num_layers: int): + def __init__(self, num_layers): super().__init__() - self.save_hyperparameters() - self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(self.hparams.num_layers)]) - - -@RunIf(min_torch="1.10.0", standalone=True) -def test_init_meta_context(): - - with init_meta_context(): - m = nn.Linear(in_features=1, out_features=1) - assert isinstance(m, nn.Linear) - assert m.weight.device.type == "meta" - assert is_on_meta_device(m) - mlp = MLP(4) - assert mlp.layer[0].weight.device.type == "meta" - - mlp = materialize_module(mlp) - assert mlp.layer[0].weight.device.type == "cpu" - - assert not is_on_meta_device(mlp) - assert not is_on_meta_device(nn.Module()) + self.layer = nn.Sequential(*[nn.Linear(1, 1) for _ in range(num_layers)]) - model = SimpleBoringModel(4) - assert model.layer[0].weight.device.type == "meta" - materialize_module(model) - assert model.layer[0].weight.device.type == "cpu" - mlp = MLP(4) - assert mlp.layer[0].weight.device.type == "cpu" - # no-op as already materialized. - materialize_module(mlp) - assert mlp.layer[0].weight.device.type == "cpu" - - m = nn.Linear(in_features=1, out_features=1) - assert m.weight.device.type == "cpu" - - with init_meta_context(): - m = nn.Linear(in_features=1, out_features=1) - assert m.weight.device.type == "meta" - - m = nn.Linear(in_features=1, out_features=1) - assert m.weight.device.type == "cpu" - - -@RunIf(min_torch="1.10.0", standalone=True) -def test_materialize_module_recursive_child(): - """Test materialize_module doesn't set a child recursively to a model instantiated within init_meta_context.""" - with init_meta_context(): - model = BoringModel() +@pytest.mark.skipif(not _TORCHDISTX_AVAILABLE, reason=_TORCHDISTX_AVAILABLE.message) +def test_deferred_init_with_lightning_module(): + from torchdistx.deferred_init import deferred_init, materialize_module + model = deferred_init(SimpleBoringModel, 4) + assert model.layer[0].weight.device.type == "cpu" materialize_module(model) - - with pytest.raises(AttributeError, match="'Linear' object has no attribute 'layer'"): - model.layer.layer + materialize_module(model) # make sure it's idempotent + assert model.layer[0].weight.device.type == "cpu" + + +@pytest.mark.skipif(not _TORCHDISTX_AVAILABLE, reason=_TORCHDISTX_AVAILABLE.message) +@pytest.mark.parametrize( + "trainer_kwargs", + ( + {"accelerator": "auto"}, + pytest.param( + {"strategy": "deepspeed_stage_3", "accelerator": "gpu", "devices": 2, "precision": 16}, + marks=RunIf(min_cuda_gpus=2, deepspeed=True), + ), + ), +) +def test_deferred_init_with_trainer(tmpdir, trainer_kwargs): + from torchdistx.deferred_init import deferred_init + + model = deferred_init(BoringModel) + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + enable_progress_bar=False, + enable_model_summary=False, + **trainer_kwargs + ) + trainer.fit(model) From db1207c2d8c2ae0c2c8ea771b5aaea35fdd44707 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Jul 2022 03:07:54 +0200 Subject: [PATCH 02/14] Deprecation tests --- src/pytorch_lightning/utilities/meta.py | 8 ++++---- .../deprecated_api/test_remove_1-9.py | 14 ++++++++++++++ 2 files changed, 18 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index eb47e074940ed..73dd80c00a1fb 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -14,7 +14,7 @@ from contextlib import contextmanager from typing import Any, Callable, Generator, Set, Type -from torch import Module, nn +from torch.nn import Module from pytorch_lightning.utilities import rank_zero_deprecation @@ -58,7 +58,7 @@ def recurse(cl): return set(subclass_list) -def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module: nn.Module) -> None: +def recursively_setattr(root_module: Module, prefix: str, materialized_module: Module) -> None: rank_zero_deprecation( "`pytorch_lightning.utilities.meta.recursively_setattr` is deprecated in v1.8 and will be removed in v1.9." " Please copy its implementation if you have an use for it." @@ -74,7 +74,7 @@ def recursively_setattr(root_module: nn.Module, prefix: str, materialized_module setattr(root_module, name, materialized_module) -def materialize_module(root_module: nn.Module) -> None: +def materialize_module(root_module: Module) -> None: rank_zero_deprecation( "`pytorch_lightning.utilities.meta.materialize_module` is deprecated in v1.8 and will be removed in v1.9." " The function has become a no-op." @@ -92,7 +92,7 @@ def init_meta_context() -> Generator: yield -def is_on_meta_device(module: nn.Module) -> bool: +def is_on_meta_device(module: Module) -> bool: rank_zero_deprecation( "`pytorch_lightning.utilities.meta.is_on_meta_device` is deprecated in v1.8 and will be removed in v1.9." " Please copy its implementation if you have an use for it." diff --git a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py index 54c59bec62b5d..dcd8ecfd0169c 100644 --- a/tests/tests_pytorch/deprecated_api/test_remove_1-9.py +++ b/tests/tests_pytorch/deprecated_api/test_remove_1-9.py @@ -217,3 +217,17 @@ def test_gpu_accelerator_deprecation_warning(): ) ): GPUAccelerator() + + +def test_meta_utility_deprecations(): + import pytorch_lightning.utilities.meta as meta + + pytest.deprecated_call(meta.is_meta_init, match="is_meta_init.*removed in v1.9") + pytest.deprecated_call(meta.init_meta, Mock(), match="init_meta.*removed in v1.9") + pytest.deprecated_call(meta.get_all_subclasses, Mock, match="get_all_subclasses.*removed in v1.9") + pytest.deprecated_call(meta.recursively_setattr, Mock(), "foo", 1, match="recursively_setattr.*removed in v1.9") + pytest.deprecated_call(meta.materialize_module, Mock(), match="materialize_module.*removed in v1.9") + with pytest.deprecated_call(match="init_meta_context.*removed in v1.9"): + with meta.init_meta_context(): + pass + pytest.deprecated_call(meta.is_on_meta_device, LightningModule(), match="is_on_meta_device.*removed in v1.9") From bcc900333d846020d09d1492721b36e345fa7fef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Jul 2022 03:12:39 +0200 Subject: [PATCH 03/14] Reuse code --- src/pytorch_lightning/utilities/cli.py | 14 +++++++------- src/pytorch_lightning/utilities/data.py | 17 ++--------------- 2 files changed, 9 insertions(+), 22 deletions(-) diff --git a/src/pytorch_lightning/utilities/cli.py b/src/pytorch_lightning/utilities/cli.py index 285b5361f9cd2..8af919b78ce93 100644 --- a/src/pytorch_lightning/utilities/cli.py +++ b/src/pytorch_lightning/utilities/cli.py @@ -22,7 +22,7 @@ import pytorch_lightning as pl import pytorch_lightning.cli as new_cli -from pytorch_lightning.utilities.meta import get_all_subclasses +from pytorch_lightning.utilities.meta import _get_all_subclasses from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation _deprecate_registry_message = ( @@ -108,17 +108,17 @@ def _populate_registries(subclasses: bool) -> None: # Remove in v1.9 if subclasses: rank_zero_deprecation(_deprecate_auto_registry_message) # this will register any subclasses from all loaded modules including userland - for cls in get_all_subclasses(torch.optim.Optimizer): + for cls in _get_all_subclasses(torch.optim.Optimizer): OPTIMIZER_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler): + for cls in _get_all_subclasses(torch.optim.lr_scheduler._LRScheduler): LR_SCHEDULER_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(pl.Callback): + for cls in _get_all_subclasses(pl.Callback): CALLBACK_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(pl.LightningModule): + for cls in _get_all_subclasses(pl.LightningModule): MODEL_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(pl.LightningDataModule): + for cls in _get_all_subclasses(pl.LightningDataModule): DATAMODULE_REGISTRY(cls, show_deprecation=False) - for cls in get_all_subclasses(pl.loggers.Logger): + for cls in _get_all_subclasses(pl.loggers.Logger): LOGGER_REGISTRY(cls, show_deprecation=False) else: # manually register torch's subclasses and our subclasses diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 2de82ceff088e..0dfa223dd5cbf 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -17,7 +17,7 @@ from contextlib import contextmanager from dataclasses import fields from functools import partial -from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Tuple, Union import torch from torch import Tensor @@ -38,6 +38,7 @@ from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler from pytorch_lightning.utilities.enums import _FaultTolerantMode from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.utilities.meta import _get_all_subclasses from pytorch_lightning.utilities.rank_zero import rank_zero_warn from pytorch_lightning.utilities.seed import pl_worker_init_function from pytorch_lightning.utilities.warnings import WarningCache @@ -389,20 +390,6 @@ def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None: return wrapper -# https://stackoverflow.com/a/63851681/9201239 -def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]: - """Returns a list of all classes that inherit directly or indirectly from the given class.""" - subclasses = set() - - def recurse(cl: Type[Any]) -> None: - for subclass in cl.__subclasses__(): - subclasses.add(subclass) - recurse(subclass) - - recurse(cls) - return subclasses - - @contextmanager def _replace_dataloader_init_method() -> Generator[None, None, None]: """This context manager is used to add support for re-instantiation of custom (subclasses) of From 8095f1c768b681ac3312eec93658e6618d90357d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Jul 2022 04:00:05 +0200 Subject: [PATCH 04/14] is_deferred --- src/pytorch_lightning/trainer/trainer.py | 12 ++++++++++-- src/pytorch_lightning/utilities/meta.py | 16 +++++++++++++++- tests/tests_pytorch/utilities/test_meta.py | 16 +++++++++++++--- 3 files changed, 38 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 6d67fa2c716b9..1b0f33ce8037a 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -69,7 +69,7 @@ SimpleProfiler, XLAProfiler, ) -from pytorch_lightning.strategies import ParallelStrategy, Strategy +from pytorch_lightning.strategies import DDPSpawnStrategy, ParallelStrategy, Strategy from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector @@ -1453,7 +1453,15 @@ def _call_configure_sharded_model(self) -> None: if _module_available("torchdistx.deferred_init"): from torchdistx.deferred_init import materialize_module - materialize_module(self.lightning_module) + from pytorch_lightning.utilities.meta import _is_deferred + + if _is_deferred(self.lightning_module): + if isinstance(self.strategy, DDPSpawnStrategy): + raise NotImplementedError( + f"The {type(self.strategy).__name__} strategy does not support `torchdistx`'s deferred" + f" initialization." + ) + materialize_module(self.lightning_module) self._call_lightning_module_hook("configure_sharded_model") self._call_callback_hooks("on_configure_sharded_model") diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index 73dd80c00a1fb..f507837f95528 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Callable, Generator, Set, Type +from typing import Any, Callable, Dict, Generator, Optional, Set, Type +from torch import Tensor from torch.nn import Module from pytorch_lightning.utilities import rank_zero_deprecation +from pytorch_lightning.utilities.imports import _module_available def is_meta_init() -> bool: @@ -102,3 +104,15 @@ def is_on_meta_device(module: Module) -> bool: return param.is_meta except StopIteration: return False + + +def _is_deferred(module: Module) -> bool: + if not _module_available("torchdistx.fake"): + return False + from torchdistx.fake import is_fake + + def any_fake(tensors: Dict[str, Optional[Tensor]]) -> bool: + return any(is_fake(t) for t in tensors.values() if t is not None) + + is_deferred = any(_is_deferred(m) for m in module.children()) + return is_deferred or any_fake(module._parameters) or any_fake(module._buffers) diff --git a/tests/tests_pytorch/utilities/test_meta.py b/tests/tests_pytorch/utilities/test_meta.py index 9f844cfe15db4..2739fdd850c42 100644 --- a/tests/tests_pytorch/utilities/test_meta.py +++ b/tests/tests_pytorch/utilities/test_meta.py @@ -32,19 +32,25 @@ def __init__(self, num_layers): @pytest.mark.skipif(not _TORCHDISTX_AVAILABLE, reason=_TORCHDISTX_AVAILABLE.message) def test_deferred_init_with_lightning_module(): from torchdistx.deferred_init import deferred_init, materialize_module + from torchdistx.fake import is_fake model = deferred_init(SimpleBoringModel, 4) - assert model.layer[0].weight.device.type == "cpu" + weight = model.layer[0].weight + assert weight.device.type == "cpu" + assert is_fake(weight) + materialize_module(model) materialize_module(model) # make sure it's idempotent - assert model.layer[0].weight.device.type == "cpu" + weight = model.layer[0].weight + assert weight.device.type == "cpu" + assert not is_fake(weight) @pytest.mark.skipif(not _TORCHDISTX_AVAILABLE, reason=_TORCHDISTX_AVAILABLE.message) @pytest.mark.parametrize( "trainer_kwargs", ( - {"accelerator": "auto"}, + {"accelerator": "auto", "devices": 1}, pytest.param( {"strategy": "deepspeed_stage_3", "accelerator": "gpu", "devices": 2, "precision": 16}, marks=RunIf(min_cuda_gpus=2, deepspeed=True), @@ -63,3 +69,7 @@ def test_deferred_init_with_trainer(tmpdir, trainer_kwargs): **trainer_kwargs ) trainer.fit(model) + + +def test_deferred_init_ddp_spawn(): + ... # FIXME From 1da7c0a4c37bd0ee9106a34384f27e4c7a2582da Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Jul 2022 04:12:48 +0200 Subject: [PATCH 05/14] ddp spawn sanity check --- src/pytorch_lightning/trainer/trainer.py | 20 +++++++++++--------- tests/tests_pytorch/utilities/test_meta.py | 21 +++++++++++++++++++-- 2 files changed, 30 insertions(+), 11 deletions(-) diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 1b0f33ce8037a..aac4cb8ac8418 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -647,6 +647,7 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: if self.strategy.launcher is not None: + self._check_torchdistx_support() return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs) else: return trainer_fn(*args, **kwargs) @@ -1453,15 +1454,7 @@ def _call_configure_sharded_model(self) -> None: if _module_available("torchdistx.deferred_init"): from torchdistx.deferred_init import materialize_module - from pytorch_lightning.utilities.meta import _is_deferred - - if _is_deferred(self.lightning_module): - if isinstance(self.strategy, DDPSpawnStrategy): - raise NotImplementedError( - f"The {type(self.strategy).__name__} strategy does not support `torchdistx`'s deferred" - f" initialization." - ) - materialize_module(self.lightning_module) + materialize_module(self.lightning_module) self._call_lightning_module_hook("configure_sharded_model") self._call_callback_hooks("on_configure_sharded_model") @@ -1790,6 +1783,15 @@ def _log_device_info(self) -> None: f" `Trainer(accelerator='mps', devices={MPSAccelerator.auto_device_count()})`." ) + def _check_torchdistx_support(self) -> None: + from pytorch_lightning.utilities.meta import _is_deferred + + if _is_deferred(self.lightning_module) and isinstance(self.strategy, DDPSpawnStrategy): + raise NotImplementedError( + f"The `{type(self.strategy).__name__}` strategy does not support `torchdistx`'s deferred" + f" initialization." + ) + """ Data loading methods """ diff --git a/tests/tests_pytorch/utilities/test_meta.py b/tests/tests_pytorch/utilities/test_meta.py index 2739fdd850c42..aa3f8e34bfaac 100644 --- a/tests/tests_pytorch/utilities/test_meta.py +++ b/tests/tests_pytorch/utilities/test_meta.py @@ -18,6 +18,7 @@ from pytorch_lightning.core.module import LightningModule from pytorch_lightning.demos.boring_classes import BoringModel from pytorch_lightning.utilities.imports import _RequirementAvailable +from pytorch_lightning.utilities.meta import _is_deferred from tests_pytorch.helpers.runif import RunIf _TORCHDISTX_AVAILABLE = _RequirementAvailable("torchdistx") @@ -38,9 +39,11 @@ def test_deferred_init_with_lightning_module(): weight = model.layer[0].weight assert weight.device.type == "cpu" assert is_fake(weight) + assert _is_deferred(model) materialize_module(model) materialize_module(model) # make sure it's idempotent + assert not _is_deferred(model) weight = model.layer[0].weight assert weight.device.type == "cpu" assert not is_fake(weight) @@ -71,5 +74,19 @@ def test_deferred_init_with_trainer(tmpdir, trainer_kwargs): trainer.fit(model) -def test_deferred_init_ddp_spawn(): - ... # FIXME +@pytest.mark.skipif(not _TORCHDISTX_AVAILABLE, reason=_TORCHDISTX_AVAILABLE.message) +def test_deferred_init_ddp_spawn(tmpdir): + from torchdistx.deferred_init import deferred_init + + model = deferred_init(BoringModel) + trainer = Trainer( + default_root_dir=tmpdir, + fast_dev_run=True, + enable_progress_bar=False, + enable_model_summary=False, + accelerator="auto", + devices="1", + strategy="ddp_spawn", + ) + with pytest.raises(NotImplementedError, match="DDPSpawnStrategy` strategy does not support.*torchdistx"): + trainer.fit(model) From 36e5937e44e28413b62ee7c67c21c9308f44b2ed Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Jul 2022 04:13:11 +0200 Subject: [PATCH 06/14] Rename --- .../tests_pytorch/utilities/{test_meta.py => test_torchdistx.py} | 0 1 file changed, 0 insertions(+), 0 deletions(-) rename tests/tests_pytorch/utilities/{test_meta.py => test_torchdistx.py} (100%) diff --git a/tests/tests_pytorch/utilities/test_meta.py b/tests/tests_pytorch/utilities/test_torchdistx.py similarity index 100% rename from tests/tests_pytorch/utilities/test_meta.py rename to tests/tests_pytorch/utilities/test_torchdistx.py From 722b60f60c8e305ee5ecdee9c78604be14113b14 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Jul 2022 04:22:07 +0200 Subject: [PATCH 07/14] mypy --- src/pytorch_lightning/utilities/meta.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index f507837f95528..af589becfe1d3 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -12,10 +12,10 @@ # See the License for the specific language governing permissions and # limitations under the License. from contextlib import contextmanager -from typing import Any, Callable, Dict, Generator, Optional, Set, Type +from typing import Any, Callable, Generator, Mapping, Optional, Set, Type, Union from torch import Tensor -from torch.nn import Module +from torch.nn import Module, Parameter from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.imports import _module_available @@ -50,7 +50,7 @@ def get_all_subclasses(cls: Type) -> Set[Type]: def _get_all_subclasses(cls: Type) -> Set[Type]: subclass_list = [] - def recurse(cl): + def recurse(cl: Type) -> None: for subclass in cl.__subclasses__(): subclass_list.append(subclass) recurse(subclass) @@ -60,7 +60,7 @@ def recurse(cl): return set(subclass_list) -def recursively_setattr(root_module: Module, prefix: str, materialized_module: Module) -> None: +def recursively_setattr(root_module: Any, prefix: str, materialized_module: Module) -> None: rank_zero_deprecation( "`pytorch_lightning.utilities.meta.recursively_setattr` is deprecated in v1.8 and will be removed in v1.9." " Please copy its implementation if you have an use for it." @@ -111,7 +111,7 @@ def _is_deferred(module: Module) -> bool: return False from torchdistx.fake import is_fake - def any_fake(tensors: Dict[str, Optional[Tensor]]) -> bool: + def any_fake(tensors: Mapping[str, Optional[Union[Tensor, Parameter]]]) -> bool: return any(is_fake(t) for t in tensors.values() if t is not None) is_deferred = any(_is_deferred(m) for m in module.children()) From 53fb65f586f246c8a88bc38c90b0d80a12a9bad2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Jul 2022 11:36:37 +0200 Subject: [PATCH 08/14] Install torchdistx --- requirements/pytorch/test.txt | 2 ++ 1 file changed, 2 insertions(+) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index ce54cd087b1de..e396fc8e81ff8 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -16,3 +16,5 @@ psutil # for `DeviceStatsMonitor` pandas # needed in benchmarks fastapi uvicorn +--extra-index-url https://download.pytorch.org/whl/cpu +torchdistx From 00df5d3c5877d9dc385537bdc5d47a203f90bf39 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Wed, 27 Jul 2022 18:32:23 +0200 Subject: [PATCH 09/14] Move into the launcher --- .../strategies/launchers/multiprocessing.py | 11 +++++++++++ src/pytorch_lightning/trainer/trainer.py | 12 +----------- 2 files changed, 12 insertions(+), 11 deletions(-) diff --git a/src/pytorch_lightning/strategies/launchers/multiprocessing.py b/src/pytorch_lightning/strategies/launchers/multiprocessing.py index 37e6c8d893150..2c80b8b3d44d5 100644 --- a/src/pytorch_lightning/strategies/launchers/multiprocessing.py +++ b/src/pytorch_lightning/strategies/launchers/multiprocessing.py @@ -83,6 +83,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"] a selected set of attributes get restored in the main process after processes join. **kwargs: Optional keyword arguments to be passed to the given function. """ + self._check_torchdistx_support() # The default cluster environment in Lightning chooses a random free port number # This needs to be done in the main process here before starting processes to ensure each rank will connect # through the same port @@ -164,6 +165,16 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra) + def _check_torchdistx_support(self) -> None: + if self._start_method == "spawn": + from pytorch_lightning.utilities.meta import _is_deferred + + if _is_deferred(self._strategy.lightning_module): + raise NotImplementedError( + f"The `{type(self._strategy).__name__}` strategy does not support `torchdistx`'s deferred" + f" initialization." + ) + def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None: """Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory sharing, we cast the data to numpy. diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index aac4cb8ac8418..6d67fa2c716b9 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -69,7 +69,7 @@ SimpleProfiler, XLAProfiler, ) -from pytorch_lightning.strategies import DDPSpawnStrategy, ParallelStrategy, Strategy +from pytorch_lightning.strategies import ParallelStrategy, Strategy from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector @@ -647,7 +647,6 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: if self.strategy.launcher is not None: - self._check_torchdistx_support() return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs) else: return trainer_fn(*args, **kwargs) @@ -1783,15 +1782,6 @@ def _log_device_info(self) -> None: f" `Trainer(accelerator='mps', devices={MPSAccelerator.auto_device_count()})`." ) - def _check_torchdistx_support(self) -> None: - from pytorch_lightning.utilities.meta import _is_deferred - - if _is_deferred(self.lightning_module) and isinstance(self.strategy, DDPSpawnStrategy): - raise NotImplementedError( - f"The `{type(self.strategy).__name__}` strategy does not support `torchdistx`'s deferred" - f" initialization." - ) - """ Data loading methods """ From e617b701c8c0b2e3270ebb90555c0f2dad90ce1c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 28 Jul 2022 16:45:28 +0200 Subject: [PATCH 10/14] Adrian review --- src/pytorch_lightning/utilities/meta.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/utilities/meta.py b/src/pytorch_lightning/utilities/meta.py index af589becfe1d3..9f4cd72bfe65d 100644 --- a/src/pytorch_lightning/utilities/meta.py +++ b/src/pytorch_lightning/utilities/meta.py @@ -41,7 +41,7 @@ def init_meta(module_fn: Callable[..., Module], *args: Any, **kwargs: Any) -> No def get_all_subclasses(cls: Type) -> Set[Type]: rank_zero_deprecation( "`pytorch_lightning.utilities.meta.get_all_subclasses` is deprecated in v1.8 and will be removed in v1.9." - " Please copy its implementation if you have an use for it." + " Please copy its implementation if you have a use for it." ) return _get_all_subclasses(cls) @@ -63,7 +63,7 @@ def recurse(cl: Type) -> None: def recursively_setattr(root_module: Any, prefix: str, materialized_module: Module) -> None: rank_zero_deprecation( "`pytorch_lightning.utilities.meta.recursively_setattr` is deprecated in v1.8 and will be removed in v1.9." - " Please copy its implementation if you have an use for it." + " Please copy its implementation if you have a use for it." ) *path, name = prefix.split(".") for p in path: @@ -97,7 +97,7 @@ def init_meta_context() -> Generator: def is_on_meta_device(module: Module) -> bool: rank_zero_deprecation( "`pytorch_lightning.utilities.meta.is_on_meta_device` is deprecated in v1.8 and will be removed in v1.9." - " Please copy its implementation if you have an use for it." + " Please copy its implementation if you have a use for it." ) try: param = next(module.parameters()) @@ -106,8 +106,8 @@ def is_on_meta_device(module: Module) -> bool: return False -def _is_deferred(module: Module) -> bool: - if not _module_available("torchdistx.fake"): +def _is_deferred(module: Optional[Module]) -> bool: + if module is None or not _module_available("torchdistx.fake"): return False from torchdistx.fake import is_fake From 022e149b154ec3180bf0c884e7e57fa44b6492ec Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 28 Jul 2022 17:30:06 +0200 Subject: [PATCH 11/14] setup_tools refactor --- .actions/setup_tools.py | 78 +++++++++++++++++++++++++++-------------- 1 file changed, 51 insertions(+), 27 deletions(-) diff --git a/.actions/setup_tools.py b/.actions/setup_tools.py index 3a105f508fd45..c451e42f3a0b4 100644 --- a/.actions/setup_tools.py +++ b/.actions/setup_tools.py @@ -12,7 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. import glob -import logging import os import pathlib import re @@ -22,8 +21,11 @@ import urllib.request from importlib.util import module_from_spec, spec_from_file_location from itertools import groupby +from pathlib import Path from types import ModuleType -from typing import List +from typing import Any, Iterable, Iterator, List, Optional + +import pkg_resources _PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) _PACKAGE_MAPPING = {"pytorch": "pytorch_lightning", "app": "lightning_app"} @@ -41,33 +43,56 @@ def _load_py_module(name: str, location: str) -> ModuleType: return py -def load_requirements( - path_dir: str, file_name: str = "base.txt", comment_char: str = "#", unfreeze: bool = True -) -> List[str]: +class _RequirementWithComment(pkg_resources.Requirement): + def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None: + super().__init__(*args, **kwargs) + self.comment = comment + assert pip_argument is None or pip_argument # sanity check that it's not an empty str + self.pip_argument = pip_argument + self.strict = "# strict" in comment.lower() + + def clean_str(self, unfreeze: bool) -> str: + # remove version restrictions unless they are strict + return self.project_name if unfreeze and not self.strict else str(self) + + +def _parse_requirements(strs: Iterable) -> Iterator[_RequirementWithComment]: + """Adapted from `pkg_resources.parse_requirements` to include comments.""" + lines = pkg_resources.yield_lines(strs) + pip_argument = None + for line in lines: + # Drop comments -- a hash without a space may be in a URL. + if " #" in line: + comment_pos = line.find(" #") + line, comment = line[:comment_pos], line[comment_pos:] + else: + comment = "" + # If there is a line continuation, drop it, and append the next line. + if line.endswith("\\"): + line = line[:-2].strip() + try: + line += next(lines) + except StopIteration: + return + # If there's a pip argument, save it + if line.startswith("--"): + pip_argument = line + continue + yield _RequirementWithComment(line, comment=comment, pip_argument=pip_argument) + pip_argument = None + + +def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: bool = True) -> List[str]: """Load requirements from a file. - >>> path_req = os.path.join(_PROJECT_ROOT, "requirements") + >>> path_req = os.path.join(_PROJECT_ROOT, "requirements", "pytorch") >>> load_requirements(path_req) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE ['numpy...', 'torch...', ...] """ - with open(os.path.join(path_dir, file_name)) as file: - lines = [ln.strip() for ln in file.readlines()] - reqs = [] - for ln in lines: - # filer all comments - comment = "" - if comment_char in ln: - comment = ln[ln.index(comment_char) :] - ln = ln[: ln.index(comment_char)] - req = ln.strip() - # skip directly installed dependencies - if not req or req.startswith("http") or "@http" in req: - continue - # remove version restrictions unless they are strict - if unfreeze and "<" in req and "strict" not in comment: - req = re.sub(r",? *<=? *[\d\.\*]+", "", req).strip() - reqs.append(req) - return reqs + path = Path(path_dir) / file_name + assert path.exists(), (path_dir, file_name, path) + text = path.read_text() + return [req.clean_str(unfreeze) for req in _parse_requirements(text)] def load_readme_description(path_dir: str, homepage: str, version: str) -> str: @@ -294,9 +319,8 @@ class implementations by cross-imports to the true package. if fname in ("__about__.py", "__version__.py"): body = lines else: - if fname.startswith("_") and fname not in ("__init__.py", "__main__.py"): - logging.warning(f"unsupported file: {local_path}") - continue + if fname.startswith("_") and fname not in ("__init__.py", "__main__.py", "__setup__.py"): + raise ValueError(f"Unsupported file: {fname}") # ToDO: perform some smarter parsing - preserve Constants, lambdas, etc body = prune_comments_docstrings(lines) if fname not in ("__init__.py", "__main__.py"): From ceb782308e3e020fb3af3ed69ef7edae378637ea Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 4 Aug 2022 18:22:39 +0200 Subject: [PATCH 12/14] Revert "setup_tools refactor" This reverts commit 022e149b154ec3180bf0c884e7e57fa44b6492ec. --- .actions/setup_tools.py | 78 ++++++++++++++--------------------------- 1 file changed, 27 insertions(+), 51 deletions(-) diff --git a/.actions/setup_tools.py b/.actions/setup_tools.py index c451e42f3a0b4..3a105f508fd45 100644 --- a/.actions/setup_tools.py +++ b/.actions/setup_tools.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import glob +import logging import os import pathlib import re @@ -21,11 +22,8 @@ import urllib.request from importlib.util import module_from_spec, spec_from_file_location from itertools import groupby -from pathlib import Path from types import ModuleType -from typing import Any, Iterable, Iterator, List, Optional - -import pkg_resources +from typing import List _PROJECT_ROOT = os.path.dirname(os.path.dirname(__file__)) _PACKAGE_MAPPING = {"pytorch": "pytorch_lightning", "app": "lightning_app"} @@ -43,56 +41,33 @@ def _load_py_module(name: str, location: str) -> ModuleType: return py -class _RequirementWithComment(pkg_resources.Requirement): - def __init__(self, *args: Any, comment: str = "", pip_argument: Optional[str] = None, **kwargs: Any) -> None: - super().__init__(*args, **kwargs) - self.comment = comment - assert pip_argument is None or pip_argument # sanity check that it's not an empty str - self.pip_argument = pip_argument - self.strict = "# strict" in comment.lower() - - def clean_str(self, unfreeze: bool) -> str: - # remove version restrictions unless they are strict - return self.project_name if unfreeze and not self.strict else str(self) - - -def _parse_requirements(strs: Iterable) -> Iterator[_RequirementWithComment]: - """Adapted from `pkg_resources.parse_requirements` to include comments.""" - lines = pkg_resources.yield_lines(strs) - pip_argument = None - for line in lines: - # Drop comments -- a hash without a space may be in a URL. - if " #" in line: - comment_pos = line.find(" #") - line, comment = line[:comment_pos], line[comment_pos:] - else: - comment = "" - # If there is a line continuation, drop it, and append the next line. - if line.endswith("\\"): - line = line[:-2].strip() - try: - line += next(lines) - except StopIteration: - return - # If there's a pip argument, save it - if line.startswith("--"): - pip_argument = line - continue - yield _RequirementWithComment(line, comment=comment, pip_argument=pip_argument) - pip_argument = None - - -def load_requirements(path_dir: str, file_name: str = "base.txt", unfreeze: bool = True) -> List[str]: +def load_requirements( + path_dir: str, file_name: str = "base.txt", comment_char: str = "#", unfreeze: bool = True +) -> List[str]: """Load requirements from a file. - >>> path_req = os.path.join(_PROJECT_ROOT, "requirements", "pytorch") + >>> path_req = os.path.join(_PROJECT_ROOT, "requirements") >>> load_requirements(path_req) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE ['numpy...', 'torch...', ...] """ - path = Path(path_dir) / file_name - assert path.exists(), (path_dir, file_name, path) - text = path.read_text() - return [req.clean_str(unfreeze) for req in _parse_requirements(text)] + with open(os.path.join(path_dir, file_name)) as file: + lines = [ln.strip() for ln in file.readlines()] + reqs = [] + for ln in lines: + # filer all comments + comment = "" + if comment_char in ln: + comment = ln[ln.index(comment_char) :] + ln = ln[: ln.index(comment_char)] + req = ln.strip() + # skip directly installed dependencies + if not req or req.startswith("http") or "@http" in req: + continue + # remove version restrictions unless they are strict + if unfreeze and "<" in req and "strict" not in comment: + req = re.sub(r",? *<=? *[\d\.\*]+", "", req).strip() + reqs.append(req) + return reqs def load_readme_description(path_dir: str, homepage: str, version: str) -> str: @@ -319,8 +294,9 @@ class implementations by cross-imports to the true package. if fname in ("__about__.py", "__version__.py"): body = lines else: - if fname.startswith("_") and fname not in ("__init__.py", "__main__.py", "__setup__.py"): - raise ValueError(f"Unsupported file: {fname}") + if fname.startswith("_") and fname not in ("__init__.py", "__main__.py"): + logging.warning(f"unsupported file: {local_path}") + continue # ToDO: perform some smarter parsing - preserve Constants, lambdas, etc body = prune_comments_docstrings(lines) if fname not in ("__init__.py", "__main__.py"): From fb4049e69f798769c280c5363e6ad1e19aeb8d04 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Thu, 4 Aug 2022 18:23:02 +0200 Subject: [PATCH 13/14] Drop torchdistx install --- requirements/pytorch/test.txt | 2 -- 1 file changed, 2 deletions(-) diff --git a/requirements/pytorch/test.txt b/requirements/pytorch/test.txt index e396fc8e81ff8..ce54cd087b1de 100644 --- a/requirements/pytorch/test.txt +++ b/requirements/pytorch/test.txt @@ -16,5 +16,3 @@ psutil # for `DeviceStatsMonitor` pandas # needed in benchmarks fastapi uvicorn ---extra-index-url https://download.pytorch.org/whl/cpu -torchdistx From 8fc890054b662d0c448d832dadc8821edb0035b6 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Carlos=20Mochol=C3=AD?= Date: Fri, 5 Aug 2022 11:42:55 +0200 Subject: [PATCH 14/14] Fix missing import --- src/pytorch_lightning/utilities/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 2b12dbfd31dad..00a7cb8486709 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -18,7 +18,7 @@ from contextlib import contextmanager from dataclasses import fields from functools import partial -from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Tuple, Union +from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Tuple, Type, Union import torch from torch import Tensor