Skip to content

Commit 9c8701f

Browse files
tchatonBordaananthsubSeanNarenwilliamFalcon
authored
[feat] Logging refactor 2/n - train (#4495)
* update logging * solve more bugs * replace Mapping by Dict * update on comments * resolve pep8 * Apply suggestions from code review Co-authored-by: ananthsub <[email protected]> * Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py Co-authored-by: Jirka Borovec <[email protected]> * update on comments * typo * update for coverage * update test * update * Update tests/models/test_hooks.py Co-authored-by: Sean Naren <[email protected]> * Update tests/models/test_hooks.py Co-authored-by: Sean Naren <[email protected]> * update on comments * remove deepcopy * remove useless look for * another small optim * extra optim * remove lastest optim, can be source of bug * resolve bug * add docstring * optimize coverage * Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py Co-authored-by: Jirka Borovec <[email protected]> * Update tests/trainer/logging_tests/test_distributed_logging.py Co-authored-by: Jirka Borovec <[email protected]> * Update pytorch_lightning/trainer/evaluation_loop.py Co-authored-by: Jirka Borovec <[email protected]> * Update tests/trainer/logging/test_logger_connector.py Co-authored-by: Jirka Borovec <[email protected]> * Update tests/trainer/logging_tests/test_train_loop_logging_1_0.py Co-authored-by: Jirka Borovec <[email protected]> * update on comments * update * update on comments * update parity speed * get it down to 0.65 * update * 0.8 max_dif Co-authored-by: Jirka Borovec <[email protected]> Co-authored-by: ananthsub <[email protected]> Co-authored-by: Sean Naren <[email protected]> Co-authored-by: William Falcon <[email protected]>
1 parent 62ea461 commit 9c8701f

File tree

15 files changed

+733
-257
lines changed

15 files changed

+733
-257
lines changed

benchmarks/test_parity.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
@pytest.mark.parametrize('cls_model,max_diff', [
1313
(ParityModuleRNN, 0.05),
14-
(ParityModuleMNIST, 0.70)
14+
(ParityModuleMNIST, 0.8)
1515
])
1616
@pytest.mark.skipif(not torch.cuda.is_available(), reason="test requires GPU machine")
1717
def test_pytorch_parity(tmpdir, cls_model, max_diff):

pytorch_lightning/trainer/connectors/logger_connector/callback_hook_validator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -192,7 +192,7 @@ def _on_validation_start_log():
192192
@staticmethod
193193
def _on_validation_end_log():
194194
"""Called when the validation loop ends."""
195-
return {"on_step": [False], "on_epoch": [False, True]}
195+
return None
196196

197197
@staticmethod
198198
def _on_test_start_log():

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Lines changed: 169 additions & 83 deletions
Large diffs are not rendered by default.

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 94 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -44,25 +44,14 @@ def __init__(self, trainer):
4444
self._callback_hook_validator = CallbackHookNameValidator()
4545
self._current_stage = None
4646

47-
def cached_results(self, stage_or_testing: Union[str, bool]) -> Union[EpochResultStore, None]:
48-
""" Function to access cached_results using str or bool. Bool is used only for testing"""
49-
stage_or_testing = str(stage_or_testing)
50-
stages = self._stages
51-
if stage_or_testing in self._stages:
52-
return self._cached_results[stage_or_testing]
53-
if stage_or_testing in LOOKUP_TABLE:
54-
# Acces using trainer.testing
55-
stage = LOOKUP_TABLE[stage_or_testing]
56-
return self._cached_results[stage]
57-
raise MisconfigurationException(
58-
f"Provide stage_or_testing {stage_or_testing} doesn't belong either to {self._stages}"
59-
f" or {LOOKUP_TABLE.keys()}"
60-
)
47+
@property
48+
def cached_results(self) -> Union[EpochResultStore, None]:
49+
return self._cached_results[self._current_stage]
6150

6251
def set_stage(self, stage_or_testing: str, reset:bool = False) -> None:
6352
self._current_stage = self._determine_stage(stage_or_testing)
6453
if reset:
65-
self.cached_results(stage_or_testing).reset()
54+
self.cached_results.reset()
6655

6756
def check_logging_in_callbacks(self, hook_fx_name, on_step: bool = None, on_epoch: bool = None) -> None:
6857
self._callback_hook_validator.check_logging_in_callbacks(current_hook_fx_name=hook_fx_name,
@@ -75,17 +64,17 @@ def on_evaluation_batch_start(self, testing, batch, dataloader_idx, num_dataload
7564
model._current_dataloader_idx = dataloader_idx if num_dataloaders > 1 else None
7665

7766
# track batch_size
78-
self.cached_results(testing)._batch_size = Result.extract_batch_size(batch)
67+
self.cached_results._batch_size = Result.extract_batch_size(batch)
7968

80-
def on_batch_start(self, split_idx: int, opt_idx: int, split_batch) -> None:
81-
self._cached_results["train"]._split_idx = split_idx
82-
self._cached_results["train"]._opt_idx = opt_idx
83-
self._cached_results["train"]._batch_size = Result.extract_batch_size(split_batch)
69+
def on_train_split_start(self, split_idx: int, opt_idx: int, split_batch) -> None:
70+
self.cached_results._split_idx = split_idx
71+
self.cached_results._opt_idx = opt_idx
72+
self.cached_results._batch_size = Result.extract_batch_size(split_batch)
8473

8574
def on_train_batch_end(self) -> None:
86-
self._cached_results["train"]._split_idx = None
87-
self._cached_results["train"]._opt_idx = None
88-
self._cached_results["train"]._batch_size = None
75+
self.cached_results._split_idx = None
76+
self.cached_results._opt_idx = None
77+
self.cached_results._batch_size = None
8978

9079
def _determine_stage(self, stage_or_testing: Union[str, bool]) -> str:
9180
stage_or_testing = str(stage_or_testing)
@@ -112,6 +101,16 @@ def on_trainer_init(self, logger, flush_logs_every_n_steps, log_every_n_steps):
112101
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
113102
self.trainer.log_every_n_steps = log_every_n_steps
114103

104+
@property
105+
def should_flush_logs(self):
106+
should_flush = (self.trainer.global_step + 1) % self.trainer.flush_logs_every_n_steps == 0
107+
return should_flush or self.trainer.should_stop
108+
109+
@property
110+
def should_update_logs(self):
111+
should_log_every_n_steps = (self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0
112+
return should_log_every_n_steps or self.trainer.should_stop
113+
115114
def configure_logger(self, logger):
116115
if logger is True:
117116
version = os.environ.get('PL_EXP_VERSION', self.trainer.slurm_job_id)
@@ -130,6 +129,53 @@ def configure_logger(self, logger):
130129
else:
131130
self.trainer.logger = logger
132131

132+
def cache_training_step_metrics(self, opt_closure_result):
133+
"""
134+
This function is responsible to update
135+
logger_connector internals metrics holder based for depreceated logging
136+
"""
137+
using_results_obj = isinstance(opt_closure_result.training_step_output, Result)
138+
139+
# temporary dict to collect metrics
140+
logged_metrics_tmp = {}
141+
pbar_metrics_tmp = {}
142+
callback_metrics_tmp = {}
143+
144+
if using_results_obj:
145+
batch_log_metrics = opt_closure_result.training_step_output.get_batch_log_metrics(
146+
include_forked_originals=False
147+
)
148+
logged_metrics_tmp.update(batch_log_metrics)
149+
150+
batch_pbar_metrics = opt_closure_result.training_step_output.get_batch_pbar_metrics(
151+
include_forked_originals=False
152+
)
153+
pbar_metrics_tmp.update(batch_pbar_metrics)
154+
155+
forked_metrics = opt_closure_result.training_step_output.get_forked_metrics()
156+
callback_metrics_tmp.update(forked_metrics)
157+
callback_metrics_tmp.update(logged_metrics_tmp)
158+
159+
else:
160+
batch_log_metrics = opt_closure_result.training_step_output.log_metrics
161+
logged_metrics_tmp.update(batch_log_metrics)
162+
163+
callback_metrics = opt_closure_result.training_step_output.callback_metrics
164+
callback_metrics_tmp.update(callback_metrics)
165+
166+
batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end
167+
pbar_metrics_tmp.update(batch_pbar_metrics)
168+
169+
# track progress bar metrics
170+
if len(pbar_metrics_tmp) > 0:
171+
self.add_progress_bar_metrics(pbar_metrics_tmp)
172+
173+
self.callback_metrics.update(callback_metrics_tmp)
174+
175+
# save legacy log metrics
176+
self.logged_metrics.update(logged_metrics_tmp)
177+
self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp)
178+
133179
def log_metrics(self, metrics, grad_norm_dic, step=None):
134180
"""Logs the metric dict passed in.
135181
If `step` parameter is None and `step` key is presented is metrics,
@@ -396,8 +442,9 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results, test_mod
396442
if num_loaders == 1:
397443
self.__process_eval_epoch_end_results_and_log_legacy_update(prog_bar_metrics, log_metrics, callback_metrics)
398444

399-
def on_train_epoch_end(self, epoch_output):
400-
pass
445+
def on_train_epoch_end(self):
446+
# inform cached logger connector epoch finished
447+
self.cached_results.has_batch_loop_finished = True
401448

402449
def log_train_epoch_end_metrics(self,
403450
epoch_output,
@@ -441,12 +488,10 @@ def log_train_epoch_end_metrics(self,
441488
# ------------------
442489
if is_1_0_result:
443490
# lightning module hook
444-
epoch_end_log_result = self.training_epoch_end(model, epoch_output, num_optimizers)
491+
self.training_epoch_end(model, epoch_output, num_optimizers)
445492

446493
# log/aggregate metrics automatically
447494
epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)
448-
epoch_log_metrics.update(epoch_end_log_result.get_epoch_log_metrics())
449-
epoch_progress_bar_metrics.update(epoch_end_log_result.get_epoch_pbar_metrics())
450495

451496
# TODO: deprecate 1.0
452497
else:
@@ -459,6 +504,14 @@ def log_train_epoch_end_metrics(self,
459504
)
460505
epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics = out
461506

507+
# it will perform reduction over epoch and return log metrics
508+
cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics()
509+
cached_epoch_pbar_metrics = self.cached_results.get_epoch_pbar_metrics()
510+
511+
# update
512+
epoch_log_metrics.update(cached_epoch_log_metrics)
513+
epoch_progress_bar_metrics.update(cached_epoch_pbar_metrics)
514+
462515
# --------------------------
463516
# track results
464517
# --------------------------
@@ -475,15 +528,16 @@ def log_train_epoch_end_metrics(self,
475528
self.add_progress_bar_metrics(epoch_progress_bar_metrics)
476529
self.callback_metrics.update(epoch_progress_bar_metrics)
477530

531+
# reset epoch loop result for next epoch
532+
self.cached_results.reset()
533+
478534
def training_epoch_end(self, model, epoch_output, num_optimizers):
479535
if not is_overridden('training_epoch_end', model=model):
480-
return Result()
536+
return
481537

482538
# run training_epoch_end
483539
# refresh the result for custom logging at the epoch level
484540
model._current_fx_name = 'training_epoch_end'
485-
model._results = Result()
486-
487541
epoch_output = self.__prepare_epoch_end_inputs(epoch_output)
488542

489543
if num_optimizers == 1 or not self.trainer.train_loop.automatic_optimization:
@@ -492,15 +546,11 @@ def training_epoch_end(self, model, epoch_output, num_optimizers):
492546
# lightningmodule hook
493547
epoch_output = model.training_epoch_end(epoch_output)
494548

495-
model._current_fx_name = ''
496-
497549
if epoch_output is not None:
498550
raise MisconfigurationException('training_epoch_end expects a return of None. '
499551
'HINT: remove the return statement in training_epoch_end')
500-
501-
# user can ALSO log at the end of an epoch
502-
new_epoch_end_logs = model._results
503-
return new_epoch_end_logs
552+
# capture logging
553+
self.trainer.logger_connector.cache_logged_metrics()
504554

505555
def __run_legacy_training_epoch_end(
506556
self,
@@ -527,8 +577,12 @@ def __run_legacy_training_epoch_end(
527577

528578
# run training_epoch_end
529579
# a list with a result per optimizer index
580+
model._current_fx_name = 'training_epoch_end'
530581
epoch_output = model.training_epoch_end(epoch_output)
531582

583+
# capture logging
584+
self.trainer.logger_connector.cache_logged_metrics()
585+
532586
if isinstance(epoch_output, Result):
533587
epoch_log_metrics = epoch_output.epoch_log_metrics
534588
epoch_progress_bar_metrics = epoch_output.epoch_pbar_metrics
@@ -563,7 +617,7 @@ def __auto_reduce_results_on_epoch_end(self, epoch_output):
563617
# reduce across training steps
564618
opt_outputs = time_reduced_outputs[0].__class__.reduce_on_epoch_end(time_reduced_outputs)
565619

566-
# with manual opt need 1+ metrics because meta is always there
620+
# with manual opt need 1 + metrics because meta is always there
567621
if opt_outputs.minimize is not None:
568622
opt_outputs.minimize = opt_outputs.minimize.mean()
569623
epoch_log_metrics.update(opt_outputs.epoch_log_metrics)
@@ -623,12 +677,9 @@ def __gather_result_across_time_and_optimizers(self, epoch_output):
623677

624678
def log_train_step_metrics(self, batch_output):
625679
# when metrics should be logged
626-
should_log_metrics = (
627-
(self.trainer.global_step + 1) % self.trainer.log_every_n_steps == 0 or self.trainer.should_stop
628-
)
629-
if should_log_metrics or self.trainer.fast_dev_run:
680+
if self.should_update_logs or self.trainer.fast_dev_run:
630681
# logs user requested information to logger
631-
metrics = batch_output.batch_log_metrics
682+
metrics = self.cached_results.get_latest_batch_log_metrics()
632683
grad_norm_dic = batch_output.grad_norm_dic
633684
if len(metrics) > 0 or len(grad_norm_dic) > 0:
634685
self.log_metrics(metrics, grad_norm_dic)

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,9 @@ def __log_result_step_metrics(self, output, batch_idx):
358358
step_log_metrics = output.get_batch_log_metrics(include_forked_originals=False)
359359
step_pbar_metrics = output.get_batch_pbar_metrics(include_forked_originals=False)
360360

361+
cached_batch_log_metrics = \
362+
self.trainer.logger_connector.cached_results.get_latest_batch_log_metrics()
363+
361364
if len(step_log_metrics) > 0:
362365
# make the metrics appear as a different line in the same graph
363366
metrics_by_epoch = {}

pytorch_lightning/trainer/trainer.py

Lines changed: 23 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -838,7 +838,25 @@ def call_setup_hook(self, model):
838838
self.setup(stage_name)
839839
model.setup(stage_name)
840840

841+
def _reset_result_and_set_hook_fx_name(self, hook_name):
842+
model_ref = self.get_model()
843+
if model_ref is not None:
844+
# used to track current hook name called
845+
model_ref._results = Result()
846+
model_ref._current_hook_fx_name = hook_name
847+
848+
def _cache_logged_metrics(self):
849+
model_ref = self.get_model()
850+
if model_ref is not None:
851+
# capture logging for this hook
852+
self.logger_connector.cache_logged_metrics()
853+
841854
def call_hook(self, hook_name, *args, **kwargs):
855+
# temporary. Don't modify evaluation behaviour
856+
if self.logger_connector._current_stage == "train":
857+
# set hook_name to model + reset Result obj
858+
self._reset_result_and_set_hook_fx_name(hook_name)
859+
842860
# always profile hooks
843861
with self.profiler.profile(hook_name):
844862

@@ -860,4 +878,8 @@ def call_hook(self, hook_name, *args, **kwargs):
860878
accelerator_hook = getattr(self.accelerator_backend, hook_name)
861879
output = accelerator_hook(*args, **kwargs)
862880

863-
return output
881+
# temporary. Don't modify evaluation behaviour
882+
if self.logger_connector._current_stage == "train":
883+
# capture logging
884+
self._cache_logged_metrics()
885+
return output

0 commit comments

Comments
 (0)