diff --git a/src/lightning/pytorch/CHANGELOG.md b/src/lightning/pytorch/CHANGELOG.md index a364618a3eb36..b03830455cc37 100644 --- a/src/lightning/pytorch/CHANGELOG.md +++ b/src/lightning/pytorch/CHANGELOG.md @@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846)) - `LightningCLI` no longer allows setting a normal class instance as default. A `lazy_instance` can be used instead ([#18822](https://github.com/Lightning-AI/lightning/pull/18822)) +- The `PL_TRAINER_*` env variables now takes precedence over the arguments passed to the Trainer ([#18876](https://github.com/Lightning-AI/lightning/issues/18876)) + + ### Deprecated @@ -36,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed an issue when `BatchSizeFinder` `steps_per_trial` parameter ends up defining how many validation batches to run during the entire training ([#18394](https://github.com/Lightning-AI/lightning/issues/18394)) +- Fixed an issue that would prevent the user to override the Trainer arguments with the `PL_TRAINER_*` env variable ([#18876](https://github.com/Lightning-AI/lightning/issues/18876)) + ## [2.1.0] - 2023-10-11 diff --git a/src/lightning/pytorch/trainer/trainer.py b/src/lightning/pytorch/trainer/trainer.py index d0dcb8f437558..56885bec82a9b 100644 --- a/src/lightning/pytorch/trainer/trainer.py +++ b/src/lightning/pytorch/trainer/trainer.py @@ -66,7 +66,7 @@ from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector from lightning.pytorch.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus from lightning.pytorch.utilities import GradClipAlgorithmType, parsing -from lightning.pytorch.utilities.argparse import _defaults_from_env_vars +from lightning.pytorch.utilities.argparse import _overrides_from_env_vars from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized, _verify_strategy_supports_compile from lightning.pytorch.utilities.exceptions import MisconfigurationException from lightning.pytorch.utilities.model_helpers import is_overridden @@ -89,7 +89,7 @@ class Trainer: - @_defaults_from_env_vars + @_overrides_from_env_vars def __init__( self, *, diff --git a/src/lightning/pytorch/utilities/argparse.py b/src/lightning/pytorch/utilities/argparse.py index eb7273b54a577..285d935fe8d33 100644 --- a/src/lightning/pytorch/utilities/argparse.py +++ b/src/lightning/pytorch/utilities/argparse.py @@ -53,9 +53,9 @@ def _parse_env_variables(cls: Type, template: str = "PL_%(cls_name)s_%(cls_argum return Namespace(**env_args) -def _defaults_from_env_vars(fn: _T) -> _T: +def _overrides_from_env_vars(fn: _T) -> _T: @wraps(fn) - def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any: + def update_kwargs_from_env(self: Any, *args: Any, **kwargs: Any) -> Any: cls = self.__class__ # get the class if args: # in case any args passed move them to kwargs # parse the argument names @@ -64,9 +64,9 @@ def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any: kwargs.update(dict(zip(cls_arg_names, args))) env_variables = vars(_parse_env_variables(cls)) # update the kwargs by env variables - kwargs = dict(list(env_variables.items()) + list(kwargs.items())) + kwargs = dict(list(kwargs.items()) + list(env_variables.items())) # all args were already moved to kwargs return fn(self, **kwargs) - return cast(_T, insert_env_defaults) + return cast(_T, update_kwargs_from_env) diff --git a/tests/tests_pytorch/test_cli.py b/tests/tests_pytorch/test_cli.py index fdc49f787c759..032181dcd03a5 100644 --- a/tests/tests_pytorch/test_cli.py +++ b/tests/tests_pytorch/test_cli.py @@ -29,6 +29,7 @@ import yaml from lightning.fabric.plugins.environments import SLURMEnvironment from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__, seed_everything +from lightning.pytorch.accelerators import CPUAccelerator from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint from lightning.pytorch.cli import ( _JSONARGPARSE_SIGNATURES_AVAILABLE, @@ -288,6 +289,21 @@ def test_lightning_env_parse(cleandir): assert cli.config.fit.trainer.logger is False +def test_lightning_cli_and_trainer_kwargs_override(cleandir): + env_vars = { + "PL_TRAINER_ACCELERATOR": "cpu", + "PL_TRAINER_DEVICES": "2", + "PL_TRAINER_NUM_NODES": "4", + "PL_TRAINER_PRECISION": "16-true", + } + with mock.patch.dict(os.environ, env_vars), mock.patch("sys.argv", [""]): + cli = LightningCLI(BoringModel, run=False) + assert isinstance(cli.trainer.accelerator, CPUAccelerator) + assert cli.trainer.num_devices == 2 + assert cli.trainer.num_nodes == 4 + assert cli.trainer.precision == "16-true" + + def test_lightning_cli_save_config_cases(cleandir): config_path = "config.yaml" cli_args = ["fit", "--trainer.logger=false", "--trainer.fast_dev_run=1"] diff --git a/tests/tests_pytorch/trainer/flags/test_env_vars.py b/tests/tests_pytorch/trainer/flags/test_env_vars.py index 62c94d4cc277e..5c2b508364301 100644 --- a/tests/tests_pytorch/trainer/flags/test_env_vars.py +++ b/tests/tests_pytorch/trainer/flags/test_env_vars.py @@ -44,14 +44,16 @@ def test_passing_env_variables_only(): def test_passing_env_variables_defaults(): """Testing overwriting trainer arguments.""" trainer = Trainer(logger=False, max_steps=42) - assert trainer.logger is None - assert trainer.max_steps == 42 + assert trainer.logger is not None + assert trainer.max_steps == 7 -@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_DEVICES": "2"}) -def test_passing_env_variables_devices(cuda_count_2): +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_DEVICES": "2", "PL_TRAINER_NUM_NODES": "4"}) +def test_passing_env_variables_devices(cuda_count_2, mps_count_0): """Testing overwriting trainer arguments.""" trainer = Trainer() assert trainer.num_devices == 2 + assert trainer.num_nodes == 4 trainer = Trainer(accelerator="gpu", devices=1) - assert trainer.num_devices == 1 + assert trainer.num_devices == 2 + assert trainer.num_nodes == 4