Source code for aiomisc.entrypoint

import asyncio
import logging
import os
import signal
import sys
import threading
from concurrent.futures import Executor
from typing import (
    Any, Callable, Coroutine, FrozenSet, MutableSet, Optional, Set, Tuple,
    TypeVar, Union,
)
from weakref import WeakSet

import aiomisc_log
from aiomisc_log import LogLevel

from ._context_vars import EVENT_LOOP, StrictContextVar
from .compat import event_loop_policy, final
from .context import Context, get_context
from .log import LogFormat, basic_config
from .service import Service
from .signal import Signal
from .utils import cancel_tasks, create_default_event_loop


ExecutorType = Executor
T = TypeVar("T")
log = logging.getLogger(__name__)


asyncio_all_tasks = asyncio.all_tasks
asyncio_current_task = asyncio.current_task


def is_main_thread() -> bool:
    return threading.current_thread() == threading.main_thread()


def _get_env_bool(name: str, default: str) -> bool:
    enable_variants = {"1", "enable", "enabled", "on", "true", "yes"}
    return os.getenv(name, default).lower() in enable_variants


def _get_env_convert(name: str, converter: Callable[..., T], default: T) -> T:
    value = os.getenv(name)
    if value is None:
        return default
    return converter(value)


class OSSignalHandler:
    def __init__(
        self, sig: int, handler: Callable[[int], None],
    ):
        self.default_handler = signal.getsignal(sig)
        self.signal = sig
        self.handler = handler

    def __callback(self, *_: Any) -> None:
        self.handler(self.signal)

    def apply(self) -> None:
        signal.signal(self.signal, self.__callback)

    def restore(self) -> None:
        signal.signal(self.signal, self.default_handler)


