Source code for aiothreads.threads

import asyncio
import contextvars
import inspect
import logging
import os
from abc import ABC, abstractmethod
from concurrent.futures import ThreadPoolExecutor as ThreadPoolExecutorBase
from functools import partial
from types import MappingProxyType
from typing import Any, Awaitable, Callable, Concatenate, Generator, Generic, MutableMapping, Optional, overload
from weakref import WeakKeyDictionary

from .iterator_wrapper import IteratorWrapper
from .new_thread import run_in_new_thread
from .threaded_iterable import threaded_iterable
from .types import BP, EVENT_LOOP, P, S, T


log = logging.getLogger(__name__)

THREADED_ITERABLE_DEFAULT_MAX_SIZE = int(os.getenv("THREADED_ITERABLE_DEFAULT_MAX_SIZE", 1024))


class ThreadPoolException(RuntimeError):
    pass


def run_in_executor(
    func: Callable[..., T],
    executor: Optional[ThreadPoolExecutorBase] = None,
    args: Any = (),
    kwargs: Any = MappingProxyType({}),
) -> Awaitable[T]:
    try:
        loop = asyncio.get_running_loop()
        context = contextvars.copy_context()

        def in_thread() -> T:
            nonlocal func, args, kwargs, loop, context
            EVENT_LOOP.set(loop)
            return context.run(func, *args, **kwargs)

        return loop.run_in_executor(executor, in_thread)
    except RuntimeError:
        # In case the event loop is not running right now is
        # returning coroutine to avoid DeprecationWarning in Python 3.10
        async def lazy_wrapper() -> T:
            loop = asyncio.get_running_loop()
            return await loop.run_in_executor(executor, partial(func, *args, **kwargs))

        return lazy_wrapper()


async def _awaiter(future: asyncio.Future) -> T:
    try:
        result = await future
        return result
    except asyncio.CancelledError as e:
        if not future.done():
            future.set_exception(e)
        raise


class ThreadedBase(Generic[P, T], ABC):
    func: Callable[P, T]

    @abstractmethod
    def __init__(self, *args: Any, **kwargs: Any) -> None:
        ...

    def sync_call(self, *args: P.args, **kwargs: P.kwargs) -> T:
        return self.func(*args, **kwargs)

    def async_call(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T]:
        EVENT_LOOP.set(asyncio.get_running_loop())
        return run_in_executor(func=self.func, args=args, kwargs=kwargs)

    def __repr__(self) -> str:
        f = getattr(self.func, "func", self.func)
        name = getattr(f, "__name__", f.__class__.__name__)
        return f"<{self.__class__.__name__} {name} at {id(self):#x}>"

    def __call__(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T]:
        return self.async_call(*args, **kwargs)


class Threaded(ThreadedBase[P, T]):
    func_type: type

    def __init__(self, func: Callable[P, T]) -> None:
        self.__cache: MutableMapping[Any, Any] = WeakKeyDictionary()

        if isinstance(func, staticmethod):
            self.func_type = staticmethod
            self.func = func.__func__
        elif isinstance(func, classmethod):
            self.func_type = classmethod
            self.func = func.__func__
        else:
            self.func_type = type(func)
            self.func = func

        if asyncio.iscoroutinefunction(self.func):
            raise TypeError("Can not wrap coroutine")
        if inspect.isgeneratorfunction(self.func):
            raise TypeError("Can not wrap generator function")

    @overload
    def __get__(
        self: "Threaded[Concatenate[S, BP], T]",
        instance: S,
        owner: Optional[type] = ...,
    ) -> "BoundThreaded[BP, T]":
        ...

    @overload
    def __get__(
        self: "Threaded[P, T]",
        instance: None,
        owner: Optional[type] = ...,
    ) -> "Threaded[P, T]":
        ...

    def __get__(
        self,
        instance: Any,
        owner: Optional[type] = None,
    ) -> "Threaded[P, T] | BoundThreaded[Any, T]":
        key = instance
        result: Any
        if key in self.__cache:
            return self.__cache[key]
        if self.func_type is staticmethod:
            result = self
        elif self.func_type is classmethod:
            cls = owner if instance is None else type(instance)
            result = BoundThreaded(self.func, cls)
        elif instance is not None:
            result = BoundThreaded(self.func, instance)
        else:
            result = self

        self.__cache[key] = result
        return result


class BoundThreaded(ThreadedBase[P, T]):
    __instance: Any

    def __init__(self, func: Callable[..., T], instance: Any) -> None:
        self.__instance = instance
        self.func = lambda *args, **kwargs: func(instance, *args, **kwargs)


@overload
def threaded(func: Callable[P, T]) -> Threaded[P, T]:
    ...


@overload
def threaded(
    func: Callable[P, Generator[T, None, None]],
) -> Callable[P, IteratorWrapper[P, T]]:
    ...


[docs] def threaded( func: Callable[P, T] | Callable[P, Generator[T, None, None]], ) -> Threaded[P, T] | Callable[P, IteratorWrapper[P, T]]: if inspect.isgeneratorfunction(func): return threaded_iterable( func, max_size=THREADED_ITERABLE_DEFAULT_MAX_SIZE, ) return Threaded(func) # type: ignore
class ThreadedSeparate(Threaded[P, T]): __slots__ = Threaded.__slots__ + ("detach",) def __init__(self, func: Callable[P, T], detach: bool = True) -> None: super().__init__(func) self.detach = detach def async_call(self, *args: P.args, **kwargs: P.kwargs) -> Awaitable[T]: return run_in_new_thread(self.func, args=args, kwargs=kwargs, detach=self.detach)
[docs] def threaded_separate( func: Callable[P, T], detach: bool = True, ) -> ThreadedSeparate[P, T]: if isinstance(func, bool): # noinspection PyTypeChecker return partial(threaded_separate, detach=detach) if asyncio.iscoroutinefunction(func): raise TypeError("Can not wrap coroutine") return ThreadedSeparate(func, detach=detach)