Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- `seed_everything()` without passing in a seed no longer randomly selects a seed, and now defaults to `0` ([#18846](https://github.com/Lightning-AI/lightning/pull/18846))
- `LightningCLI` no longer allows setting a normal class instance as default. A `lazy_instance` can be used instead ([#18822](https://github.com/Lightning-AI/lightning/pull/18822))

- The `PL_TRAINER_*` env variables now takes precedence over the arguments passed to the Trainer ([#18876](https://github.com/Lightning-AI/lightning/issues/18876))



### Deprecated

Expand All @@ -36,6 +39,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed an issue when `BatchSizeFinder` `steps_per_trial` parameter ends up defining how many validation batches to run during the entire training ([#18394](https://github.com/Lightning-AI/lightning/issues/18394))


- Fixed an issue that would prevent the user to override the Trainer arguments with the `PL_TRAINER_*` env variable ([#18876](https://github.com/Lightning-AI/lightning/issues/18876))


## [2.1.0] - 2023-10-11

Expand Down
4 changes: 2 additions & 2 deletions src/lightning/pytorch/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@
from lightning.pytorch.trainer.connectors.signal_connector import _SignalConnector
from lightning.pytorch.trainer.states import RunningStage, TrainerFn, TrainerState, TrainerStatus
from lightning.pytorch.utilities import GradClipAlgorithmType, parsing
from lightning.pytorch.utilities.argparse import _defaults_from_env_vars
from lightning.pytorch.utilities.argparse import _overrides_from_env_vars
from lightning.pytorch.utilities.compile import _maybe_unwrap_optimized, _verify_strategy_supports_compile
from lightning.pytorch.utilities.exceptions import MisconfigurationException
from lightning.pytorch.utilities.model_helpers import is_overridden
Expand All @@ -89,7 +89,7 @@


class Trainer:
@_defaults_from_env_vars
@_overrides_from_env_vars
def __init__(
self,
*,
Expand Down
8 changes: 4 additions & 4 deletions src/lightning/pytorch/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,9 +53,9 @@ def _parse_env_variables(cls: Type, template: str = "PL_%(cls_name)s_%(cls_argum
return Namespace(**env_args)


def _defaults_from_env_vars(fn: _T) -> _T:
def _overrides_from_env_vars(fn: _T) -> _T:
@wraps(fn)
def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any:
def update_kwargs_from_env(self: Any, *args: Any, **kwargs: Any) -> Any:
cls = self.__class__ # get the class
if args: # in case any args passed move them to kwargs
# parse the argument names
Expand All @@ -64,9 +64,9 @@ def insert_env_defaults(self: Any, *args: Any, **kwargs: Any) -> Any:
kwargs.update(dict(zip(cls_arg_names, args)))
env_variables = vars(_parse_env_variables(cls))
# update the kwargs by env variables
kwargs = dict(list(env_variables.items()) + list(kwargs.items()))
kwargs = dict(list(kwargs.items()) + list(env_variables.items()))
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is the fix. The env vars now take precendence over the arguments passed by the user.
This functionality is undocumented and so I am ok with changing that. I am confident that this was always the intention.


# all args were already moved to kwargs
return fn(self, **kwargs)

return cast(_T, insert_env_defaults)
return cast(_T, update_kwargs_from_env)
16 changes: 16 additions & 0 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import yaml
from lightning.fabric.plugins.environments import SLURMEnvironment
from lightning.pytorch import Callback, LightningDataModule, LightningModule, Trainer, __version__, seed_everything
from lightning.pytorch.accelerators import CPUAccelerator
from lightning.pytorch.callbacks import LearningRateMonitor, ModelCheckpoint
from lightning.pytorch.cli import (
_JSONARGPARSE_SIGNATURES_AVAILABLE,
Expand Down Expand Up @@ -288,6 +289,21 @@ def test_lightning_env_parse(cleandir):
assert cli.config.fit.trainer.logger is False


def test_lightning_cli_and_trainer_kwargs_override(cleandir):
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Test doesn't pass on master, because the CLI passes all arguments to the Trainer, including the defaults, and so env variables would never take precedence.

env_vars = {
"PL_TRAINER_ACCELERATOR": "cpu",
"PL_TRAINER_DEVICES": "2",
"PL_TRAINER_NUM_NODES": "4",
"PL_TRAINER_PRECISION": "16-true",
}
with mock.patch.dict(os.environ, env_vars), mock.patch("sys.argv", [""]):
cli = LightningCLI(BoringModel, run=False)
assert isinstance(cli.trainer.accelerator, CPUAccelerator)
assert cli.trainer.num_devices == 2
assert cli.trainer.num_nodes == 4
assert cli.trainer.precision == "16-true"


def test_lightning_cli_save_config_cases(cleandir):
config_path = "config.yaml"
cli_args = ["fit", "--trainer.logger=false", "--trainer.fast_dev_run=1"]
Expand Down
12 changes: 7 additions & 5 deletions tests/tests_pytorch/trainer/flags/test_env_vars.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,14 +44,16 @@ def test_passing_env_variables_only():
def test_passing_env_variables_defaults():
"""Testing overwriting trainer arguments."""
trainer = Trainer(logger=False, max_steps=42)
assert trainer.logger is None
assert trainer.max_steps == 42
assert trainer.logger is not None
assert trainer.max_steps == 7


@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_DEVICES": "2"})
def test_passing_env_variables_devices(cuda_count_2):
@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_DEVICES": "2", "PL_TRAINER_NUM_NODES": "4"})
def test_passing_env_variables_devices(cuda_count_2, mps_count_0):
"""Testing overwriting trainer arguments."""
trainer = Trainer()
assert trainer.num_devices == 2
assert trainer.num_nodes == 4
trainer = Trainer(accelerator="gpu", devices=1)
assert trainer.num_devices == 1
assert trainer.num_devices == 2
assert trainer.num_nodes == 4