diff --git a/CHANGELOG.md b/CHANGELOG.md index 3d4eeddb5b6ca..5df2ea86683e6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -55,6 +55,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added `IoU` class interface ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704)) +- Support to tie weights after moving model to TPU via `on_post_move_to_device` hook - Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467)) diff --git a/docs/source/advanced/tpu.rst b/docs/source/advanced/tpu.rst index e98003a3edab6..b9688ce425b5f 100644 --- a/docs/source/advanced/tpu.rst +++ b/docs/source/advanced/tpu.rst @@ -197,7 +197,62 @@ set the 16-bit flag. Under the hood the xla library will use the `bfloat16 type `_. ----------------- + +----------------- + +Weight Sharing/Tying +-------------------- +Weight Tying/Sharing is a technique where in the module weights are shared among two or more layers. +This is a common method to reduce memory consumption and is utilized in many State of the Art +architectures today. + +PyTorch XLA requires these weights to be tied/shared after moving the model +to the TPU device. To support this requirement Lightning provides a model hook which is +called after the model is moved to the device. Any weights that require to be tied should +be done in the `on_post_move_to_device` model hook. This will ensure that the weights +among the modules are shared and not copied. + +PyTorch Lightning has an inbuilt check which verifies that the model parameter lengths +match once the model is moved to the device. If the lengths do not match Lightning +throws a warning message. + +Example: + +.. code-block:: python + + from pytorch_lightning.core.lightning import LightningModule + from torch import nn + from pytorch_lightning.trainer.trainer import Trainer + + + class WeightSharingModule(LightningModule): + def __init__(self): + super().__init__() + self.layer_1 = nn.Linear(32, 10, bias=False) + self.layer_2 = nn.Linear(10, 32, bias=False) + self.layer_3 = nn.Linear(32, 10, bias=False) + # TPU shared weights are copied independently + # on the XLA device and this line won't have any effect. + # However, it works fine for CPU and GPU. + self.layer_3.weight = self.layer_1.weight + + def forward(self, x): + x = self.layer_1(x) + x = self.layer_2(x) + x = self.layer_3(x) + return x + + def on_post_move_to_device(self): + # Weights shared after the model has been moved to TPU Device + self.layer_3.weight = self.layer_1.weight + + + model = WeightSharingModule() + trainer = Trainer(max_epochs=1, tpu_cores=8) + +See `XLA Documentation `_ + +----------------------- Performance considerations -------------------------- diff --git a/pytorch_lightning/core/decorators.py b/pytorch_lightning/core/decorators.py index e67b7c230e93c..024f7d8bec49a 100644 --- a/pytorch_lightning/core/decorators.py +++ b/pytorch_lightning/core/decorators.py @@ -16,7 +16,7 @@ from functools import wraps from typing import Callable -from pytorch_lightning.core.lightning import LightningModule +from pytorch_lightning.utilities import rank_zero_warn def auto_move_data(fn: Callable) -> Callable: @@ -54,6 +54,7 @@ def forward(self, x): @wraps(fn) def auto_transfer_args(self, *args, **kwargs): + from pytorch_lightning.core.lightning import LightningModule if not isinstance(self, LightningModule): return fn(self, *args, **kwargs) @@ -62,3 +63,42 @@ def auto_transfer_args(self, *args, **kwargs): return fn(self, *args, **kwargs) return auto_transfer_args + + +def parameter_validation(fn: Callable) -> Callable: + """ + Decorator for :meth:`~pytorch_lightning.core.LightningModule.to` method. + Validates that the module parameter lengths match after moving to the device. It is useful + when tying weights on TPU's. + + Args: + fn: ``.to`` method + + Note: + TPU's require weights to be tied/shared after moving the module to the device. + Failure to do this results in the initialization of new weights which are not tied. + To overcome this issue, weights should be tied using the ``on_post_move_to_device`` model hook + which is called after the module has been moved to the device. + + See Also: + - `XLA Documentation `_ + """ + + @wraps(fn) + def inner_fn(self, *args, **kwargs): + pre_layer_count = len(list(self.parameters())) + module = fn(self, *args, **kwargs) + self.on_post_move_to_device() + post_layer_count = len(list(self.parameters())) + + if not pre_layer_count == post_layer_count: + rank_zero_warn( + f'The model layers do not match after moving to the target device.' + ' If your model employs weight sharing on TPU,' + ' please tie your weights using the `on_post_move_to_device` model hook.\n' + f'Layer count: [Before: {pre_layer_count} After: {post_layer_count}]' + ) + + return module + + return inner_fn diff --git a/pytorch_lightning/core/hooks.py b/pytorch_lightning/core/hooks.py index ac7bb2a1d20e1..a91abc1bc6c82 100644 --- a/pytorch_lightning/core/hooks.py +++ b/pytorch_lightning/core/hooks.py @@ -318,6 +318,22 @@ def on_after_backward(self): """ + def on_post_move_to_device(self) -> None: + """ + Called in the ``parameter_validation`` decorator after :meth:`~pytorch_lightning.core.LightningModule.to` + is called. This is a good place to tie weights between modules after moving them to a device. Can be + used when training models with weight sharing properties on TPU. + + Addresses the handling of shared weights on TPU: + https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks + + Example:: + + def on_post_move_to_device(self): + self.decoder.weight = self.encoder.weight + + """ + class DataHooks: """Hooks to be used with LightningDataModule.""" diff --git a/pytorch_lightning/utilities/device_dtype_mixin.py b/pytorch_lightning/utilities/device_dtype_mixin.py index c9e0b512cf30c..6408c6e21cad4 100644 --- a/pytorch_lightning/utilities/device_dtype_mixin.py +++ b/pytorch_lightning/utilities/device_dtype_mixin.py @@ -17,6 +17,8 @@ import torch from torch.nn import Module +from pytorch_lightning.core.decorators import parameter_validation + class DeviceDtypeModuleMixin(Module): __jit_unused_properties__ = ['device', 'dtype'] @@ -50,6 +52,7 @@ def device(self, new_device: Union[str, torch.device]): # Necessary to avoid infinite recursion raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).') + @parameter_validation def to(self, *args, **kwargs) -> Module: """Moves and/or casts the parameters and buffers. @@ -86,6 +89,9 @@ def to(self, *args, **kwargs) -> Module: ... def __init__(self, weight: torch.Tensor): ... super().__init__() ... self.register_buffer('weight', weight) + ... + ... def on_post_move_to_device(self): + ... pass >>> _ = torch.manual_seed(0) >>> module = ExampleModule(torch.rand(3, 4)) >>> module.weight #doctest: +ELLIPSIS diff --git a/tests/accelerators/test_tpu_backend.py b/tests/accelerators/test_tpu_backend.py index 8e20cefe3b3d5..daea22968b253 100644 --- a/tests/accelerators/test_tpu_backend.py +++ b/tests/accelerators/test_tpu_backend.py @@ -14,15 +14,32 @@ import pytest import torch +from torch import nn from pytorch_lightning import Trainer from pytorch_lightning.trainer.states import TrainerState -from pytorch_lightning.utilities.xla_device import XLADeviceUtils +from pytorch_lightning.utilities import _TPU_AVAILABLE from tests.helpers.boring_model import BoringModel from tests.helpers.utils import pl_multi_process_test -@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine") +class WeightSharingModule(BoringModel): + + def __init__(self): + super().__init__() + self.layer_1 = nn.Linear(32, 10, bias=False) + self.layer_2 = nn.Linear(10, 32, bias=False) + self.layer_3 = nn.Linear(32, 10, bias=False) + self.layer_3.weight = self.layer_1.weight + + def forward(self, x): + x = self.layer_1(x) + x = self.layer_2(x) + x = self.layer_3(x) + return x + + +@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test def test_resume_training_on_cpu(tmpdir): """ Checks if training can be resumed from a saved checkpoint on CPU""" @@ -53,7 +70,7 @@ def test_resume_training_on_cpu(tmpdir): assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}" -@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine") +@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") @pl_multi_process_test def test_if_test_works_after_train(tmpdir): """ Ensure that .test() works after .fit() """ @@ -63,3 +80,43 @@ def test_if_test_works_after_train(tmpdir): trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True) trainer.fit(model) assert trainer.test(model) == 1 + + +@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@pl_multi_process_test +def test_weight_tying_warning(tmpdir, capsys=None): + """ + Ensure a warning is thrown if model parameter lengths do not match + post moving to device. + """ + + model = WeightSharingModule() + trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) + + with pytest.warns(UserWarning, match=r'The model layers do not match after moving to the target device.'): + result = trainer.fit(model) + assert result + + +@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine") +@pl_multi_process_test +def test_if_weights_tied(tmpdir, capsys=None): + """ + Test if weights are properly tied on `on_post_move_to_device`. + Ensure no warning for parameter mismatch is thrown. + """ + + class Model(WeightSharingModule): + + def on_post_move_to_device(self): + self.layer_3.weight = self.layer_1.weight + + model = Model() + trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1) + + with pytest.warns(UserWarning) as warnings: + result = trainer.fit(model) + assert result + + assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list)) + assert trainer.test(model) == 1