Skip to content

Commit a332ee0

Browse files
committed
Merge branch 'all-rank-logger' of https://github.com/edward-io/pytorch-lightning into all-rank-logger
2 parents 7dc09cf + 07d94ad commit a332ee0

File tree

8 files changed

+92
-20
lines changed

8 files changed

+92
-20
lines changed

CHANGELOG.md

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
121121
- Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719))
122122

123123

124+
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333))
125+
124126

125127
## [1.4.0] - 2021-07-27
126128

@@ -2591,4 +2593,4 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
25912593

25922594
## [0.2.x] - 2019-07-09
25932595

2594-
## [0.1.x] - 2019-06-DD
2596+
## [0.1.x] - 2019-06-DD

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from pytorch_lightning.loops.dataloader import DataLoaderLoop
2121
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
2222
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
23-
from pytorch_lightning.trainer.states import TrainerFn
2423
from pytorch_lightning.utilities.model_helpers import is_overridden
2524
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
2625

@@ -206,10 +205,6 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
206205
else:
207206
self.trainer.call_hook("on_validation_end", *args, **kwargs)
208207

209-
if self.trainer.state.fn != TrainerFn.FITTING:
210-
# summarize profile results
211-
self.trainer.profiler.describe()
212-
213208
# reset any `torchmetrics.Metric` and the logger connector state
214209
self.trainer.logger_connector.reset(metrics=True)
215210

