Skip to content

Commit 7d3ad5b

Browse files
authored
Don't register signal in thread (#10610)
1 parent 5788789 commit 7d3ad5b

File tree

3 files changed

+23
-3
lines changed

3 files changed

+23
-3
lines changed

CHANGELOG.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).
147147
### Fixed
148148

149149

150-
-
150+
- Fixed signals being registered within threads ([#10610](https://github.com/PyTorchLightning/pytorch-lightning/pull/10610))
151151

152152

153153
-

pytorch_lightning/trainer/connectors/signal_connector.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import signal
44
import sys
5+
import threading
56
from signal import Signals
67
from subprocess import call
78
from types import FrameType, FunctionType
@@ -46,10 +47,10 @@ def register_signal_handlers(self) -> None:
4647
# signal.SIGUSR1 doesn't seem available on windows
4748
if not self._is_on_windows():
4849
if sigusr1_handlers and not self._has_already_handler(signal.SIGUSR1):
49-
signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers))
50+
self._register_signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers))
5051

5152
if sigterm_handlers and not self._has_already_handler(signal.SIGTERM):
52-
signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))
53+
self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers))
5354

5455
def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None:
5556
if self.trainer.is_global_zero:
@@ -96,3 +97,8 @@ def _has_already_handler(self, signum: Signals) -> bool:
9697
return isinstance(signal.getsignal(signum), FunctionType)
9798
except AttributeError:
9899
return False
100+
101+
@staticmethod
102+
def _register_signal(signum: Signals, handlers: HandlersCompose) -> None:
103+
if threading.current_thread() is threading.main_thread():
104+
signal.signal(signum, handlers)

tests/trainer/connectors/test_signal_connector.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14+
import concurrent.futures
1415
import os
1516
import signal
1617
from time import sleep
@@ -87,3 +88,16 @@ def test_auto_requeue_flag(auto_requeue):
8788
# TODO: should this be done in SignalConnector teardown?
8889
signal.signal(signal.SIGTERM, sigterm_handler_default)
8990
signal.signal(signal.SIGUSR1, sigusr1_handler_default)
91+
92+
93+
def _registering_signals():
94+
trainer = Trainer()
95+
trainer.signal_connector.register_signal_handlers()
96+
97+
98+
@RunIf(skip_windows=True)
99+
@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"})
100+
def test_signal_connector_in_thread():
101+
with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor:
102+
for future in concurrent.futures.as_completed([executor.submit(_registering_signals)]):
103+
assert future.exception() is None

0 commit comments

Comments
 (0)