Skip to content

Commit 7e2f9fb

Browse files
authored
Refactor codebase to use trainer.loggers over trainer.logger when needed (#11920)
1 parent 244f365 commit 7e2f9fb

File tree

22 files changed

+185
-115
lines changed

22 files changed

+185
-115
lines changed

pl_examples/domain_templates/generative_adversarial_net.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,8 @@ def on_train_epoch_end(self):
206206
# log sampled images
207207
sample_imgs = self(z)
208208
grid = torchvision.utils.make_grid(sample_imgs)
209-
self.logger.experiment.add_image("generated_images", grid, self.current_epoch)
209+
for logger in self.loggers:
210+
logger.experiment.add_image("generated_images", grid, self.current_epoch)
210211

211212

212213
def main(args: Namespace) -> None:

pytorch_lightning/callbacks/device_stats_monitor.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ class DeviceStatsMonitor(Callback):
4444
"""
4545

4646
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
47-
if not trainer.logger:
47+
if not trainer.loggers:
4848
raise MisconfigurationException("Cannot use DeviceStatsMonitor callback with Trainer that has no logger.")
4949

5050
def on_train_batch_start(
@@ -55,17 +55,18 @@ def on_train_batch_start(
5555
batch_idx: int,
5656
unused: Optional[int] = 0,
5757
) -> None:
58-
if not trainer.logger:
58+
if not trainer.loggers:
5959
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
6060

6161
if not trainer.logger_connector.should_update_logs:
6262
return
6363

6464
device = trainer.strategy.root_device
6565
device_stats = trainer.accelerator.get_device_stats(device)
66-
separator = trainer.logger.group_separator
67-
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
68-
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
66+
for logger in trainer.loggers:
67+
separator = logger.group_separator
68+
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_start", separator)
69+
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
6970

7071
def on_train_batch_end(
7172
self,
@@ -76,17 +77,18 @@ def on_train_batch_end(
7677
batch_idx: int,
7778
unused: Optional[int] = 0,
7879
) -> None:
79-
if not trainer.logger:
80+
if not trainer.loggers:
8081
raise MisconfigurationException("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`.")
8182

8283
if not trainer.logger_connector.should_update_logs:
8384
return
8485

8586
device = trainer.strategy.root_device
8687
device_stats = trainer.accelerator.get_device_stats(device)
87-
separator = trainer.logger.group_separator
88-
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
89-
trainer.logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
88+
for logger in trainer.loggers:
89+
separator = logger.group_separator
90+
prefixed_device_stats = _prefix_metric_keys(device_stats, "on_train_batch_end", separator)
91+
logger.log_metrics(prefixed_device_stats, step=trainer.global_step)
9092

9193

9294
def _prefix_metric_keys(metrics_dict: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]:

pytorch_lightning/callbacks/gpu_stats_monitor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -123,7 +123,7 @@ def __init__(
123123
self._gpu_ids: List[str] = [] # will be assigned later in setup()
124124

125125
def setup(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule", stage: Optional[str] = None) -> None:
126-
if not trainer.logger:
126+
if not trainer.loggers:
127127
raise MisconfigurationException("Cannot use GPUStatsMonitor callback with Trainer that has no logger.")
128128

129129
if trainer.strategy.root_device.type != "cuda":
@@ -161,8 +161,8 @@ def on_train_batch_start(
161161
# First log at beginning of second step
162162
logs["batch_time/inter_step (ms)"] = (time.time() - self._snap_inter_step_time) * 1000
163163

164-
assert trainer.logger is not None
165-
trainer.logger.log_metrics(logs, step=trainer.global_step)
164+
for logger in trainer.loggers:
165+
logger.log_metrics(logs, step=trainer.global_step)
166166

167167
@rank_zero_only
168168
def on_train_batch_end(
@@ -186,8 +186,8 @@ def on_train_batch_end(
186186
if self._log_stats.intra_step_time and self._snap_intra_step_time:
187187
logs["batch_time/intra_step (ms)"] = (time.time() - self._snap_intra_step_time) * 1000
188188

189-
assert trainer.logger is not None
190-
trainer.logger.log_metrics(logs, step=trainer.global_step)
189+
for logger in trainer.loggers:
190+
logger.log_metrics(logs, step=trainer.global_step)
191191

192192
@staticmethod
193193
def _get_gpu_ids(device_ids: List[int]) -> List[str]:

pytorch_lightning/callbacks/lr_monitor.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@ def on_train_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> No
104104
MisconfigurationException:
105105
If ``Trainer`` has no ``logger``.
106106
"""
107-
if not trainer.logger:
107+
if not trainer.loggers:
108108
raise MisconfigurationException(
109109
"Cannot use `LearningRateMonitor` callback with `Trainer` that has no logger."
110110
)
@@ -149,7 +149,6 @@ def _check_no_key(key: str) -> bool:
149149
self.last_momentum_values = {name + "-momentum": None for name in names_flatten}
150150

151151
def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
152-
assert trainer.logger is not None
153152
if not trainer.logger_connector.should_update_logs:
154153
return
155154

@@ -158,16 +157,17 @@ def on_train_batch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any)
158157
latest_stat = self._extract_stats(trainer, interval)
159158

