Skip to content

Commit 101368d

Browse files
committed
Merge branch 'master' into refactor/mc-save-on-train-epoch-end
2 parents 277525a + 000fbe6 commit 101368d

File tree

15 files changed

+177
-150
lines changed

15 files changed

+177
-150
lines changed

.github/workflows/code-checks.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,6 @@ jobs:
1717
python-version: 3.8
1818
- name: Install mypy
1919
run: |
20-
pip install mypy==0.790
20+
grep mypy requirements/test.txt | xargs -0 pip install
2121
pip list
2222
- run: mypy

.pre-commit-config.yaml

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,12 @@
1515
default_language_version:
1616
python: python3.8
1717

18+
ci:
19+
autofix_prs: true
20+
autoupdate_commit_msg: '[pre-commit.ci] pre-commit suggestions'
21+
autoupdate_schedule: quarterly
22+
# submodules: true
23+
1824
repos:
1925
- repo: https://github.com/pre-commit/pre-commit-hooks
2026
rev: v4.0.1
@@ -40,7 +46,7 @@ repos:
4046
- id: detect-private-key
4147

4248
- repo: https://github.com/PyCQA/isort
43-
rev: 5.9.1
49+
rev: 5.9.2
4450
hooks:
4551
- id: isort
4652
name: Format imports

CHANGELOG.md

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
99

1010
### Added
1111

