From b6cda178bc1816f1889daad9f90cb4c649bf6a53 Mon Sep 17 00:00:00 2001 From: Masahiro Wada Date: Sun, 3 Jul 2022 08:50:29 +0000 Subject: [PATCH 01/21] Fix type hints --- pyproject.toml | 1 - src/pytorch_lightning/trainer/trainer.py | 2 +- .../tuner/batch_size_scaling.py | 20 ++++++++++++++----- 3 files changed, 16 insertions(+), 7 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 51781d4953935..4a8a643ac905e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -88,7 +88,6 @@ module = [ "pytorch_lightning.trainer.data_loading", "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", - "pytorch_lightning.tuner.batch_size_scaling", "pytorch_lightning.tuner.lr_finder", "pytorch_lightning.tuner.tuning", "pytorch_lightning.utilities.auto_restart", diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 7201ef53501c0..1d511371585a6 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -601,7 +601,7 @@ def _init_debugging_flags( "Logging and checkpointing is suppressed." ) - self.limit_train_batches = _determine_batch_limits(limit_train_batches, "limit_train_batches") + self.limit_train_batches: Union[int, float] = _determine_batch_limits(limit_train_batches, "limit_train_batches") self.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches") self.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches") self.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches") diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 316fc5a2197da..c2c821375822c 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -14,12 +14,14 @@ import logging import os import uuid -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Union, List, Optional, Tuple from torch.utils.data import DataLoader import pytorch_lightning as pl -from pytorch_lightning.loggers.logger import DummyLogger +from pytorch_lightning.callbacks.callback import Callback +from pytorch_lightning.loggers.logger import DummyLogger, Logger + from pytorch_lightning.utilities.data import has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error @@ -41,7 +43,7 @@ def scale_batch_size( """See :meth:`~pytorch_lightning.tuner.tuning.Tuner.scale_batch_size`""" if trainer.fast_dev_run: rank_zero_warn("Skipping batch size scaler since fast_dev_run is enabled.") - return + return None if not lightning_hasattr(model, batch_arg_name): raise MisconfigurationException(f"Field {batch_arg_name} not found in both `model` and `model.hparams`") @@ -234,11 +236,19 @@ def _adjust_batch_size( """ model = trainer.lightning_module batch_size = lightning_getattr(model, batch_arg_name) - new_size = value if value is not None else int(batch_size * factor) + if value is not None: + new_size = value + else: + if not isinstance(batch_size, int): + raise ValueError(f"value is None and batch_size is not int value: {batch_size}") + new_size = int(batch_size * factor) + if desc: log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}") if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer): + if not isinstance(trainer.train_dataloader, DataLoader): + raise ValueError("train_dataloader is not a DataLoader") new_size = min(new_size, len(trainer.train_dataloader.dataset)) changed = new_size != batch_size @@ -246,6 +256,6 @@ def _adjust_batch_size( return new_size, changed -def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"): +def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer") -> bool: module = trainer.lightning_module or trainer.datamodule return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader) From 3f1d70ccf94d510c438be96bc710bd2fc565ea8c Mon Sep 17 00:00:00 2001 From: Masahiro Wada Date: Sun, 3 Jul 2022 13:01:47 +0000 Subject: [PATCH 02/21] fix --- src/pytorch_lightning/tuner/batch_size_scaling.py | 4 +--- src/pytorch_lightning/utilities/data.py | 2 +- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index c2c821375822c..19b13183c9191 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -247,8 +247,6 @@ def _adjust_batch_size( log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}") if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer): - if not isinstance(trainer.train_dataloader, DataLoader): - raise ValueError("train_dataloader is not a DataLoader") new_size = min(new_size, len(trainer.train_dataloader.dataset)) changed = new_size != batch_size @@ -256,6 +254,6 @@ def _adjust_batch_size( return new_size, changed -def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer") -> bool: +def _is_valid_batch_size(batch_size: int, dataloader: Any, trainer: "pl.Trainer") -> bool: module = trainer.lightning_module or trainer.datamodule return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index 2de82ceff088e..ec596ccc7b2ab 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -126,7 +126,7 @@ def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: def has_len_all_ranks( - dataloader: DataLoader, + dataloader: Any, training_type: "pl.Strategy", model: Union["pl.LightningModule", "pl.LightningDataModule"], ) -> bool: From ab9bcba3f8b1c7b6b5f76c31fae2372aabf16908 Mon Sep 17 00:00:00 2001 From: Masahiro Wada Date: Sun, 3 Jul 2022 13:21:11 +0000 Subject: [PATCH 03/21] Use Any to use data.py --- src/pytorch_lightning/tuner/batch_size_scaling.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 19b13183c9191..5e8fb774b2b87 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -246,8 +246,9 @@ def _adjust_batch_size( if desc: log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}") - if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer): - new_size = min(new_size, len(trainer.train_dataloader.dataset)) + train_dataloader: Any = trainer.train_dataloader + if not _is_valid_batch_size(new_size, train_dataloader, trainer): + new_size = min(new_size, len(train_dataloader.dataset)) changed = new_size != batch_size lightning_setattr(model, batch_arg_name, new_size) From 8361712f7f99c36ab6accb7f85138f85b5dafac6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 3 Jul 2022 13:24:02 +0000 Subject: [PATCH 04/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/trainer/trainer.py | 4 +++- src/pytorch_lightning/tuner/batch_size_scaling.py | 3 +-- 2 files changed, 4 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 1d511371585a6..b6144520a7108 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -601,7 +601,9 @@ def _init_debugging_flags( "Logging and checkpointing is suppressed." ) - self.limit_train_batches: Union[int, float] = _determine_batch_limits(limit_train_batches, "limit_train_batches") + self.limit_train_batches: Union[int, float] = _determine_batch_limits( + limit_train_batches, "limit_train_batches" + ) self.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches") self.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches") self.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches") diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 5e8fb774b2b87..3aeda6e6d001a 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -14,14 +14,13 @@ import logging import os import uuid -from typing import Any, Dict, Union, List, Optional, Tuple +from typing import Any, Dict, List, Optional, Tuple, Union from torch.utils.data import DataLoader import pytorch_lightning as pl from pytorch_lightning.callbacks.callback import Callback from pytorch_lightning.loggers.logger import DummyLogger, Logger - from pytorch_lightning.utilities.data import has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error From fb088c38b6a1c16bf7694159242ac611eb150e31 Mon Sep 17 00:00:00 2001 From: Masahiro Wada Date: Sun, 3 Jul 2022 21:37:50 +0000 Subject: [PATCH 05/21] Remove unused import --- src/pytorch_lightning/tuner/batch_size_scaling.py | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 3aeda6e6d001a..63094a14bddf0 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -14,13 +14,9 @@ import logging import os import uuid -from typing import Any, Dict, List, Optional, Tuple, Union - -from torch.utils.data import DataLoader +from typing import Any, Dict, Optional, Tuple import pytorch_lightning as pl -from pytorch_lightning.callbacks.callback import Callback -from pytorch_lightning.loggers.logger import DummyLogger, Logger from pytorch_lightning.utilities.data import has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error From b5270c7ddf214eff57dc641dceb92ebd55c25de8 Mon Sep 17 00:00:00 2001 From: Masahiro Wada Date: Sun, 3 Jul 2022 21:41:00 +0000 Subject: [PATCH 06/21] Fix missing import --- src/pytorch_lightning/tuner/batch_size_scaling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 63094a14bddf0..8a10b4b6fde7c 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -17,6 +17,7 @@ from typing import Any, Dict, Optional, Tuple import pytorch_lightning as pl +from pytorch_lightning.loggers.logger import DummyLogger from pytorch_lightning.utilities.data import has_len_all_ranks from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.memory import garbage_collection_cuda, is_oom_error From 61e3c8f872f45e689c32bc98213055857d35d8cd Mon Sep 17 00:00:00 2001 From: Masahiro Wada Date: Mon, 4 Jul 2022 23:44:23 +0900 Subject: [PATCH 07/21] Update src/pytorch_lightning/tuner/batch_size_scaling.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/tuner/batch_size_scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 8a10b4b6fde7c..ccb78b87eb170 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -236,7 +236,7 @@ def _adjust_batch_size( new_size = value else: if not isinstance(batch_size, int): - raise ValueError(f"value is None and batch_size is not int value: {batch_size}") + raise ValueError(f"Batch size attribute in LightningModule must be an integer, got: {batch_size!r}") new_size = int(batch_size * factor) if desc: From 4d952ec92887b923bf9f2366a8cf867865170edd Mon Sep 17 00:00:00 2001 From: Masahiro Wada Date: Tue, 5 Jul 2022 03:02:26 +0000 Subject: [PATCH 08/21] Remove unneccesary fix --- src/pytorch_lightning/utilities/data.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/utilities/data.py b/src/pytorch_lightning/utilities/data.py index ec596ccc7b2ab..2de82ceff088e 100644 --- a/src/pytorch_lightning/utilities/data.py +++ b/src/pytorch_lightning/utilities/data.py @@ -126,7 +126,7 @@ def has_len(dataloader: Union[DataLoader, Iterable]) -> bool: def has_len_all_ranks( - dataloader: Any, + dataloader: DataLoader, training_type: "pl.Strategy", model: Union["pl.LightningModule", "pl.LightningDataModule"], ) -> bool: From 5419848c30a28bf32ab9b9360b42faee68f46ee4 Mon Sep 17 00:00:00 2001 From: Masahiro Wada Date: Tue, 5 Jul 2022 03:03:35 +0000 Subject: [PATCH 09/21] Use assert to narrow types --- src/pytorch_lightning/tuner/batch_size_scaling.py | 14 ++++++++------ 1 file changed, 8 insertions(+), 6 deletions(-) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index ccb78b87eb170..4a2c69d0d3d09 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -242,15 +242,17 @@ def _adjust_batch_size( if desc: log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}") - train_dataloader: Any = trainer.train_dataloader - if not _is_valid_batch_size(new_size, train_dataloader, trainer): - new_size = min(new_size, len(train_dataloader.dataset)) + if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer): + assert trainer.train_dataloader is not None + new_size = min(new_size, len(trainer.train_dataloader.dataset)) changed = new_size != batch_size lightning_setattr(model, batch_arg_name, new_size) return new_size, changed - -def _is_valid_batch_size(batch_size: int, dataloader: Any, trainer: "pl.Trainer") -> bool: +def _is_valid_batch_size(batch_size: int, dataloader: Optional[Any], trainer: "pl.Trainer") -> bool: module = trainer.lightning_module or trainer.datamodule - return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader) + if not has_len_all_ranks(dataloader, trainer.strategy, module): + return True + assert dataloader is not None + return batch_size <= len(dataloader) From 1a1e265ed8ec68d2b7b8d00ae8724a3f020aa70f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Jul 2022 03:08:34 +0000 Subject: [PATCH 10/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/tuner/batch_size_scaling.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 4a2c69d0d3d09..148625064e105 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -250,6 +250,7 @@ def _adjust_batch_size( lightning_setattr(model, batch_arg_name, new_size) return new_size, changed + def _is_valid_batch_size(batch_size: int, dataloader: Optional[Any], trainer: "pl.Trainer") -> bool: module = trainer.lightning_module or trainer.datamodule if not has_len_all_ranks(dataloader, trainer.strategy, module): From ca45a66f4cdf91ec122f10f3e1905b779b5b787a Mon Sep 17 00:00:00 2001 From: Masahiro Wada Date: Tue, 5 Jul 2022 03:32:50 +0000 Subject: [PATCH 11/21] Add None check explicitly To fix type check issue, add None check explicitly. This early return dons't change the behabior of _is_valid_batch_size. Because has_len_all_ranks always return True if dataloader is None. --- src/pytorch_lightning/tuner/batch_size_scaling.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 148625064e105..2458d7f2e694a 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -252,8 +252,7 @@ def _adjust_batch_size( def _is_valid_batch_size(batch_size: int, dataloader: Optional[Any], trainer: "pl.Trainer") -> bool: - module = trainer.lightning_module or trainer.datamodule - if not has_len_all_ranks(dataloader, trainer.strategy, module): + if dataloader is None: return True - assert dataloader is not None - return batch_size <= len(dataloader) + module = trainer.lightning_module or trainer.datamodule + return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader) \ No newline at end of file From d23e262703f380357bf896a1675dd88f7a0cdc61 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 5 Jul 2022 03:37:53 +0000 Subject: [PATCH 12/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/tuner/batch_size_scaling.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 2458d7f2e694a..0c992e9bb8bf7 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -255,4 +255,4 @@ def _is_valid_batch_size(batch_size: int, dataloader: Optional[Any], trainer: "p if dataloader is None: return True module = trainer.lightning_module or trainer.datamodule - return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader) \ No newline at end of file + return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader) From 122c9302d039fade1e9097b982b0e88a14721ac7 Mon Sep 17 00:00:00 2001 From: Masahiro Wada Date: Sun, 17 Jul 2022 11:12:36 +0000 Subject: [PATCH 13/21] Simplify type narrowing Take some codes from #11089 to simplify type narrowing. --- .../tuner/batch_size_scaling.py | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 0c992e9bb8bf7..3c3d09e91d6f2 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -16,6 +16,8 @@ import uuid from typing import Any, Dict, Optional, Tuple +from torch.utils.data import DataLoader + import pytorch_lightning as pl from pytorch_lightning.loggers.logger import DummyLogger from pytorch_lightning.utilities.data import has_len_all_ranks @@ -232,18 +234,13 @@ def _adjust_batch_size( """ model = trainer.lightning_module batch_size = lightning_getattr(model, batch_arg_name) - if value is not None: - new_size = value - else: - if not isinstance(batch_size, int): - raise ValueError(f"Batch size attribute in LightningModule must be an integer, got: {batch_size!r}") - new_size = int(batch_size * factor) - + assert batch_size is not None + new_size: int = value if value is not None else int(batch_size * factor) if desc: log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}") + assert trainer.train_dataloader is not None if not _is_valid_batch_size(new_size, trainer.train_dataloader, trainer): - assert trainer.train_dataloader is not None new_size = min(new_size, len(trainer.train_dataloader.dataset)) changed = new_size != batch_size @@ -251,8 +248,6 @@ def _adjust_batch_size( return new_size, changed -def _is_valid_batch_size(batch_size: int, dataloader: Optional[Any], trainer: "pl.Trainer") -> bool: - if dataloader is None: - return True +def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer") -> bool: module = trainer.lightning_module or trainer.datamodule return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader) From 3344c4824e342c73ef70a624b65ec5544d19811b Mon Sep 17 00:00:00 2001 From: Masahiro Wada Date: Sun, 17 Jul 2022 11:47:19 +0000 Subject: [PATCH 14/21] Fix wrong code merging --- pyproject.toml | 38 +++++++------------ src/pytorch_lightning/trainer/trainer.py | 4 +- .../tuner/batch_size_scaling.py | 2 +- 3 files changed, 16 insertions(+), 28 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 4a8a643ac905e..3886f7a555091 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -22,7 +22,13 @@ line-length = 120 [tool.mypy] -files = ["pytorch_lightning"] +files = [ + "src/pytorch_lightning", + # TODO: Check typing in app source + # "src/lightning_app", +] +install_types = "True" +non_interactive = "True" disallow_untyped_defs = "True" ignore_missing_imports = "True" show_error_codes = "True" @@ -39,60 +45,44 @@ warn_no_return = "False" # TODO: the goal is for this to be empty [[tool.mypy.overrides]] # the list can be generated with: -# mypy | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g' | sed 's|\/|\.|g' | xargs -I {} echo '"{}",' +# mypy --no-error-summary 2>&1 | tr ':' ' ' | awk '{print $1}' | sort | uniq | sed 's/\.py//g; s|src/||g; s|\/|\.|g' | xargs -I {} echo '"{}",' module = [ - "pytorch_lightning.callbacks.finetuning", "pytorch_lightning.callbacks.model_checkpoint", "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.callbacks.quantization", "pytorch_lightning.callbacks.stochastic_weight_avg", "pytorch_lightning.core.datamodule", "pytorch_lightning.core.decorators", - "pytorch_lightning.core.module", "pytorch_lightning.core.mixins.device_dtype_mixin", + "pytorch_lightning.core.module", "pytorch_lightning.core.saving", "pytorch_lightning.demos.boring_classes", "pytorch_lightning.demos.mnist_datamodule", - "pytorch_lightning.distributed.dist", - "pytorch_lightning.loggers.base", - "pytorch_lightning.loggers.logger", "pytorch_lightning.loggers.comet", - "pytorch_lightning.loggers.csv_logs", "pytorch_lightning.loggers.mlflow", "pytorch_lightning.loggers.neptune", "pytorch_lightning.loggers.tensorboard", "pytorch_lightning.loggers.wandb", - "pytorch_lightning.loops.epoch.training_epoch_loop", + "pytorch_lightning.profilers.advanced", + "pytorch_lightning.profilers.base", + "pytorch_lightning.profilers.pytorch", + "pytorch_lightning.profilers.simple", "pytorch_lightning.strategies.ddp", - "pytorch_lightning.strategies.ddp2", "pytorch_lightning.strategies.ddp_spawn", "pytorch_lightning.strategies.deepspeed", - "pytorch_lightning.strategies.dp", "pytorch_lightning.strategies.fully_sharded", - "pytorch_lightning.strategies.horovod", "pytorch_lightning.strategies.ipu", - "pytorch_lightning.strategies.parallel", "pytorch_lightning.strategies.sharded", "pytorch_lightning.strategies.sharded_spawn", - "pytorch_lightning.strategies.single_device", - "pytorch_lightning.strategies.single_tpu", - "pytorch_lightning.strategies.tpu_spawn", "pytorch_lightning.strategies.strategy", - "pytorch_lightning.profilers.advanced", - "pytorch_lightning.profilers.base", - "pytorch_lightning.profilers.pytorch", - "pytorch_lightning.profilers.simple", + "pytorch_lightning.strategies.tpu_spawn", "pytorch_lightning.trainer.callback_hook", "pytorch_lightning.trainer.connectors.callback_connector", "pytorch_lightning.trainer.connectors.data_connector", - "pytorch_lightning.trainer.data_loading", "pytorch_lightning.trainer.supporters", "pytorch_lightning.trainer.trainer", - "pytorch_lightning.tuner.lr_finder", - "pytorch_lightning.tuner.tuning", "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.data", - "pytorch_lightning.utilities.distributed", "pytorch_lightning.utilities.meta", ] ignore_errors = "True" diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index 363bc15ab5cbf..ff8113cb63c8b 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -599,9 +599,7 @@ def _init_debugging_flags( "Logging and checkpointing is suppressed." ) - self.limit_train_batches: Union[int, float] = _determine_batch_limits( - limit_train_batches, "limit_train_batches" - ) + self.limit_train_batches: Union[int, float] = _determine_batch_limits(limit_train_batches, "limit_train_batches") self.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches") self.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches") self.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches") diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 3c3d09e91d6f2..ce2034290b00c 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -235,7 +235,7 @@ def _adjust_batch_size( model = trainer.lightning_module batch_size = lightning_getattr(model, batch_arg_name) assert batch_size is not None - new_size: int = value if value is not None else int(batch_size * factor) + new_size = value if value is not None else int(batch_size * factor) if desc: log.info(f"Batch size {batch_size} {desc}, trying batch size {new_size}") From 5c4b36042aebc354281698a0e9e6f3b1da60292f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sun, 17 Jul 2022 11:49:05 +0000 Subject: [PATCH 15/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/trainer/trainer.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/trainer/trainer.py b/src/pytorch_lightning/trainer/trainer.py index ff8113cb63c8b..363bc15ab5cbf 100644 --- a/src/pytorch_lightning/trainer/trainer.py +++ b/src/pytorch_lightning/trainer/trainer.py @@ -599,7 +599,9 @@ def _init_debugging_flags( "Logging and checkpointing is suppressed." ) - self.limit_train_batches: Union[int, float] = _determine_batch_limits(limit_train_batches, "limit_train_batches") + self.limit_train_batches: Union[int, float] = _determine_batch_limits( + limit_train_batches, "limit_train_batches" + ) self.limit_val_batches = _determine_batch_limits(limit_val_batches, "limit_val_batches") self.limit_test_batches = _determine_batch_limits(limit_test_batches, "limit_test_batches") self.limit_predict_batches = _determine_batch_limits(limit_predict_batches, "limit_predict_batches") From e8754de4262c871b36ebd2c5a4d74b5ae7abf5a5 Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 14 Sep 2022 15:32:10 +0200 Subject: [PATCH 16/21] , --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index c016cc0ea952a..de0b4c21e8dd0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,7 +53,7 @@ warn_no_return = "False" module = [ "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.trainer.trainer", - "pytorch_lightning.tuner.batch_size_scaling" + "pytorch_lightning.tuner.batch_size_scaling", "pytorch_lightning.utilities.auto_restart", "pytorch_lightning.utilities.data", "lightning_lite.utilities.data", From 319c4ffca2c01fe5f87f4809372b2a859f8354fb Mon Sep 17 00:00:00 2001 From: Jirka Date: Wed, 14 Sep 2022 15:33:14 +0200 Subject: [PATCH 17/21] .. --- pyproject.toml | 4 ---- 1 file changed, 4 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index de0b4c21e8dd0..8b3fd290c38c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -53,9 +53,5 @@ warn_no_return = "False" module = [ "pytorch_lightning.callbacks.progress.rich_progress", "pytorch_lightning.trainer.trainer", - "pytorch_lightning.tuner.batch_size_scaling", - "pytorch_lightning.utilities.auto_restart", - "pytorch_lightning.utilities.data", - "lightning_lite.utilities.data", ] ignore_errors = "True" From 28ac120824b00ecd59051fd11ac1b0b543a988b0 Mon Sep 17 00:00:00 2001 From: Masahiro Wada Date: Wed, 28 Sep 2022 15:11:27 +0000 Subject: [PATCH 18/21] Add type hints --- src/pytorch_lightning/tuner/batch_size_scaling.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index a85ef6a814fc9..1eeda7838c044 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -35,10 +35,10 @@ def scale_batch_size( init_val: int = 2, max_trials: int = 25, batch_arg_name: str = "batch_size", -): +) -> Optional[int]: if trainer.fast_dev_run: rank_zero_warn("Skipping batch size scaler since `fast_dev_run` is enabled.") - return + return None # Save initial model, that is loaded after batch size is found ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt") @@ -141,7 +141,7 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) def _run_power_scaling( - trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params + trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params: Dict[str, Any] ) -> int: """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered.""" # this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not @@ -179,7 +179,7 @@ def _run_power_scaling( def _run_binary_scaling( - trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params + trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params: Dict[str, Any] ) -> int: """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. @@ -309,7 +309,7 @@ def _reset_dataloaders(trainer: "pl.Trainer", pl_module: "pl.LightningModule") - reset_fn(pl_module) -def _try_loop_run(trainer: "pl.Trainer", params) -> None: +def _try_loop_run(trainer: "pl.Trainer", params: Dict[str, Any]) -> None: if trainer.state.fn == "fit": loop = trainer.fit_loop else: From ffeb484a35cd1cba457147790b01e65c57a31b5c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 28 Sep 2022 15:17:07 +0000 Subject: [PATCH 19/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/tuner/batch_size_scaling.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/tuner/batch_size_scaling.py b/src/pytorch_lightning/tuner/batch_size_scaling.py index 1eeda7838c044..781c7ee1196e1 100644 --- a/src/pytorch_lightning/tuner/batch_size_scaling.py +++ b/src/pytorch_lightning/tuner/batch_size_scaling.py @@ -141,7 +141,12 @@ def __scale_batch_restore_params(trainer: "pl.Trainer", params: Dict[str, Any]) def _run_power_scaling( - trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params: Dict[str, Any] + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + new_size: int, + batch_arg_name: str, + max_trials: int, + params: Dict[str, Any], ) -> int: """Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered.""" # this flag is used to determine whether the previously scaled batch size, right before OOM, was a success or not @@ -179,7 +184,12 @@ def _run_power_scaling( def _run_binary_scaling( - trainer: "pl.Trainer", pl_module: "pl.LightningModule", new_size: int, batch_arg_name: str, max_trials: int, params: Dict[str, Any] + trainer: "pl.Trainer", + pl_module: "pl.LightningModule", + new_size: int, + batch_arg_name: str, + max_trials: int, + params: Dict[str, Any], ) -> int: """Batch scaling mode where the size is initially is doubled at each iteration until an OOM error is encountered. From d7f7388930b5af9f1cc40b2781d70e8d9e36a886 Mon Sep 17 00:00:00 2001 From: Masahiro Wada Date: Wed, 28 Sep 2022 23:01:27 +0000 Subject: [PATCH 20/21] Add property type annotation --- src/pytorch_lightning/callbacks/batch_size_finder.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/src/pytorch_lightning/callbacks/batch_size_finder.py b/src/pytorch_lightning/callbacks/batch_size_finder.py index d4a8d37da4c88..48d48eca274e3 100644 --- a/src/pytorch_lightning/callbacks/batch_size_finder.py +++ b/src/pytorch_lightning/callbacks/batch_size_finder.py @@ -31,6 +31,8 @@ class BatchSizeFinder(Callback): SUPPORTED_MODES = ("power", "binsearch") + optimal_batch_size: Optional[int] + def __init__( self, mode: str = "power", From c6c793b64570356e96863e2c9d468623068ada27 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 29 Sep 2022 01:17:14 +0000 Subject: [PATCH 21/21] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/callbacks/batch_size_finder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/callbacks/batch_size_finder.py b/src/pytorch_lightning/callbacks/batch_size_finder.py index 48d48eca274e3..96b9f6eef874e 100644 --- a/src/pytorch_lightning/callbacks/batch_size_finder.py +++ b/src/pytorch_lightning/callbacks/batch_size_finder.py @@ -31,7 +31,7 @@ class BatchSizeFinder(Callback): SUPPORTED_MODES = ("power", "binsearch") - optimal_batch_size: Optional[int] + optimal_batch_size: Optional[int] def __init__( self,