Skip to content

Commit 3139fd5

Browse files
committed
Undo test changes
1 parent 3a30760 commit 3139fd5

File tree

2 files changed

+37
-11
lines changed

2 files changed

+37
-11
lines changed

pytorch_lightning/loops/base.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -205,8 +205,7 @@ def run(self, *args, **kwargs):
205205
self._restarting = False
206206
except StopIteration:
207207
break
208-
else:
209-
self._restarting = False
208+
self._restarting = False
210209

211210
output = self.on_run_end()
212211
return output

tests/checkpointing/test_trainer_checkpoint.py

Lines changed: 36 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -20,35 +20,62 @@
2020

2121

2222
def test_finetuning_with_ckpt_path(tmpdir):
23-
model = BoringModel()
23+
"""This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test."""
24+
25+
checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1)
26+
27+
class ExtendedBoringModel(BoringModel):
28+
def configure_optimizers(self):
29+
import torch
30+
31+
optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001)
32+
lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1)
33+
return [optimizer], [lr_scheduler]
34+
35+
def validation_step(self, batch, batch_idx):
36+
output = self.layer(batch)
37+
loss = self.loss(batch, output)
38+
self.log("val_loss", loss, on_epoch=True, prog_bar=True)
39+
40+
model = ExtendedBoringModel()
41+
model.validation_epoch_end = None
2442
trainer = Trainer(
2543
default_root_dir=tmpdir,
2644
max_epochs=1,
27-
limit_train_batches=1,
28-
callbacks=ModelCheckpoint(dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1),
45+
limit_train_batches=12,
46+
limit_val_batches=6,
47+
limit_test_batches=12,
48+
callbacks=[checkpoint_callback],
2949
logger=False,
3050
)
3151
trainer.fit(model)
3252
assert os.listdir(tmpdir) == ["epoch=00.ckpt"]
3353

34-
best_model_paths = [trainer.checkpoint_callback.best_model_path]
54+
best_model_paths = [checkpoint_callback.best_model_path]
55+
results = []
3556

3657
for idx in range(3, 6):
3758
# load from checkpoint
3859
trainer = pl.Trainer(
3960
default_root_dir=tmpdir,
4061
max_epochs=idx,
41-
limit_train_batches=1,
62+
limit_train_batches=12,
63+
limit_val_batches=12,
64+
limit_test_batches=12,
4265
enable_progress_bar=False,
43-
callbacks=ModelCheckpoint(dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1),
4466
)
4567
trainer.fit(model, ckpt_path=best_model_paths[-1])
46-
trainer.test(model)
68+
trainer.test()
69+
from copy import deepcopy
70+
71+
results.append(deepcopy(trainer.callback_metrics))
4772
best_model_paths.append(trainer.checkpoint_callback.best_model_path)
4873

49-
assert len(best_model_paths) == 4
5074
for idx, best_model_path in enumerate(best_model_paths):
51-
assert best_model_path.endswith(f"epoch=0{idx}.ckpt")
75+
if idx == 0:
76+
assert best_model_path.endswith(f"epoch=0{idx}.ckpt")
77+
else:
78+
assert f"epoch={idx + 1}" in best_model_path
5279

5380

5481
def test_accumulated_gradient_batches_with_ckpt_path(tmpdir):

0 commit comments

Comments
 (0)