From 2d480e83560dd1826df9afecdd87e75a20c05f56 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Sat, 13 Mar 2021 23:14:52 +0100 Subject: [PATCH 1/7] change tests --- .../trainer/connectors/env_vars_connector.py | 2 +- pytorch_lightning/trainer/trainer.py | 4 ++-- tests/trainer/flags/test_env_vars.py | 23 +++++++++++-------- 3 files changed, 16 insertions(+), 13 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index 2e788c256af0d..bc7ee7a286c38 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -18,7 +18,7 @@ from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables -def overwrite_by_env_vars(fn: Callable) -> Callable: +def defaults_from_env_vars(fn: Callable) -> Callable: """ Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should be moved automatically to the correct device. diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 5af9746582450..b9a0f058e5355 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -38,7 +38,7 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector -from pytorch_lightning.trainer.connectors.env_vars_connector import overwrite_by_env_vars +from pytorch_lightning.trainer.connectors.env_vars_connector import defaults_from_env_vars from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector @@ -84,7 +84,7 @@ class Trainer( DeprecatedTrainerAttributes, ): - @overwrite_by_env_vars + @defaults_from_env_vars def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, diff --git a/tests/trainer/flags/test_env_vars.py b/tests/trainer/flags/test_env_vars.py index ba76820d15ee8..fab775e2167fa 100644 --- a/tests/trainer/flags/test_env_vars.py +++ b/tests/trainer/flags/test_env_vars.py @@ -12,11 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. import os +from unittest import mock from pytorch_lightning import Trainer +from pytorch_lightning.loggers import LightningLoggerBase -def test_passing_env_variables(tmpdir): +def test_passing_no_env_variables(): """Testing overwriting trainer arguments """ trainer = Trainer() assert trainer.logger is not None @@ -25,17 +27,18 @@ def test_passing_env_variables(tmpdir): assert trainer.logger is None assert trainer.max_steps == 42 - os.environ['PL_TRAINER_LOGGER'] = 'False' - os.environ['PL_TRAINER_MAX_STEPS'] = '7' + +@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "False", "PL_TRAINER_MAX_STEPS": "7"}) +def test_passing_env_variables_only(): + """Testing overwriting trainer arguments """ trainer = Trainer() assert trainer.logger is None assert trainer.max_steps == 7 - os.environ['PL_TRAINER_LOGGER'] = 'True' - trainer = Trainer(False, max_steps=42) - assert trainer.logger is not None - assert trainer.max_steps == 7 - # this has to be cleaned - del os.environ['PL_TRAINER_LOGGER'] - del os.environ['PL_TRAINER_MAX_STEPS'] +@mock.patch.dict(os.environ, {"PL_TRAINER_LOGGER": "True", "PL_TRAINER_MAX_STEPS": "7"}) +def test_passing_env_variables_defaults(): + """Testing overwriting trainer arguments """ + trainer = Trainer(False, max_steps=42) + assert isinstance(trainer.logger, LightningLoggerBase) + assert trainer.max_steps == 42 From 04d1844a2ee80678ce2f62b01a5c5a7e3c80187e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 15 Mar 2021 21:23:27 +0100 Subject: [PATCH 2/7] fix --- .../trainer/connectors/env_vars_connector.py | 10 ++++------ pytorch_lightning/utilities/argparse.py | 2 +- tests/trainer/flags/test_env_vars.py | 2 +- 3 files changed, 6 insertions(+), 8 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index bc7ee7a286c38..ddbd6ee3331b5 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -22,11 +22,9 @@ def defaults_from_env_vars(fn: Callable) -> Callable: """ Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should be moved automatically to the correct device. - """ - @wraps(fn) - def overwrite_by_env_vars(self, *args, **kwargs): + def insert_env_defaults(self, *args, **kwargs): # get the class cls = self.__class__ if args: # inace any args passed move them to kwargs @@ -34,11 +32,11 @@ def overwrite_by_env_vars(self, *args, **kwargs): cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)] # convert args to kwargs kwargs.update({k: v for k, v in zip(cls_arg_names, args)}) + env_variables = vars(parse_env_variables(cls)) # update the kwargs by env variables - # todo: maybe add a warning that some init args were overwritten by Env arguments - kwargs.update(vars(parse_env_variables(cls))) + kwargs = dict(list(env_variables.items()) + list(kwargs.items())) # all args were already moved to kwargs return fn(self, **kwargs) - return overwrite_by_env_vars + return insert_env_defaults diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index ee42ab3241ff6..e04c15bb5d769 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -107,7 +107,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: - r"""Scans the Trainer signature and returns argument names, types and default values. + r"""Scans the class signature and returns argument names, types and default values. Returns: List with tuples of 3 values: diff --git a/tests/trainer/flags/test_env_vars.py b/tests/trainer/flags/test_env_vars.py index fab775e2167fa..f8ec4ec2f10af 100644 --- a/tests/trainer/flags/test_env_vars.py +++ b/tests/trainer/flags/test_env_vars.py @@ -40,5 +40,5 @@ def test_passing_env_variables_only(): def test_passing_env_variables_defaults(): """Testing overwriting trainer arguments """ trainer = Trainer(False, max_steps=42) - assert isinstance(trainer.logger, LightningLoggerBase) + assert trainer.logger is None assert trainer.max_steps == 42 From 8a968798e3c30033cb4b7ef45d39d3278d97adc4 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 15 Mar 2021 21:30:58 +0100 Subject: [PATCH 3/7] test --- tests/trainer/flags/test_env_vars.py | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tests/trainer/flags/test_env_vars.py b/tests/trainer/flags/test_env_vars.py index f8ec4ec2f10af..65b251a6633b5 100644 --- a/tests/trainer/flags/test_env_vars.py +++ b/tests/trainer/flags/test_env_vars.py @@ -15,7 +15,6 @@ from unittest import mock from pytorch_lightning import Trainer -from pytorch_lightning.loggers import LightningLoggerBase def test_passing_no_env_variables(): @@ -42,3 +41,14 @@ def test_passing_env_variables_defaults(): trainer = Trainer(False, max_steps=42) assert trainer.logger is None assert trainer.max_steps == 42 + + +@mock.patch.dict(os.environ, {"CUDA_VISIBLE_DEVICES": "0,1", "PL_TRAINER_GPUS": "2"}) +@mock.patch('torch.cuda.device_count', return_value=2) +@mock.patch('torch.cuda.is_available', return_value=True) +def test_passing_env_variables_gpus(cuda_available_mock, device_count_mock): + """Testing overwriting trainer arguments """ + trainer = Trainer() + assert trainer.gpus == 2 + trainer = Trainer(gpus=1) + assert trainer.gpus == 1 From 8ed010198db30d17101c01f804dfacbcd19d3b5d Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 15 Mar 2021 21:55:39 +0100 Subject: [PATCH 4/7] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Carlos MocholĂ­ --- pytorch_lightning/trainer/connectors/env_vars_connector.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index ddbd6ee3331b5..8d9f2773f61c1 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -18,7 +18,7 @@ from pytorch_lightning.utilities.argparse import get_init_arguments_and_types, parse_env_variables -def defaults_from_env_vars(fn: Callable) -> Callable: +def _defaults_from_env_vars(fn: Callable) -> Callable: """ Decorator for :class:`~pytorch_lightning.trainer.trainer.Trainer` methods for which input arguments should be moved automatically to the correct device. From 78955ce7c37829759f9545418262b7de10c880ff Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 15 Mar 2021 21:56:28 +0100 Subject: [PATCH 5/7] _defaults_from_env_vars --- pytorch_lightning/trainer/trainer.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index b9a0f058e5355..44b0e716a90c0 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -38,7 +38,7 @@ from pytorch_lightning.trainer.connectors.checkpoint_connector import CheckpointConnector from pytorch_lightning.trainer.connectors.data_connector import DataConnector from pytorch_lightning.trainer.connectors.debugging_connector import DebuggingConnector -from pytorch_lightning.trainer.connectors.env_vars_connector import defaults_from_env_vars +from pytorch_lightning.trainer.connectors.env_vars_connector import _defaults_from_env_vars from pytorch_lightning.trainer.connectors.logger_connector import LoggerConnector from pytorch_lightning.trainer.connectors.model_connector import ModelConnector from pytorch_lightning.trainer.connectors.optimizer_connector import OptimizerConnector @@ -84,7 +84,7 @@ class Trainer( DeprecatedTrainerAttributes, ): - @defaults_from_env_vars + @_defaults_from_env_vars def __init__( self, logger: Union[LightningLoggerBase, Iterable[LightningLoggerBase], bool] = True, From 8dbccac09eb06c396c9fb8871019e5102b1b9af0 Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Mon, 15 Mar 2021 23:14:27 +0100 Subject: [PATCH 6/7] var --- pytorch_lightning/utilities/argparse.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/utilities/argparse.py b/pytorch_lightning/utilities/argparse.py index e04c15bb5d769..49cbaf3c6bdcf 100644 --- a/pytorch_lightning/utilities/argparse.py +++ b/pytorch_lightning/utilities/argparse.py @@ -119,11 +119,11 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: >>> args = get_init_arguments_and_types(Trainer) """ - trainer_default_params = inspect.signature(cls).parameters + cls_default_params = inspect.signature(cls).parameters name_type_default = [] - for arg in trainer_default_params: - arg_type = trainer_default_params[arg].annotation - arg_default = trainer_default_params[arg].default + for arg in cls_default_params: + arg_type = cls_default_params[arg].annotation + arg_default = cls_default_params[arg].default try: arg_types = tuple(arg_type.__args__) except AttributeError: From 02e6151256ef524605439294dbb098efc1f6b20e Mon Sep 17 00:00:00 2001 From: Jirka Borovec Date: Tue, 16 Mar 2021 10:44:33 +0100 Subject: [PATCH 7/7] trigger --- pytorch_lightning/trainer/connectors/env_vars_connector.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/env_vars_connector.py b/pytorch_lightning/trainer/connectors/env_vars_connector.py index 8d9f2773f61c1..f4209f40d002e 100644 --- a/pytorch_lightning/trainer/connectors/env_vars_connector.py +++ b/pytorch_lightning/trainer/connectors/env_vars_connector.py @@ -25,8 +25,7 @@ def _defaults_from_env_vars(fn: Callable) -> Callable: """ @wraps(fn) def insert_env_defaults(self, *args, **kwargs): - # get the class - cls = self.__class__ + cls = self.__class__ # get the class if args: # inace any args passed move them to kwargs # parse only the argument names cls_arg_names = [arg[0] for arg in get_init_arguments_and_types(cls)]