|
2 | 2 | import os |
3 | 3 | import signal |
4 | 4 | import sys |
| 5 | +import threading |
5 | 6 | from signal import Signals |
6 | 7 | from subprocess import call |
7 | 8 | from types import FrameType, FunctionType |
@@ -43,11 +44,11 @@ def register_signal_handlers(self) -> None: |
43 | 44 |
|
44 | 45 | # signal.SIGUSR1 doesn't seem available on windows |
45 | 46 | if not self._is_on_windows(): |
46 | | - if not self._has_already_handler(signal.SIGUSR1): |
47 | | - signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) |
| 47 | + if sigusr1_handlers and not self._has_already_handler(signal.SIGUSR1): |
| 48 | + self._register_signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) |
48 | 49 |
|
49 | | - if not self._has_already_handler(signal.SIGTERM): |
50 | | - signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) |
| 50 | + if sigterm_handlers and not self._has_already_handler(signal.SIGTERM): |
| 51 | + self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) |
51 | 52 |
|
52 | 53 | def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: |
53 | 54 | if self.trainer.is_global_zero: |
@@ -107,3 +108,8 @@ def _has_already_handler(self, signum: Signals) -> bool: |
107 | 108 | return isinstance(signal.getsignal(signum), FunctionType) |
108 | 109 | except AttributeError: |
109 | 110 | return False |
| 111 | + |
| 112 | + @staticmethod |
| 113 | + def _register_signal(signum: Signals, handlers: HandlersCompose) -> None: |
| 114 | + if threading.current_thread() is threading.main_thread(): |
| 115 | + signal.signal(signum, handlers) |
0 commit comments