Skip to content

Commit 2103b5e

Browse files
authored
Move sync code from step result to lightning module [6/n] (#7651)
1 parent 0c958c5 commit 2103b5e

File tree

5 files changed

+44
-38
lines changed

5 files changed

+44
-38
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 37 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import copy
1818
import inspect
1919
import logging
20+
import numbers
2021
import os
2122
import tempfile
2223
import types
@@ -42,10 +43,11 @@
4243
from pytorch_lightning.utilities.apply_func import apply_to_collection, convert_to_tensors
4344
from pytorch_lightning.utilities.cloud_io import get_filesystem
4445
from pytorch_lightning.utilities.device_dtype_mixin import DeviceDtypeModuleMixin
46+
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed
4547
from pytorch_lightning.utilities.exceptions import MisconfigurationException
4648
from pytorch_lightning.utilities.parsing import AttributeDict, collect_init_args, save_hyperparameters
4749
from pytorch_lightning.utilities.signature_utils import is_param_in_hook_signature
48-
from pytorch_lightning.utilities.types import EPOCH_OUTPUT, STEP_OUTPUT
50+
from pytorch_lightning.utilities.types import _METRIC, EPOCH_OUTPUT, STEP_OUTPUT
4951
from pytorch_lightning.utilities.warnings import WarningCache
5052

5153
warning_cache = WarningCache()
@@ -336,6 +338,15 @@ def log(
336338
f"Logged key: {name} should not contain information about dataloader_idx."
337339
)
338340

341+
value = self.__sync(
342+
value,
343+
sync_fn=self.trainer.training_type_plugin.reduce,
344+
sync_dist=sync_dist,
345+
sync_dist_op=sync_dist_op,
346+
sync_dist_group=sync_dist_group,
347+
device=self.device,
348+
)
349+
339350
self._results.log(
340351
name,
341352
value,
@@ -345,12 +356,7 @@ def log(
345356
on_epoch=on_epoch,
346357
reduce_fx=reduce_fx,
347358
enable_graph=enable_graph,
348-
sync_dist=sync_dist,
349-
sync_dist_op=sync_dist_op,
350-
sync_dist_group=sync_dist_group,
351-
sync_fn=self.trainer.training_type_plugin.reduce,
352359
dataloader_idx=(self._current_dataloader_idx if add_dataloader_idx else None),
353-
device=self.device,
354360
)
355361

356362
def log_dict(
@@ -410,6 +416,31 @@ def log_dict(
410416
add_dataloader_idx=add_dataloader_idx
411417
)
412418

419+
@staticmethod
420+
def __sync(
421+
value: _METRIC,
422+
sync_fn: Optional[Callable] = None,
423+
sync_dist: bool = False,
424+
sync_dist_op: Union[Any, str] = 'mean',
425+
sync_dist_group: Optional[Any] = None,
426+
device: torch.device = None,
427+
) -> _METRIC:
428+
"""Sync across workers when using distributed training"""
429+
if not isinstance(value, (torch.Tensor, numbers.Number)):
430+
return value
431+
432+
sync_fn = sync_fn or sync_ddp_if_available
433+
dist_available = torch.distributed.is_available() and torch.distributed.is_initialized() or tpu_distributed()
434+
if not sync_dist or not dist_available:
435+
return value
436+
437+
# TODO: Find a way to make the reduction only once, so we don't need to clone.
438+
if isinstance(value, torch.Tensor):
439+
value = value.clone()
440+
else:
441+
value = torch.tensor(value, device=device, dtype=torch.float)
442+
return sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)
443+
413444
def write_prediction(
414445
self, name: str, value: Union[torch.Tensor, List[torch.Tensor]], filename: str = 'predictions.pt'
415446
):

pytorch_lightning/core/step_result.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -13,16 +13,13 @@
1313
# limitations under the License.
1414
"""Result class for easier logging and epoch-wise reduction."""
1515

16-
import numbers
1716
from copy import copy
1817
from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union
1918

2019
import torch
2120
from torch import Tensor
2221
from torchmetrics import Metric
2322

24-
from pytorch_lightning.utilities.distributed import sync_ddp_if_available, tpu_distributed
25-
2623

2724
class Result(Dict):
2825

@@ -86,29 +83,12 @@ def log(
8683
on_epoch: bool = True,
8784
reduce_fx: Callable = torch.mean,
8885
enable_graph: bool = False,
89-
sync_dist: bool = False,
90-
sync_dist_op: Union[Any, str] = 'mean',
91-
sync_dist_group: Optional[Any] = None,
92-
sync_fn: Callable = None,
9386
dataloader_idx: Optional[int] = None,
94-
device: torch.device = None,
9587
):
9688
# no metrics should be logged with graphs
9789
if not enable_graph and isinstance(value, torch.Tensor):
9890
value = value.detach()
9991

100-
# sync across workers when using distributed training
101-
sync_fn = sync_fn or sync_ddp_if_available
102-
103-
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
104-
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
105-
# TODO: Find a way to make the reduction only once, so we don't need to clone.
106-
if (is_dist_initialized or tpu_distributed()) and isinstance(value, torch.Tensor):
107-
value = value.clone()
108-
else:
109-
value = torch.tensor(value, device=device, dtype=torch.float)
110-
value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)
111-
11292
if isinstance(value, torch.Tensor) and value.device.type == "xla":
11393
value = value.cpu()
11494

pytorch_lightning/utilities/types.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,13 @@
1616
- Do not include any `_TYPE` suffix
1717
- Types used in public hooks (as those in the `LightningModule` and `Callback`) should be public (no trailing `_`)
1818
"""
19+
from numbers import Number
1920
from typing import Any, Dict, Iterator, List, Union
2021

2122
import torch
2223
from torchmetrics import Metric
2324

24-
_METRIC = Union[Metric, torch.Tensor, int, float]
25+
_METRIC = Union[Metric, torch.Tensor, Number]
2526
STEP_OUTPUT = Union[torch.Tensor, Dict[str, Any]]
2627
EPOCH_OUTPUT = List[STEP_OUTPUT]
2728
_EVALUATE_OUTPUT = List[Dict[str, float]] # 1 dict per DataLoader

tests/core/test_metric_result_integration.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,6 @@ def _ddp_test_fn(rank, worldsize):
9393
@RunIf(skip_windows=True)
9494
def test_result_reduce_ddp():
9595
"""Make sure result logging works with DDP"""
96-
tutils.reset_seed()
9796
tutils.set_random_master_port()
9897

9998
worldsize = 2

tests/core/test_results.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
from torch.utils.data import DataLoader
2222

2323
import tests.helpers.utils as tutils
24-
from pytorch_lightning import Trainer
24+
from pytorch_lightning import LightningModule, Trainer
2525
from pytorch_lightning.core.step_result import Result
2626
from tests.helpers import BoringDataModule, BoringModel
2727
from tests.helpers.runif import RunIf
@@ -36,24 +36,19 @@ def _setup_ddp(rank, worldsize):
3636
dist.init_process_group("gloo", rank=rank, world_size=worldsize)
3737

3838

39-
def _ddp_test_fn(rank, worldsize, result_cls: Result):
39+
def _ddp_test_fn(rank, worldsize):
4040
_setup_ddp(rank, worldsize)
4141
tensor = torch.tensor([1.0])
42-
43-
res = result_cls()
44-
res.log("test_tensor", tensor, sync_dist=True, sync_dist_op=torch.distributed.ReduceOp.SUM)
45-
46-
assert res["test_tensor"].item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"
42+
actual = LightningModule._LightningModule__sync(tensor, sync_dist=True, sync_dist_op=torch.distributed.ReduceOp.SUM)
43+
assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"
4744

4845

4946
@RunIf(skip_windows=True)
5047
def test_result_reduce_ddp():
5148
"""Make sure result logging works with DDP"""
52-
tutils.reset_seed()
5349
tutils.set_random_master_port()
54-
5550
worldsize = 2
56-
mp.spawn(_ddp_test_fn, args=(worldsize, Result), nprocs=worldsize)
51+
mp.spawn(_ddp_test_fn, args=(worldsize, ), nprocs=worldsize)
5752

5853

5954
@pytest.mark.parametrize(

0 commit comments

Comments
 (0)