Source code for aiomisc.aggregate

import asyncio
import functools
import inspect
import logging
from asyncio import CancelledError, Event, Future, Lock, wait_for
from collections.abc import Callable, Coroutine, Iterable
from dataclasses import dataclass
from inspect import Parameter
from typing import Any, Generic, Protocol, TypeVar

from .compat import EventLoopMixin
from .counters import Statistic

log = logging.getLogger(__name__)


V = TypeVar("V")
R = TypeVar("R")


[docs] @dataclass(frozen=True) class Arg(Generic[V, R]): value: V future: "Future[R]"
[docs] class ResultNotSetError(Exception): pass
[docs] class AggregateAsyncFunc(Protocol, Generic[V, R]): __name__: str async def __call__(self, *args: Arg[V, R]) -> None: ...
[docs] class AggregateStatistic(Statistic): leeway_ms: float max_count: int success: int error: int done: int
def _has_variadic_positional(func: Callable[..., Any]) -> bool: return any( parameter.kind == Parameter.VAR_POSITIONAL for parameter in inspect.signature(func).parameters.values() )
[docs] class AggregatorAsync(EventLoopMixin, Generic[V, R]): _func: AggregateAsyncFunc[V, R] _max_count: int | None _leeway: float _first_call_at: float | None _args: list _futures: "list[Future[R]]" _event: Event _lock: Lock def __init__( self, func: AggregateAsyncFunc[V, R], *, leeway_ms: float, max_count: int | None = None, statistic_name: str | None = None, ): if not _has_variadic_positional(func): raise ValueError( "Function must accept variadic positional arguments" ) if max_count is not None and max_count <= 0: raise ValueError("max_count must be positive int or None") if leeway_ms <= 0: raise ValueError("leeway_ms must be positive float") self._func = func self._max_count = max_count self._leeway = leeway_ms / 1000 self._clear() self._statistic = AggregateStatistic(statistic_name) self._statistic.leeway_ms = self.leeway_ms self._statistic.max_count = max_count or 0 def _clear(self) -> None: self._first_call_at = None self._args = [] self._futures = [] self._event = Event() self._lock = Lock() @property def max_count(self) -> int | None: return self._max_count @property def leeway_ms(self) -> float: return self._leeway * 1000 @property def count(self) -> int: return len(self._args) async def _execute( self, *, args: list[V], futures: "list[Future[R]]" ) -> None: args_ = [ Arg(value=arg, future=future) for arg, future in zip(args, futures) ] try: await self._func(*args_) self._statistic.success += 1 except CancelledError: # Other waiting tasks can try to finish the job instead. raise except Exception as e: self._set_exception(e, futures) self._statistic.error += 1 return finally: self._statistic.done += 1 # Validate that all results/exceptions are set by the func for future in futures: if not future.done(): future.set_exception(ResultNotSetError) def _set_exception( self, exc: Exception, futures: list["Future[R]"] ) -> None: for future in futures: if not future.done(): future.set_exception(exc)
[docs] async def aggregate(self, arg: V) -> R: if self._first_call_at is None: self._first_call_at = self.loop.time() first_call_at = self._first_call_at args: list = self._args futures: list[Future[R]] = self._futures event: Event = self._event lock: Lock = self._lock args.append(arg) future: Future[R] = Future() futures.append(future) if self.count == self.max_count: event.set() self._clear() else: # Waiting for max_count requests or a timeout try: await wait_for( event.wait(), timeout=first_call_at + self._leeway - self.loop.time(), ) except TimeoutError: log.debug( "Aggregation timeout of %s for batch started at %.4f " "with %d calls after %.2f ms", self._func.__name__, first_call_at, len(futures), (self.loop.time() - first_call_at) * 1000, ) # Clear only if not cleared already if args is self._args: self._clear() # Trying to acquire the lock to execute the aggregated function async with lock: if not future.done(): await self._execute(args=args, futures=futures) await future return future.result()
S = TypeVar("S", contravariant=True) T = TypeVar("T", covariant=True)
[docs] class AggregateFunc(Protocol, Generic[S, T]): __name__: str async def __call__(self, *args: S) -> Iterable[T]: ...
def _to_async_aggregate(func: AggregateFunc[V, R]) -> AggregateAsyncFunc[V, R]: @functools.wraps( func, assigned=tuple( item for item in functools.WRAPPER_ASSIGNMENTS if item != "__annotations__" ), ) async def wrapper(*args: Arg[V, R]) -> None: args_ = [item.value for item in args] results = await func(*args_) for res, arg in zip(results, args): if not arg.future.done(): arg.future.set_result(res) return wrapper
[docs] class Aggregator(AggregatorAsync[V, R], Generic[V, R]): def __init__( self, func: AggregateFunc[V, R], *, leeway_ms: float, max_count: int | None = None, statistic_name: str | None = None, ) -> None: if not _has_variadic_positional(func): raise ValueError( "Function must accept variadic positional arguments" ) super().__init__( _to_async_aggregate(func), leeway_ms=leeway_ms, max_count=max_count, statistic_name=statistic_name, )
[docs] def aggregate( leeway_ms: float, max_count: int | None = None ) -> Callable[[AggregateFunc[V, R]], Callable[[V], Coroutine[Any, Any, R]]]: """ Parametric decorator that aggregates multiple (but no more than ``max_count`` defaulting to ``None``) single-argument executions (``res1 = await func(arg1)``, ``res2 = await func(arg2)``, ...) of an asynchronous function with variadic positional arguments (``async def func(*args, pho=1, bo=2) -> Iterable``) into its single execution with multiple positional arguments (``res1, res2, ... = await func(arg1, arg2, ...)``) collected within a time window ``leeway_ms``. .. note:: ``func`` must return a sequence of values of length equal to the number of arguments (and in the same order). .. note:: if some unexpected error occurs, exception is propagated to each future; to set an individual error for each aggregated call refer to ``aggregate_async``. :param leeway_ms: The maximum approximate delay between the first collected argument and the aggregated execution. :param max_count: The maximum number of arguments to call decorated function with. Default ``None``. :return: """ def decorator( func: AggregateFunc[V, R], ) -> Callable[[V], Coroutine[Any, Any, R]]: aggregator = Aggregator(func, max_count=max_count, leeway_ms=leeway_ms) return aggregator.aggregate return decorator
[docs] def aggregate_async( leeway_ms: float, max_count: int | None = None ) -> Callable[ [AggregateAsyncFunc[V, R]], Callable[[V], Coroutine[Any, Any, R]] ]: """ Same as ``aggregate``, but with ``func`` arguments of type ``Arg`` containing ``value`` and ``future`` attributes instead. In this setting ``func`` is responsible for setting individual results/exceptions for all of the futures or throwing an exception (it will propagate to futures automatically). If ``func`` mistakenly does not set a result of some future, then, ``ResultNotSetError`` exception is set. :return: """ def decorator( func: AggregateAsyncFunc[V, R], ) -> Callable[[V], Coroutine[Any, Any, R]]: aggregator = AggregatorAsync( func, max_count=max_count, leeway_ms=leeway_ms ) return aggregator.aggregate return decorator