Skip to content

Commit 25640bc

Browse files
tchatoncarmocca
authored andcommitted
support number for logging with sync_dist=True (#5080)
* support number * add two tests * wip * add ddp in special test * remove a test * move device to bottom * simplify test * update test * Update pytorch_lightning/core/step_result.py Co-authored-by: Carlos Mocholí <[email protected]> * resolve sync_ddp Co-authored-by: Carlos Mocholí <[email protected]>
1 parent e975c98 commit 25640bc

File tree

5 files changed

+56
-13
lines changed

5 files changed

+56
-13
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -276,6 +276,7 @@ def log(
276276
sync_dist_group,
277277
accelerator.sync_tensor,
278278
self._current_dataloader_idx,
279+
self.device,
279280
)
280281

281282
def log_dict(

pytorch_lightning/core/step_result.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515
"""[Train, Eval]Result for easier logging, checkpointing, early stopping, epoch-wise reduction."""
1616

1717
import numbers
18+
import os
1819
from copy import copy
19-
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any, List, Tuple, Iterable
20+
from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union
2021

2122
import torch
2223
from torch import Tensor
23-
import os
2424

25-
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
2625
from pytorch_lightning.metrics import Metric
26+
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
2727

2828

2929
class Result(Dict):
@@ -128,6 +128,7 @@ def log(
128128
sync_dist_group: Optional[Any] = None,
129129
sync_fn: Callable = None,
130130
dataloader_idx: Optional[int] = None,
131+
device: torch.device = None,
131132
):
132133
# no metrics should be logged with graphs
133134
if not enable_graph and isinstance(value, torch.Tensor):
@@ -138,7 +139,10 @@ def log(
138139
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
139140
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
140141
# TODO: Find a way to make the reduction only once, so we don't need to clone.
141-
value = value.clone() if is_dist_initialized else value
142+
if is_dist_initialized and isinstance(value, torch.Tensor):
143+
value = value.clone()
144+
else:
145+
value = torch.tensor(value, device=device, dtype=torch.float)
142146
value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)
143147

144148
if 'meta' not in self:

pytorch_lightning/utilities/distributed.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
import os
1616
import warnings
1717
from functools import wraps
18+
from typing import Any, Optional, Union
1819

1920
import torch
21+
2022
from pytorch_lightning import _logger as log
21-
from typing import Union, Optional, Any
2223

2324
if torch.distributed.is_available():
24-
from torch.distributed import ReduceOp
25-
from torch.distributed import group
25+
from torch.distributed import ReduceOp, group
2626
else:
2727
class ReduceOp:
2828
SUM = None
@@ -145,15 +145,14 @@ def sync_ddp(
145145
if group is None:
146146
group = torch.distributed.group.WORLD
147147

148-
if reduce_op is None:
149-
reduce_op = torch.distributed.ReduceOp.SUM
150-
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
151-
reduce_op = torch.distributed.ReduceOp.SUM
148+
op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM
149+
150+
if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
152151
divide_by_world_size = True
153152

154153
# sync all processes before reduction
155154
torch.distributed.barrier(group=group)
156-
torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False)
155+
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
157156

158157
if divide_by_world_size:
159158
result = result / torch.distributed.get_world_size(group)

tests/special_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
1919
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual
2020
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp
2121
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
22-
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
22+
python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp

tests/trainer/logging_tests/test_train_loop_logging_1_0.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import collections
1919
import itertools
2020
import os
21+
import platform
2122
from unittest import mock
2223

2324
import numpy as np
@@ -686,6 +687,7 @@ class TestModel(BoringModel):
686687
def training_step(self, batch, batch_idx):
687688
acc = self.step(batch[0])
688689
self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
690+
self.log('foo_2', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
689691
return acc
690692

691693
def validation_step(self, batch, batch_idx):
@@ -705,9 +707,46 @@ def validation_step(self, batch, batch_idx):
705707
trainer.fit(model)
706708

707709
assert trainer.logged_metrics['foo'] == fake_result
710+
assert trainer.logged_metrics['foo_2'] == 2
708711
assert trainer.logged_metrics['bar'] == fake_result
709712

710713

714+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
715+
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
716+
reason="test should be run outside of pytest")
717+
def test_logging_sync_dist_true_ddp(tmpdir):
718+
"""
719+
Tests to ensure that the sync_dist flag works with ddp
720+
"""
721+
class TestLoggingSyncDistModel(BoringModel):
722+
def training_step(self, batch, batch_idx):
723+
acc = self.step(batch[0])
724+
self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='SUM')
725+
return acc
726+
727+
def validation_step(self, batch, batch_idx):
728+
self.training_step_called = True
729+
output = self.layer(batch)
730+
loss = self.loss(batch, output)
731+
self.log('bar', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='AVG')
732+
return {"x": loss}
733+
734+
model = TestLoggingSyncDistModel()
735+
trainer = Trainer(
736+
default_root_dir=tmpdir,
737+
limit_train_batches=1,
738+
limit_val_batches=1,
739+
max_epochs=2,
740+
weights_summary=None,
741+
accelerator="ddp",
742+
gpus=2,
743+
)
744+
trainer.fit(model)
745+
746+
assert trainer.logged_metrics['foo'] == 2
747+
assert trainer.logged_metrics['bar'] == 2
748+
749+
711750
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
712751
def test_logging_sync_dist_true_gpu(tmpdir):
713752
"""

0 commit comments

Comments
 (0)