diff --git a/CHANGELOG.md b/CHANGELOG.md index e73909397c3db..4f29379ebb8a6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -593,6 +593,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed `DDPStrategy` and `DDPSpawnStrategy` to initialize optimizers only after moving the module to the device ([#11886](https://github.com/PyTorchLightning/pytorch-lightning/pull/11886)) +- Fixed `Strategy` to support state saving for FairScale `OSS` and `torch.distrbuted` `ZeroRedundancyOptimizer` optimizers outside of `DDPShardedStrategy` ([#11867](https://github.com/PyTorchLightning/pytorch-lightning/pull/11867)) + + ## [1.5.10] - 2022-02-08 ### Fixed diff --git a/pytorch_lightning/strategies/sharded.py b/pytorch_lightning/strategies/sharded.py index 2d1584a2e15e5..8bdc6766346e7 100644 --- a/pytorch_lightning/strategies/sharded.py +++ b/pytorch_lightning/strategies/sharded.py @@ -25,7 +25,6 @@ from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE -from pytorch_lightning.utilities.rank_zero import rank_zero_only if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel @@ -85,20 +84,6 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: return self._reinit_optimizers_with_oss(optimizers) - def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: - if isinstance(optimizer, LightningOptimizer): - optimizer = optimizer._optimizer - optimizer.consolidate_state_dict() - return self._optim_state_dict(optimizer) - - @rank_zero_only - def _optim_state_dict(self, optimizer): - """ - Retrieves state dict only on rank 0, which contains the entire optimizer state after calling - :meth:`consolidate_state_dict`. - """ - return optimizer.state_dict() - @property def lightning_module(self) -> Optional["pl.LightningModule"]: if not _FAIRSCALE_AVAILABLE: # pragma: no cover diff --git a/pytorch_lightning/strategies/sharded_spawn.py b/pytorch_lightning/strategies/sharded_spawn.py index 289e3491be0b4..82315b3ead87a 100644 --- a/pytorch_lightning/strategies/sharded_spawn.py +++ b/pytorch_lightning/strategies/sharded_spawn.py @@ -24,7 +24,6 @@ from pytorch_lightning.utilities.enums import _StrategyType from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE -from pytorch_lightning.utilities.rank_zero import rank_zero_only if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel @@ -69,11 +68,6 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]: return self._reinit_optimizers_with_oss(optimizers) - def optimizer_state(self, optimizer: "OSS") -> Optional[dict]: - if isinstance(optimizer, OSS): - optimizer.consolidate_state_dict() - return self._optim_state_dict(optimizer) - @contextmanager def block_backward_sync(self) -> Generator: """Blocks syncing gradients behaviour on backwards pass. @@ -87,14 +81,6 @@ def block_backward_sync(self) -> Generator: else: yield None - @rank_zero_only - def _optim_state_dict(self, optimizer): - """ - Retrieves state dict only on rank 0, which contains the entire optimizer state after calling - :meth:`consolidate_state_dict`. - """ - return optimizer.state_dict() - @property def lightning_module(self) -> Optional["pl.LightningModule"]: if not _FAIRSCALE_AVAILABLE: # pragma: no cover diff --git a/pytorch_lightning/strategies/strategy.py b/pytorch_lightning/strategies/strategy.py index 629911911b780..99513d4af4e70 100644 --- a/pytorch_lightning/strategies/strategy.py +++ b/pytorch_lightning/strategies/strategy.py @@ -28,12 +28,19 @@ from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO from pytorch_lightning.plugins.precision import PrecisionPlugin from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import rank_zero_deprecation from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device from pytorch_lightning.utilities.distributed import ReduceOp +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _TORCH_GREATER_EQUAL_1_10 from pytorch_lightning.utilities.model_helpers import is_overridden +from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT +if _FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS +if _TORCH_GREATER_EQUAL_1_10: + from torch.distributed.optim import ZeroRedundancyOptimizer + + TBroadcast = TypeVar("TBroadcast") @@ -148,12 +155,19 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None: # while training on 8 and more cores. opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device) - def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]: + def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Any]: """Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom plugins. """ - return optimizer.state_dict() + if (_TORCH_GREATER_EQUAL_1_10 and isinstance(optimizer, ZeroRedundancyOptimizer)) or ( + _FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS) + ): + optimizer.consolidate_state_dict() + # only call state_dict on the rank where the states were consolidated + return optimizer.state_dict() if self.is_global_zero else {} + else: + return optimizer.state_dict() def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor: """Forwards backward-calls to the precision plugin. diff --git a/tests/strategies/test_ddp_spawn_strategy.py b/tests/strategies/test_ddp_spawn_strategy.py index 77322ac12ec24..2c0dff23aafd3 100644 --- a/tests/strategies/test_ddp_spawn_strategy.py +++ b/tests/strategies/test_ddp_spawn_strategy.py @@ -11,7 +11,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -import os from pathlib import Path from unittest.mock import Mock diff --git a/tests/strategies/test_ddp_strategy.py b/tests/strategies/test_ddp_strategy.py index df5b1ca54ab50..342bf2d6fb315 100644 --- a/tests/strategies/test_ddp_strategy.py +++ b/tests/strategies/test_ddp_strategy.py @@ -22,9 +22,15 @@ from pytorch_lightning.plugins.environments import LightningEnvironment from pytorch_lightning.strategies import DDPStrategy from pytorch_lightning.trainer.states import TrainerFn +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _TORCH_GREATER_EQUAL_1_10 from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf +if _FAIRSCALE_AVAILABLE: + from fairscale.optim import OSS +if _TORCH_GREATER_EQUAL_1_10: + from torch.distributed.optim import ZeroRedundancyOptimizer + class BoringModelGPU(BoringModel): def on_train_start(self) -> None: @@ -148,3 +154,50 @@ def test_model_parameters_on_device_for_optimizer(strategy): strategy=strategy, ) trainer.fit(model) + + +class BoringFairScaleOptimizerModel(BoringModel): + def configure_optimizers(self): + base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults) + + +@RunIf(min_gpus=2, skip_windows=True, fairscale=True) +@pytest.mark.parametrize("strategy", ("ddp", "ddp_spawn")) +def test_ddp_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy): + """Test to ensure that checkpoint is saved correctly when using faircale optimizer.""" + model = BoringFairScaleOptimizerModel() + trainer = Trainer(gpus=2, max_epochs=2, strategy=strategy) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, "model.pt") + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(ddp_param.to("cpu"), shard_param) + + +class BoringZeroRedundancyOptimizerModel(BoringModel): + def configure_optimizers(self): + return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1) + + +@RunIf(min_gpus=2, skip_windows=True, min_torch="1.10") +@pytest.mark.parametrize("strategy", ("ddp", "ddp_spawn")) +def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy): + """Test to ensure that checkpoint is saved correctly when using zero redundancy optimizer.""" + model = BoringZeroRedundancyOptimizerModel() + trainer = Trainer(max_epochs=2, gpus=2, strategy=strategy) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, "model.pt") + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(ddp_param.to("cpu"), shard_param) diff --git a/tests/strategies/test_sharded_strategy.py b/tests/strategies/test_sharded_strategy.py index 1fdb3c6f557a5..5d8e7a1011ec9 100644 --- a/tests/strategies/test_sharded_strategy.py +++ b/tests/strategies/test_sharded_strategy.py @@ -8,12 +8,13 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.strategies import DDPShardedStrategy, DDPSpawnShardedStrategy from pytorch_lightning.trainer.states import TrainerFn -from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE from tests.helpers.boring_model import BoringModel from tests.helpers.runif import RunIf if _FAIRSCALE_AVAILABLE: from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel + from fairscale.optim import OSS @pytest.mark.parametrize("clip_val", [0, 10]) @@ -278,3 +279,27 @@ def test_block_backward_sync(tmpdir): with strategy.block_backward_sync(): pass model.no_sync.assert_called_once() + + +class BoringFairScaleOptimizerModel(BoringModel): + def configure_optimizers(self): + base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) + return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults) + + +@RunIf(min_gpus=2, skip_windows=True, fairscale=True) +@pytest.mark.parametrize("strategy", ("ddp_sharded", "ddp_sharded_spawn")) +def test_ddp_sharded_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy): + """Test to ensure that checkpoint is saved correctly when using fairscale optimizers.""" + model = BoringFairScaleOptimizerModel() + trainer = Trainer(gpus=2, max_epochs=2, strategy=strategy) + + trainer.fit(model) + + checkpoint_path = os.path.join(tmpdir, "model.pt") + trainer.save_checkpoint(checkpoint_path) + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) + + # Assert model parameters are identical after loading + for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()): + assert torch.equal(ddp_param.to("cpu"), shard_param)