Skip to content

Commit d2cd7cb

Browse files
lezwonBordarohitgr7tchatonYour Name
authored
Add option for weight tying on TPU's (#5441)
* added on_post_move_to_device * added tests * docs and refactors * Update tests/backends/test_tpu_backend.py Co-authored-by: Jirka Borovec <[email protected]> * Update docs/source/tpu.rst Co-authored-by: Jirka Borovec <[email protected]> * Update docs/source/tpu.rst Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/core/decorators.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/core/decorators.py Co-authored-by: Jirka Borovec <[email protected]> * Update docs/source/tpu.rst Co-authored-by: Rohit Gupta <[email protected]> * Update pytorch_lightning/core/decorators.py Co-authored-by: Rohit Gupta <[email protected]> * Update pytorch_lightning/core/decorators.py Co-authored-by: Rohit Gupta <[email protected]> * Update pytorch_lightning/core/decorators.py Co-authored-by: Rohit Gupta <[email protected]> * Update pytorch_lightning/core/decorators.py Co-authored-by: Rohit Gupta <[email protected]> * Update pytorch_lightning/core/hooks.py Co-authored-by: Rohit Gupta <[email protected]> * moved weight sharing module back to test updated tpu available * add count to warning * fix doctest * import trainer in doctest * import trainer in doctest * do not test code as no TPU device * param count to layer count * formatting * update docs * update import * update * resolve tests * remove legacy accelerator Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: tchaton <[email protected]> Co-authored-by: Your Name <[email protected]>
1 parent bac617f commit d2cd7cb

File tree

6 files changed

+180
-5
lines changed

6 files changed

+180
-5
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5555

5656
- Added `IoU` class interface ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))
5757

58+
- Support to tie weights after moving model to TPU via `on_post_move_to_device` hook
5859

5960
- Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467))
6061

docs/source/advanced/tpu.rst

Lines changed: 56 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -197,7 +197,62 @@ set the 16-bit flag.
197197
198198
Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.
199199

200-
----------------
200+
201+
-----------------
202+
203+
Weight Sharing/Tying
204+
--------------------
205+
Weight Tying/Sharing is a technique where in the module weights are shared among two or more layers.
206+
This is a common method to reduce memory consumption and is utilized in many State of the Art
207+
architectures today.
208+
209+
PyTorch XLA requires these weights to be tied/shared after moving the model
210+
to the TPU device. To support this requirement Lightning provides a model hook which is
211+
called after the model is moved to the device. Any weights that require to be tied should
212+
be done in the `on_post_move_to_device` model hook. This will ensure that the weights
213+
among the modules are shared and not copied.
214+
215+
PyTorch Lightning has an inbuilt check which verifies that the model parameter lengths
216+
match once the model is moved to the device. If the lengths do not match Lightning
217+
throws a warning message.
218+
219+
Example:
220+
221+
.. code-block:: python
222+
223+
from pytorch_lightning.core.lightning import LightningModule
224+
from torch import nn
225+
from pytorch_lightning.trainer.trainer import Trainer
226+
227+
228+
class WeightSharingModule(LightningModule):
229+
def __init__(self):
230+
super().__init__()
231+
self.layer_1 = nn.Linear(32, 10, bias=False)
232+
self.layer_2 = nn.Linear(10, 32, bias=False)
233+
self.layer_3 = nn.Linear(32, 10, bias=False)
234+
# TPU shared weights are copied independently
235+
# on the XLA device and this line won't have any effect.
236+
# However, it works fine for CPU and GPU.
237+
self.layer_3.weight = self.layer_1.weight
238+
239+
def forward(self, x):
240+
x = self.layer_1(x)
241+
x = self.layer_2(x)
242+
x = self.layer_3(x)
243+
return x
244+
245+
def on_post_move_to_device(self):
246+
# Weights shared after the model has been moved to TPU Device
247+
self.layer_3.weight = self.layer_1.weight
248+
249+
250+
model = WeightSharingModule()
251+
trainer = Trainer(max_epochs=1, tpu_cores=8)
252+
253+
See `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_
254+
255+
-----------------------
201256

202257
Performance considerations
203258
--------------------------

pytorch_lightning/core/decorators.py

Lines changed: 41 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from functools import wraps
1717
from typing import Callable
1818

19-
from pytorch_lightning.core.lightning import LightningModule
19+
from pytorch_lightning.utilities import rank_zero_warn
2020

2121

2222
def auto_move_data(fn: Callable) -> Callable:
@@ -54,6 +54,7 @@ def forward(self, x):
5454

5555
@wraps(fn)
5656
def auto_transfer_args(self, *args, **kwargs):
57+
from pytorch_lightning.core.lightning import LightningModule
5758
if not isinstance(self, LightningModule):
5859
return fn(self, *args, **kwargs)
5960

@@ -62,3 +63,42 @@ def auto_transfer_args(self, *args, **kwargs):
6263
return fn(self, *args, **kwargs)
6364

6465
return auto_transfer_args
66+
67+
68+
def parameter_validation(fn: Callable) -> Callable:
69+
"""
70+
Decorator for :meth:`~pytorch_lightning.core.LightningModule.to` method.
71+
Validates that the module parameter lengths match after moving to the device. It is useful
72+
when tying weights on TPU's.
73+
74+
Args:
75+
fn: ``.to`` method
76+
77+
Note:
78+
TPU's require weights to be tied/shared after moving the module to the device.
79+
Failure to do this results in the initialization of new weights which are not tied.
80+
To overcome this issue, weights should be tied using the ``on_post_move_to_device`` model hook
81+
which is called after the module has been moved to the device.
82+
83+
See Also:
84+
- `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_
85+
"""
86+
87+
@wraps(fn)
88+
def inner_fn(self, *args, **kwargs):
89+
pre_layer_count = len(list(self.parameters()))
90+
module = fn(self, *args, **kwargs)
91+
self.on_post_move_to_device()
92+
post_layer_count = len(list(self.parameters()))
93+
94+
if not pre_layer_count == post_layer_count:
95+
rank_zero_warn(
96+
f'The model layers do not match after moving to the target device.'
97+
' If your model employs weight sharing on TPU,'
98+
' please tie your weights using the `on_post_move_to_device` model hook.\n'
99+
f'Layer count: [Before: {pre_layer_count} After: {post_layer_count}]'
100+
)
101+
102+
return module
103+
104+
return inner_fn

pytorch_lightning/core/hooks.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -318,6 +318,22 @@ def on_after_backward(self):
318318
319319
"""
320320

