Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion pytorch_lightning/loggers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,8 @@
from pytorch_lightning.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger # noqa: F401
from pytorch_lightning.loggers.neptune import _NEPTUNE_AVAILABLE, NeptuneLogger # noqa: F401
from pytorch_lightning.loggers.test_tube import _TESTTUBE_AVAILABLE, TestTubeLogger # noqa: F401
from pytorch_lightning.loggers.wandb import _WANDB_AVAILABLE, WandbLogger # noqa: F401
from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F401
from pytorch_lightning.utilities.imports import _WANDB_AVAILABLE

if _COMET_AVAILABLE:
__all__.append("CometLogger")
Expand Down
36 changes: 23 additions & 13 deletions pytorch_lightning/loggers/wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
Weights and Biases Logger
-------------------------
"""
import operator
import os
from argparse import Namespace
from pathlib import Path
Expand All @@ -27,13 +26,10 @@
from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint
from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from pytorch_lightning.utilities.imports import _compare_version, _module_available
from pytorch_lightning.utilities.imports import _WANDB_GREATER_EQUAL_0_10_22, _WANDB_GREATER_EQUAL_0_12_10
from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params
from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn

_WANDB_AVAILABLE = _module_available("wandb")
_WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22")

try:
import wandb
from wandb.wandb_run import Run
Expand Down Expand Up @@ -307,11 +303,18 @@ def __init__(
self._save_dir = self._wandb_init.get("dir")
self._name = self._wandb_init.get("name")
self._id = self._wandb_init.get("id")
# start wandb run (to create an attach_id for distributed modes)
if _WANDB_GREATER_EQUAL_0_12_10:
wandb.require("service")
_ = self.experiment

def __getstate__(self):
state = self.__dict__.copy()
# args needed to reload correct experiment
state["_id"] = self._experiment.id if self._experiment is not None else None
if self._experiment is not None:
state["_id"] = getattr(self._experiment, "id", None)
state["_attach_id"] = getattr(self._experiment, "_attach_id", None)
state["_name"] = self._experiment.project_name()

# cannot be pickled
state["_experiment"] = None
Expand All @@ -335,19 +338,26 @@ def experiment(self) -> Run:
if self._experiment is None:
if self._offline:
os.environ["WANDB_MODE"] = "dryrun"
if wandb.run is None:
self._experiment = wandb.init(**self._wandb_init)
else:

attach_id = getattr(self, "_attach_id", None)
if wandb.run is not None:
# wandb process already created in this instance
rank_zero_warn(
"There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse"
" this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`."
)
self._experiment = wandb.run
elif attach_id is not None and hasattr(wandb, "_attach"):
# attach to wandb process referenced
self._experiment = wandb._attach(attach_id)
else:
# create new wandb process
self._experiment = wandb.init(**self._wandb_init)

# define default x-axis (for latest wandb versions)
if getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
# define default x-axis
if getattr(self._experiment, "define_metric", None):
self._experiment.define_metric("trainer/global_step")
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)

return self._experiment

Expand Down
4 changes: 4 additions & 0 deletions pytorch_lightning/utilities/imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,12 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version:
_TORCHTEXT_AVAILABLE = _package_available("torchtext")
_TORCHTEXT_LEGACY: bool = _TORCHTEXT_AVAILABLE and _compare_version("torchtext", operator.lt, "0.11.0")
_TORCHVISION_AVAILABLE = _package_available("torchvision")
_WANDB_AVAILABLE = _package_available("wandb")
_WANDB_GREATER_EQUAL_0_10_22 = _WANDB_AVAILABLE and _compare_version("wandb", operator.ge, "0.10.22")
_WANDB_GREATER_EQUAL_0_12_10 = _WANDB_AVAILABLE and _compare_version("wandb", operator.ge, "0.12.10")
_XLA_AVAILABLE: bool = _package_available("torch_xla")


from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402

_TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists()
Expand Down
36 changes: 23 additions & 13 deletions tests/loggers/test_wandb.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,16 @@
from pytorch_lightning.loggers import WandbLogger
from pytorch_lightning.utilities.exceptions import MisconfigurationException
from tests.helpers import BoringModel
from tests.helpers.utils import no_warning_call


@mock.patch("pytorch_lightning.loggers.wandb.wandb")
def test_wandb_logger_init(wandb):
def test_wandb_logger_init(wandb, monkeypatch):
"""Verify that basic functionality of wandb logger works.

Wandb doesn't work well with pytest so we have to mock it out here.
"""
import pytorch_lightning.loggers.wandb as imports

# test wandb.init called when there is no W&B run
wandb.run = None
Expand All @@ -51,14 +53,17 @@ def test_wandb_logger_init(wandb):
wandb.init().log.reset_mock()
wandb.init.reset_mock()
wandb.run = wandb.init()
logger = WandbLogger()

# verify default resume value
assert logger._wandb_init["resume"] == "allow"

monkeypatch.setattr(imports, "_WANDB_GREATER_EQUAL_0_12_10", True)
with pytest.warns(UserWarning, match="There is a wandb run already in progress"):
logger = WandbLogger()
# check that no new run is created
with no_warning_call(UserWarning, match="There is a wandb run already in progress"):
_ = logger.experiment

# verify default resume value
assert logger._wandb_init["resume"] == "allow"

logger.log_metrics({"acc": 1.0}, step=3)
wandb.init.assert_called_once()
wandb.init().log.assert_called_once_with({"acc": 1.0, "trainer/global_step": 3})
Expand Down Expand Up @@ -120,11 +125,16 @@ def project_name(self):


@mock.patch("pytorch_lightning.loggers.wandb.wandb")
def test_wandb_logger_dirs_creation(wandb, tmpdir):
def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir):
"""Test that the logger creates the folders and files in the right place."""
import pytorch_lightning.loggers.wandb as imports

monkeypatch.setattr(imports, "_WANDB_GREATER_EQUAL_0_12_10", True)
wandb.run = None
logger = WandbLogger(save_dir=str(tmpdir), offline=True)
assert logger.version is None
assert logger.name is None
# the logger get initialized
assert logger.version == wandb.init().id
assert logger.name == wandb.init().project_name()

# mock return values of experiment
wandb.run = None
Expand All @@ -151,8 +161,11 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):


@mock.patch("pytorch_lightning.loggers.wandb.wandb")
def test_wandb_log_model(wandb, tmpdir):
def test_wandb_log_model(wandb, monkeypatch, tmpdir):
"""Test that the logger creates the folders and files in the right place."""
import pytorch_lightning.loggers.wandb as imports

monkeypatch.setattr(imports, "_WANDB_GREATER_EQUAL_0_10_22", True)

wandb.run = None
model = BoringModel()
Expand Down Expand Up @@ -186,13 +199,10 @@ def test_wandb_log_model(wandb, tmpdir):
assert not wandb.init().log_artifact.called

# test correct metadata
import pytorch_lightning.loggers.wandb as pl_wandb

pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True
wandb.init().log_artifact.reset_mock()
wandb.init.reset_mock()
wandb.Artifact.reset_mock()
logger = pl_wandb.WandbLogger(log_model=True)
logger = WandbLogger(log_model=True)
logger.experiment.id = "1"
logger.experiment.project_name.return_value = "project"
trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)
Expand Down