Skip to content

Commit 86341ba

Browse files
gautierdagcarmoccarohitgr7akihironitta
authored andcommitted
fix mypy errors for loggers/wandb.py (#13483)
Co-authored-by: Carlos Mocholí <[email protected]> Co-authored-by: Rohit Gupta <[email protected]> Co-authored-by: Akihiro Nitta <[email protected]>
1 parent e36fd77 commit 86341ba

File tree

8 files changed

+43
-35
lines changed

8 files changed

+43
-35
lines changed

environment.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,5 +50,5 @@ dependencies:
5050
- test-tube>=0.7.5
5151
- mlflow>=1.0.0
5252
- comet_ml>=3.1.12
53-
- wandb>=0.8.21
53+
- wandb>=0.10.22
5454
- neptune-client>=0.10.0

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ module = [
5959
"pytorch_lightning.demos.mnist_datamodule",
6060
"pytorch_lightning.loggers.comet",
6161
"pytorch_lightning.loggers.neptune",
62-
"pytorch_lightning.loggers.wandb",
6362
"pytorch_lightning.profilers.advanced",
6463
"pytorch_lightning.profilers.base",
6564
"pytorch_lightning.profilers.pytorch",

requirements/pytorch/extra.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,6 @@ matplotlib>3.1, <3.5.3
33
torchtext>=0.10.*, <=0.12.0
44
omegaconf>=2.0.5, <2.3.0
55
hydra-core>=1.0.5, <1.3.0
6-
jsonargparse[signatures]>=4.10.0, <=4.10.0
6+
jsonargparse[signatures]>=4.10.2, <=4.10.2
77
gcsfs>=2021.5.0, <2022.6.0
88
rich>=10.14.0, !=10.15.0.a, <13.0.0

requirements/pytorch/loggers.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,4 +4,4 @@ neptune-client>=0.10.0, <0.16.4
44
comet-ml>=3.1.12, <3.31.6
55
mlflow>=1.0.0, <1.27.0
66
test_tube>=0.7.5, <=0.7.5
7-
wandb>=0.8.21, <0.12.20
7+
wandb>=0.10.22, <0.12.20

src/pytorch_lightning/loggers/wandb.py

Lines changed: 18 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,11 @@
3232

3333
try:
3434
import wandb
35+
from wandb.sdk.lib import RunDisabled
3536
from wandb.wandb_run import Run
3637
except ModuleNotFoundError:
3738
# needed for test mocks, these tests shall be updated
38-
wandb, Run = None, None
39+
wandb, Run, RunDisabled = None, None, None # type: ignore
3940

4041

4142
class WandbLogger(Logger):
@@ -251,18 +252,18 @@ def __init__(
251252
self,
252253
name: Optional[str] = None,
253254
save_dir: Optional[str] = None,
254-
offline: Optional[bool] = False,
255+
offline: bool = False,
255256
id: Optional[str] = None,
256257
anonymous: Optional[bool] = None,
257258
version: Optional[str] = None,
258259
project: Optional[str] = None,
259260
log_model: Union[str, bool] = False,
260-
experiment=None,
261-
prefix: Optional[str] = "",
261+
experiment: Union[Run, RunDisabled, None] = None,
262+
prefix: str = "",
262263
agg_key_funcs: Optional[Mapping[str, Callable[[Sequence[float]], float]]] = None,
263264
agg_default_func: Optional[Callable[[Sequence[float]], float]] = None,
264-
**kwargs,
265-
):
265+
**kwargs: Any,
266+
) -> None:
266267
if wandb is None:
267268
raise ModuleNotFoundError(
268269
"You want to use `wandb` logger which is not installed yet,"
@@ -288,17 +289,16 @@ def __init__(
288289
self._log_model = log_model
289290
self._prefix = prefix
290291
self._experiment = experiment
291-
self._logged_model_time = {}
292-
self._checkpoint_callback = None
292+
self._logged_model_time: Dict[str, float] = {}
293+
self._checkpoint_callback: Optional["ReferenceType[Checkpoint]"] = None
293294
# set wandb init arguments
294-
anonymous_lut = {True: "allow", False: None}
295-
self._wandb_init = dict(
295+
self._wandb_init: Dict[str, Any] = dict(
296296
name=name or project,
297297
project=project,
298298
id=version or id,
299299
dir=save_dir,
300300
resume="allow",
301-
anonymous=anonymous_lut.get(anonymous, anonymous),
301+
anonymous=("allow" if anonymous else None),
302302
)
303303
self._wandb_init.update(**kwargs)
304304
# extract parameters
@@ -310,7 +310,7 @@ def __init__(
310310
wandb.require("service")
311311
_ = self.experiment
312312

313-
def __getstate__(self):
313+
def __getstate__(self) -> Dict[str, Any]:
314314
state = self.__dict__.copy()
315315
# args needed to reload correct experiment
316316
if self._experiment is not None:
@@ -322,7 +322,7 @@ def __getstate__(self):
322322
state["_experiment"] = None
323323
return state
324324

325-
@property
325+
@property # type: ignore[misc]
326326
@rank_zero_experiment
327327
def experiment(self) -> Run:
328328
r"""
@@ -357,13 +357,14 @@ def experiment(self) -> Run:
357357
self._experiment = wandb.init(**self._wandb_init)
358358

359359
# define default x-axis
360-
if getattr(self._experiment, "define_metric", None):
360+
if isinstance(self._experiment, Run) and getattr(self._experiment, "define_metric", None):
361361
self._experiment.define_metric("trainer/global_step")
362362
self._experiment.define_metric("*", step_metric="trainer/global_step", step_sync=True)
363363

364+
assert isinstance(self._experiment, Run)
364365
return self._experiment
365366

366-
def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True):
367+
def watch(self, model: nn.Module, log: str = "gradients", log_freq: int = 100, log_graph: bool = True) -> None:
367368
self.experiment.watch(model, log=log, log_freq=log_freq, log_graph=log_graph)
368369

369370
@rank_zero_only
@@ -379,7 +380,7 @@ def log_metrics(self, metrics: Dict[str, float], step: Optional[int] = None) ->
379380

380381
metrics = _add_prefix(metrics, self._prefix, self.LOGGER_JOIN_CHAR)
381382
if step is not None:
382-
self.experiment.log({**metrics, "trainer/global_step": step})
383+
self.experiment.log(dict(metrics, **{"trainer/global_step": step}))
383384
else:
384385
self.experiment.log(metrics)
385386

@@ -417,7 +418,7 @@ def log_text(
417418
self.log_table(key, columns, data, dataframe, step)
418419

419420
@rank_zero_only
420-
def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: str) -> None:
421+
def log_image(self, key: str, images: List[Any], step: Optional[int] = None, **kwargs: Any) -> None:
421422
"""Log images (tensors, numpy arrays, PIL Images or file paths).
422423
423424
Optional kwargs are lists passed to each image (ex: caption, masks, boxes).

src/pytorch_lightning/utilities/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from pytorch_lightning.utilities.model_helpers import is_overridden
3232
from pytorch_lightning.utilities.rank_zero import _warn, rank_zero_deprecation, rank_zero_warn
3333

34-
_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.10.0")
34+
_JSONARGPARSE_SIGNATURES_AVAILABLE = _RequirementAvailable("jsonargparse[signatures]>=4.10.2")
3535

3636
if _JSONARGPARSE_SIGNATURES_AVAILABLE:
3737
import docstring_parser

tests/tests_pytorch/loggers/test_all.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,7 @@
4545
mock.patch("pytorch_lightning.loggers.mlflow.MlflowClient"),
4646
mock.patch("pytorch_lightning.loggers.neptune.neptune", new_callable=create_neptune_mock),
4747
mock.patch("pytorch_lightning.loggers.wandb.wandb"),
48+
mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock),
4849
)
4950
ALL_LOGGER_CLASSES = (
5051
CometLogger,
@@ -363,7 +364,9 @@ def test_logger_with_prefix_all(tmpdir, monkeypatch):
363364
logger.experiment.add_scalar.assert_called_once_with("tmp-test", 1.0, 0)
364365

365366
# WandB
366-
with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb:
367+
with mock.patch("pytorch_lightning.loggers.wandb.wandb") as wandb, mock.patch(
368+
"pytorch_lightning.loggers.wandb.Run", new=mock.Mock
369+
):
367370
logger = _instantiate_logger(WandbLogger, save_dir=tmpdir, prefix=prefix)
368371
wandb.run = None
369372
wandb.init().step = 0

tests/tests_pytorch/loggers/test_wandb.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from tests_pytorch.helpers.utils import no_warning_call
2525

2626

27+
@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
2728
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
2829
def test_wandb_logger_init(wandb, monkeypatch):
2930
"""Verify that basic functionality of wandb logger works.
@@ -111,20 +112,21 @@ class Experiment:
111112
def name(self):
112113
return "the_run_name"
113114

114-
wandb.run = None
115-
wandb.init.return_value = Experiment()
116-
logger = WandbLogger(id="the_id", offline=True)
115+
with mock.patch("pytorch_lightning.loggers.wandb.Run", new=Experiment):
116+
wandb.run = None
117+
wandb.init.return_value = Experiment()
118+
logger = WandbLogger(id="the_id", offline=True)
117119

118-
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=logger)
119-
# Access the experiment to ensure it's created
120-
assert trainer.logger.experiment, "missing experiment"
121-
assert trainer.log_dir == logger.save_dir
122-
pkl_bytes = pickle.dumps(trainer)
123-
trainer2 = pickle.loads(pkl_bytes)
120+
trainer = Trainer(default_root_dir=tmpdir, max_epochs=1, logger=logger)
121+
# Access the experiment to ensure it's created
122+
assert trainer.logger.experiment, "missing experiment"
123+
assert trainer.log_dir == logger.save_dir
124+
pkl_bytes = pickle.dumps(trainer)
125+
trainer2 = pickle.loads(pkl_bytes)
124126

125-
assert os.environ["WANDB_MODE"] == "dryrun"
126-
assert trainer2.logger.__class__.__name__ == WandbLogger.__name__
127-
assert trainer2.logger.experiment, "missing experiment"
127+
assert os.environ["WANDB_MODE"] == "dryrun"
128+
assert trainer2.logger.__class__.__name__ == WandbLogger.__name__
129+
assert trainer2.logger.experiment, "missing experiment"
128130

129131
wandb.init.assert_called()
130132
assert "id" in wandb.init.call_args[1]
@@ -133,6 +135,7 @@ def name(self):
133135
del os.environ["WANDB_MODE"]
134136

135137

138+
@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
136139
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
137140
def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir):
138141
"""Test that the logger creates the folders and files in the right place."""
@@ -169,6 +172,7 @@ def test_wandb_logger_dirs_creation(wandb, monkeypatch, tmpdir):
169172
assert trainer.log_dir == logger.save_dir
170173

171174

175+
@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
172176
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
173177
def test_wandb_log_model(wandb, monkeypatch, tmpdir):
174178
"""Test that the logger creates the folders and files in the right place."""
@@ -234,6 +238,7 @@ def test_wandb_log_model(wandb, monkeypatch, tmpdir):
234238
)
235239

236240

241+
@mock.patch("pytorch_lightning.loggers.wandb.Run", new=mock.Mock)
237242
@mock.patch("pytorch_lightning.loggers.wandb.wandb")
238243
def test_wandb_log_media(wandb, tmpdir):
239244
"""Test that the logger creates the folders and files in the right place."""

0 commit comments

Comments
 (0)