Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
55 commits
Select commit Hold shift + click to select a range
a36a0f0
remove module from pyproject.toml
jxtngx Jul 23, 2022
c65a2e9
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Jul 23, 2022
8a39011
update
jxtngx Jul 23, 2022
b70fc21
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Jul 23, 2022
bf5e037
update
jxtngx Jul 23, 2022
17893e3
Merge remote-tracking branch 'origin/codeq/trainer-dataconnector' int…
jxtngx Jul 23, 2022
ca0ceab
update
jxtngx Jul 23, 2022
96d8f1f
update
jxtngx Jul 23, 2022
477a4a0
update
jxtngx Jul 23, 2022
09b157b
clean
jxtngx Jul 23, 2022
ac8ea25
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Jul 23, 2022
87cb7a6
update
jxtngx Jul 25, 2022
6a88e79
update
jxtngx Jul 25, 2022
82ee684
update
jxtngx Jul 25, 2022
6020d0b
update
jxtngx Jul 25, 2022
62d1c83
update
jxtngx Jul 25, 2022
1aebec4
update
jxtngx Jul 25, 2022
17cf37b
clean
jxtngx Jul 25, 2022
c4039b1
update
jxtngx Jul 25, 2022
b0cb040
revert dataloader
jxtngx Jul 25, 2022
5f7c0f2
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Jul 26, 2022
83411af
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Jul 28, 2022
dfca1fe
update
jxtngx Jul 28, 2022
66620b2
update
jxtngx Jul 28, 2022
270bd9b
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Jul 28, 2022
8549cf0
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Aug 1, 2022
5e7a4bb
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Aug 1, 2022
b9922a1
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Aug 2, 2022
88a8ab5
update
jxtngx Aug 2, 2022
c2431ae
clean
jxtngx Aug 3, 2022
ee6de74
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Aug 3, 2022
ee11137
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Aug 3, 2022
a1e5694
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Aug 5, 2022
fe64780
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Aug 5, 2022
78311ff
resolve merge conflicts
jxtngx Aug 8, 2022
dd79fd7
Merge branch 'master' into codeq/trainer-dataconnector
jxtngx Aug 9, 2022
698b026
fix another
rohitgr7 Aug 9, 2022
4c9e080
fix another 2
rohitgr7 Aug 9, 2022
ca4d8cb
fix
rohitgr7 Aug 10, 2022
efd8238
Merge branch 'master' into codeq/trainer-dataconnector
rohitgr7 Aug 10, 2022
b1b711e
fix merge conflict
jxtngx Aug 12, 2022
873aa3a
fix mypy
rohitgr7 Aug 16, 2022
bc83d7c
redundant arg
rohitgr7 Aug 16, 2022
0711062
Merge branch 'master' into codeq/trainer-dataconnector
otaj Aug 17, 2022
e861a5e
update
rohitgr7 Aug 17, 2022
67a0bf7
Apply suggestions
Aug 22, 2022
295f2bb
merge master
Aug 22, 2022
90780d0
one extra assert
Aug 22, 2022
cba5901
fix failing test
Aug 22, 2022
5313513
Merge branch 'master' into codeq/trainer-dataconnector
otaj Aug 22, 2022
50e0305
Merge branch 'master' into codeq/trainer-dataconnector
otaj Aug 23, 2022
dea0855
Merge branch 'master' into codeq/trainer-dataconnector
otaj Aug 24, 2022
5423f7a
Merge branch 'master' into codeq/trainer-dataconnector
otaj Aug 25, 2022
ec1702e
Merge branch 'master' into codeq/trainer-dataconnector
otaj Aug 26, 2022
8fa58a8
Merge branch 'master' into codeq/trainer-dataconnector
Borda Aug 26, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ module = [
"pytorch_lightning.profilers.pytorch",
"pytorch_lightning.strategies.sharded",
"pytorch_lightning.trainer.callback_hook",
"pytorch_lightning.trainer.connectors.data_connector",
"pytorch_lightning.trainer.supporters",
"pytorch_lightning.trainer.trainer",
"pytorch_lightning.tuner.batch_size_scaling",
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/core/datamodule.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from torch.utils.data import DataLoader, Dataset, IterableDataset

import pytorch_lightning as pl
from pytorch_lightning.core.hooks import CheckpointHooks, DataHooks
from pytorch_lightning.core.mixins import HyperparametersMixin
from pytorch_lightning.core.saving import _load_from_checkpoint
Expand Down Expand Up @@ -62,7 +63,7 @@ def teardown(self):
def __init__(self) -> None:
super().__init__()
# Pointer to the trainer object
self.trainer = None
self.trainer: Optional["pl.Trainer"] = None

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser, **kwargs) -> ArgumentParser:
Expand Down
3 changes: 2 additions & 1 deletion src/pytorch_lightning/core/module.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
self._use_amp: bool = False

