Skip to content

Commit ffb23f9

Browse files
krishnakalyan3carmocca
authored andcommitted
Remove the deprecated agg_and_log_metrics (#14840)
Co-authored-by: Carlos Mocholí <[email protected]>
1 parent e81115b commit ffb23f9

File tree

9 files changed

+19
-207
lines changed

9 files changed

+19
-207
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -203,6 +203,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
203203
- Removed the deprecated device attributes `Trainer.{devices,gpus,num_gpus,ipus,tpu_cores}` in favor of the accelerator-agnostic `Trainer.num_devices` ([#14829](https://github.com/Lightning-AI/lightning/pull/14829))
204204

205205

206+
- Removed the deprecated `Logger.agg_and_log_metrics` hook in favour of `Logger.log_metrics` and the `agg_key_funcs` and `agg_default_func` arguments. ([#14840](https://github.com/Lightning-AI/lightning/pull/14840))
207+
208+
206209
- Removed the deprecated precision plugin checkpoint hooks `PrecisionPlugin.on_load_checkpoint` and `PrecisionPlugin.on_save_checkpoint` ([#14833](https://github.com/Lightning-AI/lightning/pull/14833))
207210

208211

src/pytorch_lightning/loggers/base.py

Lines changed: 1 addition & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -29,25 +29,7 @@ def rank_zero_experiment(fn: Callable) -> Callable:
2929

3030

3131
class LightningLoggerBase(logger.Logger):
32-
"""Base class for experiment loggers.
33-
34-
Args:
35-
agg_key_funcs:
36-
Dictionary which maps a metric name to a function, which will
37-
aggregate the metric values for the same steps.
38-
agg_default_func:
39-
Default function to aggregate metric values. If some metric name
40-
is not presented in the `agg_key_funcs` dictionary, then the
41-
`agg_default_func` will be used for aggregation.
42-
43-
.. deprecated:: v1.6
44-
The parameters `agg_key_funcs` and `agg_default_func` are deprecated
45-
in v1.6 and will be removed in v1.8.
46-
47-
Note:
48-
The `agg_key_funcs` and `agg_default_func` arguments are used only when
49-
one logs metrics with the :meth:`~LightningLoggerBase.agg_and_log_metrics` method.
50-
"""
32+
"""Base class for experiment loggers."""
5133

5234
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
5335
rank_zero_deprecation(

src/pytorch_lightning/loggers/comet.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
import os
2121
from argparse import Namespace
22-
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
22+
from typing import Any, Dict, Mapping, Optional, Union
2323

2424
from lightning_utilities.core.imports import module_available
2525
from torch import Tensor
@@ -219,15 +219,13 @@ def __init__(
219219
experiment_key: Optional[str] = None,
220220
offline: bool = False,
221221
prefix: str = "",
222-
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
223-
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
224222
**kwargs: Any,
225223
):
226224
if comet_ml is None:
227225
raise ModuleNotFoundError(
228226
"You want to use `comet_ml` logger which is not installed yet, install it with `pip install comet-ml`."
229227
)
230-
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
228+
super().__init__()
231229
self._experiment = None
232230
self._save_dir: Optional[str]
233231
self.rest_api_key: Optional[str]

src/pytorch_lightning/loggers/logger.py

Lines changed: 6 additions & 88 deletions
Original file line numberDiff line numberDiff line change
@@ -20,15 +20,15 @@
2020
from argparse import Namespace
2121
from collections import defaultdict
2222
from functools import wraps
23-
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
23+
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
2424
from weakref import ReferenceType
2525

2626
import numpy as np
2727
from torch import Tensor
2828

2929
import pytorch_lightning as pl
3030
from pytorch_lightning.callbacks import Checkpoint
31-
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation, rank_zero_only
31+
from pytorch_lightning.utilities.rank_zero import rank_zero_only
3232

3333

3434
def rank_zero_experiment(fn: Callable) -> Callable:
@@ -57,47 +57,7 @@ def get_experiment() -> Callable:
5757

5858

5959
class Logger(ABC):
60-
"""Base class for experiment loggers.
61-
62-
Args:
63-
agg_key_funcs:
64-
Dictionary which maps a metric name to a function, which will
65-
aggregate the metric values for the same steps.
66-
agg_default_func:
67-
Default function to aggregate metric values. If some metric name
68-
is not presented in the `agg_key_funcs` dictionary, then the
69-
`agg_default_func` will be used for aggregation.
70-
71-
.. deprecated:: v1.6
72-
The parameters `agg_key_funcs` and `agg_default_func` are deprecated
73-
in v1.6 and will be removed in v1.8.
74-
75-
Note:
76-
The `agg_key_funcs` and `agg_default_func` arguments are used only when
77-
one logs metrics with the :meth:`~Logger.agg_and_log_metrics` method.
78-
"""
79-
80-
def __init__(
81-
self,
82-
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
83-
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
84-
):
85-
self._prev_step: int = -1
86-
self._metrics_to_agg: List[Dict[str, float]] = []
87-
if agg_key_funcs:
88-
self._agg_key_funcs = agg_key_funcs
89-
rank_zero_deprecation(
90-
"The `agg_key_funcs` parameter for `Logger` was deprecated in v1.6" " and will be removed in v1.8."
91-
)
92-
else:
93-
self._agg_key_funcs = {}
94-
if agg_default_func:
95-
self._agg_default_func = agg_default_func
96-
rank_zero_deprecation(
97-
"The `agg_default_func` parameter for `Logger` was deprecated in v1.6" " and will be removed in v1.8."
98-
)
99-
else:
100-
self._agg_default_func = np.mean
60+
"""Base class for experiment loggers."""
10161

10262
def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]") -> None:
10363
"""Called after model checkpoint callback saves a new checkpoint.
@@ -107,52 +67,9 @@ def after_save_checkpoint(self, checkpoint_callback: "ReferenceType[Checkpoint]"
10767
"""
10868
pass
10969

110-
def update_agg_funcs(
111-
self,
112-
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
113-
agg_default_func: Callable[[Sequence[float]], float] = np.mean,
114-
) -> None:
115-
"""Update aggregation methods.
116-
117-
.. deprecated:: v1.6
118-
`update_agg_funcs` is deprecated in v1.6 and will be removed in v1.8.
119-
120-
Args:
121-
agg_key_funcs:
122-
Dictionary which maps a metric name to a function, which will
123-
aggregate the metric values for the same steps.
124-
agg_default_func:
125-
Default function to aggregate metric values. If some metric name
126-
is not presented in the `agg_key_funcs` dictionary, then the
127-
`agg_default_func` will be used for aggregation.
128-
"""
129-
if agg_key_funcs:
130-
self._agg_key_funcs.update(agg_key_funcs)
131-
if agg_default_func:
132-
self._agg_default_func = agg_default_func
133-
rank_zero_deprecation("`Logger.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8.")
134-
135-
def agg_and_log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
136-
"""Aggregates and records metrics. This method doesn't log the passed metrics instantaneously, but instead
137-
it aggregates them and logs only if metrics are ready to be logged.
138-
139-
.. deprecated:: v1.6
140-
This method is deprecated in v1.6 and will be removed in v1.8.
141-
Please use `Logger.log_metrics` instead.
142-
143-
Args:
144-
metrics: Dictionary with metric names as keys and measured quantities as values
145-
step: Step number at which the metrics should be recorded
146-
"""
147-
self.log_metrics(metrics=metrics, step=step)
148-
14970
@abstractmethod
15071
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
151-
"""
152-
Records metrics.
153-
This method logs metrics as as soon as it received them. If you want to aggregate
154-
metrics for one specific `step`, use the
155-
:meth:`~pytorch_lightning.loggers.base.Logger.agg_and_log_metrics` method.
72+
"""Records metrics. This method logs metrics as soon as it received them.
15673
15774
Args:
15875
metrics: Dictionary with metric names as keys and measured quantities as values
@@ -273,7 +190,8 @@ def method(*args: Any, **kwargs: Any) -> None:
273190
return method
274191

275192

276-
def merge_dicts(
193+
# TODO: this should have been deprecated
194+
def merge_dicts( # pragma: no cover
277195
dicts: Sequence[Mapping],
278196
agg_key_funcs: Optional[Mapping] = None,
279197
default_func: Callable[[Sequence[float]], float] = np.mean,

src/pytorch_lightning/loggers/neptune.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
import logging
2323
import os
2424
from argparse import Namespace
25-
from typing import Any, Callable, Dict, Generator, List, Mapping, Optional, Sequence, Set, Union
25+
from typing import Any, Dict, Generator, List, Optional, Set, Union
2626
from weakref import ReferenceType
2727

2828
from lightning_utilities.core.imports import RequirementCache
@@ -227,15 +227,13 @@ def __init__(
227227
run: Optional["Run"] = None,
228228
log_model_checkpoints: Optional[bool] = True,
229229
prefix: str = "training",
230-
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
231-
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
232230
**neptune_run_kwargs: Any,
233231
):
234232
if not _NEPTUNE_AVAILABLE:
235233
raise ModuleNotFoundError(str(_NEPTUNE_AVAILABLE))
236234
# verify if user passed proper init arguments
237235
self._verify_input_arguments(api_key, project, name, run, neptune_run_kwargs)
238-
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
236+
super().__init__()
239237
self._log_model_checkpoints = log_model_checkpoints
240238
self._prefix = prefix
241239
self._run_name = name

src/pytorch_lightning/loggers/tensorboard.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import logging
2020
import os
2121
from argparse import Namespace
22-
from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union
22+
from typing import Any, Dict, Mapping, Optional, Union
2323

2424
import numpy as np
2525
from torch import Tensor
@@ -94,11 +94,9 @@ def __init__(
9494
default_hp_metric: bool = True,
9595
prefix: str = "",
9696
sub_dir: Optional[str] = None,
97-
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
98-
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
9997
**kwargs: Any,
10098
):
101-
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
99+
super().__init__()
102100
self._save_dir = save_dir
103101
self._name = name or ""
104102
self._version = version

src/pytorch_lightning/loggers/wandb.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import os
1919
from argparse import Namespace
2020
from pathlib import Path
21-
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
21+
from typing import Any, Dict, List, Mapping, Optional, Union
2222
from weakref import ReferenceType
2323

2424
import torch.nn as nn
@@ -294,8 +294,6 @@ def __init__(
294294
log_model: Union[str, bool] = False,
295295
experiment: Union[Run, RunDisabled, None] = None,
296296
prefix: str = "",
297-
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
298-
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
299297
**kwargs: Any,
300298
) -> None:
301299
if wandb is None:
@@ -318,7 +316,7 @@ def __init__(
318316
"Hint: Upgrade with `pip install --upgrade wandb`."
319317
)
320318

321-
super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func)
319+
super().__init__()
322320
self._offline = offline
323321
self._log_model = log_model
324322
self._prefix = prefix

src/pytorch_lightning/trainer/connectors/logger_connector/logger_connector.py

Lines changed: 1 addition & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,6 @@
2222
from pytorch_lightning.loggers import Logger, TensorBoardLogger
2323
from pytorch_lightning.trainer.connectors.logger_connector.result import _METRICS, _OUT_DICT, _PBAR_DICT
2424
from pytorch_lightning.utilities.metrics import metrics_to_scalars
25-
from pytorch_lightning.utilities.model_helpers import is_overridden
26-
from pytorch_lightning.utilities.rank_zero import rank_zero_deprecation
2725

2826

2927
class LoggerConnector:
@@ -36,7 +34,6 @@ def __init__(self, trainer: "pl.Trainer") -> None:
3634
self._current_fx: Optional[str] = None
3735
self._batch_idx: Optional[int] = None
3836
self._split_idx: Optional[int] = None
39-
self._override_agg_and_log_metrics: bool = False
4037

4138
def on_trainer_init(
4239
self,
@@ -47,15 +44,6 @@ def on_trainer_init(
4744
self.configure_logger(logger)
4845
self.trainer.log_every_n_steps = log_every_n_steps
4946
self.trainer.move_metrics_to_cpu = move_metrics_to_cpu
50-
for logger in self.trainer.loggers:
51-
if is_overridden("agg_and_log_metrics", logger, Logger):
52-
self._override_agg_and_log_metrics = True
53-
rank_zero_deprecation(
54-
"`Logger.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
55-
" in v1.8. `Trainer` will directly call `Logger.log_metrics` so custom"
56-
" loggers should not implement `Logger.agg_and_log_metrics`."
57-
)
58-
break
5947

6048
@property
6149
def should_update_logs(self) -> bool:
@@ -104,10 +92,7 @@ def log_metrics(self, metrics: _OUT_DICT, step: Optional[int] = None) -> None:
10492

10593
# log actual metrics
10694
for logger in self.trainer.loggers:
107-
if self._override_agg_and_log_metrics:
108-
logger.agg_and_log_metrics(metrics=scalar_metrics, step=step)
109-
else:
110-
logger.log_metrics(metrics=scalar_metrics, step=step)
95+
logger.log_metrics(metrics=scalar_metrics, step=step)
11196
logger.save()
11297

11398
"""

tests/tests_pytorch/deprecated_api/test_remove_1-8.py

Lines changed: 0 additions & 68 deletions
Original file line numberDiff line numberDiff line change
@@ -22,12 +22,10 @@
2222
from pytorch_lightning import Callback, Trainer
2323
from pytorch_lightning.callbacks import ModelCheckpoint
2424
from pytorch_lightning.demos.boring_classes import BoringDataModule, BoringModel
25-
from pytorch_lightning.loggers import CSVLogger, Logger
2625
from pytorch_lightning.profilers import AdvancedProfiler, SimpleProfiler
2726
from pytorch_lightning.strategies.ipu import LightningIPUModule
2827
from pytorch_lightning.trainer.configuration_validator import _check_datamodule_checkpoint_hooks
2928
from pytorch_lightning.trainer.states import RunningStage
30-
from pytorch_lightning.utilities.rank_zero import rank_zero_only
3129

3230

3331
def test_v1_8_0_on_init_start_end(tmpdir):
@@ -289,72 +287,6 @@ def on_before_accelerator_backend_setup(self, *args, **kwargs):
289287
trainer.fit(model)
290288

291289

292-
def test_v1_8_0_logger_agg_parameters():
293-
class CustomLogger(Logger):
294-
@rank_zero_only
295-
def log_hyperparams(self, params):
296-
pass
297-
298-
@rank_zero_only
299-
def log_metrics(self, metrics, step):
300-
pass
301-
302-
@property
303-
def name(self):
304-
pass
305-
306-
@property
307-
def version(self):
308-
pass
309-
310-
with pytest.deprecated_call(
311-
match="The `agg_key_funcs` parameter for `Logger` was deprecated in v1.6" " and will be removed in v1.8."
312-
):
313-
CustomLogger(agg_key_funcs={"mean", np.mean})
314-
315-
with pytest.deprecated_call(
316-
match="The `agg_default_func` parameter for `Logger` was deprecated in v1.6" " and will be removed in v1.8."
317-
):
318-
CustomLogger(agg_default_func=np.mean)
319-
320-
# Should have no deprecation warning
321-
logger = CustomLogger()
322-
323-
with pytest.deprecated_call(match="`Logger.update_agg_funcs` was deprecated in v1.6 and will be removed in v1.8."):
324-
logger.update_agg_funcs()
325-
326-
327-
def test_v1_8_0_deprecated_agg_and_log_metrics_override(tmpdir):
328-
class AggregationOverrideLogger(CSVLogger):
329-
@rank_zero_only
330-
def agg_and_log_metrics(self, metrics, step):
331-
self.log_metrics(metrics=metrics, step=step)
332-
333-
logger = AggregationOverrideLogger(tmpdir)
334-
logger2 = CSVLogger(tmpdir)
335-
logger3 = CSVLogger(tmpdir)
336-
337-
# Test single loggers
338-
with pytest.deprecated_call(
339-
match="`Logger.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
340-
" in v1.8. `Trainer` will directly call `Logger.log_metrics` so custom"
341-
" loggers should not implement `Logger.agg_and_log_metrics`."
342-
):
343-
Trainer(logger=logger)
344-
# Should have no deprecation warning
345-
Trainer(logger=logger2)
346-
347-
# Test multiple loggers
348-
with pytest.deprecated_call(
349-
match="`Logger.agg_and_log_metrics` is deprecated in v1.6 and will be removed"
350-
" in v1.8. `Trainer` will directly call `Logger.log_metrics` so custom"
351-
" loggers should not implement `Logger.agg_and_log_metrics`."
352-
):
353-
Trainer(logger=[logger, logger3])
354-
# Should have no deprecation warning
355-
Trainer(logger=[logger2, logger3])
356-
357-
358290
def test_v1_8_0_callback_on_pretrain_routine_start_end(tmpdir):
359291
class TestCallback(Callback):
360292
def on_pretrain_routine_start(self, trainer, pl_module):

0 commit comments

Comments
 (0)