Skip to content

Commit 9cd985b

Browse files
ethanwharrisBorda
authored andcommitted
Fix all_gather for tpu_cores=8 (#6587)
(cherry picked from commit 983a888)
1 parent caebaea commit 9cd985b

File tree

2 files changed

+27
-16
lines changed

2 files changed

+27
-16
lines changed

CHANGELOG.md

Lines changed: 23 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,29 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
119119
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))
120120

121121

122+
- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541))
123+
124+
125+
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
126+
127+
128+
- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))
129+
130+
131+
- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)
132+
133+
134+
135+
## [1.2.5] - 2021-03-23
136+
137+
### Changed
138+
139+
140+
### Fixed
141+
142+
- Fixed a bug where `all_gather` would not work correctly with `tpu_cores=8` ([#6587](https://github.com/PyTorchLightning/pytorch-lightning/pull/6587))
143+
144+
122145
## [1.2.4] - 2021-03-16
123146

124147
### Changed
@@ -139,9 +162,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
139162
- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541))
140163

141164

142-
- Fixed when Train loop config was run during `Trainer.predict` ([#6541](https://github.com/PyTorchLightning/pytorch-lightning/pull/6541))
143-
144-
145165
## [1.2.3] - 2021-03-09
146166

147167
### Fixed
@@ -180,9 +200,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
180200
- Fixed error thrown when using valid distributed mode in multi node ([#6297](https://github.com/PyTorchLightning/pytorch-lightning/pull/6297)
181201

182202

183-
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
184-
185-
186203
## [1.2.1] - 2021-02-23
187204

188205
### Fixed
@@ -192,12 +209,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
192209
- Fixed error message for AMP + CPU incompatibility ([#6107](https://github.com/PyTorchLightning/pytorch-lightning/pull/6107))
193210

194211

195-
- Disabled batch transfer in DP mode ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))
196-
197-
198-
- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115)
199-
200-
201212
## [1.2.0] - 2021-02-18
202213

203214
### Added

pytorch_lightning/accelerators/tpu.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,12 +35,12 @@ def all_gather(self, tensor: Union[torch.Tensor], group: Optional[Any] = None, s
3535
Function to gather a tensor from several distributed processes
3636
Args:
3737
tensor: tensor of shape (batch, ...)
38-
group: the process group to gather results from. Defaults to all processes (world)
39-
sync_grads: flag that allows users to synchronize gradients for all_gather op
38+
group: not available with TPUs
39+
sync_grads: not available with TPUs
4040
Return:
4141
A tensor of shape (world_size, batch, ...)
4242
"""
4343
# todo: Add support for backward with all_gather
44-
if torch.distributed.is_initialized():
45-
return xm.all_gather(tensor, group=group, sync_grads=sync_grads)
44+
if isinstance(self.training_type_plugin, TPUSpawnPlugin) and self.training_type_plugin.is_distributed:
45+
return xm.all_gather(tensor).view(-1, *tensor.shape)
4646
return tensor

0 commit comments

Comments
 (0)