Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
43 changes: 43 additions & 0 deletions pytorch_lightning/trainer/connectors/env_vars_connector.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from functools import wraps
from typing import Callable

from pytorch_lightning.utilities.argparse_utils import parse_env_variables, get_init_arguments_and_types


def overwrite_by_env_vars(fn: Callable) -> Callable:
"""
Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which
input arguments should be moved automatically to the correct device.
Comment on lines +23 to +24
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hehe, I think I know from where you copied this code :))
this docstring needs to be updated :)


"""
@wraps(fn)
def overwrite_by_env_vars(self, *args, **kwargs):
# get the class
cls = self.__class__
if args: # inace any args passed move them to kwargs
# parse only the argument names
cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]
# convert args to kwargs
kwargs.update({k: v for k, v in zip(cls_arg_names, args)})
# update the kwargs by env variables
# todo: maybe add a warning that some init args were overwritten by Env arguments
kwargs.update(vars(parse_env_variables(cls)))

# all args were already moved to kwargs
return fn(self, **kwargs)

return overwrite_by_env_vars
4 changes: 4 additions & 0 deletions pytorch_lightning/trainer/properties.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,10 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
return argparse_utils.parse_argparser(cls, arg_parser)

@classmethod
def match_env_arguments(cls) -> Namespace:
return argparse_utils.parse_env_variables(cls)

@classmethod
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
return argparse_utils.add_argparse_args(cls, parent_parser)
Expand Down
2 changes: 2 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from pytorch_lightning.profiler import BaseProfiler
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
from pytorch_lightning.trainer.configuration_validator import ConfigValidator
from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
Expand Down Expand Up @@ -79,6 +80,7 @@ class Trainer(
TrainerTrainingTricksMixin,
TrainerDataLoadingMixin,
):
@overwrite_by_env_vars
def __init__(
self,
logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True,
Expand Down
34 changes: 34 additions & 0 deletions pytorch_lightning/utilities/argparse_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import inspect
import os
from argparse import ArgumentParser, Namespace
from typing import Union, List, Tuple, Any
from pytorch_lightning.utilities import parsing
Expand All @@ -7,6 +8,7 @@
def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
"""
Create an instance from CLI arguments.
Eventually use varibles from OS environement which are defined as "PL_<CLASS-NAME>_<CLASS_ARUMENT_NAME>"

Args:
args: The parser or namespace to take arguments from. Only known arguments will be
Expand All @@ -22,8 +24,11 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
>>> args = Trainer.parse_argparser(parser.parse_args(""))
>>> trainer = Trainer.from_argparse_args(args, logger=False)
"""
# fist check if any args are defined in environment for the class and set as default

if isinstance(args, ArgumentParser):
args = cls.parse_argparser(args)
# if other arg passed, update parameters
params = vars(args)

# we only want to pass in valid Trainer args, the rest may be user specific
Expand Down Expand Up @@ -61,6 +66,35 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
return Namespace(**modified_args)


def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
"""Parse environment arguments if they are defined.

Example:
>>> from pytorch_lightning import Trainer
>>> parse_env_variables(Trainer)
Namespace()
>>> import os
>>> os.environ["PL_TRAINER_GPUS"] = '42'
>>> os.environ["PL_TRAINER_BLABLABLA"] = '1.23'
>>> parse_env_variables(Trainer)
Namespace(gpus=42)
>>> del os.environ["PL_TRAINER_GPUS"]
"""
cls_arg_defaults = get_init_arguments_and_types(cls)

env_args = {}
for arg_name, _, _ in cls_arg_defaults:
env = template % {'cls_name': cls.__name__.upper(), 'cls_argument': arg_name.upper()}
val = os.environ.get(env)
if not (val is None or val == ''):
try: # converting to native types like int/float/bool
val = eval(val)
except Exception:
pass
env_args[arg_name] = val
return Namespace(**env_args)


def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the Trainer signature and returns argument names, types and default values.

Expand Down
28 changes: 28 additions & 0 deletions tests/trainer/flags/test_env_vars.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
import os

from pytorch_lightning import Trainer


def test_passing_env_variables(tmpdir):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@williamFalcon is this the priority order you expect?

"""Testing overwriting trainer arguments """
trainer = Trainer()
assert trainer.logger is not None
assert trainer.max_steps is None
trainer = Trainer(False, max_steps=42)
assert trainer.logger is None
assert trainer.max_steps == 42

os.environ['PL_TRAINER_LOGGER'] = 'False'
os.environ['PL_TRAINER_MAX_STEPS'] = '7'
trainer = Trainer()
assert trainer.logger is None
assert trainer.max_steps == 7

os.environ['PL_TRAINER_LOGGER'] = 'True'
trainer = Trainer(False, max_steps=42)
assert trainer.logger is not None
assert trainer.max_steps == 7

# this has to be cleaned
del os.environ['PL_TRAINER_LOGGER']
del os.environ['PL_TRAINER_MAX_STEPS']