From 848e82b79ae6daa5f9aeb683d263c64b6ff497db Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 13 Jul 2022 11:01:03 +0200 Subject: [PATCH 01/23] generalize replace_init_method --- src/pytorch_lightning/lite/lite.py | 4 +- .../trainer/connectors/data_connector.py | 4 +- src/pytorch_lightning/utilities/data.py | 64 ++++++++++--------- tests/tests_pytorch/lite/test_lite.py | 2 +- tests/tests_pytorch/utilities/test_data.py | 20 +++--- 5 files changed, 50 insertions(+), 44 deletions(-) diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 4dfcde177f953..bdd2d5308a17b 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -35,7 +35,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, - _replace_dataloader_init_method, + _replace_init_method, _update_dataloader, has_iterable_dataset, ) @@ -409,7 +409,7 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: self._strategy.setup_environment() - with self._strategy.model_sharded_context(), _replace_dataloader_init_method(): + with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, ["dataset"]): return run_method(*args, **kwargs) def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index add62ceece65c..cf739c4ee69e1 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -31,7 +31,7 @@ from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, _is_dataloader_shuffled, - _replace_dataloader_init_method, + _replace_init_method, _update_dataloader, has_iterable_dataset, has_len_all_ranks, @@ -424,7 +424,7 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat """ source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source") - with _replace_dataloader_init_method(): + with _replace_init_method(DataLoader, ["dataset"]): # under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as # attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning dataloader = source.dataloader() diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 2de82ceff088e..12dd47d4a3a98 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -17,7 +17,7 @@ from contextlib import contextmanager from dataclasses import fields from functools import partial -from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union import torch from torch import Tensor @@ -217,11 +217,11 @@ def _get_dataloader_init_args_and_kwargs( if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") - was_wrapped = hasattr(dataloader, "__pl_dl_args") + was_wrapped = hasattr(dataloader, "__pl_saved_args") if was_wrapped: - dl_args = dataloader.__pl_dl_args - dl_kwargs = dataloader.__pl_dl_kwargs - arg_names = dataloader.__pl_dl_arg_names + dl_args = dataloader.__pl_saved_args + dl_kwargs = dataloader.__pl_saved_kwargs + arg_names = dataloader.__pl_saved_arg_names original_dataset = dataloader.__dataset # we have this saved from _wrap_init else: # get the dataloader instance attributes @@ -355,12 +355,12 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) -def _wrap_dataloader_init(init: Callable) -> Callable: - """Wraps the ``__init__`` method of :class:`~torch.utils.data.DataLoader` in order to enable re-instantiation - of custom subclasses.""" +def _wrap_init_method(init: Callable, store_explicit_args: Optional[List[str]] = None) -> Callable: + """Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and + :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" @functools.wraps(init) - def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None: + def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: # We need to inspect `init`, as inspecting `obj.__init__` # can lead to inspecting the wrong function with multiple inheritance params = inspect.signature(init).parameters @@ -371,18 +371,20 @@ def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None: ) param_names = param_names[: len(args)] - if not hasattr(obj, "__pl_dl_args"): - obj.__pl_dl_args = args - obj.__pl_dl_kwargs = kwargs - obj.__pl_dl_arg_names = param_names + if not hasattr(obj, "__pl_saved_args"): + obj.__pl_saved_args = args + obj.__pl_saved_kwargs = kwargs + obj.__pl_saved_arg_names = param_names - # We want to use the latest possible value for dataset argument (i.e. ideally what gets passed to DataLoader) + # We want to use the latest possible value for explicit arguments (i.e. ideally what gets passed to base class) # so that we can be sure, that it will not get changed anymore. # That is why we are setting this in every `__init__` - if "dataset" in param_names: - setattr(obj, "__dataset", args[param_names.index("dataset")]) - elif "dataset" in kwargs: - setattr(obj, "__dataset", kwargs["dataset"]) + if store_explicit_args is not None: + for explicit_arg in store_explicit_args: + if explicit_arg in param_names: + setattr(obj, f"__{explicit_arg}", args[param_names.index(explicit_arg)]) + elif explicit_arg in kwargs: + setattr(obj, f"__{explicit_arg}", kwargs[explicit_arg]) init(obj, *args, **kwargs) @@ -404,15 +406,19 @@ def recurse(cl: Type[Any]) -> None: @contextmanager -def _replace_dataloader_init_method() -> Generator[None, None, None]: - """This context manager is used to add support for re-instantiation of custom (subclasses) of - :class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method.""" - classes = _get_all_subclasses(DataLoader) | {DataLoader} +def _replace_init_method( + base_cls: Type, store_explicit_args: Optional[List[str]] = None +) -> Generator[None, None, None]: + """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. + + It patches the ``__init__`` method. + """ + classes = _get_all_subclasses(base_cls) | {base_cls} wrapped = set() for cls in classes: if cls.__init__ not in wrapped: cls._old_init = cls.__init__ - cls.__init__ = _wrap_dataloader_init(cls.__init__) + cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_args) wrapped.add(cls.__init__) yield for cls in classes: @@ -457,13 +463,13 @@ def _apply_fault_tolerant_automatic_capture_dataset_wrapper( def _is_dataloader_shuffled(dataloader: object) -> bool: - if hasattr(dataloader, "__pl_dl_kwargs"): + if hasattr(dataloader, "__pl_saved_kwargs"): # this attribute is not part of PyTorch's DataLoader, but could have been set by - # our `_replace_dataloader_init_method` context manager - if "shuffle" in dataloader.__pl_dl_kwargs: - return dataloader.__pl_dl_kwargs["shuffle"] - if "shuffle" in dataloader.__pl_dl_arg_names: - return dataloader.__pl_dl_args[dataloader.__pl_dl_arg_names.index("shuffle")] + # our `_replace_init_method` context manager + if "shuffle" in dataloader.__pl_saved_kwargs: + return dataloader.__pl_saved_kwargs["shuffle"] + if "shuffle" in dataloader.__pl_saved_arg_names: + return dataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")] if isinstance(dataloader.dataset, IterableDataset): # shuffling is useless with iterable datasets return False diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index 7166be0981846..5e71bf43271ca 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -177,7 +177,7 @@ def test_setup_dataloaders_return_type(): assert lite_dataloader1.dataset is dataset1 -@mock.patch("pytorch_lightning.lite.lite._replace_dataloader_init_method") +@mock.patch("pytorch_lightning.lite.lite._replace_init_method") def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager): """Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method.""" diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index 7b1e596d50f8c..1a8be5a0160c9 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -10,7 +10,7 @@ from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.data import ( _get_dataloader_init_args_and_kwargs, - _replace_dataloader_init_method, + _replace_init_method, _update_dataloader, extract_batch_size, get_len, @@ -144,7 +144,7 @@ def __init__(self, foo, *args, **kwargs): with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"): _update_dataloader(dataloader, dataloader.sampler) - with _replace_dataloader_init_method(): + with _replace_init_method(DataLoader, ["dataset"]): dataloader = BadStandaloneGoodHookImpl([1, 2, 3]) new_dataloader = _update_dataloader(dataloader, dataloader.sampler) assert isinstance(new_dataloader, BadStandaloneGoodHookImpl) @@ -295,13 +295,13 @@ def __init__(self, dataset, **kwargs): pytest.param(ChangingDataLoader, (range(5),), dict(), ("dataset",), list(range(10)), dict(), id="test9"), ], ) -def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, checked_values): - with _replace_dataloader_init_method(): +def test_replace_init_method_dataloader(cls, args, kwargs, arg_names, dataset, checked_values): + with _replace_init_method(DataLoader, ["dataset"]): dataloader = cls(*args, **kwargs) - assert dataloader.__pl_dl_args == args - assert dataloader.__pl_dl_kwargs == kwargs - assert dataloader.__pl_dl_arg_names == arg_names + assert dataloader.__pl_saved_args == args + assert dataloader.__pl_saved_kwargs == kwargs + assert dataloader.__pl_saved_arg_names == arg_names assert dataloader.__dataset == dataset assert dataloader.dataset == dataset @@ -316,9 +316,9 @@ def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, c dataloader = _update_dataloader(dataloader, dataloader.sampler) assert isinstance(dataloader, cls) - assert not hasattr(dataloader, "__pl_dl_kwargs") - assert not hasattr(dataloader, "__pl_dl_arg_names") - assert not hasattr(dataloader, "__pl_dl_args") + assert not hasattr(dataloader, "__pl_saved_kwargs") + assert not hasattr(dataloader, "__pl_saved_arg_names") + assert not hasattr(dataloader, "__pl_saved_args") assert not hasattr(dataloader, "__dataset") assert dataloader.dataset == dataset From 6e01b0778dc3f2a16f234239ce025cb503f64f4b Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 13 Jul 2022 18:38:13 +0200 Subject: [PATCH 02/23] custom batch sampler support --- src/pytorch_lightning/lite/lite.py | 6 +- .../trainer/connectors/data_connector.py | 4 +- .../utilities/auto_restart.py | 14 +-- src/pytorch_lightning/utilities/data.py | 74 +++++++++++- tests/tests_pytorch/lite/test_lite.py | 5 +- .../utilities/test_auto_restart.py | 10 -- tests/tests_pytorch/utilities/test_data.py | 108 +++++++++++++++++- 7 files changed, 183 insertions(+), 38 deletions(-) diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index bdd2d5308a17b..d04d6932fb3f5 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -22,7 +22,7 @@ import torch.nn as nn from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler, RandomSampler, SequentialSampler +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler, RandomSampler, SequentialSampler from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer @@ -409,7 +409,9 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: self._strategy.setup_environment() - with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, ["dataset"]): + with self._strategy.model_sharded_context(), _replace_init_method( + DataLoader, ["dataset"] + ), _replace_init_method(BatchSampler): return run_method(*args, **kwargs) def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index cf739c4ee69e1..3336e63084cfa 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -17,7 +17,7 @@ from typing import Any, Callable, Collection, List, Optional, Tuple, Union from weakref import proxy -from torch.utils.data import DataLoader, Sampler, SequentialSampler +from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler import pytorch_lightning as pl @@ -424,7 +424,7 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat """ source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source") - with _replace_init_method(DataLoader, ["dataset"]): + with _replace_init_method(DataLoader, ["dataset"]), _replace_init_method(BatchSampler): # under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as # attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning dataloader = source.dataloader() diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 0bd2942b17cca..3877a1ab3944c 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -16,15 +16,7 @@ from functools import partial, wraps from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union -from torch.utils.data import ( - BatchSampler, - Dataset, - DistributedSampler, - get_worker_info, - RandomSampler, - Sampler, - SequentialSampler, -) +from torch.utils.data import Dataset, DistributedSampler, get_worker_info, RandomSampler, Sampler, SequentialSampler from torch.utils.data.dataloader import ( _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, @@ -757,10 +749,6 @@ def _validate_map_dataset(dataloader: DataLoader) -> None: if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS: raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.") - batch_sampler = getattr(dataloader, "batch_sampler", None) - if batch_sampler is not None and type(batch_sampler) is not BatchSampler: - raise TypeError("Fault-tolerance supports only a `BatchSampler`.") - if type(sampler) is DistributedSampler and sampler.shuffle: raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.") elif type(sampler) is RandomSampler: diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 12dd47d4a3a98..f50929d94fc71 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -322,12 +322,44 @@ def _dataloader_init_kwargs_resolve_sampler( batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. - if batch_sampler is not None and (type(batch_sampler) is not BatchSampler or is_predicting): - batch_sampler = type(batch_sampler)( - sampler, - batch_size=batch_sampler.batch_size, - drop_last=(False if is_predicting else batch_sampler.drop_last), - ) + batch_cls = type(batch_sampler) + if batch_sampler is not None and (batch_cls is not BatchSampler or is_predicting): + if hasattr(batch_sampler, "__pl_saved_args"): + args = list(batch_sampler.__pl_saved_args) + kwargs = batch_sampler.__pl_saved_kwargs + arg_names = batch_sampler.__pl_saved_arg_names + + if is_predicting: + success, args, kwargs = _inject_arg_to_saved(args, kwargs, arg_names, "drop_last", False) + if not success: + rank_zero_warn( + "Trying to inject `drop_last=False` into batch sampler since you are predicting, however it " + f"seems the class `{batch_cls}` does not support it. Your predictions might be incomplete. " + "To mitigate this, expose `drop_last` in the `__init__` method of your custom class." + ) + + success, args, kwargs = _inject_arg_to_saved(args, kwargs, arg_names, "sampler", sampler) + if not success: + raise MisconfigurationException( + "Trying to inject modified sampler into batch sampler, however it seems the class " + f"`{batch_cls}` does not support argument called sampler. To mitigate this, " + "expose argument `sampler` in the `__init__` method of your custom class." + ) + + batch_sampler = batch_cls(*args, **kwargs) + else: + try: + batch_sampler = batch_cls( + sampler, + batch_size=batch_sampler.batch_size, + drop_last=(False if is_predicting else batch_sampler.drop_last), + ) + except TypeError: + raise MisconfigurationException( + "We tried to reinstantiate your custom batch sampler and failed. " + "To mitigate this, either follow API of `BatchSampler` or instantiate " + "your custom batch sampler inside `*_dataloader` hooks of your module." + ) if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) @@ -350,6 +382,25 @@ def _dataloader_init_kwargs_resolve_sampler( return {"sampler": sampler, "shuffle": False, "batch_sampler": None} +def _inject_arg_to_saved( + args: List[Any], kwargs: Dict[str, Any], arg_names: List[str], inject_name: str, inject_object: Any +) -> Tuple[bool, List[Any], Dict[str, Any]]: + """Tries to inject a custom argument to a saved list of args and kwargs. + + Returns a tuple indicating success of the operation and modified saved args and kwargs + """ + + if inject_name in arg_names: + inject_index = arg_names.index(inject_name) + args = args[:inject_index] + [inject_object] + args[inject_index + 1 :] + return True, args, kwargs + elif inject_name in kwargs: + kwargs[inject_name] = inject_object + return True, args, kwargs + + return False, args, kwargs + + def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) @@ -371,6 +422,17 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: ) param_names = param_names[: len(args)] + default_kwargs = { + param.name: param.default + for param in params.values() + if param.name != "self" + and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) + and param.default != param.empty + and (param.name not in kwargs and param.name not in param_names) + } + + kwargs = {**kwargs, **default_kwargs} + if not hasattr(obj, "__pl_saved_args"): obj.__pl_saved_args = args obj.__pl_saved_kwargs = kwargs diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index 5e71bf43271ca..24be3f23efa12 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -183,10 +183,11 @@ def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager): class Lite(LightningLite): def run(self): - ctx_manager().__enter__.assert_called_once() + # One for BatchSampler, another for DataLoader + assert ctx_manager().__enter__.call_count == 2 Lite().run() - ctx_manager().__exit__.assert_called_once() + assert ctx_manager().__exit__.call_count == 2 def test_setup_dataloaders_raises_for_unknown_custom_args(): diff --git a/tests/tests_pytorch/utilities/test_auto_restart.py b/tests/tests_pytorch/utilities/test_auto_restart.py index 47051d4efd098..7b35bbe596ae2 100644 --- a/tests/tests_pytorch/utilities/test_auto_restart.py +++ b/tests/tests_pytorch/utilities/test_auto_restart.py @@ -34,7 +34,6 @@ from torch.utils.data._utils.worker import _generate_state, get_worker_info from torch.utils.data.dataloader import DataLoader, default_collate from torch.utils.data.dataset import Dataset, IterableDataset -from torch.utils.data.sampler import Sampler import tests_pytorch.helpers.utils as tutils from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer @@ -1186,15 +1185,6 @@ class CustomRandomSampler(RandomSampler): with pytest.raises(TypeError, match="RandomSampler"): _validate_fault_tolerant_automatic(dl, RunningStage.TRAINING) - class CustomBatchSampler(BatchSampler): - pass - - sampler = Sampler(data()) - batch_sampler = CustomBatchSampler(sampler, 2, False) - dl = DataLoader(data(), batch_sampler=batch_sampler) - with pytest.raises(TypeError, match="BatchSampler"): - _validate_fault_tolerant_automatic(dl, RunningStage.TRAINING) - class CustomIterable(IterableDataset): pass diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index 1a8be5a0160c9..7f06f109cc4fa 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -3,13 +3,15 @@ import pytest import torch from torch import Tensor -from torch.utils.data.dataloader import DataLoader +from torch.utils.data import BatchSampler, DataLoader, RandomSampler from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.data import ( _get_dataloader_init_args_and_kwargs, + _inject_arg_to_saved, _replace_init_method, _update_dataloader, extract_batch_size, @@ -311,7 +313,7 @@ def test_replace_init_method_dataloader(cls, args, kwargs, arg_names, dataset, c if isinstance(dataloader_value, torch.Tensor): assert dataloader_value is value else: - assert getattr(dataloader, key) == value + assert dataloader_value == value dataloader = _update_dataloader(dataloader, dataloader.sampler) @@ -328,7 +330,107 @@ def test_replace_init_method_dataloader(cls, args, kwargs, arg_names, dataset, c if isinstance(dataloader_value, torch.Tensor): assert dataloader_value is value else: - assert getattr(dataloader, key) == value + assert dataloader_value == value + + +def test_replace_init_method_extra_kwargs(): + class LoaderSubclass(DataLoader): + def __init__(self, dataset, *args, batch_size=10, **kwargs): + super().__init__(dataset, *args, batch_size=batch_size, **kwargs) + + with _replace_init_method(DataLoader, ["dataset"]): + dataloader = LoaderSubclass(range(10)) + + assert dataloader.__pl_saved_args == (range(10),) + assert dataloader.__pl_saved_kwargs == {"batch_size": 10} + assert dataloader.__pl_saved_arg_names == ("dataset",) + assert dataloader.__dataset == range(10) + + +@pytest.mark.parametrize("predicting", [True, False]) +def test_custom_batch_sampler(predicting): + class MyBatchSampler(BatchSampler): + def __init__(self, sampler, extra_arg, drop_last=True): + self.extra_arg = extra_arg + super().__init__(sampler, 10, drop_last) + + with _replace_init_method(BatchSampler): + sampler = RandomSampler(range(10)) + dataloader = DataLoader(range(10), batch_sampler=MyBatchSampler(sampler, "random_str")) + + assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str") + assert dataloader.batch_sampler.__pl_saved_kwargs == {"drop_last": True} + assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg") + + dataloader = _update_dataloader( + dataloader, dataloader.sampler, mode=RunningStage.PREDICTING if predicting else None + ) + + batch_sampler = dataloader.batch_sampler + + if predicting: + assert isinstance(batch_sampler, IndexBatchSamplerWrapper) + batch_sampler = batch_sampler._sampler + + assert isinstance(batch_sampler, MyBatchSampler) + assert batch_sampler.drop_last == (not predicting) + + assert batch_sampler.extra_arg == "random_str" + assert not hasattr(batch_sampler, "__pl_saved_kwargs") + assert not hasattr(batch_sampler, "__pl_saved_arg_names") + assert not hasattr(batch_sampler, "__pl_saved_args") + + +def test_custom_batch_sampler_no_drop_last(): + class MyBatchSampler(BatchSampler): + def __init__(self, sampler, extra_arg): + self.extra_arg = extra_arg + super().__init__(sampler, 10, False) + + with _replace_init_method(BatchSampler): + sampler = RandomSampler(range(10)) + dataloader = DataLoader(range(10), batch_sampler=MyBatchSampler(sampler, "random_str")) + + assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str") + assert dataloader.batch_sampler.__pl_saved_kwargs == {} + assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg") + + with pytest.warns(UserWarning, match="drop_last=False"): + dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING) + + +def test_custom_batch_sampler_no_sampler(): + class MyBatchSampler(BatchSampler): + def __init__(self, extra_arg): + self.extra_arg = extra_arg + super().__init__(RandomSampler(range(10)), 10, False) + + with _replace_init_method(BatchSampler): + dataloader = DataLoader(range(10), batch_sampler=MyBatchSampler("random_str")) + + assert dataloader.batch_sampler.__pl_saved_args == ("random_str",) + assert dataloader.batch_sampler.__pl_saved_kwargs == {} + assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",) + + with pytest.raises(MisconfigurationException, match="sampler into batch sampler"): + dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING) + + +@pytest.mark.parametrize( + ["args", "kwargs", "arg_names", "inject_name", "inject_obj", "expected_status", "expected_args", "expected_kwargs"], + [ + pytest.param([], {}, [], "a", 1, False, [], {}, id="empty"), + pytest.param([1], {}, ["a"], "a", 2, True, [2], {}, id="simple1"), + pytest.param([1, 2, 3], {}, ["a", "b", "c"], "b", False, True, [1, False, 3], {}, id="simple2"), + pytest.param([1, 2, 3], {"a": 1}, ["b", "c", "d"], "a", 2, True, [1, 2, 3], {"a": 2}, id="simple_kwargs"), + ], +) +def test_inject_args(args, kwargs, arg_names, inject_name, inject_obj, expected_status, expected_args, expected_kwargs): + assert _inject_arg_to_saved(args, kwargs, arg_names, inject_name, inject_obj) == ( + expected_status, + expected_args, + expected_kwargs, + ) @pytest.mark.parametrize("mode", [RunningStage.TRAINING, RunningStage.PREDICTING, RunningStage.TESTING]) From d5cee5a6861ed14d090664d9ee95dbeb5ca71514 Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 13 Jul 2022 18:47:41 +0200 Subject: [PATCH 03/23] changelog --- src/pytorch_lightning/CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index 9b3cefc66bd19..cee7a5b7df028 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -286,6 +286,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Improved support for custom `DataLoader`s when instantiated in `*_dataloader` hook ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981)) +- Allowed custom `BatchSampler`s when instantiated in `*_dataloader` hook [#13640](https://github.com/PyTorchLightning/pytorch-lightning/pull/13640)) + - Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014)) From 834e98c71ecdc0339516155ef539d5ec4c3c5cdc Mon Sep 17 00:00:00 2001 From: otaj Date: Thu, 14 Jul 2022 09:06:20 +0200 Subject: [PATCH 04/23] apply suggestions from code review --- src/pytorch_lightning/lite/lite.py | 6 +- .../trainer/connectors/data_connector.py | 2 +- src/pytorch_lightning/utilities/data.py | 71 ++++++++++--------- tests/tests_pytorch/utilities/test_data.py | 25 +++++-- 4 files changed, 61 insertions(+), 43 deletions(-) diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index d04d6932fb3f5..a12826ecf2ced 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -409,9 +409,9 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: self._strategy.setup_environment() - with self._strategy.model_sharded_context(), _replace_init_method( - DataLoader, ["dataset"] - ), _replace_init_method(BatchSampler): + with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, "dataset"), _replace_init_method( + BatchSampler + ): return run_method(*args, **kwargs) def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 3336e63084cfa..7831316a98ae1 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -424,7 +424,7 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat """ source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source") - with _replace_init_method(DataLoader, ["dataset"]), _replace_init_method(BatchSampler): + with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler): # under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as # attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning dataloader = source.dataloader() diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index f50929d94fc71..1bb4b7f59ec92 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -322,44 +322,54 @@ def _dataloader_init_kwargs_resolve_sampler( batch_sampler = getattr(dataloader, "batch_sampler") is_predicting = mode == RunningStage.PREDICTING # checking the batch sampler type is different than PyTorch default. - batch_cls = type(batch_sampler) - if batch_sampler is not None and (batch_cls is not BatchSampler or is_predicting): + batch_sampler_cls = type(batch_sampler) + if batch_sampler is not None and (batch_sampler_cls is not BatchSampler or is_predicting): if hasattr(batch_sampler, "__pl_saved_args"): args = list(batch_sampler.__pl_saved_args) kwargs = batch_sampler.__pl_saved_kwargs arg_names = batch_sampler.__pl_saved_arg_names if is_predicting: - success, args, kwargs = _inject_arg_to_saved(args, kwargs, arg_names, "drop_last", False) + success, args, kwargs = _replace_value_in_saved_args("drop_last", False, args, kwargs, arg_names) if not success: rank_zero_warn( - "Trying to inject `drop_last=False` into batch sampler since you are predicting, however it " - f"seems the class `{batch_cls}` does not support it. Your predictions might be incomplete. " - "To mitigate this, expose `drop_last` in the `__init__` method of your custom class." + f"Trying to inject `drop_last=False` into batch sampler since you are predicting, however it " + f"seems the class `{batch_sampler_cls.__qualname__}` does not support it. " + "Your predictions might be incomplete. To mitigate this, expose `drop_last` in the `__init__` " + "method of your custom class." ) - success, args, kwargs = _inject_arg_to_saved(args, kwargs, arg_names, "sampler", sampler) + success, args, kwargs = _replace_value_in_saved_args("sampler", sampler, args, kwargs, arg_names) if not success: raise MisconfigurationException( "Trying to inject modified sampler into batch sampler, however it seems the class " - f"`{batch_cls}` does not support argument called sampler. To mitigate this, " + f"`{batch_sampler_cls.__qualname__}` does not support argument called sampler. To mitigate this, " "expose argument `sampler` in the `__init__` method of your custom class." ) - batch_sampler = batch_cls(*args, **kwargs) + batch_sampler = batch_sampler_cls(*args, **kwargs) else: try: - batch_sampler = batch_cls( + batch_sampler = batch_sampler_cls( sampler, batch_size=batch_sampler.batch_size, drop_last=(False if is_predicting else batch_sampler.drop_last), ) - except TypeError: + except TypeError as e: + import re + + match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(e)) + if not match: + # an unexpected `TypeError`, continue failure + raise + + # There could either be too few or too many arguments. Customizing the message based on this doesn't + # make much sense since our MisconfigurationException is going to be thrown from the original one. raise MisconfigurationException( "We tried to reinstantiate your custom batch sampler and failed. " - "To mitigate this, either follow API of `BatchSampler` or instantiate " + "To mitigate this, either follow the API of `BatchSampler` or instantiate " "your custom batch sampler inside `*_dataloader` hooks of your module." - ) + ) from e if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) @@ -382,20 +392,20 @@ def _dataloader_init_kwargs_resolve_sampler( return {"sampler": sampler, "shuffle": False, "batch_sampler": None} -def _inject_arg_to_saved( - args: List[Any], kwargs: Dict[str, Any], arg_names: List[str], inject_name: str, inject_object: Any +def _replace_value_in_saved_args( + replace_key: str, replace_value: Any, args: List[Any], kwargs: Dict[str, Any], arg_names: List[str] ) -> Tuple[bool, List[Any], Dict[str, Any]]: """Tries to inject a custom argument to a saved list of args and kwargs. Returns a tuple indicating success of the operation and modified saved args and kwargs """ - if inject_name in arg_names: - inject_index = arg_names.index(inject_name) - args = args[:inject_index] + [inject_object] + args[inject_index + 1 :] + if replace_key in arg_names: + replace_index = arg_names.index(replace_key) + args = args[:replace_index] + [replace_value] + args[replace_index + 1 :] return True, args, kwargs - elif inject_name in kwargs: - kwargs[inject_name] = inject_object + elif replace_key in kwargs: + kwargs[replace_key] = replace_value return True, args, kwargs return False, args, kwargs @@ -406,7 +416,7 @@ def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) -def _wrap_init_method(init: Callable, store_explicit_args: Optional[List[str]] = None) -> Callable: +def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable: """Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" @@ -438,15 +448,14 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: obj.__pl_saved_kwargs = kwargs obj.__pl_saved_arg_names = param_names - # We want to use the latest possible value for explicit arguments (i.e. ideally what gets passed to base class) + # We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class) # so that we can be sure, that it will not get changed anymore. # That is why we are setting this in every `__init__` - if store_explicit_args is not None: - for explicit_arg in store_explicit_args: - if explicit_arg in param_names: - setattr(obj, f"__{explicit_arg}", args[param_names.index(explicit_arg)]) - elif explicit_arg in kwargs: - setattr(obj, f"__{explicit_arg}", kwargs[explicit_arg]) + if store_explicit_arg is not None: + if store_explicit_arg in param_names: + setattr(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)]) + elif store_explicit_arg in kwargs: + setattr(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg]) init(obj, *args, **kwargs) @@ -468,9 +477,7 @@ def recurse(cl: Type[Any]) -> None: @contextmanager -def _replace_init_method( - base_cls: Type, store_explicit_args: Optional[List[str]] = None -) -> Generator[None, None, None]: +def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. It patches the ``__init__`` method. @@ -480,7 +487,7 @@ def _replace_init_method( for cls in classes: if cls.__init__ not in wrapped: cls._old_init = cls.__init__ - cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_args) + cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) wrapped.add(cls.__init__) yield for cls in classes: diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index 7f06f109cc4fa..cc5be8ef91fce 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -11,8 +11,8 @@ from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.data import ( _get_dataloader_init_args_and_kwargs, - _inject_arg_to_saved, _replace_init_method, + _replace_value_in_saved_args, _update_dataloader, extract_batch_size, get_len, @@ -146,7 +146,7 @@ def __init__(self, foo, *args, **kwargs): with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"): _update_dataloader(dataloader, dataloader.sampler) - with _replace_init_method(DataLoader, ["dataset"]): + with _replace_init_method(DataLoader, "dataset"): dataloader = BadStandaloneGoodHookImpl([1, 2, 3]) new_dataloader = _update_dataloader(dataloader, dataloader.sampler) assert isinstance(new_dataloader, BadStandaloneGoodHookImpl) @@ -298,7 +298,7 @@ def __init__(self, dataset, **kwargs): ], ) def test_replace_init_method_dataloader(cls, args, kwargs, arg_names, dataset, checked_values): - with _replace_init_method(DataLoader, ["dataset"]): + with _replace_init_method(DataLoader, "dataset"): dataloader = cls(*args, **kwargs) assert dataloader.__pl_saved_args == args @@ -338,7 +338,7 @@ class LoaderSubclass(DataLoader): def __init__(self, dataset, *args, batch_size=10, **kwargs): super().__init__(dataset, *args, batch_size=batch_size, **kwargs) - with _replace_init_method(DataLoader, ["dataset"]): + with _replace_init_method(DataLoader, "dataset"): dataloader = LoaderSubclass(range(10)) assert dataloader.__pl_saved_args == (range(10),) @@ -417,7 +417,16 @@ def __init__(self, extra_arg): @pytest.mark.parametrize( - ["args", "kwargs", "arg_names", "inject_name", "inject_obj", "expected_status", "expected_args", "expected_kwargs"], + [ + "args", + "kwargs", + "arg_names", + "replace_key", + "replace_value", + "expected_status", + "expected_args", + "expected_kwargs", + ], [ pytest.param([], {}, [], "a", 1, False, [], {}, id="empty"), pytest.param([1], {}, ["a"], "a", 2, True, [2], {}, id="simple1"), @@ -425,8 +434,10 @@ def __init__(self, extra_arg): pytest.param([1, 2, 3], {"a": 1}, ["b", "c", "d"], "a", 2, True, [1, 2, 3], {"a": 2}, id="simple_kwargs"), ], ) -def test_inject_args(args, kwargs, arg_names, inject_name, inject_obj, expected_status, expected_args, expected_kwargs): - assert _inject_arg_to_saved(args, kwargs, arg_names, inject_name, inject_obj) == ( +def test_replace_value_in_args( + args, kwargs, arg_names, replace_key, replace_value, expected_status, expected_args, expected_kwargs +): + assert _replace_value_in_saved_args(replace_key, replace_value, args, kwargs, arg_names) == ( expected_status, expected_args, expected_kwargs, From e7831c141e08804c0ad089992b8d5cef15d66255 Mon Sep 17 00:00:00 2001 From: otaj Date: Thu, 14 Jul 2022 09:20:28 +0200 Subject: [PATCH 05/23] change docstring --- src/pytorch_lightning/utilities/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 1bb4b7f59ec92..829d7ef9a7b9f 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -395,7 +395,7 @@ def _dataloader_init_kwargs_resolve_sampler( def _replace_value_in_saved_args( replace_key: str, replace_value: Any, args: List[Any], kwargs: Dict[str, Any], arg_names: List[str] ) -> Tuple[bool, List[Any], Dict[str, Any]]: - """Tries to inject a custom argument to a saved list of args and kwargs. + """Tries to replace an argument value in a saved list of args and kwargs. Returns a tuple indicating success of the operation and modified saved args and kwargs """ From 42683238b929f157364715fc53606531f0a8fc88 Mon Sep 17 00:00:00 2001 From: otaj Date: Thu, 14 Jul 2022 13:44:51 +0200 Subject: [PATCH 06/23] code review suggestions --- src/pytorch_lightning/utilities/data.py | 26 +++++++++++----------- tests/tests_pytorch/utilities/test_data.py | 23 +++++++++++-------- 2 files changed, 27 insertions(+), 22 deletions(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 829d7ef9a7b9f..fe84393f268c2 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -14,6 +14,7 @@ import functools import inspect import os +from collections import OrderedDict from contextlib import contextmanager from dataclasses import fields from functools import partial @@ -325,7 +326,7 @@ def _dataloader_init_kwargs_resolve_sampler( batch_sampler_cls = type(batch_sampler) if batch_sampler is not None and (batch_sampler_cls is not BatchSampler or is_predicting): if hasattr(batch_sampler, "__pl_saved_args"): - args = list(batch_sampler.__pl_saved_args) + args = batch_sampler.__pl_saved_args kwargs = batch_sampler.__pl_saved_kwargs arg_names = batch_sampler.__pl_saved_arg_names @@ -393,8 +394,8 @@ def _dataloader_init_kwargs_resolve_sampler( def _replace_value_in_saved_args( - replace_key: str, replace_value: Any, args: List[Any], kwargs: Dict[str, Any], arg_names: List[str] -) -> Tuple[bool, List[Any], Dict[str, Any]]: + replace_key: str, replace_value: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any], arg_names: List[str] +) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]: """Tries to replace an argument value in a saved list of args and kwargs. Returns a tuple indicating success of the operation and modified saved args and kwargs @@ -402,7 +403,7 @@ def _replace_value_in_saved_args( if replace_key in arg_names: replace_index = arg_names.index(replace_key) - args = args[:replace_index] + [replace_value] + args[replace_index + 1 :] + args = args[:replace_index] + (replace_value,) + args[replace_index + 1 :] return True, args, kwargs elif replace_key in kwargs: kwargs[replace_key] = replace_value @@ -425,20 +426,19 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: # We need to inspect `init`, as inspecting `obj.__init__` # can lead to inspecting the wrong function with multiple inheritance params = inspect.signature(init).parameters - param_names = tuple( - param.name + + parameters_defaults = OrderedDict( + (param.name, param.default) for param in params.values() if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) ) - param_names = param_names[: len(args)] + + param_names = tuple(parameters_defaults.keys())[: len(args)] default_kwargs = { - param.name: param.default - for param in params.values() - if param.name != "self" - and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) - and param.default != param.empty - and (param.name not in kwargs and param.name not in param_names) + name: value + for name, value in parameters_defaults.items() + if name not in kwargs and name not in param_names and value != inspect.Parameter.empty } kwargs = {**kwargs, **default_kwargs} diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index cc5be8ef91fce..e96f1baa25355 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -354,9 +354,11 @@ def __init__(self, sampler, extra_arg, drop_last=True): self.extra_arg = extra_arg super().__init__(sampler, 10, drop_last) + sampler = RandomSampler(range(10)) with _replace_init_method(BatchSampler): - sampler = RandomSampler(range(10)) - dataloader = DataLoader(range(10), batch_sampler=MyBatchSampler(sampler, "random_str")) + batch_sampler = MyBatchSampler(sampler, "random_str") + + dataloader = DataLoader(range(10), batch_sampler=batch_sampler) assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str") assert dataloader.batch_sampler.__pl_saved_kwargs == {"drop_last": True} @@ -387,9 +389,11 @@ def __init__(self, sampler, extra_arg): self.extra_arg = extra_arg super().__init__(sampler, 10, False) + sampler = RandomSampler(range(10)) with _replace_init_method(BatchSampler): - sampler = RandomSampler(range(10)) - dataloader = DataLoader(range(10), batch_sampler=MyBatchSampler(sampler, "random_str")) + batch_sampler = MyBatchSampler(sampler, "random_str") + + dataloader = DataLoader(range(10), batch_sampler=batch_sampler) assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str") assert dataloader.batch_sampler.__pl_saved_kwargs == {} @@ -406,7 +410,8 @@ def __init__(self, extra_arg): super().__init__(RandomSampler(range(10)), 10, False) with _replace_init_method(BatchSampler): - dataloader = DataLoader(range(10), batch_sampler=MyBatchSampler("random_str")) + batch_sampler = MyBatchSampler("random_str") + dataloader = DataLoader(range(10), batch_sampler=batch_sampler) assert dataloader.batch_sampler.__pl_saved_args == ("random_str",) assert dataloader.batch_sampler.__pl_saved_kwargs == {} @@ -428,10 +433,10 @@ def __init__(self, extra_arg): "expected_kwargs", ], [ - pytest.param([], {}, [], "a", 1, False, [], {}, id="empty"), - pytest.param([1], {}, ["a"], "a", 2, True, [2], {}, id="simple1"), - pytest.param([1, 2, 3], {}, ["a", "b", "c"], "b", False, True, [1, False, 3], {}, id="simple2"), - pytest.param([1, 2, 3], {"a": 1}, ["b", "c", "d"], "a", 2, True, [1, 2, 3], {"a": 2}, id="simple_kwargs"), + pytest.param((), {}, [], "a", 1, False, (), {}, id="empty"), + pytest.param((1,), {}, ["a"], "a", 2, True, (2,), {}, id="simple1"), + pytest.param((1, 2, 3), {}, ["a", "b", "c"], "b", False, True, (1, False, 3), {}, id="simple2"), + pytest.param((1, 2, 3), {"a": 1}, ["b", "c", "d"], "a", 2, True, (1, 2, 3), {"a": 2}, id="simple_kwargs"), ], ) def test_replace_value_in_args( From 0d2475a0b888b350572651b28bd03ad64d752591 Mon Sep 17 00:00:00 2001 From: otaj Date: Thu, 14 Jul 2022 13:48:39 +0200 Subject: [PATCH 07/23] types --- src/pytorch_lightning/utilities/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index fe84393f268c2..a4f87ea308e22 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -18,7 +18,7 @@ from contextlib import contextmanager from dataclasses import fields from functools import partial -from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union +from typing import Any, Callable, Dict, Generator, Iterable, Mapping, Optional, Set, Tuple, Type, Union import torch from torch import Tensor @@ -394,7 +394,7 @@ def _dataloader_init_kwargs_resolve_sampler( def _replace_value_in_saved_args( - replace_key: str, replace_value: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any], arg_names: List[str] + replace_key: str, replace_value: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any], arg_names: Tuple[str, ...] ) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]: """Tries to replace an argument value in a saved list of args and kwargs. From 9e106e12acf9ab30da22b9a611452164df0ecdec Mon Sep 17 00:00:00 2001 From: otaj Date: Thu, 14 Jul 2022 17:46:22 +0200 Subject: [PATCH 08/23] code review suggestion --- src/pytorch_lightning/utilities/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index a4f87ea308e22..9843421309a74 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -433,7 +433,7 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) ) - param_names = tuple(parameters_defaults.keys())[: len(args)] + param_names = tuple(parameters_defaults)[: len(args)] default_kwargs = { name: value From 5e3786834b783f6bbc205d0d97c41a1c73104d6d Mon Sep 17 00:00:00 2001 From: otaj <6065855+otaj@users.noreply.github.com> Date: Fri, 22 Jul 2022 15:30:50 +0200 Subject: [PATCH 09/23] Apply suggestions from code review Co-authored-by: Rohit Gupta --- src/pytorch_lightning/utilities/data.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 9843421309a74..1de0131aebdf3 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -343,7 +343,7 @@ def _dataloader_init_kwargs_resolve_sampler( success, args, kwargs = _replace_value_in_saved_args("sampler", sampler, args, kwargs, arg_names) if not success: raise MisconfigurationException( - "Trying to inject modified sampler into batch sampler, however it seems the class " + "Trying to inject modified sampler into the batch sampler; however, it seems the class " f"`{batch_sampler_cls.__qualname__}` does not support argument called sampler. To mitigate this, " "expose argument `sampler` in the `__init__` method of your custom class." ) @@ -367,7 +367,7 @@ def _dataloader_init_kwargs_resolve_sampler( # There could either be too few or too many arguments. Customizing the message based on this doesn't # make much sense since our MisconfigurationException is going to be thrown from the original one. raise MisconfigurationException( - "We tried to reinstantiate your custom batch sampler and failed. " + "We tried to re-instantiate your custom batch sampler and failed. " "To mitigate this, either follow the API of `BatchSampler` or instantiate " "your custom batch sampler inside `*_dataloader` hooks of your module." ) from e From 309298a3557e0afe54a130f93755868699fdb31b Mon Sep 17 00:00:00 2001 From: otaj <6065855+otaj@users.noreply.github.com> Date: Mon, 25 Jul 2022 09:57:36 +0200 Subject: [PATCH 10/23] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/utilities/data.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 1de0131aebdf3..2530ba105c89b 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -342,10 +342,10 @@ def _dataloader_init_kwargs_resolve_sampler( success, args, kwargs = _replace_value_in_saved_args("sampler", sampler, args, kwargs, arg_names) if not success: - raise MisconfigurationException( - "Trying to inject modified sampler into the batch sampler; however, it seems the class " - f"`{batch_sampler_cls.__qualname__}` does not support argument called sampler. To mitigate this, " - "expose argument `sampler` in the `__init__` method of your custom class." + raise TypeError( + "Trying to inject a modified sampler into the batch sampler; however, it seems the class " + f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate this, " + "expose an argument `sampler` in the `__init__` method of your custom class." ) batch_sampler = batch_sampler_cls(*args, **kwargs) @@ -365,7 +365,7 @@ def _dataloader_init_kwargs_resolve_sampler( raise # There could either be too few or too many arguments. Customizing the message based on this doesn't - # make much sense since our MisconfigurationException is going to be thrown from the original one. + # make much sense since our MisconfigurationException is going to be raised from the original one. raise MisconfigurationException( "We tried to re-instantiate your custom batch sampler and failed. " "To mitigate this, either follow the API of `BatchSampler` or instantiate " From 04f19216e1154a91418abd62031c0ba3e5dacb85 Mon Sep 17 00:00:00 2001 From: otaj Date: Mon, 25 Jul 2022 10:40:28 +0200 Subject: [PATCH 11/23] make test not fail --- tests/tests_pytorch/utilities/test_data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index e96f1baa25355..c23916cca1f34 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -417,7 +417,7 @@ def __init__(self, extra_arg): assert dataloader.batch_sampler.__pl_saved_kwargs == {} assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",) - with pytest.raises(MisconfigurationException, match="sampler into batch sampler"): + with pytest.raises(TypeError, match="sampler into the batch sampler"): dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING) From 88764c01e78dec00d267e3c70e82364276ca0ece Mon Sep 17 00:00:00 2001 From: otaj Date: Mon, 25 Jul 2022 12:13:54 +0200 Subject: [PATCH 12/23] comment the tests --- tests/tests_pytorch/utilities/test_data.py | 27 ++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index c23916cca1f34..d0747b4271e8b 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -349,25 +349,38 @@ def __init__(self, dataset, *args, batch_size=10, **kwargs): @pytest.mark.parametrize("predicting", [True, False]) def test_custom_batch_sampler(predicting): + """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to + properly reinstantiate the class, is invoked properly. + + It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore + not setting `__pl_saved_{args,arg_names,kwargs}` attributes. + """ + class MyBatchSampler(BatchSampler): + # Custom Batch sampler with extra argument and default value def __init__(self, sampler, extra_arg, drop_last=True): self.extra_arg = extra_arg super().__init__(sampler, 10, drop_last) sampler = RandomSampler(range(10)) with _replace_init_method(BatchSampler): + # instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks batch_sampler = MyBatchSampler(sampler, "random_str") dataloader = DataLoader(range(10), batch_sampler=batch_sampler) + # assert that passed information got saved assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str") assert dataloader.batch_sampler.__pl_saved_kwargs == {"drop_last": True} assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg") + # updating dataloader, what happens on access of the dataloaders. + # This should not fail, and would fail before support for custom args. dataloader = _update_dataloader( dataloader, dataloader.sampler, mode=RunningStage.PREDICTING if predicting else None ) + # Assert the `__init__` method is not replaced anymore and everything is instantiated to correct types batch_sampler = dataloader.batch_sampler if predicting: @@ -384,39 +397,53 @@ def __init__(self, sampler, extra_arg, drop_last=True): def test_custom_batch_sampler_no_drop_last(): + """Tests whether appropriate warning is raised when the custom `BatchSampler` does not support `drop_last` and + we want to reset it.""" + class MyBatchSampler(BatchSampler): + # Custom batch sampler with extra argument, but without `drop_last` def __init__(self, sampler, extra_arg): self.extra_arg = extra_arg super().__init__(sampler, 10, False) sampler = RandomSampler(range(10)) with _replace_init_method(BatchSampler): + # instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks batch_sampler = MyBatchSampler(sampler, "random_str") dataloader = DataLoader(range(10), batch_sampler=batch_sampler) + # assert that passed information got saved assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str") assert dataloader.batch_sampler.__pl_saved_kwargs == {} assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg") + # Assert that warning is raised with pytest.warns(UserWarning, match="drop_last=False"): dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING) def test_custom_batch_sampler_no_sampler(): + """Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler + argument.""" + class MyBatchSampler(BatchSampler): + # Custom batch sampler, without sampler argument. def __init__(self, extra_arg): self.extra_arg = extra_arg super().__init__(RandomSampler(range(10)), 10, False) with _replace_init_method(BatchSampler): + # instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks batch_sampler = MyBatchSampler("random_str") dataloader = DataLoader(range(10), batch_sampler=batch_sampler) + # assert that passed information got saved assert dataloader.batch_sampler.__pl_saved_args == ("random_str",) assert dataloader.batch_sampler.__pl_saved_kwargs == {} assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",) + # Assert that error is raised with pytest.raises(TypeError, match="sampler into the batch sampler"): dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING) From e996117f2d1962838d82ff8ff5f4da9a354f7ffc Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 26 Jul 2022 13:07:49 +0200 Subject: [PATCH 13/23] pass mode in ipu --- src/pytorch_lightning/strategies/ipu.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 5413756c15271..3813527c3081b 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -228,7 +228,7 @@ def _convert_to_poptorch_loader( # the user is returning the `poptorch.DataLoader` directly, don't change anything. return dataloader - dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler) + dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode) opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs) return dataloader From 5a6513b886f50600dd7a75f50ee6c880ac0f477b Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 26 Jul 2022 13:18:37 +0200 Subject: [PATCH 14/23] batch_size is None when batch_sampler is set --- src/pytorch_lightning/utilities/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 2530ba105c89b..14e47eab78acd 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -382,7 +382,7 @@ def _dataloader_init_kwargs_resolve_sampler( "sampler": None, "shuffle": False, "batch_sampler": batch_sampler, - "batch_size": 1, + "batch_size": None, "drop_last": False, } From 4314fa3cb90c588770bf53c884db0c990eaae716 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 26 Jul 2022 13:41:05 +0200 Subject: [PATCH 15/23] don't pass anything more than necessary --- src/pytorch_lightning/utilities/data.py | 4 ---- 1 file changed, 4 deletions(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 14e47eab78acd..c9a5f9cb227ac 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -379,11 +379,7 @@ def _dataloader_init_kwargs_resolve_sampler( fast_forward_sampler.setup(dataloader_batch_size=1) return { - "sampler": None, - "shuffle": False, "batch_sampler": batch_sampler, - "batch_size": None, - "drop_last": False, } if fault_tolerant_mode.is_automatic: From a0bb6ff86c3794b3864ba0a1fb6d991972046507 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 26 Jul 2022 14:01:04 +0200 Subject: [PATCH 16/23] test ipu failing test --- src/pytorch_lightning/strategies/ipu.py | 6 ++++-- src/pytorch_lightning/utilities/data.py | 4 ++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 3813527c3081b..b1a9e4e674d23 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -132,8 +132,10 @@ def setup(self, trainer: "pl.Trainer") -> None: # patch the dataloader creation function with the custom `poptorch.DataLoader`. # this violates the intended control flow for the plugins, but since this is experimental, we have chosen # to use the simpler solution before adding abstractions to override the `DataLoader` class - self._update_dataloader_original = pl.trainer.connectors.data_connector._update_dataloader - pl.trainer.connectors.data_connector._update_dataloader = self._convert_to_poptorch_loader + + # Commented out to test if test failures are due to the poptorch loader + # self._update_dataloader_original = pl.trainer.connectors.data_connector._update_dataloader + # pl.trainer.connectors.data_connector._update_dataloader = self._convert_to_poptorch_loader super().setup(trainer) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index c9a5f9cb227ac..2530ba105c89b 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -379,7 +379,11 @@ def _dataloader_init_kwargs_resolve_sampler( fast_forward_sampler.setup(dataloader_batch_size=1) return { + "sampler": None, + "shuffle": False, "batch_sampler": batch_sampler, + "batch_size": 1, + "drop_last": False, } if fault_tolerant_mode.is_automatic: From 7dc1cf4dda9b092ffe86f87142035534cf6cc4d9 Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 26 Jul 2022 14:58:48 +0200 Subject: [PATCH 17/23] return to poptorch dataloader --- src/pytorch_lightning/strategies/ipu.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index b1a9e4e674d23..3813527c3081b 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -132,10 +132,8 @@ def setup(self, trainer: "pl.Trainer") -> None: # patch the dataloader creation function with the custom `poptorch.DataLoader`. # this violates the intended control flow for the plugins, but since this is experimental, we have chosen # to use the simpler solution before adding abstractions to override the `DataLoader` class - - # Commented out to test if test failures are due to the poptorch loader - # self._update_dataloader_original = pl.trainer.connectors.data_connector._update_dataloader - # pl.trainer.connectors.data_connector._update_dataloader = self._convert_to_poptorch_loader + self._update_dataloader_original = pl.trainer.connectors.data_connector._update_dataloader + pl.trainer.connectors.data_connector._update_dataloader = self._convert_to_poptorch_loader super().setup(trainer) From 305730d55311198e977db13644216d7105228d3d Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 27 Jul 2022 12:32:12 +0200 Subject: [PATCH 18/23] added a bit of docstring --- src/pytorch_lightning/utilities/data.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index f5ee1c46a7d1d..716c0481697bb 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -324,6 +324,9 @@ def _dataloader_init_kwargs_resolve_sampler( If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a `FastForwardSampler`. + + If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated + automatically, since `poptorch.DataLoader` will try to increase the batch_size """ fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() batch_sampler = getattr(dataloader, "batch_sampler") From 5a51c5aba9515bc64e7f959f4bbefd6b525a017d Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 27 Jul 2022 15:19:38 +0200 Subject: [PATCH 19/23] pprint debugging --- src/pytorch_lightning/strategies/ipu.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 001ad77fbb5cc..108c725d1abdf 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -234,7 +234,13 @@ def _convert_to_poptorch_loader( dataloader, sampler, mode, self.replication_factor > 1 ) opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts + from pprint import pprint + + pprint(opts) + pprint(dl_args) + pprint(dl_kwargs) dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs) + pprint(dataloader) return dataloader def _handle_gradient_accumulation_steps(self) -> None: From ab20f5fc28bf4721cf39b5a876aaa89e07b4606f Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 27 Jul 2022 15:26:54 +0200 Subject: [PATCH 20/23] pprint debugging --- src/pytorch_lightning/strategies/ipu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 108c725d1abdf..aa82fb9a201c9 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -236,11 +236,11 @@ def _convert_to_poptorch_loader( opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts from pprint import pprint - pprint(opts) + pprint(repr(opts)) pprint(dl_args) pprint(dl_kwargs) dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs) - pprint(dataloader) + pprint(repr(dataloader)) return dataloader def _handle_gradient_accumulation_steps(self) -> None: From ff14329241438051223c0d3696b83097b99dfe2a Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 27 Jul 2022 16:43:55 +0200 Subject: [PATCH 21/23] Default kwargs handling --- src/pytorch_lightning/utilities/data.py | 21 +++++++++---- tests/tests_pytorch/utilities/test_data.py | 36 +++++++++++++++++----- 2 files changed, 43 insertions(+), 14 deletions(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 716c0481697bb..862c7f2de905b 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -349,10 +349,13 @@ def _dataloader_init_kwargs_resolve_sampler( if hasattr(batch_sampler, "__pl_saved_args"): args = batch_sampler.__pl_saved_args kwargs = batch_sampler.__pl_saved_kwargs + default_kwargs = batch_sampler.__pl_saved_default_kwargs arg_names = batch_sampler.__pl_saved_arg_names if is_predicting: - success, args, kwargs = _replace_value_in_saved_args("drop_last", False, args, kwargs, arg_names) + success, args, kwargs = _replace_value_in_saved_args( + "drop_last", False, args, kwargs, default_kwargs, arg_names + ) if not success: rank_zero_warn( f"Trying to inject `drop_last=False` into batch sampler since you are predicting, however " @@ -361,7 +364,9 @@ def _dataloader_init_kwargs_resolve_sampler( "the `__init__` method of your custom class." ) - success, args, kwargs = _replace_value_in_saved_args("sampler", sampler, args, kwargs, arg_names) + success, args, kwargs = _replace_value_in_saved_args( + "sampler", sampler, args, kwargs, default_kwargs, arg_names + ) if not success: raise TypeError( "Trying to inject a modified sampler into the batch sampler; however, it seems the class " @@ -416,7 +421,12 @@ def _dataloader_init_kwargs_resolve_sampler( def _replace_value_in_saved_args( - replace_key: str, replace_value: Any, args: Tuple[Any, ...], kwargs: Dict[str, Any], arg_names: Tuple[str, ...] + replace_key: str, + replace_value: Any, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + default_kwargs: Dict[str, Any], + arg_names: Tuple[str, ...], ) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]: """Tries to replace an argument value in a saved list of args and kwargs. @@ -427,7 +437,7 @@ def _replace_value_in_saved_args( replace_index = arg_names.index(replace_key) args = args[:replace_index] + (replace_value,) + args[replace_index + 1 :] return True, args, kwargs - elif replace_key in kwargs: + elif replace_key in kwargs or replace_key in default_kwargs: kwargs[replace_key] = replace_value return True, args, kwargs @@ -463,12 +473,11 @@ def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: if name not in kwargs and name not in param_names and value != inspect.Parameter.empty } - kwargs = {**kwargs, **default_kwargs} - if not hasattr(obj, "__pl_saved_args"): obj.__pl_saved_args = args obj.__pl_saved_kwargs = kwargs obj.__pl_saved_arg_names = param_names + obj.__pl_saved_default_kwargs = default_kwargs # We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class) # so that we can be sure, that it will not get changed anymore. diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index b84ef28ce7a58..ffb898efaa815 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -305,6 +305,7 @@ def test_replace_init_method_dataloader(cls, args, kwargs, arg_names, dataset, c assert dataloader.__pl_saved_args == args assert dataloader.__pl_saved_kwargs == kwargs assert dataloader.__pl_saved_arg_names == arg_names + assert dataloader.__pl_saved_default_kwargs == {} assert dataloader.__dataset == dataset assert dataloader.dataset == dataset @@ -322,6 +323,7 @@ def test_replace_init_method_dataloader(cls, args, kwargs, arg_names, dataset, c assert not hasattr(dataloader, "__pl_saved_kwargs") assert not hasattr(dataloader, "__pl_saved_arg_names") assert not hasattr(dataloader, "__pl_saved_args") + assert not hasattr(dataloader, "__pl_saved_default_kwargs") assert not hasattr(dataloader, "__dataset") assert dataloader.dataset == dataset @@ -343,8 +345,9 @@ def __init__(self, dataset, *args, batch_size=10, **kwargs): dataloader = LoaderSubclass(range(10)) assert dataloader.__pl_saved_args == (range(10),) - assert dataloader.__pl_saved_kwargs == {"batch_size": 10} + assert dataloader.__pl_saved_kwargs == {} assert dataloader.__pl_saved_arg_names == ("dataset",) + assert dataloader.__pl_saved_default_kwargs == {"batch_size": 10} assert dataloader.__dataset == range(10) @@ -372,8 +375,9 @@ def __init__(self, sampler, extra_arg, drop_last=True): # assert that passed information got saved assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str") - assert dataloader.batch_sampler.__pl_saved_kwargs == {"drop_last": True} + assert dataloader.batch_sampler.__pl_saved_kwargs == {} assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg") + assert dataloader.batch_sampler.__pl_saved_default_kwargs == {"drop_last": True} # updating dataloader, what happens on access of the dataloaders. # This should not fail, and would fail before support for custom args. @@ -395,6 +399,7 @@ def __init__(self, sampler, extra_arg, drop_last=True): assert not hasattr(batch_sampler, "__pl_saved_kwargs") assert not hasattr(batch_sampler, "__pl_saved_arg_names") assert not hasattr(batch_sampler, "__pl_saved_args") + assert not hasattr(batch_sampler, "__pl_saved_default_kwargs") def test_custom_batch_sampler_no_drop_last(): @@ -418,6 +423,7 @@ def __init__(self, sampler, extra_arg): assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str") assert dataloader.batch_sampler.__pl_saved_kwargs == {} assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg") + assert dataloader.batch_sampler.__pl_saved_default_kwargs == {} # Assert that warning is raised with pytest.warns(UserWarning, match="drop_last=False"): @@ -443,6 +449,7 @@ def __init__(self, extra_arg): assert dataloader.batch_sampler.__pl_saved_args == ("random_str",) assert dataloader.batch_sampler.__pl_saved_kwargs == {} assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",) + assert dataloader.batch_sampler.__pl_saved_default_kwargs == {} # Assert that error is raised with pytest.raises(TypeError, match="sampler into the batch sampler"): @@ -453,6 +460,7 @@ def __init__(self, extra_arg): [ "args", "kwargs", + "default_kwargs", "arg_names", "replace_key", "replace_value", @@ -461,16 +469,28 @@ def __init__(self, extra_arg): "expected_kwargs", ], [ - pytest.param((), {}, [], "a", 1, False, (), {}, id="empty"), - pytest.param((1,), {}, ["a"], "a", 2, True, (2,), {}, id="simple1"), - pytest.param((1, 2, 3), {}, ["a", "b", "c"], "b", False, True, (1, False, 3), {}, id="simple2"), - pytest.param((1, 2, 3), {"a": 1}, ["b", "c", "d"], "a", 2, True, (1, 2, 3), {"a": 2}, id="simple_kwargs"), + pytest.param((), {}, {}, [], "a", 1, False, (), {}, id="empty"), + pytest.param((1,), {}, {}, ["a"], "a", 2, True, (2,), {}, id="simple1"), + pytest.param((1, 2, 3), {}, {}, ["a", "b", "c"], "b", False, True, (1, False, 3), {}, id="simple2"), + pytest.param((1, 2, 3), {"a": 1}, {}, ["b", "c", "d"], "a", 2, True, (1, 2, 3), {"a": 2}, id="simple_kwargs"), + pytest.param( + (1, 2, 3), + {"a": 1}, + {"e": 5}, + ["b", "c", "d"], + "e", + 2, + True, + (1, 2, 3), + {"a": 1, "e": 2}, + id="default_kwargs", + ), ], ) def test_replace_value_in_args( - args, kwargs, arg_names, replace_key, replace_value, expected_status, expected_args, expected_kwargs + args, kwargs, default_kwargs, arg_names, replace_key, replace_value, expected_status, expected_args, expected_kwargs ): - assert _replace_value_in_saved_args(replace_key, replace_value, args, kwargs, arg_names) == ( + assert _replace_value_in_saved_args(replace_key, replace_value, args, kwargs, default_kwargs, arg_names) == ( expected_status, expected_args, expected_kwargs, From 5fffd334713d0ad30df07ca8083f857872217d1a Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 27 Jul 2022 16:51:04 +0200 Subject: [PATCH 22/23] Revert "pprint debugging" This reverts commit ab20f5fc28bf4721cf39b5a876aaa89e07b4606f. --- src/pytorch_lightning/strategies/ipu.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index aa82fb9a201c9..108c725d1abdf 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -236,11 +236,11 @@ def _convert_to_poptorch_loader( opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts from pprint import pprint - pprint(repr(opts)) + pprint(opts) pprint(dl_args) pprint(dl_kwargs) dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs) - pprint(repr(dataloader)) + pprint(dataloader) return dataloader def _handle_gradient_accumulation_steps(self) -> None: From b3d6075b486bdf78dcee18e582ea507f585720e3 Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 27 Jul 2022 16:51:20 +0200 Subject: [PATCH 23/23] Revert "pprint debugging" This reverts commit 5a51c5aba9515bc64e7f959f4bbefd6b525a017d. --- src/pytorch_lightning/strategies/ipu.py | 6 ------ 1 file changed, 6 deletions(-) diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 108c725d1abdf..001ad77fbb5cc 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -234,13 +234,7 @@ def _convert_to_poptorch_loader( dataloader, sampler, mode, self.replication_factor > 1 ) opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts - from pprint import pprint - - pprint(opts) - pprint(dl_args) - pprint(dl_kwargs) dataloader = poptorch.DataLoader(opts, *dl_args, **dl_kwargs) - pprint(dataloader) return dataloader def _handle_gradient_accumulation_steps(self) -> None: