Skip to content

Commit ceb8bdf

Browse files
tchatonpre-commit-ci[bot]rohitgr7carmocca
authored andcommitted
Fix self.log(sync_dist=True, reduce_fx={mean,max}) (#9142)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]>
1 parent 2edd154 commit ceb8bdf

File tree

6 files changed

+104
-51
lines changed

6 files changed

+104
-51
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
77

88
## [1.4.5] - 2021-08-31
99

10+
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
11+
12+
1013
- Fixed not setting a default value for `max_epochs` if `max_time` was specified on the `Trainer` constructor ([#9072](https://github.com/PyTorchLightning/pytorch-lightning/pull/9072))
1114

1215
## [1.4.4] - 2021-08-24

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,33 @@ class MetricSource(LightningEnum):
4545
@dataclass
4646
class _Sync:
4747
fn: Optional[Callable] = None
48-
should: bool = False
48+
_should: bool = False
4949
rank_zero_only: bool = False
5050
op: Optional[str] = None
5151
group: Optional[Any] = None
5252

5353
def __post_init__(self) -> None:
54-
if self.fn is None:
55-
self.fn = self.no_op
54+
self._generate_sync_fn()
55+
56+
@property
57+
def should(self) -> bool:
58+
return self._should
59+
60+
@should.setter
61+
def should(self, should: bool) -> None:
62+
self._should = should
63+
# `self._fn` needs to be re-generated.
64+
self._generate_sync_fn()
65+
66+
def _generate_sync_fn(self) -> None:
67+
"""Used to compute the syncing function and cache it."""
68+
fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn
69+
# save the function as `_fn` as the meta are being re-created and the object references need to match.
70+
self._fn = partial(fn, reduce_op=self.op, group=self.group)
5671

5772
@property
5873
def __call__(self) -> Any:
59-
return (
60-
partial(self.fn, reduce_op=self.op, group=self.group)
61-
if self.should and not self.rank_zero_only
62-
else self.no_op
63-
)
74+
return self._fn
6475

6576
@staticmethod
6677
def no_op(value: Any, *_, **__) -> Any:
@@ -75,31 +86,28 @@ class _Metadata:
7586
logger: bool = True
7687
on_step: bool = False
7788
on_epoch: bool = True
78-
_reduce_fx: Callable = torch.mean
89+
reduce_fx: Callable = torch.mean
7990
enable_graph: bool = False
8091
dataloader_idx: Optional[int] = None
8192
metric_attribute: Optional[str] = None
8293
_sync: Optional[_Sync] = None
8394

84-
@property
85-
def reduce_fx(self) -> Callable:
86-
return self._reduce_fx
95+
def __post_init__(self) -> None:
96+
self._parse_reduce_fx()
8797

88-
@reduce_fx.setter
89-
def reduce_fx(self, reduce_fx: Union[str, Callable]) -> None:
98+
def _parse_reduce_fx(self) -> None:
9099
error = (
91100
"Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported."
92101
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`."
93-
f" Found: {reduce_fx}"
102+
f" Found: {self.reduce_fx}"
94103
)
95-
self._reduce_fx = reduce_fx
96-
if isinstance(reduce_fx, str):
97-
reduce_fx = reduce_fx.lower()
104+
if isinstance(self.reduce_fx, str):
105+
reduce_fx = self.reduce_fx.lower()
98106
if reduce_fx == "avg":
99107
reduce_fx = "mean"
100108
if reduce_fx not in ("min", "max", "mean", "sum"):
101109
raise MisconfigurationException(error)
102-
self._reduce_fx = getattr(torch, reduce_fx)
110+
self.reduce_fx = getattr(torch, reduce_fx)
103111
elif self.is_custom_reduction:
104112
raise MisconfigurationException(error)
105113

@@ -178,11 +186,11 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
178186
def update(self, value: _METRIC, batch_size: torch.Tensor) -> None:
179187
if self.is_tensor:
180188
value = value.float()
181-
self._forward_cache = value
182189
# performance: no need to accumulate on values only logged on_step
183190
if self.meta.on_step and not self.meta.on_epoch:
184-
self.value = self.meta.sync(value)
191+
self._forward_cache = self.value = self.meta.sync(value)
185192
return
193+
self._forward_cache = value
186194
# perform accumulation with reduction
187195
if self.meta.is_mean_reduction:
188196
self.value += value.mean() * batch_size
@@ -201,8 +209,7 @@ def compute(self) -> torch.Tensor:
201209
if self.meta.is_mean_reduction:
202210
cumulated_batch_size = self.meta.sync(self.cumulated_batch_size)
203211
return value / cumulated_batch_size
204-
elif self.meta.is_max_reduction or self.meta.is_min_reduction or self.meta.is_sum_reduction:
205-
return value
212+
return value
206213
return self.value.compute()
207214

208215
def reset(self) -> None:
@@ -448,12 +455,12 @@ def log(
448455
logger=logger,
449456
on_step=on_step,
450457
on_epoch=on_epoch,
458+
reduce_fx=reduce_fx,
451459
enable_graph=enable_graph,
452460
dataloader_idx=dataloader_idx,
453461
metric_attribute=metric_attribute,
454462
)
455-
meta.reduce_fx = reduce_fx
456-
meta.sync = _Sync(should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only)
463+
meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only)
457464

458465
# register logged value if it doesn't exist
459466
if key not in self:
@@ -669,6 +676,8 @@ def load_state_dict(
669676

670677
if not metrics:
671678
return
679+
680+
# iterate through result metrics and re-attached Metric references on reload.
672681
result_metrics = self.result_metrics
673682
for metric_attribute, metric in metrics.items():
674683
for result_metric in result_metrics:

pytorch_lightning/utilities/distributed.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -179,10 +179,14 @@ def sync_ddp(
179179
if group is None:
180180
group = torch.distributed.group.WORLD
181181

182-
op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM
183-
184-
if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
185-
divide_by_world_size = True
182+
if isinstance(reduce_op, str):
183+
if reduce_op.lower() in ("avg", "mean"):
184+
op = ReduceOp.SUM
185+
divide_by_world_size = True
186+
else:
187+
op = getattr(ReduceOp, reduce_op.upper())
188+
else:
189+
op = reduce_op
186190

187191
# sync all processes before reduction
188192
torch.distributed.barrier(group=group)

tests/core/test_metric_result_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import tests.helpers.utils as tutils
2424
from pytorch_lightning import Trainer
2525
from pytorch_lightning.callbacks import ModelCheckpoint
26-
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection
26+
from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection
2727
from tests.helpers import BoringModel
2828
from tests.helpers.runif import RunIf
2929

@@ -331,7 +331,7 @@ def on_save_checkpoint(self, checkpoint) -> None:
331331
# default sync fn
332332
new_results = ResultCollection(False, device)
333333
new_results.load_state_dict(state_dict, map_location="cpu")
334-
assert new_results["validation_step.v"].meta.sync.fn == _Sync.no_op
334+
assert new_results["validation_step.v"].meta.sync.fn is None
335335

336336
# check map location
337337
assert new_results["validation_step.v"].value.device.type == "cpu"

tests/core/test_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ def _setup_ddp(rank, worldsize):
3939
def _ddp_test_fn(rank, worldsize):
4040
_setup_ddp(rank, worldsize)
4141
tensor = torch.tensor([1.0])
42-
sync = _Sync(sync_ddp_if_available, should=True, op="SUM")
42+
sync = _Sync(sync_ddp_if_available, _should=True, op="SUM")
4343
actual = sync(tensor)
4444
assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"
4545

tests/trainer/logging_/test_train_loop_logging.py

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -360,37 +360,74 @@ def get_expected(on_epoch, values):
360360
assert is_included if should_include else not is_included
361361

362362

363-
@pytest.mark.parametrize("gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1))])
363+
class LoggingSyncDistModel(BoringModel):
364+
def __init__(self, fake_result):
365+
super().__init__()
366+
self.fake_result = fake_result
367+
368+
@property
369+
def rank(self) -> int:
370+
return self.trainer.global_rank
371+
372+
def training_step(self, batch, batch_idx):
373+
value = self.fake_result + self.rank
374+
self.log("foo", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum")
375+
self.log("foo_2", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum")
376+
self.log("foo_3", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean")
377+
self.log("foo_4", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean")
378+
self.log("foo_5", batch_idx + self.rank, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="max")
379+
380+
self.log("foo_6", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
381+
self.log("foo_7", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
382+
self.log("foo_8", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
383+
self.log("foo_9", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
384+
self.log("foo_10", batch_idx, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max")
385+
return super().training_step(batch, batch_idx)
386+
387+
def validation_step(self, batch, batch_idx):
388+
self.log("bar", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
389+
self.log("bar_2", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
390+
self.log("bar_3", batch_idx + self.rank, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max")
391+
return super().validation_step(batch, batch_idx)
392+
393+
394+
@pytest.mark.parametrize(
395+
"gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1)), pytest.param(2, marks=RunIf(min_gpus=2))]
396+
)
364397
def test_logging_sync_dist_true(tmpdir, gpus):
365398
"""
366399
Tests to ensure that the sync_dist flag works (should just return the original value)
367400
"""
368401
fake_result = 1
369-
370-
class TestModel(BoringModel):
371-
def training_step(self, batch, batch_idx):
372-
self.log("foo", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
373-
self.log("foo_2", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
374-
return super().training_step(batch, batch_idx)
375-
376-
def validation_step(self, batch, batch_idx):
377-
self.log("bar", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
378-
return super().validation_step(batch, batch_idx)
379-
380-
model = TestModel()
402+
model = LoggingSyncDistModel(fake_result)
381403
trainer = Trainer(
404+
max_epochs=1,
382405
default_root_dir=tmpdir,
383-
limit_train_batches=1,
384-
limit_val_batches=1,
385-
max_epochs=2,
406+
limit_train_batches=3,
407+
limit_val_batches=3,
386408
weights_summary=None,
387409
gpus=gpus,
388410
)
389411
trainer.fit(model)
390412

391-
assert trainer.logged_metrics["foo"] == fake_result
392-
assert trainer.logged_metrics["foo_2"] == 2
393-
assert trainer.logged_metrics["bar"] == fake_result
413+
num_devices = 1 if gpus is None else gpus
414+
use_multiple_devices = num_devices > 1
415+
total = fake_result * num_devices + 1
416+
417+
metrics = trainer.callback_metrics
418+
assert metrics["foo"] == total if use_multiple_devices else fake_result
419+
assert metrics["foo_2"] == 2 * num_devices
420+
assert metrics["foo_3"] == 2
421+
assert metrics["foo_4"] == total / num_devices if use_multiple_devices else 1
422+
assert metrics["foo_5"] == fake_result * 2 + 1 if use_multiple_devices else fake_result * 2
423+
assert metrics["foo_6"] == fake_result * 3 * 2 + 3 if use_multiple_devices else fake_result * 3 * 2
424+
assert metrics["foo_7"] == 2 * num_devices * 3
425+
assert metrics["foo_8"] == 2
426+
assert metrics["foo_9"] == (fake_result * 2 + 1) / num_devices if use_multiple_devices else fake_result
427+
assert metrics["foo_10"] == 2
428+
assert metrics["bar"] == fake_result * 3 * num_devices
429+
assert metrics["bar_2"] == fake_result
430+
assert metrics["bar_3"] == 2 + int(use_multiple_devices)
394431

395432

396433
@RunIf(min_gpus=2, special=True)

0 commit comments

Comments
 (0)