Skip to content

Commit 34ceeb7

Browse files
authored
Merge 15d5fd8 into 8001987
2 parents 8001987 + 15d5fd8 commit 34ceeb7

File tree

2 files changed

+7
-1
lines changed

2 files changed

+7
-1
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8686
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
8787

8888

89+
- Fixed `SingleTPU` calling `all_gather` ([#6296](https://github.com/PyTorchLightning/pytorch-lightning/pull/6296))
90+
91+
8992
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)
9093

9194

pytorch_lightning/accelerators/tpu.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -46,4 +46,7 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
4646
Return:
4747
A tensor of shape (world_size, batch, ...)
4848
"""
49-
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
49+
# todo: Add support for backward with all_gather
50+
if torch.distributed.is_initialized():
51+
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
52+
return tensor

0 commit comments

Comments
 (0)