Skip to content

Commit dd8d073

Browse files
committed
add new test
1 parent 59b42ce commit dd8d073

File tree

2 files changed

+30
-1
lines changed

2 files changed

+30
-1
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -459,6 +459,7 @@ def fit(
459459
# SET UP TRAINING
460460
# ----------------------------
461461
self.call_setup_hook(model)
462+
self.call_hook("on_before_accelerator_backend_setup", model)
462463
self.accelerator_backend.setup(self, model)
463464
self.setup_trainer(model)
464465

@@ -470,7 +471,6 @@ def fit(
470471

471472
# plugin will setup training (e.g. ddp will launch child processes)
472473
# TODO: the old setup is now called "pre_training", where should this hook be called now?
473-
self.call_hook("on_before_accelerator_backend_setup", model)
474474
self.training_type_plugin.pre_training()
475475
self.precision_plugin.pre_training()
476476

tests/callbacks/test_finetuning_callback.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919

2020
from pytorch_lightning import LightningModule, seed_everything, Trainer
2121
from pytorch_lightning.callbacks import BackboneFinetuning, BaseFinetuning
22+
from pytorch_lightning.callbacks.base import Callback
2223
from tests.base import BoringModel, RandomDataset
2324

2425

@@ -215,3 +216,31 @@ def __init__(self):
215216
assert torch.equal(optimizer.param_groups[2]["params"][0], model.backbone[2].weight)
216217
assert torch.equal(optimizer.param_groups[2]["params"][1], model.backbone[3].weight)
217218
assert torch.equal(optimizer.param_groups[2]["params"][2], model.backbone[4].weight)
219+
220+
221+
def test_on_before_accelerator_backend_setup(tmpdir):
222+
"""
223+
`on_before_accelerator_backend_setup` hook is used make sure the finetuning freeze call is made
224+
before configure_optimizers call.
225+
"""
226+
227+
class TestCallback(Callback):
228+
229+
def on_before_accelerator_backend_setup(self, trainer, pl_module):
230+
pl_module.on_before_accelerator_backend_setup_called = True
231+
232+
class TestModel(BoringModel):
233+
234+
def __init__(self):
235+
super().__init__()
236+
self.on_before_accelerator_backend_setup_called = False
237+
238+
def configure_optimizers(self):
239+
assert self.on_before_accelerator_backend_setup_called
240+
return super().configure_optimizers()
241+
242+
model = TestModel()
243+
callback = TestCallback()
244+
245+
trainer = Trainer(default_root_dir=tmpdir, callbacks=[callback], fast_dev_run=True)
246+
trainer.fit(model)

0 commit comments

Comments
 (0)