From e7b2876ac687f2718c546755a608dc828c40bfa2 Mon Sep 17 00:00:00 2001 From: nandwalritik Date: Thu, 28 Jul 2022 14:57:11 +0530 Subject: [PATCH 1/6] Update pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 5a710faf3544b..1deb7a4d101a0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -76,7 +76,6 @@ module = [ "pytorch_lightning.trainer.trainer", "pytorch_lightning.tuner.batch_size_scaling", "pytorch_lightning.utilities.auto_restart", - "pytorch_lightning.utilities.data", "pytorch_lightning.utilities.meta", ] ignore_errors = "True" From 047738e7b2d289079da1179ce0152a59b23fbb22 Mon Sep 17 00:00:00 2001 From: nandwalritik Date: Thu, 28 Jul 2022 15:03:56 +0530 Subject: [PATCH 2/6] Add changes for data.py --- src/pytorch_lightning/utilities/data.py | 38 +++++++++++++------------ 1 file changed, 20 insertions(+), 18 deletions(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 2de82ceff088e..2e253f1957a2f 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -42,12 +42,12 @@ from pytorch_lightning.utilities.seed import pl_worker_init_function from pytorch_lightning.utilities.warnings import WarningCache -BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] +BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] # type: ignore warning_cache = WarningCache() -def _extract_batch_size(batch: BType) -> Generator[int, None, None]: +def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]: if isinstance(batch, Tensor): if batch.ndim == 0: yield 1 @@ -100,12 +100,12 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool: return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset) -def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: +def has_len(dataloader: Union[DataLoader, Dataset, Iterable]) -> bool: """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or infinite dataloader.""" try: # try getting the length - if len(dataloader) == 0: + if isinstance(dataloader, DataLoader) and len(dataloader) == 0: rank_zero_warn( f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention." ) @@ -115,7 +115,7 @@ def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: except NotImplementedError: # e.g. raised by torchtext if a batch_size_fn is used has_len = False - if has_len and has_iterable_dataset(dataloader): + if has_len and isinstance(dataloader, DataLoader) and has_iterable_dataset(dataloader): rank_zero_warn( "Your `IterableDataset` has `__len__` defined." " In combination with multi-process data loading (when num_workers > 1)," @@ -127,14 +127,16 @@ def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: def has_len_all_ranks( dataloader: DataLoader, - training_type: "pl.Strategy", + training_type: "pl.Strategy", # type: ignore model: Union["pl.LightningModule", "pl.LightningDataModule"], ) -> bool: """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or infinite dataloader.""" try: local_length = len(dataloader) - total_length = training_type.reduce(torch.tensor(local_length).to(model.device), reduce_op="sum") + total_length = training_type.reduce( + torch.tensor(local_length).to(model.device), reduce_op="sum" + ) # type: ignore if total_length == 0: rank_zero_warn( @@ -171,13 +173,13 @@ def has_len_all_ranks( return has_len -def get_len(dataloader: DataLoader) -> Union[int, float]: +def get_len(dataloader: Union[DataLoader, Dataset]) -> Union[int, float]: """Return the length of the given DataLoader. If ``__len__`` method is not implemented, return float('inf'). """ - if has_len(dataloader): + if has_len(dataloader) and isinstance(dataloader, DataLoader): return len(dataloader) return float("inf") @@ -186,7 +188,7 @@ def get_len(dataloader: DataLoader) -> Union[int, float]: def _update_dataloader( dataloader: DataLoader, sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None ) -> DataLoader: - dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode=mode) + dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler, mode=mode) # type: ignore dl_cls = type(dataloader) try: dataloader = dl_cls(*dl_args, **dl_kwargs) @@ -234,7 +236,7 @@ def _get_dataloader_init_args_and_kwargs( arg_names = () # get the dataloader instance `__init__` parameters - params = dict(inspect.signature(dataloader.__init__).parameters) + params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) if has_variadic_kwargs: # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)` @@ -264,9 +266,9 @@ def _get_dataloader_init_args_and_kwargs( dl_kwargs["batch_sampler"] = None dl_kwargs["sampler"] = None else: - dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode)) + dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode=mode)) # type: ignore - required_args = { + required_args: Union[list[str], set[str]] = { p.name for p in params.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) @@ -289,7 +291,7 @@ def _get_dataloader_init_args_and_kwargs( if not has_variadic_kwargs: # the dataloader signature does not allow keyword arguments that need to be passed - missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys() + missing_kwargs: Union[list[str], set[str]] = (set(dl_kwargs) | set(arg_names)) - params.keys() if missing_kwargs: missing_kwargs = sorted(missing_kwargs) dataloader_cls_name = dataloader.__class__.__name__ @@ -309,7 +311,7 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( - dataloader: DataLoader, sampler: Optional[Sampler], mode: Optional[RunningStage] = None + dataloader: DataLoader, sampler: Union[Sampler, Generator], mode: Optional[RunningStage] = None ) -> Dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re-instantiation. @@ -408,15 +410,15 @@ def _replace_dataloader_init_method() -> Generator[None, None, None]: """This context manager is used to add support for re-instantiation of custom (subclasses) of :class:`~torch.utils.data.DataLoader`. It patches the ``__init__`` method.""" classes = _get_all_subclasses(DataLoader) | {DataLoader} - wrapped = set() + wrapped: set[Any] = set() for cls in classes: - if cls.__init__ not in wrapped: + if cls.__init__ not in wrapped and isinstance(cls, DataLoader): cls._old_init = cls.__init__ cls.__init__ = _wrap_dataloader_init(cls.__init__) wrapped.add(cls.__init__) yield for cls in classes: - if hasattr(cls, "_old_init"): + if hasattr(cls, "_old_init") and isinstance(cls, DataLoader): cls.__init__ = cls._old_init del cls._old_init From 89e1987d4e835457f637e7a5f6e81c431feb4b33 Mon Sep 17 00:00:00 2001 From: nandwalritik Date: Wed, 3 Aug 2022 17:35:41 +0530 Subject: [PATCH 3/6] Update typing for training_type arg in has_len_all_ranks --- src/pytorch_lightning/utilities/data.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 2e253f1957a2f..8ea3ed59800ba 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -127,16 +127,18 @@ def has_len(dataloader: Union[DataLoader, Dataset, Iterable]) -> bool: def has_len_all_ranks( dataloader: DataLoader, - training_type: "pl.Strategy", # type: ignore + training_type: "pl.strategies.Strategy", model: Union["pl.LightningModule", "pl.LightningDataModule"], ) -> bool: """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or infinite dataloader.""" try: local_length = len(dataloader) - total_length = training_type.reduce( - torch.tensor(local_length).to(model.device), reduce_op="sum" - ) # type: ignore + if isinstance(model, pl.LightningModule): + local_length_tensor = torch.tensor(local_length).to(model.device) + else: + local_length_tensor = torch.tensor(local_length) + total_length = training_type.reduce(local_length_tensor, reduce_op="sum") if total_length == 0: rank_zero_warn( From b3c594e1bd4ac2236069bc35a7a2195a662de7aa Mon Sep 17 00:00:00 2001 From: otaj Date: Tue, 13 Sep 2022 18:08:09 +0200 Subject: [PATCH 4/6] WIP for the day --- pyproject.toml | 3 +-- src/lightning_lite/utilities/data.py | 20 ++++++++++---------- src/pytorch_lightning/utilities/data.py | 8 ++++---- 3 files changed, 15 insertions(+), 16 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 2aa527f7af315..ced59ba19a6fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,6 @@ module = [ "pytorch_lightning.profilers.base", "pytorch_lightning.profilers.pytorch", "pytorch_lightning.trainer.trainer", - "pytorch_lightning.tuner.batch_size_scaling", - "lightning_lite.utilities.data", + "pytorch_lightning.tuner.batch_size_scaling" ] ignore_errors = "True" diff --git a/src/lightning_lite/utilities/data.py b/src/lightning_lite/utilities/data.py index cdaf806a0c48d..9db95dd167d82 100644 --- a/src/lightning_lite/utilities/data.py +++ b/src/lightning_lite/utilities/data.py @@ -33,7 +33,7 @@ class _WrapAttrTag(LightningEnum): SET = "set" DEL = "del" - def __call__(self, *args): + def __call__(self, *args: Any) -> None: if self == self.SET: fn = setattr else: @@ -99,7 +99,7 @@ def _get_dataloader_init_args_and_kwargs( arg_names = () # get the dataloader instance `__init__` parameters - params = dict(inspect.signature(dataloader.__init__).parameters) + params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore[misc] has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) if has_variadic_kwargs: # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)` @@ -141,14 +141,14 @@ def _get_dataloader_init_args_and_kwargs( } # the dataloader has required args which we could not extract from the existing attributes if required_args: - required_args = sorted(required_args) + sorted_required_args = sorted(required_args) dataloader_cls_name = dataloader.__class__.__name__ - missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args) + missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in sorted_required_args) raise MisconfigurationException( f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. " "This would fail as some of the `__init__` arguments are not available as instance attributes. " - f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a " - "`*_dataloader` hook of your module, we will do this for you." + f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` " + "inside a `*_dataloader` hook of your module, we will do this for you." f" Otherwise, define {missing_args_message} inside your `__init__`." ) @@ -156,13 +156,13 @@ def _get_dataloader_init_args_and_kwargs( # the dataloader signature does not allow keyword arguments that need to be passed missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys() if missing_kwargs: - missing_kwargs = sorted(missing_kwargs) + sorted_missing_kwargs = sorted(missing_kwargs) dataloader_cls_name = dataloader.__class__.__name__ raise TypeError( f"Trying to inject parameters into the `{dataloader_cls_name}` instance. " "This would fail as it doesn't expose all its attributes in the `__init__` signature. " - f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, " - "add the `__init__` arguments or allow passing `**kwargs`" + f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` " + "class, add the `__init__` arguments or allow passing `**kwargs`" ) return dl_args, dl_kwargs @@ -334,7 +334,7 @@ def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable: :class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" @functools.wraps(method) - def wrapper(obj: Any, *args: Any): + def wrapper(obj: Any, *args: Any) -> None: # First, let's find out if we're the first in inheritance chain calling the patched method. name, *_ = args prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method")) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index c7e53f515ac18..8059a4b9f47a4 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -41,7 +41,7 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn -BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] # type: ignore +BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] warning_cache = WarningCache() @@ -50,7 +50,7 @@ class _WrapAttrTag(LightningEnum): SET = "set" DEL = "del" - def __call__(self, *args): + def __call__(self, *args: Any) -> None: if self == self.SET: fn = setattr else: @@ -58,7 +58,7 @@ def __call__(self, *args): return fn(*args) -def _extract_batch_size(batch: BType) -> Generator[int, None, None]: +def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]: if isinstance(batch, Tensor): if batch.ndim == 0: yield 1 @@ -197,7 +197,7 @@ def _get_dataloader_init_args_and_kwargs( arg_names = () # get the dataloader instance `__init__` parameters - params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore + params = dict(inspect.signature(dataloader.__init__).parameters) # type: ignore[misc] has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) if has_variadic_kwargs: # if the signature takes **kwargs, assume they will be passed down with `super().__init__(**kwargs)` From a31ed671f7adc24da3fede24199bd45bd64355a2 Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 14 Sep 2022 11:24:37 +0200 Subject: [PATCH 5/6] finish typing --- src/lightning_lite/utilities/data.py | 13 +++++----- src/pytorch_lightning/strategies/ipu.py | 2 +- .../utilities/auto_restart.py | 2 +- src/pytorch_lightning/utilities/data.py | 24 +++++-------------- 4 files changed, 15 insertions(+), 26 deletions(-) diff --git a/src/lightning_lite/utilities/data.py b/src/lightning_lite/utilities/data.py index 9db95dd167d82..ca50344567b8e 100644 --- a/src/lightning_lite/utilities/data.py +++ b/src/lightning_lite/utilities/data.py @@ -21,7 +21,7 @@ from typing import Any, Callable, Dict, Generator, Iterable, Optional, Tuple, Type, Union from lightning_utilities.core.inheritance import get_all_subclasses -from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler +from torch.utils.data import BatchSampler, DataLoader, Dataset, IterableDataset, Sampler from lightning_lite.utilities.enums import LightningEnum from lightning_lite.utilities.exceptions import MisconfigurationException @@ -34,6 +34,7 @@ class _WrapAttrTag(LightningEnum): DEL = "del" def __call__(self, *args: Any) -> None: + fn: Union[Callable[[object, str], None], Callable[[object, str, Any], None]] if self == self.SET: fn = setattr else: @@ -45,12 +46,12 @@ def has_iterable_dataset(dataloader: DataLoader) -> bool: return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset) -def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: +def has_len(dataloader: Union[DataLoader, Iterable, Dataset]) -> bool: """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or infinite dataloader.""" try: # try getting the length - if len(dataloader) == 0: + if len(dataloader) == 0: # type: ignore [arg-type] rank_zero_warn( f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention." ) @@ -58,7 +59,7 @@ def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: except (TypeError, NotImplementedError): has_len = False - if has_len and has_iterable_dataset(dataloader): + if has_len and isinstance(dataloader, DataLoader) and has_iterable_dataset(dataloader): rank_zero_warn( "Your `IterableDataset` has `__len__` defined." " In combination with multi-process data loading (when num_workers > 1)," @@ -76,7 +77,7 @@ def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable] def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, - sampler: Optional[Sampler], + sampler: Union[Sampler, Iterable], disallow_batch_sampler: bool = False, ) -> Tuple[Tuple[Any], Dict[str, Any]]: if not isinstance(dataloader, DataLoader): @@ -170,7 +171,7 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, - sampler: Optional[Sampler], + sampler: Union[Sampler, Iterable], disallow_batch_sampler: bool = False, ) -> Dict[str, Any]: """This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its diff --git a/src/pytorch_lightning/strategies/ipu.py b/src/pytorch_lightning/strategies/ipu.py index 64898e6c76251..966789a07feaa 100644 --- a/src/pytorch_lightning/strategies/ipu.py +++ b/src/pytorch_lightning/strategies/ipu.py @@ -245,7 +245,7 @@ def _convert_to_poptorch_loader( return dataloader dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs( - dataloader, sampler, mode, self.replication_factor > 1 # type: ignore[arg-type] + dataloader, sampler, mode, self.replication_factor > 1 ) opts = self.training_opts if mode == RunningStage.TRAINING else self.inference_opts dataloader = _reinstantiate_wrapped_cls( diff --git a/src/pytorch_lightning/utilities/auto_restart.py b/src/pytorch_lightning/utilities/auto_restart.py index d9d8c5da38858..34033b898f3be 100644 --- a/src/pytorch_lightning/utilities/auto_restart.py +++ b/src/pytorch_lightning/utilities/auto_restart.py @@ -62,7 +62,7 @@ class FastForwardSampler(Sampler): samples seen in the last iterations (for the current worker). """ - def __init__(self, sampler: Iterator, attr_name: Optional[str] = None) -> None: + def __init__(self, sampler: Union[Sampler, Iterable], attr_name: Optional[str] = None) -> None: super().__init__(data_source=None) self._sampler = sampler self.restarting: bool = False diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 8059a4b9f47a4..9e1dc2fa51498 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -30,7 +30,6 @@ ) import pytorch_lightning as pl -from lightning_lite.utilities import LightningEnum from lightning_lite.utilities.data import _reinstantiate_wrapped_cls, _replace_value_in_saved_args from lightning_lite.utilities.data import has_iterable_dataset as new_has_iterable_dataset from lightning_lite.utilities.data import has_len as new_has_len @@ -41,23 +40,12 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_warn -BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] +# might be supported in later releases, see https://github.com/python/mypy/pull/13297 +BType = Union[Tensor, str, Mapping[Any, "BType"], Iterable["BType"]] # type: ignore[misc] warning_cache = WarningCache() -class _WrapAttrTag(LightningEnum): - SET = "set" - DEL = "del" - - def __call__(self, *args: Any) -> None: - if self == self.SET: - fn = setattr - else: - fn = delattr - return fn(*args) - - def _extract_batch_size(batch: BType) -> Generator[Optional[int], None, None]: if isinstance(batch, Tensor): if batch.ndim == 0: @@ -109,7 +97,7 @@ def extract_batch_size(batch: BType) -> int: def has_len_all_ranks( dataloader: DataLoader, - strategy: "pl.Strategy", + strategy: "pl.strategies.Strategy", model: Union["pl.LightningModule", "pl.LightningDataModule"], ) -> bool: """Checks if a given Dataloader has ``__len__`` method implemented i.e. if it is a finite dataloader or @@ -158,7 +146,7 @@ def get_len(dataloader: Union[DataLoader, Dataset]) -> Union[int, float]: """ if new_has_len(dataloader): - return len(dataloader) + return len(dataloader) # type: ignore [arg-type] return float("inf") @@ -173,7 +161,7 @@ def _update_dataloader( def _get_dataloader_init_args_and_kwargs( dataloader: DataLoader, - sampler: Optional[Sampler], + sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None, disallow_batch_sampler: bool = False, ) -> Tuple[Tuple[Any], Dict[str, Any]]: @@ -273,7 +261,7 @@ def _get_dataloader_init_args_and_kwargs( def _dataloader_init_kwargs_resolve_sampler( dataloader: DataLoader, - sampler: Optional[Sampler], + sampler: Union[Sampler, Iterable], mode: Optional[RunningStage] = None, disallow_batch_sampler: bool = False, ) -> Dict[str, Any]: From 2e0f6db12759c322630e22067bd82c174f2b062c Mon Sep 17 00:00:00 2001 From: otaj Date: Wed, 14 Sep 2022 11:32:39 +0200 Subject: [PATCH 6/6] parity with lightning_lite --- src/pytorch_lightning/utilities/data.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 9e1dc2fa51498..17f8b9f101cdd 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -217,7 +217,7 @@ def _get_dataloader_init_args_and_kwargs( else: dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler, mode, disallow_batch_sampler)) - required_args: Union[list[str], set[str]] = { + required_args = { p.name for p in params.values() if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) @@ -227,28 +227,28 @@ def _get_dataloader_init_args_and_kwargs( } # the dataloader has required args which we could not extract from the existing attributes if required_args: - required_args = sorted(required_args) + sorted_required_args = sorted(required_args) dataloader_cls_name = dataloader.__class__.__name__ - missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in required_args) + missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in sorted_required_args) raise MisconfigurationException( f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. " "This would fail as some of the `__init__` arguments are not available as instance attributes. " - f"The missing attributes are {required_args}. If you instantiate your `{dataloader_cls_name}` inside a " - "`*_dataloader` hook of your module, we will do this for you." + f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` " + "inside a `*_dataloader` hook of your module, we will do this for you." f" Otherwise, define {missing_args_message} inside your `__init__`." ) if not has_variadic_kwargs: # the dataloader signature does not allow keyword arguments that need to be passed - missing_kwargs: Union[list[str], set[str]] = (set(dl_kwargs) | set(arg_names)) - params.keys() + missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys() if missing_kwargs: - missing_kwargs = sorted(missing_kwargs) + sorted_missing_kwargs = sorted(missing_kwargs) dataloader_cls_name = dataloader.__class__.__name__ raise MisconfigurationException( f"Trying to inject parameters into the `{dataloader_cls_name}` instance. " "This would fail as it doesn't expose all its attributes in the `__init__` signature. " - f"The missing arguments are {missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` class, " - "add the `__init__` arguments or allow passing `**kwargs`" + f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` " + "class, add the `__init__` arguments or allow passing `**kwargs`" ) if _FaultTolerantMode.detect_current_mode().is_automatic: