Skip to content

Commit ced662f

Browse files
alexeykarnachevBordaJoe DavisonwilliamFalcon
authored
Custom argparser extension with Trainer arguments (argument types added) (#1147)
* `add_argparse_args` method fixed (argument types added) * CHANGELOG.md upd * autopep8 fixes * --gpus=0 removed from test (for ci tests) * typo fixed * reduce on plateau scheduler fixed * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * test_get_init_arguments_and_types added * autopep8 fixes * Apply suggestions from code review * cosmetics * cosmetics * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec <[email protected]> * `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets) * deprecated args are now ignored in argparser * get_deprecated_arg_names small refactor * get_deprecated_arg_names bug fixed * Trainer cli related tests moved to test_trainer_cli.py * refactored: get_init_arguments_and_types is a public classmethod of the Trainer now * test_get_init_arguments_and_types added * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * autopep8 fixes * Trainer cli related tests moved to test_trainer_cli.py * Trainer cli related tests moved to test_trainer_cli.py * test_get_init_arguments_and_types added * autopep8 fixes * autopep8 fixes * Apply suggestions from code review * cosmetics * cosmetics * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Jirka Borovec <[email protected]> * `Trainer.get_init_arguments_and_types` now returns arg types wrapped in tuples (not in sets) * deprecated args are now ignored in argparser * get_deprecated_arg_names small refactor * get_deprecated_arg_names bug fixed * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <[email protected]> * Update pytorch_lightning/trainer/trainer.py Co-Authored-By: Joe Davison <[email protected]> Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Joe Davison <[email protected]> Co-authored-by: William Falcon <[email protected]>
1 parent f6dabc2 commit ced662f

File tree

7 files changed

+188
-43
lines changed

7 files changed

+188
-43
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2929

3030
### Fixed
3131

32-
- Fixed bug related to type checking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114))
32+
33+
- `Trainer.add_argparse_args` classmethod fixed. Now it adds a type for the arguments ([#1147](https://github.com/PyTorchLightning/pytorch-lightning/pull/1147)).
34+
- Fixed bug related to type cheking of `ReduceLROnPlateau` lr schedulers([#1114](https://github.com/PyTorchLightning/pytorch-lightning/issues/1114))
3335
- Fixed a bug to ensure lightning checkpoints to be backward compatible ([#1132](https://github.com/PyTorchLightning/pytorch-lightning/pull/1132))
3436
- Fixed all warnings and errors in the docs build process ([#1191](https://github.com/PyTorchLightning/pytorch-lightning/pull/1191))
3537

pl_examples/full_examples/semantic_segmentation/models/unet/model.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ class UNet(nn.Module):
1313
bilinear (bool) - Whether to use bilinear interpolation or transposed
1414
convolutions for upsampling.
1515
'''
16+
1617
def __init__(self, num_classes=19, bilinear=False):
1718
super().__init__()
1819
self.layer1 = DoubleConv(3, 64)

pl_examples/full_examples/semantic_segmentation/models/unet/parts.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ class DoubleConv(nn.Module):
88
Double Convolution and BN and ReLU
99
(3x3 conv -> BN -> ReLU) ** 2
1010
'''
11+
1112
def __init__(self, in_ch, out_ch):
1213
super().__init__()
1314
self.net = nn.Sequential(
@@ -27,6 +28,7 @@ class Down(nn.Module):
2728
'''
2829
Combination of MaxPool2d and DoubleConv in series
2930
'''
31+
3032
def __init__(self, in_ch, out_ch):
3133
super().__init__()
3234
self.net = nn.Sequential(
@@ -44,6 +46,7 @@ class Up(nn.Module):
4446
followed by concatenation of feature map from contracting path,
4547
followed by double 3x3 convolution.
4648
'''
49+
4750
def __init__(self, in_ch, out_ch, bilinear=False):
4851
super().__init__()
4952
self.upsample = None

pl_examples/full_examples/semantic_segmentation/semseg.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ class KITTI(Dataset):
3434
encoded using `encode_segmap`, and given `transform` (if any) are applied to the image only
3535
(mask does not usually require transforms, but they can be implemented in a similar way).
3636
'''
37+
3738
def __init__(
3839
self,
3940
root_path,
@@ -120,6 +121,7 @@ class SegModel(pl.LightningModule):
120121
121122
Adam optimizer is used along with Cosine Annealing learning rate scheduler.
122123
'''
124+
123125
def __init__(self, hparams):
124126
super(SegModel, self).__init__()
125127
self.root_path = hparams.root

pytorch_lightning/trainer/trainer.py

Lines changed: 92 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -3,33 +3,30 @@
33
import sys
44
import warnings
55
from argparse import ArgumentParser
6-
from typing import Union, Optional, List, Dict, Tuple, Iterable
6+
from typing import Union, Optional, List, Dict, Tuple, Iterable, Any
7+
import distutils
78

89
import torch
9-
from torch import optim
1010
import torch.distributed as torch_distrib
1111
import torch.multiprocessing as mp
12+
from torch import optim
1213
from torch.optim.optimizer import Optimizer
1314
from torch.utils.data import DataLoader
1415
from tqdm.auto import tqdm
1516

1617
from pytorch_lightning import _logger as log
1718
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback
19+
from pytorch_lightning.core.lightning import LightningModule
1820
from pytorch_lightning.loggers import LightningLoggerBase
1921
from pytorch_lightning.profiler import Profiler, PassThroughProfiler
2022
from pytorch_lightning.profiler.profiler import BaseProfiler
2123
from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin
2224
from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin
23-
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
24-
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
25-
from pytorch_lightning.trainer.distrib_parts import (
26-
TrainerDPMixin,
27-
parse_gpu_ids,
28-
determine_root_gpu_device
29-
)
30-
from pytorch_lightning.core.lightning import LightningModule
3125
from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin
26+
from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin
3227
from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8
28+
from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin
29+
from pytorch_lightning.trainer.distrib_parts import TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device
3330
from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin
3431
from pytorch_lightning.trainer.logging import TrainerLoggingMixin
3532
from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin
@@ -70,6 +67,11 @@ class Trainer(
7067
TrainerCallbackHookMixin,
7168
TrainerDeprecatedAPITillVer0_8,
7269
):
70+
DEPRECATED_IN_0_8 = (
71+
'gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs', 'min_nb_epochs',
72+
'add_row_log_interval', 'nb_sanity_val_steps'
73+
)
74+
DEPRECATED_IN_0_9 = ('use_amp',)
7375

7476
def __init__(
7577
self,
@@ -466,21 +468,91 @@ def default_attributes(cls):
466468

467469
return args
468470

471+
@classmethod
472+
def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
473+
r"""Scans the Trainer signature and returns argument names, types and default values.
474+
475+
Returns:
476+
List with tuples of 3 values:
477+
(argument name, set with argument types, argument default value).
478+
479+
Examples:
480+
>>> args = Trainer.get_init_arguments_and_types()
481+
>>> import pprint
482+
>>> pprint.pprint(sorted(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE
483+
[('accumulate_grad_batches',
484+
(<class 'int'>, typing.Dict[int, int], typing.List[list]),
485+
1),
486+
...
487+
('callbacks', (<class 'pytorch_lightning.callbacks.base.Callback'>,), []),
488+
('check_val_every_n_epoch', (<class 'int'>,), 1),
489+
...
490+
('max_epochs', (<class 'int'>,), 1000),
491+
...
492+
('precision', (<class 'int'>,), 32),
493+
('print_nan_grads', (<class 'bool'>,), False),
494+
('process_position', (<class 'int'>,), 0),
495+
('profiler',
496+
(<class 'pytorch_lightning.profiler.profiler.BaseProfiler'>,
497+
<class 'NoneType'>),
498+
None),
499+
...
500+
"""
501+
trainer_default_params = inspect.signature(cls).parameters
502+
name_type_default = []
503+
for arg in trainer_default_params:
504+
arg_type = trainer_default_params[arg].annotation
505+
arg_default = trainer_default_params[arg].default
506+
try:
507+
arg_types = tuple(arg_type.__args__)
508+
except AttributeError:
509+
arg_types = (arg_type,)
510+
511+
name_type_default.append((arg, arg_types, arg_default))
512+
513+
return name_type_default
514+
515+
@classmethod
516+
def get_deprecated_arg_names(cls) -> List:
517+
"""Returns a list with deprecated Trainer arguments."""
518+
depr_arg_names = []
519+
for name, val in cls.__dict__.items():
520+
if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)):
521+
depr_arg_names.extend(val)
522+
return depr_arg_names
523+
469524
@classmethod
470525
def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser:
471-
"""Extend existing argparse by default `Trainer` attributes."""
472-
parser = ArgumentParser(parents=[parent_parser], add_help=False)
526+
r"""Extends existing argparse by default `Trainer` attributes.
473527
474-
trainer_default_params = Trainer.default_attributes()
528+
Args:
529+
parent_parser:
530+
The custom cli arguments parser, which will be extended by
531+
the Trainer default arguments.
532+
533+
Only arguments of the allowed types (str, float, int, bool) will
534+
extend the `parent_parser`.
535+
"""
536+
parser = ArgumentParser(parents=[parent_parser], add_help=False, )
475537

538+
depr_arg_names = cls.get_deprecated_arg_names()
539+
540+
allowed_types = (str, float, int, bool)
476541
# TODO: get "help" from docstring :)
477-
for arg in trainer_default_params:
478-
parser.add_argument(
479-
f'--{arg}',
480-
default=trainer_default_params[arg],
481-
dest=arg,
482-
help='autogenerated by pl.Trainer'
483-
)
542+
for arg, arg_types, arg_default in cls.get_init_arguments_and_types():
543+
if arg not in depr_arg_names:
544+
for allowed_type in allowed_types:
545+
if allowed_type in arg_types:
546+
if allowed_type is bool:
547+
allowed_type = lambda x: bool(distutils.util.strtobool(x))
548+
parser.add_argument(
549+
f'--{arg}',
550+
default=arg_default,
551+
type=allowed_type,
552+
dest=arg,
553+
help='autogenerated by pl.Trainer'
554+
)
555+
break
484556

485557
return parser
486558

tests/trainer/test_trainer.py

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
import glob
22
import math
33
import os
4-
from argparse import ArgumentParser, Namespace
5-
from unittest import mock
4+
from argparse import Namespace
65

76
import pytest
87
import torch
@@ -251,6 +250,7 @@ def test_dp_output_reduce():
251250

252251
def test_model_checkpoint_options(tmpdir):
253252
"""Test ModelCheckpoint options."""
253+
254254
def mock_save_function(filepath):
255255
open(filepath, 'a').close()
256256

@@ -624,23 +624,3 @@ def test_epoch_end(self, outputs):
624624

625625
model = LightningTestModel(hparams)
626626
Trainer().test(model)
627-
628-
629-
@mock.patch('argparse.ArgumentParser.parse_args',
630-
return_value=Namespace(**Trainer.default_attributes()))
631-
def test_default_args(tmpdir):
632-
"""Tests default argument parser for Trainer"""
633-
tutils.reset_seed()
634-
635-
# logger file to get meta
636-
logger = tutils.get_test_tube_logger(tmpdir, False)
637-
638-
parser = ArgumentParser(add_help=False)
639-
args = parser.parse_args()
640-
args.logger = logger
641-
642-
args.max_epochs = 5
643-
trainer = Trainer.from_argparse_args(args)
644-
645-
assert isinstance(trainer, Trainer)
646-
assert trainer.max_epochs == 5

tests/trainer/test_trainer_cli.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import inspect
2+
from argparse import ArgumentParser, Namespace
3+
from unittest import mock
4+
5+
import pytest
6+
7+
import tests.models.utils as tutils
8+
from pytorch_lightning import Trainer
9+
10+
11+
@mock.patch('argparse.ArgumentParser.parse_args',
12+
return_value=Namespace(**Trainer.default_attributes()))
13+
def test_default_args(tmpdir):
14+
"""Tests default argument parser for Trainer"""
15+
tutils.reset_seed()
16+
17+
# logger file to get meta
18+
logger = tutils.get_test_tube_logger(tmpdir, False)
19+
20+
parser = ArgumentParser(add_help=False)
21+
args = parser.parse_args()
22+
args.logger = logger
23+
24+
args.max_epochs = 5
25+
trainer = Trainer.from_argparse_args(args)
26+
27+
assert isinstance(trainer, Trainer)
28+
assert trainer.max_epochs == 5
29+
30+
31+
@pytest.mark.parametrize('cli_args', [
32+
['--accumulate_grad_batches=22'],
33+
['--print_nan_grads=1', '--weights_save_path=./'],
34+
[]
35+
])
36+
def test_add_argparse_args_redefined(cli_args):
37+
"""Redefines some default Trainer arguments via the cli and
38+
tests the Trainer initialization correctness.
39+
"""
40+
parser = ArgumentParser(add_help=False)
41+
parser = Trainer.add_argparse_args(parent_parser=parser)
42+
43+
args = parser.parse_args(cli_args)
44+
45+
# Check few deprecated args are not in namespace:
46+
for depr_name in ('gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs'):
47+
assert depr_name not in args
48+
49+
trainer = Trainer.from_argparse_args(args=args)
50+
assert isinstance(trainer, Trainer)
51+
52+
53+
def test_get_init_arguments_and_types():
54+
"""Asserts a correctness of the `get_init_arguments_and_types` Trainer classmethod."""
55+
args = Trainer.get_init_arguments_and_types()
56+
parameters = inspect.signature(Trainer).parameters
57+
assert len(parameters) == len(args)
58+
for arg in args:
59+
assert parameters[arg[0]].default == arg[2]
60+
61+
kwargs = {arg[0]: arg[2] for arg in args}
62+
trainer = Trainer(**kwargs)
63+
assert isinstance(trainer, Trainer)
64+
65+
66+
@pytest.mark.parametrize('cli_args', [
67+
['--callbacks=1', '--logger'],
68+
['--foo', '--bar=1']
69+
])
70+
def test_add_argparse_args_redefined_error(cli_args, monkeypatch):
71+
"""Asserts thar an error raised in case of passing not default cli arguments."""
72+
73+
class _UnkArgError(Exception):
74+
pass
75+
76+
def _raise():
77+
raise _UnkArgError
78+
79+
parser = ArgumentParser(add_help=False)
80+
parser = Trainer.add_argparse_args(parent_parser=parser)
81+
82+
monkeypatch.setattr(parser, 'exit', lambda *args: _raise(), raising=True)
83+
84+
with pytest.raises(_UnkArgError):
85+
parser.parse_args(cli_args)

0 commit comments

Comments
 (0)