From b0d96ae415bdae1394a0154d271dfcd2d3063739 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Fri, 1 Jul 2022 07:53:17 -0400 Subject: [PATCH 01/24] remove corresponding line from pyproject.toml --- pyproject.toml | 1 - 1 file changed, 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 4b9f45068e089..dcd319835deb4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -55,7 +55,6 @@ module = [ "pytorch_lightning.demos.mnist_datamodule", "pytorch_lightning.distributed.dist", "pytorch_lightning.loggers.base", - "pytorch_lightning.loggers.logger", "pytorch_lightning.loggers.comet", "pytorch_lightning.loggers.csv_logs", "pytorch_lightning.loggers.mlflow", From 5bf3cce044be5db18816db1151867e26aaca2fe8 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 4 Jul 2022 22:23:03 -0400 Subject: [PATCH 02/24] add notebooks --- _notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_notebooks b/_notebooks index 8a36a41548f34..f48fad5489272 160000 --- a/_notebooks +++ b/_notebooks @@ -1 +1 @@ -Subproject commit 8a36a41548f34c44ac455d515a72994487e85813 +Subproject commit f48fad5489272af915f25e98320badfca600c588 From a6f2cd54e947b345065ad5b74261c29e0c6813ff Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 4 Jul 2022 22:28:25 -0400 Subject: [PATCH 03/24] add notebooks --- _notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_notebooks b/_notebooks index f48fad5489272..85bdeeeb4ef7e 160000 --- a/_notebooks +++ b/_notebooks @@ -1 +1 @@ -Subproject commit f48fad5489272af915f25e98320badfca600c588 +Subproject commit 85bdeeeb4ef7ea8b5c7cd1a4309321becc57cddb From b701321dfd464e1a3cc81a688c85ff74e32028fb Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 4 Jul 2022 22:34:41 -0400 Subject: [PATCH 04/24] update rank_zero_experiment --- src/pytorch_lightning/loggers/logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index d532aae413650..60eaef7de1913 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -33,9 +33,9 @@ def rank_zero_experiment(fn: Callable) -> Callable: """Returns the real experiment on rank 0 and otherwise the DummyExperiment.""" @wraps(fn) - def experiment(self): + def experiment(self: type) -> Union[Any, DummyExperiment]: @rank_zero_only - def get_experiment(): + def get_experiment() -> Callable: return fn(self) return get_experiment() or DummyExperiment() From 576c6c79acad24b2a4247434e28f56c7bc370983 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 4 Jul 2022 22:36:24 -0400 Subject: [PATCH 05/24] update update_agg_funcs --- src/pytorch_lightning/loggers/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index 60eaef7de1913..216ac6e23ba8e 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -98,7 +98,7 @@ def update_agg_funcs( self, agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, agg_default_func: Callable[[Sequence[float]], float] = np.mean, - ): + ) -> None: """Update aggregation methods. .. deprecated:: v1.6 From 8d682bde8a874977f71897c9e149a4e8387e3bc8 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 4 Jul 2022 22:40:43 -0400 Subject: [PATCH 06/24] update agg_and_log_metrics and log_metrics --- src/pytorch_lightning/loggers/logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index 216ac6e23ba8e..a6ecb57fbb524 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -119,7 +119,7 @@ def update_agg_funcs( self._agg_default_func = agg_default_func rank_zero_deprecation("`Logger.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8.") - def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): + def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: """Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead it aggregates them and logs only if metrics are ready to be logged. @@ -134,7 +134,7 @@ def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = N self.log_metrics(metrics=metrics, step=step) @abstractmethod - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): + def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: """ Records metrics. This method logs metrics as as soon as it received them. If you want to aggregate From 14ce08e5ae75087efeff4aa9410f539a37fcfe5a Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 4 Jul 2022 22:41:50 -0400 Subject: [PATCH 07/24] update log_hyperparams annotations and return type --- src/pytorch_lightning/loggers/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index a6ecb57fbb524..b2bc2e31ab7cf 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -148,7 +148,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> pass @abstractmethod - def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs): + def log_hyperparams(self, params: argparse.Namespace, *args: Any, **kwargs: Any) -> None: """Record hyperparameters. Args: From e077e117e3765ee89afbf47b051f7edb2885d68e Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 4 Jul 2022 22:56:35 -0400 Subject: [PATCH 08/24] update log_graph annotation --- src/pytorch_lightning/loggers/logger.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index b2bc2e31ab7cf..4d3b05ff71c1b 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -23,6 +23,7 @@ from weakref import ReferenceType import numpy as np +from torch import Tensor import pytorch_lightning as pl from pytorch_lightning.callbacks import Checkpoint @@ -157,7 +158,7 @@ def log_hyperparams(self, params: argparse.Namespace, *args: Any, **kwargs: Any) kwargs: Optional keyword arguments, depends on the specific logger being used """ - def log_graph(self, model: "pl.LightningModule", input_array=None) -> None: + def log_graph(self, model: "pl.LightningModule", input_array: Tensor = None) -> None: """Record model graph. Args: From 725df82b441e311e4ff46edbaf59c8d19c1c1aaa Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 4 Jul 2022 23:02:42 -0400 Subject: [PATCH 09/24] update multiple None return types --- src/pytorch_lightning/loggers/logger.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index 4d3b05ff71c1b..62665a1a1ab8d 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -185,7 +185,7 @@ def save_dir(self) -> Optional[str]: return None @property - def group_separator(self): + def group_separator(self) -> str: """Return the default separator used by the logger to group the data into subfolders.""" return "/" @@ -230,7 +230,7 @@ def update_agg_funcs( self, agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, agg_default_func: Callable[[Sequence[float]], float] = np.mean, - ): + ) -> None: for logger in self._logger_iterable: logger.update_agg_funcs(agg_key_funcs, agg_default_func) @@ -239,7 +239,7 @@ def experiment(self) -> List[Any]: """Returns a list of experiment objects for all the loggers in the logger collection.""" return [logger.experiment for logger in self._logger_iterable] - def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None): + def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: for logger in self._logger_iterable: logger.agg_and_log_metrics(metrics=metrics, step=step) @@ -251,7 +251,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: for logger in self._logger_iterable: logger.log_hyperparams(params) - def log_graph(self, model: "pl.LightningModule", input_array=None) -> None: + def log_graph(self, model: "pl.LightningModule", input_array: Tensor = None) -> None: for logger in self._logger_iterable: logger.log_graph(model, input_array) @@ -259,7 +259,7 @@ def log_text(self, *args, **kwargs) -> None: for logger in self._logger_iterable: logger.log_text(*args, **kwargs) - def log_image(self, *args, **kwargs) -> None: + def log_image(self, *args: Any, **kwargs: Any) -> None: for logger in self._logger_iterable: logger.log_image(*args, **kwargs) From 2b9dcbea2f813b36aa5134c62b42886ca8fa97ae Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Mon, 4 Jul 2022 23:08:11 -0400 Subject: [PATCH 10/24] update DummyExperiment and log_text --- src/pytorch_lightning/loggers/logger.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index 62665a1a1ab8d..f9b36534f13eb 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -255,7 +255,7 @@ def log_graph(self, model: "pl.LightningModule", input_array: Tensor = None) -> for logger in self._logger_iterable: logger.log_graph(model, input_array) - def log_text(self, *args, **kwargs) -> None: + def log_text(self, *args: Any, **kwargs: Any) -> None: for logger in self._logger_iterable: logger.log_text(*args, **kwargs) @@ -294,17 +294,17 @@ def version(self) -> str: class DummyExperiment: """Dummy experiment.""" - def nop(self, *args, **kw): + def nop(self, *args: Any, **kw: Any) -> None: pass - def __getattr__(self, _): + def __getattr__(self, _: Any) -> Callable: return self.nop - def __getitem__(self, idx) -> "DummyExperiment": + def __getitem__(self, idx: int) -> "DummyExperiment": # enables self.logger.experiment[0].add_image(...) return self - def __setitem__(self, *args, **kwargs) -> None: + def __setitem__(self, *args: Any, **kwargs: Any) -> None: pass From 6ec8365bd8d9e4ad2dc6727671b5d092a6c7daef Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Tue, 5 Jul 2022 07:24:15 -0400 Subject: [PATCH 11/24] update DummyLogger --- src/pytorch_lightning/loggers/logger.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index f9b36534f13eb..6f96b8f981797 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -323,10 +323,10 @@ def experiment(self) -> DummyExperiment: """Return the experiment object associated with this logger.""" return self._experiment - def log_metrics(self, *args, **kwargs) -> None: + def log_metrics(self, *args: Any, **kwargs: Any) -> None: pass - def log_hyperparams(self, *args, **kwargs) -> None: + def log_hyperparams(self, *args: Any, **kwargs: Any) -> None: pass @property @@ -339,7 +339,7 @@ def version(self) -> str: """Return the experiment version.""" return "" - def __getitem__(self, idx) -> "DummyLogger": + def __getitem__(self, idx: int) -> "DummyLogger": # enables self.logger[0].experiment.add_image(...) return self @@ -350,7 +350,7 @@ def __iter__(self): def __getattr__(self, name: str) -> Callable: """Allows the DummyLogger to be called with arbitrary methods, to avoid AttributeErrors.""" - def method(*args, **kwargs): + def method(*args: Any, **kwargs: Any) -> None: return None return method From 6df935d6867649b505522d7a850f65bf8040f809 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Tue, 5 Jul 2022 08:40:21 -0400 Subject: [PATCH 12/24] fix return types for rank_zero_experiment --- src/pytorch_lightning/loggers/logger.py | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index 6f96b8f981797..bfdda91c1f7e6 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -16,6 +16,7 @@ import argparse import functools import operator +import typing as t from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps @@ -29,12 +30,15 @@ from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only +if t.TYPE_CHECKING: + from pytorch_lightning.loggers.csv_logs import ExperimentWriter + def rank_zero_experiment(fn: Callable) -> Callable: """Returns the real experiment on rank 0 and otherwise the DummyExperiment.""" @wraps(fn) - def experiment(self: type) -> Union[Any, DummyExperiment]: + def experiment(self: Callable) -> Union[ExperimentWriter, DummyExperiment]: @rank_zero_only def get_experiment() -> Callable: return fn(self) From b67e68f2df05a6ffcceb097b83759ae8ce059ded Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Tue, 5 Jul 2022 09:07:31 -0400 Subject: [PATCH 13/24] update DummyLogger --- src/pytorch_lightning/loggers/logger.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index bfdda91c1f7e6..8d6cfe45ec097 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps -from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Union from weakref import ReferenceType import numpy as np @@ -318,7 +318,7 @@ class DummyLogger(Logger): It is useful if we want to disable user's logger for a feature, but still ensure that user code can run """ - def __init__(self): + def __init__(self) -> None: super().__init__() self._experiment = DummyExperiment() @@ -347,7 +347,7 @@ def __getitem__(self, idx: int) -> "DummyLogger": # enables self.logger[0].experiment.add_image(...) return self - def __iter__(self): + def __iter__(self) -> Iterable[Collection]: # if DummyLogger is substituting a logger collection, pretend it is empty yield from () From 01112c1f0e53847af6576d101616943db3f47da7 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Tue, 5 Jul 2022 09:15:53 -0400 Subject: [PATCH 14/24] update Logger.log_hyperparams --- src/pytorch_lightning/loggers/logger.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index 8d6cfe45ec097..413fe10623cc0 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -13,7 +13,7 @@ # limitations under the License. """Abstract base class used to build new loggers.""" -import argparse + import functools import operator import typing as t @@ -153,11 +153,11 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> pass @abstractmethod - def log_hyperparams(self, params: argparse.Namespace, *args: Any, **kwargs: Any) -> None: + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None: """Record hyperparameters. Args: - params: :class:`~argparse.Namespace` containing the hyperparameters + params: :class:`~argparse.Namespace` or `Dict` containing the hyperparameters args: Optional positional arguments, depends on the specific logger being used kwargs: Optional keyword arguments, depends on the specific logger being used """ From b75fb460346c91e98deb974ba8b452c5abac602f Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Tue, 5 Jul 2022 10:16:11 -0400 Subject: [PATCH 15/24] updated DummyLogger.__iter__ return type --- src/pytorch_lightning/loggers/logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index 413fe10623cc0..ece1888a0ccee 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -20,7 +20,7 @@ from abc import ABC, abstractmethod from argparse import Namespace from functools import wraps -from typing import Any, Callable, Collection, Dict, Iterable, List, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Union from weakref import ReferenceType import numpy as np @@ -347,7 +347,7 @@ def __getitem__(self, idx: int) -> "DummyLogger": # enables self.logger[0].experiment.add_image(...) return self - def __iter__(self) -> Iterable[Collection]: + def __iter__(self) -> Generator[None, None, None]: # if DummyLogger is substituting a logger collection, pretend it is empty yield from () From 3e21af15c0ee4aa045d372608e48df87a2d39412 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Tue, 5 Jul 2022 10:34:33 -0400 Subject: [PATCH 16/24] update log_hyperparams --- src/pytorch_lightning/loggers/logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index ece1888a0ccee..b45fe3a96db8c 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -251,9 +251,9 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> for logger in self._logger_iterable: logger.log_metrics(metrics=metrics, step=step) - def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: + def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None: for logger in self._logger_iterable: - logger.log_hyperparams(params) + logger.log_hyperparams(params, *args, **kwargs) def log_graph(self, model: "pl.LightningModule", input_array: Tensor = None) -> None: for logger in self._logger_iterable: From fd235cd7e31fa5adb7b4bd96601eee178e8538a6 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Tue, 5 Jul 2022 17:46:20 -0400 Subject: [PATCH 17/24] update merge_dicts --- src/pytorch_lightning/loggers/logger.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index b45fe3a96db8c..fc7cd035181b2 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -19,6 +19,7 @@ import typing as t from abc import ABC, abstractmethod from argparse import Namespace +from collections import defaultdict from functools import wraps from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Union from weakref import ReferenceType @@ -362,7 +363,7 @@ def method(*args: Any, **kwargs: Any) -> None: def merge_dicts( dicts: Sequence[Mapping], - agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, + agg_key_funcs: Optional[Mapping] = None, default_func: Callable[[Sequence[float]], float] = np.mean, ) -> Dict: """Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given @@ -400,7 +401,7 @@ def merge_dicts( """ agg_key_funcs = agg_key_funcs or {} keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts])) - d_out = {} + d_out: Dict = defaultdict(dict) for k in keys: fn = agg_key_funcs.get(k) values_to_agg = [v for v in [d_in.get(k) for d_in in dicts] if v is not None] From 0976466de2e0fc75bb609bba5b932c9e785c56d6 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Tue, 5 Jul 2022 17:49:33 -0400 Subject: [PATCH 18/24] update return --- src/pytorch_lightning/loggers/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index fc7cd035181b2..a0f5eb74e1b69 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -411,4 +411,4 @@ def merge_dicts( else: d_out[k] = (fn or default_func)(values_to_agg) - return d_out + return dict(d_out) From 2e9a0a668af0df6c57574afb9ee66c5e32c6f11c Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Wed, 6 Jul 2022 07:55:55 -0400 Subject: [PATCH 19/24] Update src/pytorch_lightning/loggers/logger.py MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrian Wälchli --- src/pytorch_lightning/loggers/logger.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index a0f5eb74e1b69..72bf4034f21ac 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -31,7 +31,7 @@ from pytorch_lightning.callbacks import Checkpoint from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only -if t.TYPE_CHECKING: +if TYPE_CHECKING: from pytorch_lightning.loggers.csv_logs import ExperimentWriter From b4c3ed514f32e4f346b83efa43b23e500b9ce23a Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Wed, 6 Jul 2022 08:41:29 -0400 Subject: [PATCH 20/24] change rank_zero_experiment; update import statements --- src/pytorch_lightning/loggers/logger.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index 72bf4034f21ac..acb31cbf49f20 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -16,12 +16,11 @@ import functools import operator -import typing as t from abc import ABC, abstractmethod from argparse import Namespace from collections import defaultdict from functools import wraps -from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union from weakref import ReferenceType import numpy as np @@ -39,7 +38,7 @@ def rank_zero_experiment(fn: Callable) -> Callable: """Returns the real experiment on rank 0 and otherwise the DummyExperiment.""" @wraps(fn) - def experiment(self: Callable) -> Union[ExperimentWriter, DummyExperiment]: + def experiment(self) -> Union[ExperimentWriter, DummyExperiment]: # type: ignore[no-untyped-def] @rank_zero_only def get_experiment() -> Callable: return fn(self) From 0382fb008841a7427c79b520f91393bbf1bc3e76 Mon Sep 17 00:00:00 2001 From: Akihiro Nitta Date: Fri, 8 Jul 2022 09:09:59 +0900 Subject: [PATCH 21/24] Revert notebooks --- _notebooks | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/_notebooks b/_notebooks index 85bdeeeb4ef7e..8a36a41548f34 160000 --- a/_notebooks +++ b/_notebooks @@ -1 +1 @@ -Subproject commit 85bdeeeb4ef7ea8b5c7cd1a4309321becc57cddb +Subproject commit 8a36a41548f34c44ac455d515a72994487e85813 From 2f78109c2e38230179b301bdb817a4ce8c7d7a4e Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Fri, 8 Jul 2022 20:27:04 -0400 Subject: [PATCH 22/24] remove conditional import statement --- src/pytorch_lightning/loggers/logger.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index acb31cbf49f20..c0b98a4e58b01 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -20,7 +20,7 @@ from argparse import Namespace from collections import defaultdict from functools import wraps -from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, TYPE_CHECKING, Union +from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Union from weakref import ReferenceType import numpy as np @@ -28,11 +28,9 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import Checkpoint +from pytorch_lightning.loggers.csv_logs import ExperimentWriter from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only -if TYPE_CHECKING: - from pytorch_lightning.loggers.csv_logs import ExperimentWriter - def rank_zero_experiment(fn: Callable) -> Callable: """Returns the real experiment on rank 0 and otherwise the DummyExperiment.""" From 5d2d31436b69c25b2c4a6f925e32c94d768182b2 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Fri, 8 Jul 2022 20:53:43 -0400 Subject: [PATCH 23/24] update log_graph --- src/pytorch_lightning/loggers/logger.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index c0b98a4e58b01..829c65d342297 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -160,7 +160,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any, kwargs: Optional keyword arguments, depends on the specific logger being used """ - def log_graph(self, model: "pl.LightningModule", input_array: Tensor = None) -> None: + def log_graph(self, model: "pl.LightningModule", input_array: Optional[Tensor] = None) -> None: """Record model graph. Args: @@ -253,7 +253,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any, for logger in self._logger_iterable: logger.log_hyperparams(params, *args, **kwargs) - def log_graph(self, model: "pl.LightningModule", input_array: Tensor = None) -> None: + def log_graph(self, model: "pl.LightningModule", input_array: Optional[Tensor] = None) -> None: for logger in self._logger_iterable: logger.log_graph(model, input_array) From a75cd2e005369229f72ad5412a9779ec01db6462 Mon Sep 17 00:00:00 2001 From: Justin Goheen <26209687+JustinGoheen@users.noreply.github.com> Date: Sun, 10 Jul 2022 11:29:10 -0400 Subject: [PATCH 24/24] fix rank_zero_experiment --- src/pytorch_lightning/loggers/logger.py | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index 829c65d342297..4113b61627d8f 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -28,7 +28,6 @@ import pytorch_lightning as pl from pytorch_lightning.callbacks import Checkpoint -from pytorch_lightning.loggers.csv_logs import ExperimentWriter from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only @@ -36,7 +35,18 @@ def rank_zero_experiment(fn: Callable) -> Callable: """Returns the real experiment on rank 0 and otherwise the DummyExperiment.""" @wraps(fn) - def experiment(self) -> Union[ExperimentWriter, DummyExperiment]: # type: ignore[no-untyped-def] + def experiment(self) -> Union[Any, DummyExperiment]: # type: ignore[no-untyped-def] + """ + Note: + `self` is a custom logger instance. The loggers typical wrap an `experiment` method + with a @rank_zero_experiment decorator. An exception being `loggers.neptune` wraps + `experiment` and `run` with rank_zero_experiment. + + Union[Any, DummyExperiment] is used because the wrapped hooks have several returns + types that are specific to the custom logger. The return type can be considered as + Union[return type of logger.experiment, DummyExperiment] + """ + @rank_zero_only def get_experiment() -> Callable: return fn(self)