Skip to content

Commit baf4f35

Browse files
authored
add parsing OS env vars (#4022)
* add parsing OS env vars * fix env * Apply suggestions from code review * overwrite init * Apply suggestions from code review
1 parent 8a3c800 commit baf4f35

File tree

5 files changed

+111
-0
lines changed

5 files changed

+111
-0
lines changed
Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
# Copyright The PyTorch Lightning team.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from functools import wraps
16+
from typing import Callable
17+
18+
from pytorch_lightning.utilities.argparse_utils import parse_env_variables, get_init_arguments_and_types
19+
20+
21+
def overwrite_by_env_vars(fn: Callable) -> Callable:
22+
"""
23+
Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which
24+
input arguments should be moved automatically to the correct device.
25+
26+
"""
27+
@wraps(fn)
28+
def overwrite_by_env_vars(self, *args, **kwargs):
29+
# get the class
30+
cls = self.__class__
31+
if args: # inace any args passed move them to kwargs
32+
# parse only the argument names
33+
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
34+
# convert args to kwargs
35+
kwargs.update({k: v for k, v in zip(cls_arg_names, args)})
36+
# update the kwargs by env variables
37+
# todo: maybe add a warning that some init args were overwritten by Env arguments
38+
kwargs.update(vars(parse_env_variables(cls)))
39+
40+
# all args were already moved to kwargs
41+
return fn(self, **kwargs)
42+
43+
return overwrite_by_env_vars

pytorch_lightning/trainer/properties.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -112,6 +112,10 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
112112
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
113113
return argparse_utils.parse_argparser(cls, arg_parser)
114114

115+
@classmethod
116+
def match_env_arguments(cls) -> Namespace:
117+
return argparse_utils.parse_env_variables(cls)
118+
115119
@classmethod
116120
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
117121
return argparse_utils.add_argparse_args(cls, parent_parser)

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from pytorch_lightning.profiler import BaseProfiler
2929
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
3030
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
31+
from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars
3132
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
3233
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
3334
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
@@ -79,6 +80,7 @@ class Trainer(
7980
TrainerTrainingTricksMixin,
8081
TrainerDataLoadingMixin,
8182
):
83+
@overwrite_by_env_vars
8284
def __init__(
8385
self,
8486
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,

pytorch_lightning/utilities/argparse_utils.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import inspect
2+
import os
23
from argparse import ArgumentParser, Namespace
34
from typing import Union, List, Tuple, Any
45
from pytorch_lightning.utilities import parsing
@@ -7,6 +8,7 @@
78
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
89
"""
910
Create an instance from CLI arguments.
11+
Eventually use varibles from OS environement which are defined as "PL_<CLASS-NAME>_<CLASS_ARUMENT_NAME>"
1012
1113
Args:
1214
args: The parser or namespace to take arguments from. Only known arguments will be
@@ -22,8 +24,11 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
2224
>>> args = Trainer.parse_argparser(parser.parse_args(""))
2325
>>> trainer = Trainer.from_argparse_args(args, logger=False)
2426
"""
27+
# fist check if any args are defined in environment for the class and set as default
28+
2529
if isinstance(args, ArgumentParser):
2630
args = cls.parse_argparser(args)
31+
# if other arg passed, update parameters
2732
params = vars(args)
2833

2934
# we only want to pass in valid Trainer args, the rest may be user specific
@@ -61,6 +66,35 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
6166
return Namespace(**modified_args)
6267

6368

69+
def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
70+
"""Parse environment arguments if they are defined.
71+
72+
Example:
73+
>>> from pytorch_lightning import Trainer
74+
>>> parse_env_variables(Trainer)
75+
Namespace()
76+
>>> import os
77+
>>> os.environ["PL_TRAINER_GPUS"] = '42'
78+
>>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23'
79+
>>> parse_env_variables(Trainer)
80+
Namespace(gpus=42)
81+
>>> del os.environ["PL_TRAINER_GPUS"]
82+
"""
83+
cls_arg_defaults = get_init_arguments_and_types(cls)
84+
85+
env_args = {}
86+
for arg_name, _, _ in cls_arg_defaults:
87+
env = template % {'cls_name': cls.__name__.upper(), 'cls_argument': arg_name.upper()}
88+
val = os.environ.get(env)
89+
if not (val is None or val == ''):
90+
try: # converting to native types like int/float/bool
91+
val = eval(val)
92+
except Exception:
93+
pass
94+
env_args[arg_name] = val
95+
return Namespace(**env_args)
96+
97+
6498
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
6599
r"""Scans the Trainer signature and returns argument names, types and default values.
66100
Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
import os
2+
3+
from pytorch_lightning import Trainer
4+
5+
6+
def test_passing_env_variables(tmpdir):
7+
"""Testing overwriting trainer arguments """
8+
trainer = Trainer()
9+
assert trainer.logger is not None
10+
assert trainer.max_steps is None
11+
trainer = Trainer(False, max_steps=42)
12+
assert trainer.logger is None
13+
assert trainer.max_steps == 42
14+
15+
os.environ['PL_TRAINER_LOGGER'] = 'False'
16+
os.environ['PL_TRAINER_MAX_STEPS'] = '7'
17+
trainer = Trainer()
18+
assert trainer.logger is None
19+
assert trainer.max_steps == 7
20+
21+
os.environ['PL_TRAINER_LOGGER'] = 'True'
22+
trainer = Trainer(False, max_steps=42)
23+
assert trainer.logger is not None
24+
assert trainer.max_steps == 7
25+
26+
# this has to be cleaned
27+
del os.environ['PL_TRAINER_LOGGER']
28+
del os.environ['PL_TRAINER_MAX_STEPS']

0 commit comments

Comments
 (0)