Skip to content

Commit bdd1757

Browse files
committed
return early if no loggers
1 parent b96f570 commit bdd1757

File tree

1 file changed

+26
-24
lines changed

1 file changed

+26
-24
lines changed

pytorch_lightning/trainer/trainer.py

Lines changed: 26 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -1245,36 +1245,38 @@ def _run(
12451245
return results
12461246

12471247
def _log_hyperparams(self) -> None:
1248+
if not self.loggers:
1249+
return
1250+
12481251
# log hyper-parameters
12491252
hparams_initial = None
12501253

12511254
# save exp to get started (this is where the first experiment logs are written)
12521255
datamodule_log_hyperparams = self.datamodule._log_hyperparams if self.datamodule is not None else False
12531256

1254-
if self.loggers:
1255-
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
1256-
datamodule_hparams = self.datamodule.hparams_initial
1257-
lightning_hparams = self.lightning_module.hparams_initial
1258-
inconsistent_keys = []
1259-
for key in lightning_hparams.keys() & datamodule_hparams.keys():
1260-
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
1261-
if type(lm_val) != type(dm_val):
1262-
inconsistent_keys.append(key)
1263-
elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val):
1264-
inconsistent_keys.append(key)
1265-
elif lm_val != dm_val:
1266-
inconsistent_keys.append(key)
1267-
if inconsistent_keys:
1268-
raise MisconfigurationException(
1269-
f"Error while merging hparams: the keys {inconsistent_keys} are present "
1270-
"in both the LightningModule's and LightningDataModule's hparams "
1271-
"but have different values."
1272-
)
1273-
hparams_initial = {**lightning_hparams, **datamodule_hparams}
1274-
elif self.lightning_module._log_hyperparams:
1275-
hparams_initial = self.lightning_module.hparams_initial
1276-
elif datamodule_log_hyperparams:
1277-
hparams_initial = self.datamodule.hparams_initial
1257+
if self.lightning_module._log_hyperparams and datamodule_log_hyperparams:
1258+
datamodule_hparams = self.datamodule.hparams_initial
1259+
lightning_hparams = self.lightning_module.hparams_initial
1260+
inconsistent_keys = []
1261+
for key in lightning_hparams.keys() & datamodule_hparams.keys():
1262+
lm_val, dm_val = lightning_hparams[key], datamodule_hparams[key]
1263+
if type(lm_val) != type(dm_val):
1264+
inconsistent_keys.append(key)
1265+
elif isinstance(lm_val, torch.Tensor) and id(lm_val) != id(dm_val):
1266+
inconsistent_keys.append(key)
1267+
elif lm_val != dm_val:
1268+
inconsistent_keys.append(key)
1269+
if inconsistent_keys:
1270+
raise MisconfigurationException(
1271+
f"Error while merging hparams: the keys {inconsistent_keys} are present "
1272+
"in both the LightningModule's and LightningDataModule's hparams "
1273+
"but have different values."
1274+
)
1275+
hparams_initial = {**lightning_hparams, **datamodule_hparams}
1276+
elif self.lightning_module._log_hyperparams:
1277+
hparams_initial = self.lightning_module.hparams_initial
1278+
elif datamodule_log_hyperparams:
1279+
hparams_initial = self.datamodule.hparams_initial
12781280

12791281
for logger in self.loggers:
12801282
if hparams_initial is not None:

0 commit comments

Comments
 (0)