Skip to content

Weight tying is broken on TPUs leading to silent errors #2705

@ibeltagy

Description

@ibeltagy

🐛 Bug

PyTorch/XLA documentation mentions here that weight tying should happen after moving tensors to XLA, otherwise the tensors are copied. This is a silent error that can easily go undetected (thanks to @matt-peters for pointing it out), and it would be good if PL guards the user against it. Notice that weight tying is pretty common in today's models not a corner case.

Code sample

The following code snippet shows how to detect that this issue is happening and how to guard against it.

import pl
class MyPLModel(pl.LightningModule):
    def to(self, *args, **kwargs):
        param_count_before_moving_to_device = len(list(self.parameters()))  # 
        super().to(*args, **kwargs)
        if self.trainer.use_tpu:
            # need to re-tie the weights after moving to XLA!
            self.tie_weights()  # a new function that the user needs to implement
        param_count_after_moving_to_device = len(list(self.parameters()))
        assert param_count_before_moving_to_device == param_count_after_moving_to_device

Metadata

Metadata

Assignees

Labels

accelerator: tpuTensor Processing UnitbugSomething isn't workingfeatureIs an improvement or enhancementhelp wantedOpen to be worked on

Type

No type

Projects

No projects

Milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions