diff --git a/CHANGELOG.md b/CHANGELOG.md index 9ada4815bcf57..975f8bdc7c499 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -143,7 +143,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/). ### Fixed -- +- Fixed signals being registered within threads ([#10610](https://github.com/PyTorchLightning/pytorch-lightning/pull/10610)) - diff --git a/pytorch_lightning/trainer/connectors/signal_connector.py b/pytorch_lightning/trainer/connectors/signal_connector.py index 795145b5be6af..90d0f6928283f 100644 --- a/pytorch_lightning/trainer/connectors/signal_connector.py +++ b/pytorch_lightning/trainer/connectors/signal_connector.py @@ -2,6 +2,7 @@ import os import signal import sys +import threading from signal import Signals from subprocess import call from types import FrameType, FunctionType @@ -46,10 +47,10 @@ def register_signal_handlers(self) -> None: # signal.SIGUSR1 doesn't seem available on windows if not self._is_on_windows(): if sigusr1_handlers and not self._has_already_handler(signal.SIGUSR1): - signal.signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) + self._register_signal(signal.SIGUSR1, HandlersCompose(sigusr1_handlers)) if sigterm_handlers and not self._has_already_handler(signal.SIGTERM): - signal.signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) + self._register_signal(signal.SIGTERM, HandlersCompose(sigterm_handlers)) def slurm_sigusr1_handler_fn(self, signum: Signals, frame: FrameType) -> None: if self.trainer.is_global_zero: @@ -96,3 +97,8 @@ def _has_already_handler(self, signum: Signals) -> bool: return isinstance(signal.getsignal(signum), FunctionType) except AttributeError: return False + + @staticmethod + def _register_signal(signum: Signals, handlers: HandlersCompose) -> None: + if threading.current_thread() is threading.main_thread(): + signal.signal(signum, handlers) diff --git a/tests/trainer/connectors/test_signal_connector.py b/tests/trainer/connectors/test_signal_connector.py index 76dae5e07db35..fbfce158e3675 100644 --- a/tests/trainer/connectors/test_signal_connector.py +++ b/tests/trainer/connectors/test_signal_connector.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import concurrent.futures import os import signal from time import sleep @@ -87,3 +88,16 @@ def test_auto_requeue_flag(auto_requeue): # TODO: should this be done in SignalConnector teardown? signal.signal(signal.SIGTERM, sigterm_handler_default) signal.signal(signal.SIGUSR1, sigusr1_handler_default) + + +def _registering_signals(): + trainer = Trainer() + trainer.signal_connector.register_signal_handlers() + + +@RunIf(skip_windows=True) +@mock.patch.dict(os.environ, {"PL_FAULT_TOLERANT_TRAINING": "1"}) +def test_signal_connector_in_thread(): + with concurrent.futures.ThreadPoolExecutor(max_workers=1) as executor: + for future in concurrent.futures.as_completed([executor.submit(_registering_signals)]): + assert future.exception() is None