diff --git a/pyproject.toml b/pyproject.toml index a0960c58f6e6d..177410cba79a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -57,7 +57,6 @@ module = [ "pytorch_lightning.core.saving", "pytorch_lightning.demos.boring_classes", "pytorch_lightning.demos.mnist_datamodule", - "pytorch_lightning.loggers.comet", "pytorch_lightning.loggers.neptune", "pytorch_lightning.profilers.base", "pytorch_lightning.profilers.pytorch", diff --git a/src/pytorch_lightning/loggers/comet.py b/src/pytorch_lightning/loggers/comet.py index 2b853f59259ff..363d47c1166e6 100644 --- a/src/pytorch_lightning/loggers/comet.py +++ b/src/pytorch_lightning/loggers/comet.py @@ -21,7 +21,7 @@ from argparse import Namespace from typing import Any, Callable, Dict, Mapping, Optional, Sequence, Union -from torch import is_tensor, Tensor +from torch import Tensor import pytorch_lightning as pl from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment @@ -141,7 +141,7 @@ def __init__( prefix: str = "", agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None, agg_default_func: Optional[Callable[[Sequence[float]], float]] = None, - **kwargs, + **kwargs: Any, ): if comet_ml is None: raise ModuleNotFoundError( @@ -149,6 +149,8 @@ def __init__( ) super().__init__(agg_key_funcs=agg_key_funcs, agg_default_func=agg_default_func) self._experiment = None + self._save_dir: Optional[str] + self.rest_api_key: Optional[str] # Determine online or offline mode based on which arguments were passed to CometLogger api_key = api_key or comet_ml.config.get_api_key(None, comet_ml.config.get_config()) @@ -170,12 +172,12 @@ def __init__( log.info(f"CometLogger will be initialized in {self.mode} mode") - self._project_name = project_name - self._experiment_key = experiment_key - self._experiment_name = experiment_name - self._prefix = prefix - self._kwargs = kwargs - self._future_experiment_key = None + self._project_name: Optional[str] = project_name + self._experiment_key: Optional[str] = experiment_key + self._experiment_name: Optional[str] = experiment_name + self._prefix: str = prefix + self._kwargs: Any = kwargs + self._future_experiment_key: Optional[str] = None if rest_api_key is not None: # Comet.ml rest API, used to determine version number @@ -185,9 +187,7 @@ def __init__( self.rest_api_key = None self.comet_api = None - self._kwargs = kwargs - - @property + @property # type: ignore[misc] @rank_zero_experiment def experiment(self) -> Union[CometExperiment, CometExistingExperiment, CometOfflineExperiment]: r""" @@ -240,19 +240,19 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: self.experiment.log_parameters(params) @rank_zero_only - def log_metrics(self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Mapping[str, Union[Tensor, float]], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" # Comet.ml expects metrics to be a dictionary of detached tensors on CPU metrics_without_epoch = metrics.copy() for key, val in metrics_without_epoch.items(): - if is_tensor(val): + if isinstance(val, Tensor): metrics_without_epoch[key] = val.cpu().detach() epoch = metrics_without_epoch.pop("epoch", None) metrics_without_epoch = _add_prefix(metrics_without_epoch, self._prefix, self.LOGGER_JOIN_CHAR) self.experiment.log_metrics(metrics_without_epoch, step=step, epoch=epoch) - def reset_experiment(self): + def reset_experiment(self) -> None: self._experiment = None @rank_zero_only @@ -326,7 +326,7 @@ def version(self) -> str: return self._future_experiment_key - def __getstate__(self): + def __getstate__(self) -> Dict[str, Any]: state = self.__dict__.copy() # Save the experiment id in case an experiment object already exists, @@ -340,6 +340,6 @@ def __getstate__(self): state["_experiment"] = None return state - def log_graph(self, model: "pl.LightningModule", input_array=None) -> None: + def log_graph(self, model: "pl.LightningModule", input_array: Optional[Tensor] = None) -> None: if self._experiment is not None: self._experiment.set_model_graph(model) diff --git a/src/pytorch_lightning/loggers/csv_logs.py b/src/pytorch_lightning/loggers/csv_logs.py index 72d21ae2c4974..45d5fffb51e33 100644 --- a/src/pytorch_lightning/loggers/csv_logs.py +++ b/src/pytorch_lightning/loggers/csv_logs.py @@ -195,7 +195,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: self.experiment.log_hparams(params) @rank_zero_only - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Dict[str, Union[Tensor, float]], step: Optional[int] = None) -> None: metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) self.experiment.log_metrics(metrics, step) if step is not None and (step + 1) % self._flush_logs_every_n_steps == 0: diff --git a/src/pytorch_lightning/loggers/mlflow.py b/src/pytorch_lightning/loggers/mlflow.py index 313fcfe07f10e..5675a3bd9fc67 100644 --- a/src/pytorch_lightning/loggers/mlflow.py +++ b/src/pytorch_lightning/loggers/mlflow.py @@ -20,7 +20,7 @@ import re from argparse import Namespace from time import time -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, Mapping, Optional, Union from pytorch_lightning.loggers.logger import Logger, rank_zero_experiment from pytorch_lightning.utilities.imports import _module_available @@ -230,7 +230,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: self.experiment.log_param(self.run_id, k, v) @rank_zero_only - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) diff --git a/src/pytorch_lightning/loggers/tensorboard.py b/src/pytorch_lightning/loggers/tensorboard.py index 12ec2e21b84ce..dacecf129523b 100644 --- a/src/pytorch_lightning/loggers/tensorboard.py +++ b/src/pytorch_lightning/loggers/tensorboard.py @@ -216,7 +216,7 @@ def log_hyperparams( writer.add_summary(sei) @rank_zero_only - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) diff --git a/src/pytorch_lightning/loggers/wandb.py b/src/pytorch_lightning/loggers/wandb.py index bc2a84dc82b00..8e30827759b99 100644 --- a/src/pytorch_lightning/loggers/wandb.py +++ b/src/pytorch_lightning/loggers/wandb.py @@ -379,7 +379,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: self.experiment.config.update(params, allow_val_change=True) @rank_zero_only - def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: + def log_metrics(self, metrics: Mapping[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, "experiment tried to log from global_rank != 0" metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR) diff --git a/src/pytorch_lightning/utilities/logger.py b/src/pytorch_lightning/utilities/logger.py index 07ecf4c3c0ca0..24d75e4f41034 100644 --- a/src/pytorch_lightning/utilities/logger.py +++ b/src/pytorch_lightning/utilities/logger.py @@ -14,7 +14,7 @@ """Utilities for loggers.""" from argparse import Namespace -from typing import Any, Dict, Generator, List, MutableMapping, Optional, Union +from typing import Any, Dict, Generator, List, Mapping, MutableMapping, Optional, Union import numpy as np import torch @@ -132,7 +132,9 @@ def _sanitize_params(params: Dict[str, Any]) -> Dict[str, Any]: return params -def _add_prefix(metrics: Dict[str, float], prefix: str, separator: str) -> Dict[str, float]: +def _add_prefix( + metrics: Mapping[str, Union[Tensor, float]], prefix: str, separator: str +) -> Mapping[str, Union[Tensor, float]]: """Insert prefix before each key in a dict, separated by the separator. Args: