|
1 | 1 | .. testsetup:: * |
2 | 2 | :skipif: not _JSONARGPARSE_AVAILABLE |
3 | 3 |
|
| 4 | + import torch |
4 | 5 | from unittest import mock |
5 | 6 | from typing import List |
6 | 7 | from pytorch_lightning.core.lightning import LightningModule |
@@ -385,7 +386,7 @@ instantiating the trainer class can be found in :code:`self.config['trainer']`. |
385 | 386 |
|
386 | 387 |
|
387 | 388 | Configurable callbacks |
388 | | -~~~~~~~~~~~~~~~~~~~~~~ |
| 389 | +^^^^^^^^^^^^^^^^^^^^^^ |
389 | 390 |
|
390 | 391 | As explained previously, any callback can be added by including it in the config via :code:`class_path` and |
391 | 392 | :code:`init_args` entries. However, there are other cases in which a callback should always be present and be |
@@ -417,7 +418,7 @@ To change the configuration of the :code:`EarlyStopping` in the config it would |
417 | 418 |
|
418 | 419 |
|
419 | 420 | Argument linking |
420 | | -~~~~~~~~~~~~~~~~ |
| 421 | +^^^^^^^^^^^^^^^^ |
421 | 422 |
|
422 | 423 | Another case in which it might be desired to extend :class:`~pytorch_lightning.utilities.cli.LightningCLI` is that the |
423 | 424 | model and data module depend on a common parameter. For example in some cases both classes require to know the |
@@ -470,3 +471,117 @@ Instantiation links are used to automatically determine the order of instantiati |
470 | 471 | The linking of arguments can be used for more complex cases. For example to derive a value via a function that takes |
471 | 472 | multiple settings as input. For more details have a look at the API of `link_arguments |
472 | 473 | <https://jsonargparse.readthedocs.io/en/stable/#jsonargparse.core.ArgumentParser.link_arguments>`_. |
| 474 | + |
| 475 | + |
| 476 | +Optimizers and learning rate schedulers |
| 477 | +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ |
| 478 | + |
| 479 | +Optimizers and learning rate schedulers can also be made configurable. The most common case is when a model only has a |
| 480 | +single optimizer and optionally a single learning rate scheduler. In this case the model's |
| 481 | +:class:`~pytorch_lightning.core.lightning.LightningModule` could be left without implementing the |
| 482 | +:code:`configure_optimizers` method since it is normally always the same and just adds boilerplate. The following code |
| 483 | +snippet shows how to implement it: |
| 484 | + |
| 485 | +.. testcode:: |
| 486 | + |
| 487 | + import torch |
| 488 | + from pytorch_lightning.utilities.cli import LightningCLI |
| 489 | + |
| 490 | + class MyLightningCLI(LightningCLI): |
| 491 | + |
| 492 | + def add_arguments_to_parser(self, parser): |
| 493 | + parser.add_optimizer_args(torch.optim.Adam) |
| 494 | + parser.add_lr_scheduler_args(torch.optim.lr_scheduler.ExponentialLR) |
| 495 | + |
| 496 | + cli = MyLightningCLI(MyModel) |
| 497 | + |
| 498 | +With this the :code:`configure_optimizers` method is automatically implemented and in the config the :code:`optimizer` |
| 499 | +and :code:`lr_scheduler` groups would accept all of the options for the given classes, in this example :code:`Adam` and |
| 500 | +:code:`ExponentialLR`. Therefore, the config file would be structured like: |
| 501 | + |
| 502 | +.. code-block:: yaml |
| 503 | +
|
| 504 | + optimizer: |
| 505 | + lr: 0.01 |
| 506 | + lr_scheduler: |
| 507 | + gamma: 0.2 |
| 508 | + model: |
| 509 | + ... |
| 510 | + trainer: |
| 511 | + ... |
| 512 | +
|
| 513 | +And any of these arguments could be passed directly through command line. For example: |
| 514 | + |
| 515 | +.. code-block:: bash |
| 516 | +
|
| 517 | + $ python train.py --optimizer.lr=0.01 --lr_scheduler.gamma=0.2 |
| 518 | +
|
| 519 | +There is also the possibility of selecting among multiple classes by giving them as a tuple. For example: |
| 520 | + |
| 521 | +.. testcode:: |
| 522 | + |
| 523 | + class MyLightningCLI(LightningCLI): |
| 524 | + |
| 525 | + def add_arguments_to_parser(self, parser): |
| 526 | + parser.add_optimizer_args((torch.optim.SGD, torch.optim.Adam)) |
| 527 | + |
| 528 | +In this case in the config the :code:`optimizer` group instead of having directly init settings, it should specify |
| 529 | +:code:`class_path` and optionally :code:`init_args`. Sub-classes of the classes in the tuple would also be accepted. |
| 530 | +A corresponding example of the config file would be: |
| 531 | + |
| 532 | +.. code-block:: yaml |
| 533 | +
|
| 534 | + optimizer: |
| 535 | + class_path: torch.optim.Adam |
| 536 | + init_args: |
| 537 | + lr: 0.01 |
| 538 | + model: |
| 539 | + ... |
| 540 | + trainer: |
| 541 | + ... |
| 542 | +
|
| 543 | +And the same through command line: |
| 544 | + |
| 545 | +.. code-block:: bash |
| 546 | +
|
| 547 | + $ python train.py --optimizer='{class_path: torch.optim.Adam, init_args: {lr: 0.01}}' |
| 548 | +
|
| 549 | +The automatic implementation of :code:`configure_optimizers` can be disabled by linking the configuration group. An |
| 550 | +example can be :code:`ReduceLROnPlateau` which requires to specify a monitor. This would be: |
| 551 | + |
| 552 | +.. testcode:: |
| 553 | + |
| 554 | + from pytorch_lightning.utilities.cli import instantiate_class, LightningCLI |
| 555 | + |
| 556 | + class MyModel(LightningModule): |
| 557 | + |
| 558 | + def __init__(self, optimizer_init: dict, lr_scheduler_init: dict): |
| 559 | + super().__init__() |
| 560 | + self.optimizer_init = optimizer_init |
| 561 | + self.lr_scheduler_init = lr_scheduler_init |
| 562 | + |
| 563 | + def configure_optimizers(self): |
| 564 | + optimizer = instantiate_class(self.parameters(), self.optimizer_init) |
| 565 | + scheduler = instantiate_class(optimizer, self.lr_scheduler_init) |
| 566 | + return {"optimizer": optimizer, "lr_scheduler": scheduler, "monitor": "metric_to_track"} |
| 567 | + |
| 568 | + class MyLightningCLI(LightningCLI): |
| 569 | + |
| 570 | + def add_arguments_to_parser(self, parser): |
| 571 | + parser.add_optimizer_args( |
| 572 | + torch.optim.Adam, |
| 573 | + link_to='model.optimizer_init', |
| 574 | + ) |
| 575 | + parser.add_lr_scheduler_args( |
| 576 | + torch.optim.lr_scheduler.ReduceLROnPlateau, |
| 577 | + link_to='model.lr_scheduler_init', |
| 578 | + ) |
| 579 | + |
| 580 | + cli = MyLightningCLI(MyModel) |
| 581 | + |
| 582 | +For both possibilities of using :meth:`pytorch_lightning.utilities.cli.LightningArgumentParser.add_optimizer_args` with |
| 583 | +a single class or a tuple of classes, the value given to :code:`optimizer_init` will always be a dictionary including |
| 584 | +:code:`class_path` and :code:`init_args` entries. The function |
| 585 | +:func:`~pytorch_lightning.utilities.cli.instantiate_class` takes care of importing the class defined in |
| 586 | +:code:`class_path` and instantiating it using some positional arguments, in this case :code:`self.parameters()`, and the |
| 587 | +:code:`init_args`. Any number of optimizers and learning rate schedulers can be added when using :code:`link_to`. |
0 commit comments