@@ -44,7 +44,7 @@ class DeviceStatsMonitor(Callback):
4444 """
4545
4646 def setup (self , trainer : "pl.Trainer" , pl_module : "pl.LightningModule" , stage : Optional [str ] = None ) -> None :
47- if not trainer .logger :
47+ if not trainer .loggers :
4848 raise MisconfigurationException ("Cannot use DeviceStatsMonitor callback with Trainer that has no logger." )
4949
5050 def on_train_batch_start (
@@ -55,17 +55,18 @@ def on_train_batch_start(
5555 batch_idx : int ,
5656 unused : Optional [int ] = 0 ,
5757 ) -> None :
58- if not trainer .logger :
58+ if not trainer .loggers :
5959 raise MisconfigurationException ("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`." )
6060
6161 if not trainer .logger_connector .should_update_logs :
6262 return
6363
6464 device = trainer .strategy .root_device
6565 device_stats = trainer .accelerator .get_device_stats (device )
66- separator = trainer .logger .group_separator
67- prefixed_device_stats = _prefix_metric_keys (device_stats , "on_train_batch_start" , separator )
68- trainer .logger .log_metrics (prefixed_device_stats , step = trainer .global_step )
66+ for logger in trainer .loggers :
67+ separator = logger .group_separator
68+ prefixed_device_stats = _prefix_metric_keys (device_stats , "on_train_batch_start" , separator )
69+ logger .log_metrics (prefixed_device_stats , step = trainer .global_step )
6970
7071 def on_train_batch_end (
7172 self ,
@@ -76,17 +77,18 @@ def on_train_batch_end(
7677 batch_idx : int ,
7778 unused : Optional [int ] = 0 ,
7879 ) -> None :
79- if not trainer .logger :
80+ if not trainer .loggers :
8081 raise MisconfigurationException ("Cannot use `DeviceStatsMonitor` callback with `Trainer(logger=False)`." )
8182
8283 if not trainer .logger_connector .should_update_logs :
8384 return
8485
8586 device = trainer .strategy .root_device
8687 device_stats = trainer .accelerator .get_device_stats (device )
87- separator = trainer .logger .group_separator
88- prefixed_device_stats = _prefix_metric_keys (device_stats , "on_train_batch_end" , separator )
89- trainer .logger .log_metrics (prefixed_device_stats , step = trainer .global_step )
88+ for logger in trainer .loggers :
89+ separator = logger .group_separator
90+ prefixed_device_stats = _prefix_metric_keys (device_stats , "on_train_batch_end" , separator )
91+ logger .log_metrics (prefixed_device_stats , step = trainer .global_step )
9092
9193
9294def _prefix_metric_keys (metrics_dict : Dict [str , float ], prefix : str , separator : str ) -> Dict [str , float ]:
0 commit comments