Skip to content

Commit 62ae3ea

Browse files
committed
Minor fixes
1 parent 41309da commit 62ae3ea

File tree

2 files changed

+15
-10
lines changed

2 files changed

+15
-10
lines changed

pytorch_lightning/trainer/progress.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -148,9 +148,6 @@ def from_defaults(cls, tracker_cls: Type[ReadyCompletedTracker], **kwargs: int)
148148
"""Utility function to easily create an instance from keyword arguments to both ``Tracker``s."""
149149
return cls(total=tracker_cls(**kwargs), current=tracker_cls(**kwargs))
150150

151-
def reset_on_epoch(self) -> None:
152-
self.current.reset()
153-
154151
def reset_on_run(self) -> None:
155152
self.current.reset()
156153

tests/models/test_restore.py

Lines changed: 15 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -224,31 +224,39 @@ def test_correct_step_and_epoch(tmpdir):
224224
model = BoringModel()
225225
first_max_epochs = 2
226226
train_batches = 2
227-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=first_max_epochs, limit_train_batches=train_batches)
227+
trainer = Trainer(
228+
default_root_dir=tmpdir, max_epochs=first_max_epochs, limit_train_batches=train_batches, limit_val_batches=0
229+
)
228230
assert trainer.current_epoch == 0
229231
assert trainer.global_step == 0
230232

231233
trainer.fit(model)
232234
assert trainer.current_epoch == first_max_epochs
233235
assert trainer.global_step == first_max_epochs * train_batches
234236

235-
ckpt = str(tmpdir / "model.ckpt")
236-
trainer.save_checkpoint(ckpt)
237+
ckpt_path = str(tmpdir / "model.ckpt")
238+
trainer.save_checkpoint(ckpt_path)
239+
240+
ckpt = torch.load(ckpt_path)
241+
assert ckpt["epoch"] == first_max_epochs
237242
# TODO(@carmocca): should not need `+1`
238-
assert torch.load(ckpt)["global_step"] == first_max_epochs * train_batches + 1
243+
assert ckpt["global_step"] == first_max_epochs * train_batches + 1
239244

240-
max_epochs = 4
241-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=max_epochs, limit_train_batches=train_batches)
245+
max_epochs = first_max_epochs + 2
246+
trainer = Trainer(
247+
default_root_dir=tmpdir, max_epochs=max_epochs, limit_train_batches=train_batches, limit_val_batches=0
248+
)
242249
# the ckpt state is not loaded at this point
243250
assert trainer.current_epoch == 0
244251
assert trainer.global_step == 0
245252

246253
class TestModel(BoringModel):
247254
def on_pretrain_routine_end(self) -> None:
255+
assert self.trainer.current_epoch == first_max_epochs
248256
# TODO(@carmocca): should not need `+1`
249257
assert self.trainer.global_step == first_max_epochs * train_batches + 1
250258

251-
trainer.fit(TestModel(), ckpt_path=ckpt)
259+
trainer.fit(TestModel(), ckpt_path=ckpt_path)
252260
assert trainer.current_epoch == max_epochs
253261
# TODO(@carmocca): should not need `+1`
254262
assert trainer.global_step == max_epochs * train_batches + 1

0 commit comments

Comments
 (0)