Skip to content

Commit feb8e7d

Browse files
Remove deprecated LightningModule.on_post_move_to_device (#13548)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 07e7d6d commit feb8e7d

File tree

8 files changed

+10
-77
lines changed

8 files changed

+10
-77
lines changed

docs/source-pytorch/accelerators/tpu_advanced.rst

Lines changed: 6 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,11 +12,10 @@ Weight Tying/Sharing is a technique where in the module weights are shared among
1212
This is a common method to reduce memory consumption and is utilized in many State of the Art
1313
architectures today.
1414

15-
PyTorch XLA requires these weights to be tied/shared after moving the model
16-
to the TPU device. To support this requirement Lightning provides a model hook which is
17-
called after the model is moved to the device. Any weights that require to be tied should
18-
be done in the `on_post_move_to_device` model hook. This will ensure that the weights
19-
among the modules are shared and not copied.
15+
PyTorch XLA requires these weights to be tied/shared after moving the model to the XLA device.
16+
To support this requirement, Lightning automatically finds these weights and ties them after
17+
the modules are moved to the XLA device under the hood. It will ensure that the weights among
18+
the modules are shared but not copied independently.
2019

2120
PyTorch Lightning has an inbuilt check which verifies that the model parameter lengths
2221
match once the model is moved to the device. If the lengths do not match Lightning
@@ -37,9 +36,8 @@ Example:
3736
self.layer_1 = nn.Linear(32, 10, bias=False)
3837
self.layer_2 = nn.Linear(10, 32, bias=False)
3938
self.layer_3 = nn.Linear(32, 10, bias=False)
40-
# TPU shared weights are copied independently
41-
# on the XLA device and this line won't have any effect.
42-
# However, it works fine for CPU and GPU.
39+
# Lightning automatically ties these weights after moving to the XLA device,
40+
# so all you need is to write the following just like on other accelerators.
4341
self.layer_3.weight = self.layer_1.weight
4442
4543
def forward(self, x):
@@ -48,10 +46,6 @@ Example:
4846
x = self.layer_3(x)
4947
return x
5048
51-
def on_post_move_to_device(self):
52-
# Weights shared after the model has been moved to TPU Device
53-
self.layer_3.weight = self.layer_1.weight
54-
5549
5650
model = WeightSharingModule()
5751
trainer = Trainer(max_epochs=1, accelerator="tpu", devices=8)

docs/source-pytorch/common/lightning_module.rst

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1501,12 +1501,6 @@ on_validation_epoch_end
15011501
.. automethod:: pytorch_lightning.core.module.LightningModule.on_validation_epoch_end
15021502
:noindex:
15031503

1504-
on_post_move_to_device
1505-
~~~~~~~~~~~~~~~~~~~~~~
1506-
1507-
.. automethod:: pytorch_lightning.core.module.LightningModule.on_post_move_to_device
1508-
:noindex:
1509-
15101504
configure_sharded_model
15111505
~~~~~~~~~~~~~~~~~~~~~~~
15121506

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -284,6 +284,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
284284
- Removed deprecated `Callback.on_keyboard_interrupt` ([#13438](https://github.com/Lightning-AI/lightning/pull/13438))
285285

286286

287+
- Removed deprecated `LightningModule.on_post_move_to_device` ([#13548](https://github.com/Lightning-AI/lightning/pull/13548))
288+
289+
287290
### Fixed
288291

289292

src/pytorch_lightning/core/hooks.py

Lines changed: 0 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -298,21 +298,6 @@ def on_before_optimizer_step(self, optimizer, optimizer_idx):
298298
)
299299
"""
300300

301-
def on_post_move_to_device(self) -> None:
302-
"""Called in the ``parameter_validation`` decorator after
303-
:meth:`~pytorch_lightning.core.LightningModule.to` is called. This is a good place to tie weights between
304-
modules after moving them to a device. Can be used when training models with weight sharing properties on
305-
TPU.
306-
307-
Addresses the handling of shared weights on TPU:
308-
https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks
309-
310-
Example::
311-
312-
def on_post_move_to_device(self):
313-
self.decoder.weight = self.encoder.weight
314-
"""
315-
316301
def configure_sharded_model(self) -> None:
317302
"""Hook to create modules in a distributed aware context. This is useful for when using sharded plugins,
318303
where we'd like to shard the model instantly, which is useful for extremely large models which can save

src/pytorch_lightning/overrides/base.py

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,6 @@ def predict_step(self, *args: Any, **kwargs: Any) -> Any:
5252
def forward(self, *args: Any, **kwargs: Any) -> Any:
5353
raise NotImplementedError
5454

55-
def on_post_move_to_device(self) -> None:
56-
pass
57-
5855

5956
class _LightningModuleWrapperBase(DeviceDtypeModuleMixin, torch.nn.Module):
6057
def __init__(self, pl_module: Union["pl.LightningModule", _LightningPrecisionModuleWrapperBase]) -> None:
@@ -95,9 +92,6 @@ def forward(self, *inputs: Any, **kwargs: Any) -> Any:
9592
return self.module.predict_step(*inputs, **kwargs)
9693
return self.module(*inputs, **kwargs)
9794

98-
def on_post_move_to_device(self) -> None:
99-
pass
100-
10195

10296
def unwrap_lightning_module(wrapped_model: nn.Module) -> "pl.LightningModule":
10397
"""Recursively unwraps a :class:`~pytorch_lightning.core.module.LightningModule` by following the ``.module``

src/pytorch_lightning/strategies/tpu_spawn.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,6 @@
3333
from pytorch_lightning.utilities.data import has_len
3434
from pytorch_lightning.utilities.distributed import ReduceOp
3535
from pytorch_lightning.utilities.exceptions import MisconfigurationException
36-
from pytorch_lightning.utilities.model_helpers import is_overridden
3736
from pytorch_lightning.utilities.optimizer import optimizers_to_device
3837
from pytorch_lightning.utilities.rank_zero import rank_zero_only
3938
from pytorch_lightning.utilities.seed import reset_seed
@@ -124,11 +123,7 @@ def setup(self, trainer: "pl.Trainer") -> None:
124123

125124
shared_params = find_shared_parameters(self.model)
126125
self.model_to_device()
127-
if is_overridden("on_post_move_to_device", self.lightning_module):
128-
self.model.module.on_post_move_to_device()
129-
else:
130-
set_shared_parameters(self.model.module, shared_params)
131-
126+
set_shared_parameters(self.model.module, shared_params)
132127
self.setup_precision_plugin()
133128

134129
if trainer.state.fn == TrainerFn.FITTING:

src/pytorch_lightning/trainer/configuration_validator.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -46,8 +46,6 @@ def verify_loop_configurations(trainer: "pl.Trainer") -> None:
4646
__verify_eval_loop_configuration(trainer, model, "predict")
4747

4848
__verify_dp_batch_transfer_support(trainer, model)
49-
# TODO: Delete _check_on_post_move_to_device in v1.7
50-
_check_on_post_move_to_device(model)
5149
_check_deprecated_callback_hooks(trainer)
5250
# TODO: Delete _check_on_hpc_hooks in v1.8
5351
_check_on_hpc_hooks(model)
@@ -122,20 +120,6 @@ def __verify_train_val_loop_configuration(trainer: "pl.Trainer", model: "pl.Ligh
122120
)
123121

124122

125-
def _check_on_post_move_to_device(model: "pl.LightningModule") -> None:
126-
r"""
127-
Checks if `on_post_move_to_device` method is overridden and sends a deprecation warning.
128-
129-
Args:
130-
model: The model to check the `on_post_move_to_device` method.
131-
"""
132-
if is_overridden("on_post_move_to_device", model):
133-
rank_zero_deprecation(
134-
"Method `on_post_move_to_device` has been deprecated in v1.5 and will be removed in v1.7. "
135-
"We perform automatic parameters tying without the need of implementing `on_post_move_to_device`."
136-
)
137-
138-
139123
def __verify_eval_loop_configuration(trainer: "pl.Trainer", model: "pl.LightningModule", stage: str) -> None:
140124
loader_name = f"{stage}_dataloader"
141125
step_name = "validation_step" if stage == "val" else f"{stage}_step"

tests/tests_pytorch/deprecated_api/test_remove_1-7.py

Lines changed: 0 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import torch
2121

2222
from pytorch_lightning import Trainer
23-
from pytorch_lightning.demos.boring_classes import BoringModel
2423
from pytorch_lightning.plugins.environments import (
2524
KubeflowEnvironment,
2625
LightningEnvironment,
@@ -39,21 +38,6 @@ def test_v1_7_0_deprecate_lightning_distributed(tmpdir):
3938
_ = LightningDistributed()
4039

4140

42-
def test_v1_7_0_deprecate_on_post_move_to_device(tmpdir):
43-
class TestModel(BoringModel):
44-
def on_post_move_to_device(self):
45-
print("on_post_move_to_device")
46-
47-
model = TestModel()
48-
49-
trainer = Trainer(default_root_dir=tmpdir, limit_train_batches=5, max_epochs=1)
50-
51-
with pytest.deprecated_call(
52-
match=r"Method `on_post_move_to_device` has been deprecated in v1.5 and will be removed in v1.7"
53-
):
54-
trainer.fit(model)
55-
56-
5741
def test_v1_7_0_deprecated_max_steps_none(tmpdir):
5842
with pytest.deprecated_call(match="`max_steps = None` is deprecated in v1.5"):
5943
_ = Trainer(max_steps=None)

0 commit comments

Comments
 (0)