Skip to content

Commit f663cfd

Browse files
committed
Integrate global step with progress tracking
1 parent b687fd1 commit f663cfd

File tree

17 files changed

+87
-98
lines changed

17 files changed

+87
-98
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 7 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def __init__(
222222
self.save_weights_only = save_weights_only
223223
self.auto_insert_metric_name = auto_insert_metric_name
224224
self._save_on_train_epoch_end = save_on_train_epoch_end
225-
self._last_global_step_saved = -1
225+
self._last_global_step_saved = 0 # no need to save when no steps were taken
226226
self._last_time_checked: Optional[float] = None
227227
self.current_score = None
228228
self.best_k_models = {}
@@ -275,8 +275,7 @@ def on_train_batch_end(
275275
"""Save checkpoint on train batch end if we meet the criteria for `every_n_train_steps`"""
276276
if self._should_skip_saving_checkpoint(trainer):
277277
return
278-
step = trainer.global_step
279-
skip_batch = self._every_n_train_steps < 1 or ((step + 1) % self._every_n_train_steps != 0)
278+
skip_batch = self._every_n_train_steps < 1 or (trainer.global_step % self._every_n_train_steps != 0)
280279

281280
train_time_interval = self._train_time_interval
282281
skip_time = True
@@ -297,16 +296,13 @@ def on_train_batch_end(
297296

298297
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
299298
"""Save a checkpoint at the end of the training epoch."""
300-
# as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates
301-
trainer.fit_loop.global_step -= 1
302299
if (
303300
not self._should_skip_saving_checkpoint(trainer)
304301
and self._save_on_train_epoch_end
305302
and self._every_n_epochs > 0
306303
and (trainer.current_epoch + 1) % self._every_n_epochs == 0
307304
):
308305
self.save_checkpoint(trainer)
309-
trainer.fit_loop.global_step += 1
310306

311307
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
312308
"""Save a checkpoint at the end of the validation stage."""
@@ -329,11 +325,8 @@ def on_train_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -
329325
return
330326
if self.verbose:
331327
rank_zero_info("Saving latest checkpoint...")
332-
# as we advance one step at end of training, we use `global_step - 1` to avoid saving duplicates
333-
monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step - 1)
334-
trainer.fit_loop.global_step -= 1
328+
monitor_candidates = self._monitor_candidates(trainer, trainer.current_epoch, trainer.global_step)
335329
self._save_last_checkpoint(trainer, monitor_candidates)
336-
trainer.fit_loop.global_step += 1
337330

338331
def on_save_checkpoint(
339332
self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", checkpoint: Dict[str, Any]
@@ -368,12 +361,8 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None:
368361
"""
369362
self._validate_monitor_key(trainer)
370363

371-
# track epoch when ckpt was last checked
372-
global_step = trainer.global_step
373-
self._last_global_step_saved = global_step
374-
375364
# what can be monitored
376-
monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=global_step)
365+
monitor_candidates = self._monitor_candidates(trainer, epoch=trainer.current_epoch, step=trainer.global_step)
377366

378367
# callback supports multiple simultaneous modes
379368
# here we call each mode sequentially
@@ -638,6 +627,7 @@ def _monitor_candidates(self, trainer: "pl.Trainer", epoch: int, step: int) -> D
638627
def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
639628
if not self.save_last:
640629
return
630+
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)
641631

642632
filepath = self.format_checkpoint_name(monitor_candidates, self.CHECKPOINT_NAME_LAST)
643633
# set the last model path before saving because it will be part of the state.
@@ -649,9 +639,9 @@ def _save_last_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[
649639
def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
650640
if self.monitor is None or self.save_top_k == 0:
651641
return
642+
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)
652643

653644
current = monitor_candidates.get(self.monitor)
654-
655645
if self.check_monitor_top_k(trainer, current):
656646
self._update_best_and_save(current, trainer, monitor_candidates)
657647
elif self.verbose:
@@ -662,6 +652,7 @@ def _save_top_k_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict
662652
def _save_none_monitor_checkpoint(self, trainer: "pl.Trainer", monitor_candidates: Dict[str, _METRIC]) -> None:
663653
if self.monitor is not None or self.save_top_k == 0:
664654
return
655+
self._last_global_step_saved = monitor_candidates.get("step", trainer.global_step)
665656

666657
filepath = self._get_metric_interpolated_filepath_name(monitor_candidates, trainer)
667658
# set the best model path before saving because it will be part of the state.

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None
5959
self.min_steps = min_steps
6060
self.max_steps = max_steps
6161

62-
self.global_step: int = 0
6362
self.batch_progress = BatchProgress()
6463
self.scheduler_progress = SchedulerProgress()
6564

@@ -72,6 +71,7 @@ def __init__(self, min_steps: Optional[int] = None, max_steps: int = -1) -> None
7271
self._dataloader_iter: Optional[Iterator] = None
7372
# caches the loaded dataloader state until dataloader objects are available
7473
self._dataloader_state_dict: Dict[str, Any] = {}
74+
self._legacy_global_step: int = 0
7575

7676
@property
7777
def total_batch_idx(self) -> int:
@@ -87,6 +87,13 @@ def batch_idx(self) -> int:
8787
# but before the next `ready` increase
8888
return self.batch_progress.current.ready - 1
8989

90+
@property
91+
def global_step(self) -> int:
92+
lightning_module = self.trainer.lightning_module
93+
if lightning_module is None or lightning_module.automatic_optimization:
94+
return self.batch_loop.optimizer_loop.optim_progress.optimizer_steps
95+
return self.batch_loop.manual_loop.optim_step_progress.total.completed
96+
9097
@property
9198
def _is_training_done(self) -> bool:
9299
max_steps_reached = _is_max_limit_reached(self.global_step, self.max_steps)
@@ -253,9 +260,9 @@ def on_advance_end(self) -> None:
253260
# update plateau LR scheduler after metrics are logged
254261
self.update_lr_schedulers("step", update_plateau_schedulers=True)
255262

256-
if not self._should_accumulate():
257-
# progress global step according to grads progress
258-
self.global_step += 1
263+
if self._should_accumulate():
264+
# this is increased once per batch disregarding multiple optimizers or tbptt on purpose for loggers
265+
self._legacy_global_step += 1
259266

260267
# if training finished, defer exit to the parent. this assumes there will be enough time in between
261268
# which might not be the case depending on what's in the `*_epoch_end` hooks

pytorch_lightning/loops/fit_loop.py

Lines changed: 5 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -68,16 +68,6 @@ def __init__(
6868
self._outputs: _EPOCH_OUTPUTS_TYPE = []
6969
self._data_fetcher: Optional[AbstractDataFetcher] = None
7070

71-
@property
72-
def global_step(self) -> int:
73-
"""Returns the global step."""
74-
return self.epoch_loop.global_step
75-
76-
@global_step.setter
77-
def global_step(self, value: int) -> None:
78-
"""Sets the global step (forwards to epoch_loop)"""
79-
self.epoch_loop.global_step = value
80-
8171
@property
8272
def total_batch_idx(self) -> int:
8373
"""Returns the current batch index (across epochs)"""
@@ -168,16 +158,16 @@ def _results(self) -> _ResultCollection:
168158
def done(self) -> bool:
169159
"""Evaluates when to leave the loop."""
170160
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
171-
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
172161
# `processed` is increased before `on_train_epoch_end`, the hook where checkpoints are typically saved.
173162
# we use it here because the checkpoint data won't have `completed` increased yet
163+
stop_steps = _is_max_limit_reached(self.epoch_loop.global_step, self.max_steps)
174164
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
175165

176166
should_stop = False
177167
if self.trainer.should_stop:
178168
# early stopping
179169
met_min_epochs = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
180-
met_min_steps = self.global_step >= self.min_steps if self.min_steps else True
170+
met_min_steps = self.epoch_loop.global_step >= self.min_steps if self.min_steps else True
181171
if met_min_epochs and met_min_steps:
182172
should_stop = True
183173
else:
@@ -308,14 +298,13 @@ def on_advance_end(self) -> None:
308298

309299
self.epoch_progress.increment_completed()
310300

311-
# the global step is manually decreased here due to backwards compatibility with existing loggers
301+
# the legacy global step is manually decreased here due to backwards compatibility with existing loggers
312302
# as they expect that the same step is used when logging epoch end metrics even when the batch loop has
313303
# finished. this means the attribute does not exactly track the number of optimizer steps applied.
314-
# TODO(@carmocca): deprecate and rename so users don't get confused
315-
self.global_step -= 1
304+
self.epoch_loop._legacy_global_step -= 1
316305
# log epoch metrics
317306
self.trainer.logger_connector.update_train_epoch_metrics()
318-
self.global_step += 1
307+
self.epoch_loop._legacy_global_step += 1
319308

320309
# if fault tolerant is enabled and process has been notified, exit.
321310
self.trainer._exit_gracefully_on_signal()

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -221,16 +221,20 @@ def restore_loops(self) -> None:
221221
if not self._loaded_checkpoint:
222222
return
223223

224-
self.trainer.fit_loop.global_step = self._loaded_checkpoint["global_step"]
224+
fit_loop = self.trainer.fit_loop
225+
# set the `global_step` value for old checkpoints without the progress tracking state.
226+
# it will be overwritten by the loop's state if it was also saved
227+
optimizer_loop = fit_loop.epoch_loop.batch_loop.optimizer_loop
228+
optimizer_loop.optim_progress.optimizer.step.total.completed = self._loaded_checkpoint["global_step"]
225229
# set the `current_epoch` value for old checkpoints without the progress tracking state.
226230
# it will be overwritten by the loop's state if it was also saved
227-
self.trainer.fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]
231+
fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]
228232

229233
assert self.trainer.state.fn is not None
230234
state_dict = self._loaded_checkpoint.get("loops")
231235
if state_dict is not None:
232236
if self.trainer.state.fn in (TrainerFn.FITTING, TrainerFn.TUNING):
233-
self.trainer.fit_loop.load_state_dict(state_dict["fit_loop"])
237+
fit_loop.load_state_dict(state_dict["fit_loop"])
234238
elif self.trainer.state.fn == TrainerFn.VALIDATING:
235239
self.trainer.validate_loop.load_state_dict(state_dict["validate_loop"])
236240
elif self.trainer.state.fn == TrainerFn.TESTING:
@@ -331,9 +335,9 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
331335
model = self.trainer.lightning_module
332336

333337
checkpoint = {
334-
# the epoch is saved for compatibility but it's not relevant for restoration
338+
# the epoch and global step are saved for compatibility but it's not relevant for restoration
335339
"epoch": self.trainer.current_epoch,
336-
"global_step": self.trainer.global_step + 1,
340+
"global_step": self.trainer.global_step,
337341
"pytorch-lightning_version": pl.__version__,
338342
"state_dict": self._get_lightning_module_state_dict(),
339343
"loops": self._get_loops_state_dict(),

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -67,12 +67,12 @@ def on_trainer_init(
6767

6868
@property
6969
def should_flush_logs(self) -> bool:
70-
should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0
70+
should_flush = self.trainer.global_step % self.trainer.flush_logs_every_n_steps == 0
7171
return should_flush or self.trainer.should_stop
7272

7373
@property
7474
def should_update_logs(self) -> bool:
75-
should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
75+
should_log_every_n_steps = self.trainer.global_step % self.trainer.log_every_n_steps == 0
7676
return should_log_every_n_steps or self.trainer.should_stop
7777

7878
def configure_logger(self, logger: Union[bool, LightningLoggerBase, Iterable[LightningLoggerBase]]) -> None:
@@ -111,7 +111,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
111111
if step is None:
112112
# added metrics for convenience
113113
scalar_metrics.setdefault("epoch", self.trainer.current_epoch)
114-
step = self.trainer.global_step
114+
step = self.trainer.fit_loop.epoch_loop._legacy_global_step
115115

116116
# log actual metrics
117117
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2427,7 +2427,7 @@ def sanity_checking(self, val: bool) -> None:
24272427

24282428
@property
24292429
def global_step(self) -> int:
2430-
return self.fit_loop.global_step
2430+
return self.fit_loop.epoch_loop.global_step
24312431

24322432
@property
24332433
def current_epoch(self) -> int:

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,9 +60,7 @@ def scale_batch_size(
6060

6161
# Save initial model, that is loaded after batch size is found
6262
ckpt_path = os.path.join(trainer.default_root_dir, f".scale_batch_size_{uuid.uuid4()}.ckpt")
63-
trainer.fit_loop.global_step -= 1
6463
trainer.save_checkpoint(ckpt_path)
65-
trainer.fit_loop.global_step += 1
6664
params = __scale_batch_dump_params(trainer)
6765

6866
# Set to values that are required by the algorithm

pytorch_lightning/tuner/lr_finder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -204,9 +204,7 @@ def lr_find(
204204

205205
# Save initial model, that is loaded after learning rate is found
206206
ckpt_path = os.path.join(trainer.default_root_dir, f".lr_find_{uuid.uuid4()}.ckpt")
207-
trainer.fit_loop.global_step -= 1
208207
trainer.save_checkpoint(ckpt_path)
209-
trainer.fit_loop.global_step += 1
210208
params = __lr_finder_dump_params(trainer)
211209

212210
# Set to values that are required by the algorithm

tests/callbacks/test_rich_progress_bar.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -354,15 +354,15 @@ def test_step(self, batch, batch_idx):
354354
trainer.fit(model)
355355
assert pbar.calls["fit"] == [
356356
("sanity_check", 0, 0, {"b": 0}),
357-
("train", 0, 0, {}),
358357
("train", 0, 1, {}),
359-
("validate", 0, 1, {"b": 1}), # validation end
358+
("train", 0, 2, {}),
359+
("validate", 0, 2, {"b": 2}), # validation end
360360
# epoch end over, `on_epoch=True` metrics are computed
361-
("train", 0, 2, {"a": 1, "b": 1}), # training epoch end
362-
("train", 1, 2, {"a": 1, "b": 1}),
363-
("train", 1, 3, {"a": 1, "b": 1}),
364-
("validate", 1, 3, {"a": 1, "b": 3}), # validation end
365-
("train", 1, 4, {"a": 3, "b": 3}), # training epoch end
361+
("train", 0, 2, {"a": 1, "b": 2}), # training epoch end
362+
("train", 1, 3, {"a": 1, "b": 2}),
363+
("train", 1, 4, {"a": 1, "b": 2}),
364+
("validate", 1, 4, {"a": 1, "b": 4}), # validation end
365+
("train", 1, 4, {"a": 3, "b": 4}), # training epoch end
366366
]
367367

368368
trainer.validate(model, verbose=False)

tests/callbacks/test_tqdm_progress_bar.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -608,15 +608,15 @@ def test_step(self, batch, batch_idx):
608608
trainer.fit(model)
609609
assert pbar.calls["fit"] == [
610610
("sanity_check", 0, 0, {"b": 0}),
611-
("train", 0, 0, {}),
612611
("train", 0, 1, {}),
613-
("validate", 0, 1, {"b": 1}), # validation end
612+
("train", 0, 2, {}),
613+
("validate", 0, 2, {"b": 2}), # validation end
614614
# epoch end over, `on_epoch=True` metrics are computed
615-
("train", 0, 2, {"a": 1, "b": 1}), # training epoch end
616-
("train", 1, 2, {"a": 1, "b": 1}),
617-
("train", 1, 3, {"a": 1, "b": 1}),
618-
("validate", 1, 3, {"a": 1, "b": 3}), # validation end
619-
("train", 1, 4, {"a": 3, "b": 3}), # training epoch end
615+
("train", 0, 2, {"a": 1, "b": 2}), # training epoch end
616+
("train", 1, 3, {"a": 1, "b": 2}),
617+
("train", 1, 4, {"a": 1, "b": 2}),
618+
("validate", 1, 4, {"a": 1, "b": 4}), # validation end
619+
("train", 1, 4, {"a": 3, "b": 4}), # training epoch end
620620
]
621621

622622
trainer.validate(model, verbose=False)

0 commit comments

Comments
 (0)