# the precision used
self.precision: int = 32
self.precision: Union[int, str] = 32

# optionally can be set by user
self._example_input_array = None
Expand Down Expand Up @@ -294,6 +294,7 @@ def loggers(self) -> List[Logger]:
def _call_batch_hook(self, hook_name: str, *args: Any) -> Any:
if self._trainer:
datahook_selector = self._trainer._data_connector._datahook_selector
assert datahook_selector is not None
obj = datahook_selector.get_instance(hook_name)
if isinstance(obj, self.__class__):
trainer_method = self._trainer._call_lightning_module_hook
Expand Down
6 changes: 4 additions & 2 deletions src/pytorch_lightning/trainer/configuration_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
elif trainer.state.fn == TrainerFn.PREDICTING:
__verify_eval_loop_configuration(trainer, model, "predict")

__verify_batch_transfer_support(trainer, model)
__verify_batch_transfer_support(trainer)
_check_deprecated_callback_hooks(trainer)
# TODO: Delete _check_on_hpc_hooks in v1.8
_check_on_hpc_hooks(model)
Expand Down Expand Up @@ -149,10 +149,12 @@ def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.Lightning
raise MisconfigurationException(f"No `{step_name}()` method defined to run `Trainer.{trainer_method}`.")


def __verify_batch_transfer_support(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
def __verify_batch_transfer_support(trainer: "pl.Trainer") -> None:
"""Raise Misconfiguration exception since these hooks are not supported in DP mode."""
batch_transfer_hooks = ("transfer_batch_to_device", "on_after_batch_transfer")
datahook_selector = trainer._data_connector._datahook_selector
assert datahook_selector is not None

for hook in batch_transfer_hooks:
# TODO: Remove this blocker once batch transfer to device is integrated in Lightning for DP mode.
if isinstance(trainer.strategy, DataParallelStrategy) and (
Expand Down
50 changes: 27 additions & 23 deletions src/pytorch_lightning/trainer/connectors/data_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import multiprocessing
import os
from dataclasses import dataclass, field
from typing import Any, Collection, List, Optional, Tuple, Union
from typing import Any, Iterable, List, Optional, Tuple, Union
from weakref import proxy

from torch.utils.data import BatchSampler, DataLoader, Sampler, SequentialSampler
Expand Down Expand Up @@ -55,7 +55,7 @@ def __init__(self, trainer: "pl.Trainer", multiple_trainloader_mode: str = "max_
self._test_dataloader_source = _DataLoaderSource(None, "")
self._predict_dataloader_source = _DataLoaderSource(None, "")

self._datahook_selector = _DataHookSelector(None, None)
self._datahook_selector: Optional[_DataHookSelector] = None

@property
def _should_reload_train_dl(self) -> bool:
Expand Down Expand Up @@ -230,7 +230,7 @@ def _worker_check(self, dataloader: DataLoader, name: str) -> None:
category=PossibleUserWarning,
)

def _requires_distributed_sampler(self, dataloader) -> bool:
def _requires_distributed_sampler(self, dataloader: DataLoader) -> bool:
return (
self.trainer._accelerator_connector.replace_sampler_ddp
and self.trainer._accelerator_connector.is_distributed
Expand Down Expand Up @@ -292,14 +292,18 @@ def _prepare_dataloader(

return dataloader

def _resolve_sampler(self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None) -> Sampler:
def _resolve_sampler(
self, dataloader: DataLoader, shuffle: bool, mode: Optional[RunningStage] = None
) -> Union[Sampler, Iterable]:
if self._requires_distributed_sampler(dataloader):
distributed_sampler_kwargs = self.trainer.distributed_sampler_kwargs
assert distributed_sampler_kwargs is not None
sampler = self._get_distributed_sampler(
dataloader,
shuffle,
mode=mode,
overfit_batches=self.trainer.overfit_batches,
**self.trainer.distributed_sampler_kwargs,
**distributed_sampler_kwargs,
)

# update docs too once this is resolved
Expand Down Expand Up @@ -357,7 +361,7 @@ def _reset_eval_dataloader(
dataloaders = self._resolve_overfit_batches(dataloaders, mode)

if not isinstance(dataloaders, list):
dataloaders = [dataloaders]
dataloaders = [dataloaders] # type: ignore[assignment]

if any(dl is None for dl in dataloaders):
rank_zero_warn("One of given dataloaders is None and it will be skipped.")
Expand Down Expand Up @@ -426,7 +430,7 @@ def _reset_eval_dataloader(

return loader_num_batches, dataloaders

def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[DataLoader]]:
def _request_dataloader(self, stage: RunningStage) -> TRAIN_DATALOADERS:
"""Requests a dataloader from the given model by calling dataloader hooks corresponding to the given stage.

Returns:
Expand All @@ -447,10 +451,12 @@ def _request_dataloader(self, stage: RunningStage) -> Union[DataLoader, List[Dat
return dataloader

@staticmethod
def _resolve_overfit_batches(dataloaders: Collection[DataLoader], mode: RunningStage) -> Collection[DataLoader]:
def _resolve_overfit_batches(
dataloaders: Union[TRAIN_DATALOADERS, EVAL_DATALOADERS], mode: RunningStage
) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:
all_have_sequential_sampler = True

def resolve_has_no_sequential_sampler(dataloader: DataLoader):
def resolve_has_no_sequential_sampler(dataloader: DataLoader) -> None:
nonlocal all_have_sequential_sampler
all_have_sequential_sampler = all_have_sequential_sampler & isinstance(
dataloader.sampler, SequentialSampler
Expand All @@ -460,19 +466,23 @@ def resolve_has_no_sequential_sampler(dataloader: DataLoader):

if not all_have_sequential_sampler:
rank_zero_warn(
"You requested to overfit but enabled training dataloader shuffling."
f"You requested to overfit but enabled {mode.dataloader_prefix} dataloader shuffling."
f" We are turning off the {mode.dataloader_prefix} dataloader shuffling for you."
)

def replace_sampler(dataloader: DataLoader) -> DataLoader:
return _update_dataloader(dataloader, sampler=SequentialSampler(dataloader.dataset), mode=mode)
return _update_dataloader(
dataloader,
sampler=SequentialSampler(dataloader.dataset), # type: ignore[arg-type]
mode=mode,
)

dataloaders = apply_to_collection(dataloaders, DataLoader, replace_sampler)

return dataloaders

@staticmethod
def _check_eval_shuffling(dataloader, mode):
def _check_eval_shuffling(dataloader: DataLoader, mode: RunningStage) -> None:
# limit this warning only for samplers assigned automatically when shuffle is set
if _is_dataloader_shuffled(dataloader):
rank_zero_warn(
Expand Down Expand Up @@ -506,18 +516,14 @@ def dataloader(self) -> Union[TRAIN_DATALOADERS, EVAL_DATALOADERS]:

If the source is a module, the method with the corresponding :attr:`name` gets called.
"""
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import

if not self.name:
return self.instance

if isinstance(self.instance, LightningModule):
if isinstance(self.instance, pl.LightningModule):
return self.instance.trainer._call_lightning_module_hook(self.name, pl_module=self.instance)

if isinstance(self.instance, LightningDataModule):
if isinstance(self.instance, pl.LightningDataModule):
method = getattr(self.instance, self.name)
return method()

assert self.instance is not None
return self.instance

def is_defined(self) -> bool:
Expand All @@ -532,9 +538,7 @@ def is_module(self) -> bool:

It does not check whether ``*_dataloader`` methods are actually overridden.
"""
from pytorch_lightning import LightningDataModule, LightningModule # prevent cyclic import

return isinstance(self.instance, (LightningModule, LightningDataModule))
return isinstance(self.instance, (pl.LightningModule, pl.LightningDataModule))


@dataclass
Expand All @@ -553,7 +557,7 @@ class _DataHookSelector:

model: "pl.LightningModule"
datamodule: Optional["pl.LightningDataModule"]
_valid_hooks: Tuple[str] = field(
_valid_hooks: Tuple[str, ...] = field(
default=("on_before_batch_transfer", "transfer_batch_to_device", "on_after_batch_transfer")
)

Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2234,7 +2234,7 @@ def is_global_zero(self) -> bool:
return self.strategy.is_global_zero

@property
def distributed_sampler_kwargs(self) -> Optional[dict]:
def distributed_sampler_kwargs(self) -> Optional[Dict[str, Any]]:
if isinstance(self.strategy, ParallelStrategy):
return self.strategy.distributed_sampler_kwargs

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def val_dataloader(self):
model = TestModel()
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, overfit_batches=2)

with pytest.warns(UserWarning, match="requested to overfit but enabled training dataloader shuffling"):
with pytest.warns(UserWarning, match="requested to overfit but enabled train dataloader shuffling"):
trainer.fit(model)

assert isinstance(trainer.train_dataloader.loaders.sampler, SequentialSampler)
Expand Down