Skip to content

Commit ecd7d79

Browse files
committed
Remove rank 0 restrictions from logger
.
1 parent 0cdf8ae commit ecd7d79

File tree

3 files changed

+53
-16
lines changed

3 files changed

+53
-16
lines changed

CHANGELOG.md

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5656

5757
- The accelerator and training type plugin `setup` hooks no longer have a `model` argument ([#8536](https://github.com/PyTorchLightning/pytorch-lightning/pull/8536))
5858

59+
60+
- Removed restrictions in the trainer that loggers can only log from rank 0. Existing logger behavior has not changed. ([#8608](https://github.com/PyTorchLightning/pytorch-lightning/pull/8608))
61+
62+
5963
### Deprecated
6064

6165
- Deprecated `LightningModule.summarize()` in favor of `pytorch_lightning.utilities.model_summary.summarize()`
@@ -2590,4 +2594,4 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
25902594

25912595
## [0.2.x] - 2019-07-09
25922596

2593-
## [0.1.x] - 2019-06-DD
2597+
## [0.1.x] - 2019-06-DD

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -102,9 +102,8 @@ def log_metrics(self, metrics: Dict[str, _METRIC], step: Optional[int] = None) -
102102
step = self.trainer.global_step
103103

104104
# log actual metrics
105-
if self.trainer.is_global_zero:
106-
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
107-
self.trainer.logger.save()
105+
self.trainer.logger.agg_and_log_metrics(scalar_metrics, step=step)
106+
self.trainer.logger.save()
108107

109108
self._logged_metrics.update(scalar_metrics)
110109

tests/trainer/logging_/test_distributed_logging.py

Lines changed: 46 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,31 +12,61 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import os
15+
from typing import Any, Dict, Optional, Union
1516
from unittest import mock
1617
from unittest.mock import Mock
1718

1819
from pytorch_lightning import Callback, Trainer
20+
from pytorch_lightning.loggers.base import LightningLoggerBase
21+
from pytorch_lightning.loggers import TensorBoardLogger
1922
from tests.helpers import BoringModel
2023
from tests.helpers.runif import RunIf
2124

2225

26+
class AllRankLogger(LightningLoggerBase):
27+
"""
28+
Logger to test all-rank logging (i.e. not just rank 0).
29+
Logs are saved to local variable `logs`.
30+
"""
31+
32+
def __init__(self):
33+
super().__init__()
34+
self.logs = {}
35+
self.exp = object()
36+
37+
def experiment(self) -> Any:
38+
return self.exp
39+
40+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
41+
self.logs.update(metrics)
42+
43+
def version(self) -> Union[int, str]:
44+
return 1
45+
46+
def name(self) -> str:
47+
return "AllRank"
48+
49+
def log_hyperparams(self, *args, **kwargs) -> None:
50+
pass
51+
52+
2353
class TestModel(BoringModel):
24-
def on_pretrain_routine_end(self) -> None:
25-
with mock.patch("pytorch_lightning.loggers.base.LightningLoggerBase.agg_and_log_metrics") as m:
26-
self.trainer.logger_connector.log_metrics({"a": 2})
27-
logged_times = m.call_count
28-
expected = int(self.trainer.is_global_zero)
29-
msg = f"actual logger called from non-global zero, logged_times: {logged_times}, expected: {expected}"
30-
assert logged_times == expected, msg
54+
log_name = "rank-{rank}"
55+
56+
def on_train_start(self):
57+
self.log(self.log_name.format(rank=self.local_rank), 0)
58+
59+
def on_train_end(self):
60+
assert self.log_name.format(rank=self.local_rank) in self.logger.logs, "Expected rank to be logged"
3161

3262

3363
@RunIf(skip_windows=True)
34-
def test_global_zero_only_logging_ddp_cpu(tmpdir):
64+
def test_all_rank_logging_ddp_cpu(tmpdir):
3565
"""
36-
Makes sure logging only happens from root zero
66+
Check that all ranks can be logged from
3767
"""
3868
model = TestModel()
39-
model.training_epoch_end = None
69+
all_rank_logger = AllRankLogger()
4070
trainer = Trainer(
4171
accelerator="ddp_cpu",
4272
num_processes=2,
@@ -45,16 +75,19 @@ def test_global_zero_only_logging_ddp_cpu(tmpdir):
4575
limit_val_batches=1,
4676
max_epochs=1,
4777
weights_summary=None,
78+
logger=all_rank_logger,
79+
log_every_n_steps=1,
4880
)
4981
trainer.fit(model)
5082

5183

5284
@RunIf(min_gpus=2)
53-
def test_global_zero_only_logging_ddp_spawn(tmpdir):
85+
def test_all_rank_logging_ddp_spawn(tmpdir):
5486
"""
55-
Makes sure logging only happens from root zero
87+
Check that all ranks can be logged from
5688
"""
5789
model = TestModel()
90+
all_rank_logger = AllRankLogger()
5891
model.training_epoch_end = None
5992
trainer = Trainer(
6093
accelerator="ddp_spawn",
@@ -63,6 +96,7 @@ def test_global_zero_only_logging_ddp_spawn(tmpdir):
6396
limit_train_batches=1,
6497
limit_val_batches=1,
6598
max_epochs=1,
99+
logger=all_rank_logger,
66100
weights_summary=None,
67101
)
68102
trainer.fit(model)

0 commit comments

Comments
 (0)