Skip to content

Commit abd8d7b

Browse files
authored
Merge branch 'master' into docs/chlog_post_173
2 parents f03b712 + 70deac2 commit abd8d7b

File tree

10 files changed

+125
-42
lines changed

10 files changed

+125
-42
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,6 @@ warn_no_return = "False"
5151
module = [
5252
"pytorch_lightning.callbacks.progress.rich_progress",
5353
"pytorch_lightning.core.datamodule",
54-
"pytorch_lightning.demos.mnist_datamodule",
5554
"pytorch_lightning.profilers.base",
5655
"pytorch_lightning.profilers.pytorch",
5756
"pytorch_lightning.strategies.sharded",

src/pytorch_lightning/CHANGELOG.md

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1515
- Added support for passing extra init-parameters to the `LightningDataModule.from_datasets` ([#14185](https://github.com/Lightning-AI/lightning/issues/14185))
1616

1717

18+
- Added support for saving sharded optimizer state dict outside of `DDPShardedStrategy` ([#14208](https://github.com/PyTorchLightning/pytorch-lightning/pull/14208))
19+
20+
1821

1922
### Changed
2023

@@ -75,6 +78,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7578
- Fixed `LightningDataModule` hparams parsing ([#12806](https://github.com/PyTorchLightning/pytorch-lightning/pull/12806))
7679

7780

81+
- Reset epoch progress with batch size scaler ([#13846](https://github.com/Lightning-AI/lightning/pull/13846))
82+
7883

7984
## [1.7.3] - 2022-08-25
8085

src/pytorch_lightning/demos/mnist_datamodule.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import random
1717
import time
1818
import urllib
19-
from typing import Any, Callable, Optional, Tuple, Union
19+
from typing import Any, Callable, Optional, Sized, Tuple, Union
2020
from urllib.error import HTTPError
2121
from warnings import warn
2222

@@ -199,6 +199,7 @@ def setup(self, stage: Optional[str] = None) -> None:
199199
"""Split the train and valid dataset."""
200200
extra = dict(transform=self.default_transforms) if self.default_transforms else {}
201201
dataset: Dataset = MNIST(self.data_dir, train=True, download=False, **extra)
202+
assert isinstance(dataset, Sized)
202203
train_length = len(dataset)
203204
self.dataset_train, self.dataset_val = random_split(dataset, [train_length - self.val_split, self.val_split])
204205

src/pytorch_lightning/strategies/sharded.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from typing import Dict, Generator, List, Optional, Tuple, Union
15+
from typing import Dict, Generator, List, Tuple, Union
1616

1717
from torch import Tensor
1818
from torch.nn import Module
@@ -27,7 +27,6 @@
2727
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2828
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _FAIRSCALE_OSS_FP16_BROADCAST_AVAILABLE
2929
from pytorch_lightning.utilities.optimizer import optimizers_to_device
30-
from pytorch_lightning.utilities.rank_zero import rank_zero_only
3130

3231
if _FAIRSCALE_AVAILABLE:
3332
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
@@ -120,20 +119,6 @@ def _reinit_optimizers_with_oss(self, optimizers: List[Union[Optimizer, Lightnin
120119
del optimizer
121120
return optimizers
122121

123-
def optimizer_state(self, optimizer: "OSS") -> Optional[dict]:
124-
if isinstance(optimizer, LightningOptimizer):
125-
optimizer = optimizer._optimizer
126-
optimizer.consolidate_state_dict()
127-
return self._optim_state_dict(optimizer)
128-
129-
@rank_zero_only
130-
def _optim_state_dict(self, optimizer):
131-
"""
132-
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
133-
:meth:`consolidate_state_dict`.
134-
"""
135-
return optimizer.state_dict()
136-
137122
def pre_backward(self, closure_loss: Tensor) -> None:
138123
pass
139124

src/pytorch_lightning/strategies/sharded_spawn.py

Lines changed: 1 addition & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
from contextlib import contextmanager
15-
from typing import Any, Dict, Generator, List, Tuple
15+
from typing import Dict, Generator, List, Tuple
1616

1717
from torch import Tensor
1818
from torch.nn import Module
@@ -25,7 +25,6 @@
2525
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2626
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE
2727
from pytorch_lightning.utilities.optimizer import optimizers_to_device
28-
from pytorch_lightning.utilities.rank_zero import rank_zero_only
2928

3029
if _FAIRSCALE_AVAILABLE:
3130
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
@@ -85,11 +84,6 @@ def _wrap_optimizers(self, optimizers: List[Optimizer]) -> List["OSS"]:
8584

8685
return self._reinit_optimizers_with_oss(optimizers)
8786

88-
def optimizer_state(self, optimizer: "OSS") -> Dict[str, Any]:
89-
if isinstance(optimizer, OSS):
90-
optimizer.consolidate_state_dict()
91-
return self._optim_state_dict(optimizer)
92-
9387
@contextmanager
9488
def block_backward_sync(self) -> Generator:
9589
"""Blocks syncing gradients behaviour on backwards pass.
@@ -103,14 +97,6 @@ def block_backward_sync(self) -> Generator:
10397
else:
10498
yield None
10599

106-
@rank_zero_only
107-
def _optim_state_dict(self, optimizer: Optimizer) -> Dict[str, Any]:
108-
"""
109-
Retrieves state dict only on rank 0, which contains the entire optimizer state after calling
110-
:meth:`consolidate_state_dict`.
111-
"""
112-
return optimizer.state_dict()
113-
114100
def pre_backward(self, closure_loss: Tensor) -> None:
115101
pass
116102

src/pytorch_lightning/strategies/strategy.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -170,6 +170,16 @@ def optimizer_state(self, optimizer: Optimizer) -> Dict[str, Tensor]:
170170
171171
Allows for syncing/collating optimizer state from processes in custom plugins.
172172
"""
173+
if isinstance(optimizer, LightningOptimizer):
174+
optimizer = optimizer._optimizer
175+
176+
if hasattr(optimizer, "consolidate_state_dict"):
177+
# there are optimizers like Fairscale's OSS or PyTorch's ZeroRedundancyOptimizer that shard their
178+
# states, and to avoid OOM we consolidate the full state on rank 0 only
179+
optimizer.consolidate_state_dict()
180+
return optimizer.state_dict() if self.is_global_zero else {}
181+
182+
# for optimizers that are not sharded, we return the state dict on all ranks
173183
return optimizer.state_dict()
174184

175185
def backward(

src/pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,7 +128,10 @@ def _run_power_scaling(
128128
"""Batch scaling mode where the size is doubled at each iteration until an OOM error is encountered."""
129129
for _ in range(max_trials):
130130
garbage_collection_cuda()
131-
trainer.fit_loop.global_step = 0 # reset after each try
131+
132+
# reset after each try
133+
_reset_progress(trainer)
134+
132135
try:
133136
# Try fit
134137
trainer.tuner._run(model)
@@ -166,7 +169,10 @@ def _run_binsearch_scaling(
166169
count = 0
167170
while True:
168171
garbage_collection_cuda()
169-
trainer.fit_loop.global_step = 0 # reset after each try
172+
173+
# reset after each try
174+
_reset_progress(trainer)
175+
170176
try:
171177
# Try fit
172178
trainer.tuner._run(model)
@@ -249,3 +255,12 @@ def _adjust_batch_size(
249255
def _is_valid_batch_size(batch_size: int, dataloader: DataLoader, trainer: "pl.Trainer"):
250256
module = trainer.lightning_module or trainer.datamodule
251257
return not has_len_all_ranks(dataloader, trainer.strategy, module) or batch_size <= len(dataloader)
258+
259+
260+
def _reset_progress(trainer: "pl.Trainer") -> None:
261+
if trainer.lightning_module.automatic_optimization:
262+
trainer.fit_loop.epoch_loop.batch_loop.optimizer_loop.optim_progress.reset()
263+
else:
264+
trainer.fit_loop.epoch_loop.batch_loop.manual_loop.optim_step_progress.reset()
265+
266+
trainer.fit_loop.epoch_progress.reset()

tests/tests_pytorch/strategies/test_ddp_strategy.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,8 +24,14 @@
2424
from pytorch_lightning.plugins.environments import ClusterEnvironment, LightningEnvironment
2525
from pytorch_lightning.strategies import DDPStrategy
2626
from pytorch_lightning.trainer.states import TrainerFn
27+
from pytorch_lightning.utilities.imports import _FAIRSCALE_AVAILABLE, _TORCH_GREATER_EQUAL_1_10
2728
from tests_pytorch.helpers.runif import RunIf
2829

30+
if _FAIRSCALE_AVAILABLE:
31+
from fairscale.optim import OSS
32+
if _TORCH_GREATER_EQUAL_1_10:
33+
from torch.distributed.optim import ZeroRedundancyOptimizer
34+
2935

3036
class BoringModelGPU(BoringModel):
3137
def on_train_start(self) -> None:
@@ -252,3 +258,50 @@ def test_ddp_strategy_set_timeout(mock_init_process_group):
252258
mock_init_process_group.assert_called_with(
253259
process_group_backend, rank=global_rank, world_size=world_size, timeout=test_timedelta
254260
)
261+
262+
263+
class BoringFairScaleOptimizerModel(BoringModel):
264+
def configure_optimizers(self):
265+
base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
266+
return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults)
267+
268+
269+
@RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True)
270+
@pytest.mark.parametrize("strategy", (pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn"))
271+
def test_ddp_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy):
272+
"""Test to ensure that checkpoint is saved correctly when using faircale optimizer."""
273+
model = BoringFairScaleOptimizerModel()
274+
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1)
275+
276+
trainer.fit(model)
277+
278+
checkpoint_path = os.path.join(tmpdir, "model.pt")
279+
trainer.save_checkpoint(checkpoint_path)
280+
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
281+
282+
# Assert model parameters are identical after loading
283+
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
284+
assert torch.equal(trained_param.to("cpu"), loaded_param)
285+
286+
287+
class BoringZeroRedundancyOptimizerModel(BoringModel):
288+
def configure_optimizers(self):
289+
return ZeroRedundancyOptimizer(self.layer.parameters(), optimizer_class=torch.optim.Adam, lr=0.1)
290+
291+
292+
@RunIf(min_cuda_gpus=2, skip_windows=True, min_torch="1.10")
293+
@pytest.mark.parametrize("strategy", (pytest.param("ddp", marks=RunIf(standalone=True)), "ddp_spawn"))
294+
def test_ddp_strategy_checkpoint_zero_redundancy_optimizer(tmpdir, strategy):
295+
"""Test to ensure that checkpoint is saved correctly when using zero redundancy optimizer."""
296+
model = BoringZeroRedundancyOptimizerModel()
297+
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1)
298+
299+
trainer.fit(model)
300+
301+
checkpoint_path = os.path.join(tmpdir, "model.pt")
302+
trainer.save_checkpoint(checkpoint_path)
303+
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
304+
305+
# Assert model parameters are identical after loading
306+
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
307+
assert torch.equal(trained_param.to("cpu"), loaded_param)

tests/tests_pytorch/strategies/test_sharded_strategy.py

Lines changed: 29 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
if _FAIRSCALE_AVAILABLE:
1616
from fairscale.nn.data_parallel.sharded_ddp import ShardedDataParallel
17+
from fairscale.optim import OSS
1718

1819

1920
@pytest.mark.parametrize("clip_val", [0, 10])
@@ -70,8 +71,8 @@ def test_ddp_sharded_strategy_checkpoint_cpu(tmpdir):
7071
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
7172

7273
# Assert model parameters are identical after loading
73-
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
74-
assert torch.equal(ddp_param.to("cpu"), shard_param)
74+
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
75+
assert torch.equal(trained_param.to("cpu"), loaded_param)
7576

7677

7778
@RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True)
@@ -87,8 +88,8 @@ def test_ddp_sharded_strategy_checkpoint_multi_gpu(tmpdir):
8788
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
8889

8990
# Assert model parameters are identical after loading
90-
for ddp_param, shard_param in zip(model.parameters(), saved_model.parameters()):
91-
assert torch.equal(ddp_param.to("cpu"), shard_param)
91+
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
92+
assert torch.equal(trained_param.to("cpu"), loaded_param)
9293

9394

9495
@RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True)
@@ -314,3 +315,27 @@ def test_block_backward_sync():
314315
def test_ddp_kwargs_from_registry(strategy_name, expected_ddp_kwargs):
315316
trainer = Trainer(strategy=strategy_name)
316317
assert trainer.strategy._ddp_kwargs == expected_ddp_kwargs
318+
319+
320+
class BoringFairScaleOptimizerModel(BoringModel):
321+
def configure_optimizers(self):
322+
base_optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.1)
323+
return OSS(params=base_optimizer.param_groups, optim=type(base_optimizer), **base_optimizer.defaults)
324+
325+
326+
@RunIf(min_cuda_gpus=2, skip_windows=True, fairscale=True)
327+
@pytest.mark.parametrize("strategy", (pytest.param("ddp_sharded", marks=RunIf(standalone=True)), "ddp_sharded_spawn"))
328+
def test_ddp_sharded_strategy_checkpoint_multi_gpu_fairscale_optimizer(tmpdir, strategy):
329+
"""Test to ensure that checkpoint is saved correctly when using fairscale optimizers."""
330+
model = BoringFairScaleOptimizerModel()
331+
trainer = Trainer(accelerator="gpu", devices=2, strategy=strategy, max_steps=1)
332+
333+
trainer.fit(model)
334+
335+
checkpoint_path = os.path.join(tmpdir, "model.pt")
336+
trainer.save_checkpoint(checkpoint_path)
337+
saved_model = BoringModel.load_from_checkpoint(checkpoint_path)
338+
339+
# Assert model parameters are identical after loading
340+
for trained_param, loaded_param in zip(model.parameters(), saved_model.parameters()):
341+
assert torch.equal(trained_param.to("cpu"), loaded_param)

tests/tests_pytorch/tuner/test_scale_batch_size.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
from copy import deepcopy
16+
from unittest.mock import patch
1617

1718
import pytest
1819
import torch
@@ -308,10 +309,13 @@ def __init__(self):
308309
def test_dataloader_reset_with_scale_batch_size(tmpdir, scale_method):
309310
"""Test that train and val dataloaders are reset at every update in scale batch size."""
310311
model = BatchSizeModel(batch_size=16)
311-
scale_batch_size_kwargs = {"max_trials": 5, "init_val": 4, "mode": scale_method}
312+
max_trials = 5
313+
scale_batch_size_kwargs = {"max_trials": max_trials, "steps_per_trial": 2, "init_val": 4, "mode": scale_method}
312314

313-
trainer = Trainer(max_epochs=2, auto_scale_batch_size=True)
314-
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
315+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, auto_scale_batch_size=True)
316+
with patch.object(model, "on_train_epoch_end") as advance_mocked:
317+
new_batch_size = trainer.tune(model, scale_batch_size_kwargs=scale_batch_size_kwargs)["scale_batch_size"]
318+
assert advance_mocked.call_count == max_trials
315319

316320
assert trainer.train_dataloader.loaders.batch_size == new_batch_size
317321
assert trainer.val_dataloaders[0].batch_size == new_batch_size

0 commit comments

Comments
 (0)