|
19 | 19 |
|
20 | 20 | from pytorch_lightning import LightningModule, seed_everything, Trainer |
21 | 21 | from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning |
| 22 | +from pytorch_lightning.callbacks.base import Callback |
22 | 23 | from tests.base import BoringModel, RandomDataset |
23 | 24 |
|
24 | 25 |
|
@@ -215,3 +216,31 @@ def __init__(self): |
215 | 216 | assert torch.equal(optimizer.param_groups[2]["params"][0], model.backbone[2].weight) |
216 | 217 | assert torch.equal(optimizer.param_groups[2]["params"][1], model.backbone[3].weight) |
217 | 218 | assert torch.equal(optimizer.param_groups[2]["params"][2], model.backbone[4].weight) |
| 219 | + |
| 220 | + |
| 221 | +def test_on_before_accelerator_backend_setup(tmpdir): |
| 222 | + """ |
| 223 | + `on_before_accelerator_backend_setup` hook is used make sure the finetuning freeze call is made |
| 224 | + before configure_optimizers call. |
| 225 | + """ |
| 226 | + |
| 227 | + class TestCallback(Callback): |
| 228 | + |
| 229 | + def on_before_accelerator_backend_setup(self, trainer, pl_module): |
| 230 | + pl_module.on_before_accelerator_backend_setup_called = True |
| 231 | + |
| 232 | + class TestModel(BoringModel): |
| 233 | + |
| 234 | + def __init__(self): |
| 235 | + super().__init__() |
| 236 | + self.on_before_accelerator_backend_setup_called = False |
| 237 | + |
| 238 | + def configure_optimizers(self): |
| 239 | + assert self.on_before_accelerator_backend_setup_called |
| 240 | + return super().configure_optimizers() |
| 241 | + |
| 242 | + model = TestModel() |
| 243 | + callback = TestCallback() |
| 244 | + |
| 245 | + trainer = Trainer(default_root_dir=tmpdir, callbacks=[callback], fast_dev_run=True) |
| 246 | + trainer.fit(model) |
0 commit comments