Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .azure/gpu-tests-pytorch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ jobs:
set -e
extra=$(python -c "print({'lightning': 'pytorch-'}.get('$(PACKAGE_NAME)', ''))")
pip install -e ".[${extra}dev]" pytest-timeout -U --extra-index-url="${TORCH_URL}"
pip install setuptools==75.6.0 jsonargparse==4.35.0
displayName: "Install package & dependencies"

- bash: pip uninstall -y lightning
Expand Down
2 changes: 1 addition & 1 deletion dockers/base-cuda/Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ ENV \
MAKEFLAGS="-j2"

RUN \
apt-get update && apt-get install -y wget && \
apt-get update --fix-missing && apt-get install -y wget && \
apt-get update -qq --fix-missing && \
NCCL_VER=$(dpkg -s libnccl2 | grep '^Version:' | awk -F ' ' '{print $2}' | awk -F '-' '{print $1}' | grep -ve '^\s*$') && \
CUDA_VERSION_MM=${CUDA_VERSION%.*} && \
Expand Down
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
matplotlib>3.1, <3.10.0
omegaconf >=2.2.3, <2.4.0
hydra-core >=1.2.0, <1.4.0
jsonargparse[signatures] >=4.28.0, <=4.40.0
jsonargparse[signatures] >=4.39.0, <4.40.0
rich >=12.3.0, <14.1.0
tensorboardX >=2.2, <2.7.0 # min version is set by torch.onnx missing attribute
bitsandbytes >=0.45.2,<0.45.3; platform_system != "Darwin"
5 changes: 4 additions & 1 deletion src/lightning/pytorch/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Fixed

