Skip to content

Commit e10ebac

Browse files
committed
Remove rank 0 restrictions from logger
1 parent 4928dc5 commit e10ebac

File tree

3 files changed

+49
-16
lines changed

3 files changed

+49
-16
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5353

5454
- The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))
5555

56+
- Removed restrictions in the trainer that loggers can only log from rank 0. Existing logger behavior has not changed. ([#8608]
57+
(https://github.com/PyTorchLightning/pytorch-lightning/pull/8608))
58+
59+
5660
### Deprecated
5761

5862
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,8 @@ def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -
102102
step = self.trainer.global_step
103103

104104
# log actual metrics
105-
if self.trainer.is_global_zero:
106-
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
107-
self.trainer.logger.save()
105+
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
106+
self.trainer.logger.save()
108107

109108
self._logged_metrics.update(scalar_metrics)
110109

tests/trainer/logging_/test_distributed_logging.py

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414
import os
1515
from typing import Any, Dict, Optional, Union
16-
from unittest import mock
1716
from unittest.mock import Mock
1817

1918
import pytorch_lightning as pl
@@ -23,23 +22,50 @@
2322
from tests.helpers.runif import RunIf
2423

2524

25+
class AllRankLogger(LightningLoggerBase):
26+
"""
27+
Logger to test all-rank logging (i.e. not just rank 0).
28+
Logs are saved to local variable `logs`.
29+
"""
30+
31+
def __init__(self):
32+
super().__init__()
33+
self.logs = {}
34+
self.exp = object()
35+
36+
def experiment(self) -> Any:
37+
return self.exp
38+
39+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
40+
self.logs.update(metrics)
41+
42+
def version(self) -> Union[int, str]:
43+
return 1
44+
45+
def name(self) -> str:
46+
return "AllRank"
47+
48+
def log_hyperparams(self, *args, **kwargs) -> None:
49+
pass
50+
51+
2652
class TestModel(BoringModel):
27-
def on_pretrain_routine_end(self) -> None:
28-
with mock.patch("pytorch_lightning.loggers.base.LightningLoggerBase.agg_and_log_metrics") as m:
29-
self.trainer.logger_connector.log_metrics({"a": 2})
30-
logged_times = m.call_count
31-
expected = int(self.trainer.is_global_zero)
32-
msg = f"actual logger called from non-global zero, logged_times: {logged_times}, expected: {expected}"
33-
assert logged_times == expected, msg
53+
log_name = "rank-{rank}"
54+
55+
def on_train_start(self):
56+
self.log(self.log_name.format(rank=self.local_rank), 0)
57+
58+
def on_train_end(self):
59+
assert self.log_name.format(rank=self.local_rank) in self.logger.logs, "Expected rank to be logged"
3460

3561

3662
@RunIf(skip_windows=True)
37-
def test_global_zero_only_logging_ddp_cpu(tmpdir):
63+
def test_all_rank_logging_ddp_cpu(tmpdir):
3864
"""
39-
Makes sure logging only happens from root zero
65+
Check that all ranks can be logged from
4066
"""
4167
model = TestModel()
42-
model.training_epoch_end = None
68+
all_rank_logger = AllRankLogger()
4369
trainer = Trainer(
4470
accelerator="ddp_cpu",
4571
num_processes=2,
@@ -48,16 +74,19 @@ def test_global_zero_only_logging_ddp_cpu(tmpdir):
4874
limit_val_batches=1,
4975
max_epochs=1,
5076
weights_summary=None,
77+
logger=all_rank_logger,
78+
log_every_n_steps=1,
5179
)
5280
trainer.fit(model)
5381

5482

5583
@RunIf(min_gpus=2)
56-
def test_global_zero_only_logging_ddp_spawn(tmpdir):
84+
def test_all_rank_logging_ddp_spawn(tmpdir):
5785
"""
58-
Makes sure logging only happens from root zero
86+
Check that all ranks can be logged from
5987
"""
6088
model = TestModel()
89+
all_rank_logger = AllRankLogger()
6190
model.training_epoch_end = None
6291
trainer = Trainer(
6392
accelerator="ddp_spawn",
@@ -66,6 +95,7 @@ def test_global_zero_only_logging_ddp_spawn(tmpdir):
6695
limit_train_batches=1,
6796
limit_val_batches=1,
6897
max_epochs=1,
98+
logger=all_rank_logger,
6999
weights_summary=None,
70100
)
71101
trainer.fit(model)

0 commit comments

Comments
 (0)