321+
def on_post_move_to_device(self) -> None:
322+
"""
323+
Called in the ``parameter_validation`` decorator after :meth:`~pytorch_lightning.core.LightningModule.to`
324+
is called. This is a good place to tie weights between modules after moving them to a device. Can be
325+
used when training models with weight sharing properties on TPU.
326+
327+
Addresses the handling of shared weights on TPU:
328+
https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks
329+
330+
Example::
331+
332+
def on_post_move_to_device(self):
333+
self.decoder.weight = self.encoder.weight
334+
335+
"""
336+
321337

322338
class DataHooks:
323339
"""Hooks to be used with LightningDataModule."""

pytorch_lightning/utilities/device_dtype_mixin.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@
1717
import torch
1818
from torch.nn import Module
1919

20+
from pytorch_lightning.core.decorators import parameter_validation
21+
2022

2123
class DeviceDtypeModuleMixin(Module):
2224
__jit_unused_properties__ = ['device', 'dtype']
@@ -50,6 +52,7 @@ def device(self, new_device: Union[str, torch.device]):
5052
# Necessary to avoid infinite recursion
5153
raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).')
5254

55+
@parameter_validation
5356
def to(self, *args, **kwargs) -> Module:
5457
"""Moves and/or casts the parameters and buffers.
5558
@@ -86,6 +89,9 @@ def to(self, *args, **kwargs) -> Module:
8689
... def __init__(self, weight: torch.Tensor):
8790
... super().__init__()
8891
... self.register_buffer('weight', weight)
92+
...
93+
... def on_post_move_to_device(self):
94+
... pass
8995
>>> _ = torch.manual_seed(0)
9096
>>> module = ExampleModule(torch.rand(3, 4))
9197
>>> module.weight #doctest: +ELLIPSIS

tests/accelerators/test_tpu_backend.py

Lines changed: 60 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,15 +14,32 @@
1414

1515
import pytest
1616
import torch
17+
from torch import nn
1718

1819
from pytorch_lightning import Trainer
1920
from pytorch_lightning.trainer.states import TrainerState
20-
from pytorch_lightning.utilities.xla_device import XLADeviceUtils
21+
from pytorch_lightning.utilities import _TPU_AVAILABLE
2122
from tests.helpers.boring_model import BoringModel
2223
from tests.helpers.utils import pl_multi_process_test
2324

2425

25-
@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine")
26+
class WeightSharingModule(BoringModel):
27+
28+
def __init__(self):
29+
super().__init__()
30+
self.layer_1 = nn.Linear(32, 10, bias=False)
31+
self.layer_2 = nn.Linear(10, 32, bias=False)
32+
self.layer_3 = nn.Linear(32, 10, bias=False)
33+
self.layer_3.weight = self.layer_1.weight
34+
35+
def forward(self, x):
36+
x = self.layer_1(x)
37+
x = self.layer_2(x)
38+
x = self.layer_3(x)
39+
return x
40+
41+
42+
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
2643
@pl_multi_process_test
2744
def test_resume_training_on_cpu(tmpdir):
2845
""" Checks if training can be resumed from a saved checkpoint on CPU"""
@@ -53,7 +70,7 @@ def test_resume_training_on_cpu(tmpdir):
5370
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
5471

5572

56-
@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine")
73+
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
5774
@pl_multi_process_test
5875
def test_if_test_works_after_train(tmpdir):
5976
""" Ensure that .test() works after .fit() """
@@ -63,3 +80,43 @@ def test_if_test_works_after_train(tmpdir):
6380
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True)
6481
trainer.fit(model)
6582
assert trainer.test(model) == 1
83+
84+
85+
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
86+
@pl_multi_process_test
87+
def test_weight_tying_warning(tmpdir, capsys=None):
88+
"""
89+
Ensure a warning is thrown if model parameter lengths do not match
90+
post moving to device.
91+
"""
92+
93+
model = WeightSharingModule()
94+
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)
95+
96+
with pytest.warns(UserWarning, match=r'The model layers do not match after moving to the target device.'):
97+
result = trainer.fit(model)
98+
assert result
99+
100+
101+
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
102+
@pl_multi_process_test
103+
def test_if_weights_tied(tmpdir, capsys=None):
104+
"""
105+
Test if weights are properly tied on `on_post_move_to_device`.
106+
Ensure no warning for parameter mismatch is thrown.
107+
"""
108+
109+
class Model(WeightSharingModule):
110+
111+
def on_post_move_to_device(self):
112+
self.layer_3.weight = self.layer_1.weight
113+
114+
model = Model()
115+
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)
116+
117+
with pytest.warns(UserWarning) as warnings:
118+
result = trainer.fit(model)
119+
assert result
120+
121+
assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list))
122+
assert trainer.test(model) == 1

0 commit comments

Comments
 (0)