From 59e33f9c0df2f0727c34f8d69d0e38af566f46f8 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Tue, 27 Oct 2020 18:44:38 -0500 Subject: [PATCH 1/4] feat(wandb): log in sync with Trainer step --- pytorch_lightning/loggers/wandb.py | 9 +++++++-- tests/loggers/test_wandb.py | 11 +++++++++-- 2 files changed, 16 insertions(+), 4 deletions(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index e6ce264d597bf..f837cd3ddd53a 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -94,6 +94,8 @@ def __init__( self._offline = offline self._log_model = log_model self._kwargs = kwargs + # logging multiple Trainer on a single W&B run (k-fold, etc) + self._step_offset = 0 def __getstate__(self): state = self.__dict__.copy() @@ -141,8 +143,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0' - - self.experiment.log({'global_step': step, **metrics} if step is not None else metrics) + self.experiment.log(metrics, step=step + self._step_offset if step is not None else None) @property def save_dir(self) -> Optional[str]: @@ -159,6 +160,10 @@ def version(self) -> Optional[str]: return self._experiment.id if self._experiment else self._id def finalize(self, status: str) -> None: + # offset future training logged on same W&B run + if self._experiment is not None: + self._step_offset = self._experiment.step + # upload all checkpoints from saving dir if self._log_model: wandb.save(os.path.join(self.save_dir, "*.ckpt")) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 6682cfdc8830a..5a9e87cfb9ae8 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -29,11 +29,18 @@ def test_wandb_logger(wandb): logger = WandbLogger(anonymous=True, offline=True) logger.log_metrics({'acc': 1.0}) - wandb.init().log.assert_called_once_with({'acc': 1.0}) + wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None) wandb.init().log.reset_mock() logger.log_metrics({'acc': 1.0}, step=3) - wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0}) + wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3) + + # continue training on same W&B run + wandb.init().step = 3 + logger.finalize('success') + logger.log_metrics({'acc': 1.0}, step=3) + wandb.init().log.assert_called_with({'acc': 1.0}, step=6) + logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]}) wandb.init().config.update.assert_called_once_with( From 766fb847a39b531df10937ccb62dd6947e549c01 Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Tue, 27 Oct 2020 20:09:03 -0500 Subject: [PATCH 2/4] docs: update CHANGELOG --- CHANGELOG.md | 3 +++ 1 file changed, 3 insertions(+) diff --git a/CHANGELOG.md b/CHANGELOG.md index f8ed74235c859..cfc51dce15ad9 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -42,6 +42,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). - Fixed santized parameters for `WandbLogger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320)) +- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405)) + + ### Deprecated From 7c4fdcf08ab9c0c2e3ecf0b9a19ad60b912ff75b Mon Sep 17 00:00:00 2001 From: Boris Dayma Date: Tue, 27 Oct 2020 20:11:38 -0500 Subject: [PATCH 3/4] style(test_wandb): fix formatting --- tests/loggers/test_wandb.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/loggers/test_wandb.py b/tests/loggers/test_wandb.py index 5a9e87cfb9ae8..cfb6533bd913b 100644 --- a/tests/loggers/test_wandb.py +++ b/tests/loggers/test_wandb.py @@ -41,7 +41,6 @@ def test_wandb_logger(wandb): logger.log_metrics({'acc': 1.0}, step=3) wandb.init().log.assert_called_with({'acc': 1.0}, step=6) - logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]}) wandb.init().config.update.assert_called_once_with( {'test': 'None', 'nested/a': 1, 'b': [2, 3, 4]}, From 36033a368450d108a600eee6ba84d8c45fd7c72f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adrian=20W=C3=A4lchli?= Date: Wed, 28 Oct 2020 13:23:05 +0100 Subject: [PATCH 4/4] parentheses --- pytorch_lightning/loggers/wandb.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_lightning/loggers/wandb.py b/pytorch_lightning/loggers/wandb.py index f837cd3ddd53a..5786a52a8e371 100644 --- a/pytorch_lightning/loggers/wandb.py +++ b/pytorch_lightning/loggers/wandb.py @@ -143,7 +143,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None: @rank_zero_only def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None: assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0' - self.experiment.log(metrics, step=step + self._step_offset if step is not None else None) + self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None) @property def save_dir(self) -> Optional[str]: