|
24 | 24 | from pytorch_lightning.plugins.environments import ClusterEnvironment, LightningEnvironment |
25 | 25 | from pytorch_lightning.strategies import DDPStrategy |
26 | 26 | from pytorch_lightning.trainer.states import TrainerFn |
| 27 | +from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _TORCH_GREATER_EQUAL_1_10 |
27 | 28 | from tests_pytorch.helpers.runif import RunIf |
28 | 29 |
|
| 30 | +if _FAIRSCALE_AVAILABLE: |
| 31 | + from fairscale.optim import OSS |
| 32 | +if _TORCH_GREATER_EQUAL_1_10: |
| 33 | + from torch.distributed.optim import ZeroRedundancyOptimizer |
| 34 | + |
29 | 35 |
|
30 | 36 | class BoringModelGPU(BoringModel): |
31 | 37 | def on_train_start(self) -> None: |
@@ -252,3 +258,50 @@ def test_ddp_strategy_set_timeout(mock_init_process_group): |
252 | 258 | mock_init_process_group.assert_called_with( |
253 | 259 | process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta |
254 | 260 | ) |
| 261 | + |
| 262 | + |
| 263 | +class BoringFairScaleOptimizerModel(BoringModel): |
| 264 | + def configure_optimizers(self): |
| 265 | + base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1) |
| 266 | + return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults) |
| 267 | + |
| 268 | + |
| 269 | +@RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True) |
| 270 | +@pytest.mark.parametrize("strategy", (pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn")) |
| 271 | +def test_ddp_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy): |
| 272 | + """Test to ensure that checkpoint is saved correctly when using faircale optimizer.""" |
| 273 | + model = BoringFairScaleOptimizerModel() |
| 274 | + trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1) |
| 275 | + |
| 276 | + trainer.fit(model) |
| 277 | + |
| 278 | + checkpoint_path = os.path.join(tmpdir, "model.pt") |
| 279 | + trainer.save_checkpoint(checkpoint_path) |
| 280 | + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) |
| 281 | + |
| 282 | + # Assert model parameters are identical after loading |
| 283 | + for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()): |
| 284 | + assert torch.equal(trained_param.to("cpu"), loaded_param) |
| 285 | + |
| 286 | + |
| 287 | +class BoringZeroRedundancyOptimizerModel(BoringModel): |
| 288 | + def configure_optimizers(self): |
| 289 | + return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1) |
| 290 | + |
| 291 | + |
| 292 | +@RunIf(min_cuda_gpus=2, skip_windows=True, min_torch="1.10") |
| 293 | +@pytest.mark.parametrize("strategy", (pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn")) |
| 294 | +def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy): |
| 295 | + """Test to ensure that checkpoint is saved correctly when using zero redundancy optimizer.""" |
| 296 | + model = BoringZeroRedundancyOptimizerModel() |
| 297 | + trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1) |
| 298 | + |
| 299 | + trainer.fit(model) |
| 300 | + |
| 301 | + checkpoint_path = os.path.join(tmpdir, "model.pt") |
| 302 | + trainer.save_checkpoint(checkpoint_path) |
| 303 | + saved_model = BoringModel.load_from_checkpoint(checkpoint_path) |
| 304 | + |
| 305 | + # Assert model parameters are identical after loading |
| 306 | + for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()): |
| 307 | + assert torch.equal(trained_param.to("cpu"), loaded_param) |
0 commit comments