diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index b1b5d91eef7a2..94baf1781e7ab 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -29,9 +29,7 @@ from pytorch_lightning.utilities import _module_available, rank_zero_only from pytorch_lightning.utilities.exceptions import MisconfigurationException from pytorch_lightning.utilities.imports import _compare_version -from pytorch_lightning.utilities.warnings import WarningCache - -warning_cache = WarningCache() +from pytorch_lightning.utilities.warnings import rank_zero_deprecation, rank_zero_warn _WANDB_AVAILABLE = _module_available("wandb") _WANDB_GREATER_EQUAL_0_10_22 = _compare_version("wandb", operator.ge, "0.10.22") @@ -129,7 +127,7 @@ def __init__( ) if log_model and not _WANDB_GREATER_EQUAL_0_10_22: - warning_cache.warn( + rank_zero_warn( f"Providing log_model={log_model} requires wandb version >= 0.10.22" " for logging associated model metadata.\n" "Hint: Upgrade with `pip install --ugrade wandb`." @@ -186,7 +184,7 @@ def experiment(self) -> Run: if wandb.run is None: self._experiment = wandb.init(**self._wandb_init) else: - warning_cache.warn( + rank_zero_warn( "There is a wandb run already in progress and newly created instances of `WandbLogger` will reuse" " this run. If this is not desired, call `wandb.finish()` before instantiating `WandbLogger`." ) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 03953388c1877..e5b80993b392c 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -56,8 +56,8 @@ def test_wandb_logger_init(wandb): # verify default resume value assert logger._wandb_init["resume"] == "allow" - _ = logger.experiment - assert any("There is a wandb run already in progress" in w for w in pytorch_lightning.loggers.wandb.warning_cache) + with pytest.warns(UserWarning, match="There is a wandb run already in progress"): + _ = logger.experiment logger.log_metrics({"acc": 1.0}, step=3) wandb.init.assert_called_once()