Skip to content

Commit b3fc662

Browse files
tchatoncarmocca
authored andcommitted
support number for logging with sync_dist=True (#5080)
* support number * add two tests * wip * add ddp in special test * remove a test * move device to bottom * simplify test * update test * Update pytorch_lightning/core/step_result.py Co-authored-by: Carlos Mocholí <[email protected]> * resolve sync_ddp Co-authored-by: Carlos Mocholí <[email protected]>
1 parent a50a1e4 commit b3fc662

File tree

5 files changed

+56
-13
lines changed

5 files changed

+56
-13
lines changed

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,6 +279,7 @@ def log(
279279
sync_dist_group,
280280
accelerator.sync_tensor,
281281
self._current_dataloader_idx,
282+
self.device,
282283
)
283284

284285
def log_dict(

pytorch_lightning/core/step_result.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -15,15 +15,15 @@
1515
"""[Train, Eval]Result for easier logging, checkpointing, early stopping, epoch-wise reduction."""
1616

1717
import numbers
18+
import os
1819
from copy import copy
19-
from typing import Optional, Dict, Union, Sequence, Callable, MutableMapping, Any, List, Tuple, Iterable
20+
from typing import Any, Callable, Dict, Iterable, List, MutableMapping, Optional, Sequence, Tuple, Union
2021

2122
import torch
2223
from torch import Tensor
23-
import os
2424

25-
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
2625
from pytorch_lightning.metrics import Metric
26+
from pytorch_lightning.utilities.distributed import sync_ddp_if_available
2727

2828

2929
class Result(Dict):
@@ -128,6 +128,7 @@ def log(
128128
sync_dist_group: Optional[Any] = None,
129129
sync_fn: Callable = None,
130130
dataloader_idx: Optional[int] = None,
131+
device: torch.device = None,
131132
):
132133
# no metrics should be logged with graphs
133134
if not enable_graph and isinstance(value, torch.Tensor):
@@ -138,7 +139,10 @@ def log(
138139
if sync_dist and isinstance(value, (torch.Tensor, numbers.Number)):
139140
is_dist_initialized = torch.distributed.is_available() and torch.distributed.is_initialized()
140141
# TODO: Find a way to make the reduction only once, so we don't need to clone.
141-
value = value.clone() if is_dist_initialized else value
142+
if is_dist_initialized and isinstance(value, torch.Tensor):
143+
value = value.clone()
144+
else:
145+
value = torch.tensor(value, device=device, dtype=torch.float)
142146
value = sync_fn(value, group=sync_dist_group, reduce_op=sync_dist_op)
143147

144148
if 'meta' not in self:

pytorch_lightning/utilities/distributed.py

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,14 @@
1515
import os
1616
import warnings
1717
from functools import wraps
18+
from typing import Any, Optional, Union
1819

1920
import torch
21+
2022
from pytorch_lightning import _logger as log
21-
from typing import Union, Optional, Any
2223

2324
if torch.distributed.is_available():
24-
from torch.distributed import ReduceOp
25-
from torch.distributed import group
25+
from torch.distributed import ReduceOp, group
2626
else:
2727
class ReduceOp:
2828
SUM = None
@@ -145,15 +145,14 @@ def sync_ddp(
145145
if group is None:
146146
group = torch.distributed.group.WORLD
147147

148-
if reduce_op is None:
149-
reduce_op = torch.distributed.ReduceOp.SUM
150-
elif isinstance(reduce_op, str) and reduce_op in ("avg", "mean"):
151-
reduce_op = torch.distributed.ReduceOp.SUM
148+
op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM
149+
150+
if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
152151
divide_by_world_size = True
153152

154153
# sync all processes before reduction
155154
torch.distributed.barrier(group=group)
156-
torch.distributed.all_reduce(result, op=reduce_op, group=group, async_op=False)
155+
torch.distributed.all_reduce(result, op=op, group=group, async_op=False)
157156

158157
if divide_by_world_size:
159158
result = result / torch.distributed.get_world_size(group)

tests/special_tests.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,4 @@ python ${DEFAULTS} tests/plugins/test_rpc_plugin.py::test_rpc_function_calls_ddp
1919
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual
2020
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_manual_amp
2121
python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_automatic
22-
# python ${DEFAULTS} tests/plugins/test_ddp_sequential_plugin.py::test_ddp_sequential_plugin_ddp_rpc_with_wrong_balance
22+
python ${DEFAULTS} tests/trainer/logging_tests/test_train_loop_logging_1_0.py::test_logging_sync_dist_true_ddp

tests/trainer/logging_tests/test_train_loop_logging_1_0.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
import collections
1919
import itertools
2020
import os
21+
import platform
2122
from unittest import mock
2223

2324
import numpy as np
@@ -685,6 +686,7 @@ class TestModel(BoringModel):
685686
def training_step(self, batch, batch_idx):
686687
acc = self.step(batch[0])
687688
self.log('foo', torch.tensor(fake_result), on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
689+
self.log('foo_2', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='sum')
688690
return acc
689691

690692
def validation_step(self, batch, batch_idx):
@@ -704,9 +706,46 @@ def validation_step(self, batch, batch_idx):
704706
trainer.fit(model)
705707

706708
assert trainer.logged_metrics['foo'] == fake_result
709+
assert trainer.logged_metrics['foo_2'] == 2
707710
assert trainer.logged_metrics['bar'] == fake_result
708711

709712

713+
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="test requires multi-GPU machine")
714+
@pytest.mark.skipif(not os.getenv("PL_RUNNING_SPECIAL_TESTS", '0') == '1',
715+
reason="test should be run outside of pytest")
716+
def test_logging_sync_dist_true_ddp(tmpdir):
717+
"""
718+
Tests to ensure that the sync_dist flag works with ddp
719+
"""
720+
class TestLoggingSyncDistModel(BoringModel):
721+
def training_step(self, batch, batch_idx):
722+
acc = self.step(batch[0])
723+
self.log('foo', 1, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='SUM')
724+
return acc
725+
726+
def validation_step(self, batch, batch_idx):
727+
self.training_step_called = True
728+
output = self.layer(batch)
729+
loss = self.loss(batch, output)
730+
self.log('bar', 2, on_step=False, on_epoch=True, sync_dist=True, sync_dist_op='AVG')
731+
return {"x": loss}
732+
733+
model = TestLoggingSyncDistModel()
734+
trainer = Trainer(
735+
default_root_dir=tmpdir,
736+
limit_train_batches=1,
737+
limit_val_batches=1,
738+
max_epochs=2,
739+
weights_summary=None,
740+
accelerator="ddp",
741+
gpus=2,
742+
)
743+
trainer.fit(model)
744+
745+
assert trainer.logged_metrics['foo'] == 2
746+
assert trainer.logged_metrics['bar'] == 2
747+
748+
710749
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
711750
def test_logging_sync_dist_true_gpu(tmpdir):
712751
"""

0 commit comments

Comments
 (0)