From 7a0fed1e253b4e19ed2c2ef7ebb9d339428bb8b7 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Feb 2022 12:58:16 +0100 Subject: [PATCH 1/4] Restore test --- pytorch_lightning/utilities/cli.py | 6 +++++- tests/utilities/test_cli.py | 19 +++++-------------- 2 files changed, 10 insertions(+), 15 deletions(-) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index a9c06febbb9b4..0332679c0741a 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -447,7 +447,11 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]: # `ArgumentParser` is un-pickleable. Drop it - return self.__class__, (None, self.config, self.config_filename), {} + return ( + self.__class__, + (None, self.config, self.config_filename), + {"overwrite": self.overwrite, "multifile": self.multifile}, + ) class LightningCLI: diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 2803c0c4601c1..250fed33ecb45 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -52,7 +52,6 @@ from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel -from tests.helpers.runif import RunIf from tests.helpers.utils import no_warning_call torchvision_version = version.parse("0") @@ -577,18 +576,8 @@ def on_fit_start(self): @pytest.mark.parametrize("logger", (False, True)) -@pytest.mark.parametrize( - "trainer_kwargs", - ( - # dict(strategy="ddp_spawn") - # dict(strategy="ddp") - # the previous accl_conn will choose singleDeviceStrategy for both strategy=ddp/ddp_spawn - # TODO revisit this test as it never worked with DDP or DDPSpawn - dict(strategy="single_device"), - pytest.param({"tpu_cores": 1}, marks=RunIf(tpu=True)), - ), -) -def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs): +@pytest.mark.parametrize("strategy", ("ddp_spawn", "ddp")) +def test_cli_distributed_save_config_callback(tmpdir, logger, strategy): with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises( MisconfigurationException, match=r"Error on fit start" ): @@ -599,7 +588,9 @@ def test_cli_distributed_save_config_callback(tmpdir, logger, trainer_kwargs): "logger": logger, "max_steps": 1, "max_epochs": 1, - **trainer_kwargs, + "strategy": strategy, + "accelerator": "cpu", + "devices": 1, }, ) if logger: From 068fcc9da99c93920f58b84361884d8b6ee67cbd Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Feb 2022 14:53:11 +0100 Subject: [PATCH 2/4] Introduce on_before_launch --- pytorch_lightning/callbacks/base.py | 3 +++ pytorch_lightning/trainer/trainer.py | 3 +++ pytorch_lightning/utilities/cli.py | 16 ++++++++++++---- tests/utilities/test_cli.py | 3 ++- 4 files changed, 20 insertions(+), 5 deletions(-) diff --git a/pytorch_lightning/callbacks/base.py b/pytorch_lightning/callbacks/base.py index a24fef72e5b36..017d447c94702 100644 --- a/pytorch_lightning/callbacks/base.py +++ b/pytorch_lightning/callbacks/base.py @@ -325,3 +325,6 @@ def on_before_optimizer_step( def on_before_zero_grad(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", optimizer: Optimizer) -> None: """Called before ``optimizer.zero_grad()``.""" + + def _on_before_launch(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None: + """For internal use.""" diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 6ed5d6c31f719..26c61e0f885da 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -670,6 +670,9 @@ def _call_and_handle_interrupt(self, trainer_fn: Callable, *args: Any, **kwargs: """ try: if self.strategy.launcher is not None: + # used in the CLI's `SaveConfigCallback` so that it still works with DDP spawn + self._call_callback_hooks("_on_before_launch") + return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs) else: return trainer_fn(*args, **kwargs) diff --git a/pytorch_lightning/utilities/cli.py b/pytorch_lightning/utilities/cli.py index 0332679c0741a..517c8d3146b76 100644 --- a/pytorch_lightning/utilities/cli.py +++ b/pytorch_lightning/utilities/cli.py @@ -27,6 +27,7 @@ import pytorch_lightning as pl from pytorch_lightning import Callback, LightningDataModule, LightningModule, seed_everything, Trainer +from pytorch_lightning.strategies import DDPSpawnStrategy from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _JSONARGPARSE_AVAILABLE @@ -414,9 +415,7 @@ def __init__( self.overwrite = overwrite self.multifile = multifile - def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: - # save the config in `setup` because (1) we want it to save regardless of the trainer function run - # and we want to save before processes are spawned + def save(self, trainer: "pl.Trainer") -> None: log_dir = trainer.log_dir # this broadcasts the directory assert log_dir is not None config_path = os.path.join(log_dir, self.config_filename) @@ -445,8 +444,17 @@ def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[st self.config, config_path, skip_none=False, overwrite=self.overwrite, multifile=self.multifile ) + def _on_before_launch(self, trainer: "pl.Trainer", *_: Any) -> None: + if isinstance(trainer.strategy, DDPSpawnStrategy): + self.save(trainer) + + def setup(self, trainer: Trainer, pl_module: LightningModule, stage: Optional[str] = None) -> None: + if not isinstance(trainer.strategy, DDPSpawnStrategy): + self.save(trainer) + def __reduce__(self) -> Tuple[Type["SaveConfigCallback"], Tuple, Dict]: - # `ArgumentParser` is un-pickleable. Drop it + # `ArgumentParser` is un-pickleable. Since we will be dropping it when DDP spawn is used, DDP spawn needs to + # save the config file before this is called, and for this, we use an internal hook `on_before_launch`. return ( self.__class__, (None, self.config, self.config_filename), diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 250fed33ecb45..499ee49610db7 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -28,6 +28,7 @@ import torch import yaml from packaging import version +from torch.multiprocessing import ProcessRaisedException from torch.optim import SGD from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR @@ -579,7 +580,7 @@ def on_fit_start(self): @pytest.mark.parametrize("strategy", ("ddp_spawn", "ddp")) def test_cli_distributed_save_config_callback(tmpdir, logger, strategy): with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises( - MisconfigurationException, match=r"Error on fit start" + (MisconfigurationException, ProcessRaisedException), match=r"Error on fit start" ): LightningCLI( EarlyExitTestModel, From cee339423a2ab050fd558c92c7671106fd2f5ef1 Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Feb 2022 14:55:24 +0100 Subject: [PATCH 3/4] Auto --- tests/utilities/test_cli.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 499ee49610db7..97487bab23783 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -590,7 +590,7 @@ def test_cli_distributed_save_config_callback(tmpdir, logger, strategy): "max_steps": 1, "max_epochs": 1, "strategy": strategy, - "accelerator": "cpu", + "accelerator": "auto", "devices": 1, }, ) From a49e75a9a395eb75d7ce973e5de3e842a002231c Mon Sep 17 00:00:00 2001 From: Carlos Mocholi Date: Fri, 18 Feb 2022 15:11:16 +0100 Subject: [PATCH 4/4] Fix import --- tests/utilities/test_cli.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/utilities/test_cli.py b/tests/utilities/test_cli.py index 97487bab23783..5b37adfa1398b 100644 --- a/tests/utilities/test_cli.py +++ b/tests/utilities/test_cli.py @@ -28,7 +28,6 @@ import torch import yaml from packaging import version -from torch.multiprocessing import ProcessRaisedException from torch.optim import SGD from torch.optim.lr_scheduler import ReduceLROnPlateau, StepLR @@ -51,7 +50,7 @@ SaveConfigCallback, ) from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _TORCHVISION_AVAILABLE +from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCHVISION_AVAILABLE from tests.helpers import BoringDataModule, BoringModel from tests.helpers.utils import no_warning_call @@ -579,6 +578,11 @@ def on_fit_start(self): @pytest.mark.parametrize("logger", (False, True)) @pytest.mark.parametrize("strategy", ("ddp_spawn", "ddp")) def test_cli_distributed_save_config_callback(tmpdir, logger, strategy): + if _TORCH_GREATER_EQUAL_1_8: + from torch.multiprocessing import ProcessRaisedException + else: + ProcessRaisedException = Exception + with mock.patch("sys.argv", ["any.py", "fit"]), pytest.raises( (MisconfigurationException, ProcessRaisedException), match=r"Error on fit start" ):