From b0ed863becf7218e9a506ed8a4e00beed37b53bb Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 7 Mar 2022 23:05:51 -0800 Subject: [PATCH 1/3] Update fully_sharded.py --- pytorch_lightning/strategies/fully_sharded.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/pytorch_lightning/strategies/fully_sharded.py b/pytorch_lightning/strategies/fully_sharded.py index a25f2b0b8695d..6214bb82ce81c 100644 --- a/pytorch_lightning/strategies/fully_sharded.py +++ b/pytorch_lightning/strategies/fully_sharded.py @@ -135,16 +135,11 @@ def setup_distributed(self) -> None: def setup(self, trainer: "pl.Trainer") -> None: self.accelerator.setup(trainer) - self.setup_optimizers(trainer) - self.setup_precision_plugin() - optimizers_to_device(self.optimizers, self.root_device) if self._layer_sync: self.model = self._layer_sync.apply(self.model) self.configure_ddp() - self.barrier() - self.setup_optimizers(trainer) @contextlib.contextmanager def model_sharded_context(self) -> Generator: @@ -171,9 +166,12 @@ def wrap_policy(*args, **kwargs): yield log.detail(f"{self.__class__.__name__}: exiting model_sharded_context.") + self.setup_optimizers(self.lightning_module.trainer) + optimizers_to_device(self.optimizers, self.root_device) + self.setup_precision_plugin() def configure_ddp(self) -> None: - log.detail(f"{self.__class__.__name__}: configuring DDP... (cpu_offload: [{self.cpu_offload}])") + log.detail(f"{self.__class__.__name__}: configuring FSDP... (cpu_offload: [{self.cpu_offload}])") if not self.cpu_offload: # When using CPU Offload, FSDP will manage the CUDA movement for us. # Note: this would be problematic for large model (which could not fit in one GPU) @@ -181,9 +179,6 @@ def configure_ddp(self) -> None: # (TODO: need to figure out solution) self.model_to_device() - # setup optimizers after fully sharded has wrapped the lightning module - self.setup_optimizers(self.lightning_module.trainer) - def model_to_device(self) -> None: log.detail(f"{self.__class__.__name__}: moving model to device [{self.root_device}]...") # ensure we update the device type in the lightning module From 10bda43bfce48b7cabb0245264d39bd512de09df Mon Sep 17 00:00:00 2001 From: ananthsub Date: Mon, 7 Mar 2022 23:09:46 -0800 Subject: [PATCH 2/3] Update CHANGELOG.md --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 4efa7677191a7..fa25b2228a458 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -772,6 +772,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed to avoid common hook warning if no hook is overridden ([#12131](https://github.com/PyTorchLightning/pytorch-lightning/pull/12131)) +- Fixed initializing optimizers unnecessarily in `DDPFullyShardedStrategy` ([#12267](https://github.com/PyTorchLightning/pytorch-lightning/pull/12267)) + + ## [1.5.10] - 2022-02-08 ### Fixed From c81a4771714622ab5cd27b855039434b961e29a8 Mon Sep 17 00:00:00 2001 From: ananthsub Date: Wed, 9 Mar 2022 17:05:15 -0800 Subject: [PATCH 3/3] Update fully_sharded.py --- pytorch_lightning/strategies/fully_sharded.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/pytorch_lightning/strategies/fully_sharded.py b/pytorch_lightning/strategies/fully_sharded.py index 6214bb82ce81c..cd15af695ab73 100644 --- a/pytorch_lightning/strategies/fully_sharded.py +++ b/pytorch_lightning/strategies/fully_sharded.py @@ -140,6 +140,10 @@ def setup(self, trainer: "pl.Trainer") -> None: self.model = self._layer_sync.apply(self.model) self.configure_ddp() + self.barrier() + self.setup_optimizers(trainer) + optimizers_to_device(self.optimizers, self.root_device) + self.setup_precision_plugin() @contextlib.contextmanager def model_sharded_context(self) -> Generator: @@ -166,9 +170,6 @@ def wrap_policy(*args, **kwargs): yield log.detail(f"{self.__class__.__name__}: exiting model_sharded_context.") - self.setup_optimizers(self.lightning_module.trainer) - optimizers_to_device(self.optimizers, self.root_device) - self.setup_precision_plugin() def configure_ddp(self) -> None: log.detail(f"{self.__class__.__name__}: configuring FSDP... (cpu_offload: [{self.cpu_offload}])")