Skip to content

Commit 4654833

Browse files
AlessioQuerciaawaelchliBordacarmocca
authored
Add support for logging the model checkpoints to MLFlowLogger (#15246)
Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Jirka <[email protected]> Co-authored-by: Carlos Mocholí <[email protected]>
1 parent 6714ca7 commit 4654833

File tree

5 files changed

+132
-15
lines changed

5 files changed

+132
-15
lines changed

src/pytorch_lightning/CHANGELOG.md

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
2121
- Add an axes argument `ax` to the `.lr_find().plot()` to enable writing to a user-defined axes in a matplotlib figure ([#15652](https://github.com/Lightning-AI/lightning/pull/15652))
2222

2323

24+
- Added `log_model` parameter to `MLFlowLogger` ([#9187](https://github.com/PyTorchLightning/pytorch-lightning/pull/9187))
25+
26+
2427
- Added a check to validate that wrapped FSDP models are used while initializing optimizers ([#15301](https://github.com/Lightning-AI/lightning/pull/15301))
2528

2629

@@ -56,6 +59,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
5659
- Enhanced `reduce_boolean_decision` to accommodate `any`-analogous semantics expected by the `EarlyStopping` callback ([#15253](https://github.com/Lightning-AI/lightning/pull/15253))
5760

5861

62+
- Fixed the automatic fallback from `Trainer(strategy="ddp_spawn", ...)` to `Trainer(strategy="ddp", ...)` when on an LSF cluster ([#15103](https://github.com/PyTorchLightning/pytorch-lightning/issues/15103))
63+
64+
5965
-
6066

6167
## [1.8.1] - 2022-11-10
@@ -80,8 +86,6 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
8086
- Fixed manual optimization raising `AttributeError` with Bagua Strategy ([#12534](https://github.com/PyTorchLightning/pytorch-lightning/issues/12534))
8187
- Fixed the import of `pytorch_lightning` causing a warning 'Redirects are currently not supported in Windows or MacOs' ([#15610](https://github.com/PyTorchLightning/pytorch-lightning/issues/15610))
8288

83-
- Fixed the automatic fallback from `Trainer(strategy="ddp_spawn", ...)` to `Trainer(strategy="ddp", ...)` when on an LSF cluster ([#15103](https://github.com/PyTorchLightning/pytorch-lightning/issues/15103))
84-
8589

8690
## [1.8.0] - 2022-11-01
8791

src/pytorch_lightning/loggers/logger.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from torch import Tensor
2727

2828
import pytorch_lightning as pl
29-
from pytorch_lightning.callbacks import Checkpoint
29+
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
3030
from pytorch_lightning.utilities.rank_zero import rank_zero_only
3131

3232

@@ -58,7 +58,7 @@ def get_experiment() -> Callable:
5858
class Logger(ABC):
5959
"""Base class for experiment loggers."""
6060

61-
def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
61+
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
6262
"""Called after model checkpoint callback saves a new checkpoint.
6363
6464
Args:

src/pytorch_lightning/loggers/mlflow.py

Lines changed: 81 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,20 @@
1818
import logging
1919
import os
2020
import re
21+
import tempfile
2122
from argparse import Namespace
23+
from pathlib import Path
2224
from time import time
2325
from typing import Any, Dict, Mapping, Optional, Union
2426

27+
import torch
28+
import yaml
2529
from lightning_utilities.core.imports import module_available
30+
from typing_extensions import Literal
2631

32+
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
2733
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
28-
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict
34+
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _scan_checkpoints
2935
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn
3036

3137
log = logging.getLogger(__name__)
@@ -108,6 +114,15 @@ def any_lightning_module_function_or_hook(self):
108114
save_dir: A path to a local directory where the MLflow runs get saved.
109115
Defaults to `./mlflow` if `tracking_uri` is not provided.
110116
Has no effect if `tracking_uri` is provided.
117+
log_model: Log checkpoints created by :class:`~pytorch_lightning.callbacks.model_checkpoint.ModelCheckpoint`
118+
as MLFlow artifacts.
119+
120+
* if ``log_model == 'all'``, checkpoints are logged during training.
121+
* if ``log_model == True``, checkpoints are logged at the end of training, except when
122+
:paramref:`~pytorch_lightning.callbacks.Checkpoint.save_top_k` ``== -1``
123+
which also logs every checkpoint during training.
124+
* if ``log_model == False`` (default), no checkpoint is logged.
125+
111126
prefix: A string to put at the beginning of metric keys.
112127
artifact_location: The location to store run artifacts. If not provided, the server picks an appropriate
113128
default.
@@ -127,6 +142,7 @@ def __init__(
127142
tracking_uri: Optional[str] = os.getenv("MLFLOW_TRACKING_URI"),
128143
tags: Optional[Dict[str, Any]] = None,
129144
save_dir: Optional[str] = "./mlruns",
145+
log_model: Literal[True, False, "all"] = False,
130146
prefix: str = "",
131147
artifact_location: Optional[str] = None,
132148
run_id: Optional[str] = None,
@@ -145,6 +161,9 @@ def __init__(
145161
self._run_name = run_name
146162
self._run_id = run_id
147163
self.tags = tags
164+
self._log_model = log_model
165+
self._logged_model_time: Dict[str, float] = {}
166+
self._checkpoint_callback: Optional[ModelCheckpoint] = None
148167
self._prefix = prefix
149168
self._artifact_location = artifact_location
150169

@@ -261,6 +280,11 @@ def finalize(self, status: str = "success") -> None:
261280
status = "FINISHED"
262281
elif status == "failed":
263282
status = "FAILED"
283+
284+
# log checkpoints as artifacts
285+
if self._checkpoint_callback:
286+
self._scan_and_log_checkpoints(self._checkpoint_callback)
287+
264288
if self.experiment.get_run(self.run_id):
265289
self.experiment.set_terminated(self.run_id, status)
266290

@@ -292,3 +316,59 @@ def version(self) -> Optional[str]:
292316
The run id.
293317
"""
294318
return self.run_id
319+
320+
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
321+
# log checkpoints as artifacts
322+
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
323+
self._scan_and_log_checkpoints(checkpoint_callback)
324+
elif self._log_model is True:
325+
self._checkpoint_callback = checkpoint_callback
326+
327+
def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
328+
# get checkpoints to be saved with associated score
329+
checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)
330+
331+
# log iteratively all new checkpoints
332+
for t, p, s, tag in checkpoints:
333+
metadata = {
334+
# Ensure .item() is called to store Tensor contents
335+
"score": s.item() if isinstance(s, torch.Tensor) else s,
336+
"original_filename": Path(p).name,
337+
"Checkpoint": {
338+
k: getattr(checkpoint_callback, k)
339+
for k in [
340+
"monitor",
341+
"mode",
342+
"save_last",
343+
"save_top_k",
344+
"save_weights_only",
345+
"_every_n_train_steps",
346+
"_every_n_val_epochs",
347+
]
348+
# ensure it does not break if `Checkpoint` args change
349+
if hasattr(checkpoint_callback, k)
350+
},
351+
}
352+
aliases = ["latest", "best"] if p == checkpoint_callback.best_model_path else ["latest"]
353+
354+
# Artifact path on mlflow
355+
artifact_path = f"model/checkpoints/{Path(p).stem}"
356+
357+
# Log the checkpoint
358+
self.experiment.log_artifact(self._run_id, p, artifact_path)
359+
360+
# Create a temporary directory to log on mlflow
361+
with tempfile.TemporaryDirectory(prefix="test", suffix="test", dir=os.getcwd()) as tmp_dir:
362+
# Log the metadata
363+
with open(f"{tmp_dir}/metadata.yaml", "w") as tmp_file_metadata:
364+
yaml.dump(metadata, tmp_file_metadata, default_flow_style=False)
365+
366+
# Log the aliases
367+
with open(f"{tmp_dir}/aliases.txt", "w") as tmp_file_aliases:
368+
tmp_file_aliases.write(str(aliases))
369+
370+
# Log the metadata and aliases
371+
self.experiment.log_artifacts(self._run_id, tmp_dir, artifact_path)
372+
373+
# remember logged models - timestamp needed in case filename didn't change (lastkckpt or custom name)
374+
self._logged_model_time[p] = t

src/pytorch_lightning/loggers/wandb.py

Lines changed: 5 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@
2525
from torch import Tensor
2626

2727
from lightning_lite.utilities.types import _PATH
28-
from pytorch_lightning.callbacks import Checkpoint
28+
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
2929
from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment
3030
from pytorch_lightning.utilities.exceptions import MisconfigurationException
3131
from pytorch_lightning.utilities.logger import (
@@ -331,7 +331,7 @@ def __init__(
331331
self._prefix = prefix
332332
self._experiment = experiment
333333
self._logged_model_time: Dict[str, float] = {}
334-
self._checkpoint_callback: Optional[Checkpoint] = None
334+
self._checkpoint_callback: Optional[ModelCheckpoint] = None
335335

336336
# paths are processed as strings
337337
if save_dir is not None:
@@ -513,14 +513,9 @@ def version(self) -> Optional[str]:
513513
# don't create an experiment if we don't have one
514514
return self._experiment.id if self._experiment else self._id
515515

516-
def after_save_checkpoint(self, checkpoint_callback: Checkpoint) -> None:
516+
def after_save_checkpoint(self, checkpoint_callback: ModelCheckpoint) -> None:
517517
# log checkpoints as artifacts
518-
if (
519-
self._log_model == "all"
520-
or self._log_model is True
521-
and hasattr(checkpoint_callback, "save_top_k")
522-
and checkpoint_callback.save_top_k == -1
523-
):
518+
if self._log_model == "all" or self._log_model is True and checkpoint_callback.save_top_k == -1:
524519
self._scan_and_log_checkpoints(checkpoint_callback)
525520
elif self._log_model is True:
526521
self._checkpoint_callback = checkpoint_callback
@@ -574,7 +569,7 @@ def finalize(self, status: str) -> None:
574569
if self._checkpoint_callback and self._experiment is not None:
575570
self._scan_and_log_checkpoints(self._checkpoint_callback)
576571

577-
def _scan_and_log_checkpoints(self, checkpoint_callback: Checkpoint) -> None:
572+
def _scan_and_log_checkpoints(self, checkpoint_callback: ModelCheckpoint) -> None:
578573
# get checkpoints to be saved with associated score
579574
checkpoints = _scan_checkpoints(checkpoint_callback, self._logged_model_time)
580575

tests/tests_pytorch/loggers/test_mlflow.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -277,3 +277,41 @@ def test_mlflow_logger_finalize_when_exception(*_):
277277
assert logger._initialized
278278
logger.finalize("failed")
279279
logger.experiment.set_terminated.assert_called_once_with(logger.run_id, "FAILED")
280+
281+
282+
@mock.patch("pytorch_lightning.loggers.mlflow.mlflow")
283+
@mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient")
284+
@pytest.mark.parametrize("log_model", ["all", True, False])
285+
def test_mlflow_log_model(client, _, tmpdir, log_model):
286+
"""Test that the logger creates the folders and files in the right place."""
287+
# Get model, logger, trainer and train
288+
model = BoringModel()
289+
logger = MLFlowLogger("test", save_dir=tmpdir, log_model=log_model)
290+
logger = mock_mlflow_run_creation(logger, experiment_id="test-id")
291+
292+
trainer = Trainer(
293+
default_root_dir=tmpdir,
294+
logger=logger,
295+
max_epochs=2,
296+
limit_train_batches=3,
297+
limit_val_batches=3,
298+
)
299+
trainer.fit(model)
300+
301+
if log_model == "all":
302+
# Checkpoint log
303+
assert client.return_value.log_artifact.call_count == 2
304+
# Metadata and aliases log
305+
assert client.return_value.log_artifacts.call_count == 2
306+
307+
elif log_model is True:
308+
# Checkpoint log
309+
client.return_value.log_artifact.assert_called_once()
310+
# Metadata and aliases log
311+
client.return_value.log_artifacts.assert_called_once()
312+
313+
elif log_model is False:
314+
# Checkpoint log
315+
assert not client.return_value.log_artifact.called
316+
# Metadata and aliases log
317+
assert not client.return_value.log_artifacts.called

0 commit comments

Comments
 (0)