diff --git a/pyproject.toml b/pyproject.toml index 989e63122f640..8187bbd3507c8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -59,7 +59,6 @@ module = [ "pytorch_lightning.demos.boring_classes", "pytorch_lightning.demos.mnist_datamodule", "pytorch_lightning.loggers.comet", - "pytorch_lightning.loggers.mlflow", "pytorch_lightning.loggers.neptune", "pytorch_lightning.loggers.tensorboard", "pytorch_lightning.loggers.wandb", diff --git a/src/pytorch_lightning/loggers/logger.py b/src/pytorch_lightning/loggers/logger.py index 03d934aa58760..56bf4660c29dd 100644 --- a/src/pytorch_lightning/loggers/logger.py +++ b/src/pytorch_lightning/loggers/logger.py @@ -203,12 +203,12 @@ def group_separator(self) -> str: @property @abstractmethod - def name(self) -> str: + def name(self) -> Optional[str]: """Return the experiment name.""" @property @abstractmethod - def version(self) -> Union[int, str]: + def version(self) -> Optional[Union[int, str]]: """Return the experiment version.""" diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index b8ce0ef423a31..313fcfe07f10e 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -50,7 +50,17 @@ from mlflow.tracking.context.registry import resolve_tags else: - def resolve_tags(tags=None): + def resolve_tags(tags: Optional[Dict] = None) -> Optional[Dict]: + """ + Args: + tags: A dictionary of tags to override. If specified, tags passed in this argument will + override those inferred from the context. + + Returns: A dictionary of resolved tags. + + Note: + See ``mlflow.tracking.context.registry`` for more details. + """ return tags @@ -129,7 +139,7 @@ def __init__( tracking_uri = f"{LOCAL_FILE_URI_PREFIX}{save_dir}" self._experiment_name = experiment_name - self._experiment_id = None + self._experiment_id: Optional[str] = None self._tracking_uri = tracking_uri self._run_name = run_name self._run_id = run_id @@ -141,7 +151,7 @@ def __init__( self._mlflow_client = MlflowClient(tracking_uri) - @property + @property # type: ignore[misc] @rank_zero_experiment def experiment(self) -> MlflowClient: r""" @@ -187,7 +197,7 @@ def experiment(self) -> MlflowClient: return self._mlflow_client @property - def run_id(self) -> str: + def run_id(self) -> Optional[str]: """Create the experiment if it does not exist to get the run id. Returns: @@ -197,7 +207,7 @@ def run_id(self) -> str: return self._run_id @property - def experiment_id(self) -> str: + def experiment_id(self) -> Optional[str]: """Create the experiment if it does not exist to get the experiment id. Returns: @@ -261,7 +271,7 @@ def save_dir(self) -> Optional[str]: return self._tracking_uri.lstrip(LOCAL_FILE_URI_PREFIX) @property - def name(self) -> str: + def name(self) -> Optional[str]: """Get the experiment id. Returns: @@ -270,7 +280,7 @@ def name(self) -> str: return self.experiment_id @property - def version(self) -> str: + def version(self) -> Optional[str]: """Get the run id. Returns: