-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add option for weight tying on TPU's #5441
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
7dd0f7e to
0e6d43b
Compare
|
Hello @lezwon! Thanks for updating this PR. There are currently no PEP 8 issues detected in this Pull Request. Cheers! 🍻 Comment last updated at 2021-02-17 23:42:15 UTC |
Codecov Report
@@ Coverage Diff @@
## master #5441 +/- ##
=======================================
- Coverage 93% 91% -2%
=======================================
Files 160 160
Lines 11358 11720 +362
=======================================
+ Hits 10569 10648 +79
- Misses 789 1072 +283 |
2275b35 to
5a00c88
Compare
|
this should go to release/1.2-dev. |
5a00c88 to
eebc4de
Compare
|
Removed ready-to-go since this is now blocked by the accelerator refactor |
…lightning into bugfix/2705_weights_tying
| post_layer_count = len(list(self.parameters())) | ||
|
|
||
| if not pre_layer_count == post_layer_count: | ||
| rank_zero_warn( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lezwon Could you help me out on what this check means?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@kaushikb11 if the layer weights are not tied while on tpu, then the layer count on the tpu will show as no_of_layers + 1, as the xla library will make a copy of the layer weights. This check makes sure the layer count matches after moving the model to the device. 😊👍
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@lezwon Yup, but I don't think it's necessary to do parameter_validation for every module to call? wdyt?
Just a fyi: Currently I am doing this https://github.com/PyTorchLightning/pytorch-lightning/blob/tpu_spawn_added/pytorch_lightning/plugins/training_type/tpu_spawn.py#L172, by using xla's MpModelWrapper. This will make the test fail. Maybe I could add this check specific to TPU acclerators.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, you could maybe add it specifically to TPU accelerators. Also do check it works well when using xla with GPU's.
What does this PR do?
Fixes #2705
Before submitting
PR review
Anyone in the community is free to review the PR once the tests have passed.
Before you start reviewing make sure you have read Review guidelines. In short, see the following bullet-list:
Did you have fun?
Make sure you had fun coding 🙃