|
3 | 3 | import sys |
4 | 4 | import warnings |
5 | 5 | from argparse import ArgumentParser |
6 | | -from typing import Union, Optional, List, Dict, Tuple, Iterable |
| 6 | +from typing import Union, Optional, List, Dict, Tuple, Iterable, Any |
| 7 | +import distutils |
7 | 8 |
|
8 | 9 | import torch |
9 | | -from torch import optim |
10 | 10 | import torch.distributed as torch_distrib |
11 | 11 | import torch.multiprocessing as mp |
| 12 | +from torch import optim |
12 | 13 | from torch.optim.optimizer import Optimizer |
13 | 14 | from torch.utils.data import DataLoader |
14 | 15 | from tqdm.auto import tqdm |
15 | 16 |
|
16 | 17 | from pytorch_lightning import _logger as log |
17 | 18 | from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, Callback |
| 19 | +from pytorch_lightning.core.lightning import LightningModule |
18 | 20 | from pytorch_lightning.loggers import LightningLoggerBase |
19 | 21 | from pytorch_lightning.profiler import Profiler, PassThroughProfiler |
20 | 22 | from pytorch_lightning.profiler.profiler import BaseProfiler |
21 | 23 | from pytorch_lightning.trainer.auto_mix_precision import TrainerAMPMixin |
22 | 24 | from pytorch_lightning.trainer.callback_config import TrainerCallbackConfigMixin |
23 | | -from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin |
24 | | -from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin |
25 | | -from pytorch_lightning.trainer.distrib_parts import ( |
26 | | - TrainerDPMixin, |
27 | | - parse_gpu_ids, |
28 | | - determine_root_gpu_device |
29 | | -) |
30 | | -from pytorch_lightning.core.lightning import LightningModule |
31 | 25 | from pytorch_lightning.trainer.callback_hook import TrainerCallbackHookMixin |
| 26 | +from pytorch_lightning.trainer.data_loading import TrainerDataLoadingMixin |
32 | 27 | from pytorch_lightning.trainer.deprecated_api import TrainerDeprecatedAPITillVer0_8 |
| 28 | +from pytorch_lightning.trainer.distrib_data_parallel import TrainerDDPMixin |
| 29 | +from pytorch_lightning.trainer.distrib_parts import TrainerDPMixin, parse_gpu_ids, determine_root_gpu_device |
33 | 30 | from pytorch_lightning.trainer.evaluation_loop import TrainerEvaluationLoopMixin |
34 | 31 | from pytorch_lightning.trainer.logging import TrainerLoggingMixin |
35 | 32 | from pytorch_lightning.trainer.model_hooks import TrainerModelHooksMixin |
@@ -70,6 +67,11 @@ class Trainer( |
70 | 67 | TrainerCallbackHookMixin, |
71 | 68 | TrainerDeprecatedAPITillVer0_8, |
72 | 69 | ): |
| 70 | + DEPRECATED_IN_0_8 = ( |
| 71 | + 'gradient_clip', 'nb_gpu_nodes', 'max_nb_epochs', 'min_nb_epochs', |
| 72 | + 'add_row_log_interval', 'nb_sanity_val_steps' |
| 73 | + ) |
| 74 | + DEPRECATED_IN_0_9 = ('use_amp',) |
73 | 75 |
|
74 | 76 | def __init__( |
75 | 77 | self, |
@@ -466,21 +468,91 @@ def default_attributes(cls): |
466 | 468 |
|
467 | 469 | return args |
468 | 470 |
|
| 471 | + @classmethod |
| 472 | + def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]: |
| 473 | + r"""Scans the Trainer signature and returns argument names, types and default values. |
| 474 | +
|
| 475 | + Returns: |
| 476 | + List with tuples of 3 values: |
| 477 | + (argument name, set with argument types, argument default value). |
| 478 | +
|
| 479 | + Examples: |
| 480 | + >>> args = Trainer.get_init_arguments_and_types() |
| 481 | + >>> import pprint |
| 482 | + >>> pprint.pprint(sorted(args)) # doctest: +ELLIPSIS +NORMALIZE_WHITESPACE |
| 483 | + [('accumulate_grad_batches', |
| 484 | + (<class 'int'>, typing.Dict[int, int], typing.List[list]), |
| 485 | + 1), |
| 486 | + ... |
| 487 | + ('callbacks', (<class 'pytorch_lightning.callbacks.base.Callback'>,), []), |
| 488 | + ('check_val_every_n_epoch', (<class 'int'>,), 1), |
| 489 | + ... |
| 490 | + ('max_epochs', (<class 'int'>,), 1000), |
| 491 | + ... |
| 492 | + ('precision', (<class 'int'>,), 32), |
| 493 | + ('print_nan_grads', (<class 'bool'>,), False), |
| 494 | + ('process_position', (<class 'int'>,), 0), |
| 495 | + ('profiler', |
| 496 | + (<class 'pytorch_lightning.profiler.profiler.BaseProfiler'>, |
| 497 | + <class 'NoneType'>), |
| 498 | + None), |
| 499 | + ... |
| 500 | + """ |
| 501 | + trainer_default_params = inspect.signature(cls).parameters |
| 502 | + name_type_default = [] |
| 503 | + for arg in trainer_default_params: |
| 504 | + arg_type = trainer_default_params[arg].annotation |
| 505 | + arg_default = trainer_default_params[arg].default |
| 506 | + try: |
| 507 | + arg_types = tuple(arg_type.__args__) |
| 508 | + except AttributeError: |
| 509 | + arg_types = (arg_type,) |
| 510 | + |
| 511 | + name_type_default.append((arg, arg_types, arg_default)) |
| 512 | + |
| 513 | + return name_type_default |
| 514 | + |
| 515 | + @classmethod |
| 516 | + def get_deprecated_arg_names(cls) -> List: |
| 517 | + """Returns a list with deprecated Trainer arguments.""" |
| 518 | + depr_arg_names = [] |
| 519 | + for name, val in cls.__dict__.items(): |
| 520 | + if name.startswith('DEPRECATED') and isinstance(val, (tuple, list)): |
| 521 | + depr_arg_names.extend(val) |
| 522 | + return depr_arg_names |
| 523 | + |
469 | 524 | @classmethod |
470 | 525 | def add_argparse_args(cls, parent_parser: ArgumentParser) -> ArgumentParser: |
471 | | - """Extend existing argparse by default `Trainer` attributes.""" |
472 | | - parser = ArgumentParser(parents=[parent_parser], add_help=False) |
| 526 | + r"""Extends existing argparse by default `Trainer` attributes. |
473 | 527 |
|
474 | | - trainer_default_params = Trainer.default_attributes() |
| 528 | + Args: |
| 529 | + parent_parser: |
| 530 | + The custom cli arguments parser, which will be extended by |
| 531 | + the Trainer default arguments. |
| 532 | +
|
| 533 | + Only arguments of the allowed types (str, float, int, bool) will |
| 534 | + extend the `parent_parser`. |
| 535 | + """ |
| 536 | + parser = ArgumentParser(parents=[parent_parser], add_help=False, ) |
475 | 537 |
|
| 538 | + depr_arg_names = cls.get_deprecated_arg_names() |
| 539 | + |
| 540 | + allowed_types = (str, float, int, bool) |
476 | 541 | # TODO: get "help" from docstring :) |
477 | | - for arg in trainer_default_params: |
478 | | - parser.add_argument( |
479 | | - f'--{arg}', |
480 | | - default=trainer_default_params[arg], |
481 | | - dest=arg, |
482 | | - help='autogenerated by pl.Trainer' |
483 | | - ) |
| 542 | + for arg, arg_types, arg_default in cls.get_init_arguments_and_types(): |
| 543 | + if arg not in depr_arg_names: |
| 544 | + for allowed_type in allowed_types: |
| 545 | + if allowed_type in arg_types: |
| 546 | + if allowed_type is bool: |
| 547 | + allowed_type = lambda x: bool(distutils.util.strtobool(x)) |
| 548 | + parser.add_argument( |
| 549 | + f'--{arg}', |
| 550 | + default=arg_default, |
| 551 | + type=allowed_type, |
| 552 | + dest=arg, |
| 553 | + help='autogenerated by pl.Trainer' |
| 554 | + ) |
| 555 | + break |
484 | 556 |
|
485 | 557 | return parser |
486 | 558 |
|
|
0 commit comments