Skip to content

Commit 5afe552

Browse files
committed
Fixes
1 parent dc28ad0 commit 5afe552

File tree

4 files changed

+11
-10
lines changed

4 files changed

+11
-10
lines changed

pytorch_lightning/core/optimizer.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,8 +54,8 @@ def __init__(self, optimizer: Optimizer):
5454
self._strategy: Optional[pl.strategies.Strategy] = None
5555
self._optimizer_idx = 0
5656
# to inject logic around the optimizer step, particularly useful with manual optimization
57-
self.on_before_step = do_nothing_closure
58-
self.on_after_step = do_nothing_closure
57+
self._on_before_step = do_nothing_closure
58+
self._on_after_step = do_nothing_closure
5959

6060
@property
6161
def optimizer(self) -> Optimizer:
@@ -157,7 +157,7 @@ def closure_dis():
157157
with opt_dis.toggle_model(sync_grad=accumulated_grad_batches):
158158
opt_dis.step(closure=closure_dis)
159159
"""
160-
self.on_before_step()
160+
self._on_before_step()
161161

162162
if closure is None:
163163
closure = do_nothing_closure
@@ -173,7 +173,7 @@ def closure_dis():
173173
with self._strategy.lightning_module.trainer.profiler.profile(profiler_action):
174174
step_output = self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
175175

176-
self.on_after_step()
176+
self._on_after_step()
177177

178178
return step_output
179179

pytorch_lightning/loops/optimization/manual_loop.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,8 +94,8 @@ def reset(self) -> None:
9494
def on_run_start(self, *_: Any, **__: Any) -> None:
9595
# inject logic around the optimizer step
9696
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
97-
lightning_optimizer.on_before_step = self._on_before_step
98-
lightning_optimizer.on_after_step = self._on_after_step
97+
lightning_optimizer._on_before_step = self._on_before_step
98+
lightning_optimizer._on_after_step = self._on_after_step
9999

100100
def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
101101
"""Performs the training step for manual optimization.
@@ -140,8 +140,8 @@ def on_run_end(self) -> _OUTPUTS_TYPE:
140140
output, self._output = self._output, {} # free memory
141141
# reset logic around the optimizer step
142142
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
143-
lightning_optimizer.on_before_step = do_nothing_closure
144-
lightning_optimizer.on_after_step = do_nothing_closure
143+
lightning_optimizer._on_before_step = do_nothing_closure
144+
lightning_optimizer._on_after_step = do_nothing_closure
145145
return output
146146

147147
def _on_before_step(self) -> None:

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
333333
checkpoint = {
334334
# the epoch is saved for compatibility but it's not relevant for restoration
335335
"epoch": self.trainer.current_epoch,
336-
"global_step": self.trainer.global_step + 1,
336+
"global_step": self.trainer.global_step + model.automatic_optimization,
337337
"pytorch-lightning_version": pl.__version__,
338338
"state_dict": self._get_lightning_module_state_dict(),
339339
"loops": self._get_loops_state_dict(),

tests/core/test_lightning_optimizer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,8 @@ def test_state(tmpdir):
152152
lightning_dict = {
153153
k: v
154154
for k, v in lightning_optimizer.__dict__.items()
155-
if k not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module"}
155+
if k
156+
not in {"_optimizer", "_optimizer_idx", "_strategy", "_lightning_module", "_on_before_step", "_on_after_step"}
156157
}
157158

158159
assert lightning_dict == optimizer.__dict__

0 commit comments

Comments
 (0)