Skip to content

Commit 45143fd

Browse files
authored
Improve val step logging (#7351)
* Fix val step logging * Add a type * Fix * Update CHANGELOG.md
1 parent f9e050c commit 45143fd

File tree

7 files changed

+51
-39
lines changed

7 files changed

+51
-39
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
141141
- Ensure accelerator is valid if running interactively ([#5970](https://github.com/PyTorchLightning/pytorch-lightning/pull/5970))
142142
- Disabled batch transfer in DP mode ([#6098](https://github.com/PyTorchLightning/pytorch-lightning/pull/6098))
143143

144+
- Changed the behaviour when logging evaluation step metrics to no longer append `/epoch_*` to the metric name ([#7351](https://github.com/PyTorchLightning/pytorch-lightning/pull/7351))
145+
146+
144147
### Deprecated
145148

146149
- Deprecated `outputs` in both `LightningModule.on_train_epoch_end` and `Callback.on_train_epoch_end` hooks ([#7339](https://github.com/PyTorchLightning/pytorch-lightning/pull/7339))

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 32 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,8 @@ def __init__(self, trainer, log_gpu_memory: Optional[str] = None):
4343
self._cached_results = {stage: EpochResultStore(trainer) for stage in RunningStage}
4444
self._cached_results[None] = EpochResultStore(trainer)
4545
self._callback_hook_validator = CallbackHookNameValidator()
46+
self._val_log_step: int = 0
47+
self._test_log_step: int = 0
4648

4749
@property
4850
def callback_metrics(self) -> Dict:
@@ -201,7 +203,8 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
201203
Args:
202204
metrics (dict): Metric values
203205
grad_norm_dic (dict): Gradient norms
204-
step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step`
206+
step (int): Step for which metrics should be logged. Default value is `self.global_step` during training or
207+
the total validation / test log step count during validation and testing.
205208
"""
206209
# add gpu memory
207210
if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory:
@@ -371,3 +374,31 @@ def log_train_step_metrics(self, batch_output):
371374
if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0:
372375
self.log_metrics(batch_log_metrics, grad_norm_dic)
373376
self._callback_metrics.update(batch_log_metrics)
377+
378+
@property
379+
def evaluation_log_step(self) -> Optional[int]:
380+
if self.trainer.state.stage is RunningStage.VALIDATING:
381+
return self._val_log_step
382+
elif self.trainer.state.stage is RunningStage.TESTING:
383+
return self._test_log_step
384+
else:
385+
return None
386+
387+
def increment_evaluation_log_step(self) -> None:
388+
if self.trainer.state.stage is RunningStage.VALIDATING:
389+
self._val_log_step += 1
390+
elif self.trainer.state.stage is RunningStage.TESTING:
391+
self._test_log_step += 1
392+
393+
def log_evaluation_step_metrics(self) -> None:
394+
if self.trainer.sanity_checking:
395+
return
396+
_, batch_log_metrics = self.cached_results.update_logger_connector()
397+
398+
# logs user requested information to logger
399+
if len(batch_log_metrics) > 0:
400+
kwargs = dict() if "step" in batch_log_metrics else dict(step=self.evaluation_log_step)
401+
self.log_metrics(batch_log_metrics, {}, **kwargs)
402+
403+
# increment the step even if nothing was logged
404+
self.increment_evaluation_log_step()

pytorch_lightning/trainer/evaluation_loop.py

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -267,22 +267,3 @@ def on_evaluation_epoch_end(self) -> None:
267267
self.trainer._cache_logged_metrics()
268268

269269
self.trainer.call_hook('on_epoch_end')
270-
271-
def log_evaluation_step_metrics(self, batch_idx: int) -> None:
272-
if self.trainer.sanity_checking:
273-
return
274-
275-
cached_results = self.trainer.logger_connector.cached_results
276-
if cached_results is not None:
277-
cached_batch_pbar_metrics, cached_batch_log_metrics = cached_results.update_logger_connector()
278-
279-
if len(cached_batch_log_metrics) > 0:
280-
# make the metrics appear as a different line in the same graph
281-
metrics_by_epoch = {}
282-
for k, v in cached_batch_log_metrics.items():
283-
metrics_by_epoch[f'{k}/epoch_{self.trainer.current_epoch}'] = v
284-
285-
self.trainer.logger_connector.log_metrics(metrics_by_epoch, {}, step=batch_idx)
286-
287-
if len(cached_batch_pbar_metrics) > 0:
288-
self.trainer.logger_connector.add_progress_bar_metrics(cached_batch_pbar_metrics)

pytorch_lightning/trainer/trainer.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -966,7 +966,7 @@ def run_evaluation(self, on_epoch: bool = False) -> _EVALUATE_OUTPUT:
966966
self.evaluation_loop.on_evaluation_batch_end(output, batch, batch_idx, dataloader_idx)
967967

968968
# log batch metrics
969-
self.evaluation_loop.log_evaluation_step_metrics(batch_idx)
969+
self.logger_connector.log_evaluation_step_metrics()
970970

971971
# track epoch level outputs
972972
dl_outputs = self.track_output_for_epoch_end(dl_outputs, output)

tests/accelerators/test_multi_nodes_gpu.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,8 +121,7 @@ def backward(self, loss, optimizer, optimizer_idx):
121121
'a2',
122122
'a_step',
123123
'a_epoch',
124-
'b_step/epoch_0',
125-
'b_step/epoch_1',
124+
'b_step',
126125
'b_epoch',
127126
'epoch',
128127
}

tests/trainer/logging_/test_eval_loop_logging.py

Lines changed: 13 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -76,16 +76,15 @@ def backward(self, loss, optimizer, optimizer_idx):
7676
'a2',
7777
'a_step',
7878
'a_epoch',
79-
'b_step/epoch_0',
80-
'b_step/epoch_1',
79+
'b_step',
8180
'b_epoch',
8281
'epoch',
8382
}
8483
logged_metrics = set(trainer.logged_metrics.keys())
8584
assert expected_logged_metrics == logged_metrics
8685

8786
# we don't want to enable val metrics during steps because it is not something that users should do
88-
# on purpose DO NOT allow step_b... it's silly to monitor val step metrics
87+
# on purpose DO NOT allow b_step... it's silly to monitor val step metrics
8988
callback_metrics = set(trainer.callback_metrics.keys())
9089
expected_cb_metrics = {'a', 'a2', 'b', 'a_epoch', 'b_epoch', 'a_step'}
9190
assert expected_cb_metrics == callback_metrics
@@ -145,8 +144,7 @@ def backward(self, loss, optimizer, optimizer_idx):
145144
'b_step',
146145
'b_epoch',
147146
'c',
148-
'd_step/epoch_0',
149-
'd_step/epoch_1',
147+
'd_step',
150148
'd_epoch',
151149
'g',
152150
}
@@ -294,15 +292,15 @@ def validation_epoch_end(self, outputs) -> None:
294292

295293
# make sure values are correct
296294
assert trainer.logged_metrics['val_loss_epoch'] == manual_mean
297-
assert trainer.callback_metrics['val_loss'] == trainer.logged_metrics['val_loss_step/epoch_0']
295+
assert trainer.callback_metrics['val_loss'] == trainer.logged_metrics['val_loss_step']
298296

299297
# make sure correct values were logged
300298
logged_val = trainer.dev_debugger.logged_metrics
301299

302300
# 3 val batches
303-
assert logged_val[0]['val_loss_step/epoch_0'] == model.seen_vals[0]
304-
assert logged_val[1]['val_loss_step/epoch_0'] == model.seen_vals[1]
305-
assert logged_val[2]['val_loss_step/epoch_0'] == model.seen_vals[2]
301+
assert logged_val[0]['val_loss_step'] == model.seen_vals[0]
302+
assert logged_val[1]['val_loss_step'] == model.seen_vals[1]
303+
assert logged_val[2]['val_loss_step'] == model.seen_vals[2]
306304

307305
# epoch mean
308306
assert logged_val[3]['val_loss_epoch'] == model.manual_epoch_end_mean
@@ -872,29 +870,29 @@ def get_metrics_at_idx(idx):
872870
else:
873871
return mock_calls[idx][2]["metrics"]
874872

875-
expected = ['valid_loss_0_step/epoch_0', 'valid_loss_2/epoch_0', 'global_step']
873+
expected = ['valid_loss_0_step', 'valid_loss_2', 'global_step']
876874
assert sorted(get_metrics_at_idx(1)) == sorted(expected)
877875
assert sorted(get_metrics_at_idx(2)) == sorted(expected)
878876

879877
expected = model.val_losses[2]
880-
assert get_metrics_at_idx(1)["valid_loss_0_step/epoch_0"] == expected
878+
assert get_metrics_at_idx(1)["valid_loss_0_step"] == expected
881879
expected = model.val_losses[3]
882-
assert get_metrics_at_idx(2)["valid_loss_0_step/epoch_0"] == expected
880+
assert get_metrics_at_idx(2)["valid_loss_0_step"] == expected
883881

884882
expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step']
885883
assert sorted(get_metrics_at_idx(3)) == sorted(expected)
886884

887885
expected = torch.stack(model.val_losses[2:4]).mean()
888886
assert get_metrics_at_idx(3)["valid_loss_1"] == expected
889-
expected = ['valid_loss_0_step/epoch_1', 'valid_loss_2/epoch_1', 'global_step']
887+
expected = ['valid_loss_0_step', 'valid_loss_2', 'global_step']
890888

891889
assert sorted(get_metrics_at_idx(4)) == sorted(expected)
892890
assert sorted(get_metrics_at_idx(5)) == sorted(expected)
893891

894892
expected = model.val_losses[4]
895-
assert get_metrics_at_idx(4)["valid_loss_0_step/epoch_1"] == expected
893+
assert get_metrics_at_idx(4)["valid_loss_0_step"] == expected
896894
expected = model.val_losses[5]
897-
assert get_metrics_at_idx(5)["valid_loss_0_step/epoch_1"] == expected
895+
assert get_metrics_at_idx(5)["valid_loss_0_step"] == expected
898896

899897
expected = ['valid_loss_0_epoch', 'valid_loss_1', 'epoch', 'global_step']
900898
assert sorted(get_metrics_at_idx(6)) == sorted(expected)

tests/trainer/logging_/test_train_loop_logging.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -401,7 +401,7 @@ def val_dataloader(self):
401401
trainer.fit(model)
402402

403403
generated = set(trainer.logger_connector.logged_metrics)
404-
expected = {'a_step', 'a_epoch', 'n_step/epoch_0', 'n_epoch', 'epoch'}
404+
expected = {'a_step', 'a_epoch', 'n_step', 'n_epoch', 'epoch'}
405405

406406
assert generated == expected
407407

0 commit comments

Comments
 (0)