Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ matplotlib>3.1, <3.5.3
torchtext>=0.10.*, <=0.12.0
omegaconf>=2.0.5, <2.3.0
hydra-core>=1.0.5, <1.3.0
jsonargparse[signatures]>=4.10.2, <=4.10.2
jsonargparse[signatures]>=4.12.0, <=4.12.0
gcsfs>=2021.5.0, <2022.6.0
rich>=10.14.0, !=10.15.0.a, <13.0.0
2 changes: 1 addition & 1 deletion src/pytorch_lightning/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn

_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.10.2")
_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.12.0")

if _JSONARGPARSE_SIGNATURES_AVAILABLE:
import docstring_parser
Expand Down
45 changes: 36 additions & 9 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -1414,21 +1414,47 @@ def _test_logger_init_args(logger_name, init, unresolved={}):

@pytest.mark.skipif(not _COMET_AVAILABLE, reason="comet-ml is required")
def test_comet_logger_init_args():
_test_logger_init_args("CometLogger", {"save_dir": "comet", "workspace": "comet"})
_test_logger_init_args(
"CometLogger",
{
"save_dir": "comet", # Resolve from CometLogger.__init__
"workspace": "comet", # Resolve from Comet{,Existing,Offline}Experiment.__init__
},
)


@pytest.mark.skipif(not _NEPTUNE_AVAILABLE, reason="neptune-client is required")
def test_neptune_logger_init_args():
_test_logger_init_args("NeptuneLogger", {"name": "neptune"}, {"description": "neptune"})
_test_logger_init_args(
"NeptuneLogger",
{
"name": "neptune", # Resolve from NeptuneLogger.__init__
},
{
"description": "neptune", # Unsupported resolving from neptune.new.internal.init.run.init_run
},
)


def test_tensorboard_logger_init_args():
_test_logger_init_args("TensorBoardLogger", {"save_dir": "tb", "name": "tb"})
_test_logger_init_args(
"TensorBoardLogger",
{
"save_dir": "tb", # Resolve from TensorBoardLogger.__init__
"comment": "tb", # Resolve from tensorboard.writer.SummaryWriter.__init__
},
)


@pytest.mark.skipif(not _WANDB_AVAILABLE, reason="wandb is required")
def test_wandb_logger_init_args():
_test_logger_init_args("WandbLogger", {"save_dir": "wandb", "notes": "wandb"})
_test_logger_init_args(
"WandbLogger",
{
"save_dir": "wandb", # Resolve from WandbLogger.__init__
"notes": "wandb", # Resolve from wandb.sdk.wandb_init.init
},
)


def test_cli_auto_seeding():
Expand Down Expand Up @@ -1512,13 +1538,13 @@ def test_pytorch_profiler_init_args():
from pytorch_lightning.profilers import Profiler, PyTorchProfiler

init = {
"dirpath": "profiler",
"row_limit": 10,
"group_by_input_shapes": True,
"dirpath": "profiler", # Resolve from PyTorchProfiler.__init__
"row_limit": 10, # Resolve from PyTorchProfiler.__init__
"group_by_input_shapes": True, # Resolve from PyTorchProfiler.__init__
}
unresolved = {
"profile_memory": True,
"record_shapes": True,
"profile_memory": True, # Not possible to resolve parameters from dynamically chosen Type[_PROFILER]
"record_shapes": True, # Resolve from PyTorchProfiler.__init__, gets moved to init_args
}
cli_args = ["--trainer.profiler=PyTorchProfiler"]
cli_args += [f"--trainer.profiler.{k}={v}" for k, v in init.items()]
Expand All @@ -1528,5 +1554,6 @@ def test_pytorch_profiler_init_args():
cli = LightningCLI(TestModel, run=False)

assert isinstance(cli.config_init.trainer.profiler, PyTorchProfiler)
init["record_shapes"] = unresolved.pop("record_shapes") # Test move to init_args
assert {k: cli.config.trainer.profiler.init_args[k] for k in init} == init
assert cli.config.trainer.profiler.dict_kwargs == unresolved