From 12443ef7623bc27dd0c1429c4074b5cf97253d25 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Oct 2023 18:11:33 +0200 Subject: [PATCH 1/9] update connector --- .../pytorch/bug_report/bug_report_model.py | 23 ++++++++----------- .../connectors/accelerator_connector.py | 15 +++++++++++- 2 files changed, 23 insertions(+), 15 deletions(-) diff --git a/examples/pytorch/bug_report/bug_report_model.py b/examples/pytorch/bug_report/bug_report_model.py index aa3f4cad710fe..3a091147da73b 100644 --- a/examples/pytorch/bug_report/bug_report_model.py +++ b/examples/pytorch/bug_report/bug_report_model.py @@ -4,6 +4,8 @@ from lightning.pytorch import LightningModule, Trainer from torch.utils.data import DataLoader, Dataset +from pytorch.cli import LightningCLI + class RandomDataset(Dataset): def __init__(self, size, length): @@ -45,20 +47,13 @@ def configure_optimizers(self): def run(): train_data = DataLoader(RandomDataset(32, 64), batch_size=2) val_data = DataLoader(RandomDataset(32, 64), batch_size=2) - test_data = DataLoader(RandomDataset(32, 64), batch_size=2) - - model = BoringModel() - trainer = Trainer( - default_root_dir=os.getcwd(), - limit_train_batches=1, - limit_val_batches=1, - limit_test_batches=1, - num_sanity_val_steps=0, - max_epochs=1, - enable_model_summary=False, - ) - trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) - trainer.test(model, dataloaders=test_data) + + import os + os.environ["PL_TRAINER_ACCELERATOR"] = "cpu" + os.environ["PL_TRAINER_DEVICES"] = "2" + + cli = LightningCLI(BoringModel, run=False, trainer_defaults={"max_epochs": 1}) + cli.trainer.fit(cli.model, train_data, val_data) if __name__ == "__main__": diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 540161c300400..7bc68ce3ee592 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -15,7 +15,7 @@ import logging import os from collections import Counter -from typing import Dict, List, Literal, Optional, Union +from typing import Dict, List, Literal, Optional, Union, Any import torch @@ -116,6 +116,12 @@ def __init__( B. Strategy > Accelerator/precision/plugins """ + accelerator = self._argument_from_env("accelerator", accelerator, default="auto") + strategy = self._argument_from_env("strategy", strategy, default="auto") + devices = self._argument_from_env("devices", devices, default="auto") + num_nodes = int(self._argument_from_env("num_nodes", num_nodes, default=1)) + precision = self._argument_from_env("precision", precision, default="32-true") + self.use_distributed_sampler = use_distributed_sampler _set_torch_flags(deterministic=deterministic, benchmark=benchmark) @@ -656,6 +662,13 @@ def is_distributed(self) -> bool: return self.strategy.is_distributed return False + @staticmethod + def _argument_from_env(name: str, current: Any, default: Any) -> Any: + env_value: Optional[str] = os.environ.get("PL_TRAINER_" + name.upper()) + if env_value is None: + return current + return env_value + def _set_torch_flags( *, deterministic: Optional[Union[bool, _LITERAL_WARN]] = None, benchmark: Optional[bool] = None From 78ea214f53d8b51473036f5219487b9c957e9098 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Oct 2023 21:46:03 +0200 Subject: [PATCH 2/9] undo --- .../trainer/connectors/accelerator_connector.py | 15 +-------------- 1 file changed, 1 insertion(+), 14 deletions(-) diff --git a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py index 7bc68ce3ee592..540161c300400 100644 --- a/src/lightning/pytorch/trainer/connectors/accelerator_connector.py +++ b/src/lightning/pytorch/trainer/connectors/accelerator_connector.py @@ -15,7 +15,7 @@ import logging import os from collections import Counter -from typing import Dict, List, Literal, Optional, Union, Any +from typing import Dict, List, Literal, Optional, Union import torch @@ -116,12 +116,6 @@ def __init__( B. Strategy > Accelerator/precision/plugins """ - accelerator = self._argument_from_env("accelerator", accelerator, default="auto") - strategy = self._argument_from_env("strategy", strategy, default="auto") - devices = self._argument_from_env("devices", devices, default="auto") - num_nodes = int(self._argument_from_env("num_nodes", num_nodes, default=1)) - precision = self._argument_from_env("precision", precision, default="32-true") - self.use_distributed_sampler = use_distributed_sampler _set_torch_flags(deterministic=deterministic, benchmark=benchmark) @@ -662,13 +656,6 @@ def is_distributed(self) -> bool: return self.strategy.is_distributed return False - @staticmethod - def _argument_from_env(name: str, current: Any, default: Any) -> Any: - env_value: Optional[str] = os.environ.get("PL_TRAINER_" + name.upper()) - if env_value is None: - return current - return env_value - def _set_torch_flags( *, deterministic: Optional[Union[bool, _LITERAL_WARN]] = None, benchmark: Optional[bool] = None From c3b3a78e1bb940822e630a9637a502b5d3a462d0 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Oct 2023 21:46:09 +0200 Subject: [PATCH 3/9] precedence --- src/lightning/pytorch/utilities/argparse.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/utilities/argparse.py b/src/lightning/pytorch/utilities/argparse.py index eb7273b54a577..41bd1d7948fe2 100644 --- a/src/lightning/pytorch/utilities/argparse.py +++ b/src/lightning/pytorch/utilities/argparse.py @@ -64,7 +64,7 @@ def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any: kwargs.update(dict(zip(cls_arg_names, args))) env_variables = vars(_parse_env_variables(cls)) # update the kwargs by env variables - kwargs = dict(list(env_variables.items()) + list(kwargs.items())) + kwargs = dict(list(kwargs.items()) + list(env_variables.items())) # all args were already moved to kwargs return fn(self, **kwargs) From 996363f22abe2eea3d31ce3ffa04755027893960 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Oct 2023 21:51:47 +0200 Subject: [PATCH 4/9] update --- src/lightning/pytorch/trainer/trainer.py | 4 ++-- src/lightning/pytorch/utilities/argparse.py | 6 +++--- tests/tests_pytorch/trainer/flags/test_env_vars.py | 12 +++++++----- 3 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index d0dcb8f437558..56885bec82a9b 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -66,7 +66,7 @@ from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector from lightning.pytorch.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus from lightning.pytorch.utilities import GradClipAlgorithmType, parsing -from lightning.pytorch.utilities.argparse import _defaults_from_env_vars +from lightning.pytorch.utilities.argparse import _overrides_from_env_vars from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized, _verify_strategy_supports_compile from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import is_overridden @@ -89,7 +89,7 @@ class Trainer: - @_defaults_from_env_vars + @_overrides_from_env_vars def __init__( self, *, diff --git a/src/lightning/pytorch/utilities/argparse.py b/src/lightning/pytorch/utilities/argparse.py index 41bd1d7948fe2..285d935fe8d33 100644 --- a/src/lightning/pytorch/utilities/argparse.py +++ b/src/lightning/pytorch/utilities/argparse.py @@ -53,9 +53,9 @@ def _parse_env_variables(cls: Type, template: str = "PL_%(cls_name)s_%(cls_argum return Namespace(**env_args) -def _defaults_from_env_vars(fn: _T) -> _T: +def _overrides_from_env_vars(fn: _T) -> _T: @wraps(fn) - def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any: + def update_kwargs_from_env(self: Any, *args: Any, **kwargs: Any) -> Any: cls = self.__class__ # get the class if args: # in case any args passed move them to kwargs # parse the argument names @@ -69,4 +69,4 @@ def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any: # all args were already moved to kwargs return fn(self, **kwargs) - return cast(_T, insert_env_defaults) + return cast(_T, update_kwargs_from_env) diff --git a/tests/tests_pytorch/trainer/flags/test_env_vars.py b/tests/tests_pytorch/trainer/flags/test_env_vars.py index 62c94d4cc277e..5c2b508364301 100644 --- a/tests/tests_pytorch/trainer/flags/test_env_vars.py +++ b/tests/tests_pytorch/trainer/flags/test_env_vars.py @@ -44,14 +44,16 @@ def test_passing_env_variables_only(): def test_passing_env_variables_defaults(): """Testing overwriting trainer arguments.""" trainer = Trainer(logger=False, max_steps=42) - assert trainer.logger is None - assert trainer.max_steps == 42 + assert trainer.logger is not None + assert trainer.max_steps == 7 -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_DEVICES": "2"}) -def test_passing_env_variables_devices(cuda_count_2): +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_DEVICES": "2", "PL_TRAINER_NUM_NODES": "4"}) +def test_passing_env_variables_devices(cuda_count_2, mps_count_0): """Testing overwriting trainer arguments.""" trainer = Trainer() assert trainer.num_devices == 2 + assert trainer.num_nodes == 4 trainer = Trainer(accelerator="gpu", devices=1) - assert trainer.num_devices == 1 + assert trainer.num_devices == 2 + assert trainer.num_nodes == 4 From 09b6707c3cbf47f61f03cb2f601bf3daece857b4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Oct 2023 21:59:12 +0200 Subject: [PATCH 5/9] test --- tests/tests_pytorch/test_cli.py | 17 +++++++++++++++++ 1 file changed, 17 insertions(+) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index fdc49f787c759..326418bfe95d1 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -58,6 +58,8 @@ from tests_pytorch.helpers.runif import RunIf +from lightning.pytorch.accelerators import CPUAccelerator + if _JSONARGPARSE_SIGNATURES_AVAILABLE: from jsonargparse import Namespace, lazy_instance else: @@ -288,6 +290,21 @@ def test_lightning_env_parse(cleandir): assert cli.config.fit.trainer.logger is False +def test_lightning_cli_and_trainer_kwargs_override(cleandir): + env_vars = { + "PL_TRAINER_ACCELERATOR": "cpu", + "PL_TRAINER_DEVICES": "2", + "PL_TRAINER_NUM_NODES": "4", + "PL_TRAINER_PRECISION": "16-true", + } + with mock.patch.dict(os.environ, env_vars), mock.patch("sys.argv", [""]): + cli = LightningCLI(BoringModel, run=False) + assert isinstance(cli.trainer.accelerator, CPUAccelerator) + assert cli.trainer.num_devices == 2 + assert cli.trainer.num_nodes == 4 + assert cli.trainer.precision == "16-true" + + def test_lightning_cli_save_config_cases(cleandir): config_path = "config.yaml" cli_args = ["fit", "--trainer.logger=false", "--trainer.fast_dev_run=1"] From 363ed8346b05ed1cff42f2a696b58c93a95c036e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Oct 2023 22:03:13 +0200 Subject: [PATCH 6/9] chlog --- src/lightning/pytorch/CHANGELOG.md | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index a364618a3eb36..c96a461cdf5f8 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846)) - `LightningCLI` no longer allows setting a normal class instance as default. A `lazy_instance` can be used instead ([#18822](https://github.com/Lightning-AI/lightning/pull/18822)) +- The `PL_TRAINER_*` env variables now takes precendence over the arguments passed to the Trainer ([#18876](https://github.com/Lightning-AI/lightning/issues/18876)) + + ### Deprecated @@ -36,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue when `BatchSizeFinder` `steps_per_trial` parameter ends up defining how many validation batches to run during the entire training ([#18394](https://github.com/Lightning-AI/lightning/issues/18394)) +- Fixed an issue that would prevent the user to override the Trainer arguments with the `PL_TRAINER_*` env variable ([#18876](https://github.com/Lightning-AI/lightning/issues/18876)) + ## [2.1.0] - 2023-10-11 From 4c939a92c7beff5605c96811d16791eb9a60cbfe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Thu, 26 Oct 2023 22:03:31 +0200 Subject: [PATCH 7/9] reset --- .../pytorch/bug_report/bug_report_model.py | 23 +++++++++++-------- 1 file changed, 14 insertions(+), 9 deletions(-) diff --git a/examples/pytorch/bug_report/bug_report_model.py b/examples/pytorch/bug_report/bug_report_model.py index 3a091147da73b..aa3f4cad710fe 100644 --- a/examples/pytorch/bug_report/bug_report_model.py +++ b/examples/pytorch/bug_report/bug_report_model.py @@ -4,8 +4,6 @@ from lightning.pytorch import LightningModule, Trainer from torch.utils.data import DataLoader, Dataset -from pytorch.cli import LightningCLI - class RandomDataset(Dataset): def __init__(self, size, length): @@ -47,13 +45,20 @@ def configure_optimizers(self): def run(): train_data = DataLoader(RandomDataset(32, 64), batch_size=2) val_data = DataLoader(RandomDataset(32, 64), batch_size=2) - - import os - os.environ["PL_TRAINER_ACCELERATOR"] = "cpu" - os.environ["PL_TRAINER_DEVICES"] = "2" - - cli = LightningCLI(BoringModel, run=False, trainer_defaults={"max_epochs": 1}) - cli.trainer.fit(cli.model, train_data, val_data) + test_data = DataLoader(RandomDataset(32, 64), batch_size=2) + + model = BoringModel() + trainer = Trainer( + default_root_dir=os.getcwd(), + limit_train_batches=1, + limit_val_batches=1, + limit_test_batches=1, + num_sanity_val_steps=0, + max_epochs=1, + enable_model_summary=False, + ) + trainer.fit(model, train_dataloaders=train_data, val_dataloaders=val_data) + trainer.test(model, dataloaders=test_data) if __name__ == "__main__": From e8e4d840d1c8829f22eb501588d45ea5aec9e5e0 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 26 Oct 2023 20:06:14 +0000 Subject: [PATCH 8/9] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/tests_pytorch/test_cli.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index 326418bfe95d1..032181dcd03a5 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -29,6 +29,7 @@ import yaml from lightning.fabric.plugins.environments import SLURMEnvironment from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__, seed_everything +from lightning.pytorch.accelerators import CPUAccelerator from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from lightning.pytorch.cli import ( _JSONARGPARSE_SIGNATURES_AVAILABLE, @@ -58,8 +59,6 @@ from tests_pytorch.helpers.runif import RunIf -from lightning.pytorch.accelerators import CPUAccelerator - if _JSONARGPARSE_SIGNATURES_AVAILABLE: from jsonargparse import Namespace, lazy_instance else: From 6b22a0fc0b6c46ec12a141e7742ef288b934556e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Fri, 27 Oct 2023 01:13:20 +0200 Subject: [PATCH 9/9] fixtypo --- src/lightning/pytorch/CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index c96a461cdf5f8..b03830455cc37 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -17,7 +17,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846)) - `LightningCLI` no longer allows setting a normal class instance as default. A `lazy_instance` can be used instead ([#18822](https://github.com/Lightning-AI/lightning/pull/18822)) -- The `PL_TRAINER_*` env variables now takes precendence over the arguments passed to the Trainer ([#18876](https://github.com/Lightning-AI/lightning/issues/18876)) +- The `PL_TRAINER_*` env variables now takes precedence over the arguments passed to the Trainer ([#18876](https://github.com/Lightning-AI/lightning/issues/18876))