12+
- Add `extract_batch_size` utility and corresponding tests to extract batch dimension from multiple batch types. ([#8357](https://github.com/PyTorchLightning/pytorch-lightning/pull/8357/))
13+
1214
- Add support for named parameter groups in `LearningRateMonitor` ([#7987](https://github.com/PyTorchLightning/pytorch-lightning/pull/7987))
1315

1416

@@ -116,7 +118,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
116118
- Added support `LightningModule.save_hyperparameters` when `LightningModule` is a dataclass ([#7992](https://github.com/PyTorchLightning/pytorch-lightning/pull/7992))
117119

118120

119-
- Add support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980))
121+
- Added support for overriding `optimizer_zero_grad` and `optimizer_step` when using accumulate_grad_batches ([#7980](https://github.com/PyTorchLightning/pytorch-lightning/pull/7980))
122+
123+
124+
- Added `logger` boolean flag to `save_hyperparameters` ([#7960](https://github.com/PyTorchLightning/pytorch-lightning/pull/7960))
120125

121126

122127
- Add support for calling scripts using the module syntax (`python -m package.script`) ([#8073](https://github.com/PyTorchLightning/pytorch-lightning/pull/8073))

pyproject.toml

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ requires = [
44
"wheel",
55
]
66

7+
78
[tool.isort]
89
known_first_party = [
910
"benchmarks",
@@ -16,3 +17,51 @@ profile = "black"
1617
line_length = 120
1718
force_sort_within_sections = "False"
1819
order_by_type = "False"
20+
21+
22+
[tool.mypy]
23+
files = ["pytorch_lightning", "pl_examples", "benchmarks"]
24+
disallow_untyped_defs = "True"
25+
ignore_missing_imports = "True"
26+
show_error_codes = "True"
27+
warn_redundant_casts = "True"
28+
warn_unused_configs = "True"
29+
warn_unused_ignores = "True"
30+
allow_redefinition = "True"
31+
# disable this rule as the Trainer attributes are defined in the connectors, not in its __init__
32+
disable_error_code = "attr-defined"
33+
34+
35+
# TODO: Fix typing for these modules
36+
[[tool.mypy.overrides]]
37+
module = [
38+
"pytorch_lightning.callbacks.*",
39+
"pytorch_lightning.core.*",
40+
"pytorch_lightning.loggers.*",
41+
"pytorch_lightning.loops.*",
42+
"pytorch_lightning.metrics.*",
43+
"pytorch_lightning.overrides.*",
44+
"pytorch_lightning.plugins.environments.*",
45+
"pytorch_lightning.plugins.training_type.*",
46+
"pytorch_lightning.profiler.*",
47+
"pytorch_lightning.trainer.*",
48+
"pytorch_lightning.distributed.*",
49+
"pytorch_lightning.tuner.*",
50+
"pytorch_lightning.utilities.*",
51+
"pl_examples.*",
52+
"benchmarks.*",
53+
"tests.helpers.*"
54+
]
55+
ignore_errors = "True"
56+
57+
[[tool.mypy.overrides]]
58+
module = [
59+
"pytorch_lightning.callbacks.pruning",
60+
"pytorch_lightning.trainer.evaluation_loop",
61+
"pytorch_lightning.trainer.connectors.logger_connector",
62+
"pytorch_lightning.utilities.cli",
63+
"pytorch_lightning.utilities.device_dtype_mixin",
64+
"pytorch_lightning.utilities.device_parser",
65+
"pytorch_lightning.utilities.parsing",
66+
]
67+
ignore_errors = "False"

pytorch_lightning/accelerators/accelerator.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ def batch_to_device(
174174
dataloader_idx: The index of the dataloader to which the batch belongs.
175175
"""
176176
model = self.lightning_module
177+
device = device or self.root_device
178+
177179
if model is not None and not isinstance(self.training_type_plugin, DataParallelPlugin):
178180
# no need to transfer batch to device in DP mode
179181
return model._apply_batch_transfer_handler(batch, device, dataloader_idx)

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -46,7 +46,7 @@ def __init__(self, trainer: 'pl.Trainer', log_gpu_memory: Optional[str] = None)
4646

4747
def on_trainer_init(
4848
self,
49-
logger: LightningLoggerBase,
49+
logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]],
5050
flush_logs_every_n_steps: int,
5151
log_every_n_steps: int,
5252
move_metrics_to_cpu: bool,
@@ -66,7 +66,7 @@ def should_update_logs(self) -> bool:
6666
should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
6767
return should_log_every_n_steps or self.trainer.should_stop
6868

69-
def configure_logger(self, logger: Union[bool, Iterable, LightningLoggerBase]) -> None:
69+
def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None:
7070
if logger is True:
7171
version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id)
7272

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
from collections.abc import Generator
1515
from dataclasses import asdict, dataclass, replace
1616
from functools import partial, wraps
17-
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Tuple, Union
17+
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
1818

1919
import torch
2020
from torchmetrics import Metric
2121

2222
from pytorch_lightning.utilities import rank_zero_warn
2323
from pytorch_lightning.utilities.apply_func import apply_to_collection, apply_to_collections
24+
from pytorch_lightning.utilities.data import extract_batch_size
2425
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
2526
from pytorch_lightning.utilities.distributed import distributed_available
2627
from pytorch_lightning.utilities.enums import LightningEnum
@@ -589,31 +590,10 @@ def fn(item: ResultMetric) -> None:
589590

590591
def extract_batch_size(self, batch: Any) -> None:
591592
try:
592-
self.batch_size = self._extract_batch_size(batch)
593+
self.batch_size = extract_batch_size(batch)
593594
except RecursionError:
594595
self.batch_size = 1
595596

596-
def _extract_batch_size(self, batch: Any) -> int:
597-
"""
598-
Recursively unpack a batch to find a torch.Tensor.
599-
600-
Returns:
601-
``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable.
602-
"""
603-
if isinstance(batch, torch.Tensor):
604-
size = batch.size(0)
605-
elif isinstance(batch, str):
606-
return len(batch)
607-
elif isinstance(batch, dict):
608-
sample = next(iter(batch.values()), 1)
609-
size = self._extract_batch_size(sample)
610-
elif isinstance(batch, Iterable):
611-
sample = next(iter(batch), 1)
612-
size = self._extract_batch_size(sample)
613-
else:
614-
size = 1
615-
return size
616-
617597
def to(self, *args, **kwargs) -> 'ResultCollection':
618598
"""Move all data to the given device."""
619599

pytorch_lightning/trainer/trainer.py

Lines changed: 20 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -907,20 +907,30 @@ def _pre_dispatch(self):
907907

908908
def _log_hyperparams(self):
909909
# log hyper-parameters
910+
hparams_initial = None
911+
910912
if self.logger is not None:
911913
# save exp to get started (this is where the first experiment logs are written)
912-
datamodule_hparams = self.datamodule.hparams_initial if self.datamodule is not None else {}
913-
lightning_hparams = self.lightning_module.hparams_initial
914-
colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys()
915-
if colliding_keys:
916-
raise MisconfigurationException(
917-
f"Error while merging hparams: the keys {colliding_keys} are present "
918-
"in both the LightningModule's and LightningDataModule's hparams."
919-
)
914+
datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False
920915

921-
hparams_initial = {**lightning_hparams, **datamodule_hparams}
916+
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
917+
datamodule_hparams = self.datamodule.hparams_initial
918+
lightning_hparams = self.lightning_module.hparams_initial
922919

923-
self.logger.log_hyperparams(hparams_initial)
920+
colliding_keys = lightning_hparams.keys() & datamodule_hparams.keys()
921+
if colliding_keys:
922+
raise MisconfigurationException(
923+
f"Error while merging hparams: the keys {colliding_keys} are present "
924+
"in both the LightningModule's and LightningDataModule's hparams."
925+
)
926+
hparams_initial = {**lightning_hparams, **datamodule_hparams}
927+
elif self.lightning_module._log_hyperparams:
928+
hparams_initial = self.lightning_module.hparams_initial
929+
elif datamodule_log_hyperparams:
930+
hparams_initial = self.datamodule.hparams_initial
931+
932+
if hparams_initial is not None:
933+
self.logger.log_hyperparams(hparams_initial)
924934
self.logger.log_graph(self.lightning_module)
925935
self.logger.save()
926936

pytorch_lightning/utilities/data.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,12 +12,36 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15-
from typing import Union
15+
from typing import Any, Iterable, Mapping, Union
1616

17+
import torch
1718
from torch.utils.data import DataLoader, IterableDataset
1819

1920
from pytorch_lightning.utilities import rank_zero_warn
2021

22+
BType = Union[torch.Tensor, str, Mapping[Any, 'BType'], Iterable['BType']]
23+
24+
25+
def extract_batch_size(batch: BType) -> int:
26+
"""
27+
Recursively unpack a batch to find a torch.Tensor.
28+
29+
Returns:
30+
``len(tensor)`` when found, or ``1`` when it hits an empty or non iterable.
31+
"""
32+
if isinstance(batch, torch.Tensor):
33+
return batch.size(0)
34+
if isinstance(batch, str):
35+
return len(batch)
36+
if isinstance(batch, dict):
37+
sample = next(iter(batch.values()), 1)
38+
return extract_batch_size(sample)
39+
if isinstance(batch, Iterable):
40+
sample = next(iter(batch), 1)
41+
return extract_batch_size(sample)
42+
43+
return 1
44+
2145

2246
def has_iterable_dataset(dataloader: DataLoader):
2347
return hasattr(dataloader, 'dataset') and isinstance(dataloader.dataset, IterableDataset)

pytorch_lightning/utilities/device_dtype_mixin.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -121,10 +121,9 @@ def cuda(self, device: Optional[Union[torch.device, int]] = None) -> 'DeviceDtyp
121121
Returns:
122122
Module: self
123123
"""
124-
property_device = (
125-
device if isinstance(device, torch.device) else torch.device('cuda', index=device) # type: ignore
126-
) # mypy expects `device` for `index` to be int, while `Optional[int]` is okay => ignore typing for now
127-
self.__update_properties(device=property_device)
124+
if device is None or isinstance(device, int):
125+
device = torch.device('cuda', index=device)
126+
self.__update_properties(device=device)
128127
return super().cuda(device=device)
129128

130129
def cpu(self) -> 'DeviceDtypeModuleMixin':

0 commit comments

Comments
 (0)