From 9d633749965c8bb0bc9e40b9284cb99380121154 Mon Sep 17 00:00:00 2001 From: donlapark Date: Wed, 27 Jul 2022 02:02:50 +0700 Subject: [PATCH 01/17] fixes mypy errors in strategies/ddp_spawn.py --- src/pytorch_lightning/strategies/ddp_spawn.py | 65 ++++++++++++------- 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index fdb0a7d851169..b9546f9d4238c 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -14,7 +14,7 @@ import logging import os from datetime import timedelta -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union import torch import torch.distributed @@ -26,6 +26,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -75,8 +76,8 @@ def __init__( checkpoint_io: Optional[CheckpointIO] = None, precision_plugin: Optional[PrecisionPlugin] = None, ddp_comm_state: Optional[object] = None, - ddp_comm_hook: Optional[callable] = None, - ddp_comm_wrapper: Optional[callable] = None, + ddp_comm_hook: Optional[Callable] = None, + ddp_comm_wrapper: Optional[Callable] = None, process_group_backend: Optional[str] = None, timeout: Optional[timedelta] = default_pg_timeout, start_method: Literal["spawn", "fork", "forkserver"] = "spawn", @@ -113,32 +114,36 @@ def local_rank(self) -> int: return self._local_rank @property - def root_device(self): + def root_device(self) -> torch.device: + assert self.parallel_devices is not None return self.parallel_devices[self.local_rank] @property - def num_processes(self): + def num_processes(self) -> int: return len(self.parallel_devices) if self.parallel_devices is not None else 0 @property - def distributed_sampler_kwargs(self): + def distributed_sampler_kwargs(self) -> Dict[str, int]: distributed_sampler_kwargs = dict(num_replicas=(self.num_nodes * self.num_processes), rank=self.global_rank) return distributed_sampler_kwargs @property - def _is_single_process_single_device(self): + def _is_single_process_single_device(self) -> bool: return True @property def process_group_backend(self) -> Optional[str]: return self._process_group_backend - def _configure_launcher(self): + def _configure_launcher(self) -> None: self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method) def setup(self, trainer: "pl.Trainer") -> None: + + assert self.cluster_environment is not None os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) - + + assert self.accelerator is not None self.accelerator.setup(trainer) # move the model to the correct device @@ -148,6 +153,7 @@ def setup(self, trainer: "pl.Trainer") -> None: trainer_fn = trainer.state.fn if trainer_fn == TrainerFn.FITTING: if self._layer_sync: + assert self.model is not None self.model = self._layer_sync.apply(self.model) self.setup_precision_plugin() @@ -167,11 +173,12 @@ def set_world_ranks(self, process_idx: int = 0) -> None: self.cluster_environment.set_world_size(self.num_nodes * self.num_processes) rank_zero_only.rank = self.cluster_environment.global_rank() - def _worker_setup(self, process_idx: int): + def _worker_setup(self, process_idx: int) -> None: reset_seed() self.set_world_ranks(process_idx) rank_zero_only.rank = self.global_rank self._process_group_backend = self._get_process_group_backend() + assert self.cluster_environment is not None init_dist_connection( self.cluster_environment, self._process_group_backend, @@ -187,7 +194,7 @@ def _get_process_group_backend(self) -> str: or get_default_process_group_backend_for_device(self.root_device) ) - def pre_configure_ddp(self): + def pre_configure_ddp(self) -> None: # if unset, default `find_unused_parameters` `True` # Many models require setting this parameter to True, as there are corner cases # when not all parameter backward hooks are fired by the autograd engine even if require_grad is set to True. @@ -198,6 +205,7 @@ def _register_ddp_hooks(self) -> None: # currently, DDP communication hooks only work with NCCL backend and SPSD (single process single device) mode # https://github.com/pytorch/pytorch/blob/v1.8.0/torch/nn/parallel/distributed.py#L1080-L1084 if self.root_device.type == "cuda" and self._is_single_process_single_device: + assert isinstance(self.model, DistributedDataParallel) register_ddp_comm_hook( model=self.model, ddp_comm_state=self._ddp_comm_state, @@ -207,19 +215,21 @@ def _register_ddp_hooks(self) -> None: def configure_ddp(self) -> None: self.pre_configure_ddp() + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) self.model = self._setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() # set up optimizers after the wrapped module has been moved to the device + assert self.lightning_module is not None self.setup_optimizers(self.lightning_module.trainer) optimizers_to_device(self.optimizers, self.root_device) - def determine_ddp_device_ids(self): + def determine_ddp_device_ids(self) -> Optional[List[Union[int, torch.device]]]: if self.root_device.type == "cpu": return None return [self.root_device.index] - def barrier(self, *args, **kwargs) -> None: + def barrier(self, *args: Any, **kwargs: Any) -> None: if not distributed_available(): return if torch.distributed.get_backend() == "nccl": @@ -227,7 +237,7 @@ def barrier(self, *args, **kwargs) -> None: else: torch.distributed.barrier() - def broadcast(self, obj: object, src: int = 0) -> object: + def broadcast(self, obj: object, src: int = 0) -> object: # type: ignore[override] if not distributed_available(): return obj obj = [obj] @@ -236,18 +246,21 @@ def broadcast(self, obj: object, src: int = 0) -> object: torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) return obj[0] - def model_to_device(self): + def model_to_device(self) -> None: if self.root_device.type == "cuda": # set the device on the spawned subprocesses torch.cuda.set_device(self.root_device) + assert self.model is not None self.model.to(self.root_device) def pre_backward(self, closure_loss: Tensor) -> None: """Run before precision plugin executes backward.""" + assert self.lightning_module is not None if not self.lightning_module.automatic_optimization: + assert isinstance(self.model, DistributedDataParallel) prepare_for_backward(self.model, closure_loss) - def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, str] = "mean") -> Tensor: + def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean") -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: @@ -263,12 +276,15 @@ def reduce(self, tensor, group: Optional[Any] = None, reduce_op: Union[ReduceOp, tensor = sync_ddp_if_available(tensor, group, reduce_op=reduce_op) return tensor - def training_step(self, *args, **kwargs) -> STEP_OUTPUT: + def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + assert self.model is not None with self.precision_plugin.train_step_context(): return self.model(*args, **kwargs) - def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): + assert self.lightning_module is not None + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) if self.lightning_module.trainer.state.fn == TrainerFn.FITTING: # used when calling `trainer.fit` return self.model(*args, **kwargs) @@ -276,17 +292,21 @@ def validation_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: # used when calling `trainer.validate` return self.model.validation_step(*args, **kwargs) - def test_step(self, *args, **kwargs) -> Optional[STEP_OUTPUT]: + def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) with self.precision_plugin.test_step_context(): return self.model.test_step(*args, **kwargs) - def predict_step(self, *args, **kwargs) -> STEP_OUTPUT: + def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) with self.precision_plugin.predict_step_context(): return self.model.predict_step(*args, **kwargs) - def post_training_step(self): + def post_training_step(self) -> None: + assert self.lightning_module is not None if not self.lightning_module.automatic_optimization: - self.model.require_backward_grad_sync = True + assert self.model is not None + self.model.require_backward_grad_sync = Tensor(True) @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: @@ -332,5 +352,6 @@ def teardown(self) -> None: and pl_module._trainer.state.fn == TrainerFn.FITTING and self._layer_sync ): + assert self.model is not None self.model = self._layer_sync.revert(self.model) super().teardown() From 4ff464c5818371242327bccc10713b38f2c3627a Mon Sep 17 00:00:00 2001 From: donlapark Date: Wed, 27 Jul 2022 02:10:17 +0700 Subject: [PATCH 02/17] fixes mypy errors in strategies/ddp_spawn.py --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 6d973aa0dde51..cbc06ba983634 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -60,7 +60,6 @@ module = [ "pytorch_lightning.profilers.pytorch", "pytorch_lightning.profilers.simple", "pytorch_lightning.strategies.ddp", - "pytorch_lightning.strategies.ddp_spawn", "pytorch_lightning.strategies.deepspeed", "pytorch_lightning.strategies.fully_sharded", "pytorch_lightning.strategies.ipu", From c63bac41d8dbd0d957f34efc737a086568e70f53 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 Jul 2022 19:16:17 +0000 Subject: [PATCH 03/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/strategies/ddp_spawn.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index b9546f9d4238c..f0dd42ef28a78 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -139,10 +139,10 @@ def _configure_launcher(self) -> None: self._launcher = _MultiProcessingLauncher(self, start_method=self._start_method) def setup(self, trainer: "pl.Trainer") -> None: - + assert self.cluster_environment is not None os.environ["MASTER_PORT"] = str(self.cluster_environment.main_port) - + assert self.accelerator is not None self.accelerator.setup(trainer) @@ -260,7 +260,9 @@ def pre_backward(self, closure_loss: Tensor) -> None: assert isinstance(self.model, DistributedDataParallel) prepare_for_backward(self.model, closure_loss) - def reduce(self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean") -> Tensor: + def reduce( + self, tensor: Tensor, group: Optional[Any] = None, reduce_op: Optional[Union[ReduceOp, str]] = "mean" + ) -> Tensor: """Reduces a tensor from several distributed processes to one aggregated tensor. Args: From 351b981ef25ecd2fb5035d0d48a89637afe0c9cf Mon Sep 17 00:00:00 2001 From: donlapark Date: Wed, 27 Jul 2022 02:29:24 +0700 Subject: [PATCH 04/17] fixes mypy errors in strategies/ddp_spawn.py --- src/pytorch_lightning/strategies/ddp_spawn.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index b9546f9d4238c..d5f532f364967 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -332,6 +332,7 @@ def teardown(self) -> None: pl_module = self.lightning_module if isinstance(self.model, DistributedDataParallel): + assert callable(self.model._get_ddp_logging_data) if ( _TORCH_GREATER_EQUAL_1_11 and not self.model.static_graph From 62a8cbcbe0376db4b46bdd7a2917b58f12bcb4a2 Mon Sep 17 00:00:00 2001 From: donlapark Date: Wed, 27 Jul 2022 03:15:02 +0700 Subject: [PATCH 05/17] adds check if is --- src/pytorch_lightning/strategies/ddp_spawn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 5926e18fe3078..8e0f7b38734df 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -286,7 +286,8 @@ def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): assert self.lightning_module is not None - assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) + assert self.model is not None + assert callable(self.model.validation) if self.lightning_module.trainer.state.fn == TrainerFn.FITTING: # used when calling `trainer.fit` return self.model(*args, **kwargs) From ae0167392a67eb5fa46da53f403196a93dc073d7 Mon Sep 17 00:00:00 2001 From: donlapark Date: Wed, 27 Jul 2022 03:29:38 +0700 Subject: [PATCH 06/17] adds check if self.model.validation_step is callable --- src/pytorch_lightning/strategies/ddp_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 8e0f7b38734df..5f43bf9f620ff 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -287,7 +287,7 @@ def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): assert self.lightning_module is not None assert self.model is not None - assert callable(self.model.validation) + assert callable(self.model.validation_step) if self.lightning_module.trainer.state.fn == TrainerFn.FITTING: # used when calling `trainer.fit` return self.model(*args, **kwargs) From 6f4691048c5797a7b6876ef0c155fb83712a62b2 Mon Sep 17 00:00:00 2001 From: donlapark Date: Wed, 27 Jul 2022 11:48:16 +0700 Subject: [PATCH 07/17] asserts that type of self.model is utilities.types.DistributedDataParallel --- src/pytorch_lightning/strategies/ddp_spawn.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 5f43bf9f620ff..a84924e0a0baa 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -215,7 +215,7 @@ def _register_ddp_hooks(self) -> None: def configure_ddp(self) -> None: self.pre_configure_ddp() - assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) + assert isinstance(self.model, pl.utilities.types.DistributedDataParallel) self.model = self._setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() @@ -286,8 +286,7 @@ def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): assert self.lightning_module is not None - assert self.model is not None - assert callable(self.model.validation_step) + assert isinstance(self.model, pl.utilities.types.DistributedDataParallel) if self.lightning_module.trainer.state.fn == TrainerFn.FITTING: # used when calling `trainer.fit` return self.model(*args, **kwargs) @@ -296,12 +295,12 @@ def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: return self.model.validation_step(*args, **kwargs) def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) + assert isinstance(self.model, pl.utilities.types.DistributedDataParallel) with self.precision_plugin.test_step_context(): return self.model.test_step(*args, **kwargs) def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) + assert isinstance(self.model, pl.utilities.types.DistributedDataParallel) with self.precision_plugin.predict_step_context(): return self.model.predict_step(*args, **kwargs) From 9293bba0717c532d7490f9337fcadd7df26bfa22 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Wed, 27 Jul 2022 11:51:40 +0700 Subject: [PATCH 08/17] minor --- src/pytorch_lightning/strategies/ddp_spawn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index a84924e0a0baa..43e6d9db49021 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -26,7 +26,6 @@ import pytorch_lightning as pl from pytorch_lightning.overrides import LightningDistributedModule -from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO From b42dad9003ca7f36f38ca0d0dabe803f266f158b Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Wed, 27 Jul 2022 12:12:46 +0700 Subject: [PATCH 09/17] minor --- src/pytorch_lightning/strategies/ddp_spawn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 43e6d9db49021..44bc1436763c7 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -26,6 +26,7 @@ import pytorch_lightning as pl from pytorch_lightning.overrides import LightningDistributedModule +from pytorch_lightning.overrides.base import _LightningPrecisionModuleWrapperBase from pytorch_lightning.overrides.distributed import prepare_for_backward from pytorch_lightning.plugins.environments.cluster_environment import ClusterEnvironment from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO @@ -214,7 +215,7 @@ def _register_ddp_hooks(self) -> None: def configure_ddp(self) -> None: self.pre_configure_ddp() - assert isinstance(self.model, pl.utilities.types.DistributedDataParallel) + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) self.model = self._setup_model(LightningDistributedModule(self.model)) self._register_ddp_hooks() From 7df1eb2f9f6092e533a6ad63136a93b9893a0b7b Mon Sep 17 00:00:00 2001 From: donlapark Date: Wed, 27 Jul 2022 12:31:12 +0700 Subject: [PATCH 10/17] asserts that type of self.model is utilities.types.DistributedDataParallel --- src/pytorch_lightning/strategies/ddp_spawn.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 44bc1436763c7..4bd08ab093ac4 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -295,12 +295,12 @@ def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: return self.model.validation_step(*args, **kwargs) def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - assert isinstance(self.model, pl.utilities.types.DistributedDataParallel) + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) with self.precision_plugin.test_step_context(): return self.model.test_step(*args, **kwargs) def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - assert isinstance(self.model, pl.utilities.types.DistributedDataParallel) + assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) with self.precision_plugin.predict_step_context(): return self.model.predict_step(*args, **kwargs) From 8081a1b666fe9890cbaf7d318b8a7d93237a331f Mon Sep 17 00:00:00 2001 From: donlapark Date: Wed, 27 Jul 2022 14:24:09 +0700 Subject: [PATCH 11/17] ignore the 'Tensor not callable' errors --- src/pytorch_lightning/strategies/ddp_spawn.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 4bd08ab093ac4..a69a3d65412d0 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -33,6 +33,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy +from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.distributed import ( _get_process_group_backend_from_env, @@ -286,29 +287,29 @@ def training_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: with self.precision_plugin.val_step_context(): assert self.lightning_module is not None - assert isinstance(self.model, pl.utilities.types.DistributedDataParallel) + assert self.model is not None if self.lightning_module.trainer.state.fn == TrainerFn.FITTING: # used when calling `trainer.fit` return self.model(*args, **kwargs) else: # used when calling `trainer.validate` - return self.model.validation_step(*args, **kwargs) + return self.model.validation_step(*args, **kwargs) # type: ignore[operator] def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) + assert self.model is not None with self.precision_plugin.test_step_context(): - return self.model.test_step(*args, **kwargs) + return self.model.test_step(*args, **kwargs) # type: ignore[operator] def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - assert isinstance(self.model, (pl.LightningModule, _LightningPrecisionModuleWrapperBase)) + assert self.model is not None with self.precision_plugin.predict_step_context(): - return self.model.predict_step(*args, **kwargs) + return self.model.predict_step(*args, **kwargs) # type: ignore[operator] def post_training_step(self) -> None: assert self.lightning_module is not None if not self.lightning_module.automatic_optimization: assert self.model is not None - self.model.require_backward_grad_sync = Tensor(True) + self.model.require_backward_grad_sync = Tensor([True]) @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: @@ -334,11 +335,10 @@ def teardown(self) -> None: pl_module = self.lightning_module if isinstance(self.model, DistributedDataParallel): - assert callable(self.model._get_ddp_logging_data) if ( _TORCH_GREATER_EQUAL_1_11 and not self.model.static_graph - and self.model._get_ddp_logging_data().get("can_set_static_graph") + and self.model._get_ddp_logging_data().get("can_set_static_graph") # type: ignore[operator] ): rank_zero_info( "Your model can run with static graph optimizations. For future training runs, we suggest you" From 4bbadf8da58e366f0261ff5d015ce2c677f405c7 Mon Sep 17 00:00:00 2001 From: donlapark Date: Wed, 27 Jul 2022 14:28:44 +0700 Subject: [PATCH 12/17] fixes typo --- src/pytorch_lightning/strategies/ddp_spawn.py | 1 - 1 file changed, 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index a69a3d65412d0..98f83cb3cdc51 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -33,7 +33,6 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy -from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.distributed import ( _get_process_group_backend_from_env, From 25844411c20cf3136781330573d48bc210130b49 Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Thu, 28 Jul 2022 00:30:30 +0700 Subject: [PATCH 13/17] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/strategies/ddp_spawn.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 98f83cb3cdc51..3c73a4cc5f63c 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -224,7 +224,7 @@ def configure_ddp(self) -> None: self.setup_optimizers(self.lightning_module.trainer) optimizers_to_device(self.optimizers, self.root_device) - def determine_ddp_device_ids(self) -> Optional[List[Union[int, torch.device]]]: + def determine_ddp_device_ids(self) -> Optional[List[int]]: if self.root_device.type == "cpu": return None return [self.root_device.index] @@ -237,7 +237,7 @@ def barrier(self, *args: Any, **kwargs: Any) -> None: else: torch.distributed.barrier() - def broadcast(self, obj: object, src: int = 0) -> object: # type: ignore[override] + def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: if not distributed_available(): return obj obj = [obj] @@ -308,7 +308,7 @@ def post_training_step(self) -> None: assert self.lightning_module is not None if not self.lightning_module.automatic_optimization: assert self.model is not None - self.model.require_backward_grad_sync = Tensor([True]) + self.model.require_backward_grad_sync = True # type: ignore[assignment] @classmethod def register_strategies(cls, strategy_registry: Dict) -> None: From d2a5030d24890f3bc5d2dd783707fcb189c3680a Mon Sep 17 00:00:00 2001 From: donlapark <10988155+donlapark@users.noreply.github.com> Date: Thu, 28 Jul 2022 00:35:34 +0700 Subject: [PATCH 14/17] ignore error when assigning `obj = [None]` --- src/pytorch_lightning/strategies/ddp_spawn.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 3c73a4cc5f63c..e7dec0e089700 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -33,6 +33,7 @@ from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.strategies.launchers.multiprocessing import _MultiProcessingLauncher from pytorch_lightning.strategies.parallel import ParallelStrategy +from pytorch_lightning.strategies.strategy import TBroadcast from pytorch_lightning.trainer.states import TrainerFn from pytorch_lightning.utilities.distributed import ( _get_process_group_backend_from_env, @@ -242,7 +243,7 @@ def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast: return obj obj = [obj] if self.global_rank != src: - obj = [None] + obj = [None] # type: ignore[list-item] torch.distributed.broadcast_object_list(obj, src, group=_group.WORLD) return obj[0] From b619da7dab15bbe53c20abdaaec7b9e00f32ba8f Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 23:01:41 +0200 Subject: [PATCH 15/17] remove ignore for step methods --- src/pytorch_lightning/strategies/ddp_spawn.py | 11 ++++++----- 1 file changed, 6 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index e7dec0e089700..cfa1d99980d4a 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -51,7 +51,7 @@ from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import STEP_OUTPUT +from pytorch_lightning.utilities.types import STEP_OUTPUT, ValidationStep, TestStep, PredictStep log = logging.getLogger(__name__) @@ -293,16 +293,17 @@ def validation_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: return self.model(*args, **kwargs) else: # used when calling `trainer.validate` - return self.model.validation_step(*args, **kwargs) # type: ignore[operator] + assert isinstance(self.model, ValidationStep) + return self.model.validation_step(*args, **kwargs) def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: - assert self.model is not None with self.precision_plugin.test_step_context(): - return self.model.test_step(*args, **kwargs) # type: ignore[operator] + assert isinstance(self.model, TestStep) + return self.model.test_step(*args, **kwargs) def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: - assert self.model is not None with self.precision_plugin.predict_step_context(): + assert isinstance(self.model, PredictStep) return self.model.predict_step(*args, **kwargs) # type: ignore[operator] def post_training_step(self) -> None: From fa610b747d948e1a45519d44a17d2b92a53041b4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 27 Jul 2022 21:04:04 +0000 Subject: [PATCH 16/17] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- src/pytorch_lightning/strategies/ddp_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index cfa1d99980d4a..407c8f4eda717 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -51,7 +51,7 @@ from pytorch_lightning.utilities.optimizer import optimizers_to_device from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_only from pytorch_lightning.utilities.seed import reset_seed -from pytorch_lightning.utilities.types import STEP_OUTPUT, ValidationStep, TestStep, PredictStep +from pytorch_lightning.utilities.types import PredictStep, STEP_OUTPUT, TestStep, ValidationStep log = logging.getLogger(__name__) From 133a25ef10e340e316cf87c2950a7f14bb4b0995 Mon Sep 17 00:00:00 2001 From: awaelchli Date: Wed, 27 Jul 2022 23:37:39 +0200 Subject: [PATCH 17/17] remove unused ignore --- src/pytorch_lightning/strategies/ddp_spawn.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/strategies/ddp_spawn.py b/src/pytorch_lightning/strategies/ddp_spawn.py index 407c8f4eda717..6a3460febbf07 100644 --- a/src/pytorch_lightning/strategies/ddp_spawn.py +++ b/src/pytorch_lightning/strategies/ddp_spawn.py @@ -304,7 +304,7 @@ def test_step(self, *args: Any, **kwargs: Any) -> Optional[STEP_OUTPUT]: def predict_step(self, *args: Any, **kwargs: Any) -> STEP_OUTPUT: with self.precision_plugin.predict_step_context(): assert isinstance(self.model, PredictStep) - return self.model.predict_step(*args, **kwargs) # type: ignore[operator] + return self.model.predict_step(*args, **kwargs) def post_training_step(self) -> None: assert self.lightning_module is not None