diff --git a/CHANGELOG.md b/CHANGELOG.md index 928144320394a..8efcd6f06c2d6 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Changed `automatic casting` for LoggerConnector `metrics` ([#5218](https://github.com/PyTorchLightning/pytorch-lightning/pull/5218)) + + - `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839)) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index c8725f4cde6fd..fca39036c9404 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -19,14 +19,12 @@ Monitor a metric and stop training when it stops improving. """ -import numbers import numpy as np import torch from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn class EarlyStopping(Callback): @@ -196,15 +194,6 @@ def _run_early_stopping_check(self, trainer, pl_module): # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) - if current is not None: - if isinstance(current, Metric): - current = current.compute() - elif isinstance(current, numbers.Number): - current = torch.tensor(current, device=pl_module.device, dtype=torch.float) - - if trainer.use_tpu and _TPU_AVAILABLE: - current = current.cpu() - if self.monitor_op(current - self.min_delta, self.best_score): self.best_score = current self.wait_count = 0 diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 7fd7a571a47ce..3fc2b54d98162 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -20,7 +20,6 @@ """ -import numbers import os import re from copy import deepcopy @@ -33,7 +32,6 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -554,12 +552,6 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics): epoch = metrics.get("epoch") step = metrics.get("step") - if current is not None: - if isinstance(current, Metric): - current = current.compute() - elif isinstance(current, numbers.Number): - current = torch.tensor(current, device=pl_module.device, dtype=torch.float) - if self.check_monitor_top_k(current): self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics) elif self.verbose: @@ -587,7 +579,7 @@ def _update_best_and_save( self.best_k_models.pop(del_filepath) # do not save nan, replace with +/- inf - if torch.isnan(current): + if isinstance(current, torch.Tensor) and torch.isnan(current): current = torch.tensor(float('inf' if self.mode == "min" else '-inf')) filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, del_filepath) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index dd12a2970727a..2796a61ee5c83 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -379,7 +379,7 @@ def update_logger_connector(self) -> None: if is_train: # Only log and add to callback epoch step during evaluation, test. - logger_connector.logged_metrics.update(batch_log_metrics) + logger_connector._logged_metrics.update(batch_log_metrics) callback_metrics.update(batch_pbar_metrics) callback_metrics.update(batch_log_metrics) else: @@ -389,8 +389,8 @@ def update_logger_connector(self) -> None: # get logged_metrics epoch_log_metrics = self.get_epoch_log_metrics() - logger_connector.logged_metrics.update(epoch_log_metrics) - logger_connector.logged_metrics.update(epoch=self.trainer.current_epoch) + logger_connector._logged_metrics.update(epoch_log_metrics) + logger_connector._logged_metrics.update({"epoch": self.trainer.current_epoch}) # get forked_metrics forked_metrics = self.get_forked_metrics() @@ -403,8 +403,8 @@ def update_logger_connector(self) -> None: logger_connector.evaluation_callback_metrics.update(callback_metrics) # update callback_metrics - logger_connector.callback_metrics.update(callback_metrics) - logger_connector.callback_metrics.pop("epoch", None) + logger_connector._callback_metrics.update(callback_metrics) + logger_connector._callback_metrics.pop("epoch", None) batch_pbar_metrics.pop("debug_epoch", None) return batch_pbar_metrics, batch_log_metrics diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 6b55b3bce1b9a..73e9223fb7d0f 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -14,7 +14,7 @@ from copy import deepcopy import os from pprint import pprint -from typing import Iterable, Union +from typing import Any, Iterable, Union, Dict import torch @@ -23,6 +23,7 @@ from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages +from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.utilities import flatten_dict from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.model_helpers import is_overridden @@ -31,19 +32,64 @@ class LoggerConnector: def __init__(self, trainer): self.trainer = trainer - self.callback_metrics = {} - self.evaluation_callback_metrics = {} - self.logged_metrics = {} - self.progress_bar_metrics = {} + self._callback_metrics = MetricsHolder() + self._evaluation_callback_metrics = MetricsHolder(to_float=True) + self._logged_metrics = MetricsHolder() + self._progress_bar_metrics = MetricsHolder() self.eval_loop_results = [] self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in LoggerStages} self._callback_hook_validator = CallbackHookNameValidator() self._current_stage = None + @property + def callback_metrics(self) -> Dict: + return self.get_metrics("callback_metrics") + + @callback_metrics.setter + def callback_metrics(self, callback_metrics: Dict) -> None: + self.set_metrics("callback_metrics", callback_metrics) + + @property + def evaluation_callback_metrics(self) -> Dict: + return self.get_metrics("evaluation_callback_metrics") + + @evaluation_callback_metrics.setter + def evaluation_callback_metrics(self, evaluation_callback_metrics: Dict) -> None: + self.set_metrics("evaluation_callback_metrics", evaluation_callback_metrics) + + @property + def logged_metrics(self) -> Dict: + return self.get_metrics("logged_metrics") + + @logged_metrics.setter + def logged_metrics(self, logged_metrics: Dict) -> None: + self.set_metrics("logged_metrics", logged_metrics) + + @property + def progress_bar_metrics(self) -> Dict: + return self.get_metrics("progress_bar_metrics") + + @progress_bar_metrics.setter + def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None: + self.set_metrics("progress_bar_metrics", progress_bar_metrics) + @property def cached_results(self) -> Union[EpochResultStore, None]: return self._cached_results.get(self._current_stage) # type: ignore + def get_metrics(self, key: str) -> Dict: + metrics_holder = getattr(self, f"_{key}", None) + model_ref = self.trainer.get_model() + metrics_holder.convert( + self.trainer.use_tpu, + model_ref.device if model_ref is not None else model_ref + ) + return metrics_holder.metrics + + def set_metrics(self, key: str, val: Any) -> None: + metrics_holder = getattr(self, f"_{key}", None) + metrics_holder.reset(val) + def set_stage(self, stage_or_testing: Union[str, bool], reset: bool = False) -> None: self._current_stage = LoggerStages.determine_stage(stage_or_testing) if reset: @@ -153,10 +199,10 @@ def cache_training_step_metrics(self, opt_closure_result): if len(pbar_metrics_tmp) > 0: self.add_progress_bar_metrics(pbar_metrics_tmp) - self.callback_metrics.update(callback_metrics_tmp) + self._callback_metrics.update(callback_metrics_tmp) # save legacy log metrics - self.logged_metrics.update(logged_metrics_tmp) + self._logged_metrics.update(logged_metrics_tmp) self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp) def log_metrics(self, metrics, grad_norm_dic, step=None, log_train_step_metrics=False): @@ -209,7 +255,7 @@ def add_progress_bar_metrics(self, metrics): if isinstance(v, torch.Tensor): v = v.item() - self.progress_bar_metrics[k] = v + self._progress_bar_metrics.metrics[k] = v self.trainer.dev_debugger.track_pbar_metrics_history(metrics) @@ -311,6 +357,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result): if 'val_loss' in flat: flat['checkpoint_on'] = flat['val_loss'] flat['early_stop_on'] = flat['val_loss'] + self.trainer.logger_connector.callback_metrics.update(flat) if self.trainer.testing: self.trainer.logger_connector.evaluation_callback_metrics.update(flat) @@ -441,15 +488,15 @@ def log_train_epoch_end_metrics( # add the metrics to the loggers and callbacks if epoch_log_metrics and len(epoch_log_metrics) > 0: self.log_metrics(epoch_log_metrics, {}) - self.callback_metrics.update(epoch_log_metrics) + self._callback_metrics.update(epoch_log_metrics) # add metrics to callbacks - self.callback_metrics.update(epoch_callback_metrics) + self._callback_metrics.update(epoch_callback_metrics) # add metrics to progress_bar and callbacks if len(epoch_progress_bar_metrics) > 0: self.add_progress_bar_metrics(epoch_progress_bar_metrics) - self.callback_metrics.update(epoch_progress_bar_metrics) + self._callback_metrics.update(epoch_progress_bar_metrics) # reset epoch loop result for next epoch self.cached_results.reset() @@ -605,4 +652,4 @@ def log_train_step_metrics(self, batch_output): grad_norm_dic = {} if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0: self.log_metrics(batch_log_metrics, grad_norm_dic, log_train_step_metrics=True) - self.callback_metrics.update(batch_log_metrics) + self._callback_metrics.update(batch_log_metrics) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py new file mode 100644 index 0000000000000..d2e2c9b7870cf --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -0,0 +1,80 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numbers +from typing import Any + +import torch + +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import _TPU_AVAILABLE + + +class MetricsHolder: + + """ + This class acts as a dictonary holder. + It holds metrics and implements conversion functions. + Those functions will be triggered within LoggerConnector + when the property is being requested from the user. + """ + + def __init__(self, to_float: bool = False): + self.metrics = {} + self._to_float = to_float + + def update(self, metrics): + self.metrics.update(metrics) + + def pop(self, key, default): + return self.metrics.pop(key, default) + + def reset(self, metrics): + self.metrics = metrics + + def convert(self, use_tpu: bool, device: torch.device): + for key, value in self.metrics.items(): + self.metrics[key] = self._convert(value, use_tpu, device) + + def _convert(self, current: Any, use_tpu: bool, device: torch.device): + if self._to_float: + return self._convert_to_float(current, use_tpu, device) + return self._convert_to_tensor(current, use_tpu, device) + + def _convert_to_float(self, current, use_tpu: bool, device: torch.device): + if isinstance(current, Metric): + current = current.compute().detach() + + if isinstance(current, torch.Tensor): + current = float(current.item()) + + elif isinstance(current, int): + current = float(current) + + return current + + def _convert_to_tensor(self, current: Any, use_tpu: bool, device: torch.device): + if current is not None: + if isinstance(current, Metric): + current = current.compute().detach() + + elif isinstance(current, numbers.Number): + if device is None: + current = torch.tensor(current, dtype=torch.float) + else: + current = torch.tensor(current, device=device, dtype=torch.float) + + if use_tpu and _TPU_AVAILABLE: + current = current.cpu() + + return current diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index 56e5765c7f4b8..f911c793b0707 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -15,6 +15,7 @@ Tests to ensure that the training loop works with a dict (1.0) """ from copy import deepcopy +from typing import Any, Callable import pytest import torch @@ -22,15 +23,17 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.step_result import Result +from pytorch_lightning.metrics import Accuracy from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator +from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel, RandomDataset -def decorator_with_arguments(fx_name='', hook_fx_name=None): - def decorator(func): - def wrapper(self, *args, **kwargs): +def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> Callable: + def decorator(func: Callable) -> Callable: + def wrapper(self, *args, **kwargs) -> Any: # Set information self._current_fx_name = fx_name self._current_hook_fx_name = hook_fx_name @@ -43,7 +46,6 @@ def wrapper(self, *args, **kwargs): return result return wrapper - return decorator @@ -425,3 +427,28 @@ def test_dataloader(self): ) trainer.fit(model) trainer.test(model, ckpt_path=None) + + +@pytest.mark.parametrize('to_float', [False, True]) +def test_metrics_holder(to_float, tmpdir): + + device = "cuda" if torch.cuda.is_available() else "cpu" + preds = torch.tensor([[0.9, 0.1]], device=device) + + def is_float(value: Any) -> bool: + return isinstance(value, float) + + excepted_function = is_float if to_float else torch.is_tensor + targets = torch.tensor([1], device=device) + acc = Accuracy().to(device) + metric_holder = MetricsHolder(to_float=to_float) + metric_holder.update({ + "x": 1, + "y": torch.tensor(2), + "z": acc(preds, targets), + }) + metric_holder.convert(False, device) + metrics = metric_holder.metrics + assert excepted_function(metrics["x"]) + assert excepted_function(metrics["y"]) + assert excepted_function(metrics["z"])