Skip to content

Commit dca1776

Browse files
authored
LiteDataLoader wrapper improvements (#10297)
1 parent 7cf6374 commit dca1776

File tree

2 files changed

+41
-50
lines changed

2 files changed

+41
-50
lines changed

pytorch_lightning/lite/wrappers.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import inspect
1616
from contextlib import contextmanager
1717
from itertools import chain
18-
from typing import Any, Callable, Dict, Generator, Iterator, Optional, Set, Type, Union
18+
from typing import Any, Callable, Generator, Iterator, Optional, Set, Type, Union
1919

2020
import torch
2121
from torch import nn as nn
@@ -110,21 +110,25 @@ def _convert_float_tensor(t: Tensor) -> Tensor:
110110
return output
111111

112112

113-
def _wrap_init(f: Callable) -> Callable:
114-
@functools.wraps(f)
115-
def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None:
116-
params = dict(inspect.signature(module._old_init).parameters)
113+
def _wrap_init(init: Callable) -> Callable:
114+
"""Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of
115+
:class:`~torch.utils.data.DataLoader`."""
116+
117+
@functools.wraps(init)
118+
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
119+
params = dict(inspect.signature(obj.__init__).parameters)
117120
params.pop("args", None)
118121
params.pop("kwargs", None)
119-
for init_name, init_arg in chain(zip(params, args), kwargs.items()):
120-
setattr(module, init_name, init_arg)
121-
f(module, *args, **kwargs)
122+
for arg_name, arg_value in chain(zip(params, args), kwargs.items()):
123+
setattr(obj, arg_name, arg_value)
124+
init(obj, *args, **kwargs)
122125

123126
return wrapper
124127

125128

126129
# https://stackoverflow.com/a/63851681/9201239
127130
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
131+
"""Returns a list of all classes that inherit directly or indirectly from the given class."""
128132
subclasses = set()
129133

130134
def recurse(cl: Type[Any]) -> None:
@@ -136,24 +140,17 @@ def recurse(cl: Type[Any]) -> None:
136140
return subclasses
137141

138142

139-
def _enable_class(cls: Type[Any]) -> None:
140-
cls._old_init = cls.__init__
141-
cls.__init__ = _wrap_init(cls.__init__)
142-
143-
144-
def _disable_class(cls: Type[Any]) -> None:
145-
cls.__init__ = cls._old_init
146-
del cls._old_init
147-
148-
149143
@contextmanager
150-
def _replace_dataloader_init_method() -> Generator:
151-
"""This context manager is used to support custom :class:`~torch.utils.data.DataLoader."""
144+
def _replace_dataloader_init_method() -> Generator[None, None, None]:
145+
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
146+
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
152147
for subclass in _get_all_subclasses(DataLoader):
153-
_enable_class(subclass)
148+
subclass._old_init = subclass.__init__
149+
subclass.__init__ = _wrap_init(subclass.__init__)
154150
yield
155151
for subclass in _get_all_subclasses(DataLoader):
156-
_disable_class(subclass)
152+
subclass.__init__ = subclass._old_init
153+
del subclass._old_init
157154

158155

159156
class _LiteDataLoader:

tests/lite/test_lite.py

Lines changed: 22 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,7 @@
2424
from torch.utils.data import DataLoader, DistributedSampler, Sampler
2525

2626
from pytorch_lightning.lite import LightningLite
27-
from pytorch_lightning.lite.wrappers import (
28-
_LiteDataLoader,
29-
_LiteModule,
30-
_LiteOptimizer,
31-
_replace_dataloader_init_method,
32-
)
27+
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
3328
from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin
3429
from pytorch_lightning.utilities import _StrategyType
3530
from pytorch_lightning.utilities.exceptions import MisconfigurationException
@@ -197,6 +192,27 @@ def run(self):
197192
LiteWithCustomDataLoader().run()
198193

199194

195+
def test_setup_dataloaders_raises_for_unknown_custom_args():
196+
"""Test that an error raises when custom dataloaders with unknown arguments are created from outside Lite's run
197+
method."""
198+
lite = EmptyLite()
199+
200+
class CustomDataLoader(DataLoader):
201+
def __init__(self, new_arg, *args, **kwargs):
202+
super().__init__(range(5), *args, **kwargs)
203+
204+
with pytest.raises(
205+
MisconfigurationException,
206+
match=(
207+
r"Trying to inject `DistributedSampler` into the `CustomDataLoader` instance.*"
208+
r"The missing attributes are \['new_arg'\]"
209+
),
210+
):
211+
# The dataloader was not created within the run function, and therefore init args were not intercepted
212+
dataloader = CustomDataLoader(2, batch_size=2)
213+
lite.setup_dataloaders(dataloader)
214+
215+
200216
def test_setup_dataloaders_twice_fails():
201217
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
202218
lite = EmptyLite()
@@ -444,25 +460,3 @@ def run(self):
444460
assert self.is_global_zero == (self.local_rank == 0)
445461

446462
Lite(strategy=DeepSpeedPlugin(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run()
447-
448-
449-
def test_replace_dataloader_init_method():
450-
"""Test that the context manager enables to save the parameters passed to the DataLoader __init__ method."""
451-
452-
class CustomDataLoader(DataLoader):
453-
def __init__(self, extra_argument: int, *args, **kwargs):
454-
super().__init__(*args, **kwargs)
455-
456-
dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
457-
lite = EmptyLite()
458-
with pytest.raises(MisconfigurationException, match="extra_argument"):
459-
dataloader = lite.setup_dataloaders(dataloader)
460-
461-
with _replace_dataloader_init_method():
462-
dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
463-
assert dataloader.extra_argument == 1
464-
dataloader = lite.setup_dataloaders(dataloader)
465-
466-
dataloader = CustomDataLoader(1, range(1))
467-
assert dataloader.extra_argument == 1
468-
dataloader = lite.setup_dataloaders(dataloader)

0 commit comments

Comments
 (0)