Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,5 @@ module = [
"pytorch_lightning.tuner.batch_size_scaling",
"pytorch_lightning.utilities.auto_restart",
"pytorch_lightning.utilities.data",
"pytorch_lightning.utilities.meta",
]
ignore_errors = "True"
8 changes: 2 additions & 6 deletions src/pytorch_lightning/core/mixins/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,6 @@
from torch.nn import Module
from typing_extensions import Self

import pytorch_lightning as pl


class DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ["device", "dtype"]
Expand Down Expand Up @@ -180,10 +178,8 @@ def half(self) -> Self: # type: ignore[valid-type]
def __update_properties(
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
) -> None:
def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None:
# TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't
# work when using `init_meta_context`.
if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)):
def apply_fn(module: Union[DeviceDtypeModuleMixin, Module]) -> None:
if not isinstance(module, DeviceDtypeModuleMixin):
return
if device is not None:
module._device = device
Expand Down
11 changes: 11 additions & 0 deletions src/pytorch_lightning/strategies/launchers/multiprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
a selected set of attributes get restored in the main process after processes join.
**kwargs: Optional keyword arguments to be passed to the given function.
"""
self._check_torchdistx_support()
# The default cluster environment in Lightning chooses a random free port number
# This needs to be done in the main process here before starting processes to ensure each rank will connect
# through the same port
Expand Down Expand Up @@ -178,6 +179,16 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt

return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)

def _check_torchdistx_support(self) -> None:
if self._start_method == "spawn":
from pytorch_lightning.utilities.meta import _is_deferred

if _is_deferred(self._strategy.lightning_module):
raise NotImplementedError(
f"The `{type(self._strategy).__name__}` strategy does not support `torchdistx`'s deferred"
f" initialization."
)

def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
sharing, we cast the data to numpy.
Expand Down
22 changes: 7 additions & 15 deletions src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,6 @@
XLAProfiler,
)
from pytorch_lightning.strategies import ParallelStrategy, Strategy
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
Expand Down Expand Up @@ -106,8 +105,7 @@
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks
from pytorch_lightning.utilities.distributed import distributed_available
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
from pytorch_lightning.utilities.imports import _fault_tolerant_training
from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _module_available
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities.seed import isolate_rng
Expand Down Expand Up @@ -1469,20 +1467,14 @@ def _call_setup_hook(self) -> None:

def _call_configure_sharded_model(self) -> None:
with self.strategy.model_sharded_context():
self._handle_meta_model()
self._call_lightning_module_hook("configure_sharded_model")
self._call_callback_hooks("on_configure_sharded_model")

def _handle_meta_model(self) -> None:
if not is_on_meta_device(self.lightning_module):
return
# experimental support for torchdistx
if _module_available("torchdistx.deferred_init"):
from torchdistx.deferred_init import materialize_module

if isinstance(self.strategy, DDPSpawnStrategy):
raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.")
materialize_module(self.lightning_module)

materialize_module(self.lightning_module)
# the trainer reference is lost during materialization
self.lightning_module.trainer = proxy(self)
self._call_lightning_module_hook("configure_sharded_model")
self._call_callback_hooks("on_configure_sharded_model")

def _call_teardown_hook(self) -> None:
fn = self.state.fn._setup_fn
Expand Down
14 changes: 7 additions & 7 deletions src/pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import pytorch_lightning as pl
import pytorch_lightning.cli as new_cli
from pytorch_lightning.utilities.meta import get_all_subclasses
from pytorch_lightning.utilities.meta import _get_all_subclasses
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation

_deprecate_registry_message = (
Expand Down Expand Up @@ -108,17 +108,17 @@ def _populate_registries(subclasses: bool) -> None: # Remove in v1.9
if subclasses:
rank_zero_deprecation(_deprecate_auto_registry_message)
# this will register any subclasses from all loaded modules including userland
for cls in get_all_subclasses(torch.optim.Optimizer):
for cls in _get_all_subclasses(torch.optim.Optimizer):
OPTIMIZER_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
for cls in _get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
LR_SCHEDULER_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(pl.Callback):
for cls in _get_all_subclasses(pl.Callback):
CALLBACK_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(pl.LightningModule):
for cls in _get_all_subclasses(pl.LightningModule):
MODEL_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(pl.LightningDataModule):
for cls in _get_all_subclasses(pl.LightningDataModule):
DATAMODULE_REGISTRY(cls, show_deprecation=False)
for cls in get_all_subclasses(pl.loggers.Logger):
for cls in _get_all_subclasses(pl.loggers.Logger):
LOGGER_REGISTRY(cls, show_deprecation=False)
else:
# manually register torch's subclasses and our subclasses
Expand Down
17 changes: 2 additions & 15 deletions src/pytorch_lightning/utilities/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from contextlib import contextmanager
from dataclasses import fields
from functools import partial
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Tuple, Type, Union
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Tuple, Type, Union

import torch
from torch import Tensor
Expand All @@ -39,6 +39,7 @@
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
from pytorch_lightning.utilities.enums import _FaultTolerantMode
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.meta import _get_all_subclasses
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
from pytorch_lightning.utilities.seed import pl_worker_init_function
from pytorch_lightning.utilities.warnings import WarningCache
Expand Down Expand Up @@ -493,20 +494,6 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
return wrapper


# https://stackoverflow.com/a/63851681/9201239
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
"""Returns a list of all classes that inherit directly or indirectly from the given class."""
subclasses = set()

def recurse(cl: Type[Any]) -> None:
for subclass in cl.__subclasses__():
subclasses.add(subclass)
recurse(subclass)

recurse(cls)
return subclasses


@contextmanager
def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.
Expand Down
Loading