From 0484f4aae1b8b56b03211cc96a4b3f82d0ed233f Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sat, 6 Feb 2021 23:20:34 +0000 Subject: [PATCH 1/2] Remove copy trainer parameters to happen earlier within the loop and add safe guard to get ref model --- .../trainer/connectors/model_connector.py | 2 +- pytorch_lightning/trainer/trainer.py | 9 +++------ tests/trainer/test_trainer.py | 14 ++++++++++++++ 3 files changed, 18 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/model_connector.py b/pytorch_lightning/trainer/connectors/model_connector.py index 2acd5a3cc8cb3..060601049f9b7 100644 --- a/pytorch_lightning/trainer/connectors/model_connector.py +++ b/pytorch_lightning/trainer/connectors/model_connector.py @@ -42,6 +42,6 @@ def get_model(self): return self._get_reference_model(self.trainer.model) def _get_reference_model(self, model): - if self.trainer.accelerator_backend: + if self.trainer.accelerator_backend and self.trainer.accelerator_backend.lightning_module: return self.trainer.accelerator_backend.lightning_module return model diff --git a/pytorch_lightning/trainer/trainer.py b/pytorch_lightning/trainer/trainer.py index 8e833c33cbbcf..cedb491340b05 100644 --- a/pytorch_lightning/trainer/trainer.py +++ b/pytorch_lightning/trainer/trainer.py @@ -405,12 +405,6 @@ def setup_trainer(self, model: LightningModule): Args: model: The model to run sanity test on. """ - # -------------------------- - # Setup?? - # -------------------------- - - # set local properties on the model - self.model_connector.copy_trainer_model_properties(model) # init amp. Must be done here instead of __init__ to allow ddp to work if self.amp_backend == AMPType.NATIVE and self.precision == 16 and self._device_type != DeviceType.TPU: @@ -449,6 +443,9 @@ def fit( self._state = TrainerState.RUNNING self._set_wide_running_stage(RunningStage.TRAINING) + # set local properties on the model + self.model_connector.copy_trainer_model_properties(model) + # ---------------------------- # LINK DATA # ---------------------------- diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index c7551fb811b86..65f69dfe394e4 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1726,3 +1726,17 @@ def training_epoch_end(self, *args, **kwargs): assert trainer.current_epoch == current_epoch assert model.training_step_invoked == should_train, f"`training_step` {error_string}" assert model.training_epoch_end_invoked == should_train, f"`training_epoch_end` {error_string}" + + +def test_trainer_access_in_configure_optimizers(tmpdir): + + class TestModel(BoringModel): + + def configure_optimizers(self): + assert self.trainer is not None, "Expect to have access to the trainer within `configure_optimizers`" + + train_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) + + model = TestModel() + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=1) + trainer.fit(model, train_data) From a15648e42e19a1bf27717b89ec3e79b1d1606e95 Mon Sep 17 00:00:00 2001 From: SeanNaren Date: Sat, 6 Feb 2021 23:52:28 +0000 Subject: [PATCH 2/2] Remove GPU flag --- tests/trainer/test_trainer.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py index 65f69dfe394e4..0fb452f7a47ff 100644 --- a/tests/trainer/test_trainer.py +++ b/tests/trainer/test_trainer.py @@ -1738,5 +1738,5 @@ def configure_optimizers(self): train_data = torch.utils.data.DataLoader(RandomDataset(32, 64)) model = TestModel() - trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True, gpus=1) + trainer = Trainer(default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model, train_data)