Skip to content

Commit 72f0e5b

Browse files
Piyush-97rohitgr7daniellepintzananthsubawaelchli
authored
Deprecate on_configure_sharded_model callback hook for v1.6 (#11627)
Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Danielle Pintz <[email protected]> Co-authored-by: ananthsub <[email protected]> Co-authored-by: Adrian Wälchli <[email protected]>
1 parent 6586dd2 commit 72f0e5b

File tree

4 files changed

+41
-1
lines changed

4 files changed

+41
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -290,6 +290,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
290290
- Deprecated function `pytorch_lightning.callbacks.device_stats_monitor.prefix_metric_keys` ([#11254](https://github.com/PyTorchLightning/pytorch-lightning/pull/11254))
291291

292292

293+
- Deprecated `on_configure_sharded_model` callback hook in favor of `setup` ([#11627](https://github.com/PyTorchLightning/pytorch-lightning/pull/11627))
294+
295+
293296
### Removed
294297

295298
- Removed deprecated parameter `method` in `pytorch_lightning.utilities.model_helpers.is_overridden` ([#10507](https://github.com/PyTorchLightning/pytorch-lightning/pull/10507))

pytorch_lightning/callbacks/base.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,12 @@ def _generate_state_key(self, **kwargs: Any) -> str:
5757
return f"{self.__class__.__qualname__}{repr(kwargs)}"
5858

5959
def on_configure_sharded_model(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
60-
"""Called before configure sharded model."""
60+
r"""
61+
.. deprecated:: v1.6
62+
This callback hook was deprecated in v1.6 and will be removed in v1.8. Use `setup()` instead.
63+
64+
Called before configure sharded model.
65+
"""
6166

6267
def on_before_accelerator_backend_setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
6368
"""Called before accelerator is being setup."""

pytorch_lightning/trainer/configuration_validator.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
5757
_check_on_init_start_end(trainer)
5858
# TODO: Delete _check_on_hpc_hooks in v1.8
5959
_check_on_hpc_hooks(model)
60+
# TODO: Remove this in v1.8
61+
_check_on_configure_sharded_model(trainer)
6062

6163

6264
def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule") -> None:
@@ -322,3 +324,12 @@ def _check_on_hpc_hooks(model: "pl.LightningModule") -> None:
322324
"Method `LightningModule.on_hpc_load` is deprecated in v1.6 and"
323325
" will be removed in v1.8. Please use `LightningModule.on_load_checkpoint` instead."
324326
)
327+
328+
329+
def _check_on_configure_sharded_model(trainer: "pl.Trainer") -> None:
330+
for callback in trainer.callbacks:
331+
if is_overridden(method_name="on_configure_sharded_model", instance=callback):
332+
rank_zero_deprecation(
333+
"The `on_configure_sharded_model` callback hook was deprecated in"
334+
" v1.6 and will be removed in v1.8. Use `setup()` instead."
335+
)

tests/deprecated_api/test_remove_1-8.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -351,3 +351,24 @@ def test_v1_8_0_deprecated_lightning_optimizers():
351351
match="Trainer.lightning_optimizers` is deprecated in v1.6 and will be removed in v1.8"
352352
):
353353
assert trainer.lightning_optimizers == {}
354+
355+
356+
def test_v1_8_0_on_configure_sharded_model(tmpdir):
357+
class TestCallback(Callback):
358+
def on_configure_sharded_model(self, trainer, model):
359+
print("Configuring sharded model")
360+
361+
model = BoringModel()
362+
363+
trainer = Trainer(
364+
callbacks=[TestCallback()],
365+
max_epochs=1,
366+
fast_dev_run=True,
367+
enable_progress_bar=False,
368+
logger=False,
369+
default_root_dir=tmpdir,
370+
)
371+
with pytest.deprecated_call(
372+
match="The `on_configure_sharded_model` callback hook was deprecated in v1.6 and will be removed in v1.8."
373+
):
374+
trainer.fit(model)

0 commit comments

Comments
 (0)