- Fixed logger_connector has edge case where step can be a float ([#20692](https://github.com/Lightning-AI/pytorch-lightning/issues/20692))
- Fixed `save_hyperparameters` not working correctly with `LightningCLI` when there are parsing links applied on instantiation ([#20777](https://github.com/Lightning-AI/pytorch-lightning/pull/20777))


- Fixed logger_connector has edge case where step can be a float ([#20692](https://github.com/Lightning-AI/pytorch-lightning/pull/20692))


- Fix: Synchronize SIGTERM Handling in DDP to Prevent Deadlocks ([#20825](https://github.com/Lightning-AI/pytorch-lightning/pull/20825))
Expand Down
40 changes: 37 additions & 3 deletions src/lightning/pytorch/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -327,6 +327,7 @@ def __init__(
args: ArgsType = None,
run: bool = True,
auto_configure_optimizers: bool = True,
load_from_checkpoint_support: bool = True,
) -> None:
"""Receives as input pytorch-lightning classes (or callables which return pytorch-lightning classes), which are
called / instantiated using a parsed configuration file and / or command line args.
Expand Down Expand Up @@ -367,6 +368,11 @@ def __init__(
``dict`` or ``jsonargparse.Namespace``.
run: Whether subcommands should be added to run a :class:`~lightning.pytorch.trainer.trainer.Trainer`
method. If set to ``False``, the trainer and model classes will be instantiated only.
auto_configure_optimizers: Whether to automatically add default optimizer and lr_scheduler arguments.
load_from_checkpoint_support: Whether ``save_hyperparameters`` should save the original parsed
hyperparameters (instead of what ``__init__`` receives), such that it is possible for
``load_from_checkpoint`` to correctly instantiate classes even when using complex nesting and
dependency injection.

"""
self.save_config_callback = save_config_callback
Expand Down Expand Up @@ -396,7 +402,8 @@ def __init__(

self._set_seed()

self._add_instantiators()
if load_from_checkpoint_support:
self._add_instantiators()
self.before_instantiate_classes()
self.instantiate_classes()
self.after_instantiate_classes()
Expand Down Expand Up @@ -544,11 +551,14 @@ def parse_arguments(self, parser: LightningArgumentParser, args: ArgsType) -> No
else:
self.config = parser.parse_args(args)

def _add_instantiators(self) -> None:
def _dump_config(self) -> None:
if hasattr(self, "config_dump"):
return
self.config_dump = yaml.safe_load(self.parser.dump(self.config, skip_link_targets=False, skip_none=False))
if "subcommand" in self.config:
self.config_dump = self.config_dump[self.config.subcommand]

def _add_instantiators(self) -> None:
self.parser.add_instantiator(
_InstantiatorFn(cli=self, key="model"),
_get_module_type(self._model_class),
Expand Down Expand Up @@ -799,12 +809,27 @@ def _get_module_type(value: Union[Callable, type]) -> type:
return value


def _set_dict_nested(data: dict, key: str, value: Any) -> None:
keys = key.split(".")
for k in keys[:-1]:
assert k in data, f"Expected key {key} to be in data"
data = data[k]
data[keys[-1]] = value


class _InstantiatorFn:
def __init__(self, cli: LightningCLI, key: str) -> None:
self.cli = cli
self.key = key

def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> ModuleType:
def __call__(
self,
class_type: type[ModuleType],
*args: Any,
applied_instantiation_links: dict,
**kwargs: Any,
) -> ModuleType:
self.cli._dump_config()
hparams = self.cli.config_dump.get(self.key, {})
if "class_path" in hparams:
# To make hparams backwards compatible, and so that it is the same irrespective of subclass_mode, the
Expand All @@ -815,6 +840,15 @@ def __call__(self, class_type: type[ModuleType], *args: Any, **kwargs: Any) -> M
**hparams.get("init_args", {}),
**hparams.get("dict_kwargs", {}),
}
# get instantiation link target values from kwargs
for key, value in applied_instantiation_links.items():
if not key.startswith(f"{self.key}."):
continue
key = key[len(f"{self.key}.") :]
if key.startswith("init_args."):
key = key[len("init_args.") :]
_set_dict_nested(hparams, key, value)

with _given_hyperparameters_context(
hparams=hparams,
instantiator="lightning.pytorch.cli.instantiate_module",
Expand Down
79 changes: 74 additions & 5 deletions tests/tests_pytorch/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,6 +550,7 @@ def __init__(self, activation: torch.nn.Module = None, transform: Optional[list[
class BoringModelRequiredClasses(BoringModel):
def __init__(self, num_classes: int, batch_size: int = 8):
super().__init__()
self.save_hyperparameters()
self.num_classes = num_classes
self.batch_size = batch_size

Expand All @@ -561,35 +562,103 @@ def __init__(self, batch_size: int = 8):
self.num_classes = 5 # only available after instantiation


def test_lightning_cli_link_arguments():
def test_lightning_cli_link_arguments(cleandir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.batch_size", "model.batch_size")
parser.link_arguments("data.num_classes", "model.num_classes", apply_on="instantiate")

cli_args = ["--data.batch_size=12"]
cli_args = ["--data.batch_size=12", "--trainer.max_epochs=1"]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, run=False)

assert cli.model.batch_size == 12
assert cli.model.num_classes == 5

class MyLightningCLI(LightningCLI):
cli.trainer.fit(cli.model)
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())

hparams.pop("_instantiator")
assert hparams == {"batch_size": 12, "num_classes": 5}

class MyLightningCLI2(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments("data.batch_size", "model.init_args.batch_size")
parser.link_arguments("data.num_classes", "model.init_args.num_classes", apply_on="instantiate")

cli_args[-1] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"
cli_args[0] = "--model=tests_pytorch.test_cli.BoringModelRequiredClasses"

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(
cli = MyLightningCLI2(
BoringModelRequiredClasses, BoringDataModuleBatchSizeAndClasses, subclass_mode_model=True, run=False
)

assert cli.model.batch_size == 8
assert cli.model.num_classes == 5

cli.trainer.fit(cli.model)
hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())

hparams.pop("_instantiator")
assert hparams == {"batch_size": 8, "num_classes": 5}


class CustomAdam(torch.optim.Adam):
def __init__(self, params, num_classes: Optional[int] = None, **kwargs):
super().__init__(params, **kwargs)


class DeepLinkTargetModel(BoringModel):
def __init__(
self,
optimizer: OptimizerCallable = torch.optim.Adam,
):
super().__init__()
self.save_hyperparameters()
self.optimizer = optimizer

def configure_optimizers(self):
optimizer = self.optimizer(self.parameters())
return {"optimizer": optimizer}


def test_lightning_cli_link_arguments_subcommands_nested_target(cleandir):
class MyLightningCLI(LightningCLI):
def add_arguments_to_parser(self, parser):
parser.link_arguments(
"data.num_classes",
"model.init_args.optimizer.init_args.num_classes",
apply_on="instantiate",
)

cli_args = [
"fit",
"--data.batch_size=12",
"--trainer.max_epochs=1",
"--model=tests_pytorch.test_cli.DeepLinkTargetModel",
"--model.optimizer=tests_pytorch.test_cli.CustomAdam",
]

with mock.patch("sys.argv", ["any.py"] + cli_args):
cli = MyLightningCLI(
DeepLinkTargetModel,
BoringDataModuleBatchSizeAndClasses,
subclass_mode_model=True,
auto_configure_optimizers=False,
)

hparams_path = Path(cli.trainer.log_dir) / "hparams.yaml"
assert hparams_path.is_file()
hparams = yaml.safe_load(hparams_path.read_text())

assert hparams["optimizer"]["class_path"] == "tests_pytorch.test_cli.CustomAdam"
assert hparams["optimizer"]["init_args"]["num_classes"] == 5


class EarlyExitTestModel(BoringModel):
def on_fit_start(self):
Expand Down
Loading