|
14 | 14 | import os |
15 | 15 | import pickle |
16 | 16 | from unittest import mock |
| 17 | +from argparse import ArgumentParser |
| 18 | +import types |
17 | 19 |
|
18 | 20 | from pytorch_lightning import Trainer |
19 | 21 | from pytorch_lightning.loggers import WandbLogger |
@@ -109,3 +111,30 @@ def test_wandb_logger_dirs_creation(wandb, tmpdir): |
109 | 111 |
|
110 | 112 | assert trainer.checkpoint_callback.dirpath == str(tmpdir / 'project' / version / 'checkpoints') |
111 | 113 | assert set(os.listdir(trainer.checkpoint_callback.dirpath)) == {'epoch=0.ckpt'} |
| 114 | + |
| 115 | + |
| 116 | +def test_wandb_sanitize_callable_params(tmpdir): |
| 117 | + """ |
| 118 | + Callback function are not serializiable. Therefore, we get them a chance to return |
| 119 | + something and if the returned type is not accepted, return None. |
| 120 | + """ |
| 121 | + opt = "--max_epochs 1".split(" ") |
| 122 | + parser = ArgumentParser() |
| 123 | + parser = Trainer.add_argparse_args(parent_parser=parser) |
| 124 | + params = parser.parse_args(opt) |
| 125 | + |
| 126 | + def return_something(): |
| 127 | + return "something" |
| 128 | + params.something = return_something |
| 129 | + |
| 130 | + def wrapper_something(): |
| 131 | + return return_something |
| 132 | + params.wrapper_something = wrapper_something |
| 133 | + |
| 134 | + assert isinstance(params.gpus, types.FunctionType) |
| 135 | + params = WandbLogger._convert_params(params) |
| 136 | + params = WandbLogger._flatten_dict(params) |
| 137 | + params = WandbLogger._sanitize_callable_params(params) |
| 138 | + assert params["gpus"] == '_gpus_arg_default' |
| 139 | + assert params["something"] == "something" |
| 140 | + assert params["wrapper_something"] == "wrapper_something" |
0 commit comments