Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
5cc9959
added on_post_move_to_device
lezwon Jan 5, 2021
3d1a313
added tests
lezwon Jan 9, 2021
bc7910b
docs and refactors
lezwon Jan 10, 2021
bb4c891
Update tests/backends/test_tpu_backend.py
lezwon Jan 12, 2021
9897945
Update docs/source/tpu.rst
lezwon Jan 12, 2021
0528db8
Update docs/source/tpu.rst
lezwon Jan 12, 2021
c41e4ac
Update pytorch_lightning/core/decorators.py
lezwon Jan 12, 2021
c7866ff
Update pytorch_lightning/core/decorators.py
lezwon Jan 12, 2021
ac494d1
Update docs/source/tpu.rst
lezwon Jan 14, 2021
5e49706
Update pytorch_lightning/core/decorators.py
lezwon Jan 14, 2021
e1126b2
Update pytorch_lightning/core/decorators.py
lezwon Jan 14, 2021
975b899
Update pytorch_lightning/core/decorators.py
lezwon Jan 14, 2021
814b163
Update pytorch_lightning/core/decorators.py
lezwon Jan 14, 2021
e69910a
Update pytorch_lightning/core/hooks.py
lezwon Jan 14, 2021
8a22096
moved weight sharing module back to test
lezwon Jan 16, 2021
bf5349b
add count to warning
lezwon Jan 16, 2021
35e67c0
fix doctest
lezwon Jan 16, 2021
829c041
import trainer in doctest
lezwon Jan 16, 2021
a1e574f
import trainer in doctest
lezwon Jan 18, 2021
ac71a15
do not test code as no TPU device
lezwon Jan 19, 2021
5a5f306
Merge branch 'release/1.2-dev' into bugfix/2705_weights_tying
lezwon Jan 19, 2021
432c3c3
param count to layer count
lezwon Jan 20, 2021
beb1ab4
formatting
Borda Jan 28, 2021
a23e23c
Merge remote-tracking branch 'origin/release/1.2-dev' into bugfix/270…
lezwon Jan 29, 2021
4cf6799
update docs
lezwon Feb 1, 2021
983619e
Merge remote-tracking branch 'origin/release/1.2-dev' into bugfix/270…
lezwon Feb 1, 2021
7f7e195
Merge remote-tracking branch 'origin/release/1.2-dev' into bugfix/270…
lezwon Feb 2, 2021
7d8823f
Merge branch 'master' of https://github.com/PyTorchLightning/pytorch-…
tchaton Feb 17, 2021
2268df8
update import
tchaton Feb 17, 2021
8acfd39
update
tchaton Feb 17, 2021
62042c6
resolve tests
Feb 17, 2021
3605c0b
remove legacy accelerator
tchaton Feb 17, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

