Skip to content

Commit f37bd46

Browse files
authored
Update mypy (#11096)
1 parent cc42aa9 commit f37bd46

File tree

4 files changed

+26
-20
lines changed

4 files changed

+26
-20
lines changed

pytorch_lightning/plugins/precision/precision_plugin.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -150,7 +150,7 @@ def optimizer_step(
150150
"""Hook to run the optimizer step."""
151151
if isinstance(model, pl.LightningModule):
152152
closure = partial(self._wrap_closure, model, optimizer, optimizer_idx, closure)
153-
optimizer.step(closure=closure, **kwargs) # type: ignore[call-arg]
153+
optimizer.step(closure=closure, **kwargs)
154154

155155
def _track_grad_norm(self, trainer: "pl.Trainer") -> None:
156156
if trainer.track_grad_norm == -1:

pytorch_lightning/trainer/connectors/signal_connector.py

Lines changed: 22 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,6 @@
33
import signal
44
import sys
55
import threading
6-
from signal import Signals
76
from subprocess import call
87
from types import FrameType
98
from typing import Any, Callable, Dict, List, Set, Union
@@ -12,33 +11,38 @@
1211
from pytorch_lightning.plugins.environments import SLURMEnvironment
1312
from pytorch_lightning.utilities.imports import _fault_tolerant_training, _IS_WINDOWS
1413

15-
log = logging.getLogger(__name__)
14+
# copied from signal.pyi
15+
_SIGNUM = Union[int, signal.Signals]
16+
_HANDLER = Union[Callable[[_SIGNUM, FrameType], Any], int, signal.Handlers, None]
1617

17-
_SIGNAL_HANDLER_DICT = Dict[Signals, Union[Callable[[Signals, FrameType], Any], int, None]]
18+
log = logging.getLogger(__name__)
1819

1920

2021
class HandlersCompose:
21-
def __init__(self, signal_handlers: Union[List[Callable], Callable]) -> None:
22+
def __init__(self, signal_handlers: Union[List[_HANDLER], _HANDLER]) -> None:
2223
if not isinstance(signal_handlers, list):
2324
signal_handlers = [signal_handlers]
2425
self.signal_handlers = signal_handlers
2526

26-
def __call__(self, signum: Signals, frame: FrameType) -> None:
27+
def __call__(self, signum: _SIGNUM, frame: FrameType) -> None:
2728
for signal_handler in self.signal_handlers:
28-
signal_handler(signum, frame)
29+
if isinstance(signal_handler, int):
30+
signal_handler = signal.getsignal(signal_handler)
31+
if callable(signal_handler):
32+
signal_handler(signum, frame)
2933

3034

3135
class SignalConnector:
3236
def __init__(self, trainer: "pl.Trainer") -> None:
3337
self.trainer = trainer
3438
self.trainer._terminate_gracefully = False
35-
self._original_handlers: _SIGNAL_HANDLER_DICT = {}
39+
self._original_handlers: Dict[_SIGNUM, _HANDLER] = {}
3640

3741
def register_signal_handlers(self) -> None:
3842
self._original_handlers = self._get_current_signal_handlers()
3943

40-
sigusr1_handlers: List[Callable] = []
41-
sigterm_handlers: List[Callable] = []
44+
sigusr1_handlers: List[_HANDLER] = []
45+
sigterm_handlers: List[_HANDLER] = []
4246

4347
if _fault_tolerant_training():
4448
sigterm_handlers.append(self.fault_tolerant_sigterm_handler_fn)
@@ -57,7 +61,7 @@ def register_signal_handlers(self) -> None:
5761
if sigterm_handlers and not self._has_already_handler(signal.SIGTERM):
5862
self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))
5963

60-
def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
64+
def slurm_sigusr1_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
6165
if self.trainer.is_global_zero:
6266
# save weights
6367
log.info("handling SIGUSR1")
@@ -88,22 +92,22 @@ def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
8892
if self.trainer.logger:
8993
self.trainer.logger.finalize("finished")
9094

91-
def fault_tolerant_sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None:
95+
def fault_tolerant_sigterm_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
9296
log.info(f"Received signal {signum}. Saving a fault-tolerant checkpoint and terminating.")
9397
self.trainer._terminate_gracefully = True
9498

95-
def sigterm_handler_fn(self, signum: Signals, frame: FrameType) -> None:
99+
def sigterm_handler_fn(self, signum: _SIGNUM, frame: FrameType) -> None:
96100
log.info("bypassing sigterm")
97101

98102
def teardown(self) -> None:
99103
"""Restores the signals that were previsouly configured before :class:`SignalConnector` replaced them."""
100104
for signum, handler in self._original_handlers.items():
101105
if handler is not None:
102-
signal.signal(signum, handler)
106+
signal.signal(signum, handler) # type: ignore[arg-type]
103107
self._original_handlers = {}
104108

105109
@staticmethod
106-
def _get_current_signal_handlers() -> _SIGNAL_HANDLER_DICT:
110+
def _get_current_signal_handlers() -> Dict[_SIGNUM, _HANDLER]:
107111
"""Collects the currently assigned signal handlers."""
108112
valid_signals = SignalConnector._valid_signals()
109113
if not _IS_WINDOWS:
@@ -112,7 +116,7 @@ def _get_current_signal_handlers() -> _SIGNAL_HANDLER_DICT:
112116
return {signum: signal.getsignal(signum) for signum in valid_signals}
113117

114118
@staticmethod
115-
def _valid_signals() -> Set[Signals]:
119+
def _valid_signals() -> Set[signal.Signals]:
116120
"""Returns all valid signals supported on the current platform.
117121
118122
Behaves identically to :func:`signals.valid_signals` in Python 3.8+ and implements the equivalent behavior for
@@ -138,13 +142,13 @@ def _is_on_windows() -> bool:
138142
return sys.platform == "win32"
139143

140144
@staticmethod
141-
def _has_already_handler(signum: Signals) -> bool:
145+
def _has_already_handler(signum: _SIGNUM) -> bool:
142146
return signal.getsignal(signum) not in (None, signal.SIG_DFL)
143147

144148
@staticmethod
145-
def _register_signal(signum: Signals, handlers: HandlersCompose) -> None:
149+
def _register_signal(signum: _SIGNUM, handlers: _HANDLER) -> None:
146150
if threading.current_thread() is threading.main_thread():
147-
signal.signal(signum, handlers)
151+
signal.signal(signum, handlers) # type: ignore[arg-type]
148152

149153
def __getstate__(self) -> Dict:
150154
state = self.__dict__.copy()

pytorch_lightning/utilities/model_helpers.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ def is_overridden(method_name: str, instance: Optional[object] = None, parent: O
3434
raise ValueError("Expected a parent")
3535

3636
instance_attr = getattr(instance, method_name, None)
37+
if instance_attr is None:
38+
return False
3739
# `functools.wraps()` support
3840
if hasattr(instance_attr, "__wrapped__"):
3941
instance_attr = instance_attr.__wrapped__

requirements/test.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ codecov>=2.1
33
pytest>=6.0
44
pytest-rerunfailures>=10.2
55
twine==3.2
6-
mypy==0.910
6+
mypy>=0.920
77
flake8>=3.9.2
88
pre-commit>=1.0
99

0 commit comments

Comments
 (0)