diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 2acd5a3cc8cb3..060601049f9b7 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -42,6 +42,6 @@ def get_model(self): return self._get_reference_model(self.trainer.model) def _get_reference_model(self, model): - if self.trainer.accelerator_backend: + if self.trainer.accelerator_backend and self.trainer.accelerator_backend.lightning_module: return self.trainer.accelerator_backend.lightning_module return model diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8e833c33cbbcf..cedb491340b05 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -405,12 +405,6 @@ def setup_trainer(self, model: LightningModule): Args: model: The model to run sanity test on. """ - # -------------------------- - # Setup?? - # -------------------------- - - # set local properties on the model - self.model_connector.copy_trainer_model_properties(model) # init amp. Must be done here instead of __init__ to allow ddp to work if self.amp_backend == AMPType.NATIVE and self.precision == 16 and self._device_type != DeviceType.TPU: @@ -449,6 +443,9 @@ def fit( self._state = TrainerState.RUNNING self._set_wide_running_stage(RunningStage.TRAINING) + # set local properties on the model + self.model_connector.copy_trainer_model_properties(model) + # ---------------------------- # LINK DATA # ---------------------------- diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c7551fb811b86..0fb452f7a47ff 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1726,3 +1726,17 @@ def training_epoch_end(self, *args, **kwargs): assert trainer.current_epoch == current_epoch assert model.training_step_invoked == should_train, f"`training_step` {error_string}" assert model.training_epoch_end_invoked == should_train, f"`training_epoch_end` {error_string}" + + +def test_trainer_access_in_configure_optimizers(tmpdir): + + class TestModel(BoringModel): + + def configure_optimizers(self): + assert self.trainer is not None, "Expect to have access to the trainer within `configure_optimizers`" + + train_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) + trainer.fit(model, train_data)