Skip to content

Commit ff41d80

Browse files
borisdaymaawaelchlirohitgr7
authored
feat(wandb): log in sync with Trainer step (#4405)
* feat(wandb): log in sync with Trainer step * docs: update CHANGELOG * style(test_wandb): fix formatting * parentheses Co-authored-by: Adrian Wälchli <[email protected]> Co-authored-by: Rohit Gupta <[email protected]>
1 parent 41de453 commit ff41d80

File tree

3 files changed

+18
-4
lines changed

3 files changed

+18
-4
lines changed

CHANGELOG.md

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,9 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
4545
- Fixed santized parameters for `WandbLogger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320))
4646

4747

48+
- W&B log in sync with Trainer step ([#4405](https://github.com/PyTorchLightning/pytorch-lightning/pull/4405))
49+
50+
4851
### Deprecated
4952

5053

pytorch_lightning/loggers/wandb.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,8 @@ def __init__(
9494
self._offline = offline
9595
self._log_model = log_model
9696
self._kwargs = kwargs
97+
# logging multiple Trainer on a single W&B run (k-fold, etc)
98+
self._step_offset = 0
9799

98100
def __getstate__(self):
99101
state = self.__dict__.copy()
@@ -141,8 +143,7 @@ def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
141143
@rank_zero_only
142144
def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) -> None:
143145
assert rank_zero_only.rank == 0, 'experiment tried to log from global_rank != 0'
144-
145-
self.experiment.log({'global_step': step, **metrics} if step is not None else metrics)
146+
self.experiment.log(metrics, step=(step + self._step_offset) if step is not None else None)
146147

147148
@property
148149
def save_dir(self) -> Optional[str]:
@@ -159,6 +160,10 @@ def version(self) -> Optional[str]:
159160
return self._experiment.id if self._experiment else self._id
160161

161162
def finalize(self, status: str) -> None:
163+
# offset future training logged on same W&B run
164+
if self._experiment is not None:
165+
self._step_offset = self._experiment.step
166+
162167
# upload all checkpoints from saving dir
163168
if self._log_model:
164169
wandb.save(os.path.join(self.save_dir, "*.ckpt"))

tests/loggers/test_wandb.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -29,11 +29,17 @@ def test_wandb_logger(wandb):
2929
logger = WandbLogger(anonymous=True, offline=True)
3030

3131
logger.log_metrics({'acc': 1.0})
32-
wandb.init().log.assert_called_once_with({'acc': 1.0})
32+
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=None)
3333

3434
wandb.init().log.reset_mock()
3535
logger.log_metrics({'acc': 1.0}, step=3)
36-
wandb.init().log.assert_called_once_with({'global_step': 3, 'acc': 1.0})
36+
wandb.init().log.assert_called_once_with({'acc': 1.0}, step=3)
37+
38+
# continue training on same W&B run
39+
wandb.init().step = 3
40+
logger.finalize('success')
41+
logger.log_metrics({'acc': 1.0}, step=3)
42+
wandb.init().log.assert_called_with({'acc': 1.0}, step=6)
3743

3844
logger.log_hyperparams({'test': None, 'nested': {'a': 1}, 'b': [2, 3, 4]})
3945
wandb.init().config.update.assert_called_once_with(

0 commit comments

Comments
 (0)