Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ module = [
"pytorch_lightning.demos.boring_classes",
"pytorch_lightning.demos.mnist_datamodule",
"pytorch_lightning.loggers.comet",
"pytorch_lightning.loggers.mlflow",
"pytorch_lightning.loggers.neptune",
"pytorch_lightning.loggers.tensorboard",
"pytorch_lightning.loggers.wandb",
Expand Down
4 changes: 2 additions & 2 deletions src/pytorch_lightning/loggers/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,12 +203,12 @@ def group_separator(self) -> str:

@property
@abstractmethod
def name(self) -> str:
def name(self) -> Optional[str]:
"""Return the experiment name."""

@property
@abstractmethod
def version(self) -> Union[int, str]:
def version(self) -> Optional[Union[int, str]]:
"""Return the experiment version."""


Expand Down
24 changes: 17 additions & 7 deletions src/pytorch_lightning/loggers/mlflow.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,17 @@
from mlflow.tracking.context.registry import resolve_tags
else:

def resolve_tags(tags=None):
def resolve_tags(tags: Optional[Dict] = None) -> Optional[Dict]:
"""
Args:
tags: A dictionary of tags to override. If specified, tags passed in this argument will
override those inferred from the context.

Returns: A dictionary of resolved tags.

Note:
See ``mlflow.tracking.context.registry`` for more details.
"""
return tags


Expand Down Expand Up @@ -129,7 +139,7 @@ def __init__(
tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}"

self._experiment_name = experiment_name
self._experiment_id = None
self._experiment_id: Optional[str] = None
self._tracking_uri = tracking_uri
self._run_name = run_name
self._run_id = run_id
Expand All @@ -141,7 +151,7 @@ def __init__(

self._mlflow_client = MlflowClient(tracking_uri)

@property
@property # type: ignore[misc]
@rank_zero_experiment
def experiment(self) -> MlflowClient:
r"""
Expand Down Expand Up @@ -187,7 +197,7 @@ def experiment(self) -> MlflowClient:
return self._mlflow_client

@property
def run_id(self) -> str:
def run_id(self) -> Optional[str]:
"""Create the experiment if it does not exist to get the run id.

Returns:
Expand All @@ -197,7 +207,7 @@ def run_id(self) -> str:
return self._run_id

@property
def experiment_id(self) -> str:
def experiment_id(self) -> Optional[str]:
"""Create the experiment if it does not exist to get the experiment id.

Returns:
Expand Down Expand Up @@ -261,7 +271,7 @@ def save_dir(self) -> Optional[str]:
return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX)

@property
def name(self) -> str:
def name(self) -> Optional[str]:
"""Get the experiment id.

Returns:
Expand All @@ -270,7 +280,7 @@ def name(self) -> str:
return self.experiment_id

@property
def version(self) -> str:
def version(self) -> Optional[str]:
"""Get the run id.

Returns:
Expand Down