-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Closed
Labels
accelerator: tpuTensor Processing UnitTensor Processing UnitbugSomething isn't workingSomething isn't workingfeatureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on
Milestone
Description
🐛 Bug
PyTorch/XLA documentation mentions here that weight tying should happen after moving tensors to XLA, otherwise the tensors are copied. This is a silent error that can easily go undetected (thanks to @matt-peters for pointing it out), and it would be good if PL guards the user against it. Notice that weight tying is pretty common in today's models not a corner case.
Code sample
The following code snippet shows how to detect that this issue is happening and how to guard against it.
import pl
class MyPLModel(pl.LightningModule):
def to(self, *args, **kwargs):
param_count_before_moving_to_device = len(list(self.parameters())) #
super().to(*args, **kwargs)
if self.trainer.use_tpu:
# need to re-tie the weights after moving to XLA!
self.tie_weights() # a new function that the user needs to implement
param_count_after_moving_to_device = len(list(self.parameters()))
assert param_count_before_moving_to_device == param_count_after_moving_to_device
Metadata
Metadata
Assignees
Labels
accelerator: tpuTensor Processing UnitTensor Processing UnitbugSomething isn't workingSomething isn't workingfeatureIs an improvement or enhancementIs an improvement or enhancementhelp wantedOpen to be worked onOpen to be worked on