Skip to content

Commit f819212

Browse files
authored
Merge 590e0df into 55dd3a4
2 parents 55dd3a4 + 590e0df commit f819212

File tree

3 files changed

+15
-0
lines changed

3 files changed

+15
-0
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
131131
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))
132132

133133

134+
- Fixed LightningModule `all_gather` on cpu tensors ([#6416](https://github.com/PyTorchLightning/pytorch-lightning/pull/6416))
135+
136+
134137
- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688))
135138

136139

pytorch_lightning/utilities/apply_func.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,6 +164,14 @@ def batch_to(data):
164164
def convert_to_tensors(data, device: torch.device = None):
165165
if device is None:
166166
raise MisconfigurationException("device (torch.device) should be provided.")
167+
167168
for src_dtype, conversion_func in CONVERSION_DTYPES:
168169
data = apply_to_collection(data, src_dtype, partial(conversion_func, device=device))
170+
171+
def _move_to_device_and_make_contiguous(t: torch.Tensor, device: torch.device):
172+
if t.device != device:
173+
t = t.to(device)
174+
return t.contiguous()
175+
176+
data = apply_to_collection(data, torch.Tensor, partial(_move_to_device_and_make_contiguous, device=device))
169177
return data

tests/utilities/test_all_gather_grad.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,13 +58,17 @@ def training_epoch_end(self, outputs) -> None:
5858
self.training_epoch_end_called = True
5959
losses = torch.stack([x["loss"] for x in outputs])
6060
gathered_loss = self.all_gather({
61+
"losses_tensor_int": torch.tensor([1, 2, 3]),
62+
"losses_tensor_float": torch.tensor([1., 2., 3.]),
6163
"losses_np_ndarray": np.array([1, 2, 3]),
6264
"losses_bool": [True, False],
6365
"losses_float": [0., 1., 2.],
6466
"losses_int": [0, 1, 2],
6567
"losses": losses,
6668
"losses_list": [losses, losses]
6769
})
70+
assert gathered_loss["losses_tensor_int"][0].dtype == torch.int64
71+
assert gathered_loss["losses_tensor_float"][0].dtype == torch.float
6872
assert gathered_loss["losses_np_ndarray"][0].dtype == torch.int64
6973
# torch.bool can't be all_gathered
7074
assert gathered_loss["losses_bool"][0].dtype == torch.uint8

0 commit comments

Comments
 (0)