diff --git a/pytorch_lightning/loggers/__init__.py b/pytorch_lightning/loggers/__init__.py index d0e3cf759aeb8..d8049346f69bf 100644 --- a/pytorch_lightning/loggers/__init__.py +++ b/pytorch_lightning/loggers/__init__.py @@ -23,7 +23,8 @@ from pytorch_lightning.loggers.mlflow import _MLFLOW_AVAILABLE, MLFlowLogger # noqa: F401 from pytorch_lightning.loggers.neptune import _NEPTUNE_AVAILABLE, NeptuneLogger # noqa: F401 from pytorch_lightning.loggers.test_tube import _TESTTUBE_AVAILABLE, TestTubeLogger # noqa: F401 -from pytorch_lightning.loggers.wandb import _WANDB_AVAILABLE, WandbLogger # noqa: F401 +from pytorch_lightning.loggers.wandb import WandbLogger # noqa: F401 +from pytorch_lightning.utilities.imports import _WANDB_AVAILABLE if _COMET_AVAILABLE: __all__.append("CometLogger") diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 6e409029e0a95..a256ab05d1d89 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -15,7 +15,6 @@ Weights and Biases Logger ------------------------- """ -import operator import os from argparse import Namespace from pathlib import Path @@ -27,13 +26,10 @@ from pytorch_lightning.callbacks.model_checkpoint import ModelCheckpoint from pytorch_lightning.loggers.base import LightningLoggerBase, rank_zero_experiment from pytorch_lightning.utilities.exceptions import MisconfigurationException -from pytorch_lightning.utilities.imports import _compare_version, _module_available +from pytorch_lightning.utilities.imports import _WANDB_GREATER_EQUAL_0_10_22, _WANDB_GREATER_EQUAL_0_12_10 from pytorch_lightning.utilities.logger import _add_prefix, _convert_params, _flatten_dict, _sanitize_callable_params from pytorch_lightning.utilities.rank_zero import rank_zero_only, rank_zero_warn -_WANDB_AVAILABLE = _module_available("wandb") -_WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22") - try: import wandb from wandb.wandb_run import Run @@ -307,11 +303,18 @@ def __init__( self._save_dir = self._wandb_init.get("dir") self._name = self._wandb_init.get("name") self._id = self._wandb_init.get("id") + # start wandb run (to create an attach_id for distributed modes) + if _WANDB_GREATER_EQUAL_0_12_10: + wandb.require("service") + _ = self.experiment def __getstate__(self): state = self.__dict__.copy() # args needed to reload correct experiment - state["_id"] = self._experiment.id if self._experiment is not None else None + if self._experiment is not None: + state["_id"] = getattr(self._experiment, "id", None) + state["_attach_id"] = getattr(self._experiment, "_attach_id", None) + state["_name"] = self._experiment.project_name() # cannot be pickled state["_experiment"] = None @@ -335,19 +338,26 @@ def experiment(self) -> Run: if self._experiment is None: if self._offline: os.environ["WANDB_MODE"] = "dryrun" - if wandb.run is None: - self._experiment = wandb.init(**self._wandb_init) - else: + + attach_id = getattr(self, "_attach_id", None) + if wandb.run is not None: + # wandb process already created in this instance rank_zero_warn( "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse" " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`." ) self._experiment = wandb.run + elif attach_id is not None and hasattr(wandb, "_attach"): + # attach to wandb process referenced + self._experiment = wandb._attach(attach_id) + else: + # create new wandb process + self._experiment = wandb.init(**self._wandb_init) - # define default x-axis (for latest wandb versions) - if getattr(self._experiment, "define_metric", None): - self._experiment.define_metric("trainer/global_step") - self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) + # define default x-axis + if getattr(self._experiment, "define_metric", None): + self._experiment.define_metric("trainer/global_step") + self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True) return self._experiment diff --git a/pytorch_lightning/utilities/imports.py b/pytorch_lightning/utilities/imports.py index 2d335450d02e6..6c20d90e01646 100644 --- a/pytorch_lightning/utilities/imports.py +++ b/pytorch_lightning/utilities/imports.py @@ -116,8 +116,12 @@ def _compare_version(package: str, op: Callable, version: str, use_base_version: _TORCHTEXT_AVAILABLE = _package_available("torchtext") _TORCHTEXT_LEGACY: bool = _TORCHTEXT_AVAILABLE and _compare_version("torchtext", operator.lt, "0.11.0") _TORCHVISION_AVAILABLE = _package_available("torchvision") +_WANDB_AVAILABLE = _package_available("wandb") +_WANDB_GREATER_EQUAL_0_10_22 = _WANDB_AVAILABLE and _compare_version("wandb", operator.ge, "0.10.22") +_WANDB_GREATER_EQUAL_0_12_10 = _WANDB_AVAILABLE and _compare_version("wandb", operator.ge, "0.12.10") _XLA_AVAILABLE: bool = _package_available("torch_xla") + from pytorch_lightning.utilities.xla_device import XLADeviceUtils # noqa: E402 _TPU_AVAILABLE = XLADeviceUtils.tpu_device_exists() diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index c447e8d8e92ea..280303a3f7318 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -21,14 +21,16 @@ from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException from tests.helpers import BoringModel +from tests.helpers.utils import no_warning_call @mock.patch("pytorch_lightning.loggers.wandb.wandb") -def test_wandb_logger_init(wandb): +def test_wandb_logger_init(wandb, monkeypatch): """Verify that basic functionality of wandb logger works. Wandb doesn't work well with pytest so we have to mock it out here. """ + import pytorch_lightning.loggers.wandb as imports # test wandb.init called when there is no W&B run wandb.run = None @@ -51,14 +53,17 @@ def test_wandb_logger_init(wandb): wandb.init().log.reset_mock() wandb.init.reset_mock() wandb.run = wandb.init() - logger = WandbLogger() - - # verify default resume value - assert logger._wandb_init["resume"] == "allow" + monkeypatch.setattr(imports, "_WANDB_GREATER_EQUAL_0_12_10", True) with pytest.warns(UserWarning, match="There is a wandb run already in progress"): + logger = WandbLogger() + # check that no new run is created + with no_warning_call(UserWarning, match="There is a wandb run already in progress"): _ = logger.experiment + # verify default resume value + assert logger._wandb_init["resume"] == "allow" + logger.log_metrics({"acc": 1.0}, step=3) wandb.init.assert_called_once() wandb.init().log.assert_called_once_with({"acc": 1.0, "trainer/global_step": 3}) @@ -120,11 +125,16 @@ def project_name(self): @mock.patch("pytorch_lightning.loggers.wandb.wandb") -def test_wandb_logger_dirs_creation(wandb, tmpdir): +def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir): """Test that the logger creates the folders and files in the right place.""" + import pytorch_lightning.loggers.wandb as imports + + monkeypatch.setattr(imports, "_WANDB_GREATER_EQUAL_0_12_10", True) + wandb.run = None logger = WandbLogger(save_dir=str(tmpdir), offline=True) - assert logger.version is None - assert logger.name is None + # the logger get initialized + assert logger.version == wandb.init().id + assert logger.name == wandb.init().project_name() # mock return values of experiment wandb.run = None @@ -151,8 +161,11 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): @mock.patch("pytorch_lightning.loggers.wandb.wandb") -def test_wandb_log_model(wandb, tmpdir): +def test_wandb_log_model(wandb, monkeypatch, tmpdir): """Test that the logger creates the folders and files in the right place.""" + import pytorch_lightning.loggers.wandb as imports + + monkeypatch.setattr(imports, "_WANDB_GREATER_EQUAL_0_10_22", True) wandb.run = None model = BoringModel() @@ -186,13 +199,10 @@ def test_wandb_log_model(wandb, tmpdir): assert not wandb.init().log_artifact.called # test correct metadata - import pytorch_lightning.loggers.wandb as pl_wandb - - pl_wandb._WANDB_GREATER_EQUAL_0_10_22 = True wandb.init().log_artifact.reset_mock() wandb.init.reset_mock() wandb.Artifact.reset_mock() - logger = pl_wandb.WandbLogger(log_model=True) + logger = WandbLogger(log_model=True) logger.experiment.id = "1" logger.experiment.project_name.return_value = "project" trainer = Trainer(default_root_dir=tmpdir, logger=logger, max_epochs=2, limit_train_batches=3, limit_val_batches=3)