-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add option for weight tying on TPU's #5441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
5cc9959
3d1a313
bc7910b
bb4c891
9897945
0528db8
c41e4ac
c7866ff
ac494d1
5e49706
e1126b2
975b899
814b163
e69910a
8a22096
bf5349b
35e67c0
829c041
a1e574f
ac71a15
5a5f306
432c3c3
beb1ab4
a23e23c
4cf6799
983619e
7f7e195
7d8823f
2268df8
8acfd39
62042c6
3605c0b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -16,7 +16,7 @@ | |
| from functools import wraps | ||
| from typing import Callable | ||
|
|
||
| from pytorch_lightning.core.lightning import LightningModule | ||
| from pytorch_lightning.utilities import rank_zero_warn | ||
|
|
||
|
|
||
| def auto_move_data(fn: Callable) -> Callable: | ||
|
|
@@ -54,6 +54,7 @@ def forward(self, x): | |
|
|
||
| @wraps(fn) | ||
| def auto_transfer_args(self, *args, **kwargs): | ||
| from pytorch_lightning.core.lightning import LightningModule | ||
| if not isinstance(self, LightningModule): | ||
| return fn(self, *args, **kwargs) | ||
|
|
||
|
|
@@ -62,3 +63,42 @@ def auto_transfer_args(self, *args, **kwargs): | |
| return fn(self, *args, **kwargs) | ||
|
|
||
| return auto_transfer_args | ||
|
|
||
|
|
||
| def parameter_validation(fn: Callable) -> Callable: | ||
| """ | ||
| Decorator for :meth:`~pytorch_lightning.core.LightningModule.to` method. | ||
| Validates that the module parameter lengths match after moving to the device. It is useful | ||
| when tying weights on TPU's. | ||
|
|
||
| Args: | ||
| fn: ``.to`` method | ||
|
|
||
| Note: | ||
| TPU's require weights to be tied/shared after moving the module to the device. | ||
| Failure to do this results in the initialization of new weights which are not tied. | ||
| To overcome this issue, weights should be tied using the ``on_post_move_to_device`` model hook | ||
| which is called after the module has been moved to the device. | ||
|
|
||
| See Also: | ||
| - `XLA Documentation <https://github.com/pytorch/xla/blob/master/TROUBLESHOOTING.md#xla-tensor-quirks>`_ | ||
| """ | ||
|
|
||
| @wraps(fn) | ||
| def inner_fn(self, *args, **kwargs): | ||
lezwon marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| pre_layer_count = len(list(self.parameters())) | ||
| module = fn(self, *args, **kwargs) | ||
| self.on_post_move_to_device() | ||
| post_layer_count = len(list(self.parameters())) | ||
|
|
||
| if not pre_layer_count == post_layer_count: | ||
| rank_zero_warn( | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lezwon Could you help me out on what this check means?
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @kaushikb11 if the layer weights are not tied while on tpu, then the layer count on the tpu will show as no_of_layers + 1, as the xla library will make a copy of the layer weights. This check makes sure the layer count matches after moving the model to the device. 😊👍
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lezwon Yup, but I don't think it's necessary to do Just a fyi: Currently I am doing this https://github.com/PyTorchLightning/pytorch-lightning/blob/tpu_spawn_added/pytorch_lightning/plugins/training_type/tpu_spawn.py#L172, by using xla's MpModelWrapper. This will make the test fail. Maybe I could add this check specific to TPU acclerators.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, you could maybe add it specifically to TPU accelerators. Also do check it works well when using xla with GPU's. |
||
| f'The model layers do not match after moving to the target device.' | ||
| ' If your model employs weight sharing on TPU,' | ||
| ' please tie your weights using the `on_post_move_to_device` model hook.\n' | ||
| f'Layer count: [Before: {pre_layer_count} After: {post_layer_count}]' | ||
| ) | ||
|
|
||
| return module | ||
|
|
||
| return inner_fn | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -17,6 +17,8 @@ | |
| import torch | ||
| from torch.nn import Module | ||
|
|
||
| from pytorch_lightning.core.decorators import parameter_validation | ||
|
|
||
|
|
||
| class DeviceDtypeModuleMixin(Module): | ||
| __jit_unused_properties__ = ['device', 'dtype'] | ||
|
|
@@ -50,6 +52,7 @@ def device(self, new_device: Union[str, torch.device]): | |
| # Necessary to avoid infinite recursion | ||
| raise RuntimeError('Cannot set the device explicitly. Please use module.to(new_device).') | ||
|
|
||
| @parameter_validation | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. We are trying not to use decorator as they are hard to debug and easy to forget. @rohitgr7 doesn't it overlap with your new
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. this one is for models,
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I feel like this kind of check shall in a sanity check there is no need to execute in each call
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. @lezwon do we really need to run this check with every
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. cc: @PyTorchLightning/core-contributors
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Any chance the weight tying validation can happen within the TPU accelerator? and happen only after model_to_device is called?
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes, we should move it there in another PR. |
||
| def to(self, *args, **kwargs) -> Module: | ||
| """Moves and/or casts the parameters and buffers. | ||
|
|
||
|
|
@@ -86,6 +89,9 @@ def to(self, *args, **kwargs) -> Module: | |
| ... def __init__(self, weight: torch.Tensor): | ||
| ... super().__init__() | ||
| ... self.register_buffer('weight', weight) | ||
| ... | ||
| ... def on_post_move_to_device(self): | ||
| ... pass | ||
| >>> _ = torch.manual_seed(0) | ||
| >>> module = ExampleModule(torch.rand(3, 4)) | ||
| >>> module.weight #doctest: +ELLIPSIS | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.