Skip to content

Commit 02ac4b0

Browse files
authored
Replace .get_model() with explicit .lightning_module (#6035)
* rename get_model -> lightning_module * update references to get_model * pep8 * add proper deprecation * remove outdated _get_reference_model * fix cyclic import
1 parent 3449e2d commit 02ac4b0

File tree

21 files changed

+140
-131
lines changed

21 files changed

+140
-131
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
237237
- Deprecated using `'val_loss'` to set the `ModelCheckpoint` monitor ([#6012](https://github.com/PyTorchLightning/pytorch-lightning/pull/6012))
238238

239239

240+
- Deprecated `.get_model()` with explicit `.lightning_module` property ([#6035](https://github.com/PyTorchLightning/pytorch-lightning/pull/6035))
241+
242+
240243
### Removed
241244

242245
- Removed deprecated checkpoint argument `filepath` ([#5321](https://github.com/PyTorchLightning/pytorch-lightning/pull/5321))

pytorch_lightning/core/optimizer.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -101,11 +101,11 @@ def _to_lightning_optimizer(cls, optimizer, trainer, opt_idx):
101101
return optimizer
102102

103103
def _toggle_model(self):
104-
model_ref = self._trainer.get_model()
104+
model_ref = self._trainer.lightning_module
105105
model_ref.toggle_optimizer(self, self._optimizer_idx)
106106

107107
def _untoggle_model(self):
108-
model_ref = self._trainer.get_model()
108+
model_ref = self._trainer.lightning_module
109109
model_ref.untoggle_optimizer(self)
110110

111111
@contextmanager
@@ -129,7 +129,7 @@ def toggle_model(self, sync_grad: bool = True):
129129
def __optimizer_step(self, closure: Optional[Callable] = None, profiler_name: str = None, **kwargs):
130130
trainer = self._trainer
131131
optimizer = self._optimizer
132-
model = trainer.get_model()
132+
model = trainer.lightning_module
133133

134134
with trainer.profiler.profile(profiler_name):
135135
trainer.accelerator_backend.optimizer_step(optimizer, self._optimizer_idx, lambda_closure=closure, **kwargs)

pytorch_lightning/trainer/callback_hook.py

Lines changed: 34 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,18 @@
1414

1515
from abc import ABC
1616
from copy import deepcopy
17-
from typing import Callable, List
17+
from typing import List
1818

1919
from pytorch_lightning.callbacks import Callback
20+
from pytorch_lightning.core.lightning import LightningModule
2021

2122

2223
class TrainerCallbackHookMixin(ABC):
2324

2425
# this is just a summary on variables used in this abstract class,
2526
# the proper values/initialisation should be done in child class
2627
callbacks: List[Callback] = []
27-
get_model: Callable
28+
lightning_module: LightningModule
2829

2930
def on_before_accelerator_backend_setup(self, model):
3031
"""Called in the beginning of fit and test"""
@@ -39,7 +40,7 @@ def setup(self, model, stage: str):
3940
def teardown(self, stage: str):
4041
"""Called at the end of fit and test"""
4142
for callback in self.callbacks:
42-
callback.teardown(self, self.get_model(), stage)
43+
callback.teardown(self, self.lightning_module, stage)
4344

4445
def on_init_start(self):
4546
"""Called when the trainer initialization begins, model has not yet been set."""
@@ -54,72 +55,72 @@ def on_init_end(self):
5455
def on_fit_start(self):
5556
"""Called when the trainer initialization begins, model has not yet been set."""
5657
for callback in self.callbacks:
57-
callback.on_fit_start(self, self.get_model())
58+
callback.on_fit_start(self, self.lightning_module)
5859

5960
def on_fit_end(self):
6061
"""Called when the trainer initialization begins, model has not yet been set."""
6162
for callback in self.callbacks:
62-
callback.on_fit_end(self, self.get_model())
63+
callback.on_fit_end(self, self.lightning_module)
6364

6465
def on_sanity_check_start(self):
6566
"""Called when the validation sanity check starts."""
6667
for callback in self.callbacks:
67-
callback.on_sanity_check_start(self, self.get_model())
68+
callback.on_sanity_check_start(self, self.lightning_module)
6869

6970
def on_sanity_check_end(self):
7071
"""Called when the validation sanity check ends."""
7172
for callback in self.callbacks:
72-
callback.on_sanity_check_end(self, self.get_model())
73+
callback.on_sanity_check_end(self, self.lightning_module)
7374

7475
def on_train_epoch_start(self):
7576
"""Called when the epoch begins."""
7677
for callback in self.callbacks:
77-
callback.on_train_epoch_start(self, self.get_model())
78+
callback.on_train_epoch_start(self, self.lightning_module)
7879

7980
def on_train_epoch_end(self, outputs):
8081
"""Called when the epoch ends."""
8182
for callback in self.callbacks:
82-
callback.on_train_epoch_end(self, self.get_model(), outputs)
83+
callback.on_train_epoch_end(self, self.lightning_module, outputs)
8384

8485
def on_validation_epoch_start(self):
8586
"""Called when the epoch begins."""
8687
for callback in self.callbacks:
87-
callback.on_validation_epoch_start(self, self.get_model())
88+
callback.on_validation_epoch_start(self, self.lightning_module)
8889

8990
def on_validation_epoch_end(self):
9091
"""Called when the epoch ends."""
9192
for callback in self.callbacks:
92-
callback.on_validation_epoch_end(self, self.get_model())
93+
callback.on_validation_epoch_end(self, self.lightning_module)
9394

9495
def on_test_epoch_start(self):
9596
"""Called when the epoch begins."""
9697
for callback in self.callbacks:
97-
callback.on_test_epoch_start(self, self.get_model())
98+
callback.on_test_epoch_start(self, self.lightning_module)
9899

99100
def on_test_epoch_end(self):
100101
"""Called when the epoch ends."""
101102
for callback in self.callbacks:
102-
callback.on_test_epoch_end(self, self.get_model())
103+
callback.on_test_epoch_end(self, self.lightning_module)
103104

104105
def on_epoch_start(self):
105106
"""Called when the epoch begins."""
106107
for callback in self.callbacks:
107-
callback.on_epoch_start(self, self.get_model())
108+
callback.on_epoch_start(self, self.lightning_module)
108109

109110
def on_epoch_end(self):
110111
"""Called when the epoch ends."""
111112
for callback in self.callbacks:
112-
callback.on_epoch_end(self, self.get_model())
113+
callback.on_epoch_end(self, self.lightning_module)
113114

114115
def on_train_start(self):
115116
"""Called when the train begins."""
116117
for callback in self.callbacks:
117-
callback.on_train_start(self, self.get_model())
118+
callback.on_train_start(self, self.lightning_module)
118119

119120
def on_train_end(self):
120121
"""Called when the train ends."""
121122
for callback in self.callbacks:
122-
callback.on_train_end(self, self.get_model())
123+
callback.on_train_end(self, self.lightning_module)
123124

124125
def on_pretrain_routine_start(self, model):
125126
"""Called when the train begins."""
@@ -134,74 +135,74 @@ def on_pretrain_routine_end(self, model):
134135
def on_batch_start(self):
135136
"""Called when the training batch begins."""
136137
for callback in self.callbacks:
137-
callback.on_batch_start(self, self.get_model())
138+
callback.on_batch_start(self, self.lightning_module)
138139

139140
def on_batch_end(self):
140141
"""Called when the training batch ends."""
141142
for callback in self.callbacks:
142-
callback.on_batch_end(self, self.get_model())
143+
callback.on_batch_end(self, self.lightning_module)
143144

144145
def on_train_batch_start(self, batch, batch_idx, dataloader_idx):
145146
"""Called when the training batch begins."""
146147
for callback in self.callbacks:
147-
callback.on_train_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)
148+
callback.on_train_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)
148149

149150
def on_train_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
150151
"""Called when the training batch ends."""
151152
for callback in self.callbacks:
152-
callback.on_train_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx)
153+
callback.on_train_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
153154

154155
def on_validation_batch_start(self, batch, batch_idx, dataloader_idx):
155156
"""Called when the validation batch begins."""
156157
for callback in self.callbacks:
157-
callback.on_validation_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)
158+
callback.on_validation_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)
158159

159160
def on_validation_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
160161
"""Called when the validation batch ends."""
161162
for callback in self.callbacks:
162-
callback.on_validation_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx)
163+
callback.on_validation_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
163164

164165
def on_test_batch_start(self, batch, batch_idx, dataloader_idx):
165166
"""Called when the test batch begins."""
166167
for callback in self.callbacks:
167-
callback.on_test_batch_start(self, self.get_model(), batch, batch_idx, dataloader_idx)
168+
callback.on_test_batch_start(self, self.lightning_module, batch, batch_idx, dataloader_idx)
168169

169170
def on_test_batch_end(self, outputs, batch, batch_idx, dataloader_idx):
170171
"""Called when the test batch ends."""
171172
for callback in self.callbacks:
172-
callback.on_test_batch_end(self, self.get_model(), outputs, batch, batch_idx, dataloader_idx)
173+
callback.on_test_batch_end(self, self.lightning_module, outputs, batch, batch_idx, dataloader_idx)
173174

174175
def on_validation_start(self):
175176
"""Called when the validation loop begins."""
176177
for callback in self.callbacks:
177-
callback.on_validation_start(self, self.get_model())
178+
callback.on_validation_start(self, self.lightning_module)
178179

179180
def on_validation_end(self):
180181
"""Called when the validation loop ends."""
181182
for callback in self.callbacks:
182-
callback.on_validation_end(self, self.get_model())
183+
callback.on_validation_end(self, self.lightning_module)
183184

184185
def on_test_start(self):
185186
"""Called when the test begins."""
186187
for callback in self.callbacks:
187-
callback.on_test_start(self, self.get_model())
188+
callback.on_test_start(self, self.lightning_module)
188189

189190
def on_test_end(self):
190191
"""Called when the test ends."""
191192
for callback in self.callbacks:
192-
callback.on_test_end(self, self.get_model())
193+
callback.on_test_end(self, self.lightning_module)
193194

194195
def on_keyboard_interrupt(self):
195196
"""Called when the training is interrupted by KeyboardInterrupt."""
196197
for callback in self.callbacks:
197-
callback.on_keyboard_interrupt(self, self.get_model())
198+
callback.on_keyboard_interrupt(self, self.lightning_module)
198199

199200
def on_save_checkpoint(self):
200201
"""Called when saving a model checkpoint."""
201202
callback_states = {}
202203
for callback in self.callbacks:
203204
callback_class = type(callback)
204-
state = callback.on_save_checkpoint(self, self.get_model())
205+
state = callback.on_save_checkpoint(self, self.lightning_module)
205206
if state:
206207
callback_states[callback_class] = state
207208
return callback_states
@@ -224,11 +225,11 @@ def on_after_backward(self):
224225
Called after loss.backward() and before optimizers do anything.
225226
"""
226227
for callback in self.callbacks:
227-
callback.on_after_backward(self, self.get_model())
228+
callback.on_after_backward(self, self.lightning_module)
228229

229230
def on_before_zero_grad(self, optimizer):
230231
"""
231232
Called after optimizer.step() and before optimizer.zero_grad().
232233
"""
233234
for callback in self.callbacks:
234-
callback.on_before_zero_grad(self, self.get_model(), optimizer)
235+
callback.on_before_zero_grad(self, self.lightning_module, optimizer)

pytorch_lightning/trainer/connectors/checkpoint_connector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,7 @@ def restore(self, checkpoint_path: str, on_gpu: bool) -> bool:
9494
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
9595

9696
# acquire the model
97-
model = self.trainer.get_model()
97+
model = self.trainer.lightning_module
9898

9999
# restore model and datamodule state
100100
self.restore_model_state(model, checkpoint)
@@ -214,7 +214,7 @@ def hpc_save(self, folderpath: str, logger):
214214
filepath = os.path.join(folderpath, f'hpc_ckpt_{ckpt_number}.ckpt')
215215

216216
# give model a chance to do something on hpc_save
217-
model = self.trainer.get_model()
217+
model = self.trainer.lightning_module
218218
checkpoint = self.dump_checkpoint()
219219

220220
model.on_hpc_save(checkpoint)
@@ -307,7 +307,7 @@ def dump_checkpoint(self, weights_only: bool = False) -> dict:
307307
checkpoint['amp_scaling_state'] = amp.state_dict()
308308

309309
# add the hyper_parameters and state_dict from the model
310-
model = self.trainer.get_model()
310+
model = self.trainer.lightning_module
311311

312312
# dump the module_arguments and state_dict from the model
313313
checkpoint['state_dict'] = model.state_dict()
@@ -339,7 +339,7 @@ def hpc_load(self, checkpoint_path: str, on_gpu: bool):
339339
checkpoint = pl_load(checkpoint_path, map_location=lambda storage, loc: storage)
340340

341341
# acquire the model
342-
model = self.trainer.get_model()
342+
model = self.trainer.lightning_module
343343

344344
# restore model and datamodule state
345345
self.restore_model_state(model, checkpoint)

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -235,7 +235,7 @@ def info(self):
235235
"""
236236
This function provides necessary parameters to properly configure HookResultStore obj
237237
"""
238-
model_ref = self.trainer.get_model()
238+
model_ref = self.trainer.lightning_module
239239
return {
240240
"batch_idx": self.trainer.batch_idx,
241241
"fx_name": model_ref._current_hook_fx_name or model_ref._current_fx_name,
@@ -252,7 +252,7 @@ def reset_model(self):
252252
"""
253253
This function is used to reset model state at the end of the capture
254254
"""
255-
model_ref = self.trainer.get_model()
255+
model_ref = self.trainer.lightning_module
256256
model_ref._results = Result()
257257
model_ref._current_hook_fx_name = None
258258
model_ref._current_fx_name = ''
@@ -263,7 +263,7 @@ def cache_result(self) -> None:
263263
and store the result object
264264
"""
265265
with self.trainer.profiler.profile("cache_result"):
266-
model_ref = self.trainer.get_model()
266+
model_ref = self.trainer.lightning_module
267267

268268
# extract hook results
269269
hook_result = model_ref._results

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def cached_results(self) -> Union[EpochResultStore, None]:
8282

8383
def get_metrics(self, key: str) -> Dict:
8484
metrics_holder = getattr(self, f"_{key}", None)
85-
model_ref = self.trainer.get_model()
85+
model_ref = self.trainer.lightning_module
8686
metrics_holder.convert(
8787
self.trainer._device_type == DeviceType.TPU,
8888
model_ref.device if model_ref is not None else model_ref,
@@ -103,7 +103,7 @@ def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoc
103103

104104
def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataloaders):
105105
# Todo: required argument `testing` is not used
106-
model = self.trainer.get_model()
106+
model = self.trainer.lightning_module
107107
# set dataloader_idx only if multiple ones
108108
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
109109
# track batch_size
@@ -263,7 +263,7 @@ def track_metrics_deprecated(self, deprecated_eval_results):
263263
def evaluation_epoch_end(self, testing):
264264
# Todo: required argument `testing` is not used
265265
# reset dataloader idx
266-
model_ref = self.trainer.get_model()
266+
model_ref = self.trainer.lightning_module
267267
model_ref._current_dataloader_idx = None
268268

269269
# setting `has_batch_loop_finished` to True
@@ -408,7 +408,7 @@ def log_train_epoch_end_metrics(
408408
# epoch_output[optimizer_idx][training_step_idx][tbptt_index]
409409
# remember that not using truncated backprop is equivalent with truncated back prop of len(1)
410410

411-
model = self.trainer.get_model()
411+
model = self.trainer.lightning_module
412412

413413
epoch_callback_metrics = {}
414414

pytorch_lightning/trainer/connectors/model_connector.py

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ def __init__(self, trainer):
2525
self.trainer = trainer
2626

2727
def copy_trainer_model_properties(self, model):
28-
ref_model = self._get_reference_model(model)
28+
ref_model = self.trainer.lightning_module or model
2929

3030
automatic_optimization = ref_model.automatic_optimization and self.trainer.train_loop.automatic_optimization
3131
self.trainer.train_loop.automatic_optimization = automatic_optimization
@@ -37,11 +37,3 @@ def copy_trainer_model_properties(self, model):
3737
m.use_amp = self.trainer.amp_backend is not None
3838
m.testing = self.trainer.testing
3939
m.precision = self.trainer.precision
40-
41-
def get_model(self):
42-
return self._get_reference_model(self.trainer.model)
43-
44-
def _get_reference_model(self, model):
45-
if self.trainer.accelerator_backend and self.trainer.accelerator_backend.lightning_module:
46-
return self.trainer.accelerator_backend.lightning_module
47-
return model

pytorch_lightning/trainer/deprecated_api.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
from pytorch_lightning.core.lightning import LightningModule
1415
from pytorch_lightning.trainer.connectors.accelerator_connector import AcceleratorConnector
1516
from pytorch_lightning.trainer.states import RunningStage
1617
from pytorch_lightning.utilities import DeviceType, DistributedType, rank_zero_warn
@@ -130,3 +131,15 @@ def use_single_gpu(self, val: bool) -> None:
130131
)
131132
if val:
132133
self.accelerator_connector._device_type = DeviceType.GPU
134+
135+
136+
class DeprecatedModelAttributes:
137+
138+
lightning_module = LightningModule
139+
140+
def get_model(self) -> LightningModule:
141+
rank_zero_warn(
142+
"The use of `Trainer.get_model()` is deprecated in favor of `Trainer.lightning_module`"
143+
" and will be removed in v1.4.", DeprecationWarning
144+
)
145+
return self.lightning_module

0 commit comments

Comments
 (0)