Skip to content

Commit 61b7fd5

Browse files
carmoccalexierule
authored andcommitted
Do not add return dict items to callback_metrics (#6682)
1 parent f6d5782 commit 61b7fd5

File tree

18 files changed

+101
-342
lines changed

18 files changed

+101
-342
lines changed

CHANGELOG.md

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,10 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1010

1111
- Changed the behavior of `on_epoch_start` to run at the beginning of validation & test epoch ([#6498](https://github.com/PyTorchLightning/pytorch-lightning/pull/6498))
1212

13+
### Removed
14+
15+
- Removed legacy code to include `step` dictionary returns in `callback_metrics`. Use `self.log_dict` instead. ([#6682](https://github.com/PyTorchLightning/pytorch-lightning/pull/6682))
16+
1317
### Fixed
1418

1519
- Fixed `DummyLogger.log_hyperparams` raising a `TypeError` when running with `fast_dev_run=True` ([#6398](https://github.com/PyTorchLightning/pytorch-lightning/pull/6398))

docs/source/ecosystem/asr_nlp_tts.rst

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -270,12 +270,12 @@ with PyTorch Lightning since every NeMo model is a Lightning Module.
270270
log_probs=log_probs, targets=transcript, input_lengths=encoded_len, target_lengths=transcript_len
271271
)
272272
wer_num, wer_denom = self._wer(predictions, transcript, transcript_len)
273-
tensorboard_logs = {
273+
self.log_dict({
274274
'train_loss': loss_value,
275275
'training_batch_wer': wer_num / wer_denom,
276276
'learning_rate': self._optimizer.param_groups[0]['lr'],
277-
}
278-
return {'loss': loss_value, 'log': tensorboard_logs}
277+
})
278+
return loss_value
279279
280280
Neural Types in NeMo ASR
281281
------------------------
@@ -539,8 +539,8 @@ since every NeMo model is a Lightning Module.
539539
logits = self(input_ids=input_ids, token_type_ids=input_type_ids, attention_mask=input_mask)
540540
541541
loss = self.loss(logits=logits, labels=labels, loss_mask=loss_mask)
542-
tensorboard_logs = {'train_loss': loss, 'lr': self._optimizer.param_groups[0]['lr']}
543-
return {'loss': loss, 'log': tensorboard_logs}
542+
self.log_dict({'train_loss': loss, 'lr': self._optimizer.param_groups[0]['lr']})
543+
return loss
544544
...
545545
546546
Neural Types in NeMo NLP

docs/source/ecosystem/bolts.rst

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -68,8 +68,8 @@ you can trust the implementations and use them to bootstrap your research much f
6868
6969
loss = self.criterion(logits.view(-1, logits.size(-1)), x.view(-1).long())
7070
71-
logs = {"loss": loss}
72-
return {"loss": loss, "log": logs}
71+
self.log("loss", loss)
72+
return loss
7373
7474
----------
7575

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -490,7 +490,7 @@ def _validate_monitor_key(self, trainer):
490490
m = (
491491
f"ModelCheckpoint(monitor='{self.monitor}') not found in the returned metrics:"
492492
f" {list(metrics.keys())}. "
493-
f"HINT: Did you call self.log('{self.monitor}', tensor) in the LightningModule?"
493+
f"HINT: Did you call self.log('{self.monitor}', value) in the LightningModule?"
494494
)
495495
raise MisconfigurationException(m)
496496

pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -346,7 +346,6 @@ def update_logger_connector(self) -> Tuple[Dict, Dict]:
346346

347347
# update callback_metrics
348348
logger_connector._callback_metrics.update(callback_metrics)
349-
logger_connector._callback_metrics.pop("epoch", None)
350349

351350
batch_pbar_metrics.pop("debug_epoch", None)
352351
return batch_pbar_metrics, batch_log_metrics

pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 17 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,7 @@ def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None:
7878

7979
@property
8080
def cached_results(self) -> Union[EpochResultStore, None]:
81-
return self._cached_results.get(self.trainer._running_stage) # type: ignore
81+
return self._cached_results.get(self.trainer._running_stage)
8282

8383
def get_metrics(self, key: str) -> Dict:
8484
metrics_holder = getattr(self, f"_{key}", None)
@@ -125,8 +125,6 @@ def cache_logged_metrics(self):
125125
def on_trainer_init(self, logger, flush_logs_every_n_steps: int, log_every_n_steps: int, move_metrics_to_cpu: bool):
126126
# logging
127127
self.configure_logger(logger)
128-
# todo: IDE is complaining, these shall be initialized in the Trainer init at leas as placeholders
129-
# and assign here the desired value
130128
self.trainer.flush_logs_every_n_steps = flush_logs_every_n_steps
131129
self.trainer.log_every_n_steps = log_every_n_steps
132130
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
@@ -189,9 +187,6 @@ def cache_training_step_metrics(self, opt_closure_result):
189187
batch_log_metrics = opt_closure_result.training_step_output.log_metrics
190188
logged_metrics_tmp.update(batch_log_metrics)
191189

192-
callback_metrics = opt_closure_result.training_step_output.callback_metrics
193-
callback_metrics_tmp.update(callback_metrics)
194-
195190
batch_pbar_metrics = opt_closure_result.training_step_output.pbar_on_batch_end
196191
pbar_metrics_tmp.update(batch_pbar_metrics)
197192

@@ -214,9 +209,6 @@ def log_metrics(self, metrics, grad_norm_dic, step=None):
214209
metrics (dict): Metric values
215210
grad_norm_dic (dict): Gradient norms
216211
step (int): Step for which metrics should be logged. Default value corresponds to `self.global_step`
217-
log_train_step_metrics (bool): Used to track if `log_metrics` function is being called in during training
218-
steps. In training steps, we will log metrics on step: `total_nb_idx` (for accumulated gradients)
219-
and global_step for the rest.
220212
"""
221213
# add gpu memory
222214
if self.trainer._device_type == DeviceType.GPU and self.log_gpu_memory:
@@ -350,27 +342,6 @@ def _track_callback_metrics(self, eval_results):
350342
if self.trainer.testing:
351343
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
352344

353-
def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics):
354-
# eval loop returns all metrics
355-
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics, **callback_metrics}
356-
357-
# add metrics to prog bar
358-
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)
359-
360-
# log metrics
361-
if len(log_metrics) > 0:
362-
self.trainer.logger_connector.log_metrics(log_metrics, {})
363-
364-
# track metrics for callbacks (all prog bar, logged and callback metrics)
365-
callback_metrics.update(log_metrics)
366-
callback_metrics.update(prog_bar_metrics)
367-
self.trainer.logger_connector.callback_metrics.update(callback_metrics)
368-
if self.trainer.testing:
369-
self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics)
370-
371-
if len(dataloader_result_metrics) > 0:
372-
self.eval_loop_results.append(dataloader_result_metrics)
373-
374345
def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
375346
if self.trainer.running_sanity_check:
376347
return
@@ -381,21 +352,21 @@ def __process_eval_epoch_end_results_and_log_legacy(self, eval_results):
381352
if not isinstance(eval_results, list):
382353
eval_results = [eval_results]
383354

384-
num_loaders: int = self.trainer.evaluation_loop.num_dataloaders
385-
prog_bar_metrics, log_metrics, callback_metrics = {}, {}, {}
386-
387355
for result_idx, result in enumerate(eval_results):
388-
_, prog_bar_metrics, log_metrics, callback_metrics, _ = self.trainer.process_dict_result(result)
356+
_, prog_bar_metrics, log_metrics, _ = self.trainer.process_dict_result(result)
357+
358+
# eval loop returns all metrics
359+
dataloader_result_metrics = {**prog_bar_metrics, **log_metrics}
360+
361+
# add metrics to prog bar
362+
self.trainer.logger_connector.add_progress_bar_metrics(prog_bar_metrics)
389363

390-
if num_loaders > 1:
391-
self.__process_eval_epoch_end_results_and_log_legacy_update(
392-
prog_bar_metrics, log_metrics, callback_metrics
393-
)
364+
# log metrics
365+
if len(log_metrics) > 0:
366+
self.trainer.logger_connector.log_metrics(log_metrics, {})
394367

395-
if num_loaders == 1:
396-
self.__process_eval_epoch_end_results_and_log_legacy_update(
397-
prog_bar_metrics, log_metrics, callback_metrics
398-
)
368+
if len(dataloader_result_metrics) > 0:
369+
self.eval_loop_results.append(dataloader_result_metrics)
399370

400371
def on_train_epoch_end(self):
401372
# inform cached logger connector epoch finished
@@ -448,10 +419,9 @@ def log_train_epoch_end_metrics(
448419

449420
# TODO: deprecate 1.0
450421
else:
451-
out = self.__run_legacy_training_epoch_end(
452-
num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics
422+
epoch_log_metrics, epoch_progress_bar_metrics = self.__run_legacy_training_epoch_end(
423+
num_optimizers, epoch_output, model, is_result_obj
453424
)
454-
epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics = out
455425

456426
# it will perform reduction over epoch and return log metrics
457427
cached_epoch_log_metrics = self.cached_results.get_epoch_log_metrics()
@@ -503,9 +473,7 @@ def training_epoch_end(self, model, epoch_output, num_optimizers):
503473
# capture logging
504474
self.trainer.logger_connector.cache_logged_metrics()
505475

506-
def __run_legacy_training_epoch_end(
507-
self, num_optimizers, epoch_output, model, is_result_obj, epoch_callback_metrics
508-
):
476+
def __run_legacy_training_epoch_end(self, num_optimizers, epoch_output, model, is_result_obj):
509477

510478
epoch_log_metrics = {}
511479
epoch_progress_bar_metrics = {}
@@ -536,15 +504,14 @@ def __run_legacy_training_epoch_end(
536504
_processed_outputs = self.trainer.process_dict_result(epoch_output)
537505
epoch_progress_bar_metrics = _processed_outputs[1]
538506
epoch_log_metrics = _processed_outputs[2]
539-
epoch_callback_metrics = _processed_outputs[3]
540507

541508
# --------------------------
542509
# Structured Result (auto epoch end)
543510
# --------------------------
544511
elif is_result_obj:
545512
epoch_log_metrics, epoch_progress_bar_metrics = self.__auto_reduce_results_on_epoch_end(epoch_output)
546513

547-
return epoch_log_metrics, epoch_progress_bar_metrics, epoch_callback_metrics
514+
return epoch_log_metrics, epoch_progress_bar_metrics
548515

549516
def __auto_reduce_results_on_epoch_end(self, epoch_output):
550517
epoch_log_metrics = {}

pytorch_lightning/trainer/logging.py

Lines changed: 9 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from pytorch_lightning.loggers import LightningLoggerBase
2222
from pytorch_lightning.utilities import DeviceType, DistributedType
2323
from pytorch_lightning.utilities.distributed import rank_zero_warn
24+
from pytorch_lightning.utilities.exceptions import MisconfigurationException
2425
from pytorch_lightning.utilities.memory import recursive_detach
2526

2627

@@ -42,8 +43,14 @@ class TrainerLoggingMixin(ABC):
4243

4344
def metrics_to_scalars(self, metrics):
4445
new_metrics = {}
46+
# TODO: this is duplicated in MetricsHolder. should be unified
4547
for k, v in metrics.items():
4648
if isinstance(v, torch.Tensor):
49+
if v.numel() != 1:
50+
raise MisconfigurationException(
51+
f"The metric `{k}` does not contain a single element"
52+
f" thus it cannot be converted to float. Found `{v}`"
53+
)
4754
v = v.item()
4855

4956
if isinstance(v, dict):
@@ -81,23 +88,8 @@ def process_dict_result(self, output, train=False):
8188
if isinstance(output, torch.Tensor):
8289
progress_bar_metrics = {}
8390
log_metrics = {}
84-
callback_metrics = {}
8591
hiddens = None
86-
return output, progress_bar_metrics, log_metrics, callback_metrics, hiddens
87-
88-
# ---------------
89-
# EXTRACT CALLBACK KEYS
90-
# ---------------
91-
# all keys not progress_bar or log are candidates for callbacks
92-
callback_metrics = {}
93-
if isinstance(output, Mapping):
94-
for k, v in output.items():
95-
if k not in ['progress_bar', 'log', 'hiddens']:
96-
callback_metrics[k] = v
97-
98-
if train and self._distrib_type in (DistributedType.DP, DistributedType.DDP2):
99-
num_gpus = self.num_gpus
100-
callback_metrics = self.reduce_distributed_output(callback_metrics, num_gpus)
92+
return output, progress_bar_metrics, log_metrics, hiddens
10193

10294
# ---------------
10395
# EXTRACT PROGRESS BAR KEYS
@@ -159,17 +151,12 @@ def process_dict_result(self, output, train=False):
159151
# ---------------
160152
hiddens = output.get('hiddens', None) if isinstance(output, Mapping) else None
161153

162-
# use every metric passed in as a candidate for callback
163-
callback_metrics.update(progress_bar_metrics)
164-
callback_metrics.update(log_metrics)
165-
166154
# detach all metrics for callbacks to prevent memory leaks
167155
# no .item() because it will slow things down
168-
callback_metrics = recursive_detach(callback_metrics)
169156
progress_bar_metrics = recursive_detach(progress_bar_metrics)
170157
log_metrics = recursive_detach(log_metrics)
171158

172-
return loss, progress_bar_metrics, log_metrics, callback_metrics, hiddens
159+
return loss, progress_bar_metrics, log_metrics, hiddens
173160

174161
def reduce_distributed_output(self, output, num_gpus):
175162
if num_gpus <= 1:

pytorch_lightning/trainer/trainer.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -859,15 +859,6 @@ def run_sanity_check(self, ref_model):
859859
# run eval step
860860
_, eval_results = self.run_evaluation(max_batches=self.num_sanity_val_batches)
861861

862-
# allow no returns from eval
863-
if eval_results is not None and len(eval_results) > 0:
864-
# when we get a list back, used only the last item
865-
if isinstance(eval_results, list):
866-
eval_results = eval_results[-1]
867-
868-
_, _, _, callback_metrics, _ = self.process_dict_result(eval_results)
869-
self.logger_connector.callback_metrics = callback_metrics
870-
871862
self.on_sanity_check_end()
872863
self.running_sanity_check = False
873864

pytorch_lightning/trainer/training_loop.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -360,8 +360,7 @@ def _process_training_step_output(self, training_step_output, split_batch):
360360
batch_loss=training_step_output[0],
361361
pbar_on_batch_end=training_step_output[1],
362362
log_metrics=training_step_output[2],
363-
callback_metrics=training_step_output[3],
364-
hiddens=training_step_output[4],
363+
hiddens=training_step_output[3],
365364
)
366365
# if the user decides to finally reduce things in epoch_end, save raw output without graphs
367366
if isinstance(training_step_output_for_epoch_end, torch.Tensor):

tests/base/model_valid_epoch_ends.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,8 @@ def _mean(res, key):
4343
val_loss_mean = val_loss_mean.item()
4444
val_acc_mean = val_acc_mean.item()
4545

46-
metrics_dict = {'early_stop_on': val_loss_mean, 'val_acc': val_acc_mean}
47-
results = {'progress_bar': metrics_dict, 'log': metrics_dict}
48-
return results
46+
self.log('early_stop_on', val_loss_mean, prog_bar=True)
47+
self.log('val_acc', val_acc_mean, prog_bar=True)
4948

5049
def validation_epoch_end__multiple_dataloaders(self, outputs):
5150
"""

0 commit comments

Comments
 (0)