Skip to content

Commit cfad43f

Browse files
committed
Fix and CHANGELOG
1 parent 70fd1dd commit cfad43f

File tree

2 files changed

+9
-1
lines changed

2 files changed

+9
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
7777

7878
- Added `opt_idx` to scheduler config if not assigned by user ([#11247](https://github.com/PyTorchLightning/pytorch-lightning/pull/11247))
7979

80+
- Added support for optimizer step progress tracking with manual optimization ([#11848](https://github.com/PyTorchLightning/pytorch-lightning/pull/11848))
81+
82+
8083

8184
- Return the output of the `optimizer.step`. This can be useful for `LightningLite` users, manual optimization users, or users overriding `LightningModule.optimizer_step` ([#11711](https://github.com/PyTorchLightning/pytorch-lightning/pull/11711))
8285

pytorch_lightning/loops/optimization/manual_loop.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
from torch import Tensor
1818

19+
from pytorch_lightning.core.optimizer import do_nothing_closure
1920
from pytorch_lightning.loops import Loop
2021
from pytorch_lightning.loops.optimization.closure import OutputResult
2122
from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens
@@ -76,7 +77,7 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
7677
def __init__(self) -> None:
7778
super().__init__()
7879
# since manual optimization does not track lr scheduler or optimizer frequencies, we use a simpler progress than
79-
# `OptimizerProgress`
80+
# `OptimizationProgress`
8081
self.optim_step_progress = Progress.from_defaults(ReadyCompletedTracker)
8182

8283
self._done: bool = False
@@ -137,6 +138,10 @@ def advance(self, batch: Any, batch_idx: int) -> None: # type: ignore[override]
137138
def on_run_end(self) -> _OUTPUTS_TYPE:
138139
"""Returns the result of this loop, i.e., the post-processed outputs from the training step."""
139140
output, self._output = self._output, {} # free memory
141+
# reset logic around the optimizer step
142+
for i, lightning_optimizer in self.trainer.strategy._lightning_optimizers.items():
143+
lightning_optimizer.on_before_step = do_nothing_closure
144+
lightning_optimizer.on_after_step = do_nothing_closure
140145
return output
141146

142147
def _on_before_step(self) -> None:

0 commit comments

Comments
 (0)