Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
b0d96ae
remove corresponding line from pyproject.toml
jxtngx Jul 1, 2022
47d8cdb
Merge branch 'Lightning-AI:master' into typing/loggers/update-logger.py
jxtngx Jul 1, 2022
5bf3cce
add notebooks
jxtngx Jul 5, 2022
7bcf463
Merge branch 'master' into typing/loggers/13445-update-logger.py
jxtngx Jul 5, 2022
a6f2cd5
add notebooks
jxtngx Jul 5, 2022
b701321
update rank_zero_experiment
jxtngx Jul 5, 2022
576c6c7
update update_agg_funcs
jxtngx Jul 5, 2022
8d682bd
update agg_and_log_metrics and log_metrics
jxtngx Jul 5, 2022
14ce08e
update log_hyperparams annotations and return type
jxtngx Jul 5, 2022
e077e11
update log_graph annotation
jxtngx Jul 5, 2022
725df82
update multiple None return types
jxtngx Jul 5, 2022
2b9dcbe
update DummyExperiment and log_text
jxtngx Jul 5, 2022
6ec8365
update DummyLogger
jxtngx Jul 5, 2022
6df935d
fix return types for rank_zero_experiment
jxtngx Jul 5, 2022
b67e68f
update DummyLogger
jxtngx Jul 5, 2022
01112c1
update Logger.log_hyperparams
jxtngx Jul 5, 2022
b75fb46
updated DummyLogger.__iter__ return type
jxtngx Jul 5, 2022
3e21af1
update log_hyperparams
jxtngx Jul 5, 2022
fd235cd
update merge_dicts
jxtngx Jul 5, 2022
0976466
update return
jxtngx Jul 5, 2022
2e9a0a6
Update src/pytorch_lightning/loggers/logger.py
jxtngx Jul 6, 2022
b4c3ed5
change rank_zero_experiment; update import statements
jxtngx Jul 6, 2022
0382fb0
Revert notebooks
akihironitta Jul 8, 2022
2f78109
remove conditional import statement
jxtngx Jul 9, 2022
5d2d314
update log_graph
jxtngx Jul 9, 2022
a75cd2e
fix rank_zero_experiment
jxtngx Jul 10, 2022
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@ module = [
"pytorch_lightning.demos.mnist_datamodule",
"pytorch_lightning.distributed.dist",
"pytorch_lightning.loggers.base",
"pytorch_lightning.loggers.logger",
"pytorch_lightning.loggers.comet",
"pytorch_lightning.loggers.csv_logs",
"pytorch_lightning.loggers.mlflow",
Expand Down
75 changes: 44 additions & 31 deletions src/pytorch_lightning/loggers/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
# limitations under the License.
"""Abstract base class used to build new loggers."""

import argparse

import functools
import operator
from abc import ABC, abstractmethod
from argparse import Namespace
from collections import defaultdict
from functools import wraps
from typing import Any, Callable, Dict, Iterable, List, Mapping, Optional, Sequence, Union
from typing import Any, Callable, Dict, Generator, Iterable, List, Mapping, Optional, Sequence, Union
from weakref import ReferenceType

import numpy as np
from torch import Tensor

import pytorch_lightning as pl
from pytorch_lightning.callbacks import Checkpoint
Expand All @@ -33,9 +35,20 @@ def rank_zero_experiment(fn: Callable) -> Callable:
"""Returns the real experiment on rank 0 and otherwise the DummyExperiment."""

@wraps(fn)
def experiment(self):
def experiment(self) -> Union[Any, DummyExperiment]: # type: ignore[no-untyped-def]
"""
Note:
`self` is a custom logger instance. The loggers typical wrap an `experiment` method
with a @rank_zero_experiment decorator. An exception being `loggers.neptune` wraps
`experiment` and `run` with rank_zero_experiment.

Union[Any, DummyExperiment] is used because the wrapped hooks have several returns
types that are specific to the custom logger. The return type can be considered as
Union[return type of logger.experiment, DummyExperiment]
"""

@rank_zero_only
def get_experiment():
def get_experiment() -> Callable:
return fn(self)

return get_experiment() or DummyExperiment()
Expand Down Expand Up @@ -98,7 +111,7 @@ def update_agg_funcs(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Callable[[Sequence[float]], float] = np.mean,
):
) -> None:
"""Update aggregation methods.

.. deprecated:: v1.6
Expand All @@ -119,7 +132,7 @@ def update_agg_funcs(
self._agg_default_func = agg_default_func
rank_zero_deprecation("`Logger.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8.")

def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
"""Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead
it aggregates them and logs only if metrics are ready to be logged.

Expand All @@ -134,7 +147,7 @@ def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = N
self.log_metrics(metrics=metrics, step=step)

@abstractmethod
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
"""
Records metrics.
This method logs metrics as as soon as it received them. If you want to aggregate
Expand All @@ -148,16 +161,16 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
pass

@abstractmethod
def log_hyperparams(self, params: argparse.Namespace, *args, **kwargs):
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None:
"""Record hyperparameters.

