Skip to content

Commit a304dba

Browse files
committed
Resolve FIXME: implement option 3b
1 parent ea8299d commit a304dba

File tree

8 files changed

+44
-32
lines changed

8 files changed

+44
-32
lines changed

pytorch_lightning/loops/base.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,7 @@ class Loop(ABC, Generic[T]):
4848
"""
4949

5050
def __init__(self) -> None:
51-
self.restarting = False
51+
self._restarting = False
5252
self._trainer: Optional["pl.Trainer"] = None
5353

5454
@property
@@ -69,6 +69,17 @@ def trainer(self, trainer: "pl.Trainer") -> None:
6969
if isinstance(v, Loop):
7070
v.trainer = trainer
7171

72+
@property
73+
def restarting(self) -> bool:
74+
return self._restarting
75+
76+
@restarting.setter
77+
def restarting(self, restarting: bool) -> None:
78+
self._restarting = restarting
79+
for loop in vars(self).values():
80+
if isinstance(loop, Loop):
81+
loop.restarting = restarting
82+
7283
@property
7384
@abstractmethod
7485
def done(self) -> bool:
@@ -189,7 +200,7 @@ def run(self, *args, **kwargs):
189200
self.on_advance_start(*args, **kwargs)
190201
self.advance(*args, **kwargs)
191202
self.on_advance_end()
192-
self.restarting = False
203+
self._restarting = False
193204
except StopIteration:
194205
break
195206

@@ -298,6 +309,7 @@ def load_state_dict(
298309
for k, v in self.__dict__.items():
299310
if isinstance(v, Loop):
300311
v.load_state_dict(state_dict.copy(), prefix + k + ".")
312+
self.restarting = True
301313

302314
def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional[Dict[str, Metric]] = None) -> None:
303315
for k, v in self.__dict__.items():
@@ -333,4 +345,3 @@ def _load_from_state_dict(self, state_dict: Dict, prefix: str, metrics: Optional
333345

334346
if prefix + "state_dict" in state_dict: # compatibility with old checkpoints
335347
self.on_load_checkpoint(state_dict[prefix + "state_dict"])
336-
self.restarting = True

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ def batch_idx(self) -> int:
8888
"""Returns the current batch index (within this epoch)"""
8989
# use `ready` instead of `completed` in case this is accessed after `completed` has been increased
9090
# but before the next `ready` increase
91-
return self.batch_progress.current.ready - 1
91+
return max(self.batch_progress.current.ready - 1, 0)
9292

9393
@property
9494
def _is_training_done(self) -> bool:
@@ -130,12 +130,6 @@ def reset(self) -> None:
130130
self.batch_progress.reset_on_restart()
131131
self.scheduler_progress.reset_on_restart()
132132
self.batch_loop.optimizer_loop.optim_progress.reset_on_restart()
133-
# FIXME: fuck me this makes
134-
# 1) test_restore::test_correct_step_and_epoch pass
135-
# 2) test_model_checkpoint::test_checkpoint_repeated_strategy_extended fail
136-
# 1) restarts after on_train_end (ce: 2, gs: 4) -> (ce: 4, gs: 8)
137-
# 2) restarts after on_train_epoch_end (ce: 1, gs: 4) -> (ce: 2, gs: 4)
138-
# if not self.restarting or self.done:
139133
else:
140134
self.batch_progress.reset_on_run()
141135
self.scheduler_progress.reset_on_run()
@@ -148,7 +142,7 @@ def reset(self) -> None:
148142

149143
def on_run_start(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
150144
self._reload_dataloader_state_dict(data_fetcher)
151-
self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_idx + 1)
145+
self._dataloader_iter = _update_dataloader_iter(data_fetcher, self.batch_idx)
152146

153147
def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[override]
154148
"""Runs a single training batch.
@@ -159,6 +153,9 @@ def advance(self, data_fetcher: AbstractDataFetcher) -> None: # type: ignore[ov
159153
if self.restarting and self._should_check_val_fx(self.batch_idx, self.batch_progress.is_last_batch):
160154
# skip training and run validation in `on_advance_end`
161155
return
156+
else:
157+
# we are going to train first so the val loop does not need to restart
158+
self.val_loop.restarting = False
162159

