diff --git a/CHANGELOG.md b/CHANGELOG.md index 655484292ee59..cc8fcd63e3e00 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -213,10 +213,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - LightningLite: * Added `PrecisionPlugin.forward_context`, making it the default implementation for all `{train,val,test,predict}_step_context()` methods ([#9988](https://github.com/PyTorchLightning/pytorch-lightning/pull/9988)) * Added `DDPSpawnPlugin.spawn()` for spawning new processes of a given function ([#10018](https://github.com/PyTorchLightning/pytorch-lightning/pull/10018), [#10022](https://github.com/PyTorchLightning/pytorch-lightning/pull/10022)) - * Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994)) + * Added `TrainingTypePlugin.{_setup_model, _setup_optimizer}` methods ([#9994](https://github.com/PyTorchLightning/pytorch-lightning/pull/9994), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064)) * Implemented `DataParallelPlugin._setup_model` ([#10010](https://github.com/PyTorchLightning/pytorch-lightning/pull/10010)) - * Implemented `DeepSpeedPlugin._setup_models_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009)) - * Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_models_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028)) + * Implemented `DeepSpeedPlugin._setup_model_and_optimizers` ([#10009](https://github.com/PyTorchLightning/pytorch-lightning/pull/10009), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064)) + * Implemented `{DDPShardedPlugin,DDPShardedSpawnPlugin}._setup_model_and_optimizers` ([#10028](https://github.com/PyTorchLightning/pytorch-lightning/pull/10028), [#10064](https://github.com/PyTorchLightning/pytorch-lightning/pull/10064)) * Added optional `model` argument to the `optimizer_step` methods in accelerators and plugins ([#10023](https://github.com/PyTorchLightning/pytorch-lightning/pull/10023)) diff --git a/pytorch_lightning/plugins/training_type/deepspeed.py b/pytorch_lightning/plugins/training_type/deepspeed.py index 03eee1db167fc..1ff22cc07ecb0 100644 --- a/pytorch_lightning/plugins/training_type/deepspeed.py +++ b/pytorch_lightning/plugins/training_type/deepspeed.py @@ -379,30 +379,28 @@ def pre_dispatch(self): self.init_deepspeed() self.barrier() - def _setup_models_and_optimizers( - self, models: List[Module], optimizers: List[Optimizer] - ) -> Tuple[List[Module], List[Optimizer]]: - """Setup multiple models and multiple optimizers together. + def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: + """Setup a model and multiple optimizers together. - Currently only one model paired with a single optimizer is supported. + Currently only a single optimizer is supported. Return: - A list with one model wrapped into a :class:`deepspeed.DeepSpeedEngine` and list with a single + The model wrapped into a :class:`deepspeed.DeepSpeedEngine` and a list with a single deepspeed optimizer. """ - if not (len(models) == len(optimizers) == 1): + if len(optimizers) != 1: raise ValueError( - f"Currently only one model and one optimizer is supported with DeepSpeed." - f" Got {len(models)} models and {len(optimizers)} optimizers instead." + f"Currently only one optimizer is supported with DeepSpeed." + f" Got {len(optimizers)} optimizers instead." ) # train_micro_batch_size_per_gpu is used for throughput logging purposes # normally we set this to the batch size, but it is not available here unless the user provides it # as part of the config self.config.setdefault("train_micro_batch_size_per_gpu", 1) - self._model, optimizer = self._setup_model_and_optimizer(models[0], optimizers[0]) + self._model, optimizer = self._setup_model_and_optimizer(model, optimizers[0]) self._set_deepspeed_activation_checkpointing() - return [self._model], [optimizer] + return self._model, [optimizer] def _setup_model_and_optimizer( self, model: Module, optimizer: Optimizer, lr_scheduler: Optional[_LRScheduler] = None diff --git a/pytorch_lightning/plugins/training_type/sharded.py b/pytorch_lightning/plugins/training_type/sharded.py index 9712b5356091f..6e278d44e5cb8 100644 --- a/pytorch_lightning/plugins/training_type/sharded.py +++ b/pytorch_lightning/plugins/training_type/sharded.py @@ -47,33 +47,23 @@ def configure_ddp(self) -> None: # For multi-node training, enabling bucketing will improve performance. self._ddp_kwargs["reduce_buffer_size"] = self._REDUCE_BUFFER_SIZE_DEFAULT if self.num_nodes > 1 else 0 - [self._model], optimizers = self._setup_models_and_optimizers( - models=[LightningShardedDataParallel(self.model)], + self._model, optimizers = self._setup_model_and_optimizers( + model=LightningShardedDataParallel(self.model), optimizers=trainer.optimizers, ) trainer.optimizers = optimizers trainer.convert_to_lightning_optimizers() - def _setup_models_and_optimizers( - self, models: List[Module], optimizers: List[Optimizer] - ) -> Tuple[List[Module], List[Optimizer]]: + def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Wraps the model and optimizers with fairscale components. - Currently only one model can be setup at once. - Return: - A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module + The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. """ - if len(models) > 1: - raise ValueError( - "DDPSharded only supports setting up a single model with one or several optimizers." - f" Got {len(models)} models." - ) - optimizers = self._wrap_optimizers(optimizers) - model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs) - return [model], optimizers + model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs) + return model, optimizers def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, LightningOptimizer]]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): diff --git a/pytorch_lightning/plugins/training_type/sharded_spawn.py b/pytorch_lightning/plugins/training_type/sharded_spawn.py index 9503ffb951abb..13615ce05e2fb 100644 --- a/pytorch_lightning/plugins/training_type/sharded_spawn.py +++ b/pytorch_lightning/plugins/training_type/sharded_spawn.py @@ -39,32 +39,22 @@ class DDPSpawnShardedPlugin(DDPSpawnPlugin): def configure_ddp(self) -> None: trainer = self.lightning_module.trainer - [self._model], optimizers = self._setup_models_and_optimizers( - models=[LightningShardedDataParallel(self.model)], + self._model, optimizers = self._setup_model_and_optimizers( + model=LightningShardedDataParallel(self.model), optimizers=trainer.optimizers, ) trainer.optimizers = optimizers - def _setup_models_and_optimizers( - self, models: List[Module], optimizers: List[Optimizer] - ) -> Tuple[List[Module], List[Optimizer]]: + def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: """Wraps the model and optimizers with fairscale components. - Currently only one model can be setup at once. - Return: - A list with one model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module + The model wrapped into a :class:`~fairscale.nn.data_parallel.ShardedDataParallel` module and a list of optimizer wrapped in :class:~`fairscale.optim.OSS`. """ - if len(models) > 1: - raise ValueError( - f"DDPShardedSpawn only supports setting up a single model with one or several optimizers." - f" Got {len(models)} models." - ) - optimizers = self._wrap_optimizers(optimizers) - model = ShardedDataParallel(models[0], sharded_optimizer=optimizers, **self._ddp_kwargs) - return [model], optimizers + model = ShardedDataParallel(model, sharded_optimizer=optimizers, **self._ddp_kwargs) + return model, optimizers def _reinit_optimizers_with_oss(self, optimizers: List[Optimizer]) -> List["OSS"]: for x, optimizer in enumerate(optimizers): diff --git a/pytorch_lightning/plugins/training_type/training_type_plugin.py b/pytorch_lightning/plugins/training_type/training_type_plugin.py index 95c74d4a87b70..e1cfdda2d68d8 100644 --- a/pytorch_lightning/plugins/training_type/training_type_plugin.py +++ b/pytorch_lightning/plugins/training_type/training_type_plugin.py @@ -61,18 +61,16 @@ def setup_environment(self) -> None: def setup(self) -> None: """Called by the accelerator to finish setup.""" - def _setup_models_and_optimizers( - self, models: List[Module], optimizers: List[Optimizer] - ) -> Tuple[List[Module], List[Optimizer]]: - """Setup multiple models and multiple optimizers together. + def _setup_model_and_optimizers(self, model: Module, optimizers: List[Optimizer]) -> Tuple[Module, List[Optimizer]]: + """Setup a model and multiple optimizers together. The returned objects are expected to be in the same order they were passed in. The default implementation will - call :meth:`_setup_model` and :meth:`_setup_optimizer` on the input lists. + call :meth:`_setup_model` and :meth:`_setup_optimizer` on the inputs. """ # TODO (@awaelchli): standardize this across all plugins in Lightning and Lite. Related refactor: #7324 - models = [self._setup_model(model) for model in models] + model = self._setup_model(model) optimizers = [self._setup_optimizer(optimizer) for optimizer in optimizers] - return models, optimizers + return model, optimizers def _setup_model(self, model: Module) -> Module: """Performs setup for the model, e.g., by wrapping it by another class."""