Skip to content
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ module = [
"pytorch_lightning.distributed.dist",
"pytorch_lightning.loggers.base",
"pytorch_lightning.loggers.comet",
"pytorch_lightning.loggers.csv_logs",
"pytorch_lightning.loggers.mlflow",
"pytorch_lightning.loggers.neptune",
"pytorch_lightning.loggers.tensorboard",
Expand Down
22 changes: 11 additions & 11 deletions src/pytorch_lightning/loggers/csv_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import logging
import os
from argparse import Namespace
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, List, Optional, Union

from torch import Tensor

Expand All @@ -49,8 +49,8 @@ class ExperimentWriter:
NAME_METRICS_FILE = "metrics.csv"

def __init__(self, log_dir: str) -> None:
self.hparams = {}
self.metrics = []
self.hparams: Dict[str, Any] = {}
self.metrics: List[Dict[str, float]] = []

self.log_dir = log_dir
if os.path.exists(self.log_dir) and os.listdir(self.log_dir):
Expand All @@ -69,7 +69,7 @@ def log_hparams(self, params: Dict[str, Any]) -> None:
def log_metrics(self, metrics_dict: Dict[str, float], step: Optional[int] = None) -> None:
"""Record metrics."""

def _handle_value(value):
def _handle_value(value: Union[Tensor, Any]) -> Any:
if isinstance(value, Tensor):
return value.item()
return value
Expand Down Expand Up @@ -126,7 +126,7 @@ class CSVLogger(Logger):
def __init__(
self,
save_dir: str,
name: Optional[str] = "lightning_logs",
name: str = "lightning_logs",
version: Optional[Union[int, str]] = None,
prefix: str = "",
flush_logs_every_n_steps: int = 100,
Expand All @@ -136,7 +136,7 @@ def __init__(
self._name = name or ""
self._version = version
self._prefix = prefix
self._experiment = None
self._experiment: Optional[ExperimentWriter] = None
self._flush_logs_every_n_steps = flush_logs_every_n_steps

@property
Expand All @@ -161,15 +161,15 @@ def log_dir(self) -> str:
return log_dir

@property
def save_dir(self) -> Optional[str]:
def save_dir(self) -> str:
"""The current directory where logs are saved.

Returns:
The path to current directory where logs are saved.
"""
return self._save_dir

@property
@property # type: ignore[misc]
@rank_zero_experiment
def experiment(self) -> ExperimentWriter:
r"""
Expand All @@ -182,7 +182,7 @@ def experiment(self) -> ExperimentWriter:
self.logger.experiment.some_experiment_writer_function()

"""
if self._experiment:
if self._experiment is not None:
return self._experiment

os.makedirs(self.root_dir, exist_ok=True)
Expand Down Expand Up @@ -220,7 +220,7 @@ def name(self) -> str:
return self._name

@property
def version(self) -> int:
def version(self) -> Union[int, str]:
"""Gets the version of the experiment.

Returns:
Expand All @@ -230,7 +230,7 @@ def version(self) -> int:
self._version = self._get_next_version()
return self._version

def _get_next_version(self):
def _get_next_version(self) -> int:
root_dir = self.root_dir

if not os.path.isdir(root_dir):
Expand Down