Skip to content

Commit 045c879

Browse files
tchatonpre-commit-ci[bot]rohitgr7carmocca
authored
Fix self.log(sync_dist=True, reduce_fx={mean,max}) (#9142)
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Carlos Mocholi <[email protected]>
1 parent b5fb49a commit 045c879

File tree

6 files changed

+104
-51
lines changed

6 files changed

+104
-51
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -241,6 +241,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
241241
- Fixed a bug in the binary search mode of auto batch size scaling where exception was thrown if the first trainer run resulted in OOM ([#8954](https://github.com/PyTorchLightning/pytorch-lightning/pull/8954))
242242

243243

244+
- Fixed reduction using `self.log(sync_dict=True, reduce_fx={mean,max})` ([#9142](https://github.com/PyTorchLightning/pytorch-lightning/pull/9142))
245+
246+
244247
- Fixed not setting a default value for `max_epochs` if `max_time` was specified on the `Trainer` constructor ([#9072](https://github.com/PyTorchLightning/pytorch-lightning/pull/9072))
245248

246249

pytorch_lightning/trainer/connectors/logger_connector/result.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -45,22 +45,33 @@ class MetricSource(LightningEnum):
4545
@dataclass
4646
class _Sync:
4747
fn: Optional[Callable] = None
48-
should: bool = False
48+
_should: bool = False
4949
rank_zero_only: bool = False
5050
op: Optional[str] = None
5151
group: Optional[Any] = None
5252

5353
def __post_init__(self) -> None:
54-
if self.fn is None:
55-
self.fn = self.no_op
54+
self._generate_sync_fn()
55+
56+
@property
57+
def should(self) -> bool:
58+
return self._should
59+
60+
@should.setter
61+
def should(self, should: bool) -> None:
62+
self._should = should
63+
# `self._fn` needs to be re-generated.
64+
self._generate_sync_fn()
65+
66+
def _generate_sync_fn(self) -> None:
67+
"""Used to compute the syncing function and cache it."""
68+
fn = self.no_op if self.fn is None or not self.should or self.rank_zero_only else self.fn
69+
# save the function as `_fn` as the meta are being re-created and the object references need to match.
70+
self._fn = partial(fn, reduce_op=self.op, group=self.group)
5671

5772
@property
5873
def __call__(self) -> Any:
59-
return (
60-
partial(self.fn, reduce_op=self.op, group=self.group)
61-
if self.should and not self.rank_zero_only
62-
else self.no_op
63-
)
74+
return self._fn
6475

6576
@staticmethod
6677
def no_op(value: Any, *_, **__) -> Any:
@@ -75,31 +86,28 @@ class _Metadata:
7586
logger: bool = True
7687
on_step: bool = False
7788
on_epoch: bool = True
78-
_reduce_fx: Callable = torch.mean
89+
reduce_fx: Callable = torch.mean
7990
enable_graph: bool = False
8091
dataloader_idx: Optional[int] = None
8192
metric_attribute: Optional[str] = None
8293
_sync: Optional[_Sync] = None
8394

84-
@property
85-
def reduce_fx(self) -> Callable:
86-
return self._reduce_fx
95+
def __post_init__(self) -> None:
96+
self._parse_reduce_fx()
8797

88-
@reduce_fx.setter
89-
def reduce_fx(self, reduce_fx: Union[str, Callable]) -> None:
98+
def _parse_reduce_fx(self) -> None:
9099
error = (
91100
"Only `self.log(..., reduce_fx={min,max,mean,sum})` are currently supported."
92101
" Please, open an issue in `https://github.com/PyTorchLightning/pytorch-lightning/issues`."
93-
f" Found: {reduce_fx}"
102+
f" Found: {self.reduce_fx}"
94103
)
95-
self._reduce_fx = reduce_fx
96-
if isinstance(reduce_fx, str):
97-
reduce_fx = reduce_fx.lower()
104+
if isinstance(self.reduce_fx, str):
105+
reduce_fx = self.reduce_fx.lower()
98106
if reduce_fx == "avg":
99107
reduce_fx = "mean"
100108
if reduce_fx not in ("min", "max", "mean", "sum"):
101109
raise MisconfigurationException(error)
102-
self._reduce_fx = getattr(torch, reduce_fx)
110+
self.reduce_fx = getattr(torch, reduce_fx)
103111
elif self.is_custom_reduction:
104112
raise MisconfigurationException(error)
105113

@@ -178,11 +186,11 @@ def __init__(self, metadata: _Metadata, is_tensor: bool) -> None:
178186
def update(self, value: _METRIC, batch_size: torch.Tensor) -> None:
179187
if self.is_tensor:
180188
value = value.float()
181-
self._forward_cache = value
182189
# performance: no need to accumulate on values only logged on_step
183190
if self.meta.on_step and not self.meta.on_epoch:
184-
self.value = self.meta.sync(value)
191+
self._forward_cache = self.value = self.meta.sync(value)
185192
return
193+
self._forward_cache = value
186194
# perform accumulation with reduction
187195
if self.meta.is_mean_reduction:
188196
self.value += value.mean() * batch_size
@@ -201,8 +209,7 @@ def compute(self) -> torch.Tensor:
201209
if self.meta.is_mean_reduction:
202210
cumulated_batch_size = self.meta.sync(self.cumulated_batch_size)
203211
return value / cumulated_batch_size
204-
elif self.meta.is_max_reduction or self.meta.is_min_reduction or self.meta.is_sum_reduction:
205-
return value
212+
return value
206213
return self.value.compute()
207214

208215
def reset(self) -> None:
@@ -449,12 +456,12 @@ def log(
449456
logger=logger,
450457
on_step=on_step,
451458
on_epoch=on_epoch,
459+
reduce_fx=reduce_fx,
452460
enable_graph=enable_graph,
453461
dataloader_idx=dataloader_idx,
454462
metric_attribute=metric_attribute,
455463
)
456-
meta.reduce_fx = reduce_fx
457-
meta.sync = _Sync(should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only)
464+
meta.sync = _Sync(_should=sync_dist, fn=sync_dist_fn, group=sync_dist_group, rank_zero_only=rank_zero_only)
458465

459466
# register logged value if it doesn't exist
460467
if key not in self:
@@ -680,6 +687,8 @@ def load_state_dict(
680687

681688
if not metrics:
682689
return
690+
691+
# iterate through result metrics and re-attached Metric references on reload.
683692
result_metrics = self.result_metrics
684693
for metric_attribute, metric in metrics.items():
685694
for result_metric in result_metrics:

pytorch_lightning/utilities/distributed.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -181,10 +181,14 @@ def sync_ddp(
181181
if group is None:
182182
group = torch.distributed.group.WORLD
183183

184-
op = reduce_op if isinstance(reduce_op, ReduceOp) else ReduceOp.SUM
185-
186-
if isinstance(reduce_op, str) and reduce_op.lower() in ("avg", "mean"):
187-
divide_by_world_size = True
184+
if isinstance(reduce_op, str):
185+
if reduce_op.lower() in ("avg", "mean"):
186+
op = ReduceOp.SUM
187+
divide_by_world_size = True
188+
else:
189+
op = getattr(ReduceOp, reduce_op.upper())
190+
else:
191+
op = reduce_op
188192

189193
# sync all processes before reduction
190194
torch.distributed.barrier(group=group)

tests/core/test_metric_result_integration.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727
import tests.helpers.utils as tutils
2828
from pytorch_lightning import Trainer
2929
from pytorch_lightning.callbacks import ModelCheckpoint
30-
from pytorch_lightning.trainer.connectors.logger_connector.result import _Sync, MetricSource, ResultCollection
30+
from pytorch_lightning.trainer.connectors.logger_connector.result import MetricSource, ResultCollection
3131
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _TORCH_GREATER_EQUAL_1_7
3232
from tests.helpers import BoringModel
3333
from tests.helpers.runif import RunIf
@@ -336,7 +336,7 @@ def on_save_checkpoint(self, checkpoint) -> None:
336336
# default sync fn
337337
new_results = ResultCollection(False, device)
338338
new_results.load_state_dict(state_dict, map_location="cpu")
339-
assert new_results["validation_step.v"].meta.sync.fn == _Sync.no_op
339+
assert new_results["validation_step.v"].meta.sync.fn is None
340340

341341
# check map location
342342
assert new_results["validation_step.v"].value.device.type == "cpu"

tests/core/test_results.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ def _setup_ddp(rank, worldsize):
3333
def _ddp_test_fn(rank, worldsize):
3434
_setup_ddp(rank, worldsize)
3535
tensor = torch.tensor([1.0])
36-
sync = _Sync(sync_ddp_if_available, should=True, op="SUM")
36+
sync = _Sync(sync_ddp_if_available, _should=True, op="SUM")
3737
actual = sync(tensor)
3838
assert actual.item() == dist.get_world_size(), "Result-Log does not work properly with DDP and Tensors"
3939

tests/trainer/logging_/test_train_loop_logging.py

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -359,37 +359,74 @@ def get_expected(on_epoch, values):
359359
assert is_included if should_include else not is_included
360360

361361

362-
@pytest.mark.parametrize("gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1))])
362+
class LoggingSyncDistModel(BoringModel):
363+
def __init__(self, fake_result):
364+
super().__init__()
365+
self.fake_result = fake_result
366+
367+
@property
368+
def rank(self) -> int:
369+
return self.trainer.global_rank
370+
371+
def training_step(self, batch, batch_idx):
372+
value = self.fake_result + self.rank
373+
self.log("foo", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum")
374+
self.log("foo_2", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="sum")
375+
self.log("foo_3", 2, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean")
376+
self.log("foo_4", value, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="mean")
377+
self.log("foo_5", batch_idx + self.rank, on_step=True, on_epoch=False, sync_dist=True, reduce_fx="max")
378+
379+
self.log("foo_6", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
380+
self.log("foo_7", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
381+
self.log("foo_8", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
382+
self.log("foo_9", value, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
383+
self.log("foo_10", batch_idx, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max")
384+
return super().training_step(batch, batch_idx)
385+
386+
def validation_step(self, batch, batch_idx):
387+
self.log("bar", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
388+
self.log("bar_2", self.fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="mean")
389+
self.log("bar_3", batch_idx + self.rank, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="max")
390+
return super().validation_step(batch, batch_idx)
391+
392+
393+
@pytest.mark.parametrize(
394+
"gpus", [None, pytest.param(1, marks=RunIf(min_gpus=1)), pytest.param(2, marks=RunIf(min_gpus=2))]
395+
)
363396
def test_logging_sync_dist_true(tmpdir, gpus):
364397
"""
365398
Tests to ensure that the sync_dist flag works (should just return the original value)
366399
"""
367400
fake_result = 1
368-
369-
class TestModel(BoringModel):
370-
def training_step(self, batch, batch_idx):
371-
self.log("foo", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
372-
self.log("foo_2", 2, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
373-
return super().training_step(batch, batch_idx)
374-
375-
def validation_step(self, batch, batch_idx):
376-
self.log("bar", fake_result, on_step=False, on_epoch=True, sync_dist=True, reduce_fx="sum")
377-
return super().validation_step(batch, batch_idx)
378-
379-
model = TestModel()
401+
model = LoggingSyncDistModel(fake_result)
380402
trainer = Trainer(
403+
max_epochs=1,
381404
default_root_dir=tmpdir,
382-
limit_train_batches=1,
383-
limit_val_batches=1,
384-
max_epochs=2,
405+
limit_train_batches=3,
406+
limit_val_batches=3,
385407
weights_summary=None,
386408
gpus=gpus,
387409
)
388410
trainer.fit(model)
389411

390-
assert trainer.logged_metrics["foo"] == fake_result
391-
assert trainer.logged_metrics["foo_2"] == 2
392-
assert trainer.logged_metrics["bar"] == fake_result
412+
num_devices = 1 if gpus is None else gpus
413+
use_multiple_devices = num_devices > 1
414+
total = fake_result * num_devices + 1
415+
416+
metrics = trainer.callback_metrics
417+
assert metrics["foo"] == total if use_multiple_devices else fake_result
418+
assert metrics["foo_2"] == 2 * num_devices
419+
assert metrics["foo_3"] == 2
420+
assert metrics["foo_4"] == total / num_devices if use_multiple_devices else 1
421+
assert metrics["foo_5"] == fake_result * 2 + 1 if use_multiple_devices else fake_result * 2
422+
assert metrics["foo_6"] == fake_result * 3 * 2 + 3 if use_multiple_devices else fake_result * 3 * 2
423+
assert metrics["foo_7"] == 2 * num_devices * 3
424+
assert metrics["foo_8"] == 2
425+
assert metrics["foo_9"] == (fake_result * 2 + 1) / num_devices if use_multiple_devices else fake_result
426+
assert metrics["foo_10"] == 2
427+
assert metrics["bar"] == fake_result * 3 * num_devices
428+
assert metrics["bar_2"] == fake_result
429+
assert metrics["bar_3"] == 2 + int(use_multiple_devices)
393430

394431

395432
@RunIf(min_gpus=2, special=True)

0 commit comments

Comments
 (0)