Skip to content

Commit 8fe5dd4

Browse files
awaelchlilexierule
authored andcommitted
fix logger creating directory structure too early in DDP (#6380)
* fix * add simple test * fix imports * add changelog * tighter test with on_fit_start hook closer to the dispatch call * move class inside test f unction * add a comment (cherry picked from commit fc6d402)
1 parent 0631827 commit 8fe5dd4

File tree

3 files changed

+176
-20
lines changed

3 files changed

+176
-20
lines changed

CHANGELOG.md

Lines changed: 132 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,138 @@ All notable changes to this project will be documented in this file.
55
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
66

77

8+
## [UnReleased] - 2021-MM-DD
9+
10+
### Added
11+
12+
- Added a way to print to terminal without breaking up the progress bar ([#5470](https://github.com/PyTorchLightning/pytorch-lightning/pull/5470))
13+
14+
15+
- Added `checkpoint` parameter to callback's `on_save_checkpoint` hook ([#6072](https://github.com/PyTorchLightning/pytorch-lightning/pull/6072))
16+
17+
18+
- Added `RunningStage.SANITY_CHECKING` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
19+
20+
21+
- Added `TrainerState.{FITTING,VALIDATING,TESTING,PREDICTING,TUNING}` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
22+
23+
24+
- Added `LightningEnvironment` for Lightning-specific DDP ([#5915](https://github.com/PyTorchLightning/pytorch-lightning/pull/5915))
25+
26+
27+
- Added arg to `self.log` that enables users to give custom names when dealing with multiple dataloaders ([#6274](https://github.com/PyTorchLightning/pytorch-lightning/pull/6274))
28+
29+
30+
- Added no return warning to predict ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
31+
32+
33+
### Changed
34+
35+
- Renamed `pytorch_lightning.callbacks.swa` to `pytorch_lightning.callbacks.stochastic_weight_avg` ([#6259](https://github.com/PyTorchLightning/pytorch-lightning/pull/6259))
36+
37+
38+
- Refactor `RunningStage` and `TrainerState` usage ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
39+
40+
41+
- Changed `trainer.evaluating` to return `True` if validating or testing ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
42+
43+
44+
- Changed `setup()` and `teardown()` stage argument to take any of `{fit,validate,test,predict}` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386))
45+
46+
47+
### Deprecated
48+
49+
50+
- Deprecated `trainer.running_sanity_check` in favor of `trainer.sanity_checking` ([#4945](https://github.com/PyTorchLightning/pytorch-lightning/pull/4945))
51+
52+
53+
### Removed
54+
55+
- Removed support for passing a bool value to `profiler` argument of Trainer ([#6164](https://github.com/PyTorchLightning/pytorch-lightning/pull/6164))
56+
57+
58+
- Removed no return warning from val/test step ([#6139](https://github.com/PyTorchLightning/pytorch-lightning/pull/6139))
59+
60+
61+
- Removed passing a `ModelCheckpoint` instance to `Trainer(checkpoint_callback)` ([#6166](https://github.com/PyTorchLightning/pytorch-lightning/pull/6166))
62+
63+
64+
- Removed deprecated Trainer argument `enable_pl_optimizer` and `automatic_optimization` ([#6163](https://github.com/PyTorchLightning/pytorch-lightning/pull/6163))
65+
66+
67+
- Removed deprecated metrics ([#6161](https://github.com/PyTorchLightning/pytorch-lightning/pull/6161))
68+
* from `pytorch_lightning.metrics.functional.classification` removed `to_onehot`, `to_categorical`, `get_num_classes`, `roc`, `multiclass_roc`, `average_precision`, `precision_recall_curve`, `multiclass_precision_recall_curve`
69+
* from `pytorch_lightning.metrics.functional.reduction` removed `reduce`, `class_reduce`
70+
71+
72+
- Removed deprecated `ModelCheckpoint` arguments `prefix`, `mode="auto"` ([#6162](https://github.com/PyTorchLightning/pytorch-lightning/pull/6162))
73+
74+
75+
- Removed `mode='auto'` from `EarlyStopping` ([#6167](https://github.com/PyTorchLightning/pytorch-lightning/pull/6167))
76+
77+
78+
- Removed deprecated `LightningModule` `hparams` setter ([#6207](https://github.com/PyTorchLightning/pytorch-lightning/pull/6207))
79+
80+
81+
- Removed `optimizer_idx` argument from `training_step` in manual optimization ([#6093](https://github.com/PyTorchLightning/pytorch-lightning/pull/6093))
82+
83+
84+
### Fixed
85+
86+
- Made the `Plugin.reduce` method more consistent across all Plugins to reflect a mean-reduction by default ([#6011](https://github.com/PyTorchLightning/pytorch-lightning/pull/6011))
87+
88+
89+
- Move lightning module to correct device type when using LightningDistributedWrapper ([#6070](https://github.com/PyTorchLightning/pytorch-lightning/pull/6070))
90+
91+
92+
- Do not print top-k verbose log with `ModelCheckpoint(monitor=None)` ([#6109](https://github.com/PyTorchLightning/pytorch-lightning/pull/6109))
93+
94+
95+
- Fixed `ModelCheckpoint(monitor=None, save_last=True)` not saving checkpoints ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136))
96+
97+
98+
- Fixed `ModelCheckpoint(save_top_k=0, save_last=True)` not saving the `last` checkpoint ([#6136](https://github.com/PyTorchLightning/pytorch-lightning/pull/6136))
99+
100+
101+
- Expose DeepSpeed loss parameters to allow users to fix loss instability ([#6115](https://github.com/PyTorchLightning/pytorch-lightning/pull/6115))
102+
103+
104+
- Fixed `AttributeError` when `logger=None` on TPU ([#6221](https://github.com/PyTorchLightning/pytorch-lightning/pull/6221))
105+
106+
107+
- Fixed `ModelPruning(make_pruning_permanent=True)` pruning buffers getting removed when saved during training ([#6073](https://github.com/PyTorchLightning/pytorch-lightning/pull/6073))
108+
109+
110+
- Fixed `trainer.test` from `best_path` hangs after calling `trainer.fit` ([#6272](https://github.com/PyTorchLightning/pytorch-lightning/pull/6272))
111+
112+
113+
- Fixed duplicate logs appearing in console when using the python logging module ([#5509](https://github.com/PyTorchLightning/pytorch-lightning/pull/5509), [#6275](https://github.com/PyTorchLightning/pytorch-lightning/pull/6275))
114+
115+
116+
- Fixed `SingleTPU` calling `all_gather` ([#6296](https://github.com/PyTorchLightning/pytorch-lightning/pull/6296))
117+
118+
119+
- Fixed DP reduction with collection ([#6324](https://github.com/PyTorchLightning/pytorch-lightning/pull/6324))
120+
121+
122+
- Fixed `.teardown(stage='fit')` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386))
123+
124+
125+
- Fixed `.on_fit_{start,end}()` getting called during `trainer.test` ([#6386](https://github.com/PyTorchLightning/pytorch-lightning/pull/6386))
126+
127+
128+
- Fixed PyTorch Profiler with `emit_nvtx` ([#6260](https://github.com/PyTorchLightning/pytorch-lightning/pull/6260))
129+
130+
131+
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))
132+
133+
134+
- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688))
135+
136+
137+
- Fixed logger creating directory structure too early in DDP ([#6380](https://github.com/PyTorchLightning/pytorch-lightning/pull/6380))
138+
139+
8140
## [1.2.3] - 2021-03-09
9141

10142

@@ -23,9 +155,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
23155
- Fixed `Trainer` not resetting `lightning_optimizers` when calling `Trainer.fit()` multiple times ([#6372](https://github.com/PyTorchLightning/pytorch-lightning/pull/6372))
24156

25157

26-
- Fixed an issue where the tuner would not tune the learning rate if also tuning the batch size ([#4688](https://github.com/PyTorchLightning/pytorch-lightning/pull/4688))
27-
28-
29158
## [1.2.2] - 2021-03-02
30159

31160
### Added

pytorch_lightning/trainer/trainer.py

Lines changed: 8 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -407,21 +407,6 @@ def __init__(
407407
# Callback system
408408
self.on_init_end()
409409

410-
def setup_trainer(self, model: LightningModule):
411-
"""
412-
Sanity check a few things before starting actual training or testing.
413-
414-
Args:
415-
model: The model to run sanity test on.
416-
"""
417-
418-
# log hyper-parameters
419-
if self.logger is not None:
420-
# save exp to get started (this is where the first experiment logs are written)
421-
self.logger.log_hyperparams(model.hparams_initial)
422-
self.logger.log_graph(model)
423-
self.logger.save()
424-
425410
def fit(
426411
self,
427412
model: LightningModule,
@@ -471,8 +456,7 @@ def fit(
471456
# ----------------------------
472457
self.call_setup_hook(model)
473458
self.call_hook("on_before_accelerator_backend_setup", model)
474-
self.accelerator.setup(self, model)
475-
self.setup_trainer(model)
459+
self.accelerator.setup(self, model) # note: this sets up self.lightning_module
476460

477461
# ----------------------------
478462
# INSPECT THE CORE LOOPS
@@ -539,6 +523,13 @@ def fit(
539523
def pre_dispatch(self):
540524
self.accelerator.pre_dispatch()
541525

526+
# log hyper-parameters
527+
if self.logger is not None:
528+
# save exp to get started (this is where the first experiment logs are written)
529+
self.logger.log_hyperparams(self.lightning_module.hparams_initial)
530+
self.logger.log_graph(self.lightning_module)
531+
self.logger.save()
532+
542533
def post_dispatch(self):
543534
self.accelerator.post_dispatch()
544535
self.accelerator.teardown()

tests/trainer/logging_/test_distributed_logging.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,3 +69,39 @@ def test_global_zero_only_logging_ddp_spawn(tmpdir):
6969
weights_summary=None,
7070
)
7171
trainer.fit(model)
72+
73+
74+
def test_first_logger_call_in_subprocess(tmpdir):
75+
"""
76+
Test that the Trainer does not call the logger too early. Only when the worker processes are initialized
77+
do we have access to the rank and know which one is the main process.
78+
"""
79+
80+
class LoggerCallsObserver(Callback):
81+
82+
def on_fit_start(self, trainer, pl_module):
83+
# this hook is executed directly before Trainer.pre_dispatch
84+
# logger should not write any logs until this point
85+
assert not trainer.logger.method_calls
86+
assert not os.listdir(trainer.logger.save_dir)
87+
88+
def on_train_start(self, trainer, pl_module):
89+
assert trainer.logger.method_call
90+
trainer.logger.log_hyperparams.assert_called_once()
91+
trainer.logger.log_graph.assert_called_once()
92+
93+
logger = Mock()
94+
logger.version = "0"
95+
logger.name = "name"
96+
logger.save_dir = tmpdir
97+
98+
model = BoringModel()
99+
trainer = Trainer(
100+
default_root_dir=tmpdir,
101+
limit_train_batches=1,
102+
limit_val_batches=1,
103+
max_epochs=1,
104+
logger=logger,
105+
callbacks=[LoggerCallsObserver()]
106+
)
107+
trainer.fit(model)

0 commit comments

Comments
 (0)