Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
95d6889
move helpers to the bottom
awaelchli Nov 1, 2021
642b9c0
update docs for wrappers
awaelchli Nov 1, 2021
2c1bcfd
rename iterator variable
awaelchli Nov 1, 2021
29ae286
mention iterable in the docstring
awaelchli Nov 1, 2021
a10df79
update type
awaelchli Nov 1, 2021
134122e
add comment, improve readability
awaelchli Nov 1, 2021
16fc44d
add typing for generator
awaelchli Nov 1, 2021
8b352e2
update docs for LiteDataLoader
awaelchli Nov 1, 2021
cc2673a
every Python object has a dict
awaelchli Nov 1, 2021
61bfb09
wrap_init code improvement
awaelchli Nov 1, 2021
f047032
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 1, 2021
5e8d88e
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 2, 2021
f5b19b7
add changes from master
awaelchli Nov 2, 2021
53b15af
change order for review
awaelchli Nov 2, 2021
98834e4
fix iterator
awaelchli Nov 2, 2021
e4dd939
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2021
4b29301
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 2, 2021
c3456a7
simplify reference to old_init
awaelchli Nov 2, 2021
d60cf92
inline code
awaelchli Nov 2, 2021
0563c8c
add docs
awaelchli Nov 2, 2021
3e3d7c4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 2, 2021
7411a62
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 3, 2021
23fab9e
wip
awaelchli Nov 3, 2021
ccfbf56
wip
awaelchli Nov 3, 2021
cf5923b
wip
awaelchli Nov 3, 2021
66961f3
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 5, 2021
a11c358
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 5, 2021
cf6131e
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 5, 2021
5f8e67c
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 18, 2021
6c05ab0
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 18, 2021
265b8b9
Merge branch 'master' into feature/lite-dataloader
awaelchli Nov 22, 2021
c9a05e0
update docs
awaelchli Nov 22, 2021
6cc836b
update signature
awaelchli Nov 22, 2021
9c53ddc
use init reference directly
awaelchli Nov 22, 2021
bb40f54
remove
awaelchli Nov 22, 2021
5b4c1d7
remove
awaelchli Nov 22, 2021
0c69d3a
update message
awaelchli Nov 22, 2021
da5f1a0
unused imprts
awaelchli Nov 22, 2021
527fe52
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Nov 22, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 19 additions & 22 deletions pytorch_lightning/lite/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import inspect
from contextlib import contextmanager
from itertools import chain
from typing import Any, Callable, Dict, Generator, Iterator, Optional, Set, Type, Union
from typing import Any, Callable, Generator, Iterator, Optional, Set, Type, Union

import torch
from torch import nn as nn
Expand Down Expand Up @@ -110,21 +110,25 @@ def _convert_float_tensor(t: Tensor) -> Tensor:
return output


def _wrap_init(f: Callable) -> Callable:
@functools.wraps(f)
def wrapper(module: Any, *args: Any, **kwargs: Dict[str, Any]) -> None:
params = dict(inspect.signature(module._old_init).parameters)
def _wrap_init(init: Callable) -> Callable:
"""Wraps the ``__init__`` method of the dataloader in order to enable re-instantiation of custom subclasses of
:class:`~torch.utils.data.DataLoader`."""

@functools.wraps(init)
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
params = dict(inspect.signature(obj.__init__).parameters)
params.pop("args", None)
params.pop("kwargs", None)
for init_name, init_arg in chain(zip(params, args), kwargs.items()):
setattr(module, init_name, init_arg)
f(module, *args, **kwargs)
for arg_name, arg_value in chain(zip(params, args), kwargs.items()):
setattr(obj, arg_name, arg_value)
init(obj, *args, **kwargs)

return wrapper


# https://stackoverflow.com/a/63851681/9201239
def _get_all_subclasses(cls: Type[Any]) -> Set[Type[Any]]:
"""Returns a list of all classes that inherit directly or indirectly from the given class."""
subclasses = set()

def recurse(cl: Type[Any]) -> None:
Expand All @@ -136,24 +140,17 @@ def recurse(cl: Type[Any]) -> None:
return subclasses


def _enable_class(cls: Type[Any]) -> None:
cls._old_init = cls.__init__
cls.__init__ = _wrap_init(cls.__init__)


def _disable_class(cls: Type[Any]) -> None:
cls.__init__ = cls._old_init
del cls._old_init


@contextmanager
def _replace_dataloader_init_method() -> Generator:
"""This context manager is used to support custom :class:`~torch.utils.data.DataLoader."""
def _replace_dataloader_init_method() -> Generator[None, None, None]:
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
for subclass in _get_all_subclasses(DataLoader):
_enable_class(subclass)
subclass._old_init = subclass.__init__
subclass.__init__ = _wrap_init(subclass.__init__)
yield
for subclass in _get_all_subclasses(DataLoader):
_disable_class(subclass)
subclass.__init__ = subclass._old_init
del subclass._old_init


class _LiteDataLoader:
Expand Down
50 changes: 22 additions & 28 deletions tests/lite/test_lite.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,7 @@
from torch.utils.data import DataLoader, DistributedSampler, Sampler

from pytorch_lightning.lite import LightningLite
from pytorch_lightning.lite.wrappers import (
_LiteDataLoader,
_LiteModule,
_LiteOptimizer,
_replace_dataloader_init_method,
)
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
from pytorch_lightning.plugins import DeepSpeedPlugin, PrecisionPlugin, TrainingTypePlugin
from pytorch_lightning.utilities import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -197,6 +192,27 @@ def run(self):
LiteWithCustomDataLoader().run()


def test_setup_dataloaders_raises_for_unknown_custom_args():
"""Test that an error raises when custom dataloaders with unknown arguments are created from outside Lite's run
method."""
lite = EmptyLite()

class CustomDataLoader(DataLoader):
def __init__(self, new_arg, *args, **kwargs):
super().__init__(range(5), *args, **kwargs)

with pytest.raises(
MisconfigurationException,
match=(
r"Trying to inject `DistributedSampler` into the `CustomDataLoader` instance.*"
r"The missing attributes are \['new_arg'\]"
),
):
# The dataloader was not created within the run function, and therefore init args were not intercepted
dataloader = CustomDataLoader(2, batch_size=2)
lite.setup_dataloaders(dataloader)


def test_setup_dataloaders_twice_fails():
"""Test that calling setup_dataloaders with a dataloader that is already wrapped fails."""
lite = EmptyLite()
Expand Down Expand Up @@ -444,25 +460,3 @@ def run(self):
assert self.is_global_zero == (self.local_rank == 0)

Lite(strategy=DeepSpeedPlugin(stage=3, logging_batch_size_per_gpu=1), devices=2, accelerator="gpu").run()


def test_replace_dataloader_init_method():
"""Test that the context manager enables to save the parameters passed to the DataLoader __init__ method."""

class CustomDataLoader(DataLoader):
def __init__(self, extra_argument: int, *args, **kwargs):
super().__init__(*args, **kwargs)

dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
lite = EmptyLite()
with pytest.raises(MisconfigurationException, match="extra_argument"):
dataloader = lite.setup_dataloaders(dataloader)

with _replace_dataloader_init_method():
dataloader = CustomDataLoader(extra_argument=1, dataset=range(1))
assert dataloader.extra_argument == 1
dataloader = lite.setup_dataloaders(dataloader)

dataloader = CustomDataLoader(1, range(1))
assert dataloader.extra_argument == 1
dataloader = lite.setup_dataloaders(dataloader)