Skip to content

Commit 91dd6a6

Browse files
authored
Remove meta device utilities in favor of torchdistx (#13868)
1 parent 0883971 commit 91dd6a6

File tree

11 files changed

+194
-429
lines changed

11 files changed

+194
-429
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,5 @@ module = [
7171
"pytorch_lightning.tuner.batch_size_scaling",
7272
"pytorch_lightning.utilities.auto_restart",
7373
"pytorch_lightning.utilities.data",
74-
"pytorch_lightning.utilities.meta",
7574
]
7675
ignore_errors = "True"

src/pytorch_lightning/core/mixins/device_dtype_mixin.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818
from torch.nn import Module
1919
from typing_extensions import Self
2020

21-
import pytorch_lightning as pl
22-
2321

2422
class DeviceDtypeModuleMixin(Module):
2523
__jit_unused_properties__ = ["device", "dtype"]
@@ -180,10 +178,8 @@ def half(self) -> Self: # type: ignore[valid-type]
180178
def __update_properties(
181179
self, device: Optional[torch.device] = None, dtype: Optional[Union[str, torch.dtype]] = None
182180
) -> None:
183-
def apply_fn(module: Union["DeviceDtypeModuleMixin", Module]) -> None:
184-
# TODO: Find why `isinstance(module, DeviceDtypeModuleMixin)` doesn't
185-
# work when using `init_meta_context`.
186-
if not isinstance(module, (DeviceDtypeModuleMixin, pl.LightningModule)):
181+
def apply_fn(module: Union[DeviceDtypeModuleMixin, Module]) -> None:
182+
if not isinstance(module, DeviceDtypeModuleMixin):
187183
return
188184
if device is not None:
189185
module._device = device

src/pytorch_lightning/strategies/launchers/multiprocessing.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,7 @@ def launch(self, function: Callable, *args: Any, trainer: Optional["pl.Trainer"]
8787
a selected set of attributes get restored in the main process after processes join.
8888
**kwargs: Optional keyword arguments to be passed to the given function.
8989
"""
90+
self._check_torchdistx_support()
9091
# The default cluster environment in Lightning chooses a random free port number
9192
# This needs to be done in the main process here before starting processes to ensure each rank will connect
9293
# through the same port
@@ -178,6 +179,16 @@ def _collect_rank_zero_results(self, trainer: "pl.Trainer", results: Any) -> Opt
178179

179180
return _WorkerOutput(best_model_path, weights_path, trainer.state, results, extra)
180181

182+
def _check_torchdistx_support(self) -> None:
183+
if self._start_method == "spawn":
184+
from pytorch_lightning.utilities.meta import _is_deferred
185+
186+
if _is_deferred(self._strategy.lightning_module):
187+
raise NotImplementedError(
188+
f"The `{type(self._strategy).__name__}` strategy does not support `torchdistx`'s deferred"
189+
f" initialization."
190+
)
191+
181192
def add_to_queue(self, trainer: "pl.Trainer", queue: "_FakeQueue") -> None:
182193
"""Appends the :attr:`trainer.callback_metrics` dictionary to the given queue. To avoid issues with memory
183194
sharing, we cast the data to numpy.

src/pytorch_lightning/trainer/trainer.py

Lines changed: 7 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,6 @@
7070
XLAProfiler,
7171
)
7272
from pytorch_lightning.strategies import ParallelStrategy, Strategy
73-
from pytorch_lightning.strategies.ddp_spawn import DDPSpawnStrategy
7473
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
7574
from pytorch_lightning.trainer.configuration_validator import verify_loop_configurations
7675
from pytorch_lightning.trainer.connectors.accelerator_connector import _LITERAL_WARN, AcceleratorConnector
@@ -106,8 +105,7 @@
106105
from pytorch_lightning.utilities.data import _auto_add_worker_init_fn, has_len_all_ranks
107106
from pytorch_lightning.utilities.distributed import distributed_available
108107
from pytorch_lightning.utilities.exceptions import ExitGracefullyException, MisconfigurationException
109-
from pytorch_lightning.utilities.imports import _fault_tolerant_training
110-
from pytorch_lightning.utilities.meta import is_on_meta_device, materialize_module
108+
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _module_available
111109
from pytorch_lightning.utilities.model_helpers import is_overridden
112110
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_info, rank_zero_warn
113111
from pytorch_lightning.utilities.seed import isolate_rng
@@ -1469,20 +1467,14 @@ def _call_setup_hook(self) -> None:
14691467

14701468
def _call_configure_sharded_model(self) -> None:
14711469
with self.strategy.model_sharded_context():
1472-
self._handle_meta_model()
1473-
self._call_lightning_module_hook("configure_sharded_model")
1474-
self._call_callback_hooks("on_configure_sharded_model")
1475-
1476-
def _handle_meta_model(self) -> None:
1477-
if not is_on_meta_device(self.lightning_module):
1478-
return
1470+
# experimental support for torchdistx
1471+
if _module_available("torchdistx.deferred_init"):
1472+
from torchdistx.deferred_init import materialize_module
14791473

1480-
if isinstance(self.strategy, DDPSpawnStrategy):
1481-
raise MisconfigurationException("LightningModule on meta device isn't supported with spawn.")
1474+
materialize_module(self.lightning_module)
14821475

1483-
materialize_module(self.lightning_module)
1484-
# the trainer reference is lost during materialization
1485-
self.lightning_module.trainer = proxy(self)
1476+
self._call_lightning_module_hook("configure_sharded_model")
1477+
self._call_callback_hooks("on_configure_sharded_model")
14861478

14871479
def _call_teardown_hook(self) -> None:
14881480
fn = self.state.fn._setup_fn

src/pytorch_lightning/utilities/cli.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222

2323
import pytorch_lightning as pl
2424
import pytorch_lightning.cli as new_cli
25-
from pytorch_lightning.utilities.meta import get_all_subclasses
25+
from pytorch_lightning.utilities.meta import _get_all_subclasses
2626
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
2727

2828
_deprecate_registry_message = (
@@ -108,17 +108,17 @@ def _populate_registries(subclasses: bool) -> None: # Remove in v1.9
108108
if subclasses:
109109
rank_zero_deprecation(_deprecate_auto_registry_message)
110110
# this will register any subclasses from all loaded modules including userland
111-
for cls in get_all_subclasses(torch.optim.Optimizer):
111+
for cls in _get_all_subclasses(torch.optim.Optimizer):
112112
OPTIMIZER_REGISTRY(cls, show_deprecation=False)
113-
for cls in get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
113+
for cls in _get_all_subclasses(torch.optim.lr_scheduler._LRScheduler):
114114
LR_SCHEDULER_REGISTRY(cls, show_deprecation=False)
115-
for cls in get_all_subclasses(pl.Callback):
115+
for cls in _get_all_subclasses(pl.Callback):
116116
CALLBACK_REGISTRY(cls, show_deprecation=False)
117-
for cls in get_all_subclasses(pl.LightningModule):
117+
for cls in _get_all_subclasses(pl.LightningModule):
118118
MODEL_REGISTRY(cls, show_deprecation=False)
119-
for cls in get_all_subclasses(pl.LightningDataModule):
119+
for cls in _get_all_subclasses(pl.LightningDataModule):
120120
DATAMODULE_REGISTRY(cls, show_deprecation=False)
121-
for cls in get_all_subclasses(pl.loggers.Logger):
121+
for cls in _get_all_subclasses(pl.loggers.Logger):
122122
LOGGER_REGISTRY(cls, show_deprecation=False)
123123
else:
124124
# manually register torch's subclasses and our subclasses

src/pytorch_lightning/utilities/data.py

Lines changed: 2 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
from contextlib import contextmanager
1919
from dataclasses import fields
2020
from functools import partial
21-
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Tuple, Type, Union
21+
from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Tuple, Type, Union
2222

2323
import torch
2424
from torch import Tensor
@@ -39,6 +39,7 @@
3939
from pytorch_lightning.utilities.auto_restart import CaptureIterableDataset, CaptureMapDataset, FastForwardSampler
4040
from pytorch_lightning.utilities.enums import _FaultTolerantMode
4141
from pytorch_lightning.utilities.exceptions import MisconfigurationException
42+
from pytorch_lightning.utilities.meta import _get_all_subclasses
4243
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
4344
from pytorch_lightning.utilities.seed import pl_worker_init_function
4445
from pytorch_lightning.utilities.warnings import WarningCache
@@ -493,20 +494,6 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
493494
return wrapper
494495

495496

496-
# https://stackoverflow.com/a/63851681/9201239
497-
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
498-
"""Returns a list of all classes that inherit directly or indirectly from the given class."""
499-
subclasses = set()
500-
501-
def recurse(cl: Type[Any]) -> None:
502-
for subclass in cl.__subclasses__():
503-
subclasses.add(subclass)
504-
recurse(subclass)
505-
506-
recurse(cls)
507-
return subclasses
508-
509-
510497
@contextmanager
511498
def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
512499
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.

0 commit comments

Comments
 (0)