Skip to content

Commit fdaa5d7

Browse files
committed
Fix all_gather for tpu_cores=8
1 parent 9e35f97 commit fdaa5d7

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
163163
- Fixed an exception in the layer summary when the model contains torch.jit scripted submodules ([#6511](https://github.com/PyTorchLightning/pytorch-lightning/pull/6511))
164164

165165

166+
- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8`
167+
168+
166169
## [1.2.3] - 2021-03-09
167170

168171
### Fixed

pytorch_lightning/accelerators/tpu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,12 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
4646
Function to gather a tensor from several distributed processes
4747
Args:
4848
tensor: tensor of shape (batch, ...)
49-
group: the process group to gather results from. Defaults to all processes (world)
50-
sync_grads: flag that allows users to synchronize gradients for all_gather op
49+
group: not available with TPUs
50+
sync_grads: not available with TPUs
5151
Return:
5252
A tensor of shape (world_size, batch, ...)
5353
"""
5454
# todo: Add support for backward with all_gather
55-
if torch.distributed.is_initialized():
56-
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
55+
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
56+
return xm.all_gather(tensor).view(-1, *tensor.shape)
5757
return tensor

0 commit comments

Comments
 (0)