diff --git a/pytorch_lightning/lite/wrappers.py b/pytorch_lightning/lite/wrappers.py index 3cd2f5eb69712..3bd94b5f36b74 100644 --- a/pytorch_lightning/lite/wrappers.py +++ b/pytorch_lightning/lite/wrappers.py @@ -15,7 +15,7 @@ import inspect from contextlib import contextmanager from itertools import chain -from typing import Any, Callable, Dict, Generator, Iterator, Optional, Set, Type, Union +from typing import Any, Callable, Generator, Iterator, Optional, Set, Type, Union import torch from torch import nn as nn @@ -110,21 +110,25 @@ def _convert_float_tensor(t: Tensor) -> Tensor: return output -def _wrap_init(f: Callable) -> Callable: - @functools.wraps(f) - def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None: - params = dict(inspect.signature(module._old_init).parameters) +def _wrap_init(init: Callable) -> Callable: + """Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of + :class:`~torch.utils.data.DataLoader`.""" + + @functools.wraps(init) + def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None: + params = dict(inspect.signature(obj.__init__).parameters) params.pop("args", None) params.pop("kwargs", None) - for init_name, init_arg in chain(zip(params, args), kwargs.items()): - setattr(module, init_name, init_arg) - f(module, *args, **kwargs) + for arg_name, arg_value in chain(zip(params, args), kwargs.items()): + setattr(obj, arg_name, arg_value) + init(obj, *args, **kwargs) return wrapper # https://stackoverflow.com/a/63851681/9201239 def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]: + """Returns a list of all classes that inherit directly or indirectly from the given class.""" subclasses = set() def recurse(cl: Type[Any]) -> None: @@ -136,24 +140,17 @@ def recurse(cl: Type[Any]) -> None: return subclasses -def _enable_class(cls: Type[Any]) -> None: - cls._old_init = cls.__init__ - cls.__init__ = _wrap_init(cls.__init__) - - -def _disable_class(cls: Type[Any]) -> None: - cls.__init__ = cls._old_init - del cls._old_init - - @contextmanager -def _replace_dataloader_init_method() -> Generator: - """This context manager is used to support custom :class:`~torch.utils.data.DataLoader.""" +def _replace_dataloader_init_method() -> Generator[None, None, None]: + """This context manager is used to add support for re-instantiation of custom (subclasses) of + :class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method.""" for subclass in _get_all_subclasses(DataLoader): - _enable_class(subclass) + subclass._old_init = subclass.__init__ + subclass.__init__ = _wrap_init(subclass.__init__) yield for subclass in _get_all_subclasses(DataLoader): - _disable_class(subclass) + subclass.__init__ = subclass._old_init + del subclass._old_init class _LiteDataLoader: diff --git a/tests/lite/test_lite.py b/tests/lite/test_lite.py index 7c79cb7f2e709..f9ed4a9da7d9d 100644 --- a/tests/lite/test_lite.py +++ b/tests/lite/test_lite.py @@ -24,12 +24,7 @@ from torch.utils.data import DataLoader, DistributedSampler, Sampler from pytorch_lightning.lite import LightningLite -from pytorch_lightning.lite.wrappers import ( - _LiteDataLoader, - _LiteModule, - _LiteOptimizer, - _replace_dataloader_init_method, -) +from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin from pytorch_lightning.utilities import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -197,6 +192,27 @@ def run(self): LiteWithCustomDataLoader().run() +def test_setup_dataloaders_raises_for_unknown_custom_args(): + """Test that an error raises when custom dataloaders with unknown arguments are created from outside Lite's run + method.""" + lite = EmptyLite() + + class CustomDataLoader(DataLoader): + def __init__(self, new_arg, *args, **kwargs): + super().__init__(range(5), *args, **kwargs) + + with pytest.raises( + MisconfigurationException, + match=( + r"Trying to inject `DistributedSampler` into the `CustomDataLoader` instance.*" + r"The missing attributes are \['new_arg'\]" + ), + ): + # The dataloader was not created within the run function, and therefore init args were not intercepted + dataloader = CustomDataLoader(2, batch_size=2) + lite.setup_dataloaders(dataloader) + + def test_setup_dataloaders_twice_fails(): """Test that calling setup_dataloaders with a dataloader that is already wrapped fails.""" lite = EmptyLite() @@ -444,25 +460,3 @@ def run(self): assert self.is_global_zero == (self.local_rank == 0) Lite(strategy=DeepSpeedPlugin(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run() - - -def test_replace_dataloader_init_method(): - """Test that the context manager enables to save the parameters passed to the DataLoader __init__ method.""" - - class CustomDataLoader(DataLoader): - def __init__(self, extra_argument: int, *args, **kwargs): - super().__init__(*args, **kwargs) - - dataloader = CustomDataLoader(extra_argument=1, dataset=range(1)) - lite = EmptyLite() - with pytest.raises(MisconfigurationException, match="extra_argument"): - dataloader = lite.setup_dataloaders(dataloader) - - with _replace_dataloader_init_method(): - dataloader = CustomDataLoader(extra_argument=1, dataset=range(1)) - assert dataloader.extra_argument == 1 - dataloader = lite.setup_dataloaders(dataloader) - - dataloader = CustomDataLoader(1, range(1)) - assert dataloader.extra_argument == 1 - dataloader = lite.setup_dataloaders(dataloader)