Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -240,6 +240,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Fixed `AttributeError for `require_backward_grad_sync` when running manual optimization with sharded plugin ([#6915](https://github.com/PyTorchLightning/pytorch-lightning/pull/6915))


- Fixed `sync_dist` for tpus ([#6950](https://github.com/PyTorchLightning/pytorch-lightning/pull/6950))


- Fixed `self.device` not returning the correct device in replicas of data-parallel ([#6414](https://github.com/PyTorchLightning/pytorch-lightning/pull/6414))


Expand Down
2 changes: 1 addition & 1 deletion pytorch_lightning/accelerators/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def pre_dispatch(self, trainer: 'pl.Trainer') -> None:
self.precision_plugin.pre_dispatch()

def post_dispatch(self, trainer: 'pl.Trainer') -> None:
"""Hook to do something before the training/evaluation/prediction starts."""
"""Hook to do something after the training/evaluation/prediction starts."""
self.training_type_plugin.post_dispatch()
self.precision_plugin.post_dispatch()

Expand Down
5 changes: 3 additions & 2 deletions pytorch_lightning/core/step_result.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from torch import Tensor
from torchmetrics import Metric

from pytorch_lightning.utilities.distributed import sync_ddp_if_available
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed


class Result(Dict):
Expand Down Expand Up @@ -105,10 +105,11 @@ def log(

# sync across workers when using distributed training
sync_fn = sync_fn or sync_ddp_if_available

if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
# TODO: Find a way to make the reduction only once, so we don't need to clone.
if is_dist_initialized and isinstance(value, torch.Tensor):
if (is_dist_initialized or tpu_distributed) and isinstance(value, torch.Tensor):
value = value.clone()
else:
value = torch.tensor(value, device=device, dtype=torch.float)
Expand Down
7 changes: 4 additions & 3 deletions pytorch_lightning/plugins/training_type/tpu_spawn.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
import os
import re
import time
from typing import Any, Dict, List, Optional, Union, TYPE_CHECKING
from typing import Any, Dict, List, Optional, TYPE_CHECKING, Union

import torch
import torch.multiprocessing as mp
Expand All @@ -41,7 +41,6 @@
if _OMEGACONF_AVAILABLE:
from omegaconf import DictConfig, ListConfig, OmegaConf


if TYPE_CHECKING:
from torch.nn import Module
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -278,4 +277,6 @@ def all_gather(self, tensor: torch.Tensor, group: Optional[Any] = None, sync_gra
Return:
A tensor of shape (world_size, batch, ...)
"""
return xm.all_gather(tensor.unsqueeze(0))
if isinstance(tensor, torch.Tensor) and tensor.dim() == 0:
tensor = tensor.unsqueeze(0)
return xm.all_gather(tensor)
4 changes: 1 addition & 3 deletions pytorch_lightning/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,10 @@
_TORCH_QUANTIZE_AVAILABLE,
_TORCHTEXT_AVAILABLE,
_TORCHVISION_AVAILABLE,
_TPU_AVAILABLE,
_XLA_AVAILABLE,
)
from pytorch_lightning.utilities.parsing import AttributeDict, flatten_dict, is_picklable # noqa: F401
from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: F401

_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()

FLOAT16_EPSILON = numpy.finfo(numpy.float16).eps
FLOAT32_EPSILON = numpy.finfo(numpy.float32).eps
Expand Down
29 changes: 15 additions & 14 deletions pytorch_lightning/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,14 @@
import warnings
from functools import partial, wraps
from typing import Any, Optional, Union
from pytorch_lightning.utilities.imports import (
_TORCH_GREATER_EQUAL_1_8,
_TORCH_GREATER_EQUAL_1_9,
)

import torch

from torch.nn.parallel.distributed import DistributedDataParallel

log = logging.getLogger(__name__)
from pytorch_lightning.utilities.imports import _TORCH_GREATER_EQUAL_1_8, _TORCH_GREATER_EQUAL_1_9, _TPU_AVAILABLE

if _TPU_AVAILABLE:
import torch_xla.core.xla_model as xm

if torch.distributed.is_available():
from torch.distributed import group, ReduceOp
Expand All @@ -40,6 +38,9 @@ class group:
WORLD = None


log = logging.getLogger(__name__)


def rank_zero_only(fn):

@wraps(fn)
Expand Down Expand Up @@ -294,19 +295,13 @@ def register_ddp_comm_hook(
)
"""
if not _TORCH_GREATER_EQUAL_1_8:
rank_zero_warn(
"Not registering DDP comm hook. "
"To use communication hooks, please use pytorch>=1.8.0."
)
rank_zero_warn("Not registering DDP comm hook. To use communication hooks, please use pytorch>=1.8.0.")
return
if ddp_comm_hook is None:
return
if ddp_comm_wrapper is not None:
if not _TORCH_GREATER_EQUAL_1_9:
rank_zero_warn(
"Not applying DDP comm wrapper. "
"To use communication wrapper, please use pytorch>=1.9.0."
)
rank_zero_warn("Not applying DDP comm wrapper. To use communication wrapper, please use pytorch>=1.9.0.")
else:
rank_zero_info(
f"DDP comm wrapper is provided, apply {ddp_comm_wrapper.__qualname__}({ddp_comm_hook.__qualname__})."
Expand All @@ -318,3 +313,9 @@ def register_ddp_comm_hook(
state=ddp_comm_state,
hook=ddp_comm_hook,
)


def tpu_distributed() -> bool:
if _TPU_AVAILABLE:
return xm.xrt_world_size() > 1
return False
4 changes: 4 additions & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,7 @@ def _compare_version(package: str, op, version) -> bool:
_TORCHTEXT_AVAILABLE = _module_available("torchtext")
_TORCHVISION_AVAILABLE = _module_available('torchvision')
_XLA_AVAILABLE = _module_available("torch_xla")

from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402

_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
25 changes: 25 additions & 0 deletions tests/models/test_tpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,15 @@
from unittest import mock

import pytest
import torch
from torch.utils.data import DataLoader

import tests.helpers.pipelines as tpipes
import tests.helpers.utils as tutils
from pytorch_lightning import Trainer
from pytorch_lightning.accelerators import TPUAccelerator
from pytorch_lightning.callbacks import EarlyStopping
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.plugins import TPUSpawnPlugin
from pytorch_lightning.trainer.states import TrainerState
from pytorch_lightning.utilities import _TPU_AVAILABLE
Expand Down Expand Up @@ -416,3 +418,26 @@ def test_if_test_works_with_checkpoint_false(tmpdir):
trainer = Trainer(max_epochs=1, tpu_cores=8, default_root_dir=tmpdir, fast_dev_run=True, checkpoint_callback=False)
trainer.fit(model)
assert trainer.state == TrainerState.FINISHED, f"Training failed with {trainer.state}"


@RunIf(tpu=True)
@pl_multi_process_test
def test_tpu_sync_dist():
"""Test tpu spawn sync dist operation """

def test_sync_dist(rank):
tensor = torch.tensor([1.0])
training_type_plugin = TPUSpawnPlugin()

res = Result()
res.log(
"test_tensor",
tensor,
sync_fn=training_type_plugin.reduce,
sync_dist=True,
sync_dist_op=torch.distributed.ReduceOp.SUM
)

assert res["test_tensor"].item() == 8, "Result-Log does not work properly with TPU Spawn and Tensors"

xmp.spawn(test_sync_dist, nprocs=8, start_method='fork')