Skip to content

all_gather not called correctly when using tpu_cores=8 #6586

@ethanwharris

Description

@ethanwharris

🐛 Bug

Currently, when tpu_cores is set to 8, all_gather will simply return the original tensor (even though it should gather over all cores). In particular, this clause is not enterred: https://github.com/PyTorchLightning/pytorch-lightning/blob/38a2119359f22dd9525cb7978eb2ac230a36ab59/pytorch_lightning/accelerators/tpu.py#L55

In addition, xm.all_gather does not support the groups or sync_grads arguments.

Please reproduce using the BoringModel

The problem is illustrated here (just prints sizes but shows the point...): https://colab.research.google.com/drive/15CM8EKxP2E_CBu9jcXFoVuS7HDUY2D7V?usp=sharing

Additional context

This should probably be sorted out as a first step before #6295 can be addressed.

Metadata

Metadata

Assignees

Labels

bugSomething isn't working

Type

No type

Projects

No projects

Milestone

No milestone

Relationships

None yet

Development

No branches or pull requests

Issue actions