pytorch_lightning/loops/dataloader/prediction_loop.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
119119
Returns:
120120
the results for all dataloaders
121121
"""
122-
self.trainer.profiler.describe()
123-
124122
results = self.predictions
125123

126124
self.trainer.call_hook("on_predict_epoch_end", results)

pytorch_lightning/loops/fit_loop.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,6 @@ def on_run_end(self) -> None:
225225
# hook
226226
self.trainer.call_hook("on_train_end")
227227

228-
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
229-
# It might be related to xla tensors blocked when moving the cpu
230-
# kill loggers
231-
if self.trainer.logger is not None:
232-
self.trainer.logger.finalize("success")
233-
234-
# summarize profile results
235-
self.trainer.profiler.describe()
236-
237228
# give accelerators a chance to finish
238229
self.trainer.accelerator.on_train_end()
239230

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ def new_process(self, process_idx, trainer, mp_queue):
203203
# persist info in ddp_spawn
204204
self.transfer_distrib_spawn_state_on_fit_end(results)
205205

206+
# ensure that spawned processes go through teardown before joining
207+
trainer._call_teardown_hook()
208+
206209
def post_dispatch(self):
207210
# restore main state with best weights
208211
best_path = self.mp_queue.get()

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
172172
if self.local_rank == 0:
173173
time.sleep(2)
174174

175+
# ensure that spawned processes go through teardown before joining
176+
trainer._call_teardown_hook()
177+
175178
@parameter_validation
176179
def model_to_device(self) -> None:
177180
self.model = self.wrapped_model.to(self.root_device)

pytorch_lightning/trainer/trainer.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
)
7777
from pytorch_lightning.utilities.debugging import InternalDebugger
7878
from pytorch_lightning.utilities.distributed import distributed_available
79+
from pytorch_lightning.utilities.enums import DistributedType
7980
from pytorch_lightning.utilities.exceptions import MisconfigurationException
8081
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
8182
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -944,8 +945,10 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
944945
if self.state.fn == TrainerFn.FITTING:
945946
self.call_hook("on_fit_end")
946947

947-
# teardown
948-
self._call_teardown_hook()
948+
# teardown if necessary (similar calls for spawn plugins are excluded as they have
949+
# been included at the end of `new_process` functions)
950+
if self._distrib_type not in DistributedType.interactive_compatible_types():
951+
self._call_teardown_hook()
949952

950953
if self.state.status != TrainerStatus.INTERRUPTED:
951954
self.state.status = TrainerStatus.FINISHED
@@ -1211,7 +1214,7 @@ def _call_teardown_hook(self) -> None:
12111214

12121215
if self.datamodule is not None:
12131216
self.datamodule.teardown(stage=fn)
1214-
self.profiler.teardown(stage=fn)
1217+
12151218
self.teardown(stage=fn)
12161219
self.lightning_module.teardown(stage=fn)
12171220

@@ -1220,6 +1223,14 @@ def _call_teardown_hook(self) -> None:
12201223
# these could have become stale if metrics are defined in `setup`
12211224
self.lightning_module._metric_attributes = None
12221225

1226+
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
1227+
# It might be related to xla tensors blocked when moving the cpu kill loggers.
1228+
if self.logger is not None:
1229+
self.logger.finalize("success")
1230+
1231+
# summarize profile results
1232+
self.profiler.describe()
1233+
12231234
def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
12241235
if self.lightning_module:
12251236
prev_fx_name = self.lightning_module._current_fx_name

tests/trainer/logging_/test_distributed_logging.py

Lines changed: 69 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from typing import Any, Dict, Optional, Union
1616
from unittest.mock import Mock
1717

18+
import pytorch_lightning as pl
1819
from pytorch_lightning import Callback, Trainer
1920
from pytorch_lightning.loggers.base import LightningLoggerBase
2021
from tests.helpers import BoringModel
@@ -32,15 +33,18 @@ def __init__(self):
3233
self.logs = {}
3334
self.exp = object()
3435

36+
@property
3537
def experiment(self) -> Any:
3638
return self.exp
3739

3840
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
3941
self.logs.update(metrics)
4042

43+
@property
4144
def version(self) -> Union[int, str]:
4245
return 1
4346

47+
@property
4448
def name(self) -> str:
4549
return "AllRank"
4650

@@ -133,3 +137,68 @@ def on_train_start(self, trainer, pl_module):
133137
callbacks=[LoggerCallsObserver()],
134138
)
135139
trainer.fit(model)
140+
141+
142+
def test_logger_after_fit_predict_test_calls(tmpdir):
143+
"""
144+
Make sure logger outputs are finalized after fit, prediction, and test calls.
145+
"""
146+
147+
class BufferLogger(LightningLoggerBase):
148+
def __init__(self):
149+
super().__init__()
150+
self.buffer = {}
151+
self.logs = {}
152+
153+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
154+
self.buffer.update(metrics)
155+
156+
def finalize(self, status: str) -> None:
157+
self.logs.update(self.buffer)
158+
self.buffer = {}
159+
160+
@property
161+
def experiment(self) -> Any:
162+
return None
163+
164+
@property
165+
def version(self) -> Union[int, str]:
166+
return 1
167+
168+
@property
169+
def name(self) -> str:
170+
return "BufferLogger"
171+
172+
def log_hyperparams(self, *args, **kwargs) -> None:
173+
return None
174+
175+
class LoggerCallsObserver(Callback):
176+
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
177+
trainer.logger.log_metrics({"fit": 1})
178+
179+
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
180+
trainer.logger.log_metrics({"validate": 1})
181+
182+
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
183+
trainer.logger.log_metrics({"predict": 1})
184+
185+
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
186+
trainer.logger.log_metrics({"test": 1})
187+
188+
model = BoringModel()
189+
trainer = Trainer(
190+
default_root_dir=tmpdir,
191+
limit_train_batches=1,
192+
limit_val_batches=1,
193+
max_epochs=1,
194+
logger=BufferLogger(),
195+
callbacks=[LoggerCallsObserver()],
196+
)
197+
198+
assert not trainer.logger.logs
199+
trainer.fit(model)
200+
assert trainer.logger.logs == {"fit": 1, "validate": 1}
201+
trainer.test(model)
202+
assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1}
203+
trainer.predict(model)
204+
assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1, "predict": 1}

0 commit comments

Comments
 (0)