Skip to content

Commit 649260d

Browse files
committed
Possible fixes for todos
1 parent bc28576 commit 649260d

File tree

2 files changed

+18
-12
lines changed

2 files changed

+18
-12
lines changed

pytorch_lightning/callbacks/model_checkpoint.py

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -584,15 +584,21 @@ def __resolve_ckpt_dir(self, trainer: "pl.Trainer") -> None:
584584
# the user has changed weights_save_path, it overrides anything
585585
save_dir = trainer.weights_save_path
586586
else:
587-
save_dir = trainer.logger.save_dir if len(trainer.loggers) == 1 else trainer.default_root_dir
588-
589-
version = (
590-
trainer.logger.version
591-
if isinstance(trainer.logger.version, str)
592-
else f"version_{trainer.logger.version}"
593-
)
594-
# TODO: Find out what ckpt_path should be with multiple loggers
595-
ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints")
587+
if len(trainer.loggers) == 1:
588+
save_dir = trainer.logger.save_dir or trainer.default_root_dir
589+
else:
590+
save_dir = trainer.default_root_dir
591+
592+
if len(trainer.loggers) == 1:
593+
version = (
594+
trainer.logger.version
595+
if isinstance(trainer.logger.version, str)
596+
else f"version_{trainer.logger.version}"
597+
)
598+
# TODO: Find out what ckpt_path should be with multiple loggers
599+
ckpt_path = os.path.join(save_dir, str(trainer.logger.name), version, "checkpoints")
600+
else:
601+
ckpt_path = os.path.join(save_dir, "checkpoints")
596602
else:
597603
ckpt_path = os.path.join(trainer.weights_save_path, "checkpoints")
598604

pytorch_lightning/callbacks/progress/base.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -213,9 +213,9 @@ def get_standard_metrics(trainer: "pl.Trainer", pl_module: "pl.LightningModule")
213213
if pl_module.truncated_bptt_steps > 0:
214214
items_dict["split_idx"] = trainer.fit_loop.split_idx
215215

216-
# TODO: Adapt for trainer.loggers
217-
if trainer.logger is not None and trainer.logger.version is not None:
218-
version = trainer.logger.version
216+
# TODO: Find out if this is the correct approach
217+
if len(trainer.loggers) == 1 and trainer.loggers[0].version is not None:
218+
version = trainer.loggers[0].version
219219
if isinstance(version, str):
220220
# show last 4 places of long version strings
221221
version = version[-4:]

0 commit comments

Comments
 (0)