- Added `IoU` class interface ([#4704](https://github.com/PyTorchLightning/pytorch-lightning/pull/4704))

- Support to tie weights after moving model to TPU via `on_post_move_to_device` hook

- Added missing val/test hooks in `LightningModule` ([#5467](https://github.com/PyTorchLightning/pytorch-lightning/pull/5467))

Expand Down
57 changes: 56 additions & 1 deletion docs/source/advanced/tpu.rst
Original file line number Diff line number Diff line change
Expand Up @@ -197,7 +197,62 @@ set the 16-bit flag.

Under the hood the xla library will use the `bfloat16 type <https://en.wikipedia.org/wiki/Bfloat16_floating-point_format>`_.

----------------

-----------------

Weight Sharing/Tying
--------------------
Weight Tying/Sharing is a technique where in the module weights are shared among two or more layers.
This is a common method to reduce memory consumption and is utilized in many State of the Art
architectures today.

PyTorch XLA requires these weights to be tied/shared after moving the model
to the TPU device. To support this requirement Lightning provides a model hook which is
called after the model is moved to the device. Any weights that require to be tied should
be done in the `on_post_move_to_device` model hook. This will ensure that the weights
among the modules are shared and not copied.

PyTorch Lightning has an inbuilt check which verifies that the model parameter lengths
match once the model is moved to the device. If the lengths do not match Lightning
throws a warning message.

Example:

.. code-block:: python

from pytorch_lightning.core.lightning import LightningModule
from torch import nn
from pytorch_lightning.trainer.trainer import Trainer


class WeightSharingModule(LightningModule):
def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(32, 10, bias=False)
self.layer_2 = nn.Linear(10, 32, bias=False)
self.layer_3 = nn.Linear(32, 10, bias=False)
# TPU shared weights are copied independently
# on the XLA device and this line won't have any effect.
# However, it works fine for CPU and GPU.
self.layer_3.weight = self.layer_1.weight

def forward(self, x):
x = self.layer_1(x)
x = self.layer_2(x)
x = self.layer_3(x)
return x

def on_post_move_to_device(self):
# Weights shared after the model has been moved to TPU Device
self.layer_3.weight = self.layer_1.weight


model = WeightSharingModule()
trainer = Trainer(max_epochs=1, tpu_cores=8)

See `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_

-----------------------

Performance considerations
--------------------------
Expand Down
42 changes: 41 additions & 1 deletion pytorch_lightning/core/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from functools import wraps
from typing import Callable

from pytorch_lightning.core.lightning import LightningModule
from pytorch_lightning.utilities import rank_zero_warn


def auto_move_data(fn: Callable) -> Callable:
Expand Down Expand Up @@ -54,6 +54,7 @@ def forward(self, x):

@wraps(fn)
def auto_transfer_args(self, *args, **kwargs):
from pytorch_lightning.core.lightning import LightningModule
if not isinstance(self, LightningModule):
return fn(self, *args, **kwargs)

Expand All @@ -62,3 +63,42 @@ def auto_transfer_args(self, *args, **kwargs):
return fn(self, *args, **kwargs)

return auto_transfer_args


def parameter_validation(fn: Callable) -> Callable:
"""
Decorator for :meth:`~pytorch_lightning.core.LightningModule.to` method.
Validates that the module parameter lengths match after moving to the device. It is useful
when tying weights on TPU's.

Args:
fn: ``.to`` method

Note:
TPU's require weights to be tied/shared after moving the module to the device.
Failure to do this results in the initialization of new weights which are not tied.
To overcome this issue, weights should be tied using the ``on_post_move_to_device`` model hook
which is called after the module has been moved to the device.

See Also:
- `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_
"""

@wraps(fn)
def inner_fn(self, *args, **kwargs):
pre_layer_count = len(list(self.parameters()))
module = fn(self, *args, **kwargs)
self.on_post_move_to_device()
post_layer_count = len(list(self.parameters()))

if not pre_layer_count == post_layer_count:
rank_zero_warn(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lezwon Could you help me out on what this check means?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@kaushikb11 if the layer weights are not tied while on tpu, then the layer count on the tpu will show as no_of_layers + 1, as the xla library will make a copy of the layer weights. This check makes sure the layer count matches after moving the model to the device. 😊👍

Copy link
Contributor

@kaushikb11 kaushikb11 Apr 20, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lezwon Yup, but I don't think it's necessary to do parameter_validation for every module to call? wdyt?

Just a fyi: Currently I am doing this https://github.com/PyTorchLightning/pytorch-lightning/blob/tpu_spawn_added/pytorch_lightning/plugins/training_type/tpu_spawn.py#L172, by using xla's MpModelWrapper. This will make the test fail. Maybe I could add this check specific to TPU acclerators.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, you could maybe add it specifically to TPU accelerators. Also do check it works well when using xla with GPU's.

f'The model layers do not match after moving to the target device.'
' If your model employs weight sharing on TPU,'
' please tie your weights using the `on_post_move_to_device` model hook.\n'
f'Layer count: [Before: {pre_layer_count} After: {post_layer_count}]'
)

return module

return inner_fn
16 changes: 16 additions & 0 deletions pytorch_lightning/core/hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -318,6 +318,22 @@ def on_after_backward(self):

"""

def on_post_move_to_device(self) -> None:
"""
Called in the ``parameter_validation`` decorator after :meth:`~pytorch_lightning.core.LightningModule.to`
is called. This is a good place to tie weights between modules after moving them to a device. Can be
used when training models with weight sharing properties on TPU.

Addresses the handling of shared weights on TPU:
https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks

Example::

def on_post_move_to_device(self):
self.decoder.weight = self.encoder.weight

"""


class DataHooks:
"""Hooks to be used with LightningDataModule."""
Expand Down
6 changes: 6 additions & 0 deletions pytorch_lightning/utilities/device_dtype_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@
import torch
from torch.nn import Module

from pytorch_lightning.core.decorators import parameter_validation


class DeviceDtypeModuleMixin(Module):
__jit_unused_properties__ = ['device', 'dtype']
Expand Down Expand Up @@ -50,6 +52,7 @@ def device(self, new_device: Union[str, torch.device]):
# Necessary to avoid infinite recursion
raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).')

@parameter_validation
Copy link
Contributor

@tchaton tchaton Jan 11, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We are trying not to use decorator as they are hard to debug and easy to forget.

@rohitgr7 doesn't it overlap with your new transfer_batch_to_device with on_before/after_transfert_to_device ?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this one is for models, on_before/after_transfert_to_device is for data. It's on_before/after_batch_transfer.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I feel like this kind of check shall in a sanity check there is no need to execute in each call to

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@tchaton @Borda should we have this check without the decorator? move it somewhere else?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@lezwon do we really need to run this check with every to(...) or can it be just for a single epoch, ale it can be overkill to wrap and check if it is empty...
so if this is not essential for this PR lets move it to separate PR and let this fix land asap 🐰

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

cc: @PyTorchLightning/core-contributors

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any chance the weight tying validation can happen within the TPU accelerator? and happen only after model_to_device is called?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we should move it there in another PR.

def to(self, *args, **kwargs) -> Module:
"""Moves and/or casts the parameters and buffers.

Expand Down Expand Up @@ -86,6 +89,9 @@ def to(self, *args, **kwargs) -> Module:
... def __init__(self, weight: torch.Tensor):
... super().__init__()
... self.register_buffer('weight', weight)
...
... def on_post_move_to_device(self):
... pass
>>> _ = torch.manual_seed(0)
>>> module = ExampleModule(torch.rand(3, 4))
>>> module.weight #doctest: +ELLIPSIS
Expand Down
63 changes: 60 additions & 3 deletions tests/accelerators/test_tpu_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,32 @@

import pytest
import torch
from torch import nn

from pytorch_lightning import Trainer
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities.xla_device import XLADeviceUtils
from pytorch_lightning.utilities import _TPU_AVAILABLE
from tests.helpers.boring_model import BoringModel
from tests.helpers.utils import pl_multi_process_test


@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine")
class WeightSharingModule(BoringModel):

def __init__(self):
super().__init__()
self.layer_1 = nn.Linear(32, 10, bias=False)
self.layer_2 = nn.Linear(10, 32, bias=False)
self.layer_3 = nn.Linear(32, 10, bias=False)
self.layer_3.weight = self.layer_1.weight

def forward(self, x):
x = self.layer_1(x)
x = self.layer_2(x)
x = self.layer_3(x)
return x


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_resume_training_on_cpu(tmpdir):
""" Checks if training can be resumed from a saved checkpoint on CPU"""
Expand Down Expand Up @@ -53,7 +70,7 @@ def test_resume_training_on_cpu(tmpdir):
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


@pytest.mark.skipif(not XLADeviceUtils.tpu_device_exists(), reason="test requires TPU machine")
@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_if_test_works_after_train(tmpdir):
""" Ensure that .test() works after .fit() """
Expand All @@ -63,3 +80,43 @@ def test_if_test_works_after_train(tmpdir):
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True)
trainer.fit(model)
assert trainer.test(model) == 1


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_weight_tying_warning(tmpdir, capsys=None):
"""
Ensure a warning is thrown if model parameter lengths do not match
post moving to device.
"""

model = WeightSharingModule()
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)

with pytest.warns(UserWarning, match=r'The model layers do not match after moving to the target device.'):
result = trainer.fit(model)
assert result


@pytest.mark.skipif(not _TPU_AVAILABLE, reason="test requires TPU machine")
@pl_multi_process_test
def test_if_weights_tied(tmpdir, capsys=None):
"""
Test if weights are properly tied on `on_post_move_to_device`.
Ensure no warning for parameter mismatch is thrown.
"""

class Model(WeightSharingModule):

def on_post_move_to_device(self):
self.layer_3.weight = self.layer_1.weight

model = Model()
trainer = Trainer(checkpoint_callback=True, max_epochs=1, tpu_cores=1)

with pytest.warns(UserWarning) as warnings:
result = trainer.fit(model)
assert result

assert not list(filter(lambda x: 'The model layers do not match' in str(x), warnings.list))
assert trainer.test(model) == 1