Args:
params: :class:`~argparse.Namespace` containing the hyperparameters
params: :class:`~argparse.Namespace` or `Dict` containing the hyperparameters
args: Optional positional arguments, depends on the specific logger being used
kwargs: Optional keyword arguments, depends on the specific logger being used
"""

def log_graph(self, model: "pl.LightningModule", input_array=None) -> None:
def log_graph(self, model: "pl.LightningModule", input_array: Optional[Tensor] = None) -> None:
"""Record model graph.

Args:
Expand All @@ -184,7 +197,7 @@ def save_dir(self) -> Optional[str]:
return None

@property
def group_separator(self):
def group_separator(self) -> str:
"""Return the default separator used by the logger to group the data into subfolders."""
return "/"

Expand Down Expand Up @@ -229,7 +242,7 @@ def update_agg_funcs(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Callable[[Sequence[float]], float] = np.mean,
):
) -> None:
for logger in self._logger_iterable:
logger.update_agg_funcs(agg_key_funcs, agg_default_func)

Expand All @@ -238,27 +251,27 @@ def experiment(self) -> List[Any]:
"""Returns a list of experiment objects for all the loggers in the logger collection."""
return [logger.experiment for logger in self._logger_iterable]

def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None):
def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
for logger in self._logger_iterable:
logger.agg_and_log_metrics(metrics=metrics, step=step)

def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
for logger in self._logger_iterable:
logger.log_metrics(metrics=metrics, step=step)

def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace], *args: Any, **kwargs: Any) -> None:
for logger in self._logger_iterable:
logger.log_hyperparams(params)
logger.log_hyperparams(params, *args, **kwargs)

def log_graph(self, model: "pl.LightningModule", input_array=None) -> None:
def log_graph(self, model: "pl.LightningModule", input_array: Optional[Tensor] = None) -> None:
for logger in self._logger_iterable:
logger.log_graph(model, input_array)

def log_text(self, *args, **kwargs) -> None:
def log_text(self, *args: Any, **kwargs: Any) -> None:
for logger in self._logger_iterable:
logger.log_text(*args, **kwargs)

def log_image(self, *args, **kwargs) -> None:
def log_image(self, *args: Any, **kwargs: Any) -> None:
for logger in self._logger_iterable:
logger.log_image(*args, **kwargs)

Expand Down Expand Up @@ -293,17 +306,17 @@ def version(self) -> str:
class DummyExperiment:
"""Dummy experiment."""

def nop(self, *args, **kw):
def nop(self, *args: Any, **kw: Any) -> None:
pass

def __getattr__(self, _):
def __getattr__(self, _: Any) -> Callable:
return self.nop

def __getitem__(self, idx) -> "DummyExperiment":
def __getitem__(self, idx: int) -> "DummyExperiment":
# enables self.logger.experiment[0].add_image(...)
return self

def __setitem__(self, *args, **kwargs) -> None:
def __setitem__(self, *args: Any, **kwargs: Any) -> None:
pass


Expand All @@ -313,7 +326,7 @@ class DummyLogger(Logger):
It is useful if we want to disable user's logger for a feature, but still ensure that user code can run
"""

def __init__(self):
def __init__(self) -> None:
super().__init__()
self._experiment = DummyExperiment()

Expand All @@ -322,10 +335,10 @@ def experiment(self) -> DummyExperiment:
"""Return the experiment object associated with this logger."""
return self._experiment

def log_metrics(self, *args, **kwargs) -> None:
def log_metrics(self, *args: Any, **kwargs: Any) -> None:
pass

def log_hyperparams(self, *args, **kwargs) -> None:
def log_hyperparams(self, *args: Any, **kwargs: Any) -> None:
pass

@property
Expand All @@ -338,26 +351,26 @@ def version(self) -> str:
"""Return the experiment version."""
return ""

def __getitem__(self, idx) -> "DummyLogger":
def __getitem__(self, idx: int) -> "DummyLogger":
# enables self.logger[0].experiment.add_image(...)
return self

def __iter__(self):
def __iter__(self) -> Generator[None, None, None]:
# if DummyLogger is substituting a logger collection, pretend it is empty
yield from ()

def __getattr__(self, name: str) -> Callable:
"""Allows the DummyLogger to be called with arbitrary methods, to avoid AttributeErrors."""

def method(*args, **kwargs):
def method(*args: Any, **kwargs: Any) -> None:
return None

return method


def merge_dicts(
dicts: Sequence[Mapping],
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_key_funcs: Optional[Mapping] = None,
default_func: Callable[[Sequence[float]], float] = np.mean,
) -> Dict:
"""Merge a sequence with dictionaries into one dictionary by aggregating the same keys with some given
Expand Down Expand Up @@ -395,7 +408,7 @@ def merge_dicts(
"""
agg_key_funcs = agg_key_funcs or {}
keys = list(functools.reduce(operator.or_, [set(d.keys()) for d in dicts]))
d_out = {}
d_out: Dict = defaultdict(dict)
for k in keys:
fn = agg_key_funcs.get(k)
values_to_agg = [v for v in [d_in.get(k) for d_in in dicts] if v is not None]
Expand All @@ -405,4 +418,4 @@ def merge_dicts(
else:
d_out[k] = (fn or default_func)(values_to_agg)

return d_out
return dict(d_out)