diff --git a/src/pytorch_lightning/CHANGELOG.md b/src/pytorch_lightning/CHANGELOG.md index f8341248b20e8..baf01371fb8bc 100644 --- a/src/pytorch_lightning/CHANGELOG.md +++ b/src/pytorch_lightning/CHANGELOG.md @@ -348,6 +348,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Improved support for custom `DataLoader`s when instantiated in `*_dataloader` hook ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981)) +- Allowed custom `BatchSampler`s when instantiated in `*_dataloader` hook [#13640](https://github.com/PyTorchLightning/pytorch-lightning/pull/13640)) + - Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014)) diff --git a/src/pytorch_lightning/lite/lite.py b/src/pytorch_lightning/lite/lite.py index 0195e6852eb28..981eed30635f6 100644 --- a/src/pytorch_lightning/lite/lite.py +++ b/src/pytorch_lightning/lite/lite.py @@ -22,7 +22,7 @@ import torch.nn as nn from torch import Tensor from torch.optim import Optimizer -from torch.utils.data import DataLoader, DistributedSampler +from torch.utils.data import BatchSampler, DataLoader, DistributedSampler from pytorch_lightning.accelerators.accelerator import Accelerator from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer @@ -35,7 +35,7 @@ from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, - _replace_dataloader_init_method, + _replace_init_method, _update_dataloader, has_iterable_dataset, ) @@ -403,7 +403,9 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any: self._strategy.setup_environment() - with self._strategy.model_sharded_context(), _replace_dataloader_init_method(): + with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, "dataset"), _replace_init_method( + BatchSampler + ): return run_method(*args, **kwargs) def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module: diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index add62ceece65c..7831316a98ae1 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -17,7 +17,7 @@ from typing import Any, Callable, Collection, List, Optional, Tuple, Union from weakref import proxy -from torch.utils.data import DataLoader, Sampler, SequentialSampler +from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler import pytorch_lightning as pl @@ -31,7 +31,7 @@ from pytorch_lightning.utilities.data import ( _auto_add_worker_init_fn, _is_dataloader_shuffled, - _replace_dataloader_init_method, + _replace_init_method, _update_dataloader, has_iterable_dataset, has_len_all_ranks, @@ -424,7 +424,7 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat """ source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source") - with _replace_dataloader_init_method(): + with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler): # under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as # attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning dataloader = source.dataloader() diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index 0bd2942b17cca..3877a1ab3944c 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -16,15 +16,7 @@ from functools import partial, wraps from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union -from torch.utils.data import ( - BatchSampler, - Dataset, - DistributedSampler, - get_worker_info, - RandomSampler, - Sampler, - SequentialSampler, -) +from torch.utils.data import Dataset, DistributedSampler, get_worker_info, RandomSampler, Sampler, SequentialSampler from torch.utils.data.dataloader import ( _BaseDataLoaderIter, _MultiProcessingDataLoaderIter, @@ -757,10 +749,6 @@ def _validate_map_dataset(dataloader: DataLoader) -> None: if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS: raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.") - batch_sampler = getattr(dataloader, "batch_sampler", None) - if batch_sampler is not None and type(batch_sampler) is not BatchSampler: - raise TypeError("Fault-tolerance supports only a `BatchSampler`.") - if type(sampler) is DistributedSampler and sampler.shuffle: raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.") elif type(sampler) is RandomSampler: diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index e60c56f6c7a7e..862c7f2de905b 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -14,6 +14,7 @@ import functools import inspect import os +from collections import OrderedDict from contextlib import contextmanager from dataclasses import fields from functools import partial @@ -220,11 +221,11 @@ def _get_dataloader_init_args_and_kwargs( if not isinstance(dataloader, DataLoader): raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") - was_wrapped = hasattr(dataloader, "__pl_dl_args") + was_wrapped = hasattr(dataloader, "__pl_saved_args") if was_wrapped: - dl_args = dataloader.__pl_dl_args - dl_kwargs = dataloader.__pl_dl_kwargs - arg_names = dataloader.__pl_dl_arg_names + dl_args = dataloader.__pl_saved_args + dl_kwargs = dataloader.__pl_saved_kwargs + arg_names = dataloader.__pl_saved_arg_names original_dataset = dataloader.__dataset # we have this saved from _wrap_init else: # get the dataloader instance attributes @@ -323,6 +324,9 @@ def _dataloader_init_kwargs_resolve_sampler( If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a `FastForwardSampler`. + + If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated + automatically, since `poptorch.DataLoader` will try to increase the batch_size """ fault_tolerant_mode = _FaultTolerantMode.detect_current_mode() batch_sampler = getattr(dataloader, "batch_sampler") @@ -341,11 +345,59 @@ def _dataloader_init_kwargs_resolve_sampler( "when running on multiple IPU devices." ) elif type(batch_sampler) is not BatchSampler or is_predicting: - batch_sampler = type(batch_sampler)( - sampler, - batch_size=batch_sampler.batch_size, - drop_last=(False if is_predicting else batch_sampler.drop_last), - ) + batch_sampler_cls = type(batch_sampler) + if hasattr(batch_sampler, "__pl_saved_args"): + args = batch_sampler.__pl_saved_args + kwargs = batch_sampler.__pl_saved_kwargs + default_kwargs = batch_sampler.__pl_saved_default_kwargs + arg_names = batch_sampler.__pl_saved_arg_names + + if is_predicting: + success, args, kwargs = _replace_value_in_saved_args( + "drop_last", False, args, kwargs, default_kwargs, arg_names + ) + if not success: + rank_zero_warn( + f"Trying to inject `drop_last=False` into batch sampler since you are predicting, however " + f"it seems the class `{batch_sampler_cls.__qualname__}` does not support it. " + "Your predictions might be incomplete. To mitigate this, expose `drop_last` in " + "the `__init__` method of your custom class." + ) + + success, args, kwargs = _replace_value_in_saved_args( + "sampler", sampler, args, kwargs, default_kwargs, arg_names + ) + if not success: + raise TypeError( + "Trying to inject a modified sampler into the batch sampler; however, it seems the class " + f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate " + "this, expose an argument `sampler` in the `__init__` method of your custom class." + ) + + batch_sampler = batch_sampler_cls(*args, **kwargs) + else: + try: + batch_sampler = batch_sampler_cls( + sampler, + batch_size=batch_sampler.batch_size, + drop_last=(False if is_predicting else batch_sampler.drop_last), + ) + except TypeError as e: + import re + + match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(e)) + if not match: + # an unexpected `TypeError`, continue failure + raise + + # There could either be too few or too many arguments. Customizing the message based on this doesn't + # make much sense since our MisconfigurationException is going to be raised from the original one. + raise MisconfigurationException( + "We tried to re-instantiate your custom batch sampler and failed. " + "To mitigate this, either follow the API of `BatchSampler` or instantiate " + "your custom batch sampler inside `*_dataloader` hooks of your module." + ) from e + if is_predicting: batch_sampler = IndexBatchSamplerWrapper(batch_sampler) @@ -368,39 +420,73 @@ def _dataloader_init_kwargs_resolve_sampler( return {"sampler": sampler, "shuffle": False, "batch_sampler": None} +def _replace_value_in_saved_args( + replace_key: str, + replace_value: Any, + args: Tuple[Any, ...], + kwargs: Dict[str, Any], + default_kwargs: Dict[str, Any], + arg_names: Tuple[str, ...], +) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]: + """Tries to replace an argument value in a saved list of args and kwargs. + + Returns a tuple indicating success of the operation and modified saved args and kwargs + """ + + if replace_key in arg_names: + replace_index = arg_names.index(replace_key) + args = args[:replace_index] + (replace_value,) + args[replace_index + 1 :] + return True, args, kwargs + elif replace_key in kwargs or replace_key in default_kwargs: + kwargs[replace_key] = replace_value + return True, args, kwargs + + return False, args, kwargs + + def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None: if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None: dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) -def _wrap_dataloader_init(init: Callable) -> Callable: - """Wraps the ``__init__`` method of :class:`~torch.utils.data.DataLoader` in order to enable re-instantiation - of custom subclasses.""" +def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable: + """Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and + :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" @functools.wraps(init) - def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None: + def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: # We need to inspect `init`, as inspecting `obj.__init__` # can lead to inspecting the wrong function with multiple inheritance params = inspect.signature(init).parameters - param_names = tuple( - param.name + + parameters_defaults = OrderedDict( + (param.name, param.default) for param in params.values() if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) ) - param_names = param_names[: len(args)] - if not hasattr(obj, "__pl_dl_args"): - obj.__pl_dl_args = args - obj.__pl_dl_kwargs = kwargs - obj.__pl_dl_arg_names = param_names + param_names = tuple(parameters_defaults)[: len(args)] - # We want to use the latest possible value for dataset argument (i.e. ideally what gets passed to DataLoader) + default_kwargs = { + name: value + for name, value in parameters_defaults.items() + if name not in kwargs and name not in param_names and value != inspect.Parameter.empty + } + + if not hasattr(obj, "__pl_saved_args"): + obj.__pl_saved_args = args + obj.__pl_saved_kwargs = kwargs + obj.__pl_saved_arg_names = param_names + obj.__pl_saved_default_kwargs = default_kwargs + + # We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class) # so that we can be sure, that it will not get changed anymore. # That is why we are setting this in every `__init__` - if "dataset" in param_names: - setattr(obj, "__dataset", args[param_names.index("dataset")]) - elif "dataset" in kwargs: - setattr(obj, "__dataset", kwargs["dataset"]) + if store_explicit_arg is not None: + if store_explicit_arg in param_names: + setattr(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)]) + elif store_explicit_arg in kwargs: + setattr(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg]) init(obj, *args, **kwargs) @@ -422,15 +508,17 @@ def recurse(cl: Type[Any]) -> None: @contextmanager -def _replace_dataloader_init_method() -> Generator[None, None, None]: - """This context manager is used to add support for re-instantiation of custom (subclasses) of - :class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method.""" - classes = _get_all_subclasses(DataLoader) | {DataLoader} +def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: + """This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. + + It patches the ``__init__`` method. + """ + classes = _get_all_subclasses(base_cls) | {base_cls} wrapped = set() for cls in classes: if cls.__init__ not in wrapped: cls._old_init = cls.__init__ - cls.__init__ = _wrap_dataloader_init(cls.__init__) + cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) wrapped.add(cls.__init__) yield for cls in classes: @@ -475,13 +563,13 @@ def _apply_fault_tolerant_automatic_capture_dataset_wrapper( def _is_dataloader_shuffled(dataloader: object) -> bool: - if hasattr(dataloader, "__pl_dl_kwargs"): + if hasattr(dataloader, "__pl_saved_kwargs"): # this attribute is not part of PyTorch's DataLoader, but could have been set by - # our `_replace_dataloader_init_method` context manager - if "shuffle" in dataloader.__pl_dl_kwargs: - return dataloader.__pl_dl_kwargs["shuffle"] - if "shuffle" in dataloader.__pl_dl_arg_names: - return dataloader.__pl_dl_args[dataloader.__pl_dl_arg_names.index("shuffle")] + # our `_replace_init_method` context manager + if "shuffle" in dataloader.__pl_saved_kwargs: + return dataloader.__pl_saved_kwargs["shuffle"] + if "shuffle" in dataloader.__pl_saved_arg_names: + return dataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")] if isinstance(dataloader.dataset, IterableDataset): # shuffling is useless with iterable datasets return False diff --git a/tests/tests_pytorch/lite/test_lite.py b/tests/tests_pytorch/lite/test_lite.py index c0439854013a2..3652613526549 100644 --- a/tests/tests_pytorch/lite/test_lite.py +++ b/tests/tests_pytorch/lite/test_lite.py @@ -177,16 +177,17 @@ def test_setup_dataloaders_return_type(): assert lite_dataloader1.dataset is dataset1 -@mock.patch("pytorch_lightning.lite.lite._replace_dataloader_init_method") +@mock.patch("pytorch_lightning.lite.lite._replace_init_method") def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager): """Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method.""" class Lite(LightningLite): def run(self): - ctx_manager().__enter__.assert_called_once() + # One for BatchSampler, another for DataLoader + assert ctx_manager().__enter__.call_count == 2 Lite().run() - ctx_manager().__exit__.assert_called_once() + assert ctx_manager().__exit__.call_count == 2 def test_setup_dataloaders_raises_for_unknown_custom_args(): diff --git a/tests/tests_pytorch/utilities/test_auto_restart.py b/tests/tests_pytorch/utilities/test_auto_restart.py index 5a5982ad009f9..8a888ce09c90a 100644 --- a/tests/tests_pytorch/utilities/test_auto_restart.py +++ b/tests/tests_pytorch/utilities/test_auto_restart.py @@ -34,7 +34,6 @@ from torch.utils.data._utils.worker import _generate_state, get_worker_info from torch.utils.data.dataloader import DataLoader, default_collate from torch.utils.data.dataset import Dataset, IterableDataset -from torch.utils.data.sampler import Sampler import tests_pytorch.helpers.utils as tutils from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer @@ -1177,15 +1176,6 @@ class CustomRandomSampler(RandomSampler): with pytest.raises(TypeError, match="RandomSampler"): _validate_fault_tolerant_automatic(dl, RunningStage.TRAINING) - class CustomBatchSampler(BatchSampler): - pass - - sampler = Sampler(data()) - batch_sampler = CustomBatchSampler(sampler, 2, False) - dl = DataLoader(data(), batch_sampler=batch_sampler) - with pytest.raises(TypeError, match="BatchSampler"): - _validate_fault_tolerant_automatic(dl, RunningStage.TRAINING) - class CustomIterable(IterableDataset): pass diff --git a/tests/tests_pytorch/utilities/test_data.py b/tests/tests_pytorch/utilities/test_data.py index 5f66d802ea939..ffb898efaa815 100644 --- a/tests/tests_pytorch/utilities/test_data.py +++ b/tests/tests_pytorch/utilities/test_data.py @@ -3,15 +3,17 @@ import pytest import torch from torch import Tensor -from torch.utils.data import BatchSampler, DataLoader, SequentialSampler +from torch.utils.data import BatchSampler, DataLoader, RandomSampler, SequentialSampler from pytorch_lightning import Trainer from pytorch_lightning.demos.boring_classes import BoringModel, RandomDataset +from pytorch_lightning.overrides.distributed import IndexBatchSamplerWrapper from pytorch_lightning.trainer.states import RunningStage from pytorch_lightning.utilities.data import ( _dataloader_init_kwargs_resolve_sampler, _get_dataloader_init_args_and_kwargs, - _replace_dataloader_init_method, + _replace_init_method, + _replace_value_in_saved_args, _update_dataloader, extract_batch_size, get_len, @@ -145,7 +147,7 @@ def __init__(self, foo, *args, **kwargs): with pytest.raises(MisconfigurationException, match="`DataLoader` implementation has an error.*`dataset`"): _update_dataloader(dataloader, dataloader.sampler) - with _replace_dataloader_init_method(): + with _replace_init_method(DataLoader, "dataset"): dataloader = BadStandaloneGoodHookImpl([1, 2, 3]) new_dataloader = _update_dataloader(dataloader, dataloader.sampler) assert isinstance(new_dataloader, BadStandaloneGoodHookImpl) @@ -296,13 +298,14 @@ def __init__(self, dataset, **kwargs): pytest.param(ChangingDataLoader, (range(5),), dict(), ("dataset",), list(range(10)), dict(), id="test9"), ], ) -def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, checked_values): - with _replace_dataloader_init_method(): +def test_replace_init_method_dataloader(cls, args, kwargs, arg_names, dataset, checked_values): + with _replace_init_method(DataLoader, "dataset"): dataloader = cls(*args, **kwargs) - assert dataloader.__pl_dl_args == args - assert dataloader.__pl_dl_kwargs == kwargs - assert dataloader.__pl_dl_arg_names == arg_names + assert dataloader.__pl_saved_args == args + assert dataloader.__pl_saved_kwargs == kwargs + assert dataloader.__pl_saved_arg_names == arg_names + assert dataloader.__pl_saved_default_kwargs == {} assert dataloader.__dataset == dataset assert dataloader.dataset == dataset @@ -312,14 +315,15 @@ def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, c if isinstance(dataloader_value, torch.Tensor): assert dataloader_value is value else: - assert getattr(dataloader, key) == value + assert dataloader_value == value dataloader = _update_dataloader(dataloader, dataloader.sampler) assert isinstance(dataloader, cls) - assert not hasattr(dataloader, "__pl_dl_kwargs") - assert not hasattr(dataloader, "__pl_dl_arg_names") - assert not hasattr(dataloader, "__pl_dl_args") + assert not hasattr(dataloader, "__pl_saved_kwargs") + assert not hasattr(dataloader, "__pl_saved_arg_names") + assert not hasattr(dataloader, "__pl_saved_args") + assert not hasattr(dataloader, "__pl_saved_default_kwargs") assert not hasattr(dataloader, "__dataset") assert dataloader.dataset == dataset @@ -329,7 +333,168 @@ def test_replace_dataloader_init_method(cls, args, kwargs, arg_names, dataset, c if isinstance(dataloader_value, torch.Tensor): assert dataloader_value is value else: - assert getattr(dataloader, key) == value + assert dataloader_value == value + + +def test_replace_init_method_extra_kwargs(): + class LoaderSubclass(DataLoader): + def __init__(self, dataset, *args, batch_size=10, **kwargs): + super().__init__(dataset, *args, batch_size=batch_size, **kwargs) + + with _replace_init_method(DataLoader, "dataset"): + dataloader = LoaderSubclass(range(10)) + + assert dataloader.__pl_saved_args == (range(10),) + assert dataloader.__pl_saved_kwargs == {} + assert dataloader.__pl_saved_arg_names == ("dataset",) + assert dataloader.__pl_saved_default_kwargs == {"batch_size": 10} + assert dataloader.__dataset == range(10) + + +@pytest.mark.parametrize("predicting", [True, False]) +def test_custom_batch_sampler(predicting): + """This test asserts, that custom `BatchSampler`, with all the arguments, that are required in order to + properly reinstantiate the class, is invoked properly. + + It also asserts, that during the reinstantiation, the wrapper of `__init__` method is not present anymore, therefore + not setting `__pl_saved_{args,arg_names,kwargs}` attributes. + """ + + class MyBatchSampler(BatchSampler): + # Custom Batch sampler with extra argument and default value + def __init__(self, sampler, extra_arg, drop_last=True): + self.extra_arg = extra_arg + super().__init__(sampler, 10, drop_last) + + sampler = RandomSampler(range(10)) + with _replace_init_method(BatchSampler): + # instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks + batch_sampler = MyBatchSampler(sampler, "random_str") + + dataloader = DataLoader(range(10), batch_sampler=batch_sampler) + + # assert that passed information got saved + assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str") + assert dataloader.batch_sampler.__pl_saved_kwargs == {} + assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg") + assert dataloader.batch_sampler.__pl_saved_default_kwargs == {"drop_last": True} + + # updating dataloader, what happens on access of the dataloaders. + # This should not fail, and would fail before support for custom args. + dataloader = _update_dataloader( + dataloader, dataloader.sampler, mode=RunningStage.PREDICTING if predicting else None + ) + + # Assert the `__init__` method is not replaced anymore and everything is instantiated to correct types + batch_sampler = dataloader.batch_sampler + + if predicting: + assert isinstance(batch_sampler, IndexBatchSamplerWrapper) + batch_sampler = batch_sampler._sampler + + assert isinstance(batch_sampler, MyBatchSampler) + assert batch_sampler.drop_last == (not predicting) + + assert batch_sampler.extra_arg == "random_str" + assert not hasattr(batch_sampler, "__pl_saved_kwargs") + assert not hasattr(batch_sampler, "__pl_saved_arg_names") + assert not hasattr(batch_sampler, "__pl_saved_args") + assert not hasattr(batch_sampler, "__pl_saved_default_kwargs") + + +def test_custom_batch_sampler_no_drop_last(): + """Tests whether appropriate warning is raised when the custom `BatchSampler` does not support `drop_last` and + we want to reset it.""" + + class MyBatchSampler(BatchSampler): + # Custom batch sampler with extra argument, but without `drop_last` + def __init__(self, sampler, extra_arg): + self.extra_arg = extra_arg + super().__init__(sampler, 10, False) + + sampler = RandomSampler(range(10)) + with _replace_init_method(BatchSampler): + # instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks + batch_sampler = MyBatchSampler(sampler, "random_str") + + dataloader = DataLoader(range(10), batch_sampler=batch_sampler) + + # assert that passed information got saved + assert dataloader.batch_sampler.__pl_saved_args == (sampler, "random_str") + assert dataloader.batch_sampler.__pl_saved_kwargs == {} + assert dataloader.batch_sampler.__pl_saved_arg_names == ("sampler", "extra_arg") + assert dataloader.batch_sampler.__pl_saved_default_kwargs == {} + + # Assert that warning is raised + with pytest.warns(UserWarning, match="drop_last=False"): + dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING) + + +def test_custom_batch_sampler_no_sampler(): + """Tests whether appropriate error is raised when the custom `BatchSampler` does not support sampler + argument.""" + + class MyBatchSampler(BatchSampler): + # Custom batch sampler, without sampler argument. + def __init__(self, extra_arg): + self.extra_arg = extra_arg + super().__init__(RandomSampler(range(10)), 10, False) + + with _replace_init_method(BatchSampler): + # instantiate within `_replace_init_method` context manager, simulating `*_dataloader` hooks + batch_sampler = MyBatchSampler("random_str") + dataloader = DataLoader(range(10), batch_sampler=batch_sampler) + + # assert that passed information got saved + assert dataloader.batch_sampler.__pl_saved_args == ("random_str",) + assert dataloader.batch_sampler.__pl_saved_kwargs == {} + assert dataloader.batch_sampler.__pl_saved_arg_names == ("extra_arg",) + assert dataloader.batch_sampler.__pl_saved_default_kwargs == {} + + # Assert that error is raised + with pytest.raises(TypeError, match="sampler into the batch sampler"): + dataloader = _update_dataloader(dataloader, dataloader.sampler, mode=RunningStage.PREDICTING) + + +@pytest.mark.parametrize( + [ + "args", + "kwargs", + "default_kwargs", + "arg_names", + "replace_key", + "replace_value", + "expected_status", + "expected_args", + "expected_kwargs", + ], + [ + pytest.param((), {}, {}, [], "a", 1, False, (), {}, id="empty"), + pytest.param((1,), {}, {}, ["a"], "a", 2, True, (2,), {}, id="simple1"), + pytest.param((1, 2, 3), {}, {}, ["a", "b", "c"], "b", False, True, (1, False, 3), {}, id="simple2"), + pytest.param((1, 2, 3), {"a": 1}, {}, ["b", "c", "d"], "a", 2, True, (1, 2, 3), {"a": 2}, id="simple_kwargs"), + pytest.param( + (1, 2, 3), + {"a": 1}, + {"e": 5}, + ["b", "c", "d"], + "e", + 2, + True, + (1, 2, 3), + {"a": 1, "e": 2}, + id="default_kwargs", + ), + ], +) +def test_replace_value_in_args( + args, kwargs, default_kwargs, arg_names, replace_key, replace_value, expected_status, expected_args, expected_kwargs +): + assert _replace_value_in_saved_args(replace_key, replace_value, args, kwargs, default_kwargs, arg_names) == ( + expected_status, + expected_args, + expected_kwargs, + ) def test_dataloader_disallow_batch_sampler():