163160
assert self._dataloader_iter is not None
164161
batch_idx, (batch, self.batch_progress.is_last_batch) = next(self._dataloader_iter)

pytorch_lightning/loops/fit_loop.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -56,16 +56,6 @@ def __init__(
5656
self._is_fresh_start_epoch: bool = True
5757
self._outputs: _EPOCH_OUTPUTS_TYPE = []
5858

59-
@property
60-
def current_epoch(self) -> int:
61-
"""Return the current epoch."""
62-
return self.epoch_progress.current.completed
63-
64-
@current_epoch.setter
65-
def current_epoch(self, value: int) -> None:
66-
"""Setter for the current epoch."""
67-
self.epoch_progress.current.completed = value
68-
6959
@property
7060
def global_step(self) -> int:
7161
"""Returns the global step."""
@@ -129,6 +119,18 @@ def running_loss(self) -> TensorRunningAccum:
129119
"""Returns the running loss."""
130120
return self.epoch_loop.batch_loop.running_loss
131121

122+
@Loop.restarting.setter
123+
def restarting(self, restarting: bool) -> None:
124+
# if the last epoch completely finished, we are not actually restarting, we can check this to see if all
125+
# current values are equal
126+
values = (
127+
self.epoch_progress.current.ready,
128+
self.epoch_progress.current.started,
129+
self.epoch_progress.current.processed,
130+
)
131+
restarting &= any(v != self.epoch_progress.current.completed for v in values)
132+
Loop.restarting.fset(self, restarting) # call the parent setter
133+
132134
@property
133135
def _skip_backward(self) -> bool:
134136
"""Determines whether the loop will skip backward during automatic optimization."""
@@ -152,11 +154,11 @@ def done(self) -> bool:
152154
"""Evaluates when to leave the loop."""
153155
# TODO(@awaelchli): Move track steps inside training loop and move part of these condition inside training loop
154156
stop_steps = _is_max_limit_reached(self.global_step, self.max_steps)
155-
stop_epochs = _is_max_limit_reached(self.current_epoch, self.max_epochs)
157+
stop_epochs = _is_max_limit_reached(self.epoch_progress.current.processed, self.max_epochs)
156158

157159
should_stop = self.trainer.should_stop
158160
if should_stop:
159-
should_stop = self.current_epoch >= self.min_epochs if self.min_epochs else True
161+
should_stop = self.epoch_progress.current.processed >= self.min_epochs if self.min_epochs else True
160162
if not should_stop:
161163
log.info(
162164
f"Trainer was signaled to stop but required minimum epochs ({self.min_epochs}) has not been met."
@@ -169,7 +171,7 @@ def skip(self) -> bool:
169171
"""Whether we should skip the training and immediately return from the call to :meth:`run`."""
170172
# since `trainer.num_training_batches` depends on the `train_dataloader` but that won't be called
171173
# until `on_run_start`, we use `limit_train_batches` instead
172-
return self.done or self.trainer.limit_train_batches == 0
174+
return self.trainer.limit_train_batches == 0
173175

174176
def connect(self, epoch_loop: TrainingEpochLoop) -> None: # type: ignore[override]
175177
"""Connects a training epoch loop to this fit loop."""
@@ -207,7 +209,7 @@ def on_advance_start(self) -> None: # type: ignore[override]
207209
getattr(self.trainer.train_dataloader.sampler, "set_epoch", None)
208210
):
209211
# set seed for distributed sampler (enables shuffling for each epoch)
210-
self.trainer.train_dataloader.sampler.set_epoch(self.current_epoch)
212+
self.trainer.train_dataloader.sampler.set_epoch(self.epoch_progress.current.processed)
211213

212214
# changing gradient according accumulation_scheduler
213215
self.trainer.accumulation_scheduler.on_train_epoch_start(self.trainer, self.trainer.lightning_module)

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -212,8 +212,9 @@ def restore_loops(self) -> None:
212212
return
213213

214214
self.trainer.fit_loop.global_step = self._loaded_checkpoint["global_step"]
215-
# FIXME: keep in mind old checkpoints without progress tracking
216-
self.trainer.fit_loop.current_epoch = self._loaded_checkpoint["epoch"]
215+
# set the `current_epoch` value for old checkpoints without the progress tracking state
216+
# it will be overwritten by the loop's state if it was also saved
217+
self.trainer.fit_loop.epoch_progress.current.completed = self._loaded_checkpoint["epoch"]
217218

218219
assert self.trainer.state.fn is not None
219220
state_dict = self._loaded_checkpoint.get("loops")

pytorch_lightning/trainer/trainer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2342,7 +2342,8 @@ def global_step(self) -> int:
23422342

23432343
@property
23442344
def current_epoch(self) -> int:
2345-
return self.fit_loop.current_epoch
2345+
"""The current epoch, updated after the epoch end hooks are run."""
2346+
return self.fit_loop.epoch_progress.current.completed
23462347

23472348
@property
23482349
def max_epochs(self) -> int:

pytorch_lightning/tuner/batch_size_scaling.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@ def __scale_batch_dump_params(trainer: "pl.Trainer") -> None:
103103
# Prevent going into infinite loop
104104
trainer.__dumped_params = {
105105
"auto_lr_find": trainer.auto_lr_find,
106-
"current_epoch": trainer.current_epoch,
107106
"global_step": trainer.global_step,
108107
"max_steps": trainer.max_steps,
109108
"logger": trainer.logger,
@@ -118,7 +117,6 @@ def __scale_batch_dump_params(trainer: "pl.Trainer") -> None:
118117
def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule", steps_per_trial: int) -> None:
119118
trainer.auto_scale_batch_size = None # prevent recursion
120119
trainer.auto_lr_find = False # avoid lr find being called multiple times
121-
trainer.fit_loop.current_epoch = 0
122120
trainer.fit_loop.max_steps = steps_per_trial # take few steps
123121
trainer.logger = DummyLogger() if trainer.logger is not None else None
124122
trainer.callbacks = [] # not needed before full run
@@ -129,7 +127,6 @@ def __scale_batch_reset_params(trainer: "pl.Trainer", model: "pl.LightningModule
129127

130128
def __scale_batch_restore_params(trainer: "pl.Trainer") -> None:
131129
trainer.auto_lr_find = trainer.__dumped_params["auto_lr_find"]
132-
trainer.fit_loop.current_epoch = trainer.__dumped_params["current_epoch"]
133130
trainer.fit_loop.global_step = trainer.__dumped_params["global_step"]
134131
trainer.fit_loop.max_steps = trainer.__dumped_params["max_steps"]
135132
trainer.logger = trainer.__dumped_params["logger"]

tests/checkpointing/test_model_checkpoint.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -980,14 +980,17 @@ def assert_checkpoint_log_dir(idx):
980980
trainer.fit(model, ckpt_path=chk)
981981
assert trainer.global_step == epochs * limit_train_batches
982982
assert trainer.current_epoch == epochs
983+
assert trainer.fit_loop.epoch_progress.current.processed == epochs
983984

984985
trainer.validate(model)
985986
assert trainer.global_step == epochs * limit_train_batches
986987
assert trainer.current_epoch == epochs
988+
assert trainer.fit_loop.epoch_progress.current.processed == epochs
987989

988990
trainer.fit(model)
989991
assert trainer.global_step == epochs * limit_train_batches
990992
assert trainer.current_epoch == epochs
993+
assert trainer.fit_loop.epoch_progress.current.processed == epochs
991994
assert_checkpoint_log_dir(idx)
992995

993996

tests/trainer/test_trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -331,7 +331,7 @@ def mock_save_function(filepath, *args):
331331

332332
# emulate callback's calls during the training
333333
for i, loss in enumerate(losses):
334-
trainer.fit_loop.current_epoch = i
334+
trainer.fit_loop.epoch_progress.current.completed = i # sets `trainer.current_epoch`
335335
trainer.fit_loop.global_step = i
336336
trainer.callback_metrics.update({"checkpoint_on": torch.tensor(loss)})
337337
checkpoint_callback.on_validation_end(trainer, trainer.lightning_module)

0 commit comments

Comments
 (0)