Skip to content

Commit 07d94ad

Browse files
authored
Merge branch 'master' into all-rank-logger
2 parents 00184af + efec3d4 commit 07d94ad

File tree

8 files changed

+89
-19
lines changed

8 files changed

+89
-19
lines changed

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
121121
- Fixed recursive call for `apply_to_collection(include_none=False)` ([#8719](https://github.com/PyTorchLightning/pytorch-lightning/pull/8719))
122122

123123

124+
- Fixed an issue with logger outputs not being finalized correctly after prediction runs ([#8333](https://github.com/PyTorchLightning/pytorch-lightning/issues/8333))
125+
124126

125127
## [1.4.0] - 2021-07-27
126128

pytorch_lightning/loops/dataloader/evaluation_loop.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from pytorch_lightning.loops.dataloader import DataLoaderLoop
2121
from pytorch_lightning.loops.epoch import EvaluationEpochLoop
2222
from pytorch_lightning.trainer.connectors.logger_connector.result import ResultCollection
23-
from pytorch_lightning.trainer.states import TrainerFn
2423
from pytorch_lightning.utilities.model_helpers import is_overridden
2524
from pytorch_lightning.utilities.types import EPOCH_OUTPUT
2625

@@ -206,10 +205,6 @@ def on_evaluation_end(self, *args: Any, **kwargs: Any) -> None:
206205
else:
207206
self.trainer.call_hook("on_validation_end", *args, **kwargs)
208207

209-
if self.trainer.state.fn != TrainerFn.FITTING:
210-
# summarize profile results
211-
self.trainer.profiler.describe()
212-
213208
# reset any `torchmetrics.Metric` and the logger connector state
214209
self.trainer.logger_connector.reset(metrics=True)
215210

pytorch_lightning/loops/dataloader/prediction_loop.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,6 @@ def on_predict_epoch_end(self) -> Optional[_PREDICT_OUTPUT]:
119119
Returns:
120120
the results for all dataloaders
121121
"""
122-
self.trainer.profiler.describe()
123-
124122
results = self.predictions
125123

126124
self.trainer.call_hook("on_predict_epoch_end", results)

pytorch_lightning/loops/fit_loop.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -225,15 +225,6 @@ def on_run_end(self) -> None:
225225
# hook
226226
self.trainer.call_hook("on_train_end")
227227

228-
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
229-
# It might be related to xla tensors blocked when moving the cpu
230-
# kill loggers
231-
if self.trainer.logger is not None:
232-
self.trainer.logger.finalize("success")
233-
234-
# summarize profile results
235-
self.trainer.profiler.describe()
236-
237228
# give accelerators a chance to finish
238229
self.trainer.accelerator.on_train_end()
239230

pytorch_lightning/plugins/training_type/ddp_spawn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ def new_process(self, process_idx, trainer, mp_queue):
203203
# persist info in ddp_spawn
204204
self.transfer_distrib_spawn_state_on_fit_end(results)
205205

206+
# ensure that spawned processes go through teardown before joining
207+
trainer._call_teardown_hook()
208+
206209
def post_dispatch(self):
207210
# restore main state with best weights
208211
best_path = self.mp_queue.get()

pytorch_lightning/plugins/training_type/tpu_spawn.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,6 +172,9 @@ def new_process(self, process_idx: int, trainer, mp_queue) -> None:
172172
if self.local_rank == 0:
173173
time.sleep(2)
174174

175+
# ensure that spawned processes go through teardown before joining
176+
trainer._call_teardown_hook()
177+
175178
@parameter_validation
176179
def model_to_device(self) -> None:
177180
self.model = self.wrapped_model.to(self.root_device)

pytorch_lightning/trainer/trainer.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@
7676
)
7777
from pytorch_lightning.utilities.debugging import InternalDebugger
7878
from pytorch_lightning.utilities.distributed import distributed_available
79+
from pytorch_lightning.utilities.enums import DistributedType
7980
from pytorch_lightning.utilities.exceptions import MisconfigurationException
8081
from pytorch_lightning.utilities.imports import _fault_tolerant_enabled
8182
from pytorch_lightning.utilities.model_helpers import is_overridden
@@ -944,8 +945,10 @@ def _run(self, model: "pl.LightningModule") -> Optional[Union[_EVALUATE_OUTPUT,
944945
if self.state.fn == TrainerFn.FITTING:
945946
self.call_hook("on_fit_end")
946947

947-
# teardown
948-
self._call_teardown_hook()
948+
# teardown if necessary (similar calls for spawn plugins are excluded as they have
949+
# been included at the end of `new_process` functions)
950+
if self._distrib_type not in DistributedType.interactive_compatible_types():
951+
self._call_teardown_hook()
949952

950953
if self.state.status != TrainerStatus.INTERRUPTED:
951954
self.state.status = TrainerStatus.FINISHED
@@ -1211,7 +1214,7 @@ def _call_teardown_hook(self) -> None:
12111214

12121215
if self.datamodule is not None:
12131216
self.datamodule.teardown(stage=fn)
1214-
self.profiler.teardown(stage=fn)
1217+
12151218
self.teardown(stage=fn)
12161219
self.lightning_module.teardown(stage=fn)
12171220

@@ -1220,6 +1223,14 @@ def _call_teardown_hook(self) -> None:
12201223
# these could have become stale if metrics are defined in `setup`
12211224
self.lightning_module._metric_attributes = None
12221225

1226+
# todo: TPU 8 cores hangs in flush with TensorBoard. Might do for all loggers.
1227+
# It might be related to xla tensors blocked when moving the cpu kill loggers.
1228+
if self.logger is not None:
1229+
self.logger.finalize("success")
1230+
1231+
# summarize profile results
1232+
self.profiler.describe()
1233+
12231234
def call_hook(self, hook_name: str, *args, **kwargs) -> Any:
12241235
if self.lightning_module:
12251236
prev_fx_name = self.lightning_module._current_fx_name

tests/trainer/logging_/test_distributed_logging.py

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,8 +13,10 @@
1313
# limitations under the License.
1414
import os
1515
from typing import Any, Dict, Optional, Union
16+
from unittest import mock
1617
from unittest.mock import Mock
1718

19+
import pytorch_lightning as pl
1820
from pytorch_lightning import Callback, Trainer
1921
from pytorch_lightning.loggers.base import LightningLoggerBase
2022
from tests.helpers import BoringModel
@@ -136,3 +138,68 @@ def on_train_start(self, trainer, pl_module):
136138
callbacks=[LoggerCallsObserver()],
137139
)
138140
trainer.fit(model)
141+
142+
143+
def test_logger_after_fit_predict_test_calls(tmpdir):
144+
"""
145+
Make sure logger outputs are finalized after fit, prediction, and test calls.
146+
"""
147+
148+
class BufferLogger(LightningLoggerBase):
149+
def __init__(self):
150+
super().__init__()
151+
self.buffer = {}
152+
self.logs = {}
153+
154+
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
155+
self.buffer.update(metrics)
156+
157+
def finalize(self, status: str) -> None:
158+
self.logs.update(self.buffer)
159+
self.buffer = {}
160+
161+
@property
162+
def experiment(self) -> Any:
163+
return None
164+
165+
@property
166+
def version(self) -> Union[int, str]:
167+
return 1
168+
169+
@property
170+
def name(self) -> str:
171+
return "BufferLogger"
172+
173+
def log_hyperparams(self, *args, **kwargs) -> None:
174+
return None
175+
176+
class LoggerCallsObserver(Callback):
177+
def on_fit_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
178+
trainer.logger.log_metrics({"fit": 1})
179+
180+
def on_validation_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
181+
trainer.logger.log_metrics({"validate": 1})
182+
183+
def on_predict_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
184+
trainer.logger.log_metrics({"predict": 1})
185+
186+
def on_test_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
187+
trainer.logger.log_metrics({"test": 1})
188+
189+
model = BoringModel()
190+
trainer = Trainer(
191+
default_root_dir=tmpdir,
192+
limit_train_batches=1,
193+
limit_val_batches=1,
194+
max_epochs=1,
195+
logger=BufferLogger(),
196+
callbacks=[LoggerCallsObserver()],
197+
)
198+
199+
assert not trainer.logger.logs
200+
trainer.fit(model)
201+
assert trainer.logger.logs == {"fit": 1, "validate": 1}
202+
trainer.test(model)
203+
assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1}
204+
trainer.predict(model)
205+
assert trainer.logger.logs == {"fit": 1, "validate": 1, "test": 1, "predict": 1}

0 commit comments

Comments
 (0)