|
19 | 19 | import logging |
20 | 20 | import os |
21 | 21 | import tempfile |
| 22 | +import types |
22 | 23 | import uuid |
23 | 24 | from abc import ABC |
24 | 25 | from argparse import Namespace |
25 | 26 | from functools import partial |
26 | 27 | from pathlib import Path |
27 | | -from typing import Any, Callable, Dict, List, Optional, Tuple, TYPE_CHECKING, Union |
| 28 | +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union |
28 | 29 |
|
29 | 30 | import torch |
30 | 31 | from torch import ScriptModule, Tensor |
@@ -1591,55 +1592,84 @@ def _auto_collect_arguments(cls, frame=None) -> Tuple[Dict, Dict]: |
1591 | 1592 | parents_arguments.update(args) |
1592 | 1593 | return self_arguments, parents_arguments |
1593 | 1594 |
|
1594 | | - def save_hyperparameters(self, *args, frame=None) -> None: |
1595 | | - """Save all model arguments. |
| 1595 | + def save_hyperparameters( |
| 1596 | + self, |
| 1597 | + *args, |
| 1598 | + ignore: Optional[Union[Sequence[str], str]] = None, |
| 1599 | + frame: Optional[types.FrameType] = None |
| 1600 | + ) -> None: |
| 1601 | + """Save model arguments to ``hparams`` attribute. |
1596 | 1602 |
|
1597 | 1603 | Args: |
1598 | 1604 | args: single object of `dict`, `NameSpace` or `OmegaConf` |
1599 | | - or string names or arguments from class `__init__` |
1600 | | -
|
1601 | | - >>> class ManuallyArgsModel(LightningModule): |
1602 | | - ... def __init__(self, arg1, arg2, arg3): |
1603 | | - ... super().__init__() |
1604 | | - ... # manually assign arguments |
1605 | | - ... self.save_hyperparameters('arg1', 'arg3') |
1606 | | - ... def forward(self, *args, **kwargs): |
1607 | | - ... ... |
1608 | | - >>> model = ManuallyArgsModel(1, 'abc', 3.14) |
1609 | | - >>> model.hparams |
1610 | | - "arg1": 1 |
1611 | | - "arg3": 3.14 |
1612 | | -
|
1613 | | - >>> class AutomaticArgsModel(LightningModule): |
1614 | | - ... def __init__(self, arg1, arg2, arg3): |
1615 | | - ... super().__init__() |
1616 | | - ... # equivalent automatic |
1617 | | - ... self.save_hyperparameters() |
1618 | | - ... def forward(self, *args, **kwargs): |
1619 | | - ... ... |
1620 | | - >>> model = AutomaticArgsModel(1, 'abc', 3.14) |
1621 | | - >>> model.hparams |
1622 | | - "arg1": 1 |
1623 | | - "arg2": abc |
1624 | | - "arg3": 3.14 |
1625 | | -
|
1626 | | - >>> class SingleArgModel(LightningModule): |
1627 | | - ... def __init__(self, params): |
1628 | | - ... super().__init__() |
1629 | | - ... # manually assign single argument |
1630 | | - ... self.save_hyperparameters(params) |
1631 | | - ... def forward(self, *args, **kwargs): |
1632 | | - ... ... |
1633 | | - >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) |
1634 | | - >>> model.hparams |
1635 | | - "p1": 1 |
1636 | | - "p2": abc |
1637 | | - "p3": 3.14 |
| 1605 | + or string names or arguments from class ``__init__`` |
| 1606 | + ignore: an argument name or a list of argument names from |
| 1607 | + class ``__init__`` to be ignored |
| 1608 | + frame: a frame object. Default is None |
| 1609 | +
|
| 1610 | + Example:: |
| 1611 | + >>> class ManuallyArgsModel(LightningModule): |
| 1612 | + ... def __init__(self, arg1, arg2, arg3): |
| 1613 | + ... super().__init__() |
| 1614 | + ... # manually assign arguments |
| 1615 | + ... self.save_hyperparameters('arg1', 'arg3') |
| 1616 | + ... def forward(self, *args, **kwargs): |
| 1617 | + ... ... |
| 1618 | + >>> model = ManuallyArgsModel(1, 'abc', 3.14) |
| 1619 | + >>> model.hparams |
| 1620 | + "arg1": 1 |
| 1621 | + "arg3": 3.14 |
| 1622 | +
|
| 1623 | + >>> class AutomaticArgsModel(LightningModule): |
| 1624 | + ... def __init__(self, arg1, arg2, arg3): |
| 1625 | + ... super().__init__() |
| 1626 | + ... # equivalent automatic |
| 1627 | + ... self.save_hyperparameters() |
| 1628 | + ... def forward(self, *args, **kwargs): |
| 1629 | + ... ... |
| 1630 | + >>> model = AutomaticArgsModel(1, 'abc', 3.14) |
| 1631 | + >>> model.hparams |
| 1632 | + "arg1": 1 |
| 1633 | + "arg2": abc |
| 1634 | + "arg3": 3.14 |
| 1635 | +
|
| 1636 | + >>> class SingleArgModel(LightningModule): |
| 1637 | + ... def __init__(self, params): |
| 1638 | + ... super().__init__() |
| 1639 | + ... # manually assign single argument |
| 1640 | + ... self.save_hyperparameters(params) |
| 1641 | + ... def forward(self, *args, **kwargs): |
| 1642 | + ... ... |
| 1643 | + >>> model = SingleArgModel(Namespace(p1=1, p2='abc', p3=3.14)) |
| 1644 | + >>> model.hparams |
| 1645 | + "p1": 1 |
| 1646 | + "p2": abc |
| 1647 | + "p3": 3.14 |
| 1648 | +
|
| 1649 | + >>> class ManuallyArgsModel(LightningModule): |
| 1650 | + ... def __init__(self, arg1, arg2, arg3): |
| 1651 | + ... super().__init__() |
| 1652 | + ... # pass argument(s) to ignore as a string or in a list |
| 1653 | + ... self.save_hyperparameters(ignore='arg2') |
| 1654 | + ... def forward(self, *args, **kwargs): |
| 1655 | + ... ... |
| 1656 | + >>> model = ManuallyArgsModel(1, 'abc', 3.14) |
| 1657 | + >>> model.hparams |
| 1658 | + "arg1": 1 |
| 1659 | + "arg3": 3.14 |
1638 | 1660 | """ |
1639 | 1661 | if not frame: |
1640 | 1662 | frame = inspect.currentframe().f_back |
1641 | 1663 | init_args = get_init_args(frame) |
1642 | 1664 | assert init_args, "failed to inspect the self init" |
| 1665 | + |
| 1666 | + if ignore is not None: |
| 1667 | + if isinstance(ignore, str): |
| 1668 | + ignore = [ignore] |
| 1669 | + if isinstance(ignore, (list, tuple)): |
| 1670 | + ignore = [arg for arg in ignore if isinstance(arg, str)] |
| 1671 | + init_args = {k: v for k, v in init_args.items() if k not in ignore} |
| 1672 | + |
1643 | 1673 | if not args: |
1644 | 1674 | # take all arguments |
1645 | 1675 | hp = init_args |
|
0 commit comments