From a36a0f0e05bfca55fd244ec517ff861d9aa9ff8f Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Sat, 23 Jul 2022 07:11:39 -0400 Subject: [PATCH 01/29] remove module from pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index a0960c58f6e6d..93e2db1e45b79 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -73,7 +73,6 @@ module = [ "pytorch_lightning.strategies.tpu_spawn", "pytorch_lightning.trainer.callback_hook", "pytorch_lightning.trainer.connectors.callback_connector", - "pytorch_lightning.trainer.connectors.data_connector", "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", "pytorch_lightning.tuner.batch_size_scaling", From 8a39011d311f59a196d0516add1b98677b5b72c9 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Sat, 23 Jul 2022 11:10:05 -0400 Subject: [PATCH 02/29] update --- src/pytorch_lightning/core/module.py | 2 +- .../trainer/connectors/data_connector.py | 18 ++++++++++++------ src/pytorch_lightning/trainer/trainer.py | 4 ++-- 3 files changed, 15 insertions(+), 9 deletions(-) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index a66c7679b3ee0..9a8ba2bdebd80 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -99,7 +99,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None: self._use_amp: bool = False # the precision used - self.precision: int = 32 + self.precision: Union[int, str] = 32 # optionally can be set by user self._example_input_array = None diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index add62ceece65c..75f4506e91f4a 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -45,6 +45,8 @@ warning_cache = WarningCache() +RESOLVE_OVERFIT_BATCH_DATALOADER_TYPE = Union[Collection[DataLoader], Union[DataLoader[Any], List[DataLoader[Any]]]] + class DataConnector: def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): @@ -224,7 +226,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None: category=PossibleUserWarning, ) - def _requires_distributed_sampler(self, dataloader) -> bool: + def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool: return ( self.trainer._accelerator_connector.replace_sampler_ddp and self.trainer._accelerator_connector.is_distributed @@ -435,10 +437,12 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat return dataloader @staticmethod - def _resolve_overfit_batches(dataloaders: Collection[DataLoader], mode: RunningStage) -> Collection[DataLoader]: + def _resolve_overfit_batches( + dataloaders: RESOLVE_OVERFIT_BATCH_DATALOADER_TYPE, mode: RunningStage + ) -> RESOLVE_OVERFIT_BATCH_DATALOADER_TYPE: all_have_sequential_sampler = True - def resolve_has_no_sequential_sampler(dataloader: DataLoader): + def resolve_has_no_sequential_sampler(dataloader: DataLoader) -> None: nonlocal all_have_sequential_sampler all_have_sequential_sampler = all_have_sequential_sampler & isinstance( dataloader.sampler, SequentialSampler @@ -453,14 +457,16 @@ def resolve_has_no_sequential_sampler(dataloader: DataLoader): ) def replace_sampler(dataloader: DataLoader) -> DataLoader: - return _update_dataloader(dataloader, sampler=SequentialSampler(dataloader.dataset), mode=mode) + return _update_dataloader( + dataloader, sampler=SequentialSampler(dataloader.dataset), mode=mode + ) # type: ignore[arg-type] dataloaders = apply_to_collection(dataloaders, DataLoader, replace_sampler) return dataloaders @staticmethod - def _check_eval_shuffling(dataloader, mode): + def _check_eval_shuffling(dataloader: DataLoader, mode: RunningStage) -> None: # limit this warning only for samplers assigned automatically when shuffle is set if _is_dataloader_shuffled(dataloader): rank_zero_warn( @@ -539,7 +545,7 @@ class _DataHookSelector: datamodule: A LightningDataModule """ - model: "pl.LightningModule" + model: Optional["pl.LightningModule"] datamodule: Optional["pl.LightningDataModule"] _valid_hooks: Tuple[str] = field( default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index d10225fea2d65..b7ff94c35fc09 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -25,7 +25,7 @@ from datetime import timedelta from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Type, Union from weakref import proxy import torch @@ -2229,7 +2229,7 @@ def is_global_zero(self) -> bool: return self.strategy.is_global_zero @property - def distributed_sampler_kwargs(self) -> Optional[dict]: + def distributed_sampler_kwargs(self) -> Mapping: if isinstance(self.strategy, ParallelStrategy): return self.strategy.distributed_sampler_kwargs From bf5e03723d720e5caca19f63653a24ad770bc692 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Sat, 23 Jul 2022 11:25:55 -0400 Subject: [PATCH 03/29] update --- src/pytorch_lightning/trainer/connectors/data_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 75f4506e91f4a..d864067cf836f 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -458,8 +458,8 @@ def resolve_has_no_sequential_sampler(dataloader: DataLoader) -> None: def replace_sampler(dataloader: DataLoader) -> DataLoader: return _update_dataloader( - dataloader, sampler=SequentialSampler(dataloader.dataset), mode=mode - ) # type: ignore[arg-type] + dataloader, sampler=SequentialSampler(dataloader.dataset), mode=mode # type: ignore[arg-type] + ) dataloaders = apply_to_collection(dataloaders, DataLoader, replace_sampler) From ca0ceab5ab2051a968b58c6361f465ee610e325f Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Sat, 23 Jul 2022 11:49:48 -0400 Subject: [PATCH 04/29] update --- .../trainer/connectors/data_connector.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index d864067cf836f..5f277a0c2d456 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,7 +14,7 @@ import multiprocessing import os from dataclasses import dataclass, field -from typing import Any, Callable, Collection, List, Optional, Tuple, Union +from typing import Any, Callable, Collection, Iterable, List, Optional, Tuple, TypeAlias, Union from weakref import proxy from torch.utils.data import DataLoader, Sampler, SequentialSampler @@ -288,7 +288,9 @@ def _prepare_dataloader( return dataloader - def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler: + def _resolve_sampler( + self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None + ) -> Union[Sampler, Iterable]: if self._requires_distributed_sampler(dataloader): sampler = self._get_distributed_sampler( dataloader, @@ -495,7 +497,7 @@ class _DataLoaderSource: instance: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] name: str - def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: + def dataloader(self) -> Union[TypeAlias["instance"], TRAIN_DATALOADERS, EVAL_DATALOADERS]: """Returns the dataloader from the source. If the source is a module, the method with the corresponding :attr:`name` gets called. From 96d8f1f2702388de40779ec0a7b69181303b86c5 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Sat, 23 Jul 2022 12:18:56 -0400 Subject: [PATCH 05/29] update --- .../trainer/connectors/data_connector.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 5f277a0c2d456..8a255f565fa82 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -45,11 +45,12 @@ warning_cache = WarningCache() -RESOLVE_OVERFIT_BATCH_DATALOADER_TYPE = Union[Collection[DataLoader], Union[DataLoader[Any], List[DataLoader[Any]]]] +BATCH_DATALOADER = Union[Collection[DataLoader], List[DataLoader], Union[DataLoader[Any], List[DataLoader[Any]]]] +REQUEST_DATALOADER = Union[DataLoader, List[DataLoader], BATCH_DATALOADER] class DataConnector: - def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): + def __init__(self, trainer: TypeAlias["pl.Trainer"], multiple_trainloader_mode: str = "max_size_cycle"): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode self._train_dataloader_source = _DataLoaderSource(None, "") @@ -143,7 +144,7 @@ def attach_data( # set local properties on the model self._copy_trainer_model_properties(model) - def _copy_trainer_model_properties(self, model: "pl.LightningModule") -> None: + def _copy_trainer_model_properties(self, model: TypeAlias["pl.LightningModule"]) -> None: model.trainer = proxy(self.trainer) # Remove setting use_amp in v1.8 model._use_amp = self.trainer.amp_backend is not None @@ -420,7 +421,7 @@ def _reset_eval_dataloader( return loader_num_batches, dataloaders - def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[DataLoader]]: + def _request_dataloader(self, stage: RunningStage) -> REQUEST_DATALOADER: """Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage. Returns: @@ -439,9 +440,7 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat return dataloader @staticmethod - def _resolve_overfit_batches( - dataloaders: RESOLVE_OVERFIT_BATCH_DATALOADER_TYPE, mode: RunningStage - ) -> RESOLVE_OVERFIT_BATCH_DATALOADER_TYPE: + def _resolve_overfit_batches(dataloaders: BATCH_DATALOADER, mode: RunningStage) -> BATCH_DATALOADER: all_have_sequential_sampler = True def resolve_has_no_sequential_sampler(dataloader: DataLoader) -> None: @@ -549,7 +548,7 @@ class _DataHookSelector: model: Optional["pl.LightningModule"] datamodule: Optional["pl.LightningDataModule"] - _valid_hooks: Tuple[str] = field( + _valid_hooks: Tuple[str, str, str] = field( default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") ) From 477a4a0e230b9e168cd5f158a3c6cceb90c3ae30 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Sat, 23 Jul 2022 12:25:21 -0400 Subject: [PATCH 06/29] update --- src/pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 8a255f565fa82..5cf64b3830796 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -352,7 +352,7 @@ def _reset_eval_dataloader( dataloaders = self._resolve_overfit_batches(dataloaders, mode) if not isinstance(dataloaders, list): - dataloaders = [dataloaders] + dataloaders = [dataloaders] # type: ignore[list-item] if any(dl is None for dl in dataloaders): rank_zero_warn("One of given dataloaders is None and it will be skipped.") From 09b157b3157c847f3130d6d6435da643ca3ef825 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Sat, 23 Jul 2022 12:40:33 -0400 Subject: [PATCH 07/29] clean --- src/pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 5cf64b3830796..bfa25aba2c026 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -45,7 +45,7 @@ warning_cache = WarningCache() -BATCH_DATALOADER = Union[Collection[DataLoader], List[DataLoader], Union[DataLoader[Any], List[DataLoader[Any]]]] +BATCH_DATALOADER = Union[Collection[DataLoader], List[DataLoader], DataLoader[Any]] REQUEST_DATALOADER = Union[DataLoader, List[DataLoader], BATCH_DATALOADER] From 87cb7a66d3ed03cdca5ed8ae60d211a7049e329e Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 25 Jul 2022 06:25:31 -0400 Subject: [PATCH 08/29] update --- src/pytorch_lightning/trainer/connectors/data_connector.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index bfa25aba2c026..826ed3df0174e 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,11 +14,12 @@ import multiprocessing import os from dataclasses import dataclass, field -from typing import Any, Callable, Collection, Iterable, List, Optional, Tuple, TypeAlias, Union +from typing import Any, Callable, Collection, Iterable, List, Optional, Tuple, Union from weakref import proxy from torch.utils.data import DataLoader, Sampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler +from typing_extensions import TypeAlias import pytorch_lightning as pl from pytorch_lightning.accelerators.ipu import IPUAccelerator From 6a88e79a884d1dce2e4a0b9e839a66b34e9f90e4 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 25 Jul 2022 06:53:19 -0400 Subject: [PATCH 09/29] update --- src/pytorch_lightning/core/datamodule.py | 3 ++- src/pytorch_lightning/trainer/connectors/data_connector.py | 4 ++-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/core/datamodule.py b/src/pytorch_lightning/core/datamodule.py index 60a010ff7c3b9..286031b68ca2d 100644 --- a/src/pytorch_lightning/core/datamodule.py +++ b/src/pytorch_lightning/core/datamodule.py @@ -17,6 +17,7 @@ from torch.utils.data import DataLoader, Dataset, IterableDataset +import pytorch_lightning as pl from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks from pytorch_lightning.core.mixins import HyperparametersMixin from pytorch_lightning.core.saving import _load_from_checkpoint @@ -61,7 +62,7 @@ def teardown(self): def __init__(self) -> None: super().__init__() # Pointer to the trainer object - self.trainer = None + self.trainer: Optional["pl.Trainer"] = None @classmethod def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser: diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 826ed3df0174e..87b10e1601512 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -51,7 +51,7 @@ class DataConnector: - def __init__(self, trainer: TypeAlias["pl.Trainer"], multiple_trainloader_mode: str = "max_size_cycle"): + def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): self.trainer = trainer self.multiple_trainloader_mode = multiple_trainloader_mode self._train_dataloader_source = _DataLoaderSource(None, "") @@ -162,7 +162,7 @@ def attach_dataloaders( self.trainer.train_dataloader = None self.trainer.val_dataloaders = None self.trainer.test_dataloaders = None - self.trainer.predict_dataloaders = None + self.trainer.predict_dataloaders = None # type: ignore[assignment] self._train_dataloader_source = _DataLoaderSource( train_dataloaders if train_dataloaders is not None else model, "train_dataloader" From 82ee684a0b76486eee08b0b9f7462762edf3ae51 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 25 Jul 2022 06:59:26 -0400 Subject: [PATCH 10/29] update --- src/pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 87b10e1601512..6b0c999bf002a 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -145,7 +145,7 @@ def attach_data( # set local properties on the model self._copy_trainer_model_properties(model) - def _copy_trainer_model_properties(self, model: TypeAlias["pl.LightningModule"]) -> None: + def _copy_trainer_model_properties(self, model: "pl.LightningModule") -> None: model.trainer = proxy(self.trainer) # Remove setting use_amp in v1.8 model._use_amp = self.trainer.amp_backend is not None From 6020d0b3ec9afa56ee34757adea889163eac16bc Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 25 Jul 2022 07:20:04 -0400 Subject: [PATCH 11/29] update --- src/pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 6b0c999bf002a..fb7b2bbc7e6b8 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -549,7 +549,7 @@ class _DataHookSelector: model: Optional["pl.LightningModule"] datamodule: Optional["pl.LightningDataModule"] - _valid_hooks: Tuple[str, str, str] = field( + _valid_hooks: Tuple[str, ...] = field( default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") ) From 62d1c83b374878aae1ac373b36ad605aa61ff2a9 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 25 Jul 2022 07:46:28 -0400 Subject: [PATCH 12/29] update --- src/pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index fb7b2bbc7e6b8..aaf3e5f10777f 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -353,7 +353,7 @@ def _reset_eval_dataloader( dataloaders = self._resolve_overfit_batches(dataloaders, mode) if not isinstance(dataloaders, list): - dataloaders = [dataloaders] # type: ignore[list-item] + dataloaders = list(dataloaders) if any(dl is None for dl in dataloaders): rank_zero_warn("One of given dataloaders is None and it will be skipped.") From 1aebec4fe5f7cb291784e2164367c60a578b8f79 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 25 Jul 2022 08:10:25 -0400 Subject: [PATCH 13/29] update --- src/pytorch_lightning/trainer/connectors/data_connector.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index aaf3e5f10777f..d45bca6bcc66a 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -19,7 +19,6 @@ from torch.utils.data import DataLoader, Sampler, SequentialSampler from torch.utils.data.distributed import DistributedSampler -from typing_extensions import TypeAlias import pytorch_lightning as pl from pytorch_lightning.accelerators.ipu import IPUAccelerator @@ -48,6 +47,7 @@ BATCH_DATALOADER = Union[Collection[DataLoader], List[DataLoader], DataLoader[Any]] REQUEST_DATALOADER = Union[DataLoader, List[DataLoader], BATCH_DATALOADER] +INSTANCE = Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] class DataConnector: @@ -497,7 +497,7 @@ class _DataLoaderSource: instance: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] name: str - def dataloader(self) -> Union[TypeAlias["instance"], TRAIN_DATALOADERS, EVAL_DATALOADERS]: + def dataloader(self) -> Union[INSTANCE, TRAIN_DATALOADERS, EVAL_DATALOADERS]: """Returns the dataloader from the source. If the source is a module, the method with the corresponding :attr:`name` gets called. From 17cf37becdd3eca265ce18955952c4487907abc4 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 25 Jul 2022 08:13:02 -0400 Subject: [PATCH 14/29] clean --- src/pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index d45bca6bcc66a..a5f97f7dd9658 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -494,7 +494,7 @@ class _DataLoaderSource: that returns the desired dataloader(s). """ - instance: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] + instance: INSTANCE name: str def dataloader(self) -> Union[INSTANCE, TRAIN_DATALOADERS, EVAL_DATALOADERS]: From c4039b1bc929aeb96049cfbd85e3b032f2166f14 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 25 Jul 2022 08:15:24 -0400 Subject: [PATCH 15/29] update --- src/pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index a5f97f7dd9658..b104f3e491aa1 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -353,7 +353,7 @@ def _reset_eval_dataloader( dataloaders = self._resolve_overfit_batches(dataloaders, mode) if not isinstance(dataloaders, list): - dataloaders = list(dataloaders) + dataloaders = [dataloaders] # type: ignore[list-item] if any(dl is None for dl in dataloaders): rank_zero_warn("One of given dataloaders is None and it will be skipped.") From b0cb0409a05469990f4805b953ed937f06707bd3 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 25 Jul 2022 08:26:01 -0400 Subject: [PATCH 16/29] revert dataloader --- src/pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index b104f3e491aa1..dbe26284ee6b6 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -497,7 +497,7 @@ class _DataLoaderSource: instance: INSTANCE name: str - def dataloader(self) -> Union[INSTANCE, TRAIN_DATALOADERS, EVAL_DATALOADERS]: + def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: """Returns the dataloader from the source. If the source is a module, the method with the corresponding :attr:`name` gets called. From dfca1fe56087253e8fc3fb2bab52b7bdcce72377 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Thu, 28 Jul 2022 06:04:41 -0400 Subject: [PATCH 17/29] update --- src/pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index e20b3960c3323..9fc82fa1f6c59 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -497,7 +497,7 @@ class _DataLoaderSource: instance: INSTANCE name: str - def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: + def dataloader(self) -> Union[INSTANCE, TRAIN_DATALOADERS, EVAL_DATALOADERS]: """Returns the dataloader from the source. If the source is a module, the method with the corresponding :attr:`name` gets called. From 66620b23f3b46acd86b113a052a62b7cf3a2cb93 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Thu, 28 Jul 2022 06:23:56 -0400 Subject: [PATCH 18/29] update --- src/pytorch_lightning/strategies/deepspeed.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index e7fbcf91967fc..e5a8f6c1e14e8 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -674,7 +674,7 @@ def _auto_select_batch_size(self) -> int: try: train_dataloader = train_dl_source.dataloader() if hasattr(train_dataloader, "batch_sampler"): - batch_size = train_dataloader.batch_sampler.batch_size # type: ignore[union-attr] + batch_size = train_dataloader.batch_sampler.batch_size # type: ignore[union-attr, assignment] # broad exception on purpose as `source.dataloader()` will fail if the dataloader requires `setup` # to have been called before except Exception: From 88a8ab5d355f53f306af7cd8ce55fbc65e35bc94 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Tue, 2 Aug 2022 12:50:49 -0400 Subject: [PATCH 19/29] update --- src/pytorch_lightning/strategies/deepspeed.py | 2 +- src/pytorch_lightning/trainer/connectors/data_connector.py | 6 +++--- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/strategies/deepspeed.py b/src/pytorch_lightning/strategies/deepspeed.py index e5a8f6c1e14e8..e7fbcf91967fc 100644 --- a/src/pytorch_lightning/strategies/deepspeed.py +++ b/src/pytorch_lightning/strategies/deepspeed.py @@ -674,7 +674,7 @@ def _auto_select_batch_size(self) -> int: try: train_dataloader = train_dl_source.dataloader() if hasattr(train_dataloader, "batch_sampler"): - batch_size = train_dataloader.batch_sampler.batch_size # type: ignore[union-attr, assignment] + batch_size = train_dataloader.batch_sampler.batch_size # type: ignore[union-attr] # broad exception on purpose as `source.dataloader()` will fail if the dataloader requires `setup` # to have been called before except Exception: diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index cb3ef64243bb1..b1f6fa976864b 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -497,7 +497,7 @@ class _DataLoaderSource: instance: INSTANCE name: str - def dataloader(self) -> Union[INSTANCE, TRAIN_DATALOADERS, EVAL_DATALOADERS]: + def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: """Returns the dataloader from the source. If the source is a module, the method with the corresponding :attr:`name` gets called. @@ -505,7 +505,7 @@ def dataloader(self) -> Union[INSTANCE, TRAIN_DATALOADERS, EVAL_DATALOADERS]: from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import if not self.name: - return self.instance + return self.instance # type: ignore[return-value] if isinstance(self.instance, LightningModule): return self.instance.trainer._call_lightning_module_hook(self.name, pl_module=self.instance) @@ -514,7 +514,7 @@ def dataloader(self) -> Union[INSTANCE, TRAIN_DATALOADERS, EVAL_DATALOADERS]: method = getattr(self.instance, self.name) return method() - return self.instance + return self.instance # type: ignore[return-value] def is_defined(self) -> bool: """Returns whether the source dataloader can be retrieved or not. From c2431aeb555bace222bbe470d8a60c381458f125 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Wed, 3 Aug 2022 08:02:33 -0400 Subject: [PATCH 20/29] clean --- src/pytorch_lightning/trainer/connectors/data_connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index b1f6fa976864b..e214182c97b8c 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -47,7 +47,6 @@ BATCH_DATALOADER = Union[Collection[DataLoader], List[DataLoader], DataLoader[Any]] REQUEST_DATALOADER = Union[DataLoader, List[DataLoader], BATCH_DATALOADER] -INSTANCE = Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] class DataConnector: @@ -494,7 +493,7 @@ class _DataLoaderSource: that returns the desired dataloader(s). """ - instance: INSTANCE + instance: Optional[Union[TRAIN_DATALOADERS, EVAL_DATALOADERS, "pl.LightningModule", "pl.LightningDataModule"]] name: str def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: From 698b02620a59b790654ddc66c7c59962c5d0b1e6 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 10 Aug 2022 01:15:37 +0530 Subject: [PATCH 21/29] fix another --- .../loops/dataloader/prediction_loop.py | 3 ++- .../loops/epoch/prediction_epoch_loop.py | 3 ++- .../trainer/connectors/data_connector.py | 18 ++++++------------ src/pytorch_lightning/trainer/trainer.py | 4 ++-- 4 files changed, 12 insertions(+), 16 deletions(-) diff --git a/src/pytorch_lightning/loops/dataloader/prediction_loop.py b/src/pytorch_lightning/loops/dataloader/prediction_loop.py index f026ae5c7c813..2faf66a3e0c3e 100644 --- a/src/pytorch_lightning/loops/dataloader/prediction_loop.py +++ b/src/pytorch_lightning/loops/dataloader/prediction_loop.py @@ -60,7 +60,8 @@ def max_batches(self) -> List[int]: @property def dataloaders(self) -> Sequence[DataLoader]: """Returns all prediction dataloaders.""" - return self.trainer.predict_dataloaders + dataloaders = self.trainer.predict_dataloaders + return [] if dataloaders is None else dataloaders @property def skip(self) -> bool: diff --git a/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py b/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py index ba16c56feee4c..015f44d55795d 100644 --- a/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py +++ b/src/pytorch_lightning/loops/epoch/prediction_epoch_loop.py @@ -162,8 +162,9 @@ def _get_batch_indices(self, dataloader_idx: int) -> List[List[int]]: """Returns a reference to the seen batch indices if the dataloader has a batch sampler wrapped by our :class:`~pytorch_lightning.overrides.distributed.IndexBatchSamplerWrapper`.""" # the batch_sampler is not be defined in case of CombinedDataLoaders + assert self.trainer.predict_dataloaders batch_sampler = getattr( - self.trainer.predict_dataloaders[dataloader_idx], # type: ignore[has-type] + self.trainer.predict_dataloaders[dataloader_idx], "batch_sampler", None, ) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index e214182c97b8c..dc740d7f0249a 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -161,7 +161,7 @@ def attach_dataloaders( self.trainer.train_dataloader = None self.trainer.val_dataloaders = None self.trainer.test_dataloaders = None - self.trainer.predict_dataloaders = None # type: ignore[assignment] + self.trainer.predict_dataloaders = None self._train_dataloader_source = _DataLoaderSource( train_dataloaders if train_dataloaders is not None else model, "train_dataloader" @@ -501,19 +501,15 @@ def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: If the source is a module, the method with the corresponding :attr:`name` gets called. """ - from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import - - if not self.name: - return self.instance # type: ignore[return-value] - - if isinstance(self.instance, LightningModule): + if isinstance(self.instance, pl.LightningModule): return self.instance.trainer._call_lightning_module_hook(self.name, pl_module=self.instance) - if isinstance(self.instance, LightningDataModule): + if isinstance(self.instance, pl.LightningDataModule): method = getattr(self.instance, self.name) return method() - return self.instance # type: ignore[return-value] + assert self.instance + return self.instance def is_defined(self) -> bool: """Returns whether the source dataloader can be retrieved or not. @@ -527,9 +523,7 @@ def is_module(self) -> bool: It does not check whether ``*_dataloader`` methods are actually overridden. """ - from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import - - return isinstance(self.instance, (LightningModule, LightningDataModule)) + return isinstance(self.instance, (pl.LightningModule, pl.LightningDataModule)) @dataclass diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index fba1a0e41ad32..3ac88f05ce338 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -628,13 +628,13 @@ def _setup_on_init(self) -> None: self.num_sanity_val_batches = [] self.num_test_batches = [] self.num_val_batches = [] + self.num_predict_batches = [] self.test_dataloaders = None self.val_dataloaders = None + self.predict_dataloaders = None self._last_train_dl_reload_epoch = float("-inf") self._last_val_dl_reload_epoch = float("-inf") - self.num_predict_batches = [] - def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: Any) -> Any: r""" Error handling, intended to be used only for main trainer function entry points (fit, validate, test, predict) From 4c9e08064bcb6dbbafefd44754b6851b9a8381d8 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 10 Aug 2022 01:48:51 +0530 Subject: [PATCH 22/29] fix another 2 --- .../trainer/connectors/data_connector.py | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index dc740d7f0249a..53a549ce6307f 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,7 +14,7 @@ import multiprocessing import os from dataclasses import dataclass, field -from typing import Any, Callable, Collection, Iterable, List, Optional, Tuple, Union +from typing import Any, Callable, cast, Iterable, List, Optional, Tuple, Union from weakref import proxy from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler @@ -45,9 +45,6 @@ warning_cache = WarningCache() -BATCH_DATALOADER = Union[Collection[DataLoader], List[DataLoader], DataLoader[Any]] -REQUEST_DATALOADER = Union[DataLoader, List[DataLoader], BATCH_DATALOADER] - class DataConnector: def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_size_cycle"): @@ -347,13 +344,14 @@ def _reset_eval_dataloader( # always get the loaders first so we can count how many there are dataloaders = self._request_dataloader(mode) - - if self.trainer.overfit_batches > 0: - dataloaders = self._resolve_overfit_batches(dataloaders, mode) + dataloaders = cast(EVAL_DATALOADERS, dataloaders) if not isinstance(dataloaders, list): dataloaders = [dataloaders] # type: ignore[list-item] + if self.trainer.overfit_batches > 0: + dataloaders = self._resolve_overfit_batches(dataloaders, mode) + if any(dl is None for dl in dataloaders): rank_zero_warn("One of given dataloaders is None and it will be skipped.") @@ -421,7 +419,7 @@ def _reset_eval_dataloader( return loader_num_batches, dataloaders - def _request_dataloader(self, stage: RunningStage) -> REQUEST_DATALOADER: + def _request_dataloader(self, stage: RunningStage) -> TRAIN_DATALOADERS: """Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage. Returns: @@ -440,7 +438,7 @@ def _request_dataloader(self, stage: RunningStage) -> REQUEST_DATALOADER: return dataloader @staticmethod - def _resolve_overfit_batches(dataloaders: BATCH_DATALOADER, mode: RunningStage) -> BATCH_DATALOADER: + def _resolve_overfit_batches(dataloaders: EVAL_DATALOADERS, mode: RunningStage) -> EVAL_DATALOADERS: all_have_sequential_sampler = True def resolve_has_no_sequential_sampler(dataloader: DataLoader) -> None: From ca4d8cb48b1bd3ff9163351588d0d7bfb583fcbb Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 10 Aug 2022 14:00:19 +0530 Subject: [PATCH 23/29] fix --- src/pytorch_lightning/trainer/connectors/data_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 53a549ce6307f..7015ec44bad58 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -506,7 +506,7 @@ def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: method = getattr(self.instance, self.name) return method() - assert self.instance + assert self.instance is not None return self.instance def is_defined(self) -> bool: From 873aa3ad2bb0f277f010240b6844582d1f0091c2 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 16 Aug 2022 23:15:26 +0530 Subject: [PATCH 24/29] fix mypy --- src/pytorch_lightning/trainer/configuration_validator.py | 2 ++ src/pytorch_lightning/trainer/connectors/data_connector.py | 4 ++-- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/trainer/configuration_validator.py b/src/pytorch_lightning/trainer/configuration_validator.py index 6cf3e6d52ed95..0c78c2815421a 100644 --- a/src/pytorch_lightning/trainer/configuration_validator.py +++ b/src/pytorch_lightning/trainer/configuration_validator.py @@ -153,6 +153,8 @@ def __verify_batch_transfer_support(trainer: "pl.Trainer", model: "pl.LightningM """Raise Misconfiguration exception since these hooks are not supported in DP mode.""" batch_transfer_hooks = ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") datahook_selector = trainer._data_connector._datahook_selector + assert datahook_selector is not None + for hook in batch_transfer_hooks: # TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode. if isinstance(trainer.strategy, DataParallelStrategy) and ( diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index c76d0d2dbf0d4..0b3c25b48546a 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -55,7 +55,7 @@ def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_ self._test_dataloader_source = _DataLoaderSource(None, "") self._predict_dataloader_source = _DataLoaderSource(None, "") - self._datahook_selector = _DataHookSelector(None, None) + self._datahook_selector: Optional[_DataHookSelector] = None @property def _should_reload_train_dl(self) -> bool: @@ -542,7 +542,7 @@ class _DataHookSelector: datamodule: A ``LightningDataModule`` """ - model: Optional["pl.LightningModule"] + model: "pl.LightningModule" datamodule: Optional["pl.LightningDataModule"] _valid_hooks: Tuple[str, ...] = field( default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") From bc83d7ce4000d22bd44bd09c93172dd675113c9a Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Tue, 16 Aug 2022 23:16:05 +0530 Subject: [PATCH 25/29] redundant arg --- src/pytorch_lightning/trainer/configuration_validator.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/trainer/configuration_validator.py b/src/pytorch_lightning/trainer/configuration_validator.py index 0c78c2815421a..20bf67e5ec5e6 100644 --- a/src/pytorch_lightning/trainer/configuration_validator.py +++ b/src/pytorch_lightning/trainer/configuration_validator.py @@ -46,7 +46,7 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None: elif trainer.state.fn == TrainerFn.PREDICTING: __verify_eval_loop_configuration(trainer, model, "predict") - __verify_batch_transfer_support(trainer, model) + __verify_batch_transfer_support(trainer) _check_deprecated_callback_hooks(trainer) # TODO: Delete _check_on_hpc_hooks in v1.8 _check_on_hpc_hooks(model) @@ -149,7 +149,7 @@ def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightning raise MisconfigurationException(f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`.") -def __verify_batch_transfer_support(trainer: "pl.Trainer", model: "pl.LightningModule") -> None: +def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None: """Raise Misconfiguration exception since these hooks are not supported in DP mode.""" batch_transfer_hooks = ("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer") datahook_selector = trainer._data_connector._datahook_selector From e861a5eb50c6f57560fd6dac52e6a8076eaa9c44 Mon Sep 17 00:00:00 2001 From: rohitgr7 Date: Wed, 17 Aug 2022 20:00:16 +0530 Subject: [PATCH 26/29] update --- .../trainer/connectors/data_connector.py | 20 +++++++++++-------- src/pytorch_lightning/trainer/trainer.py | 4 ++-- 2 files changed, 14 insertions(+), 10 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 0b3c25b48546a..61ca05a49931d 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -14,7 +14,7 @@ import multiprocessing import os from dataclasses import dataclass, field -from typing import Any, cast, Iterable, List, Optional, Tuple, Union +from typing import Any, Iterable, List, Optional, Tuple, Union from weakref import proxy from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler @@ -290,12 +290,15 @@ def _resolve_sampler( self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None ) -> Union[Sampler, Iterable]: if self._requires_distributed_sampler(dataloader): + distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs + assert distributed_sampler_kwargs is not None + assert self.trainer.distributed_sampler_kwargs is not None sampler = self._get_distributed_sampler( dataloader, shuffle, mode=mode, overfit_batches=self.trainer.overfit_batches, - **self.trainer.distributed_sampler_kwargs, + **distributed_sampler_kwargs, ) # update docs too once this is resolved @@ -348,14 +351,13 @@ def _reset_eval_dataloader( # always get the loaders first so we can count how many there are dataloaders = self._request_dataloader(mode) - dataloaders = cast(EVAL_DATALOADERS, dataloaders) - - if not isinstance(dataloaders, list): - dataloaders = [dataloaders] # type: ignore[list-item] if self.trainer.overfit_batches > 0: dataloaders = self._resolve_overfit_batches(dataloaders, mode) + if not isinstance(dataloaders, list): + dataloaders = [dataloaders] # type: ignore[assignment] + if any(dl is None for dl in dataloaders): rank_zero_warn("One of given dataloaders is None and it will be skipped.") @@ -442,7 +444,9 @@ def _request_dataloader(self, stage: RunningStage) -> TRAIN_DATALOADERS: return dataloader @staticmethod - def _resolve_overfit_batches(dataloaders: EVAL_DATALOADERS, mode: RunningStage) -> EVAL_DATALOADERS: + def _resolve_overfit_batches( + dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], mode: RunningStage + ) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]: all_have_sequential_sampler = True def resolve_has_no_sequential_sampler(dataloader: DataLoader) -> None: @@ -455,7 +459,7 @@ def resolve_has_no_sequential_sampler(dataloader: DataLoader) -> None: if not all_have_sequential_sampler: rank_zero_warn( - "You requested to overfit but enabled training dataloader shuffling." + f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling." f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you." ) diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 21228a2394494..93543c2ec13ba 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -25,7 +25,7 @@ from datetime import timedelta from functools import partial from pathlib import Path -from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Type, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Optional, Type, Union from weakref import proxy import torch @@ -2231,7 +2231,7 @@ def is_global_zero(self) -> bool: return self.strategy.is_global_zero @property - def distributed_sampler_kwargs(self) -> Mapping: + def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]: if isinstance(self.strategy, ParallelStrategy): return self.strategy.distributed_sampler_kwargs From 67a0bf709afc056d8738494d953cd218d9ed4cb7 Mon Sep 17 00:00:00 2001 From: otaj Date: Mon, 22 Aug 2022 10:34:33 +0200 Subject: [PATCH 27/29] Apply suggestions --- src/pytorch_lightning/trainer/connectors/data_connector.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/trainer/connectors/data_connector.py b/src/pytorch_lightning/trainer/connectors/data_connector.py index 61ca05a49931d..c127fce3cf824 100644 --- a/src/pytorch_lightning/trainer/connectors/data_connector.py +++ b/src/pytorch_lightning/trainer/connectors/data_connector.py @@ -292,7 +292,6 @@ def _resolve_sampler( if self._requires_distributed_sampler(dataloader): distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs assert distributed_sampler_kwargs is not None - assert self.trainer.distributed_sampler_kwargs is not None sampler = self._get_distributed_sampler( dataloader, shuffle, @@ -465,7 +464,9 @@ def resolve_has_no_sequential_sampler(dataloader: DataLoader) -> None: def replace_sampler(dataloader: DataLoader) -> DataLoader: return _update_dataloader( - dataloader, sampler=SequentialSampler(dataloader.dataset), mode=mode # type: ignore[arg-type] + dataloader, + sampler=SequentialSampler(dataloader.dataset), # type: ignore[arg-type] + mode=mode, ) dataloaders = apply_to_collection(dataloaders, DataLoader, replace_sampler) From 90780d004a39398dfb34f4502a23bdbb6b45337a Mon Sep 17 00:00:00 2001 From: otaj Date: Mon, 22 Aug 2022 10:54:01 +0200 Subject: [PATCH 28/29] one extra assert --- src/pytorch_lightning/core/module.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_lightning/core/module.py b/src/pytorch_lightning/core/module.py index 6f2deada34c2e..a479beadc7931 100644 --- a/src/pytorch_lightning/core/module.py +++ b/src/pytorch_lightning/core/module.py @@ -294,6 +294,7 @@ def loggers(self) -> List[Logger]: def _call_batch_hook(self, hook_name: str, *args: Any) -> Any: if self._trainer: datahook_selector = self._trainer._data_connector._datahook_selector + assert datahook_selector is not None obj = datahook_selector.get_instance(hook_name) if isinstance(obj, self.__class__): trainer_method = self._trainer._call_lightning_module_hook From cba5901f341066f9bc88f2f0b4a4dad803611a81 Mon Sep 17 00:00:00 2001 From: otaj Date: Mon, 22 Aug 2022 13:29:15 +0200 Subject: [PATCH 29/29] fix failing test --- tests/tests_pytorch/trainer/flags/test_overfit_batches.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py index da3e154349e1b..dc73e76cc391b 100644 --- a/tests/tests_pytorch/trainer/flags/test_overfit_batches.py +++ b/tests/tests_pytorch/trainer/flags/test_overfit_batches.py @@ -66,7 +66,7 @@ def val_dataloader(self): model = TestModel() trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2) - with pytest.warns(UserWarning, match="requested to overfit but enabled training dataloader shuffling"): + with pytest.warns(UserWarning, match="requested to overfit but enabled train dataloader shuffling"): trainer.fit(model) assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)