Skip to content

Commit 95f5f17

Browse files
otajrohitgr7awaelchli
authored
Allowed custom BatchSamplers when instantiated in *_dataloader hook (#13640)
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent c58d351 commit 95f5f17

File tree

8 files changed

+317
-81
lines changed

8 files changed

+317
-81
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -348,6 +348,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
348348

349349
- Improved support for custom `DataLoader`s when instantiated in `*_dataloader` hook ([#12981](https://github.com/PyTorchLightning/pytorch-lightning/pull/12981))
350350

351+
- Allowed custom `BatchSampler`s when instantiated in `*_dataloader` hook [#13640](https://github.com/PyTorchLightning/pytorch-lightning/pull/13640))
352+
351353

352354
- Fixed an issue with unsupported torch.inference_mode() on hpu backends by making it use no_grad ([#13014](https://github.com/PyTorchLightning/pytorch-lightning/pull/13014))
353355

src/pytorch_lightning/lite/lite.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import torch.nn as nn
2323
from torch import Tensor
2424
from torch.optim import Optimizer
25-
from torch.utils.data import DataLoader, DistributedSampler
25+
from torch.utils.data import BatchSampler, DataLoader, DistributedSampler
2626

2727
from pytorch_lightning.accelerators.accelerator import Accelerator
2828
from pytorch_lightning.lite.wrappers import _LiteDataLoader, _LiteModule, _LiteOptimizer
@@ -35,7 +35,7 @@
3535
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
3636
from pytorch_lightning.utilities.data import (
3737
_auto_add_worker_init_fn,
38-
_replace_dataloader_init_method,
38+
_replace_init_method,
3939
_update_dataloader,
4040
has_iterable_dataset,
4141
)
@@ -403,7 +403,9 @@ def _run_impl(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
403403

404404
def _run_with_strategy_setup(self, run_method: Callable, *args: Any, **kwargs: Any) -> Any:
405405
self._strategy.setup_environment()
406-
with self._strategy.model_sharded_context(), _replace_dataloader_init_method():
406+
with self._strategy.model_sharded_context(), _replace_init_method(DataLoader, "dataset"), _replace_init_method(
407+
BatchSampler
408+
):
407409
return run_method(*args, **kwargs)
408410

409411
def _move_model_to_device(self, model: nn.Module, optimizers: List[Optimizer]) -> nn.Module:

src/pytorch_lightning/trainer/connectors/data_connector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
from typing import Any, Callable, Collection, List, Optional, Tuple, Union
1818
from weakref import proxy
1919

20-
from torch.utils.data import DataLoader, Sampler, SequentialSampler
20+
from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
2121
from torch.utils.data.distributed import DistributedSampler
2222

2323
import pytorch_lightning as pl
@@ -31,7 +31,7 @@
3131
from pytorch_lightning.utilities.data import (
3232
_auto_add_worker_init_fn,
3333
_is_dataloader_shuffled,
34-
_replace_dataloader_init_method,
34+
_replace_init_method,
3535
_update_dataloader,
3636
has_iterable_dataset,
3737
has_len_all_ranks,
@@ -424,7 +424,7 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat
424424
"""
425425
source = getattr(self, f"_{stage.dataloader_prefix}_dataloader_source")
426426

427-
with _replace_dataloader_init_method():
427+
with _replace_init_method(DataLoader, "dataset"), _replace_init_method(BatchSampler):
428428
# under this context manager, the arguments passed to `DataLoader.__init__` will be captured and saved as
429429
# attributes on the instance in case the dataloader needs to be re-instantiated later by Lightning
430430
dataloader = source.dataloader()

src/pytorch_lightning/utilities/auto_restart.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,7 @@
1616
from functools import partial, wraps
1717
from typing import Any, Callable, Dict, Generator, Iterable, Iterator, List, Optional, Tuple, Union
1818

19-
from torch.utils.data import (
20-
BatchSampler,
21-
Dataset,
22-
DistributedSampler,
23-
get_worker_info,
24-
RandomSampler,
25-
Sampler,
26-
SequentialSampler,
27-
)
19+
from torch.utils.data import Dataset, DistributedSampler, get_worker_info, RandomSampler, Sampler, SequentialSampler
2820
from torch.utils.data.dataloader import (
2921
_BaseDataLoaderIter,
3022
_MultiProcessingDataLoaderIter,
@@ -757,10 +749,6 @@ def _validate_map_dataset(dataloader: DataLoader) -> None:
757749
if sampler is not None and type(sampler) not in SUPPORTED_SAMPLERS:
758750
raise TypeError(f"Fault-tolerance supports only {SUPPORTED_SAMPLERS}.")
759751

760-
batch_sampler = getattr(dataloader, "batch_sampler", None)
761-
if batch_sampler is not None and type(batch_sampler) is not BatchSampler:
762-
raise TypeError("Fault-tolerance supports only a `BatchSampler`.")
763-
764752
if type(sampler) is DistributedSampler and sampler.shuffle:
765753
raise TypeError("A `DistributedSampler` sampler shuffle attribute is set to True.")
766754
elif type(sampler) is RandomSampler:

src/pytorch_lightning/utilities/data.py

Lines changed: 124 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
import functools
1515
import inspect
1616
import os
17+
from collections import OrderedDict
1718
from contextlib import contextmanager
1819
from dataclasses import fields
1920
from functools import partial
@@ -220,11 +221,11 @@ def _get_dataloader_init_args_and_kwargs(
220221
if not isinstance(dataloader, DataLoader):
221222
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`")
222223

223-
was_wrapped = hasattr(dataloader, "__pl_dl_args")
224+
was_wrapped = hasattr(dataloader, "__pl_saved_args")
224225
if was_wrapped:
225-
dl_args = dataloader.__pl_dl_args
226-
dl_kwargs = dataloader.__pl_dl_kwargs
227-
arg_names = dataloader.__pl_dl_arg_names
226+
dl_args = dataloader.__pl_saved_args
227+
dl_kwargs = dataloader.__pl_saved_kwargs
228+
arg_names = dataloader.__pl_saved_arg_names
228229
original_dataset = dataloader.__dataset # we have this saved from _wrap_init
229230
else:
230231
# get the dataloader instance attributes
@@ -323,6 +324,9 @@ def _dataloader_init_kwargs_resolve_sampler(
323324
If the dataloader is being used for prediction, the sampler will be wrapped into an `IndexBatchSamplerWrapper`, so
324325
Lightning can keep track of its indices. If fault tolerant training is enabled, the sampler will be wrapped into a
325326
`FastForwardSampler`.
327+
328+
If there are multiple devices in IPU mode, it is necessary to disallow BatchSampler that isn't instantiated
329+
automatically, since `poptorch.DataLoader` will try to increase the batch_size
326330
"""
327331
fault_tolerant_mode = _FaultTolerantMode.detect_current_mode()
328332
batch_sampler = getattr(dataloader, "batch_sampler")
@@ -341,11 +345,59 @@ def _dataloader_init_kwargs_resolve_sampler(
341345
"when running on multiple IPU devices."
342346
)
343347
elif type(batch_sampler) is not BatchSampler or is_predicting:
344-
batch_sampler = type(batch_sampler)(
345-
sampler,
346-
batch_size=batch_sampler.batch_size,
347-
drop_last=(False if is_predicting else batch_sampler.drop_last),
348-
)
348+
batch_sampler_cls = type(batch_sampler)
349+
if hasattr(batch_sampler, "__pl_saved_args"):
350+
args = batch_sampler.__pl_saved_args
351+
kwargs = batch_sampler.__pl_saved_kwargs
352+
default_kwargs = batch_sampler.__pl_saved_default_kwargs
353+
arg_names = batch_sampler.__pl_saved_arg_names
354+
355+
if is_predicting:
356+
success, args, kwargs = _replace_value_in_saved_args(
357+
"drop_last", False, args, kwargs, default_kwargs, arg_names
358+
)
359+
if not success:
360+
rank_zero_warn(
361+
f"Trying to inject `drop_last=False` into batch sampler since you are predicting, however "
362+
f"it seems the class `{batch_sampler_cls.__qualname__}` does not support it. "
363+
"Your predictions might be incomplete. To mitigate this, expose `drop_last` in "
364+
"the `__init__` method of your custom class."
365+
)
366+
367+
success, args, kwargs = _replace_value_in_saved_args(
368+
"sampler", sampler, args, kwargs, default_kwargs, arg_names
369+
)
370+
if not success:
371+
raise TypeError(
372+
"Trying to inject a modified sampler into the batch sampler; however, it seems the class "
373+
f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate "
374+
"this, expose an argument `sampler` in the `__init__` method of your custom class."
375+
)
376+
377+
batch_sampler = batch_sampler_cls(*args, **kwargs)
378+
else:
379+
try:
380+
batch_sampler = batch_sampler_cls(
381+
sampler,
382+
batch_size=batch_sampler.batch_size,
383+
drop_last=(False if is_predicting else batch_sampler.drop_last),
384+
)
385+
except TypeError as e:
386+
import re
387+
388+
match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(e))
389+
if not match:
390+
# an unexpected `TypeError`, continue failure
391+
raise
392+
393+
# There could either be too few or too many arguments. Customizing the message based on this doesn't
394+
# make much sense since our MisconfigurationException is going to be raised from the original one.
395+
raise MisconfigurationException(
396+
"We tried to re-instantiate your custom batch sampler and failed. "
397+
"To mitigate this, either follow the API of `BatchSampler` or instantiate "
398+
"your custom batch sampler inside `*_dataloader` hooks of your module."
399+
) from e
400+
349401
if is_predicting:
350402
batch_sampler = IndexBatchSamplerWrapper(batch_sampler)
351403

@@ -368,39 +420,73 @@ def _dataloader_init_kwargs_resolve_sampler(
368420
return {"sampler": sampler, "shuffle": False, "batch_sampler": None}
369421

370422

423+
def _replace_value_in_saved_args(
424+
replace_key: str,
425+
replace_value: Any,
426+
args: Tuple[Any, ...],
427+
kwargs: Dict[str, Any],
428+
default_kwargs: Dict[str, Any],
429+
arg_names: Tuple[str, ...],
430+
) -> Tuple[bool, Tuple[Any, ...], Dict[str, Any]]:
431+
"""Tries to replace an argument value in a saved list of args and kwargs.
432+
433+
Returns a tuple indicating success of the operation and modified saved args and kwargs
434+
"""
435+
436+
if replace_key in arg_names:
437+
replace_index = arg_names.index(replace_key)
438+
args = args[:replace_index] + (replace_value,) + args[replace_index + 1 :]
439+
return True, args, kwargs
440+
elif replace_key in kwargs or replace_key in default_kwargs:
441+
kwargs[replace_key] = replace_value
442+
return True, args, kwargs
443+
444+
return False, args, kwargs
445+
446+
371447
def _auto_add_worker_init_fn(dataloader: DataLoader, rank: int) -> None:
372448
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None:
373449
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank)
374450

375451

376-
def _wrap_dataloader_init(init: Callable) -> Callable:
377-
"""Wraps the ``__init__`` method of :class:`~torch.utils.data.DataLoader` in order to enable re-instantiation
378-
of custom subclasses."""
452+
def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable:
453+
"""Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and
454+
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses."""
379455

380456
@functools.wraps(init)
381-
def wrapper(obj: DataLoader, *args: Any, **kwargs: Any) -> None:
457+
def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None:
382458
# We need to inspect `init`, as inspecting `obj.__init__`
383459
# can lead to inspecting the wrong function with multiple inheritance
384460
params = inspect.signature(init).parameters
385-
param_names = tuple(
386-
param.name
461+
462+
parameters_defaults = OrderedDict(
463+
(param.name, param.default)
387464
for param in params.values()
388465
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD)
389466
)
390-
param_names = param_names[: len(args)]
391467

392-
if not hasattr(obj, "__pl_dl_args"):
393-
obj.__pl_dl_args = args
394-
obj.__pl_dl_kwargs = kwargs
395-
obj.__pl_dl_arg_names = param_names
468+
param_names = tuple(parameters_defaults)[: len(args)]
396469

397-
# We want to use the latest possible value for dataset argument (i.e. ideally what gets passed to DataLoader)
470+
default_kwargs = {
471+
name: value
472+
for name, value in parameters_defaults.items()
473+
if name not in kwargs and name not in param_names and value != inspect.Parameter.empty
474+
}
475+
476+
if not hasattr(obj, "__pl_saved_args"):
477+
obj.__pl_saved_args = args
478+
obj.__pl_saved_kwargs = kwargs
479+
obj.__pl_saved_arg_names = param_names
480+
obj.__pl_saved_default_kwargs = default_kwargs
481+
482+
# We want to use the latest possible value for explicit argument (i.e. ideally what gets passed to base class)
398483
# so that we can be sure, that it will not get changed anymore.
399484
# That is why we are setting this in every `__init__`
400-
if "dataset" in param_names:
401-
setattr(obj, "__dataset", args[param_names.index("dataset")])
402-
elif "dataset" in kwargs:
403-
setattr(obj, "__dataset", kwargs["dataset"])
485+
if store_explicit_arg is not None:
486+
if store_explicit_arg in param_names:
487+
setattr(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)])
488+
elif store_explicit_arg in kwargs:
489+
setattr(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg])
404490

405491
init(obj, *args, **kwargs)
406492

@@ -422,15 +508,17 @@ def recurse(cl: Type[Any]) -> None:
422508

423509

424510
@contextmanager
425-
def _replace_dataloader_init_method() -> Generator[None, None, None]:
426-
"""This context manager is used to add support for re-instantiation of custom (subclasses) of
427-
:class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method."""
428-
classes = _get_all_subclasses(DataLoader) | {DataLoader}
511+
def _replace_init_method(base_cls: Type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]:
512+
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`.
513+
514+
It patches the ``__init__`` method.
515+
"""
516+
classes = _get_all_subclasses(base_cls) | {base_cls}
429517
wrapped = set()
430518
for cls in classes:
431519
if cls.__init__ not in wrapped:
432520
cls._old_init = cls.__init__
433-
cls.__init__ = _wrap_dataloader_init(cls.__init__)
521+
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg)
434522
wrapped.add(cls.__init__)
435523
yield
436524
for cls in classes:
@@ -475,13 +563,13 @@ def _apply_fault_tolerant_automatic_capture_dataset_wrapper(
475563

476564

477565
def _is_dataloader_shuffled(dataloader: object) -> bool:
478-
if hasattr(dataloader, "__pl_dl_kwargs"):
566+
if hasattr(dataloader, "__pl_saved_kwargs"):
479567
# this attribute is not part of PyTorch's DataLoader, but could have been set by
480-
# our `_replace_dataloader_init_method` context manager
481-
if "shuffle" in dataloader.__pl_dl_kwargs:
482-
return dataloader.__pl_dl_kwargs["shuffle"]
483-
if "shuffle" in dataloader.__pl_dl_arg_names:
484-
return dataloader.__pl_dl_args[dataloader.__pl_dl_arg_names.index("shuffle")]
568+
# our `_replace_init_method` context manager
569+
if "shuffle" in dataloader.__pl_saved_kwargs:
570+
return dataloader.__pl_saved_kwargs["shuffle"]
571+
if "shuffle" in dataloader.__pl_saved_arg_names:
572+
return dataloader.__pl_saved_args[dataloader.__pl_saved_arg_names.index("shuffle")]
485573
if isinstance(dataloader.dataset, IterableDataset):
486574
# shuffling is useless with iterable datasets
487575
return False

tests/tests_pytorch/lite/test_lite.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -177,16 +177,17 @@ def test_setup_dataloaders_return_type():
177177
assert lite_dataloader1.dataset is dataset1
178178

179179

180-
@mock.patch("pytorch_lightning.lite.lite._replace_dataloader_init_method")
180+
@mock.patch("pytorch_lightning.lite.lite._replace_init_method")
181181
def test_setup_dataloaders_captures_dataloader_arguments(ctx_manager):
182182
"""Test that Lite intercepts the DataLoader constructor arguments with a context manager in its run method."""
183183

184184
class Lite(LightningLite):
185185
def run(self):
186-
ctx_manager().__enter__.assert_called_once()
186+
# One for BatchSampler, another for DataLoader
187+
assert ctx_manager().__enter__.call_count == 2
187188

188189
Lite().run()
189-
ctx_manager().__exit__.assert_called_once()
190+
assert ctx_manager().__exit__.call_count == 2
190191

191192

192193
def test_setup_dataloaders_raises_for_unknown_custom_args():

tests/tests_pytorch/utilities/test_auto_restart.py

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,6 @@
3434
from torch.utils.data._utils.worker import _generate_state, get_worker_info
3535
from torch.utils.data.dataloader import DataLoader, default_collate
3636
from torch.utils.data.dataset import Dataset, IterableDataset
37-
from torch.utils.data.sampler import Sampler
3837

3938
import tests_pytorch.helpers.utils as tutils
4039
from pytorch_lightning import Callback, LightningModule, seed_everything, Trainer
@@ -1177,15 +1176,6 @@ class CustomRandomSampler(RandomSampler):
11771176
with pytest.raises(TypeError, match="RandomSampler"):
11781177
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)
11791178

1180-
class CustomBatchSampler(BatchSampler):
1181-
pass
1182-
1183-
sampler = Sampler(data())
1184-
batch_sampler = CustomBatchSampler(sampler, 2, False)
1185-
dl = DataLoader(data(), batch_sampler=batch_sampler)
1186-
with pytest.raises(TypeError, match="BatchSampler"):
1187-
_validate_fault_tolerant_automatic(dl, RunningStage.TRAINING)
1188-
11891179
class CustomIterable(IterableDataset):
11901180
pass
11911181

0 commit comments

Comments
 (0)