Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -50,5 +50,5 @@ dependencies:
- test-tube>=0.7.5
- mlflow>=1.0.0
- comet_ml>=3.1.12
- wandb>=0.8.21
- wandb>=0.10.22
- neptune-client>=0.10.0
1 change: 0 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -59,7 +59,6 @@ module = [
"pytorch_lightning.demos.mnist_datamodule",
"pytorch_lightning.loggers.comet",
"pytorch_lightning.loggers.neptune",
"pytorch_lightning.loggers.wandb",
"pytorch_lightning.profilers.advanced",
"pytorch_lightning.profilers.base",
"pytorch_lightning.profilers.pytorch",
Expand Down
2 changes: 1 addition & 1 deletion requirements/pytorch/extra.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,6 @@ matplotlib>3.1, <3.5.3
torchtext>=0.10.*, <=0.12.0
omegaconf>=2.0.5, <2.3.0
hydra-core>=1.0.5, <1.3.0
jsonargparse[signatures]>=4.10.0, <=4.10.0
jsonargparse[signatures]>=4.10.2, <=4.10.2
gcsfs>=2021.5.0, <2022.6.0
rich>=10.14.0, !=10.15.0.a, <13.0.0
2 changes: 1 addition & 1 deletion requirements/pytorch/loggers.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ neptune-client>=0.10.0, <0.16.4
comet-ml>=3.1.12, <3.31.6
mlflow>=1.0.0, <1.27.0
test_tube>=0.7.5, <=0.7.5
wandb>=0.8.21, <0.12.20
wandb>=0.10.22, <0.12.20
35 changes: 18 additions & 17 deletions src/pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,11 @@

try:
import wandb
from wandb.sdk.lib import RunDisabled
from wandb.wandb_run import Run
except ModuleNotFoundError:
# needed for test mocks, these tests shall be updated
wandb, Run = None, None
wandb, Run, RunDisabled = None, None, None # type: ignore


class WandbLogger(Logger):
Expand Down Expand Up @@ -251,18 +252,18 @@ def __init__(
self,
name: Optional[str] = None,
save_dir: Optional[str] = None,
offline: Optional[bool] = False,
offline: bool = False,
id: Optional[str] = None,
anonymous: Optional[bool] = None,
version: Optional[str] = None,
project: Optional[str] = None,
log_model: Union[str, bool] = False,
experiment=None,
prefix: Optional[str] = "",
experiment: Union[Run, RunDisabled, None] = None,
prefix: str = "",
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
**kwargs,
):
**kwargs: Any,
) -> None:
if wandb is None:
raise ModuleNotFoundError(
"You want to use `wandb` logger which is not installed yet,"
Expand All @@ -288,17 +289,16 @@ def __init__(
self._log_model = log_model
self._prefix = prefix
self._experiment = experiment
self._logged_model_time = {}
self._checkpoint_callback = None
self._logged_model_time: Dict[str, float] = {}
self._checkpoint_callback: Optional["ReferenceType[Checkpoint]"] = None
# set wandb init arguments
anonymous_lut = {True: "allow", False: None}
self._wandb_init = dict(
self._wandb_init: Dict[str, Any] = dict(
name=name or project,
project=project,
id=version or id,
dir=save_dir,
resume="allow",
anonymous=anonymous_lut.get(anonymous, anonymous),
anonymous=("allow" if anonymous else None),
)
self._wandb_init.update(**kwargs)
# extract parameters
Expand All @@ -310,7 +310,7 @@ def __init__(
wandb.require("service")
_ = self.experiment

def __getstate__(self):
def __getstate__(self) -> Dict[str, Any]:
state = self.__dict__.copy()
# args needed to reload correct experiment
if self._experiment is not None:
Expand All @@ -322,7 +322,7 @@ def __getstate__(self):
state["_experiment"] = None
return state

@property
@property # type: ignore[misc]
@rank_zero_experiment
def experiment(self) -> Run:
r"""
Expand Down Expand Up @@ -357,13 +357,14 @@ def experiment(self) -> Run:
self._experiment = wandb.init(**self._wandb_init)

# define default x-axis
if getattr(self._experiment, "define_metric", None):
if isinstance(self._experiment, Run) and getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)

assert isinstance(self._experiment, Run)
return self._experiment

def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True):
def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True) -> None:
self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph)

@rank_zero_only
Expand All @@ -379,7 +380,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->

metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
if step is not None:
self.experiment.log({**metrics, "trainer/global_step": step})
self.experiment.log(dict(metrics, **{"trainer/global_step": step}))
else:
self.experiment.log(metrics)

Expand Down Expand Up @@ -417,7 +418,7 @@ def log_text(
self.log_table(key, columns, data, dataframe, step)

@rank_zero_only
def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: str) -> None:
def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: Any) -> None:
"""Log images (tensors, numpy arrays, PIL Images or file paths).

Optional kwargs are lists passed to each image (ex: caption, masks, boxes).
Expand Down
2 changes: 1 addition & 1 deletion src/pytorch_lightning/utilities/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
from pytorch_lightning.utilities.model_helpers import is_overridden
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn

_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.10.0")
_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.10.2")

if _JSONARGPARSE_SIGNATURES_AVAILABLE:
import docstring_parser
Expand Down
5 changes: 4 additions & 1 deletion tests/tests_pytorch/loggers/test_all.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"),
mock.patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock),
mock.patch("pytorch_lightning.loggers.wandb.wandb"),
mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock),
)
ALL_LOGGER_CLASSES = (
CometLogger,
Expand Down Expand Up @@ -363,7 +364,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)

# WandB
with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb:
with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb, mock.patch(
"pytorch_lightning.loggers.wandb.Run", new=mock.Mock
):
logger = _instantiate_logger(WandbLogger, save_dir=tmpdir, prefix=prefix)
wandb.run = None
wandb.init().step = 0
Expand Down
29 changes: 17 additions & 12 deletions tests/tests_pytorch/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from tests_pytorch.helpers.utils import no_warning_call


@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
def test_wandb_logger_init(wandb, monkeypatch):
"""Verify that basic functionality of wandb logger works.
Expand Down Expand Up @@ -111,20 +112,21 @@ class Experiment:
def name(self):
return "the_run_name"

wandb.run = None
wandb.init.return_value = Experiment()
logger = WandbLogger(id="the_id", offline=True)
with mock.patch("pytorch_lightning.loggers.wandb.Run", new=Experiment):
wandb.run = None
wandb.init.return_value = Experiment()
logger = WandbLogger(id="the_id", offline=True)

trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=logger)
# Access the experiment to ensure it's created
assert trainer.logger.experiment, "missing experiment"
assert trainer.log_dir == logger.save_dir
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=logger)
# Access the experiment to ensure it's created
assert trainer.logger.experiment, "missing experiment"
assert trainer.log_dir == logger.save_dir
pkl_bytes = pickle.dumps(trainer)
trainer2 = pickle.loads(pkl_bytes)

assert os.environ["WANDB_MODE"] == "dryrun"
assert trainer2.logger.__class__.__name__ == WandbLogger.__name__
assert trainer2.logger.experiment, "missing experiment"
assert os.environ["WANDB_MODE"] == "dryrun"
assert trainer2.logger.__class__.__name__ == WandbLogger.__name__
assert trainer2.logger.experiment, "missing experiment"

wandb.init.assert_called()
assert "id" in wandb.init.call_args[1]
Expand All @@ -133,6 +135,7 @@ def name(self):
del os.environ["WANDB_MODE"]


@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir):
"""Test that the logger creates the folders and files in the right place."""
Expand Down Expand Up @@ -169,6 +172,7 @@ def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir):
assert trainer.log_dir == logger.save_dir


@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
def test_wandb_log_model(wandb, monkeypatch, tmpdir):
"""Test that the logger creates the folders and files in the right place."""
Expand Down Expand Up @@ -234,6 +238,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
)


@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
def test_wandb_log_media(wandb, tmpdir):
"""Test that the logger creates the folders and files in the right place."""
Expand Down