Skip to content

Commit eebc4de

Browse files
committed
docs and refactors
1 parent 46bc680 commit eebc4de

File tree

4 files changed

+91
-32
lines changed

4 files changed

+91
-32
lines changed

docs/source/tpu.rst

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,59 @@ set the 16-bit flag.
192192
193193
Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.
194194

195-
----------------
195+
196+
-----------------
197+
198+
Weight Sharing/Tying
199+
-----------------------
200+
Weight Tying/Sharing is a technique where in the module weights are shared among two or more layers.
201+
This is a common method to reduce memory consumption and is utilized in many State of the Art
202+
architectures today.
203+
204+
PyTorch XLA requires these weights to be tied/shared after moving the model
205+
to the TPU device. To support this requirement Lightning provides a model hook which is
206+
called after the model is moved to the device. Any weights that require to be tied should
207+
be done in the `on_post_move_to_device` model hook. This will ensure that the weights
208+
among the modules are shared and not copied.
209+
210+
PyTorch Lightning has an inbuilt check which verifies that the model parameter lengths
211+
match once the model is moved to the device. If the lengths do not match Lightning
212+
throws a warning message.
213+
214+
Example:
215+
216+
.. code-block:: python
217+
218+
import pytorch_lightning as pl
219+
from torch import nn
220+
221+
222+
class WeightSharingModule(pl.LightningModule):
223+
def __init__(self):
224+
super().__init__()
225+
self.layer_1 = nn.Linear(32, 10, bias=False)
226+
self.layer_2 = nn.Linear(10, 32, bias=False)
227+
self.layer_3 = nn.Linear(32, 10, bias=False)
228+
self.layer_3.weight = self.layer_1.weight # Weights will be copied on TPU
229+
230+
def forward(self, x):
231+
x = self.layer_1(x)
232+
x = self.layer_2(x)
233+
x = self.layer_3(x)
234+
return x
235+
236+
def on_post_move_to_device(self):
237+
# Weights shared after the model has been moved to TPU Device
238+
self.layer_3.weight = self.layer_1.weight
239+
240+
241+
model = WeightSharingModule()
242+
trainer = Trainer(max_epochs=1, tpu_cores=8)
243+
result = trainer.fit(model)
244+
245+
See `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_
246+
247+
-----------------------
196248

197249
About XLA
198250
----------

pytorch_lightning/core/decorators.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,23 @@ def auto_transfer_args(self, *args, **kwargs):
6868

6969

7070
def parameter_validation(fn: Callable) -> Callable:
71+
"""
72+
Decorator for `~pytorch_lightning.core.LightningModule.to` method.
73+
Validates that the module parameter lengths match after moving to the device. It is useful
74+
when tying weights on TPU's.
75+
76+
Args:
77+
fn: `.to` method
78+
79+
Note:
80+
TPU's require weights to be tied/shared after moving the module to the device.
81+
Failure to do this results in the initialization of new weights which are not tied.
82+
To overcome this issue, weights should be tied using the `on_post_move_to_device` model hook
83+
which is called after the module has been moved to the device.
84+
85+
See Also:
86+
- `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_
87+
"""
7188
@wraps(fn)
7289
def inner_f(self, *args, **kwargs):
7390
pre_param_count = len(list(self.parameters()))

tests/backends/test_tpu_backend.py

Lines changed: 3 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,12 @@
1414

1515
import pytest
1616
import torch
17-
from torch import nn
1817

1918
from pytorch_lightning import Trainer
20-
from tests.base import SimpleModule
2119
from pytorch_lightning.utilities.xla_device import XLADeviceUtils
2220
from tests.base.boring_model import BoringModel
2321
from tests.base.develop_utils import pl_multi_process_test
22+
from tests.base.weight_sharing_module import WeightSharingModule
2423

2524

2625
@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine")
@@ -73,20 +72,6 @@ def test_weight_tying_warning(tmpdir, capsys=None):
7372
post moving to device.
7473
"""
7574

76-
class WeightSharingModule(SimpleModule):
77-
def __init__(self):
78-
super().__init__()
79-
self.layer_1 = nn.Linear(32, 10, bias=False)
80-
self.layer_2 = nn.Linear(10, 32, bias=False)
81-
self.layer_3 = nn.Linear(32, 10, bias=False)
82-
self.layer_3.weight = self.layer_1.weight
83-
84-
def forward(self, x):
85-
x = self.layer_1(x)
86-
x = self.layer_2(x)
87-
x = self.layer_3(x)
88-
return x
89-
9075
model = WeightSharingModule()
9176
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)
9277

@@ -103,24 +88,11 @@ def test_if_weights_tied(tmpdir, capsys=None):
10388
Ensure no warning for parameter mismatch is thrown.
10489
"""
10590

106-
class WeightSharingModule(SimpleModule):
107-
def __init__(self):
108-
super().__init__()
109-
self.layer_1 = nn.Linear(32, 10, bias=False)
110-
self.layer_2 = nn.Linear(10, 32, bias=False)
111-
self.layer_3 = nn.Linear(32, 10, bias=False)
112-
self.layer_3.weight = self.layer_1.weight
113-
114-
def forward(self, x):
115-
x = self.layer_1(x)
116-
x = self.layer_2(x)
117-
x = self.layer_3(x)
118-
return x
119-
91+
class Model(WeightSharingModule):
12092
def on_post_move_to_device(self):
12193
self.layer_3.weight = self.layer_1.weight
12294

123-
model = WeightSharingModule()
95+
model = Model()
12496
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)
12597

12698
with pytest.warns(UserWarning) as warnings:
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
from torch import nn
2+
3+
from tests.base import SimpleModule
4+
5+
6+
class WeightSharingModule(SimpleModule):
7+
def __init__(self):
8+
super().__init__()
9+
self.layer_1 = nn.Linear(32, 10, bias=False)
10+
self.layer_2 = nn.Linear(10, 32, bias=False)
11+
self.layer_3 = nn.Linear(32, 10, bias=False)
12+
self.layer_3.weight = self.layer_1.weight
13+
14+
def forward(self, x):
15+
x = self.layer_1(x)
16+
x = self.layer_2(x)
17+
x = self.layer_3(x)
18+
return x

0 commit comments

Comments
 (0)