Skip to content

Commit 1b3e4f9

Browse files
authored
Fix sync_dist for tpus (#6950)
1 parent 80c5293 commit 1b3e4f9

File tree

8 files changed

+56
-23
lines changed

8 files changed

+56
-23
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -240,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
240240
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))
241241

242242

243+
- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))
244+
245+
243246
- Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414))
244247

245248

pytorch_lightning/accelerators/accelerator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ def pre_dispatch(self, trainer: 'pl.Trainer') -> None:
106106
self.precision_plugin.pre_dispatch()
107107

108108
def post_dispatch(self, trainer: 'pl.Trainer') -> None:
109-
"""Hook to do something before the training/evaluation/prediction starts."""
109+
"""Hook to do something after the training/evaluation/prediction starts."""
110110
self.training_type_plugin.post_dispatch()
111111
self.precision_plugin.post_dispatch()
112112

pytorch_lightning/core/step_result.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch import Tensor
2222
from torchmetrics import Metric
2323

24-
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
24+
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed
2525

2626

2727
class Result(Dict):
@@ -105,10 +105,11 @@ def log(
105105

106106
# sync across workers when using distributed training
107107
sync_fn = sync_fn or sync_ddp_if_available
108+
108109
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
109110
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
110111
# TODO: Find a way to make the reduction only once, so we don't need to clone.
111-
if is_dist_initialized and isinstance(value, torch.Tensor):
112+
if (is_dist_initialized or tpu_distributed) and isinstance(value, torch.Tensor):
112113
value = value.clone()
113114
else:
114115
value = torch.tensor(value, device=device, dtype=torch.float)

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
import os
1616
import re
1717
import time
18-
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
18+
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union
1919

2020
import torch
2121
import torch.multiprocessing as mp
@@ -41,7 +41,6 @@
4141
if _OMEGACONF_AVAILABLE:
4242
from omegaconf import DictConfig, ListConfig, OmegaConf
4343

44-
4544
if TYPE_CHECKING:
4645
from torch.nn import Module
4746
from torch.utils.data import DataLoader
@@ -278,4 +277,6 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
278277
Return:
279278
A tensor of shape (world_size, batch, ...)
280279
"""
281-
return xm.all_gather(tensor.unsqueeze(0))
280+
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
281+
tensor = tensor.unsqueeze(0)
282+
return xm.all_gather(tensor)

pytorch_lightning/utilities/__init__.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -53,12 +53,10 @@
5353
_TORCH_QUANTIZE_AVAILABLE,
5454
_TORCHTEXT_AVAILABLE,
5555
_TORCHVISION_AVAILABLE,
56+
_TPU_AVAILABLE,
5657
_XLA_AVAILABLE,
5758
)
5859
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
59-
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401
60-
61-
_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
6260

6361
FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
6462
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps

pytorch_lightning/utilities/distributed.py

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,16 +17,14 @@
1717
import warnings
1818
from functools import partial, wraps
1919
from typing import Any, Optional, Union
20-
from pytorch_lightning.utilities.imports import (
21-
_TORCH_GREATER_EQUAL_1_8,
22-
_TORCH_GREATER_EQUAL_1_9,
23-
)
2420

2521
import torch
26-
2722
from torch.nn.parallel.distributed import DistributedDataParallel
2823

29-
log = logging.getLogger(__name__)
24+
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE
25+
26+
if _TPU_AVAILABLE:
27+
import torch_xla.core.xla_model as xm
3028

3129
if torch.distributed.is_available():
3230
from torch.distributed import group, ReduceOp
@@ -40,6 +38,9 @@ class group:
4038
WORLD = None
4139

4240

41+
log = logging.getLogger(__name__)
42+
43+
4344
def rank_zero_only(fn):
4445

4546
@wraps(fn)
@@ -294,19 +295,13 @@ def register_ddp_comm_hook(
294295
)
295296
"""
296297
if not _TORCH_GREATER_EQUAL_1_8:
297-
rank_zero_warn(
298-
"Not registering DDP comm hook. "
299-
"To use communication hooks, please use pytorch>=1.8.0."
300-
)
298+
rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.")
301299
return
302300
if ddp_comm_hook is None:
303301
return
304302
if ddp_comm_wrapper is not None:
305303
if not _TORCH_GREATER_EQUAL_1_9:
306-
rank_zero_warn(
307-
"Not applying DDP comm wrapper. "
308-
"To use communication wrapper, please use pytorch>=1.9.0."
309-
)
304+
rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.")
310305
else:
311306
rank_zero_info(
312307
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
@@ -318,3 +313,9 @@ def register_ddp_comm_hook(
318313
state=ddp_comm_state,
319314
hook=ddp_comm_hook,
320315
)
316+
317+
318+
def tpu_distributed() -> bool:
319+
if _TPU_AVAILABLE:
320+
return xm.xrt_world_size() > 1
321+
return False

pytorch_lightning/utilities/imports.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,7 @@ def _compare_version(package: str, op, version) -> bool:
8989
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
9090
_TORCHVISION_AVAILABLE = _module_available('torchvision')
9191
_XLA_AVAILABLE = _module_available("torch_xla")
92+
93+
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402
94+
95+
_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

tests/models/test_tpu.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,13 +16,15 @@
1616
from unittest import mock
1717

1818
import pytest
19+
import torch
1920
from torch.utils.data import DataLoader
2021

2122
import tests.helpers.pipelines as tpipes
2223
import tests.helpers.utils as tutils
2324
from pytorch_lightning import Trainer
2425
from pytorch_lightning.accelerators import TPUAccelerator
2526
from pytorch_lightning.callbacks import EarlyStopping
27+
from pytorch_lightning.core.step_result import Result
2628
from pytorch_lightning.plugins import TPUSpawnPlugin
2729
from pytorch_lightning.trainer.states import TrainerState
2830
from pytorch_lightning.utilities import _TPU_AVAILABLE
@@ -416,3 +418,26 @@ def test_if_test_works_with_checkpoint_false(tmpdir):
416418
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
417419
trainer.fit(model)
418420
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"
421+
422+
423+
@RunIf(tpu=True)
424+
@pl_multi_process_test
425+
def test_tpu_sync_dist():
426+
"""Test tpu spawn sync dist operation """
427+
428+
def test_sync_dist(rank):
429+
tensor = torch.tensor([1.0])
430+
training_type_plugin = TPUSpawnPlugin()
431+
432+
res = Result()
433+
res.log(
434+
"test_tensor",
435+
tensor,
436+
sync_fn=training_type_plugin.reduce,
437+
sync_dist=True,
438+
sync_dist_op=torch.distributed.ReduceOp.SUM
439+
)
440+
441+
assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors"
442+
443+
xmp.spawn(test_sync_dist, nprocs=8, start_method='fork')

0 commit comments

Comments
 (0)