33import signal
44import sys
55import threading
6- from signal import Signals
76from subprocess import call
87from types import FrameType
98from typing import Any , Callable , Dict , List , Set , Union
1211from pytorch_lightning .plugins .environments import SLURMEnvironment
1312from pytorch_lightning .utilities .imports import _fault_tolerant_training , _IS_WINDOWS
1413
15- log = logging .getLogger (__name__ )
14+ # copied from signal.pyi
15+ _SIGNUM = Union [int , signal .Signals ]
16+ _HANDLER = Union [Callable [[_SIGNUM , FrameType ], Any ], int , signal .Handlers , None ]
1617
17- _SIGNAL_HANDLER_DICT = Dict [ Signals , Union [ Callable [[ Signals , FrameType ], Any ], int , None ]]
18+ log = logging . getLogger ( __name__ )
1819
1920
2021class HandlersCompose :
21- def __init__ (self , signal_handlers : Union [List [Callable ], Callable ]) -> None :
22+ def __init__ (self , signal_handlers : Union [List [_HANDLER ], _HANDLER ]) -> None :
2223 if not isinstance (signal_handlers , list ):
2324 signal_handlers = [signal_handlers ]
2425 self .signal_handlers = signal_handlers
2526
26- def __call__ (self , signum : Signals , frame : FrameType ) -> None :
27+ def __call__ (self , signum : _SIGNUM , frame : FrameType ) -> None :
2728 for signal_handler in self .signal_handlers :
28- signal_handler (signum , frame )
29+ if isinstance (signal_handler , int ):
30+ signal_handler = signal .getsignal (signal_handler )
31+ if callable (signal_handler ):
32+ signal_handler (signum , frame )
2933
3034
3135class SignalConnector :
3236 def __init__ (self , trainer : "pl.Trainer" ) -> None :
3337 self .trainer = trainer
3438 self .trainer ._terminate_gracefully = False
35- self ._original_handlers : _SIGNAL_HANDLER_DICT = {}
39+ self ._original_handlers : Dict [ _SIGNUM , _HANDLER ] = {}
3640
3741 def register_signal_handlers (self ) -> None :
3842 self ._original_handlers = self ._get_current_signal_handlers ()
3943
40- sigusr1_handlers : List [Callable ] = []
41- sigterm_handlers : List [Callable ] = []
44+ sigusr1_handlers : List [_HANDLER ] = []
45+ sigterm_handlers : List [_HANDLER ] = []
4246
4347 if _fault_tolerant_training ():
4448 sigterm_handlers .append (self .fault_tolerant_sigterm_handler_fn )
@@ -57,7 +61,7 @@ def register_signal_handlers(self) -> None:
5761 if sigterm_handlers and not self ._has_already_handler (signal .SIGTERM ):
5862 self ._register_signal (signal .SIGTERM , HandlersCompose (sigterm_handlers ))
5963
60- def slurm_sigusr1_handler_fn (self , signum : Signals , frame : FrameType ) -> None :
64+ def slurm_sigusr1_handler_fn (self , signum : _SIGNUM , frame : FrameType ) -> None :
6165 if self .trainer .is_global_zero :
6266 # save weights
6367 log .info ("handling SIGUSR1" )
@@ -88,22 +92,22 @@ def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
8892 if self .trainer .logger :
8993 self .trainer .logger .finalize ("finished" )
9094
91- def fault_tolerant_sigterm_handler_fn (self , signum : Signals , frame : FrameType ) -> None :
95+ def fault_tolerant_sigterm_handler_fn (self , signum : _SIGNUM , frame : FrameType ) -> None :
9296 log .info (f"Received signal { signum } . Saving a fault-tolerant checkpoint and terminating." )
9397 self .trainer ._terminate_gracefully = True
9498
95- def sigterm_handler_fn (self , signum : Signals , frame : FrameType ) -> None :
99+ def sigterm_handler_fn (self , signum : _SIGNUM , frame : FrameType ) -> None :
96100 log .info ("bypassing sigterm" )
97101
98102 def teardown (self ) -> None :
99103 """Restores the signals that were previsouly configured before :class:`SignalConnector` replaced them."""
100104 for signum , handler in self ._original_handlers .items ():
101105 if handler is not None :
102- signal .signal (signum , handler )
106+ signal .signal (signum , handler ) # type: ignore[arg-type]
103107 self ._original_handlers = {}
104108
105109 @staticmethod
106- def _get_current_signal_handlers () -> _SIGNAL_HANDLER_DICT :
110+ def _get_current_signal_handlers () -> Dict [ _SIGNUM , _HANDLER ] :
107111 """Collects the currently assigned signal handlers."""
108112 valid_signals = SignalConnector ._valid_signals ()
109113 if not _IS_WINDOWS :
@@ -112,7 +116,7 @@ def _get_current_signal_handlers() -> _SIGNAL_HANDLER_DICT:
112116 return {signum : signal .getsignal (signum ) for signum in valid_signals }
113117
114118 @staticmethod
115- def _valid_signals () -> Set [Signals ]:
119+ def _valid_signals () -> Set [signal . Signals ]:
116120 """Returns all valid signals supported on the current platform.
117121
118122 Behaves identically to :func:`signals.valid_signals` in Python 3.8+ and implements the equivalent behavior for
@@ -138,13 +142,13 @@ def _is_on_windows() -> bool:
138142 return sys .platform == "win32"
139143
140144 @staticmethod
141- def _has_already_handler (signum : Signals ) -> bool :
145+ def _has_already_handler (signum : _SIGNUM ) -> bool :
142146 return signal .getsignal (signum ) not in (None , signal .SIG_DFL )
143147
144148 @staticmethod
145- def _register_signal (signum : Signals , handlers : HandlersCompose ) -> None :
149+ def _register_signal (signum : _SIGNUM , handlers : _HANDLER ) -> None :
146150 if threading .current_thread () is threading .main_thread ():
147- signal .signal (signum , handlers )
151+ signal .signal (signum , handlers ) # type: ignore[arg-type]
148152
149153 def __getstate__ (self ) -> Dict :
150154 state = self .__dict__ .copy ()
0 commit comments