Skip to content

Commit c43ee1f

Browse files
committed
Support optimizer step progress tracking with manual optimization
1 parent 6356ef3 commit c43ee1f

File tree

9 files changed

+67
-55
lines changed

9 files changed

+67
-55
lines changed

pytorch_lightning/core/optimizer.py

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,10 +161,22 @@ def closure_dis():
161161
profiler_action += f"_{self._optimizer_idx}"
162162

163163
assert self._strategy is not None
164-
assert self._strategy.lightning_module is not None
165-
with self._strategy.lightning_module.trainer.profiler.profile(profiler_action):
164+
lightning_module = self._strategy.lightning_module
165+
assert lightning_module is not None
166+
trainer = lightning_module.trainer
167+
batch_loop = trainer.fit_loop.epoch_loop.batch_loop
168+
if lightning_module.automatic_optimization:
169+
progress = batch_loop.optimizer_loop.optim_progress
170+
else:
171+
progress = batch_loop.manual_loop.optim_progress
172+
173+
progress.optimizer.step.increment_ready()
174+
175+
with trainer.profiler.profile(profiler_action):
166176
self._strategy.optimizer_step(self._optimizer, self._optimizer_idx, closure, **kwargs)
167177

178+
progress.optimizer.step.increment_completed()
179+
168180