[docs]@final class Entrypoint: DEFAULT_LOG_LEVEL: str = os.getenv( "AIOMISC_LOG_LEVEL", LogLevel.default(), ) DEFAULT_LOG_FORMAT: str = os.getenv( "AIOMISC_LOG_FORMAT", LogFormat.default(), ) DEFAULT_LOG_DATE_FORMAT: Optional[str] = os.getenv( "AIOMISC_LOG_DATE_FORMAT", ) DEFAULT_AIOMISC_DEBUG: bool = _get_env_bool("AIOMISC_DEBUG", "0") DEFAULT_AIOMISC_LOG_CONFIG: bool = _get_env_bool( "AIOMISC_LOG_CONFIG", "1", ) DEFAULT_AIOMISC_LOG_FLUSH: float = _get_env_convert( "AIOMISC_LOG_FLUSH", float, 0.2, ) DEFAULT_AIOMISC_BUFFERING: bool = _get_env_bool( "AIOMISC_LOG_BUFFERING", "1", ) DEFAULT_AIOMISC_BUFFER_SIZE: int = _get_env_convert( "AIOMISC_LOG_BUFFER", int, 1024, ) DEFAULT_AIOMISC_POOL_SIZE: Optional[int] = _get_env_convert( "AIOMISC_POOL_SIZE", int, None, ) AIOMISC_SHUTDOWN_TIMEOUT: float = _get_env_convert( "AIOMISC_SHUTDOWN_TIMEOUT", float, 60., ) PRE_START = Signal() POST_STOP = Signal() POST_START = Signal() PRE_STOP = Signal()
[docs] @classmethod def get_current(cls) -> "Entrypoint": return CURRENT_ENTRYPOINT.get()
async def _start(self) -> None: if self.log_config: basic_config( level=self.log_level, log_format=self.log_format, date_format=self.log_date_format, buffered=self.log_buffering, loop=self.loop, buffer_size=self.log_buffer_size, flush_interval=self.log_flush_interval, ) CURRENT_ENTRYPOINT.set(self) EVENT_LOOP.set(self.loop) signals = ( self.pre_start, self.post_stop, self.pre_stop, self.post_start, ) for sig in signals: sig.freeze() await self.start_services(*self.__passed_services) del self.__passed_services for handler in self._signal_handlers: handler.apply() def __init__( self, *services: Service, loop: Optional[asyncio.AbstractEventLoop] = None, pool_size: Optional[int] = DEFAULT_AIOMISC_POOL_SIZE, log_level: Union[int, str] = DEFAULT_LOG_LEVEL, log_format: Union[str, LogFormat] = DEFAULT_LOG_FORMAT, log_buffering: bool = DEFAULT_AIOMISC_BUFFERING, log_buffer_size: int = DEFAULT_AIOMISC_BUFFER_SIZE, log_date_format: Optional[str] = DEFAULT_LOG_DATE_FORMAT, log_flush_interval: float = DEFAULT_AIOMISC_LOG_FLUSH, log_config: bool = DEFAULT_AIOMISC_LOG_CONFIG, policy: asyncio.AbstractEventLoopPolicy = event_loop_policy, debug: bool = DEFAULT_AIOMISC_DEBUG, catch_signals: Optional[Tuple[int, ...]] = None, shutdown_timeout: Union[int, float] = AIOMISC_SHUTDOWN_TIMEOUT, ): """ Creates a new Entrypoint :param debug: set debug to event-loop :param loop: loop :param services: Service instances which will be starting. :param pool_size: thread pool size :param log_level: Logging level which will be configured :param log_format: Logging format which will be configured :param log_buffer_size: Buffer size for logging :param log_flush_interval: interval in seconds for flushing logs :param log_config: if False do not configure logging :param catch_signals: Perform shutdown when this signals will be received :param shutdown_timeout: Timeout in seconds for graceful shutdown """ self.__passed_services: FrozenSet[Service] = frozenset(services) self._services: Set[Service] = set() self._debug = debug self._loop = loop self._loop_owner = False self._tasks: MutableSet[asyncio.Task] = WeakSet() self._thread_pool: Optional[ExecutorType] = None self._closing: Optional[asyncio.Event] = None if catch_signals is None and is_main_thread(): # Apply default catch signals only if the entrypoint is creating # in only in main thread catch_signals = (signal.SIGINT, signal.SIGTERM) self._signal_handlers = [ OSSignalHandler(sig, self._on_interrupt_callback) for sig in catch_signals or () ] self.catch_signals = catch_signals self.shutdown_timeout = float(shutdown_timeout) self.ctx: Optional[Context] = None self.log_buffer_size = log_buffer_size self.log_buffering = log_buffering self.log_config = log_config self.log_date_format = log_date_format self.log_flush_interval = log_flush_interval self.log_format = log_format self.log_level = log_level self.policy = policy # signals self.pool_size = pool_size self.pre_start = self.PRE_START.copy() self.post_start = self.POST_START.copy() self.pre_stop = self.PRE_STOP.copy() self.post_stop = self.POST_STOP.copy() if self.log_config: aiomisc_log.basic_config( level=self.log_level, log_format=self.log_format, date_format=log_date_format, ) if self._loop is not None: EVENT_LOOP.set(self._loop) CURRENT_ENTRYPOINT.set(self) @property def services(self) -> Tuple[Service, ...]: return tuple(self._services)
[docs] async def closing(self) -> None: # Lazy initialization because event loop might be not exists if self._closing is None: self._closing = asyncio.Event() await self._closing.wait()
@property def loop(self) -> asyncio.AbstractEventLoop: if self._loop is None: self._loop, self._thread_pool = create_default_event_loop( pool_size=self.pool_size, policy=self.policy, debug=self._debug, ) self._loop_owner = True EVENT_LOOP.set(self._loop) return self._loop def __del__(self) -> None: if self._loop and self._loop.is_closed(): return if self._loop_owner and self._loop is not None: self._loop.close() def __enter__(self) -> asyncio.AbstractEventLoop: self.loop.run_until_complete(self.__aenter__()) return self.loop def __exit__( self, exc_type: Any, exc_val: Any, exc_tb: Any, ) -> None: loop = self.loop if loop.is_closed(): return if self.log_config: basic_config( level=self.log_level, log_format=self.log_format, date_format=self.log_date_format, buffered=False, ) loop.run_until_complete(self.__aexit__(exc_type, exc_val, exc_tb)) if self._loop_owner and self._loop is not None: loop.close() async def __aenter__(self) -> "Entrypoint": if self._loop is None: # When __aenter__ called without __enter__ self._loop = asyncio.get_running_loop() self.ctx = Context(loop=self.loop) await self._start() return self async def __aexit__( self, exc_type: Any, exc_val: Any, exc_tb: Any, ) -> None: await self._stop(exc_val) if sys.version_info < (3, 9): async def __shutdown_thread_pool( self, loop: asyncio.AbstractEventLoop, ) -> None: result = self._thread_pool.shutdown() if hasattr(result, "__await__"): await result else: def __shutdown_thread_pool( self, loop: asyncio.AbstractEventLoop, ) -> Coroutine[Any, Any, None]: return loop.shutdown_default_executor() async def _stop(self, exc: Exception) -> None: loop = self.loop for handler in self._signal_handlers: handler.restore() try: if loop.is_closed(): return await self.graceful_shutdown(exc) finally: if self.ctx: self.ctx.close() if self._thread_pool: await self.__shutdown_thread_pool(loop) async def _start_service( self, svc: Service, ) -> None: svc.set_loop(self.loop) start_task, ev_task = map( asyncio.ensure_future, ( svc.start(), svc.start_event.wait(), ), ) self._services.add(svc) await asyncio.wait( (start_task, ev_task), return_when=asyncio.FIRST_COMPLETED, ) self.loop.call_soon(svc.start_event.set) await ev_task if start_task.done(): # raise an Exception when failed await start_task return else: self._tasks.add(start_task) return None
[docs] async def start_services(self, *svc: Service) -> None: await self.pre_start.call(entrypoint=self, services=svc) try: await asyncio.gather(*[self._start_service(s) for s in svc]) finally: await self.post_start.call(entrypoint=self, services=svc)
[docs] async def stop_services( self, *svc: Service, exc: Optional[Exception] = None, ) -> None: await self.pre_stop.call(entrypoint=self, services=svc) if not svc: await self.post_stop.call(entrypoint=self, services=svc) return tasks = [] try: for s in svc: try: log.debug("Stopping service %r", s) coro = s.stop(exc) if hasattr(coro, "__await__"): tasks.append(self.loop.create_task(coro)) except TypeError as e: log.warning("Failed to stop service %r:\n%r", svc, e) log.debug("Service stop failed traceback", exc_info=True) await asyncio.gather(*tasks) finally: await cancel_tasks(tasks) for s in svc: self._services.discard(s) await self.post_stop.call(entrypoint=self, services=svc)
async def _cancel_background_tasks(self) -> None: current_task = asyncio_current_task(self.loop) await cancel_tasks( filter( lambda x: x is not current_task, asyncio_all_tasks(self._loop), ), )
[docs] async def graceful_shutdown(self, exception: Exception) -> None: if self._closing: self._closing.set() await cancel_tasks(set(self._tasks)) await self.stop_services(*self._services, exc=exception) if self._loop_owner: await self._cancel_background_tasks() await self.loop.shutdown_asyncgens()
def _on_interrupt_callback(self, _: Any) -> None: loop = self.loop self.loop.call_soon_threadsafe(self._on_interrupt, loop) def _on_interrupt(self, loop: asyncio.AbstractEventLoop) -> None: log.warning("Interrupt signal received, shutting down...") task = loop.create_task( self._stop(RuntimeError("Interrupt signal received")), ) handle = loop.call_later(self.shutdown_timeout, task.cancel) def on_shutdown_finish(task: asyncio.Future) -> None: nonlocal handle, loop if task.cancelled(): log.warning( "Shutdown did not happen in %s seconds, aborting.", self.shutdown_timeout, ) handle.cancel() loop.stop() if task.cancelled(): # 70 from sysexits.h means "internal software error" raise SystemExit(70) task.add_done_callback(on_shutdown_finish)
CURRENT_ENTRYPOINT: StrictContextVar[Entrypoint] = StrictContextVar( "CURRENT_ENTRYPOINT", RuntimeError("no current entrypoint is set"), ) entrypoint = Entrypoint
[docs]def run( coro: Coroutine[None, Any, T], *services: Service, **kwargs: Any, ) -> T: with entrypoint(*services, **kwargs) as loop: return loop.run_until_complete(coro)
__all__ = ( "CURRENT_ENTRYPOINT", "Entrypoint", "entrypoint", "get_context", "run", )