import asyncio
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Awaitable, Callable, Coroutine
from contextlib import AbstractAsyncContextManager
from random import random
from typing import Any, Generic, NoReturn, TypeVar
from .compat import EventLoopMixin
from .utils import cancel_tasks
T = TypeVar("T", bound=Any)
Number = int | float
log = logging.getLogger(__name__)
[docs]
class ContextManager(AbstractAsyncContextManager):
__slots__ = "__aenter", "__aexit", "__instance"
sentinel = object()
def __init__(
self,
aenter: Callable[..., Awaitable[T]],
aexit: Callable[..., Awaitable[T]],
):
self.__aenter = aenter
self.__aexit = aexit
self.__instance = self.sentinel
async def __aenter__(self) -> T:
if self.__instance is not self.sentinel:
raise RuntimeError("Reuse of context manager is not acceptable")
self.__instance = await self.__aenter()
return self.__instance
async def __aexit__(
self, exc_type: Any, exc_value: Any, traceback: Any
) -> Any:
await self.__aexit(self.__instance)
[docs]
class PoolBase(ABC, EventLoopMixin, Generic[T]):
__slots__ = (
"_create_lock",
"_instances",
"_recycle",
"_recycle_bin",
"_recycle_times",
"_semaphore",
"_tasks",
"_len",
"_used",
) + EventLoopMixin.__slots__
_tasks: set[Any]
_used: set[Any]
_instances: asyncio.Queue
_recycle_bin: asyncio.Queue
def __init__(self, maxsize: int = 10, recycle: int | None = None):
assert recycle is None or recycle > 0, (
"recycle should be positive number or None"
)
self._instances = asyncio.Queue()
self._recycle_bin = asyncio.Queue()
self._semaphore = asyncio.Semaphore(maxsize)
self._len = 0
self._recycle = recycle
self._tasks = set()
self._used = set()
self._create_lock = asyncio.Lock()
self._recycle_times: defaultdict[float, Any] = defaultdict(
self.loop.time
)
self.__create_task(self.__recycler())
def __create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task:
task = self.loop.create_task(coro)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
return task
async def __recycler(self) -> NoReturn:
while True:
instance = await self._recycle_bin.get()
try:
await self._destroy_instance(instance)
except Exception:
log.exception("Error when recycle instance %r", instance)
finally:
self._recycle_bin.task_done()
@abstractmethod
async def _create_instance(self) -> T:
pass
@abstractmethod
async def _destroy_instance(self, instance: Any) -> None:
pass
# noinspection PyMethodMayBeStatic,PyUnusedLocal
@abstractmethod
async def _check_instance(self, instance: Any) -> bool:
return True
def __len__(self) -> int:
return self._len
def __recycle_instance(self, instance: Any) -> None:
self._len -= 1
self._semaphore.release()
if instance in self._recycle_times:
self._recycle_times.pop(instance)
if instance in self._used:
self._used.discard(instance)
self._recycle_bin.put_nowait(instance)
async def __create_new_instance(self) -> None:
await self._semaphore.acquire()
instance: Any = await self._create_instance()
self._len += 1
if self._recycle:
deadline = self._recycle * (1 + random())
self._recycle_times[instance] += deadline
await self._instances.put(instance)
async def __acquire(self) -> T:
if not self._semaphore.locked():
await self.__create_new_instance()
instance = await self._instances.get()
try:
result = await self._check_instance(instance)
except Exception:
log.exception("Check instance %r failed", instance)
self.__recycle_instance(instance)
else:
if not result:
self.__recycle_instance(instance)
return await self.__acquire()
self._used.add(instance)
return instance
async def __release(self, instance: Any) -> None:
self._used.discard(instance)
if self._recycle and self._recycle_times[instance] < self.loop.time():
self.__recycle_instance(instance)
return
self._instances.put_nowait(instance)
[docs]
def acquire(self) -> AbstractAsyncContextManager[T]:
return ContextManager(self.__acquire, self.__release)
[docs]
async def close(self, timeout: Number | None = None) -> None:
instances = list(self._used)
self._used.clear()
while self._instances.qsize():
try:
instances.append(self._instances.get_nowait())
except asyncio.QueueEmpty:
break
async def log_exception(coro: Awaitable[Any]) -> None:
try:
await coro
except Exception:
log.exception("Exception when task execution")
await asyncio.wait_for(
asyncio.gather(
*[
self.__create_task(
log_exception(self._destroy_instance(instance))
)
for instance in instances
],
return_exceptions=True,
),
timeout=timeout,
)
await cancel_tasks(self._tasks)