Skip to content

Commit 65e1256

Browse files
authored
Merge branch 'master' into master
2 parents 016c10a + f07ee33 commit 65e1256

File tree

5 files changed

+71
-0
lines changed

5 files changed

+71
-0
lines changed

.gitignore

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -138,3 +138,4 @@ mlruns/
138138
*.ckpt
139139
pytorch\ lightning
140140
test-reports/
141+
wandb

CHANGELOG.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,30 +11,45 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
1111

1212
- Added `dirpath` and `filename` parameter in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213))
1313

14+
1415
- Added plugins docs and DDPPlugin to customize ddp across all accelerators([#4258](https://github.com/PyTorchLightning/pytorch-lightning/pull/4285))
1516

17+
1618
- Added `strict` option to the scheduler dictionary ([#3586](https://github.com/PyTorchLightning/pytorch-lightning/pull/3586))
1719

1820
- Added autogenerated helptext to `Trainer.add_argparse_args`. ([#4344](https://github.com/PyTorchLightning/pytorch-lightning/pull/4344))
1921

2022
- Added `fsspec` support for profilers ([#4162](https://github.com/PyTorchLightning/pytorch-lightning/pull/4162))
2123

24+
2225
### Changed
2326

27+
2428
- Improved error messages for invalid `configure_optimizers` returns ([#3587](https://github.com/PyTorchLightning/pytorch-lightning/pull/3587))
2529

30+
2631
- Allow changing the logged step value in `validation_step` ([#4130](https://github.com/PyTorchLightning/pytorch-lightning/pull/4130))
32+
33+
2734
- Allow setting `replace_sampler_ddp=True` with a distributed sampler already added ([#4273](https://github.com/PyTorchLightning/pytorch-lightning/pull/4273))
2835

36+
37+
- Fixed santized parameters for `WandbLogger.log_hyperparams` ([#4320](https://github.com/PyTorchLightning/pytorch-lightning/pull/4320))
38+
39+
2940
### Deprecated
3041

42+
3143
- Deprecated `filepath` in `ModelCheckpoint` ([#4213](https://github.com/PyTorchLightning/pytorch-lightning/pull/4213))
3244

45+
3346
- Deprecated `reorder` parameter of the `auc` metric ([#4237](https://github.com/PyTorchLightning/pytorch-lightning/pull/4237))
3447

48+
3549
### Removed
3650

3751

52+
3853
### Fixed
3954

4055
- Fixed setting device ids in DDP ([#4297](https://github.com/PyTorchLightning/pytorch-lightning/pull/4297))

pytorch_lightning/loggers/base.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,31 @@ def _convert_params(params: Union[Dict[str, Any], Namespace]) -> Dict[str, Any]:
168168

169169
return params
170170

171+
@staticmethod
172+
def _sanitize_callable_params(params: Dict[str, Any]) -> Dict[str, Any]:
173+
"""
174+
Sanitize callable params dict, e.g. ``{'a': <function_**** at 0x****>} -> {'a': 'function_****'}``.
175+
176+
Args:
177+
params: Dictionary containing the hyperparameters
178+
179+
Returns:
180+
dictionary with all callables sanitized
181+
"""
182+
def _sanitize_callable(val):
183+
# Give them one chance to return a value. Don't go rabbit hole of recursive call
184+
if isinstance(val, Callable):
185+
try:
186+
_val = val()
187+
if isinstance(_val, Callable):
188+
return val.__name__
189+
return _val
190+
except Exception:
191+
return val.__name__
192+
return val
193+
194+
return {key: _sanitize_callable(val) for key, val in params.items()}
195+
171196
@staticmethod
172197
def _flatten_dict(params: Dict[str, Any], delimiter: str = '/') -> Dict[str, Any]:
173198
"""

pytorch_lightning/loggers/wandb.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,7 @@ def watch(self, model: nn.Module, log: str = 'gradients', log_freq: int = 100):
135135
def log_hyperparams(self, params: Union[Dict[str, Any], Namespace]) -> None:
136136
params = self._convert_params(params)
137137
params = self._flatten_dict(params)
138+
params = self._sanitize_callable_params(params)
138139
self.experiment.config.update(params, allow_val_change=True)
139140

140141
@rank_zero_only

tests/loggers/test_wandb.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414
import os
1515
import pickle
1616
from unittest import mock
17+
from argparse import ArgumentParser
18+
import types
1719

1820
from pytorch_lightning import Trainer
1921
from pytorch_lightning.loggers import WandbLogger
@@ -109,3 +111,30 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir):
109111

110112
assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints')
111113
assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'}
114+
115+
116+
def test_wandb_sanitize_callable_params(tmpdir):
117+
"""
118+
Callback function are not serializiable. Therefore, we get them a chance to return
119+
something and if the returned type is not accepted, return None.
120+
"""
121+
opt = "--max_epochs 1".split(" ")
122+
parser = ArgumentParser()
123+
parser = Trainer.add_argparse_args(parent_parser=parser)
124+
params = parser.parse_args(opt)
125+
126+
def return_something():
127+
return "something"
128+
params.something = return_something
129+
130+
def wrapper_something():
131+
return return_something
132+
params.wrapper_something = wrapper_something
133+
134+
assert isinstance(params.gpus, types.FunctionType)
135+
params = WandbLogger._convert_params(params)
136+
params = WandbLogger._flatten_dict(params)
137+
params = WandbLogger._sanitize_callable_params(params)
138+
assert params["gpus"] == '_gpus_arg_default'
139+
assert params["something"] == "something"
140+
assert params["wrapper_something"] == "wrapper_something"

0 commit comments

Comments
 (0)