160159
if latest_stat:
161-
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
160+
for logger in trainer.loggers:
161+
logger.log_metrics(latest_stat, step=trainer.global_step)
162162

163163
def on_train_epoch_start(self, trainer: "pl.Trainer", *args: Any, **kwargs: Any) -> None:
164-
assert trainer.logger is not None
165164
if self.logging_interval != "step":
166165
interval = "epoch" if self.logging_interval is None else "any"
167166
latest_stat = self._extract_stats(trainer, interval)
168167

169168
if latest_stat:
170-
trainer.logger.log_metrics(latest_stat, step=trainer.global_step)
169+
for logger in trainer.loggers:
170+
logger.log_metrics(latest_stat, step=trainer.global_step)
171171

172172
def _extract_stats(self, trainer: "pl.Trainer", interval: str) -> Dict[str, float]:
173173
latest_stat = {}

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pytorch_lightning.callbacks.base import Callback
3636
from pytorch_lightning.utilities.cloud_io import get_filesystem
3737
from pytorch_lightning.utilities.exceptions import MisconfigurationException
38+
from pytorch_lightning.utilities.logger import _name, _version
3839
from pytorch_lightning.utilities.rank_zero import rank_zero_info, rank_zero_warn
3940
from pytorch_lightning.utilities.types import _METRIC, _PATH, STEP_OUTPUT
4041
from pytorch_lightning.utilities.warnings import WarningCache
@@ -379,8 +380,9 @@ def save_checkpoint(self, trainer: "pl.Trainer") -> None:
379380
self._save_last_checkpoint(trainer, monitor_candidates)
380381

381382
# notify loggers
382-
if trainer.is_global_zero and trainer.logger:
383-
trainer.logger.after_save_checkpoint(proxy(self))
383+
if trainer.is_global_zero:
384+
for logger in trainer.loggers:
385+
logger.after_save_checkpoint(proxy(self))
384386

