Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ module = [
"pytorch_lightning.callbacks.pruning",
"pytorch_lightning.trainer.evaluation_loop",
"pytorch_lightning.trainer.connectors.logger_connector",
"pytorch_lightning.utilities.argparse",
"pytorch_lightning.utilities.cli",
"pytorch_lightning.utilities.device_dtype_mixin",
"pytorch_lightning.utilities.device_parser",
Expand Down
45 changes: 30 additions & 15 deletions pytorch_lightning/utilities/argparse.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,14 +13,27 @@
# limitations under the License.
import inspect
import os
from abc import ABC
from argparse import _ArgumentGroup, ArgumentParser, Namespace
from contextlib import suppress
from typing import Any, Dict, List, Tuple, Union
from typing import Any, Callable, Dict, List, Tuple, Type, Union

import pytorch_lightning as pl
from pytorch_lightning.utilities.parsing import str_to_bool, str_to_bool_or_int, str_to_bool_or_str


def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
class ParseArgparserDataType(ABC):
def __init__(self, *_: Any, **__: Any) -> None:
pass

@classmethod
def parse_argparser(cls, args: "ArgumentParser") -> Any:
pass


def from_argparse_args(
cls: Type[ParseArgparserDataType], args: Union[Namespace, ArgumentParser], **kwargs: Any
) -> ParseArgparserDataType:
"""Create an instance from CLI arguments.
Eventually use varibles from OS environement which are defined as "PL_<CLASS-NAME>_<CLASS_ARUMENT_NAME>"

Expand Down Expand Up @@ -52,7 +65,7 @@ def from_argparse_args(cls, args: Union[Namespace, ArgumentParser], **kwargs):
return cls(**trainer_kwargs)


def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
def parse_argparser(cls: Type["pl.Trainer"], arg_parser: Union[ArgumentParser, Namespace]) -> Namespace:
"""Parse CLI arguments, required for custom bool types."""
args = arg_parser.parse_args() if isinstance(arg_parser, ArgumentParser) else arg_parser

Expand All @@ -77,7 +90,7 @@ def parse_argparser(cls, arg_parser: Union[ArgumentParser, Namespace]) -> Namesp
return Namespace(**modified_args)


def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
def parse_env_variables(cls: Type["pl.Trainer"], template: str = "PL_%(cls_name)s_%(cls_argument)s") -> Namespace:
"""Parse environment arguments if they are defined.

Example:
Expand Down Expand Up @@ -106,7 +119,7 @@ def parse_env_variables(cls, template: str = "PL_%(cls_name)s_%(cls_argument)s")
return Namespace(**env_args)


def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
def get_init_arguments_and_types(cls: Any) -> List[Tuple[str, Tuple, Any]]:
r"""Scans the class signature and returns argument names, types and default values.

Returns:
Expand Down Expand Up @@ -134,7 +147,7 @@ def get_init_arguments_and_types(cls) -> List[Tuple[str, Tuple, Any]]:
return name_type_default


def _get_abbrev_qualified_cls_name(cls):
def _get_abbrev_qualified_cls_name(cls: Any) -> str:
assert isinstance(cls, type), repr(cls)
if cls.__module__.startswith("pytorch_lightning."):
# Abbreviate.
Expand All @@ -143,7 +156,9 @@ def _get_abbrev_qualified_cls_name(cls):
return f"{cls.__module__}.{cls.__qualname__}"


def add_argparse_args(cls, parent_parser: ArgumentParser, *, use_argument_group=True) -> ArgumentParser:
def add_argparse_args(
cls: Type["pl.Trainer"], parent_parser: ArgumentParser, *, use_argument_group: bool = True
) -> Union[_ArgumentGroup, ArgumentParser]:
r"""Extends existing argparse by default attributes for ``cls``.

Args:
Expand Down Expand Up @@ -187,7 +202,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, *, use_argument_group=
raise RuntimeError("Please only pass an ArgumentParser instance.")
if use_argument_group:
group_name = _get_abbrev_qualified_cls_name(cls)
parser = parent_parser.add_argument_group(group_name)
parser: Union[_ArgumentGroup, ArgumentParser] = parent_parser.add_argument_group(group_name)
else:
parser = ArgumentParser(parents=[parent_parser], add_help=False)

Expand All @@ -207,16 +222,16 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, *, use_argument_group=
args_help = _parse_args_from_docstring(cls.__init__.__doc__ or cls.__doc__ or "")

for arg, arg_types, arg_default in args_and_types:
arg_types = [at for at in allowed_types if at in arg_types]
arg_types = tuple(at for at in allowed_types if at in arg_types)
if not arg_types:
# skip argument with not supported type
continue
arg_kwargs = {}
arg_kwargs: Dict[str, Any] = {}
if bool in arg_types:
arg_kwargs.update(nargs="?", const=True)
# if the only arg type is bool
if len(arg_types) == 1:
use_type = str_to_bool
use_type: Callable[[str], Union[bool, int, float, str]] = str_to_bool
elif int in arg_types:
use_type = str_to_bool_or_int
elif str in arg_types:
Expand Down Expand Up @@ -249,7 +264,7 @@ def add_argparse_args(cls, parent_parser: ArgumentParser, *, use_argument_group=

def _parse_args_from_docstring(docstring: str) -> Dict[str, str]:
arg_block_indent = None
current_arg = None
current_arg = ""
parsed = {}
for line in docstring.split("\n"):
stripped = line.lstrip()
Expand All @@ -270,20 +285,20 @@ def _parse_args_from_docstring(docstring: str) -> Dict[str, str]:
return parsed


def _gpus_allowed_type(x) -> Union[int, str]:
def _gpus_allowed_type(x: str) -> Union[int, str]:
if "," in x:
return str(x)
return int(x)


def _gpus_arg_default(x) -> Union[int, str]: # pragma: no-cover
def _gpus_arg_default(x: str) -> Union[int, str]: # pragma: no-cover
# unused, but here for backward compatibility with old checkpoints that need to be able to
# unpickle the function from the checkpoint, as it was not filtered out in versions < 1.2.8
# see: https://github.com/PyTorchLightning/pytorch-lightning/pull/6898
pass


def _int_or_float_type(x) -> Union[int, float]:
def _int_or_float_type(x: Union[int, float, str]) -> Union[int, float]:
if "." in str(x):
return float(x)
return int(x)