Skip to content

Commit 782145e

Browse files
kaushikb11Borda
authored andcommitted
Fix sync_dist for tpus (#6950)
(cherry picked from commit 1b3e4f9)
1 parent 3e77551 commit 782145e

File tree

7 files changed

+53
-9
lines changed

7 files changed

+53
-9
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -237,6 +237,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
237237
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
238238

239239

240+
- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))
241+
242+
240243
## [1.2.7] - 2021-04-06
241244

242245
### Fixed

pytorch_lightning/core/step_result.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
from torch import Tensor
2323
from torchmetrics import Metric
2424

25-
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
25+
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed
2626

2727

2828
class Result(Dict):
@@ -139,10 +139,11 @@ def log(
139139

140140
# sync across workers when using distributed training
141141
sync_fn = sync_fn or sync_ddp_if_available
142+
142143
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
143144
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
144145
# TODO: Find a way to make the reduction only once, so we don't need to clone.
145-
if is_dist_initialized and isinstance(value, torch.Tensor):
146+
if (is_dist_initialized or tpu_distributed) and isinstance(value, torch.Tensor):
146147
value = value.clone()
147148
else:
148149
value = torch.tensor(value, device=device, dtype=torch.float)

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
import re
1717
import time
18-
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
18+
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
1919

2020
import torch
2121
import torch.multiprocessing as mp
@@ -40,7 +40,6 @@
4040
if _OMEGACONF_AVAILABLE:
4141
from omegaconf import DictConfig, ListConfig, OmegaConf
4242

43-
4443
if TYPE_CHECKING:
4544
from torch.nn import Module
4645
from torch.utils.data import DataLoader
@@ -276,4 +275,6 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
276275
Return:
277276
A tensor of shape (world_size, batch, ...)
278277
"""
279-
return xm.all_gather(tensor.unsqueeze(0))
278+
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
279+
tensor = tensor.unsqueeze(0)
280+
return xm.all_gather(tensor)

pytorch_lightning/utilities/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,12 +42,10 @@
4242
_TORCH_QUANTIZE_AVAILABLE,
4343
_TORCHTEXT_AVAILABLE,
4444
_TORCHVISION_AVAILABLE,
45+
_TPU_AVAILABLE,
4546
_XLA_AVAILABLE,
4647
)
4748
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
48-
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401
49-
50-
_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
5149

5250
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
5351
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps

pytorch_lightning/utilities/distributed.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,10 @@
2020

2121
import torch
2222

23-
log = logging.getLogger(__name__)
23+
from pytorch_lightning.utilities.imports import _TPU_AVAILABLE
24+
25+
if _TPU_AVAILABLE:
26+
import torch_xla.core.xla_model as xm
2427

2528
if torch.distributed.is_available():
2629
from torch.distributed import group, ReduceOp
@@ -34,6 +37,9 @@ class group:
3437
WORLD = None
3538

3639

40+
log = logging.getLogger(__name__)
41+
42+
3743
def rank_zero_only(fn):
3844

3945
@wraps(fn)
@@ -222,3 +228,9 @@ def all_gather_ddp_if_available(
222228
with torch.no_grad():
223229
return AllGatherGrad.apply(tensor, group)
224230
return tensor
231+
232+
233+
def tpu_distributed() -> bool:
234+
if _TPU_AVAILABLE:
235+
return xm.xrt_world_size() > 1
236+
return False

pytorch_lightning/utilities/imports.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -82,3 +82,7 @@ def _compare_version(package: str, op, version) -> bool:
8282
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
8383
_TORCHVISION_AVAILABLE = _module_available('torchvision')
8484
_XLA_AVAILABLE = _module_available("torch_xla")
85+
86+
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402
87+
88+
_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

tests/models/test_tpu.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
from unittest import mock
1717

1818
import pytest
19+
import torch
1920
from torch.utils.data import DataLoader
2021

2122
import tests.helpers.pipelines as tpipes
2223
import tests.helpers.utils as tutils
2324
from pytorch_lightning import Trainer
2425
from pytorch_lightning.accelerators import TPUAccelerator
2526
from pytorch_lightning.callbacks import EarlyStopping
27+
from pytorch_lightning.core.step_result import Result
2628
from pytorch_lightning.plugins import TPUSpawnPlugin
2729
from pytorch_lightning.trainer.states import TrainerState
2830
from pytorch_lightning.utilities import _TPU_AVAILABLE
@@ -397,3 +399,26 @@ def test_if_test_works_with_checkpoint_false(tmpdir):
397399
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
398400
trainer.fit(model)
399401
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
402+
403+
404+
@RunIf(tpu=True)
405+
@pl_multi_process_test
406+
def test_tpu_sync_dist():
407+
"""Test tpu spawn sync dist operation """
408+
409+
def test_sync_dist(rank):
410+
tensor = torch.tensor([1.0])
411+
training_type_plugin = TPUSpawnPlugin()
412+
413+
res = Result()
414+
res.log(
415+
"test_tensor",
416+
tensor,
417+
sync_fn=training_type_plugin.reduce,
418+
sync_dist=True,
419+
sync_dist_op=torch.distributed.ReduceOp.SUM
420+
)
421+
422+
assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors"
423+
424+
xmp.spawn(test_sync_dist, nprocs=8, start_method='fork')

0 commit comments

Comments
 (0)