Skip to content
6 changes: 6 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,12 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
- Deprecated `pytorch_lightning.utilities.warnings.LightningDeprecationWarning` in favor of `pytorch_lightning.utilities.rank_zero.LightningDeprecationWarning`


- Deprecated `agg_key_funcs` and `agg_default_func` parameters from `LightningLoggerBase` ([#11871](https://github.com/PyTorchLightning/pytorch-lightning/pull/11871))


- Deprecated `LightningLoggerBase.update_agg_funcs` ([#11871](https://github.com/PyTorchLightning/pytorch-lightning/pull/11871))


- Deprecated `LightningLoggerBase.agg_and_log_metrics` in favor of `LightningLoggerBase.log_metrics` ([#11832](https://github.com/PyTorchLightning/pytorch-lightning/pull/11832))


Expand Down
30 changes: 27 additions & 3 deletions pytorch_lightning/loggers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,10 @@ class LightningLoggerBase(ABC):
is not presented in the `agg_key_funcs` dictionary, then the
`agg_default_func` will be used for aggregation.

.. deprecated:: v1.6
The parameters `agg_key_funcs` and `agg_default_func` are deprecated
in v1.6 and will be removed in v1.8.

Note:
The `agg_key_funcs` and `agg_default_func` arguments are used only when
one logs metrics with the :meth:`~LightningLoggerBase.agg_and_log_metrics` method.
Expand All @@ -63,12 +67,26 @@ class LightningLoggerBase(ABC):
def __init__(
self,
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Callable[[Sequence[float]], float] = np.mean,
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
):
self._prev_step: int = -1
self._metrics_to_agg: List[Dict[str, float]] = []
self._agg_key_funcs = agg_key_funcs if agg_key_funcs else {}
self._agg_default_func = agg_default_func
if agg_key_funcs:
self._agg_key_funcs = agg_key_funcs
rank_zero_deprecation(
"The `agg_key_funcs` parameter for `LightningLoggerBase` was deprecated in v1.6"
" and will be removed in v1.8."
)
else:
self._agg_key_funcs = {}
if agg_default_func:
self._agg_default_func = agg_default_func
rank_zero_deprecation(
"The `agg_default_func` parameter for `LightningLoggerBase` was deprecated in v1.6"
" and will be removed in v1.8."
)
else:
self._agg_default_func = np.mean

def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[ModelCheckpoint]") -> None:
"""Called after model checkpoint callback saves a new checkpoint.
Expand All @@ -85,6 +103,9 @@ def update_agg_funcs(
):
"""Update aggregation methods.

.. deprecated:: v1.6
`update_agg_funcs` is deprecated in v1.6 and will be removed in v1.8.

Args:
agg_key_funcs:
Dictionary which maps a metric name to a function, which will
Expand All @@ -98,6 +119,9 @@ def update_agg_funcs(
self._agg_key_funcs.update(agg_key_funcs)
if agg_default_func:
self._agg_default_func = agg_default_func
rank_zero_deprecation(
"`LightningLoggerBase.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8."
)

def _aggregate_metrics(
self, metrics: Dict[str, float], step: Optional[int] = None
Expand Down
42 changes: 41 additions & 1 deletion tests/deprecated_api/test_remove_1-8.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@
"""Test deprecated functionality which will be removed in v1.8.0."""
from unittest.mock import Mock

import numpy as np
import pytest
import torch
from torch import optim

from pytorch_lightning import Callback, Trainer
from pytorch_lightning.loggers import CSVLogger
from pytorch_lightning.loggers import CSVLogger, LightningLoggerBase
from pytorch_lightning.plugins.training_type.ddp import DDPPlugin
from pytorch_lightning.plugins.training_type.ddp2 import DDP2Plugin
from pytorch_lightning.plugins.training_type.ddp_spawn import DDPSpawnPlugin
Expand Down Expand Up @@ -503,6 +504,45 @@ def on_before_accelerator_backend_setup(self, *args, **kwargs):
trainer.fit(model)


def test_v1_8_0_logger_agg_parameters():
class CustomLogger(LightningLoggerBase):
@rank_zero_only
def log_hyperparams(self, params):
pass

@rank_zero_only
def log_metrics(self, metrics, step):
pass

@property
def name(self):
pass

@property
def version(self):
pass

with pytest.deprecated_call(
match="The `agg_key_funcs` parameter for `LightningLoggerBase` was deprecated in v1.6"
" and will be removed in v1.8."
):
CustomLogger(agg_key_funcs={"mean", np.mean})

with pytest.deprecated_call(
match="The `agg_default_func` parameter for `LightningLoggerBase` was deprecated in v1.6"
" and will be removed in v1.8."
):
CustomLogger(agg_default_func=np.mean)

# Should have no deprecation warning
logger = CustomLogger()

with pytest.deprecated_call(
match="`LightningLoggerBase.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8."
):
logger.update_agg_funcs()


def test_v1_8_0_deprecated_agg_and_log_metrics_override(tmpdir):
class AggregationOverrideLogger(CSVLogger):
@rank_zero_only
Expand Down