import asyncio
import functools
import inspect
import logging
from asyncio import CancelledError, Event, Future, Lock, wait_for
from collections.abc import Callable, Coroutine, Iterable
from dataclasses import dataclass
from inspect import Parameter
from typing import Any, Generic, Protocol, TypeVar
from .compat import EventLoopMixin
from .counters import Statistic
log = logging.getLogger(__name__)
V = TypeVar("V")
R = TypeVar("R")
[docs]
@dataclass(frozen=True)
class Arg(Generic[V, R]):
value: V
future: "Future[R]"
[docs]
class ResultNotSetError(Exception):
pass
[docs]
class AggregateAsyncFunc(Protocol, Generic[V, R]):
__name__: str
async def __call__(self, *args: Arg[V, R]) -> None: ...
[docs]
class AggregateStatistic(Statistic):
leeway_ms: float
max_count: int
success: int
error: int
done: int
def _has_variadic_positional(func: Callable[..., Any]) -> bool:
return any(
parameter.kind == Parameter.VAR_POSITIONAL
for parameter in inspect.signature(func).parameters.values()
)
[docs]
class AggregatorAsync(EventLoopMixin, Generic[V, R]):
_func: AggregateAsyncFunc[V, R]
_max_count: int | None
_leeway: float
_first_call_at: float | None
_args: list
_futures: "list[Future[R]]"
_event: Event
_lock: Lock
def __init__(
self,
func: AggregateAsyncFunc[V, R],
*,
leeway_ms: float,
max_count: int | None = None,
statistic_name: str | None = None,
):
if not _has_variadic_positional(func):
raise ValueError(
"Function must accept variadic positional arguments"
)
if max_count is not None and max_count <= 0:
raise ValueError("max_count must be positive int or None")
if leeway_ms <= 0:
raise ValueError("leeway_ms must be positive float")
self._func = func
self._max_count = max_count
self._leeway = leeway_ms / 1000
self._clear()
self._statistic = AggregateStatistic(statistic_name)
self._statistic.leeway_ms = self.leeway_ms
self._statistic.max_count = max_count or 0
def _clear(self) -> None:
self._first_call_at = None
self._args = []
self._futures = []
self._event = Event()
self._lock = Lock()
@property
def max_count(self) -> int | None:
return self._max_count
@property
def leeway_ms(self) -> float:
return self._leeway * 1000
@property
def count(self) -> int:
return len(self._args)
async def _execute(
self, *, args: list[V], futures: "list[Future[R]]"
) -> None:
args_ = [
Arg(value=arg, future=future) for arg, future in zip(args, futures)
]
try:
await self._func(*args_)
self._statistic.success += 1
except CancelledError:
# Other waiting tasks can try to finish the job instead.
raise
except Exception as e:
self._set_exception(e, futures)
self._statistic.error += 1
return
finally:
self._statistic.done += 1
# Validate that all results/exceptions are set by the func
for future in futures:
if not future.done():
future.set_exception(ResultNotSetError)
def _set_exception(
self, exc: Exception, futures: list["Future[R]"]
) -> None:
for future in futures:
if not future.done():
future.set_exception(exc)
[docs]
async def aggregate(self, arg: V) -> R:
if self._first_call_at is None:
self._first_call_at = self.loop.time()
first_call_at = self._first_call_at
args: list = self._args
futures: list[Future[R]] = self._futures
event: Event = self._event
lock: Lock = self._lock
args.append(arg)
future: Future[R] = Future()
futures.append(future)
if self.count == self.max_count:
event.set()
self._clear()
else:
# Waiting for max_count requests or a timeout
try:
await wait_for(
event.wait(),
timeout=first_call_at + self._leeway - self.loop.time(),
)
except TimeoutError:
log.debug(
"Aggregation timeout of %s for batch started at %.4f "
"with %d calls after %.2f ms",
self._func.__name__,
first_call_at,
len(futures),
(self.loop.time() - first_call_at) * 1000,
)
# Clear only if not cleared already
if args is self._args:
self._clear()
# Trying to acquire the lock to execute the aggregated function
async with lock:
if not future.done():
await self._execute(args=args, futures=futures)
await future
return future.result()
S = TypeVar("S", contravariant=True)
T = TypeVar("T", covariant=True)
[docs]
class AggregateFunc(Protocol, Generic[S, T]):
__name__: str
async def __call__(self, *args: S) -> Iterable[T]: ...
def _to_async_aggregate(func: AggregateFunc[V, R]) -> AggregateAsyncFunc[V, R]:
@functools.wraps(
func,
assigned=tuple(
item
for item in functools.WRAPPER_ASSIGNMENTS
if item != "__annotations__"
),
)
async def wrapper(*args: Arg[V, R]) -> None:
args_ = [item.value for item in args]
results = await func(*args_)
for res, arg in zip(results, args):
if not arg.future.done():
arg.future.set_result(res)
return wrapper
[docs]
class Aggregator(AggregatorAsync[V, R], Generic[V, R]):
def __init__(
self,
func: AggregateFunc[V, R],
*,
leeway_ms: float,
max_count: int | None = None,
statistic_name: str | None = None,
) -> None:
if not _has_variadic_positional(func):
raise ValueError(
"Function must accept variadic positional arguments"
)
super().__init__(
_to_async_aggregate(func),
leeway_ms=leeway_ms,
max_count=max_count,
statistic_name=statistic_name,
)
[docs]
def aggregate(
leeway_ms: float, max_count: int | None = None
) -> Callable[[AggregateFunc[V, R]], Callable[[V], Coroutine[Any, Any, R]]]:
"""
Parametric decorator that aggregates multiple
(but no more than ``max_count`` defaulting to ``None``) single-argument
executions (``res1 = await func(arg1)``, ``res2 = await func(arg2)``, ...)
of an asynchronous function with variadic positional arguments
(``async def func(*args, pho=1, bo=2) -> Iterable``) into its single
execution with multiple positional arguments
(``res1, res2, ... = await func(arg1, arg2, ...)``) collected within a time
window ``leeway_ms``.
.. note::
``func`` must return a sequence of values of length equal to the
number of arguments (and in the same order).
.. note::
if some unexpected error occurs, exception is propagated to each
future; to set an individual error for each aggregated call refer
to ``aggregate_async``.
:param leeway_ms: The maximum approximate delay between the first
collected argument and the aggregated execution.
:param max_count: The maximum number of arguments to call decorated
function with. Default ``None``.
:return:
"""
def decorator(
func: AggregateFunc[V, R],
) -> Callable[[V], Coroutine[Any, Any, R]]:
aggregator = Aggregator(func, max_count=max_count, leeway_ms=leeway_ms)
return aggregator.aggregate
return decorator
[docs]
def aggregate_async(
leeway_ms: float, max_count: int | None = None
) -> Callable[
[AggregateAsyncFunc[V, R]], Callable[[V], Coroutine[Any, Any, R]]
]:
"""
Same as ``aggregate``, but with ``func`` arguments of type ``Arg``
containing ``value`` and ``future`` attributes instead. In this setting
``func`` is responsible for setting individual results/exceptions for all
of the futures or throwing an exception (it will propagate to futures
automatically). If ``func`` mistakenly does not set a result of some
future, then, ``ResultNotSetError`` exception is set.
:return:
"""
def decorator(
func: AggregateAsyncFunc[V, R],
) -> Callable[[V], Coroutine[Any, Any, R]]:
aggregator = AggregatorAsync(
func, max_count=max_count, leeway_ms=leeway_ms
)
return aggregator.aggregate
return decorator