385387
def _should_skip_saving_checkpoint(self, trainer: "pl.Trainer") -> bool:
386388
from pytorch_lightning.trainer.states import TrainerFn
@@ -572,20 +574,20 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
572574
"""
573575
if self.dirpath is not None:
574576
return # short circuit
575-
576-
if trainer.logger is not None:
577+
if trainer.loggers:
577578
if trainer.weights_save_path != trainer.default_root_dir:
578579
# the user has changed weights_save_path, it overrides anything
579580
save_dir = trainer.weights_save_path
580-
else:
581+
elif len(trainer.loggers) == 1:
581582
save_dir = trainer.logger.save_dir or trainer.default_root_dir
583+
else:
584+
save_dir = trainer.default_root_dir
582585

583-
version = (
584-
trainer.logger.version
585-
if isinstance(trainer.logger.version, str)
586-
else f"version_{trainer.logger.version}"
587-
)
588-
ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints")
586+
name = _name(trainer.loggers)
587+
version = _version(trainer.loggers)
588+
version = version if isinstance(version, str) else f"version_{version}"
589+
590+
ckpt_path = os.path.join(save_dir, str(name), version, "checkpoints")
589591
else:
590592
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")
591593

pytorch_lightning/callbacks/progress/base.py

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import pytorch_lightning as pl
1717
from pytorch_lightning.callbacks import Callback
18+
from pytorch_lightning.utilities.logger import _version
1819
from pytorch_lightning.utilities.rank_zero import rank_zero_warn
1920

2021

@@ -213,11 +214,12 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule")
213214
if pl_module.truncated_bptt_steps > 0:
214215
items_dict["split_idx"] = trainer.fit_loop.split_idx
215216

216-
if trainer.logger is not None and trainer.logger.version is not None:
217-
version = trainer.logger.version
218-
if isinstance(version, str):
219-
# show last 4 places of long version strings
220-
version = version[-4:]
221-
items_dict["v_num"] = version
217+
if trainer.loggers:
218+
version = _version(trainer.loggers)
219+
if version is not None:
220+
if isinstance(version, str):
221+
# show last 4 places of long version strings
222+
version = version[-4:]
223+
items_dict["v_num"] = version
222224

223225
return items_dict

pytorch_lightning/callbacks/xla_stats_monitor.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@ def __init__(self, verbose: bool = True) -> None:
7070
self._verbose = verbose
7171

7272
def on_train_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
73-
if not trainer.logger:
73+
if not trainer.loggers:
7474
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
7575

7676
if isinstance(trainer.accelerator, TPUAccelerator):
@@ -88,7 +88,7 @@ def on_train_epoch_start(self, trainer: "pl.Trainer", pl_module: "pl.LightningMo
8888
self._start_time = time.time()
8989

9090
def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModule") -> None:
91-
if not trainer.logger:
91+
if not trainer.loggers:
9292
raise MisconfigurationException("Cannot use XLAStatsMonitor callback with Trainer that has no logger.")
9393

9494
device = trainer.strategy.root_device
@@ -102,10 +102,11 @@ def on_train_epoch_end(self, trainer: "pl.Trainer", pl_module: "pl.LightningModu
102102
peak_memory = trainer.strategy.reduce(peak_memory) * 0.001
103103
epoch_time = trainer.strategy.reduce(epoch_time)
104104

105-
trainer.logger.log_metrics(
106-
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
107-
step=trainer.current_epoch,
108-
)
105+
for logger in trainer.loggers:
106+
logger.log_metrics(
107+
{"avg. free memory (MB)": float(free_memory), "avg. peak memory (MB)": float(peak_memory)},
108+
step=trainer.current_epoch,
109+
)
109110

110111
if self._verbose:
111112
rank_zero_info(f"Average Epoch time: {epoch_time:.2f} seconds")

pytorch_lightning/core/lightning.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -253,7 +253,7 @@ def logger(self) -> Optional[LightningLoggerBase]:
253253

254254
@property
255255
def loggers(self) -> List[LightningLoggerBase]:
256-
"""Reference to the loggers object in the Trainer."""
256+
"""Reference to the list of loggers in the Trainer."""
257257
return self.trainer.loggers if self.trainer else []
258258

259259
def _apply_batch_transfer_handler(

pytorch_lightning/loops/epoch/training_epoch_loop.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -504,8 +504,10 @@ def _save_loggers_on_train_batch_end(self) -> None:
504504
"""Flushes loggers to disk."""
505505
# when loggers should save to disk
506506
should_flush_logs = self.trainer.logger_connector.should_flush_logs
507-
if should_flush_logs and self.trainer.is_global_zero and self.trainer.logger is not None:
508-
self.trainer.logger.save()
507+
# TODO: is_global_zero check should be moved to logger.save() implementation
508+
if should_flush_logs and self.trainer.is_global_zero:
509+
for logger in self.trainer.loggers:
510+
logger.save()
509511

510512
def _reload_dataloader_state_dict(self, data_fetcher: AbstractDataFetcher) -> None:
511513
if self._dataloader_state_dict:

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,11 @@ def optimizer_step(
155155
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
156156
if trainer.track_grad_norm == -1:
157157
return
158-
kwargs = {"group_separator": trainer.logger.group_separator} if trainer.logger is not None else {}
158+
159+
kwargs = {}
160+
if len(trainer.loggers) == 1:
161+
kwargs["group_separator"] = trainer.loggers[0].group_separator
162+
159163
grad_norm_dict = grad_norm(trainer.lightning_module, trainer.track_grad_norm, **kwargs)
160164
if grad_norm_dict:
161165
prev_fx = trainer.lightning_module._current_fx_name

0 commit comments

Comments
 (0)