|
20 | 20 |
|
21 | 21 |
|
22 | 22 | def test_finetuning_with_ckpt_path(tmpdir): |
23 | | - model = BoringModel() |
| 23 | + """This test validates that generated ModelCheckpoint is pointing to the right best_model_path during test.""" |
| 24 | + |
| 25 | + checkpoint_callback = ModelCheckpoint(monitor="val_loss", dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1) |
| 26 | + |
| 27 | + class ExtendedBoringModel(BoringModel): |
| 28 | + def configure_optimizers(self): |
| 29 | + import torch |
| 30 | + |
| 31 | + optimizer = torch.optim.SGD(self.layer.parameters(), lr=0.001) |
| 32 | + lr_scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=1) |
| 33 | + return [optimizer], [lr_scheduler] |
| 34 | + |
| 35 | + def validation_step(self, batch, batch_idx): |
| 36 | + output = self.layer(batch) |
| 37 | + loss = self.loss(batch, output) |
| 38 | + self.log("val_loss", loss, on_epoch=True, prog_bar=True) |
| 39 | + |
| 40 | + model = ExtendedBoringModel() |
| 41 | + model.validation_epoch_end = None |
24 | 42 | trainer = Trainer( |
25 | 43 | default_root_dir=tmpdir, |
26 | 44 | max_epochs=1, |
27 | | - limit_train_batches=1, |
28 | | - callbacks=ModelCheckpoint(dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1), |
| 45 | + limit_train_batches=12, |
| 46 | + limit_val_batches=6, |
| 47 | + limit_test_batches=12, |
| 48 | + callbacks=[checkpoint_callback], |
29 | 49 | logger=False, |
30 | 50 | ) |
31 | 51 | trainer.fit(model) |
32 | 52 | assert os.listdir(tmpdir) == ["epoch=00.ckpt"] |
33 | 53 |
|
34 | | - best_model_paths = [trainer.checkpoint_callback.best_model_path] |
| 54 | + best_model_paths = [checkpoint_callback.best_model_path] |
| 55 | + results = [] |
35 | 56 |
|
36 | 57 | for idx in range(3, 6): |
37 | 58 | # load from checkpoint |
38 | 59 | trainer = pl.Trainer( |
39 | 60 | default_root_dir=tmpdir, |
40 | 61 | max_epochs=idx, |
41 | | - limit_train_batches=1, |
| 62 | + limit_train_batches=12, |
| 63 | + limit_val_batches=12, |
| 64 | + limit_test_batches=12, |
42 | 65 | enable_progress_bar=False, |
43 | | - callbacks=ModelCheckpoint(dirpath=tmpdir, filename="{epoch:02d}", save_top_k=-1), |
44 | 66 | ) |
45 | 67 | trainer.fit(model, ckpt_path=best_model_paths[-1]) |
46 | | - trainer.test(model) |
| 68 | + trainer.test() |
| 69 | + from copy import deepcopy |
| 70 | + |
| 71 | + results.append(deepcopy(trainer.callback_metrics)) |
47 | 72 | best_model_paths.append(trainer.checkpoint_callback.best_model_path) |
48 | 73 |
|
49 | | - assert len(best_model_paths) == 4 |
50 | 74 | for idx, best_model_path in enumerate(best_model_paths): |
51 | | - assert best_model_path.endswith(f"epoch=0{idx}.ckpt") |
| 75 | + if idx == 0: |
| 76 | + assert best_model_path.endswith(f"epoch=0{idx}.ckpt") |
| 77 | + else: |
| 78 | + assert f"epoch={idx + 1}" in best_model_path |
52 | 79 |
|
53 | 80 |
|
54 | 81 | def test_accumulated_gradient_batches_with_ckpt_path(tmpdir): |
|
0 commit comments