Skip to content

Commit f116c2f

Browse files
fix mypy typing errors in pytorch_lightning/strategies/single_device.py (#13532)
* fix typing in strategies/single_device.py * Make assert statement more explicit Co-authored-by: Justus Schock <[email protected]>
1 parent cf189bd commit f116c2f

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

pyproject.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,6 @@ module = [
7474
"pytorch_lightning.strategies.parallel",
7575
"pytorch_lightning.strategies.sharded",
7676
"pytorch_lightning.strategies.sharded_spawn",
77-
"pytorch_lightning.strategies.single_device",
7877
"pytorch_lightning.strategies.single_tpu",
7978
"pytorch_lightning.strategies.tpu_spawn",
8079
"pytorch_lightning.strategies.strategy",

src/pytorch_lightning/strategies/single_device.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pytorch_lightning as pl
2222
from pytorch_lightning.plugins.io.checkpoint_plugin import CheckpointIO
2323
from pytorch_lightning.plugins.precision import PrecisionPlugin
24-
from pytorch_lightning.strategies.strategy import Strategy
24+
from pytorch_lightning.strategies.strategy import Strategy, TBroadcast
2525
from pytorch_lightning.utilities.types import _DEVICE
2626

2727

@@ -66,6 +66,7 @@ def root_device(self) -> torch.device:
6666
return self._root_device
6767

6868
def model_to_device(self) -> None:
69+
assert self.model is not None, "self.model must be set before self.model.to()"
6970
self.model.to(self.root_device)
7071

7172
def setup(self, trainer: pl.Trainer) -> None:
@@ -76,10 +77,10 @@ def setup(self, trainer: pl.Trainer) -> None:
7677
def is_global_zero(self) -> bool:
7778
return True
7879

79-
def barrier(self, *args, **kwargs) -> None:
80+
def barrier(self, *args: Any, **kwargs: Any) -> None:
8081
pass
8182

82-
def broadcast(self, obj: object, src: int = 0) -> object:
83+
def broadcast(self, obj: TBroadcast, src: int = 0) -> TBroadcast:
8384
return obj
8485

8586
@classmethod

0 commit comments

Comments
 (0)