169181
def _init_optimizers_and_lr_schedulers(
170182
model: "pl.LightningModule",

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,10 @@ def batch_idx(self) -> int:
8989

9090
@property
9191
def global_step(self) -> int:
92-
return self.batch_loop.optimizer_loop.optim_progress.optimizer_steps
92+
lightning_module = self.trainer.lightning_module
93+
if lightning_module is None or lightning_module.automatic_optimization:
94+
return self.batch_loop.optimizer_loop.optim_progress.optimizer_steps
95+
return self.batch_loop.manual_loop.optim_progress.optimizer_steps
9396

9497
@property
9598
def _is_training_done(self) -> bool:

pytorch_lightning/loops/optimization/manual_loop.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
from pytorch_lightning.loops import Loop
2020
from pytorch_lightning.loops.optimization.closure import OutputResult
2121
from pytorch_lightning.loops.utilities import _build_training_step_kwargs, _extract_hiddens
22+
from pytorch_lightning.trainer.progress import OptimizationProgress
2223
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2324
from pytorch_lightning.utilities.types import STEP_OUTPUT
2425

@@ -74,6 +75,9 @@ class ManualOptimization(Loop[_OUTPUTS_TYPE]):
7475

7576
def __init__(self) -> None:
7677
super().__init__()
78+
# FIXME: should this be a simpler progress? lr schedulers are not wrapped anyways
79+
self.optim_progress = OptimizationProgress()
80+
7781
self._done: bool = False
7882
self._hiddens: Optional[Any] = None
7983
self._output: _OUTPUTS_TYPE = {}

pytorch_lightning/loops/optimization/optimizer_loop.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -359,8 +359,6 @@ def _optimizer_step(
359359
else:
360360
optimizer = self.trainer.strategy._lightning_optimizers[opt_idx]
361361

362-
self.optim_progress.optimizer.step.increment_ready()
363-
364362
# model hook
365363
self.trainer._call_lightning_module_hook(
366364
"optimizer_step",
@@ -374,8 +372,6 @@ def _optimizer_step(
374372
using_lbfgs=is_lbfgs,
375373
)
376374

377-
self.optim_progress.optimizer.step.increment_completed()
378-
379375
def _on_before_zero_grad(self, optimizer: torch.optim.Optimizer) -> None:
380376
"""Calls the ``on_before_zero_grad`` hook.
381377

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2010,6 +2010,7 @@ def data_parallel_device_ids(self) -> Optional[List[int]]:
20102010

20112011
@property
20122012
def lightning_module(self) -> "pl.LightningModule":
2013+
# TODO: this is actually an optional return
20132014
return self.strategy.lightning_module
20142015

20152016
@property

tests/loops/test_loop_state_dict.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,16 @@ def test_loops_state_dict_structure():
5656
},
5757
"epoch_loop.batch_loop.state_dict": {},
5858
"epoch_loop.batch_loop.manual_loop.state_dict": {},
59+
"epoch_loop.batch_loop.manual_loop.optim_progress": {
60+
"optimizer": {
61+
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
62+
"zero_grad": {
63+
"total": {"ready": 0, "started": 0, "completed": 0},
64+
"current": {"ready": 0, "started": 0, "completed": 0},
65+
},
66+
},
67+
"optimizer_position": 0,
68+
},
5969
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
6070
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
6171
"optimizer": {

tests/loops/test_loops.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -512,6 +512,16 @@ def configure_optimizers_multiple(self):
512512
},
513513
"epoch_loop.batch_loop.state_dict": ANY,
514514
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
515+
"epoch_loop.batch_loop.manual_loop.optim_progress": {
516+
"optimizer_position": 0,
517+
"optimizer": {
518+
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
519+
"zero_grad": {
520+
"total": {"ready": 0, "started": 0, "completed": 0},
521+
"current": {"ready": 0, "started": 0, "completed": 0},
522+
},
523+
},
524+
},
515525
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
516526
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
517527
"optimizer_position": stop_optimizer,
@@ -680,6 +690,16 @@ def train_dataloader(self):
680690
},
681691
"epoch_loop.batch_loop.state_dict": ANY,
682692
"epoch_loop.batch_loop.manual_loop.state_dict": ANY,
693+
"epoch_loop.batch_loop.manual_loop.optim_progress": {
694+
"optimizer_position": 0,
695+
"optimizer": {
696+
"step": {"total": {"ready": 0, "completed": 0}, "current": {"ready": 0, "completed": 0}},
697+
"zero_grad": {
698+
"total": {"ready": 0, "started": 0, "completed": 0},
699+
"current": {"ready": 0, "started": 0, "completed": 0},
700+
},
701+
},
702+
},
683703
"epoch_loop.batch_loop.optimizer_loop.state_dict": {},
684704
"epoch_loop.batch_loop.optimizer_loop.optim_progress": {
685705
"optimizer_position": n_optimizers,

tests/models/test_hooks.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -558,7 +558,6 @@ def training_step(self, batch, batch_idx):
558558
dict(name="on_validation_model_train"),
559559
dict(name="training_epoch_end", args=([dict(loss=ANY)] * train_batches,)),
560560
dict(name="Callback.on_train_epoch_end", args=(trainer, model)),
561-
# FIXME: there seems to be a problem with manual here
562561
# `ModelCheckpoint.save_checkpoint` is called here from `Callback.on_train_epoch_end`
563562
dict(name="Callback.on_save_checkpoint", args=(trainer, model, saved_ckpt)),
564563
dict(name="on_save_checkpoint", args=(saved_ckpt,)),

tests/trainer/optimization/test_manual_optimization.py

Lines changed: 14 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def training_epoch_end(self, outputs) -> None:
165165
with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock:
166166
trainer.fit(model)
167167
assert bwd_mock.call_count == limit_train_batches * 3
168+
assert trainer.global_step == limit_train_batches * 2
168169

169170

170171
def test_multiple_optimizers_manual_log(tmpdir):
@@ -524,18 +525,14 @@ def optimizer_closure():
524525
weight_after = self.layer.weight.clone()
525526
assert not torch.equal(weight_before, weight_after)
526527

527-
def configure_optimizers(self):
528-
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
529-
530528
model = TestModel()
531-
model.val_dataloader = None
532529
model.training_epoch_end = None
533530

534531
limit_train_batches = 2
535532
trainer = Trainer(
536533
default_root_dir=tmpdir,
537534
limit_train_batches=limit_train_batches,
538-
limit_val_batches=2,
535+
limit_val_batches=0,
539536
max_epochs=1,
540537
log_every_n_steps=1,
541538
)
@@ -547,58 +544,45 @@ def configure_optimizers(self):
547544
assert trainer.progress_bar_metrics["train_loss_epoch"] == torch.stack(model._losses).mean()
548545

549546

550-
def test_step_with_optimizer_closure_and_accumulated_grad(tmpdir):
551-
"""Tests that `step` works with optimizer_closure and accumulated_grad."""
552-
547+
def test_step_with_optimizer_closure_2(tmpdir):
553548
class TestModel(BoringModel):
554549
def __init__(self):
555550
super().__init__()
556551
self.automatic_optimization = False
557552

558553
def training_step(self, batch, batch_idx):
559-
# manual
560554
opt = self.optimizers()
561555
x = batch[0]
562-
563-
loss_1 = self(x)
564-
loss_1 = self.loss(loss_1, loss_1)
556+
loss = self(x).sum()
565557

566558
def optimizer_closure():
567559
# emulate bayesian optimization.
568560
num_backward = 1
569561
for backward_idx in range(num_backward + 1):
570562
retain_graph = num_backward != backward_idx
571-
self.manual_backward(loss_1, retain_graph=retain_graph)
563+
self.manual_backward(loss, retain_graph=retain_graph)
572564

573565
weight_before = self.layer.weight.clone()
574-
575566
opt.step(closure=optimizer_closure)
576-
577567
weight_after = self.layer.weight.clone()
578-
if not self.trainer.fit_loop._should_accumulate():
579-
assert not torch.equal(weight_before, weight_after)
580-
else:
581-
assert self.layer.weight.grad is not None
582-
583-
def configure_optimizers(self):
584-
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
568+
assert not torch.equal(weight_before, weight_after)
585569

586570
model = TestModel()
587-
model.val_dataloader = None
588571
model.training_epoch_end = None
589572

590573
limit_train_batches = 4
591574
trainer = Trainer(
592575
default_root_dir=tmpdir,
593576
limit_train_batches=limit_train_batches,
594-
limit_val_batches=2,
577+
limit_val_batches=0,
595578
max_epochs=1,
596579
log_every_n_steps=1,
597580
)
598581

599582
with mock.patch.object(Strategy, "backward", wraps=trainer.strategy.backward) as bwd_mock:
600583
trainer.fit(model)
601584
assert bwd_mock.call_count == limit_train_batches * 2
585+
assert trainer.global_step == limit_train_batches
602586

603587

604588
@patch("torch.optim.SGD.step")
@@ -614,41 +598,23 @@ def on_train_start(self) -> None:
614598
step_mock.reset_mock()
615599

616600
def training_step(self, batch, batch_idx):
617-
# manual
618601
opt = self.optimizers()
619-
x = batch[0]
620-
621-
loss_1 = self(x)
622-
loss_1 = self.loss(loss_1, loss_1)
623-
624-
def optimizer_closure():
625-
# emulate bayesian optimization.
626-
num_backward = 1
627-
for backward_idx in range(num_backward + 1):
628-
retain_graph = num_backward != backward_idx
629-
self.manual_backward(loss_1, retain_graph=retain_graph)
630-
631-
opt.step(closure=optimizer_closure)
632-
opt.zero_grad()
633-
634-
def configure_optimizers(self):
635-
return torch.optim.SGD(self.layer.parameters(), lr=0.1)
602+
opt.step(closure=lambda: ..., foo=123)
636603

637604
model = TestModel()
638-
model.val_dataloader = None
639605
model.training_epoch_end = None
640606

641-
limit_train_batches = 4
607+
limit_train_batches = 2
642608
trainer = Trainer(
643609
default_root_dir=tmpdir,
644610
limit_train_batches=limit_train_batches,
645-
limit_val_batches=2,
611+
limit_val_batches=0,
646612
max_epochs=1,
647-
log_every_n_steps=1,
648613
)
649614

650615
trainer.fit(model)
651-
assert step_mock.mock_calls == [call(closure=ANY) for _ in range(limit_train_batches)]
616+
assert step_mock.mock_calls == [call(closure=ANY, foo=123) for _ in range(limit_train_batches)]
617+
assert trainer.global_step == limit_train_batches
652618

653619

654620
@patch("torch.optim.Adam.step")
@@ -724,6 +690,7 @@ def configure_optimizers(self):
724690
trainer.fit(model)
725691
assert mock_sgd_step.mock_calls == [call(closure=ANY, optim="sgd") for _ in range(4)]
726692
assert mock_adam_step.mock_calls == [call(closure=ANY) for _ in range(2)]
693+
assert trainer.global_step == 4 + 2
727694

728695

729696
class TesManualOptimizationDDPModel(BoringModel):

0 commit comments

Comments
 (0)