Skip to content
Closed
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -593,6 +593,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `DDPStrategy` and `DDPSpawnStrategy` to initialize optimizers only after moving the module to the device ([#11886](https://github.com/PyTorchLightning/pytorch-lightning/pull/11886))


- Fixed `Strategy` to support state saving for FairScale `OSS` and `torch.distrbuted` `ZeroRedundancyOptimizer` optimizers outside of `DDPShardedStrategy` ([#11867](https://github.com/PyTorchLightning/pytorch-lightning/pull/11867))


## [1.5.10] - 2022-02-08

### Fixed
Expand Down
15 changes: 0 additions & 15 deletions pytorch_lightning/strategies/sharded.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from pytorch_lightning.utilities.enums import _StrategyType, PrecisionType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
Expand Down Expand Up @@ -85,20 +84,6 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:

return self._reinit_optimizers_with_oss(optimizers)

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if isinstance(optimizer, LightningOptimizer):
optimizer = optimizer._optimizer
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

@rank_zero_only
def _optim_state_dict(self, optimizer):
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
"""
return optimizer.state_dict()

@property
def lightning_module(self) -> Optional["pl.LightningModule"]:
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
Expand Down
14 changes: 0 additions & 14 deletions pytorch_lightning/strategies/sharded_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
from pytorch_lightning.utilities.enums import _StrategyType
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.rank_zero import rank_zero_only

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
Expand Down Expand Up @@ -69,11 +68,6 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:

return self._reinit_optimizers_with_oss(optimizers)

def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
if isinstance(optimizer, OSS):
optimizer.consolidate_state_dict()
return self._optim_state_dict(optimizer)

@contextmanager
def block_backward_sync(self) -> Generator:
"""Blocks syncing gradients behaviour on backwards pass.
Expand All @@ -87,14 +81,6 @@ def block_backward_sync(self) -> Generator:
else:
yield None

@rank_zero_only
def _optim_state_dict(self, optimizer):
"""
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
:meth:`consolidate_state_dict`.
"""
return optimizer.state_dict()

@property
def lightning_module(self) -> Optional["pl.LightningModule"]:
if not _FAIRSCALE_AVAILABLE: # pragma: no cover
Expand Down
20 changes: 17 additions & 3 deletions pytorch_lightning/strategies/strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,12 +28,19 @@
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
from pytorch_lightning.plugins.precision import PrecisionPlugin
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import rank_zero_deprecation
from pytorch_lightning.utilities.apply_func import apply_to_collection, move_data_to_device
from pytorch_lightning.utilities.distributed import ReduceOp
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _TORCH_GREATER_EQUAL_1_10
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
from pytorch_lightning.utilities.types import _PATH, LRSchedulerConfig, STEP_OUTPUT

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.optim import ZeroRedundancyOptimizer


TBroadcast = TypeVar("TBroadcast")


Expand Down Expand Up @@ -148,12 +155,19 @@ def _move_optimizer_state(self, device: Optional[torch.device] = None) -> None:
# while training on 8 and more cores.
opt.state[p] = apply_to_collection(v, torch.Tensor, move_data_to_device, device or self.root_device)

def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Any]:
"""Returns state of an optimizer.

Allows for syncing/collating optimizer state from processes in custom plugins.
"""
return optimizer.state_dict()
if (_TORCH_GREATER_EQUAL_1_10 and isinstance(optimizer, ZeroRedundancyOptimizer)) or (
_FAIRSCALE_AVAILABLE and isinstance(optimizer, OSS)
):
optimizer.consolidate_state_dict()
# only call state_dict on the rank where the states were consolidated
return optimizer.state_dict() if self.is_global_zero else {}
else:
return optimizer.state_dict()

def backward(self, closure_loss: Tensor, *args: Any, **kwargs: Any) -> Tensor:
"""Forwards backward-calls to the precision plugin.
Expand Down
1 change: 0 additions & 1 deletion tests/strategies/test_ddp_spawn_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
from pathlib import Path
from unittest.mock import Mock

Expand Down
53 changes: 53 additions & 0 deletions tests/strategies/test_ddp_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,15 @@
from pytorch_lightning.plugins.environments import LightningEnvironment
from pytorch_lightning.strategies import DDPStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _TORCH_GREATER_EQUAL_1_10
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf

if _FAIRSCALE_AVAILABLE:
from fairscale.optim import OSS
if _TORCH_GREATER_EQUAL_1_10:
from torch.distributed.optim import ZeroRedundancyOptimizer


class BoringModelGPU(BoringModel):
def on_train_start(self) -> None:
Expand Down Expand Up @@ -148,3 +154,50 @@ def test_model_parameters_on_device_for_optimizer(strategy):
strategy=strategy,
)
trainer.fit(model)


class BoringFairScaleOptimizerModel(BoringModel):
def configure_optimizers(self):
base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults)


@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
@pytest.mark.parametrize("strategy", ("ddp", "ddp_spawn"))
def test_ddp_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy):
"""Test to ensure that checkpoint is saved correctly when using faircale optimizer."""
model = BoringFairScaleOptimizerModel()
trainer = Trainer(gpus=2, max_epochs=2, strategy=strategy)

trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(ddp_param.to("cpu"), shard_param)


class BoringZeroRedundancyOptimizerModel(BoringModel):
def configure_optimizers(self):
return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1)


@RunIf(min_gpus=2, skip_windows=True, min_torch="1.10")
@pytest.mark.parametrize("strategy", ("ddp", "ddp_spawn"))
def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy):
"""Test to ensure that checkpoint is saved correctly when using zero redundancy optimizer."""
model = BoringZeroRedundancyOptimizerModel()
trainer = Trainer(max_epochs=2, gpus=2, strategy=strategy)

trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(ddp_param.to("cpu"), shard_param)
27 changes: 26 additions & 1 deletion tests/strategies/test_sharded_strategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,13 @@
from pytorch_lightning import LightningModule, Trainer
from pytorch_lightning.strategies import DDPShardedStrategy, DDPSpawnShardedStrategy
from pytorch_lightning.trainer.states import TrainerFn
from pytorch_lightning.utilities import _FAIRSCALE_AVAILABLE
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
from tests.helpers.boring_model import BoringModel
from tests.helpers.runif import RunIf

if _FAIRSCALE_AVAILABLE:
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
from fairscale.optim import OSS


@pytest.mark.parametrize("clip_val", [0, 10])
Expand Down Expand Up @@ -278,3 +279,27 @@ def test_block_backward_sync(tmpdir):
with strategy.block_backward_sync():
pass
model.no_sync.assert_called_once()


class BoringFairScaleOptimizerModel(BoringModel):
def configure_optimizers(self):
base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults)


@RunIf(min_gpus=2, skip_windows=True, fairscale=True)
@pytest.mark.parametrize("strategy", ("ddp_sharded", "ddp_sharded_spawn"))
def test_ddp_sharded_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy):
"""Test to ensure that checkpoint is saved correctly when using fairscale optimizers."""
model = BoringFairScaleOptimizerModel()
trainer = Trainer(gpus=2, max_epochs=2, strategy=strategy)

trainer.fit(model)

checkpoint_path = os.path.join(tmpdir, "model.pt")
trainer.save_checkpoint(checkpoint_path)
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)

# Assert model parameters are identical after loading
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
assert torch.equal(ddp_param.to("cpu"), shard_param)