Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

### Changed

- Changed `automatic casting` for LoggerConnector `metrics` ([#5218](https://github.com/PyTorchLightning/pytorch-lightning/pull/5218))


- `stat_scores` metric now calculates stat scores over all classes and gains new parameters, in line with the new `StatScores` metric ([#4839](https://github.com/PyTorchLightning/pytorch-lightning/pull/4839))


Expand Down
13 changes: 1 addition & 12 deletions pytorch_lightning/callbacks/early_stopping.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,14 +19,12 @@
Monitor a metric and stop training when it stops improving.

"""
import numbers

import numpy as np
import torch

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_info, rank_zero_warn
from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn


class EarlyStopping(Callback):
Expand Down Expand Up @@ -196,15 +194,6 @@ def _run_early_stopping_check(self, trainer, pl_module):
# when in dev debugging
trainer.dev_debugger.track_early_stopping_history(self, current)

if current is not None:
if isinstance(current, Metric):
current = current.compute()
elif isinstance(current, numbers.Number):
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)

if trainer.use_tpu and _TPU_AVAILABLE:
current = current.cpu()

if self.monitor_op(current - self.min_delta, self.best_score):
self.best_score = current
self.wait_count = 0
Expand Down
10 changes: 1 addition & 9 deletions pytorch_lightning/callbacks/model_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@

"""

import numbers
import os
import re
from copy import deepcopy
Expand All @@ -33,7 +32,6 @@

from pytorch_lightning import _logger as log
from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn
from pytorch_lightning.utilities.cloud_io import get_filesystem
from pytorch_lightning.utilities.exceptions import MisconfigurationException
Expand Down Expand Up @@ -554,12 +552,6 @@ def _save_top_k_checkpoints(self, trainer, pl_module, metrics):
epoch = metrics.get("epoch")
step = metrics.get("step")

if current is not None:
if isinstance(current, Metric):
current = current.compute()
elif isinstance(current, numbers.Number):
current = torch.tensor(current, device=pl_module.device, dtype=torch.float)

if self.check_monitor_top_k(current):
self._update_best_and_save(current, epoch, step, trainer, pl_module, metrics)
elif self.verbose:
Expand Down Expand Up @@ -587,7 +579,7 @@ def _update_best_and_save(
self.best_k_models.pop(del_filepath)

# do not save nan, replace with +/- inf
if torch.isnan(current):
if isinstance(current, torch.Tensor) and torch.isnan(current):
current = torch.tensor(float('inf' if self.mode == "min" else '-inf'))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

should this always be on cpu? or should it be on current.device?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great question ! I am not sure. What do you think ?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IMO it should live on current.device, since all the other tensors (especially current if not nan) also live on this device.


filepath = self._get_metric_interpolated_filepath_name(ckpt_name_metrics, epoch, step, del_filepath)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -379,7 +379,7 @@ def update_logger_connector(self) -> None:

if is_train:
# Only log and add to callback epoch step during evaluation, test.
logger_connector.logged_metrics.update(batch_log_metrics)
logger_connector._logged_metrics.update(batch_log_metrics)
callback_metrics.update(batch_pbar_metrics)
callback_metrics.update(batch_log_metrics)
else:
Expand All @@ -389,8 +389,8 @@ def update_logger_connector(self) -> None:

# get logged_metrics
epoch_log_metrics = self.get_epoch_log_metrics()
logger_connector.logged_metrics.update(epoch_log_metrics)
logger_connector.logged_metrics.update(epoch=self.trainer.current_epoch)
logger_connector._logged_metrics.update(epoch_log_metrics)
logger_connector._logged_metrics.update({"epoch": self.trainer.current_epoch})

# get forked_metrics
forked_metrics = self.get_forked_metrics()
Expand All @@ -403,8 +403,8 @@ def update_logger_connector(self) -> None:
logger_connector.evaluation_callback_metrics.update(callback_metrics)

# update callback_metrics
logger_connector.callback_metrics.update(callback_metrics)
logger_connector.callback_metrics.pop("epoch", None)
logger_connector._callback_metrics.update(callback_metrics)
logger_connector._callback_metrics.pop("epoch", None)

batch_pbar_metrics.pop("debug_epoch", None)
return batch_pbar_metrics, batch_log_metrics
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from copy import deepcopy
import os
from pprint import pprint
from typing import Iterable, Union
from typing import Any, Iterable, Union, Dict

import torch

Expand All @@ -23,6 +23,7 @@
from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.utilities import flatten_dict
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.model_helpers import is_overridden
Expand All @@ -31,19 +32,64 @@
class LoggerConnector:
def __init__(self, trainer):
self.trainer = trainer
self.callback_metrics = {}
self.evaluation_callback_metrics = {}
self.logged_metrics = {}
self.progress_bar_metrics = {}
self._callback_metrics = MetricsHolder()
self._evaluation_callback_metrics = MetricsHolder(to_float=True)
self._logged_metrics = MetricsHolder()
self._progress_bar_metrics = MetricsHolder()
self.eval_loop_results = []
self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in LoggerStages}
self._callback_hook_validator = CallbackHookNameValidator()
self._current_stage = None

@property
def callback_metrics(self) -> Dict:
return self.get_metrics("callback_metrics")

@callback_metrics.setter
def callback_metrics(self, callback_metrics: Dict) -> None:
self.set_metrics("callback_metrics", callback_metrics)

@property
def evaluation_callback_metrics(self) -> Dict:
return self.get_metrics("evaluation_callback_metrics")

@evaluation_callback_metrics.setter
def evaluation_callback_metrics(self, evaluation_callback_metrics: Dict) -> None:
self.set_metrics("evaluation_callback_metrics", evaluation_callback_metrics)

@property
def logged_metrics(self) -> Dict:
return self.get_metrics("logged_metrics")

@logged_metrics.setter
def logged_metrics(self, logged_metrics: Dict) -> None:
self.set_metrics("logged_metrics", logged_metrics)

@property
def progress_bar_metrics(self) -> Dict:
return self.get_metrics("progress_bar_metrics")

@progress_bar_metrics.setter
def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None:
self.set_metrics("progress_bar_metrics", progress_bar_metrics)

@property
def cached_results(self) -> Union[EpochResultStore, None]:
return self._cached_results.get(self._current_stage) # type: ignore

def get_metrics(self, key: str) -> Dict:
metrics_holder = getattr(self, f"_{key}", None)
model_ref = self.trainer.get_model()
metrics_holder.convert(
self.trainer.use_tpu,
model_ref.device if model_ref is not None else model_ref
)
return metrics_holder.metrics

def set_metrics(self, key: str, val: Any) -> None:
metrics_holder = getattr(self, f"_{key}", None)
metrics_holder.reset(val)

def set_stage(self, stage_or_testing: Union[str, bool], reset: bool = False) -> None:
self._current_stage = LoggerStages.determine_stage(stage_or_testing)
if reset:
Expand Down Expand Up @@ -153,10 +199,10 @@ def cache_training_step_metrics(self, opt_closure_result):
if len(pbar_metrics_tmp) > 0:
self.add_progress_bar_metrics(pbar_metrics_tmp)

self.callback_metrics.update(callback_metrics_tmp)
self._callback_metrics.update(callback_metrics_tmp)

# save legacy log metrics
self.logged_metrics.update(logged_metrics_tmp)
self._logged_metrics.update(logged_metrics_tmp)
self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp)

def log_metrics(self, metrics, grad_norm_dic, step=None, log_train_step_metrics=False):
Expand Down Expand Up @@ -209,7 +255,7 @@ def add_progress_bar_metrics(self, metrics):
if isinstance(v, torch.Tensor):
v = v.item()

self.progress_bar_metrics[k] = v
self._progress_bar_metrics.metrics[k] = v

self.trainer.dev_debugger.track_pbar_metrics_history(metrics)

Expand Down Expand Up @@ -311,6 +357,7 @@ def _track_callback_metrics(self, eval_results, using_eval_result):
if 'val_loss' in flat:
flat['checkpoint_on'] = flat['val_loss']
flat['early_stop_on'] = flat['val_loss']

self.trainer.logger_connector.callback_metrics.update(flat)
if self.trainer.testing:
self.trainer.logger_connector.evaluation_callback_metrics.update(flat)
Expand Down Expand Up @@ -441,15 +488,15 @@ def log_train_epoch_end_metrics(
# add the metrics to the loggers and callbacks
if epoch_log_metrics and len(epoch_log_metrics) > 0:
self.log_metrics(epoch_log_metrics, {})
self.callback_metrics.update(epoch_log_metrics)
self._callback_metrics.update(epoch_log_metrics)

# add metrics to callbacks
self.callback_metrics.update(epoch_callback_metrics)
self._callback_metrics.update(epoch_callback_metrics)

# add metrics to progress_bar and callbacks
if len(epoch_progress_bar_metrics) > 0:
self.add_progress_bar_metrics(epoch_progress_bar_metrics)
self.callback_metrics.update(epoch_progress_bar_metrics)
self._callback_metrics.update(epoch_progress_bar_metrics)

# reset epoch loop result for next epoch
self.cached_results.reset()
Expand Down Expand Up @@ -605,4 +652,4 @@ def log_train_step_metrics(self, batch_output):
grad_norm_dic = {}
if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0:
self.log_metrics(batch_log_metrics, grad_norm_dic, log_train_step_metrics=True)
self.callback_metrics.update(batch_log_metrics)
self._callback_metrics.update(batch_log_metrics)
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
# Copyright The PyTorch Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numbers
from typing import Any

import torch

from pytorch_lightning.metrics.metric import Metric
from pytorch_lightning.utilities import _TPU_AVAILABLE


class MetricsHolder:

"""
This class acts as a dictonary holder.
It holds metrics and implements conversion functions.
Those functions will be triggered within LoggerConnector
when the property is being requested from the user.
"""

def __init__(self, to_float: bool = False):
self.metrics = {}
self._to_float = to_float

def update(self, metrics):
self.metrics.update(metrics)

def pop(self, key, default):
return self.metrics.pop(key, default)

def reset(self, metrics):
self.metrics = metrics

def convert(self, use_tpu: bool, device: torch.device):
for key, value in self.metrics.items():
self.metrics[key] = self._convert(value, use_tpu, device)

def _convert(self, current: Any, use_tpu: bool, device: torch.device):
if self._to_float:
return self._convert_to_float(current, use_tpu, device)
return self._convert_to_tensor(current, use_tpu, device)

def _convert_to_float(self, current, use_tpu: bool, device: torch.device):
if isinstance(current, Metric):
current = current.compute().detach()

if isinstance(current, torch.Tensor):
current = float(current.item())

elif isinstance(current, int):
current = float(current)

return current

def _convert_to_tensor(self, current: Any, use_tpu: bool, device: torch.device):
if current is not None:
if isinstance(current, Metric):
current = current.compute().detach()

elif isinstance(current, numbers.Number):
if device is None:
current = torch.tensor(current, dtype=torch.float)
else:
current = torch.tensor(current, device=device, dtype=torch.float)

if use_tpu and _TPU_AVAILABLE:
current = current.cpu()

return current
35 changes: 31 additions & 4 deletions tests/trainer/logging/test_logger_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,25 @@
Tests to ensure that the training loop works with a dict (1.0)
"""
from copy import deepcopy
from typing import Any, Callable

import pytest
import torch
from torch.utils.data import DataLoader

from pytorch_lightning.callbacks.base import Callback
from pytorch_lightning.core.step_result import Result
from pytorch_lightning.metrics import Accuracy
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator
from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.base.boring_model import BoringModel, RandomDataset


def decorator_with_arguments(fx_name='', hook_fx_name=None):
def decorator(func):
def wrapper(self, *args, **kwargs):
def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> Callable:
def decorator(func: Callable) -> Callable:
def wrapper(self, *args, **kwargs) -> Any:
# Set information
self._current_fx_name = fx_name
self._current_hook_fx_name = hook_fx_name
Expand All @@ -43,7 +46,6 @@ def wrapper(self, *args, **kwargs):
return result

return wrapper

return decorator


Expand Down Expand Up @@ -425,3 +427,28 @@ def test_dataloader(self):
)
trainer.fit(model)
trainer.test(model, ckpt_path=None)


@pytest.mark.parametrize('to_float', [False, True])
def test_metrics_holder(to_float, tmpdir):

device = "cuda" if torch.cuda.is_available() else "cpu"
preds = torch.tensor([[0.9, 0.1]], device=device)

def is_float(value: Any) -> bool:
return isinstance(value, float)

excepted_function = is_float if to_float else torch.is_tensor
targets = torch.tensor([1], device=device)
acc = Accuracy().to(device)
metric_holder = MetricsHolder(to_float=to_float)
metric_holder.update({
"x": 1,
"y": torch.tensor(2),
"z": acc(preds, targets),
})
metric_holder.convert(False, device)
metrics = metric_holder.metrics
assert excepted_function(metrics["x"])
assert excepted_function(metrics["y"])
assert excepted_function(metrics["z"])