From a48ca183d46db609f82de33990b92d4848f7c2b3 Mon Sep 17 00:00:00 2001 From: chaton Date: Wed, 16 Dec 2020 22:08:06 +0100 Subject: [PATCH 01/14] [bug-fix] Metric reduction with Logging (#5150) * add test * resolve bug * udpate test * wrongly copy / paste * update test * resolve a second bug Co-authored-by: Ubuntu --- pytorch_lightning/callbacks/early_stopping.py | 7 ++- .../callbacks/model_checkpoint.py | 10 +++- pytorch_lightning/core/step_result.py | 5 +- .../test_train_loop_logging_1_0.py | 49 ++++++++++++++++++- 4 files changed, 64 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 066effc68a03c..3576420715cd5 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -199,8 +199,11 @@ def _run_early_stopping_check(self, trainer, pl_module): # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) - if not isinstance(current, torch.Tensor): - current = torch.tensor(current, device=pl_module.device) + if current is not None: + if isinstance(current, Metric): + current = current.compute() + elif isinstance(current, numbers.Number): + current = torch.tensor(current, device=pl_module.device, dtype=torch.float) if trainer.use_tpu and _TPU_AVAILABLE: current = current.cpu() diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 64b7353554220..b2dff332e99ce 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -20,6 +20,7 @@ """ +import numbers import os import re from copy import deepcopy @@ -32,6 +33,8 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -574,8 +577,11 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath): epoch = metrics.get("epoch") step = metrics.get("step") - if not isinstance(current, torch.Tensor) and current is not None: - current = torch.tensor(current, device=pl_module.device) + if current is not None: + if isinstance(current, Metric): + current = current.compute() + elif isinstance(current, numbers.Number): + current = torch.tensor(current, device=pl_module.device, dtype=torch.float) if self.check_monitor_top_k(current): self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module) diff --git a/pytorch_lightning/core/step_result.py b/pytorch_lightning/core/step_result.py index 7101ec17c4bbc..61cd11f900ea7 100644 --- a/pytorch_lightning/core/step_result.py +++ b/pytorch_lightning/core/step_result.py @@ -367,7 +367,10 @@ def get_forked_metrics(self, add_dataloader_idx=False): dl_key = self._add_dataloader_idx(k, options["dataloader_idx"], add_dataloader_idx) if options['forked']: - result[dl_key] = self[k] + if isinstance(self[k], Metric): + result[dl_key] = self[k].compute().detach() + else: + result[dl_key] = self[k] return result diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index a77b4eb451e28..52d48e2460d58 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -26,8 +26,8 @@ from torch.utils.data import Dataset import pytorch_lightning as pl -from pytorch_lightning import Trainer, callbacks -from pytorch_lightning.callbacks import ModelCheckpoint +from pytorch_lightning import callbacks, Trainer +from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset from tests.base.deterministic_model import DeterministicModel @@ -817,3 +817,48 @@ def on_train_epoch_end(self, trainer, pl_module, outputs): 'on_epoch_end': 5, 'on_train_epoch_end': 6} assert trainer.callback_metrics == expected + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") +def test_metric_are_properly_reduced(tmpdir): + class TestingModel(BoringModel): + def __init__(self, *args, **kwargs): + super().__init__() + self.train_acc = pl.metrics.Accuracy() + self.val_acc = pl.metrics.Accuracy() + + def training_step(self, batch, batch_idx): + self.train_acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device)) + self.log('train_acc', self.train_acc, on_step=True, on_epoch=True) + return super().training_step(batch, batch_idx) + + def validation_step(self, batch, batch_idx): + preds = torch.tensor(0, device=self.device) + targets = torch.tensor(1, device=self.device) + if batch_idx < 8: + targets = preds + self.val_acc(preds, targets) + self.log('val_acc', self.val_acc, on_step=True, on_epoch=True) + return super().validation_step(batch, batch_idx) + + early_stop = EarlyStopping(monitor='val_acc', mode='max') + + checkpoint = ModelCheckpoint( + monitor='val_acc', + save_last=True, + save_top_k=2, + mode='max', + ) + + model = TestingModel() + trainer = Trainer( + default_root_dir=tmpdir, + gpus=1, + max_epochs=2, + limit_train_batches=5, + limit_val_batches=32, + callbacks=[early_stop, checkpoint]) + trainer.fit(model) + + assert trainer.callback_metrics["val_acc"] == 8 / 32. + assert "train_acc" in trainer.callback_metrics From 884a4544b16728a50ce369c71b218d501c00234a Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 21 Dec 2020 14:12:58 +0000 Subject: [PATCH 02/14] iupdate --- .../test_train_loop_logging_1_0.py | 78 ------------------- 1 file changed, 78 deletions(-) diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 52d48e2460d58..8e0cfb933bc52 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -741,84 +741,6 @@ def validation_step(self, batch, batch_idx): assert trainer.logged_metrics['foo'] == fake_result assert trainer.logged_metrics['bar'] == fake_result - -def test_progress_bar_dict_contains_values_on_train_epoch_end(tmpdir): - class TestModel(BoringModel): - def training_step(self, *args): - self.log("foo", torch.tensor(self.current_epoch), on_step=False, on_epoch=True, prog_bar=True) - return super().training_step(*args) - - def on_epoch_end(self): - self.epoch_end_called = True - self.log('foo_2', torch.tensor(self.current_epoch), prog_bar=True, - on_epoch=True, sync_dist=True, sync_dist_op='sum') - - def on_train_epoch_end(self, *_): - self.on_train_epoch_end_called = True - assert self.trainer.progress_bar_dict["foo"] == self.current_epoch - assert self.trainer.progress_bar_dict["foo_2"] == self.current_epoch - - trainer = Trainer( - default_root_dir=tmpdir, - max_epochs=2, - limit_train_batches=1, - limit_val_batches=0, - checkpoint_callback=False, - logger=False, - weights_summary=None, - progress_bar_refresh_rate=0, - ) - model = TestModel() - trainer.fit(model) - assert model.epoch_end_called - assert model.on_train_epoch_end_called - - -def test_logging_in_callbacks_with_log_function(tmpdir): - """ - Tests ensure self.log can be used directly in callbacks. - """ - class LoggingCallback(callbacks.Callback): - def on_train_start(self, trainer, pl_module): - self.log("on_train_start", 1) - - def on_train_epoch_start(self, trainer, pl_module): - self.log("on_train_epoch_start", 2) - - def on_batch_end(self, trainer, pl_module): - self.log("on_batch_end", 3) - - def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): - self.log("on_train_batch_end", 4) - - def on_epoch_end(self, trainer, pl_module): - self.log("on_epoch_end", 5) - - def on_train_epoch_end(self, trainer, pl_module, outputs): - self.log("on_train_epoch_end", 6) - self.callback_metrics = trainer.logger_connector.callback_metrics - - model = BoringModel() - trainer = Trainer( - default_root_dir=tmpdir, - limit_train_batches=1, - limit_val_batches=1, - max_epochs=1, - weights_summary=None, - callbacks=[LoggingCallback()] - ) - trainer.fit(model) - - expected = { - 'on_train_start': 1, - 'on_train_epoch_start': 2, - 'on_batch_end': 3, - 'on_train_batch_end': 4, - 'on_epoch_end': 5, - 'on_train_epoch_end': 6} - assert trainer.callback_metrics == expected - - @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") def test_metric_are_properly_reduced(tmpdir): class TestingModel(BoringModel): From f30de4cf3979d3b30e7e447f87c978789ffd3fb6 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 21 Dec 2020 14:27:42 +0000 Subject: [PATCH 03/14] resolve bugs --- pytorch_lightning/callbacks/early_stopping.py | 9 -- .../callbacks/model_checkpoint.py | 9 +- .../logger_connector/epoch_result_store.py | 12 +-- .../logger_connector/logger_connector.py | 90 ++++++++++++++----- .../logger_connector/metrics_holder.py | 75 ++++++++++++++++ 5 files changed, 150 insertions(+), 45 deletions(-) create mode 100644 pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index 3576420715cd5..d84d89c306a2e 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -199,15 +199,6 @@ def _run_early_stopping_check(self, trainer, pl_module): # when in dev debugging trainer.dev_debugger.track_early_stopping_history(self, current) - if current is not None: - if isinstance(current, Metric): - current = current.compute() - elif isinstance(current, numbers.Number): - current = torch.tensor(current, device=pl_module.device, dtype=torch.float) - - if trainer.use_tpu and _TPU_AVAILABLE: - current = current.cpu() - if self.monitor_op(current - self.min_delta, self.best_score): self.best_score = current self.wait_count = 0 diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index b2dff332e99ce..44dae3821ee96 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -33,7 +33,6 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem @@ -577,12 +576,6 @@ def _save_top_k_checkpoints(self, metrics, trainer, pl_module, filepath): epoch = metrics.get("epoch") step = metrics.get("step") - if current is not None: - if isinstance(current, Metric): - current = current.compute() - elif isinstance(current, numbers.Number): - current = torch.tensor(current, device=pl_module.device, dtype=torch.float) - if self.check_monitor_top_k(current): self._update_best_and_save(filepath, current, epoch, step, trainer, pl_module) elif self.verbose: @@ -611,7 +604,7 @@ def _update_best_and_save( del_list.append(delpath) # do not save nan, replace with +/- inf - if torch.isnan(current): + if isinstance(current, torch.Tensor) and torch.isnan(current): current = torch.tensor(float('inf' if self.mode == "min" else '-inf')) # save the current score diff --git a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py index 28025859814cc..bb664e45ba0db 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/epoch_result_store.py @@ -375,7 +375,7 @@ def update_logger_connector(self) -> None: if is_train: # Only log and add to callback epoch step during evaluation, test. - logger_connector.logged_metrics.update(batch_log_metrics) + logger_connector._logged_metrics.update(batch_log_metrics) callback_metrics.update(batch_pbar_metrics) callback_metrics.update(batch_log_metrics) else: @@ -385,8 +385,8 @@ def update_logger_connector(self) -> None: # get logged_metrics epoch_log_metrics = self.get_epoch_log_metrics() - logger_connector.logged_metrics.update(epoch_log_metrics) - logger_connector.logged_metrics.update(epoch=self.trainer.current_epoch) + logger_connector._logged_metrics.update(epoch_log_metrics) + logger_connector._logged_metrics.update({"epoch": self.trainer.current_epoch}) # get forked_metrics forked_metrics = self.get_forked_metrics() @@ -396,11 +396,11 @@ def update_logger_connector(self) -> None: callback_metrics.update(forked_metrics) if not is_train: - logger_connector.evaluation_callback_metrics.update(callback_metrics) + logger_connector._evaluation_callback_metrics.update(callback_metrics) # update callback_metrics - logger_connector.callback_metrics.update(callback_metrics) - logger_connector.callback_metrics.pop("epoch", None) + logger_connector._callback_metrics.update(callback_metrics) + logger_connector._callback_metrics.pop("epoch", None) batch_pbar_metrics.pop("debug_epoch", None) return batch_pbar_metrics, batch_log_metrics diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index a3f86f62874ca..1cdf1ca64edaa 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -14,7 +14,7 @@ import os from copy import deepcopy from pprint import pprint -from typing import Iterable, Union +from typing import Iterable, Union, Any import torch @@ -25,25 +25,71 @@ from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages from pytorch_lightning.utilities import flatten_dict from pytorch_lightning.utilities.exceptions import MisconfigurationException +from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.utilities.model_utils import is_overridden class LoggerConnector: def __init__(self, trainer): self.trainer = trainer - self.callback_metrics = {} - self.evaluation_callback_metrics = {} - self.logged_metrics = {} - self.progress_bar_metrics = {} + self._callback_metrics = MetricsHolder() + self._evaluation_callback_metrics = MetricsHolder() + self._logged_metrics = MetricsHolder() + self._progress_bar_metrics = MetricsHolder() self.eval_loop_results = [] self._cached_results = {stage: EpochResultStore(trainer, stage) for stage in LoggerStages} self._callback_hook_validator = CallbackHookNameValidator() self._current_stage = None + @property + def callback_metrics(self): + return self.get_metrics("callback_metrics") + + @callback_metrics.setter + def callback_metrics(self, callback_metrics): + self.set_metrics("callback_metrics", callback_metrics) + + @property + def evaluation_callback_metrics(self): + return self.get_metrics("evaluation_callback_metrics") + + @evaluation_callback_metrics.setter + def evaluation_callback_metrics(self, evaluation_callback_metrics): + self.set_metrics("evaluation_callback_metrics", evaluation_callback_metrics) + + @property + def logged_metrics(self): + return self.get_metrics("logged_metrics") + + @logged_metrics.setter + def logged_metrics(self, logged_metrics): + self.set_metrics("logged_metrics", logged_metrics) + + @property + def progress_bar_metrics(self): + return self.get_metrics("progress_bar_metrics") + + @progress_bar_metrics.setter + def progress_bar_metrics(self, progress_bar_metrics): + self.set_metrics("progress_bar_metrics", progress_bar_metrics) + @property def cached_results(self) -> Union[EpochResultStore, None]: return self._cached_results.get(self._current_stage) # type: ignore + def get_metrics(self, key): + metrics_holder = getattr(self, f"_{key}", None) + model_ref = self.trainer.get_model() + metrics_holder.convert( + self.trainer.use_tpu, + model_ref.device if model_ref is not None else model_ref + ) + return metrics_holder.metrics + + def set_metrics(self, key: str, val: Any): + metrics_holder = getattr(self, f"_{key}", None) + metrics_holder.reset(val) + def set_stage(self, stage_or_testing: Union[str, bool], reset: bool = False) -> None: self._current_stage = LoggerStages.determine_stage(stage_or_testing) if reset: @@ -153,10 +199,10 @@ def cache_training_step_metrics(self, opt_closure_result): if len(pbar_metrics_tmp) > 0: self.add_progress_bar_metrics(pbar_metrics_tmp) - self.callback_metrics.update(callback_metrics_tmp) + self._callback_metrics.update(callback_metrics_tmp) # save legacy log metrics - self.logged_metrics.update(logged_metrics_tmp) + self._logged_metrics.update(logged_metrics_tmp) self.cached_results.legacy_batch_log_metrics.update(logged_metrics_tmp) def log_metrics(self, metrics, grad_norm_dic, step=None, log_train_step_metrics=False): @@ -209,7 +255,7 @@ def add_progress_bar_metrics(self, metrics): if isinstance(v, torch.Tensor): v = v.item() - self.progress_bar_metrics[k] = v + self._progress_bar_metrics.metrics[k] = v self.trainer.dev_debugger.track_pbar_metrics_history(metrics) @@ -275,11 +321,11 @@ def _track_callback_metrics(self, eval_results, using_eval_result): if using_eval_result: if isinstance(eval_results, list): for eval_result in eval_results: - self.trainer.logger_connector.callback_metrics.update(eval_result.callback_metrics) - self.trainer.logger_connector.evaluation_callback_metrics.update(eval_result.callback_metrics) + self.trainer.logger_connector._callback_metrics.update(eval_result.callback_metrics) + self.trainer.logger_connector._evaluation_callback_metrics.update(eval_result.callback_metrics) else: - self.trainer.logger_connector.callback_metrics.update(eval_results.callback_metrics) - self.trainer.logger_connector.evaluation_callback_metrics.update(eval_results.callback_metrics) + self.trainer.logger_connector._callback_metrics.update(eval_results.callback_metrics) + self.trainer.logger_connector._evaluation_callback_metrics.update(eval_results.callback_metrics) else: flat = {} if isinstance(eval_results, list): @@ -294,8 +340,8 @@ def _track_callback_metrics(self, eval_results, using_eval_result): if 'val_loss' in flat: flat['checkpoint_on'] = flat['val_loss'] flat['early_stop_on'] = flat['val_loss'] - self.trainer.logger_connector.callback_metrics.update(flat) - self.trainer.logger_connector.evaluation_callback_metrics.update(flat) + self.trainer.logger_connector._callback_metrics.update(flat) + self.trainer.logger_connector._evaluation_callback_metrics.update(flat) else: # with a scalar return, auto set it to "val_loss" for callbacks if isinstance(eval_results, torch.Tensor): @@ -307,8 +353,8 @@ def _track_callback_metrics(self, eval_results, using_eval_result): if 'val_loss' in flat: flat['checkpoint_on'] = flat['val_loss'] flat['early_stop_on'] = flat['val_loss'] - self.trainer.logger_connector.callback_metrics.update(flat) - self.trainer.logger_connector.evaluation_callback_metrics.update(flat) + self.trainer.logger_connector._callback_metrics.update(flat) + self.trainer.logger_connector._evaluation_callback_metrics.update(flat) def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metrics, log_metrics, callback_metrics): # eval loop returns all metrics @@ -324,8 +370,8 @@ def __process_eval_epoch_end_results_and_log_legacy_update(self, prog_bar_metric # track metrics for callbacks (all prog bar, logged and callback metrics) callback_metrics.update(log_metrics) callback_metrics.update(prog_bar_metrics) - self.trainer.logger_connector.callback_metrics.update(callback_metrics) - self.trainer.logger_connector.evaluation_callback_metrics.update(callback_metrics) + self.trainer.logger_connector._callback_metrics.update(callback_metrics) + self.trainer.logger_connector._evaluation_callback_metrics.update(callback_metrics) if len(dataloader_result_metrics) > 0: self.eval_loop_results.append(dataloader_result_metrics) @@ -435,15 +481,15 @@ def log_train_epoch_end_metrics( # add the metrics to the loggers and callbacks if epoch_log_metrics and len(epoch_log_metrics) > 0: self.log_metrics(epoch_log_metrics, {}) - self.callback_metrics.update(epoch_log_metrics) + self._callback_metrics.update(epoch_log_metrics) # add metrics to callbacks - self.callback_metrics.update(epoch_callback_metrics) + self._callback_metrics.update(epoch_callback_metrics) # add metrics to progress_bar and callbacks if len(epoch_progress_bar_metrics) > 0: self.add_progress_bar_metrics(epoch_progress_bar_metrics) - self.callback_metrics.update(epoch_progress_bar_metrics) + self._callback_metrics.update(epoch_progress_bar_metrics) # reset epoch loop result for next epoch self.cached_results.reset() @@ -599,4 +645,4 @@ def log_train_step_metrics(self, batch_output): grad_norm_dic = {} if len(batch_log_metrics) > 0 or len(grad_norm_dic) > 0: self.log_metrics(batch_log_metrics, grad_norm_dic, log_train_step_metrics=True) - self.callback_metrics.update(batch_log_metrics) + self._callback_metrics.update(batch_log_metrics) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py new file mode 100644 index 0000000000000..89bd1a5177d73 --- /dev/null +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -0,0 +1,75 @@ +# Copyright The PyTorch Lightning team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numbers +from typing import Any, Dict + +import torch + +from pytorch_lightning.metrics.metric import Metric +from pytorch_lightning.utilities import _TPU_AVAILABLE + + +class MetricsHolder: + + """ + This class hold metris and implement convertion function which are called when user + asked for them. + """ + + def __init__(self, to_float: bool = False): + self.metrics = {} + self._convert = self._convert_to_tensor + if to_float: + self._convert = self._convert_to_float + + def update(self, metrics): + self.metrics.update(metrics) + + def pop(self, key, default): + return self.metrics.pop(key, default) + + def reset(self, metrics): + self.metrics = metrics + + def convert(self, use_tpu: bool, device: torch.device): + for key, value in self.metrics.items(): + self.metrics[key] = self._convert(value, use_tpu, device) + + def _convert_to_float(self, current, use_tpu: bool, device: torch.device): + if isinstance(current, Metric): + current = current.compute() + + if isinstance(current, torch.Tensor): + current = float(current.item()) + + elif isinstance(current, int): + current = float(current) + + return current + + def _convert_to_tensor(self, current: Any, use_tpu: bool, device: torch.device): + if current is not None: + if isinstance(current, Metric): + current = current.compute() + + elif isinstance(current, numbers.Number): + if device is None: + current = torch.tensor(current, dtype=torch.float) + else: + current = torch.tensor(current, device=device, dtype=torch.float) + + if use_tpu and _TPU_AVAILABLE: + current = current.cpu() + + return current \ No newline at end of file From d6bae3435939fe957fd9b46265943f204c90781d Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 21 Dec 2020 14:31:45 +0000 Subject: [PATCH 04/14] add test back --- .../test_train_loop_logging_1_0.py | 77 +++++++++++++++++++ 1 file changed, 77 insertions(+) diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 8e0cfb933bc52..9588c480e1091 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -784,3 +784,80 @@ def validation_step(self, batch, batch_idx): assert trainer.callback_metrics["val_acc"] == 8 / 32. assert "train_acc" in trainer.callback_metrics + + +def test_progress_bar_dict_contains_values_on_train_epoch_end(tmpdir): + class TestModel(BoringModel): + def training_step(self, *args): + self.log("foo", torch.tensor(self.current_epoch), on_step=False, on_epoch=True, prog_bar=True) + return super().training_step(*args) + + def on_epoch_end(self): + self.epoch_end_called = True + self.log('foo_2', torch.tensor(self.current_epoch), prog_bar=True, + on_epoch=True, sync_dist=True, sync_dist_op='sum') + + def on_train_epoch_end(self, *_): + self.on_train_epoch_end_called = True + assert self.trainer.progress_bar_dict["foo"] == self.current_epoch + assert self.trainer.progress_bar_dict["foo_2"] == self.current_epoch + + trainer = Trainer( + default_root_dir=tmpdir, + max_epochs=2, + limit_train_batches=1, + limit_val_batches=0, + checkpoint_callback=False, + logger=False, + weights_summary=None, + progress_bar_refresh_rate=0, + ) + model = TestModel() + trainer.fit(model) + assert model.epoch_end_called + assert model.on_train_epoch_end_called + + +def test_logging_in_callbacks_with_log_function(tmpdir): + """ + Tests ensure self.log can be used directly in callbacks. + """ + class LoggingCallback(callbacks.Callback): + def on_train_start(self, trainer, pl_module): + self.log("on_train_start", 1) + + def on_train_epoch_start(self, trainer, pl_module): + self.log("on_train_epoch_start", 2) + + def on_batch_end(self, trainer, pl_module): + self.log("on_batch_end", 3) + + def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx): + self.log("on_train_batch_end", 4) + + def on_epoch_end(self, trainer, pl_module): + self.log("on_epoch_end", 5) + + def on_train_epoch_end(self, trainer, pl_module, outputs): + self.log("on_train_epoch_end", 6) + self.callback_metrics = trainer.logger_connector.callback_metrics + + model = BoringModel() + trainer = Trainer( + default_root_dir=tmpdir, + limit_train_batches=1, + limit_val_batches=1, + max_epochs=1, + weights_summary=None, + callbacks=[LoggingCallback()] + ) + trainer.fit(model) + + expected = { + 'on_train_start': 1, + 'on_train_epoch_start': 2, + 'on_batch_end': 3, + 'on_train_batch_end': 4, + 'on_epoch_end': 5, + 'on_train_epoch_end': 6} + assert trainer.callback_metrics == expected From 303e85d9989daf8e1c2e44f702e9d310b9822f93 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 21 Dec 2020 15:34:05 +0100 Subject: [PATCH 05/14] correct flake8 --- pytorch_lightning/callbacks/early_stopping.py | 2 +- pytorch_lightning/callbacks/model_checkpoint.py | 2 -- .../trainer/connectors/logger_connector/metrics_holder.py | 6 +++--- tests/trainer/logging_tests/test_train_loop_logging_1_0.py | 3 ++- 4 files changed, 6 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index d84d89c306a2e..8454ae20d0e16 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -24,7 +24,7 @@ import torch from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn, _TPU_AVAILABLE +from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn class EarlyStopping(Callback): diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index 44dae3821ee96..322dc8fd6cae2 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -20,7 +20,6 @@ """ -import numbers import os import re from copy import deepcopy @@ -33,7 +32,6 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.plugins.rpc_plugin import RPCPlugin from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index 89bd1a5177d73..5d13fb7bdebdc 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import numbers -from typing import Any, Dict +from typing import Any import torch @@ -62,7 +62,7 @@ def _convert_to_tensor(self, current: Any, use_tpu: bool, device: torch.device): if current is not None: if isinstance(current, Metric): current = current.compute() - + elif isinstance(current, numbers.Number): if device is None: current = torch.tensor(current, dtype=torch.float) @@ -72,4 +72,4 @@ def _convert_to_tensor(self, current: Any, use_tpu: bool, device: torch.device): if use_tpu and _TPU_AVAILABLE: current = current.cpu() - return current \ No newline at end of file + return current diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 9588c480e1091..be7d6ce14bca2 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -26,7 +26,7 @@ from torch.utils.data import Dataset import pytorch_lightning as pl -from pytorch_lightning import callbacks, Trainer +from pytorch_lightning import Trainer, callbacks from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.core.lightning import LightningModule from tests.base.boring_model import BoringModel, RandomDictDataset, RandomDictStringDataset @@ -741,6 +741,7 @@ def validation_step(self, batch, batch_idx): assert trainer.logged_metrics['foo'] == fake_result assert trainer.logged_metrics['bar'] == fake_result + @pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") def test_metric_are_properly_reduced(tmpdir): class TestingModel(BoringModel): From 3af20feeaf60a933c8e2c00f288c1f994cebc42e Mon Sep 17 00:00:00 2001 From: tchaton Date: Wed, 23 Dec 2020 13:59:52 +0100 Subject: [PATCH 06/14] resolve flake8 --- tests/deprecated_api/test_remove_1-3.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/deprecated_api/test_remove_1-3.py b/tests/deprecated_api/test_remove_1-3.py index c855086c9526d..3deb4e219fcee 100644 --- a/tests/deprecated_api/test_remove_1-3.py +++ b/tests/deprecated_api/test_remove_1-3.py @@ -21,7 +21,6 @@ from pytorch_lightning import LightningModule, Trainer from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint from pytorch_lightning.profiler.profilers import PassThroughProfiler, SimpleProfiler -from tests.deprecated_api import _soft_unimport_module def test_v1_3_0_deprecated_arguments(tmpdir): From 5d3d7ce2db0da75a064a2740444c4776591e6df2 Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 28 Dec 2020 11:51:44 +0100 Subject: [PATCH 07/14] update on comments --- .../connectors/logger_connector/logger_connector.py | 6 +++--- .../connectors/logger_connector/metrics_holder.py | 10 ++++++---- 2 files changed, 9 insertions(+), 7 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index c38df217518fc..531e3d013ca23 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -14,7 +14,7 @@ import os from copy import deepcopy from pprint import pprint -from typing import Iterable, Union, Any +from typing import Any, Iterable, Union import torch @@ -23,9 +23,9 @@ from pytorch_lightning.loggers import LoggerCollection, TensorBoardLogger from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.trainer.connectors.logger_connector.epoch_result_store import EpochResultStore, LoggerStages +from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.utilities import flatten_dict from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.utilities.model_helpers import is_overridden @@ -33,7 +33,7 @@ class LoggerConnector: def __init__(self, trainer): self.trainer = trainer self._callback_metrics = MetricsHolder() - self._evaluation_callback_metrics = MetricsHolder() + self._evaluation_callback_metrics = MetricsHolder(to_float=True) self._logged_metrics = MetricsHolder() self._progress_bar_metrics = MetricsHolder() self.eval_loop_results = [] diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index 5d13fb7bdebdc..91fb2c2e176db 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -23,8 +23,10 @@ class MetricsHolder: """ - This class hold metris and implement convertion function which are called when user - asked for them. + This class acts as a dictonary holder. + It holds metris and implement convertion functions. + Those functions will be triggered within LoggerConnector + when the property is being requested from the user. """ def __init__(self, to_float: bool = False): @@ -48,7 +50,7 @@ def convert(self, use_tpu: bool, device: torch.device): def _convert_to_float(self, current, use_tpu: bool, device: torch.device): if isinstance(current, Metric): - current = current.compute() + current = current.compute().detach() if isinstance(current, torch.Tensor): current = float(current.item()) @@ -61,7 +63,7 @@ def _convert_to_float(self, current, use_tpu: bool, device: torch.device): def _convert_to_tensor(self, current: Any, use_tpu: bool, device: torch.device): if current is not None: if isinstance(current, Metric): - current = current.compute() + current = current.compute().detach() elif isinstance(current, numbers.Number): if device is None: From fe95959907020adf02443e8263f986f75976e9d4 Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 28 Dec 2020 12:46:30 +0000 Subject: [PATCH 08/14] update tests --- .../logging_tests/test_train_loop_logging_1_0.py | 15 +++++++-------- 1 file changed, 7 insertions(+), 8 deletions(-) diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index be7d6ce14bca2..99db894c39311 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -747,19 +747,18 @@ def test_metric_are_properly_reduced(tmpdir): class TestingModel(BoringModel): def __init__(self, *args, **kwargs): super().__init__() - self.train_acc = pl.metrics.Accuracy() self.val_acc = pl.metrics.Accuracy() def training_step(self, batch, batch_idx): - self.train_acc(torch.rand(1, 3, device=self.device), torch.randint(0, 2, (1,), device=self.device)) - self.log('train_acc', self.train_acc, on_step=True, on_epoch=True) - return super().training_step(batch, batch_idx) + output = super().training_step(batch, batch_idx) + self.log("train_loss", output["loss"]) + return output def validation_step(self, batch, batch_idx): - preds = torch.tensor(0, device=self.device) - targets = torch.tensor(1, device=self.device) + preds = torch.tensor([[0.9, 0.1]], device=self.device) + targets = torch.tensor([1], device=self.device) if batch_idx < 8: - targets = preds + preds = torch.tensor([[0.1, 0.9]], device=self.device) self.val_acc(preds, targets) self.log('val_acc', self.val_acc, on_step=True, on_epoch=True) return super().validation_step(batch, batch_idx) @@ -784,7 +783,7 @@ def validation_step(self, batch, batch_idx): trainer.fit(model) assert trainer.callback_metrics["val_acc"] == 8 / 32. - assert "train_acc" in trainer.callback_metrics + assert "train_loss" in trainer.callback_metrics def test_progress_bar_dict_contains_values_on_train_epoch_end(tmpdir): From bf2e78edfaaa690ca24481ecec1754c9b5ebffef Mon Sep 17 00:00:00 2001 From: Ubuntu Date: Mon, 28 Dec 2020 16:33:49 +0000 Subject: [PATCH 09/14] add a test --- .../logger_connector/logger_connector.py | 22 ++++++++-------- .../logger_connector/metrics_holder.py | 9 ++++--- .../trainer/logging/test_logger_connector.py | 25 +++++++++++++++++++ 3 files changed, 42 insertions(+), 14 deletions(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py index 531e3d013ca23..d417d21a22da4 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py @@ -14,7 +14,7 @@ import os from copy import deepcopy from pprint import pprint -from typing import Any, Iterable, Union +from typing import Any, Iterable, Union, Dict import torch @@ -42,42 +42,42 @@ def __init__(self, trainer): self._current_stage = None @property - def callback_metrics(self): + def callback_metrics(self) -> Dict: return self.get_metrics("callback_metrics") @callback_metrics.setter - def callback_metrics(self, callback_metrics): + def callback_metrics(self, callback_metrics: Dict) -> None: self.set_metrics("callback_metrics", callback_metrics) @property - def evaluation_callback_metrics(self): + def evaluation_callback_metrics(self) -> Dict: return self.get_metrics("evaluation_callback_metrics") @evaluation_callback_metrics.setter - def evaluation_callback_metrics(self, evaluation_callback_metrics): + def evaluation_callback_metrics(self, evaluation_callback_metrics: Dict) -> None: self.set_metrics("evaluation_callback_metrics", evaluation_callback_metrics) @property - def logged_metrics(self): + def logged_metrics(self) -> Dict: return self.get_metrics("logged_metrics") @logged_metrics.setter - def logged_metrics(self, logged_metrics): + def logged_metrics(self, logged_metrics: Dict) -> None: self.set_metrics("logged_metrics", logged_metrics) @property - def progress_bar_metrics(self): + def progress_bar_metrics(self) -> Dict: return self.get_metrics("progress_bar_metrics") @progress_bar_metrics.setter - def progress_bar_metrics(self, progress_bar_metrics): + def progress_bar_metrics(self, progress_bar_metrics: Dict) -> None: self.set_metrics("progress_bar_metrics", progress_bar_metrics) @property def cached_results(self) -> Union[EpochResultStore, None]: return self._cached_results.get(self._current_stage) # type: ignore - def get_metrics(self, key): + def get_metrics(self, key: str) -> Dict: metrics_holder = getattr(self, f"_{key}", None) model_ref = self.trainer.get_model() metrics_holder.convert( @@ -86,7 +86,7 @@ def get_metrics(self, key): ) return metrics_holder.metrics - def set_metrics(self, key: str, val: Any): + def set_metrics(self, key: str, val: Any) -> None: metrics_holder = getattr(self, f"_{key}", None) metrics_holder.reset(val) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index 91fb2c2e176db..81388fc37bedd 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -31,9 +31,7 @@ class MetricsHolder: def __init__(self, to_float: bool = False): self.metrics = {} - self._convert = self._convert_to_tensor - if to_float: - self._convert = self._convert_to_float + self._to_float = to_float def update(self, metrics): self.metrics.update(metrics) @@ -48,6 +46,11 @@ def convert(self, use_tpu: bool, device: torch.device): for key, value in self.metrics.items(): self.metrics[key] = self._convert(value, use_tpu, device) + def _convert(self, current: Any, use_tpu: bool, device: torch.device): + if self._to_float: + return self._convert_to_float(current, use_tpu, device) + return self._convert_to_tensor(current, use_tpu, device) + def _convert_to_float(self, current, use_tpu: bool, device: torch.device): if isinstance(current, Metric): current = current.compute().detach() diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index 56e5765c7f4b8..02082ae4b8ed4 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -26,6 +26,8 @@ from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel, RandomDataset +from pytorch_lightning.metrics import Accuracy +from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder def decorator_with_arguments(fx_name='', hook_fx_name=None): @@ -425,3 +427,26 @@ def test_dataloader(self): ) trainer.fit(model) trainer.test(model, ckpt_path=None) + + +@pytest.mark.parametrize('to_float', [False, True]) +def test_metrics_holder(to_float, tmpdir): + + device = "cuda" if torch.cuda.is_available() else "cpu" + preds = torch.tensor([[0.9, 0.1]], device=device) + def is_float(value): + return isinstance(value, float) + excepted_function = is_float if to_float else torch.is_tensor + targets = torch.tensor([1], device=device) + acc = Accuracy().to(device) + metric_holder = MetricsHolder(to_float=to_float) + metric_holder.update({ + "x": 1, + "y": torch.tensor(2), + "z": acc(preds, targets), + }) + metric_holder.convert(False, device) + metrics = metric_holder.metrics + assert excepted_function(metrics["x"]) + assert excepted_function(metrics["y"]) + assert excepted_function(metrics["z"]) From 89f901e447e5035768d8b8b38da37c031f37de9e Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 28 Dec 2020 17:38:53 +0100 Subject: [PATCH 10/14] add test --- tests/trainer/logging/test_logger_connector.py | 18 ++++++++++-------- 1 file changed, 10 insertions(+), 8 deletions(-) diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index 02082ae4b8ed4..e715e8c6bd82a 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -15,6 +15,7 @@ Tests to ensure that the training loop works with a dict (1.0) """ from copy import deepcopy +from typing import Any import pytest import torch @@ -22,17 +23,17 @@ from pytorch_lightning.callbacks.base import Callback from pytorch_lightning.core.step_result import Result +from pytorch_lightning.metrics import Accuracy from pytorch_lightning.trainer import Trainer from pytorch_lightning.trainer.connectors.logger_connector.callback_hook_validator import CallbackHookNameValidator +from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.base.boring_model import BoringModel, RandomDataset -from pytorch_lightning.metrics import Accuracy -from pytorch_lightning.trainer.connectors.logger_connector.metrics_holder import MetricsHolder -def decorator_with_arguments(fx_name='', hook_fx_name=None): - def decorator(func): - def wrapper(self, *args, **kwargs): +def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> callable: + def decorator(func: callable) -> callable: + def wrapper(self, *args, **kwargs) -> Any: # Set information self._current_fx_name = fx_name self._current_hook_fx_name = hook_fx_name @@ -45,7 +46,6 @@ def wrapper(self, *args, **kwargs): return result return wrapper - return decorator @@ -434,14 +434,16 @@ def test_metrics_holder(to_float, tmpdir): device = "cuda" if torch.cuda.is_available() else "cpu" preds = torch.tensor([[0.9, 0.1]], device=device) - def is_float(value): + + def is_float(value: Any) -> bool: return isinstance(value, float) + excepted_function = is_float if to_float else torch.is_tensor targets = torch.tensor([1], device=device) acc = Accuracy().to(device) metric_holder = MetricsHolder(to_float=to_float) metric_holder.update({ - "x": 1, + "x": 1, "y": torch.tensor(2), "z": acc(preds, targets), }) From f2ffa5244c8bde31170983829cf99e8315c3ee4a Mon Sep 17 00:00:00 2001 From: tchaton Date: Mon, 28 Dec 2020 18:31:12 +0100 Subject: [PATCH 11/14] update to Callable --- tests/trainer/logging/test_logger_connector.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/trainer/logging/test_logger_connector.py b/tests/trainer/logging/test_logger_connector.py index e715e8c6bd82a..f911c793b0707 100644 --- a/tests/trainer/logging/test_logger_connector.py +++ b/tests/trainer/logging/test_logger_connector.py @@ -15,7 +15,7 @@ Tests to ensure that the training loop works with a dict (1.0) """ from copy import deepcopy -from typing import Any +from typing import Any, Callable import pytest import torch @@ -31,8 +31,8 @@ from tests.base.boring_model import BoringModel, RandomDataset -def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> callable: - def decorator(func: callable) -> callable: +def decorator_with_arguments(fx_name: str = '', hook_fx_name: str = None) -> Callable: + def decorator(func: Callable) -> Callable: def wrapper(self, *args, **kwargs) -> Any: # Set information self._current_fx_name = fx_name From 97c9e5191cfcd83492e76e51a35b013ef586cf0d Mon Sep 17 00:00:00 2001 From: chaton Date: Thu, 7 Jan 2021 11:07:13 +0100 Subject: [PATCH 12/14] Update pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py Co-authored-by: Roger Shieh --- .../trainer/connectors/logger_connector/metrics_holder.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py index 81388fc37bedd..d2e2c9b7870cf 100644 --- a/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py +++ b/pytorch_lightning/trainer/connectors/logger_connector/metrics_holder.py @@ -24,7 +24,7 @@ class MetricsHolder: """ This class acts as a dictonary holder. - It holds metris and implement convertion functions. + It holds metrics and implements conversion functions. Those functions will be triggered within LoggerConnector when the property is being requested from the user. """ From 46de65bb9c051cc91f97dba0d55d217c034ee06b Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 7 Jan 2021 13:16:17 +0100 Subject: [PATCH 13/14] add changelog --- CHANGELOG.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index 2a46f49211268..8d30f421bdb46 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -17,6 +17,8 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Changed +- Changed `automatic casting` for LoggerConnector `metrics` ([#5218](https://github.com/PyTorchLightning/pytorch-lightning/pull/5218)) + ### Deprecated From 2650ae28e39fe4902e35a48fd1a2cee1200a721d Mon Sep 17 00:00:00 2001 From: tchaton Date: Thu, 7 Jan 2021 13:21:52 +0100 Subject: [PATCH 14/14] resolve flake8 --- pytorch_lightning/callbacks/early_stopping.py | 4 +- .../callbacks/model_checkpoint.py | 2 - .../test_train_loop_logging_1_0.py | 44 ------------------- 3 files changed, 1 insertion(+), 49 deletions(-) diff --git a/pytorch_lightning/callbacks/early_stopping.py b/pytorch_lightning/callbacks/early_stopping.py index c5686c44fdd15..fca39036c9404 100644 --- a/pytorch_lightning/callbacks/early_stopping.py +++ b/pytorch_lightning/callbacks/early_stopping.py @@ -19,14 +19,12 @@ Monitor a metric and stop training when it stops improving. """ -import numbers import numpy as np import torch from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.metrics.metric import Metric -from pytorch_lightning.utilities import _TPU_AVAILABLE, rank_zero_info, rank_zero_warn +from pytorch_lightning.utilities import rank_zero_info, rank_zero_warn class EarlyStopping(Callback): diff --git a/pytorch_lightning/callbacks/model_checkpoint.py b/pytorch_lightning/callbacks/model_checkpoint.py index e49767eba5a82..3fc2b54d98162 100644 --- a/pytorch_lightning/callbacks/model_checkpoint.py +++ b/pytorch_lightning/callbacks/model_checkpoint.py @@ -20,7 +20,6 @@ """ -import numbers import os import re from copy import deepcopy @@ -33,7 +32,6 @@ from pytorch_lightning import _logger as log from pytorch_lightning.callbacks.base import Callback -from pytorch_lightning.metrics.metric import Metric from pytorch_lightning.utilities import rank_zero_info, rank_zero_only, rank_zero_warn from pytorch_lightning.utilities.cloud_io import get_filesystem from pytorch_lightning.utilities.exceptions import MisconfigurationException diff --git a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py index 044ca54dd045f..f418db2bd72a5 100644 --- a/tests/trainer/logging_tests/test_train_loop_logging_1_0.py +++ b/tests/trainer/logging_tests/test_train_loop_logging_1_0.py @@ -781,50 +781,6 @@ def validation_step(self, batch, batch_idx): assert trainer.logged_metrics['bar'] == fake_result -@pytest.mark.skipif(not torch.cuda.is_available(), reason="requires GPU machine") -def test_metric_are_properly_reduced(tmpdir): - class TestingModel(BoringModel): - def __init__(self, *args, **kwargs): - super().__init__() - self.val_acc = pl.metrics.Accuracy() - - def training_step(self, batch, batch_idx): - output = super().training_step(batch, batch_idx) - self.log("train_loss", output["loss"]) - return output - - def validation_step(self, batch, batch_idx): - preds = torch.tensor([[0.9, 0.1]], device=self.device) - targets = torch.tensor([1], device=self.device) - if batch_idx < 8: - preds = torch.tensor([[0.1, 0.9]], device=self.device) - self.val_acc(preds, targets) - self.log('val_acc', self.val_acc, on_step=True, on_epoch=True) - return super().validation_step(batch, batch_idx) - - early_stop = EarlyStopping(monitor='val_acc', mode='max') - - checkpoint = ModelCheckpoint( - monitor='val_acc', - save_last=True, - save_top_k=2, - mode='max', - ) - - model = TestingModel() - trainer = Trainer( - default_root_dir=tmpdir, - gpus=1, - max_epochs=2, - limit_train_batches=5, - limit_val_batches=32, - callbacks=[early_stop, checkpoint]) - trainer.fit(model) - - assert trainer.callback_metrics["val_acc"] == 8 / 32. - assert "train_loss" in trainer.callback_metrics - - def test_progress_bar_dict_contains_values_on_train_epoch_end(tmpdir): class TestModel(BoringModel): def training_step(self, *args):