From 797975669235f2b5d6d2b0a8b5571647b1c0407c Mon Sep 17 00:00:00 2001 From: Christoph Burgdorf Date: Thu, 20 Sep 2018 12:10:28 +0200 Subject: [PATCH] Ensure shared process plugins shutdown cleanly Fixes #1284 --- setup.py | 1 + trinity/extensibility/__init__.py | 8 +- trinity/extensibility/exceptions.py | 12 ++ trinity/extensibility/plugin.py | 81 ++++---- trinity/extensibility/plugin_manager.py | 190 +++++++++++++----- trinity/main.py | 44 ++-- trinity/plugins/builtin/attach/plugin.py | 4 +- .../builtin/fix_unclean_shutdown/plugin.py | 4 +- trinity/plugins/builtin/json_rpc/plugin.py | 4 +- .../builtin/light_peer_chain_bridge/plugin.py | 17 +- trinity/plugins/builtin/tx_pool/plugin.py | 17 +- trinity/utils/shutdown.py | 32 ++- 12 files changed, 283 insertions(+), 131 deletions(-) create mode 100644 trinity/extensibility/exceptions.py diff --git a/setup.py b/setup.py index 488c19758f..771eed0adf 100644 --- a/setup.py +++ b/setup.py @@ -36,6 +36,7 @@ "upnpclient>=0.0.8,<1", ], 'trinity': [ + "async-generator==1.10", "bloom-filter==1.3", "cachetools>=2.1.0,<3.0.0", "coincurve>=8.0.0,<9.0.0", diff --git a/trinity/extensibility/__init__.py b/trinity/extensibility/__init__.py index 795dff62aa..2b03807f7b 100644 --- a/trinity/extensibility/__init__.py +++ b/trinity/extensibility/__init__.py @@ -2,15 +2,17 @@ BaseEvent ) from trinity.extensibility.plugin import ( # noqa: F401 - BasePlugin, + BaseAsyncStopPlugin, + BaseMainProcessPlugin, BaseIsolatedPlugin, + BasePlugin, + BaseSyncStopPlugin, DebugPlugin, PluginContext, - PluginProcessScope, ) from trinity.extensibility.plugin_manager import ( # noqa: F401 + BaseManagerProcessScope, MainAndIsolatedProcessScope, PluginManager, - ManagerProcessScope, SharedProcessScope, ) diff --git a/trinity/extensibility/exceptions.py b/trinity/extensibility/exceptions.py new file mode 100644 index 0000000000..5977435d69 --- /dev/null +++ b/trinity/extensibility/exceptions.py @@ -0,0 +1,12 @@ +from trinity.exceptions import ( + BaseTrinityError, +) + + +class UnsuitableShutdownError(BaseTrinityError): + """ + Raised when `shutdown` was called on a ``PluginManager`` instance that operates + in the ``MainAndIsolatedProcessScope`` or when ``shutdown_blocking`` was called on a + ``PluginManager`` instance that operates in the ``SharedProcessScope``. + """ + pass diff --git a/trinity/extensibility/plugin.py b/trinity/extensibility/plugin.py index 7e6193d4de..036083d0be 100644 --- a/trinity/extensibility/plugin.py +++ b/trinity/extensibility/plugin.py @@ -7,10 +7,7 @@ Namespace, _SubParsersAction, ) -from enum import ( - auto, - Enum, -) +import asyncio import logging from multiprocessing import ( Process @@ -48,20 +45,6 @@ ) -class PluginProcessScope(Enum): - """ - Define the process model in which a plugin operates: - - - ISOLATED: The plugin runs in its own separate process - - MAIN: The plugin takes over the Trinity main process (e.g. attach) - - SHARED: The plugin runs in a process that is shared with other plugins - """ - - ISOLATED = auto() - MAIN = auto() - SHARED = auto() - - class PluginContext: """ The ``PluginContext`` holds valuable contextual information such as the parsed @@ -98,14 +81,6 @@ def name(self) -> str: "Must be implemented by subclasses" ) - @property - def process_scope(self) -> PluginProcessScope: - """ - Return the :class:`~trinity.extensibility.plugin.PluginProcessScope` that the plugin uses - to operate. The default scope is ``PluginProcessScope.SHARED``. - """ - return PluginProcessScope.SHARED - @property def logger(self) -> logging.Logger: return logging.getLogger('trinity.extensibility.plugin.BasePlugin#{0}'.format(self.name)) @@ -147,21 +122,42 @@ def start(self) -> None: """ pass + +class BaseSyncStopPlugin(BasePlugin): + """ + A ``BaseSyncStopPlugin`` unwinds synchronoulsy, hence blocks until shut down is done. + """ def stop(self) -> None: - """ - Called when the plugin gets stopped. Should be overwritten to perform cleanup - work in case the plugin set up external resources. - """ pass -class BaseIsolatedPlugin(BasePlugin): +class BaseAsyncStopPlugin(BasePlugin): + """ + A ``BaseAsyncStopPlugin`` unwinds asynchronoulsy, hence needs to be awaited. + """ - _process: Process = None + async def stop(self) -> None: + pass - @property - def process_scope(self) -> PluginProcessScope: - return PluginProcessScope.ISOLATED + +class BaseMainProcessPlugin(BasePlugin): + """ + A ``BaseMainProcessPlugin`` overtakes the whole main process before most of the Trinity boot + process had a chance to start. In that sense it redefines the whole meaning of the ``trinity`` + process. + """ + pass + + +class BaseIsolatedPlugin(BaseSyncStopPlugin): + """ + A ``BaseIsolatedPlugin`` runs in an isolated process and doesn't dictate whether its + implementation is based on non-blocking asyncio or synchronous calls. When an isolated + plugin is stopped it will first receive a SIGINT followed by a SIGTERM soon after. + It is up to the plugin to handle these signals accordingly. + """ + + _process: Process = None def _start(self) -> None: self._process = ctx.Process( @@ -182,7 +178,7 @@ def stop(self) -> None: kill_process_gracefully(self._process, self.logger) -class DebugPlugin(BasePlugin): +class DebugPlugin(BaseAsyncStopPlugin): """ This is a dummy plugin useful for demonstration and debugging purposes """ @@ -195,7 +191,7 @@ def configure_parser(self, arg_parser: ArgumentParser, subparser: _SubParsersAct arg_parser.add_argument("--debug-plugin", type=bool, required=False) def handle_event(self, activation_event: BaseEvent) -> None: - self.logger.info("Debug plugin: handle_event called: ", activation_event) + self.logger.info("Debug plugin: handle_event called: %s", activation_event) def should_start(self) -> bool: self.logger.info("Debug plugin: should_start called") @@ -203,3 +199,14 @@ def should_start(self) -> bool: def start(self) -> None: self.logger.info("Debug plugin: start called") + asyncio.ensure_future(self.count_forever()) + + async def count_forever(self) -> None: + i = 0 + while True: + self.logger.info(i) + i += 1 + await asyncio.sleep(1) + + async def stop(self) -> None: + self.logger.info("Debug plugin: stop called") diff --git a/trinity/extensibility/plugin_manager.py b/trinity/extensibility/plugin_manager.py index ef57d1a345..5b23ffbf48 100644 --- a/trinity/extensibility/plugin_manager.py +++ b/trinity/extensibility/plugin_manager.py @@ -1,14 +1,21 @@ +from abc import ( + ABC, + abstractmethod, +) from argparse import ( ArgumentParser, Namespace, _SubParsersAction, ) +import asyncio import logging from typing import ( Any, + Awaitable, Dict, Iterable, List, + Optional, Union, ) @@ -24,27 +31,96 @@ BaseEvent, PluginStartedEvent, ) +from trinity.extensibility.exceptions import ( + UnsuitableShutdownError, +) from trinity.extensibility.plugin import ( + BaseAsyncStopPlugin, + BaseIsolatedPlugin, + BaseMainProcessPlugin, BasePlugin, + BaseSyncStopPlugin, PluginContext, - PluginProcessScope, ) -class MainAndIsolatedProcessScope(): +class BaseManagerProcessScope(ABC): + """ + Define the operational model under which a ``PluginManager`` runs. + """ + + endpoint: Endpoint + + @abstractmethod + def is_responsible_for_plugin(self, plugin: BasePlugin) -> bool: + """ + Define whether a ``PluginManager`` operating under this scope is responsible + for a given plugin or not. + """ + raise NotImplementedError("Must be implemented by subclasses") + + @abstractmethod + def create_plugin_context(self, + plugin: BasePlugin, + args: Namespace, + chain_config: ChainConfig, + boot_kwargs: Dict[str, Any]) -> PluginContext: + """ + Create the ``PluginContext`` for a given plugin. + """ + raise NotImplementedError("Must be implemented by subclasses") + + +class MainAndIsolatedProcessScope(BaseManagerProcessScope): def __init__(self, event_bus: EventBus, main_proc_endpoint: Endpoint) -> None: self.event_bus = event_bus self.endpoint = main_proc_endpoint + def is_responsible_for_plugin(self, plugin: BasePlugin) -> bool: + return isinstance(plugin, BaseIsolatedPlugin) or isinstance(plugin, BaseMainProcessPlugin) + + def create_plugin_context(self, + plugin: BasePlugin, + args: Namespace, + chain_config: ChainConfig, + boot_kwargs: Dict[str, Any]) -> PluginContext: + + if isinstance(plugin, BaseIsolatedPlugin): + # Isolated plugins get an entirely new endpoint to be passed into that new process + context = PluginContext( + self.event_bus.create_endpoint(plugin.name) + ) + context.args = args + context.chain_config = chain_config + context.boot_kwargs = boot_kwargs + return context + + # A plugin that overtakes the main process never gets far enough to even get a context. + # For now it should be safe to just return `None`. Maybe reconsider in the future. + return None -class SharedProcessScope(): + +class SharedProcessScope(BaseManagerProcessScope): def __init__(self, shared_proc_endpoint: Endpoint) -> None: self.endpoint = shared_proc_endpoint + def is_responsible_for_plugin(self, plugin: BasePlugin) -> bool: + return isinstance(plugin, BaseAsyncStopPlugin) -ManagerProcessScope = Union[SharedProcessScope, MainAndIsolatedProcessScope] + def create_plugin_context(self, + plugin: BasePlugin, + args: Namespace, + chain_config: ChainConfig, + boot_kwargs: Dict[str, Any]) -> PluginContext: + + # Plugins that run in a shared process all share the endpoint of the plugin manager + context = PluginContext(self.endpoint) + context.args = args + context.chain_config = chain_config + context.boot_kwargs = boot_kwargs + return context class PluginManager: @@ -57,10 +133,7 @@ class PluginManager: This API is very much in flux and is expected to change heavily. """ - MAIN_AND_ISOLATED_SCOPES = {PluginProcessScope.MAIN, PluginProcessScope.ISOLATED} - MAIN_AND_SHARED_SCOPES = {PluginProcessScope.MAIN, PluginProcessScope.SHARED} - - def __init__(self, scope: ManagerProcessScope) -> None: + def __init__(self, scope: BaseManagerProcessScope) -> None: self._scope = scope self._plugin_store: List[BasePlugin] = [] self._started_plugins: List[BasePlugin] = [] @@ -103,7 +176,8 @@ def broadcast(self, event: BaseEvent, exclude: BasePlugin = None) -> None: """ for plugin in self._plugin_store: - if plugin is exclude or not self._is_responsible_for_plugin(plugin): + if plugin is exclude or not self._scope.is_responsible_for_plugin(plugin): + self._logger.debug("Skipping plugin %s (not responsible)", plugin.name) continue plugin.handle_event(event) @@ -116,65 +190,77 @@ def broadcast(self, event: BaseEvent, exclude: BasePlugin = None) -> None: plugin._start() self._started_plugins.append(plugin) - self._logger.info("Plugin started: {}".format(plugin.name)) + self._logger.info("Plugin started: %s", plugin.name) self.broadcast(PluginStartedEvent(plugin), plugin) def prepare(self, args: Namespace, chain_config: ChainConfig, boot_kwargs: Dict[str, Any] = None) -> None: + """ + Create a ``PluginContext`` for every plugin that this plugin manager instance + is responsible for. + """ for plugin in self._plugin_store: - if not self._is_responsible_for_plugin(plugin): + if not self._scope.is_responsible_for_plugin(plugin): continue - context = self._create_context_for_plugin(plugin, args, chain_config, boot_kwargs) + context = self._scope.create_plugin_context(plugin, args, chain_config, boot_kwargs) plugin.set_context(context) - def shutdown(self) -> None: + def shutdown_blocking(self) -> None: + """ + Synchronously shut down all started plugins. + """ + + if isinstance(self._scope, SharedProcessScope): + raise UnsuitableShutdownError("Use `shutdown` for instances of this scope") + + self._logger.info("Shutting down PluginManager with scope %s", type(self._scope)) + for plugin in self._started_plugins: + + if not isinstance(plugin, BaseSyncStopPlugin): + continue + try: + self._logger.info("Stopping plugin: %s", plugin.name) plugin.stop() + self._logger.info("Successfully stopped plugin: %s", plugin.name) except Exception: self._logger.exception("Exception thrown while stopping plugin %s", plugin.name) - def _is_responsible_for_plugin(self, plugin: BasePlugin) -> bool: - - main_or_isolated_plugin = plugin.process_scope in self.MAIN_AND_ISOLATED_SCOPES - shared_plugin = not main_or_isolated_plugin - - manager_for_main_or_isolated = isinstance(self._scope, MainAndIsolatedProcessScope) - manager_for_shared = not manager_for_main_or_isolated - - return ((main_or_isolated_plugin and manager_for_main_or_isolated) or - (shared_plugin and manager_for_shared)) - - def _create_context_for_plugin(self, - plugin: BasePlugin, - args: Namespace, - chain_config: ChainConfig, - boot_kwargs: Dict[str, Any]) -> PluginContext: - - context: PluginContext = None - if plugin.process_scope in self.MAIN_AND_SHARED_SCOPES: - # A plugin that runs in a shared process as well as a plugin that overtakes the main - # process uses the endpoint of the PluginManager which will either be the main - # endpoint or the networking endpoint in the case of Trinity - context = PluginContext(self._scope.endpoint) - elif plugin.process_scope is PluginProcessScope.ISOLATED: - # A plugin that runs in it's own process gets a new endpoint created to get - # passed into that new process - - # mypy doesn't know it can only be that scope at this point. The `isinstance` - # check avoids adding an ignore - if isinstance(self._scope, MainAndIsolatedProcessScope): - endpoint = self._scope.event_bus.create_endpoint(plugin.name) - context = PluginContext(endpoint) - else: - Exception("Invariant: unreachable code path") - - context.args = args - context.chain_config = chain_config - context.boot_kwargs = boot_kwargs + async def shutdown(self) -> None: + """ + Asynchronously shut down all started plugins. + """ - return context + if isinstance(self._scope, MainAndIsolatedProcessScope): + raise UnsuitableShutdownError("Use `shutdown_blocking` for instances of this scope") + + self._logger.info("Shutting down PluginManager with scope %s", type(self._scope)) + + async_plugins = [ + plugin for plugin in self._started_plugins + if isinstance(plugin, BaseAsyncStopPlugin) + ] + + stop_results = await asyncio.gather( + *self._stop_plugins(async_plugins), return_exceptions=True + ) + + for plugin, result in zip(async_plugins, stop_results): + if isinstance(result, Exception): + self._logger.error( + 'Exception thrown while stopping plugin %s: %s', plugin.name, result + ) + else: + self._logger.info("Successfully stopped plugin: %s", plugin.name) + + def _stop_plugins(self, + plugins: Iterable[BaseAsyncStopPlugin] + ) -> Iterable[Awaitable[Optional[Exception]]]: + for plugin in plugins: + self._logger.info("Stopping plugin: %s", plugin.name) + yield plugin.stop() diff --git a/trinity/main.py b/trinity/main.py index 0e11e4185d..b6511b2f7a 100644 --- a/trinity/main.py +++ b/trinity/main.py @@ -2,7 +2,6 @@ import asyncio import logging import signal -import time from typing import ( Any, Dict, @@ -49,9 +48,9 @@ ShutdownRequest ) from trinity.extensibility import ( - PluginManager, + BaseManagerProcessScope, MainAndIsolatedProcessScope, - ManagerProcessScope, + PluginManager, SharedProcessScope, ) from trinity.extensibility.events import ( @@ -77,7 +76,7 @@ setup_cprofiler, ) from trinity.utils.shutdown import ( - exit_on_signal + exit_signal_with_service, ) from trinity.utils.version import ( construct_trinity_client_identifier, @@ -306,16 +305,16 @@ def kill_trinity_gracefully(logger: logging.Logger, # simply uses 'kill' to send a signal to the main process, but also because they will # perform a non-gracefull shutdown if the process takes too long to terminate. logger.info('Keyboard Interrupt: Stopping') - plugin_manager.shutdown() + plugin_manager.shutdown_blocking() main_endpoint.stop() event_bus.stop() - kill_process_gracefully(database_server_process, logger) - logger.info('DB server process (pid=%d) terminated', database_server_process.pid) - # XXX: This short sleep here seems to avoid us hitting a deadlock when attempting to - # join() the networking subprocess: https://github.com/ethereum/py-evm/issues/940 - time.sleep(0.2) - kill_process_gracefully(networking_process, logger) - logger.info('Networking process (pid=%d) terminated', networking_process.pid) + for name, process in [("DB", database_server_process), ("Networking", networking_process)]: + # Our sub-processes will have received a SIGINT already (see comment above), so here we + # wait 2s for them to finish cleanly, and if they fail we kill them for real. + process.join(2) + if process.is_alive(): + kill_process_gracefully(process, logger) + logger.info('%s process (pid=%d) terminated', name, process.pid) # This is required to be within the `kill_trinity_gracefully` so that # plugins can trigger a shutdown of the trinity process. @@ -366,7 +365,11 @@ def launch_node(args: Namespace, chain_config: ChainConfig, endpoint: Endpoint) )) node = NodeClass(plugin_manager, chain_config) - run_service_until_quit(node) + loop = node.get_event_loop() + asyncio.ensure_future(handle_networking_exit(node, plugin_manager, endpoint), loop=loop) + asyncio.ensure_future(node.run(), loop=loop) + loop.run_forever() + loop.close() def display_launch_logs(chain_config: ChainConfig) -> None: @@ -376,15 +379,16 @@ def display_launch_logs(chain_config: ChainConfig) -> None: logger.info("Trinity DEBUG log file is created at %s", str(chain_config.logfile_path)) -def run_service_until_quit(service: BaseService) -> None: - loop = service.get_event_loop() - asyncio.ensure_future(exit_on_signal(service), loop=loop) - asyncio.ensure_future(service.run(), loop=loop) - loop.run_forever() - loop.close() +async def handle_networking_exit(service: BaseService, + plugin_manager: PluginManager, + endpoint: Endpoint) -> None: + + async with exit_signal_with_service(service): + await plugin_manager.shutdown() + endpoint.stop() -def setup_plugins(scope: ManagerProcessScope) -> PluginManager: +def setup_plugins(scope: BaseManagerProcessScope) -> PluginManager: plugin_manager = PluginManager(scope) # TODO: Implement auto-discovery of plugins based on some convention/configuration scheme plugin_manager.register(ENABLED_PLUGINS) diff --git a/trinity/plugins/builtin/attach/plugin.py b/trinity/plugins/builtin/attach/plugin.py index 9c099ad8c0..bdf506df47 100644 --- a/trinity/plugins/builtin/attach/plugin.py +++ b/trinity/plugins/builtin/attach/plugin.py @@ -9,7 +9,7 @@ ChainConfig, ) from trinity.extensibility import ( - BasePlugin, + BaseMainProcessPlugin, ) from trinity.plugins.builtin.attach.console import ( @@ -17,7 +17,7 @@ ) -class AttachPlugin(BasePlugin): +class AttachPlugin(BaseMainProcessPlugin): def __init__(self, use_ipython: bool = True) -> None: super().__init__() diff --git a/trinity/plugins/builtin/fix_unclean_shutdown/plugin.py b/trinity/plugins/builtin/fix_unclean_shutdown/plugin.py index c79a6eca63..9c5a03d3d7 100644 --- a/trinity/plugins/builtin/fix_unclean_shutdown/plugin.py +++ b/trinity/plugins/builtin/fix_unclean_shutdown/plugin.py @@ -9,14 +9,14 @@ ChainConfig, ) from trinity.extensibility import ( - BasePlugin, + BaseMainProcessPlugin, ) from trinity.utils.ipc import ( kill_process_id_gracefully, ) -class FixUncleanShutdownPlugin(BasePlugin): +class FixUncleanShutdownPlugin(BaseMainProcessPlugin): @property def name(self) -> str: diff --git a/trinity/plugins/builtin/json_rpc/plugin.py b/trinity/plugins/builtin/json_rpc/plugin.py index 722dff7f83..0665277cd4 100644 --- a/trinity/plugins/builtin/json_rpc/plugin.py +++ b/trinity/plugins/builtin/json_rpc/plugin.py @@ -23,7 +23,7 @@ create_db_manager ) from trinity.utils.shutdown import ( - exit_on_signal + exit_with_service_and_endpoint, ) @@ -64,7 +64,7 @@ def start(self) -> None: ipc_server = IPCServer(rpc, self.context.chain_config.jsonrpc_ipc_path) loop = asyncio.get_event_loop() - asyncio.ensure_future(exit_on_signal(ipc_server, self.context.event_bus)) + asyncio.ensure_future(exit_with_service_and_endpoint(ipc_server, self.context.event_bus)) asyncio.ensure_future(ipc_server.run()) loop.run_forever() loop.close() diff --git a/trinity/plugins/builtin/light_peer_chain_bridge/plugin.py b/trinity/plugins/builtin/light_peer_chain_bridge/plugin.py index f85ebea818..bdc1f6ba7b 100644 --- a/trinity/plugins/builtin/light_peer_chain_bridge/plugin.py +++ b/trinity/plugins/builtin/light_peer_chain_bridge/plugin.py @@ -12,7 +12,7 @@ ) from trinity.extensibility import ( BaseEvent, - BasePlugin, + BaseAsyncStopPlugin, ) from trinity.chains.light import ( LightDispatchChain, @@ -25,7 +25,7 @@ ) -class LightPeerChainBridgePlugin(BasePlugin): +class LightPeerChainBridgePlugin(BaseAsyncStopPlugin): """ The ``LightPeerChainBridgePlugin`` runs in the ``networking`` process and acts as a bridge between other processes and the ``LightPeerChain``. @@ -35,6 +35,7 @@ class LightPeerChainBridgePlugin(BasePlugin): """ chain: BaseChain = None + handler: LightPeerChainEventBusHandler = None @property def name(self) -> str: @@ -51,5 +52,13 @@ def handle_event(self, activation_event: BaseEvent) -> None: def start(self) -> None: self.logger.info('LightPeerChain Bridge started') chain = cast(LightDispatchChain, self.chain) - handler = LightPeerChainEventBusHandler(chain._peer_chain, self.context.event_bus) - asyncio.ensure_future(handler.run()) + self.handler = LightPeerChainEventBusHandler(chain._peer_chain, self.context.event_bus) + asyncio.ensure_future(self.handler.run()) + + async def stop(self) -> None: + # This isn't really needed for the standard shutdown case as the LightPeerChain will + # automatically shutdown whenever the `CancelToken` it was chained with is triggered. + # It may still be useful to stop the LightPeerChain Bridge plugin individually though. + if self.handler.is_operational: + await self.handler.cancel() + self.logger.info("Successfully stopped LightPeerChain Bridge") diff --git a/trinity/plugins/builtin/tx_pool/plugin.py b/trinity/plugins/builtin/tx_pool/plugin.py index ca366a7226..d2f563a6a5 100644 --- a/trinity/plugins/builtin/tx_pool/plugin.py +++ b/trinity/plugins/builtin/tx_pool/plugin.py @@ -23,7 +23,7 @@ ) from trinity.extensibility import ( BaseEvent, - BasePlugin, + BaseAsyncStopPlugin, ) from trinity.extensibility.events import ( ResourceAvailableEvent, @@ -38,11 +38,12 @@ from trinity.protocol.eth.peer import ETHPeerPool -class TxPlugin(BasePlugin): +class TxPlugin(BaseAsyncStopPlugin): peer_pool: ETHPeerPool = None cancel_token: CancelToken = None chain: BaseChain = None is_enabled: bool = False + tx_pool: TxPool = None @property def name(self) -> str: @@ -81,5 +82,13 @@ def start(self) -> None: # tx pool without tx validation in this case raise ValueError("The TxPool plugin only supports MainnetChain or RopstenChain") - tx_pool = TxPool(self.peer_pool, validator, self.cancel_token) - asyncio.ensure_future(tx_pool.run()) + self.tx_pool = TxPool(self.peer_pool, validator, self.cancel_token) + asyncio.ensure_future(self.tx_pool.run()) + + async def stop(self) -> None: + # This isn't really needed for the standard shutdown case as the TxPool will automatically + # shutdown whenever the `CancelToken` it was chained with is triggered. It may still be + # useful to stop the TxPool plugin individually though. + if self.tx_pool.is_operational: + await self.tx_pool.cancel() + self.logger.info("Successfully stopped TxPool") diff --git a/trinity/utils/shutdown.py b/trinity/utils/shutdown.py index 9911844d1d..3eb91d610a 100644 --- a/trinity/utils/shutdown.py +++ b/trinity/utils/shutdown.py @@ -1,5 +1,11 @@ import asyncio +from async_generator import ( + asynccontextmanager, +) import signal +from typing import ( + AsyncGenerator, +) from lahja import ( Endpoint, @@ -10,8 +16,27 @@ ) -async def exit_on_signal(service_to_exit: BaseService, endpoint: Endpoint = None) -> None: +async def exit_with_service_and_endpoint(service_to_exit: BaseService, endpoint: Endpoint) -> None: + async with exit_signal_with_service(service_to_exit): + endpoint.stop() + + +async def exit_with_service(service_to_exit: BaseService) -> None: + async with exit_signal_with_service(service_to_exit): + pass + + +@asynccontextmanager +async def exit_signal_with_service(service_to_exit: BaseService) -> AsyncGenerator[None, None]: loop = service_to_exit.get_event_loop() + async with exit_signal(loop): + await service_to_exit.cancel() + yield + service_to_exit._executor.shutdown(wait=True) + + +@asynccontextmanager +async def exit_signal(loop: asyncio.AbstractEventLoop) -> AsyncGenerator[None, None]: sigint_received = asyncio.Event() for sig in [signal.SIGINT, signal.SIGTERM]: # TODO also support Windows @@ -19,9 +44,6 @@ async def exit_on_signal(service_to_exit: BaseService, endpoint: Endpoint = None await sigint_received.wait() try: - await service_to_exit.cancel() - if endpoint is not None: - endpoint.stop() - service_to_exit._executor.shutdown(wait=True) + yield finally: loop.stop()