Skip to content

Commit e7134a9

Browse files
SeanNarentchaton
andauthored
Sharded Plugin 2/n: Allow ddp plugin to modify optimizer state saving (#4675)
* Allow ddp plugin to modify optimizer state saving * Rely on the accelerator for optimizer states * Ensure we init the accelerator for the saving function * Better comment for optim state dump * Revert "Ensure we init the accelerator for the saving function" This reverts commit af65eff * Added accelerator check to initialize tuner before saving model checkpoint * Simplify comment * Revert "Added accelerator check to initialize tuner before saving model checkpoint" This reverts commit f9929c0 * Return single optimizer state to reduce duplication * Fixed docstring * Fixed typing * Fixed comment * Added CHANGELOG.md Co-authored-by: chaton <[email protected]>
1 parent 8283680 commit e7134a9

File tree

4 files changed

+24
-3
lines changed

4 files changed

+24
-3
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
3636
[#4439](https://github.com/PyTorchLightning/pytorch-lightning/pull/4439))
3737

3838

39+
- Added ability for DDP plugin to modify optimizer state saving ([#4675](https://github.com/PyTorchLightning/pytorch-lightning/pull/4675))
40+
41+
3942
### Changed
4043

4144
- Tuner algorithms will be skipped if `fast_dev_run=True` ([#3903](https://github.com/PyTorchLightning/pytorch-lightning/pull/3903))

pytorch_lightning/accelerators/accelerator.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414
import os
1515
from enum import Enum
16-
from typing import Any, Optional, Union
16+
from typing import Any, Optional, Union, List
1717

1818
import torch
1919
from torch.optim import Optimizer
@@ -202,6 +202,17 @@ def sync_tensor(self,
202202
"""
203203
raise NotImplementedError()
204204

205+
def optimizer_state(self, optimizer: Optimizer) -> dict:
206+
"""
207+
Returns state of an optimizer. Allows for syncing/collating optimizer state from processes in custom
208+
plugins.
209+
Return:
210+
Optimizer state dict
211+
"""
212+
if self.ddp_plugin:
213+
return self.ddp_plugin.optimizer_state(optimizer)
214+
return optimizer.state_dict()
215+
205216
def __getstate__(self):
206217
return {
207218
'trainer': self.trainer,

pytorch_lightning/plugins/ddp_plugin.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from typing import List, Dict, Any
22

3+
from torch.optim import Optimizer
4+
35
from pytorch_lightning.core.lightning import LightningModule
46
from pytorch_lightning.overrides.data_parallel import LightningDistributedDataParallel
57

@@ -80,3 +82,6 @@ def on_before_forward(self, model, *args):
8082
Returns: args moved to correct device if needed.
8183
"""
8284
return args
85+
86+
def optimizer_state(self, optimizer: Optimizer) -> dict:
87+
return optimizer.state_dict()

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -298,10 +298,12 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
298298
callback_states = self.trainer.on_save_checkpoint()
299299
checkpoint['callbacks'] = callback_states
300300

301-
# dump optimizers
302301
optimizer_states = []
303302
for i, optimizer in enumerate(self.trainer.optimizers):
304-
optimizer_states.append(optimizer.state_dict())
303+
# Rely on accelerator to dump optimizer state
304+
optimizer_state = self.trainer.accelerator_backend.optimizer_state(optimizer)
305+
optimizer_states.append(optimizer_state)
306+
305307
checkpoint['optimizer_states'] = optimizer_states
306308

307309
# dump lr schedulers

0 commit comments

Comments
 (0)