diff --git a/CHANGELOG.md b/CHANGELOG.md index 3ac97333e6d90..c25c247505cdf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -24,6 +24,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Added DeepSpeed collate checkpoint utility function ([#8701](https://github.com/PyTorchLightning/pytorch-lightning/pull/8701)) +- Added a warning to `WandbLogger` when reusing a wandb run ([#8714](https://github.com/PyTorchLightning/pytorch-lightning/pull/8714)) + + - Added `log_graph` argument for `watch` method of `WandbLogger` ([#8662](https://github.com/PyTorchLightning/pytorch-lightning/pull/8662)) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index 0d7932c40c34c..e3c31c6c4cda2 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -183,7 +183,14 @@ def experiment(self) -> Run: if self._experiment is None: if self._offline: os.environ["WANDB_MODE"] = "dryrun" - self._experiment = wandb.init(**self._wandb_init) if wandb.run is None else wandb.run + if wandb.run is None: + self._experiment = wandb.init(**self._wandb_init) + else: + warning_cache.warn( + "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse" + " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`." + ) + self._experiment = wandb.run # define default x-axis (for latest wandb versions) if getattr(self._experiment, "define_metric", None): diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 40243860b1cb9..03953388c1877 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -18,6 +18,7 @@ import pytest +import pytorch_lightning from pytorch_lightning import Trainer from pytorch_lightning.loggers import WandbLogger from pytorch_lightning.utilities.exceptions import MisconfigurationException @@ -51,8 +52,13 @@ def test_wandb_logger_init(wandb): wandb.init.reset_mock() wandb.run = wandb.init() logger = WandbLogger() + # verify default resume value assert logger._wandb_init["resume"] == "allow" + + _ = logger.experiment + assert any("There is a wandb run already in progress" in w for w in pytorch_lightning.loggers.wandb.warning_cache) + logger.log_metrics({"acc": 1.0}, step=3) wandb.init.assert_called_once() wandb.init().log.assert_called_once_with({"acc": 1.0, "trainer/global_step": 3})