Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions pytorch_lightning/callbacks/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,3 +325,6 @@ def on_before_optimizer_step(

def on_before_zero_grad(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None:
"""Called before ``optimizer.zero_grad()``."""

def _on_before_launch(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
"""For internal use."""
3 changes: 3 additions & 0 deletions pytorch_lightning/trainer/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -670,6 +670,9 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs:
"""
try:
if self.strategy.launcher is not None:
# used in the CLI's `SaveConfigCallback` so that it still works with DDP spawn
self._call_callback_hooks("_on_before_launch")

return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
else:
return trainer_fn(*args, **kwargs)
Expand Down
22 changes: 17 additions & 5 deletions pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

import pytorch_lightning as pl
from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer
from pytorch_lightning.strategies import DDPSpawnStrategy
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE
Expand Down Expand Up @@ -414,9 +415,7 @@ def __init__(
self.overwrite = overwrite
self.multifile = multifile

def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
# save the config in `setup` because (1) we want it to save regardless of the trainer function run
# and we want to save before processes are spawned
def save(self, trainer: "pl.Trainer") -> None:
log_dir = trainer.log_dir # this broadcasts the directory
assert log_dir is not None
config_path = os.path.join(log_dir, self.config_filename)
Expand Down Expand Up @@ -445,9 +444,22 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st
self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile
)

def _on_before_launch(self, trainer: "pl.Trainer", *_: Any) -> None:
if isinstance(trainer.strategy, DDPSpawnStrategy):
self.save(trainer)

def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None:
if not isinstance(trainer.strategy, DDPSpawnStrategy):
self.save(trainer)

def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]:
# `ArgumentParser` is un-pickleable. Drop it
return self.__class__, (None, self.config, self.config_filename), {}
# `ArgumentParser` is un-pickleable. Since we will be dropping it when DDP spawn is used, DDP spawn needs to
# save the config file before this is called, and for this, we use an internal hook `on_before_launch`.
return (
self.__class__,
(None, self.config, self.config_filename),
{"overwrite": self.overwrite, "multifile": self.multifile},
)


class LightningCLI:
Expand Down
28 changes: 12 additions & 16 deletions tests/utilities/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,8 @@
SaveConfigCallback,
)
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCHVISION_AVAILABLE
from tests.helpers import BoringDataModule, BoringModel
from tests.helpers.runif import RunIf
from tests.helpers.utils import no_warning_call

torchvision_version = version.parse("0")
Expand Down Expand Up @@ -577,20 +576,15 @@ def on_fit_start(self):


@pytest.mark.parametrize("logger", (False, True))
@pytest.mark.parametrize(
"trainer_kwargs",
(
# dict(strategy="ddp_spawn")
# dict(strategy="ddp")
# the previous accl_conn will choose singleDeviceStrategy for both strategy=ddp/ddp_spawn
# TODO revisit this test as it never worked with DDP or DDPSpawn
dict(strategy="single_device"),
pytest.param({"tpu_cores": 1}, marks=RunIf(tpu=True)),
),
)
def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
@pytest.mark.parametrize("strategy", ("ddp_spawn", "ddp"))
def test_cli_distributed_save_config_callback(tmpdir, logger, strategy):
if _TORCH_GREATER_EQUAL_1_8:
from torch.multiprocessing import ProcessRaisedException
else:
ProcessRaisedException = Exception

with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises(
MisconfigurationException, match=r"Error on fit start"
(MisconfigurationException, ProcessRaisedException), match=r"Error on fit start"
):
LightningCLI(
EarlyExitTestModel,
Expand All @@ -599,7 +593,9 @@ def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs):
"logger": logger,
"max_steps": 1,
"max_epochs": 1,
**trainer_kwargs,
"strategy": strategy,
"accelerator": "auto",
"devices": 1,
},
)
if logger:
Expand Down