import asyncio
import logging
from abc import ABC, abstractmethod
from collections import defaultdict
from random import random
from typing import (
Any, AsyncContextManager, Awaitable, Callable, Coroutine, DefaultDict,
Generic, NoReturn, Optional, Set, TypeVar, Union,
)
from .compat import EventLoopMixin
from .utils import cancel_tasks
T = TypeVar("T", bound=Any)
Number = Union[int, float]
log = logging.getLogger(__name__)
[docs]class ContextManager(AsyncContextManager):
__slots__ = "__aenter", "__aexit", "__instance"
sentinel = object()
def __init__(
self, aenter: Callable[..., Awaitable[T]],
aexit: Callable[..., Awaitable[T]],
):
self.__aenter = aenter
self.__aexit = aexit
self.__instance = self.sentinel
async def __aenter__(self) -> T:
if self.__instance is not self.sentinel:
raise RuntimeError("Reuse of context manager is not acceptable")
self.__instance = await self.__aenter()
return self.__instance
async def __aexit__(
self, exc_type: Any, exc_value: Any,
traceback: Any,
) -> Any:
await self.__aexit(self.__instance)
[docs]class PoolBase(ABC, EventLoopMixin, Generic[T]):
__slots__ = (
"_create_lock",
"_instances",
"_recycle",
"_recycle_bin",
"_recycle_times",
"_semaphore",
"_tasks",
"_len",
"_used",
) + EventLoopMixin.__slots__
_tasks: Set[Any]
_used: Set[Any]
_instances: asyncio.Queue
_recycle_bin: asyncio.Queue
def __init__(self, maxsize: int = 10, recycle: Optional[int] = None):
assert (
recycle is None or recycle > 0
), "recycle should be positive number or None"
self._instances = asyncio.Queue()
self._recycle_bin = asyncio.Queue()
self._semaphore = asyncio.Semaphore(maxsize)
self._len = 0
self._recycle = recycle
self._tasks = set()
self._used = set()
self._create_lock = asyncio.Lock()
self._recycle_times: DefaultDict[float, Any] = defaultdict(
self.loop.time,
)
self.__create_task(self.__recycler())
def __create_task(self, coro: Coroutine[Any, Any, Any]) -> asyncio.Task:
task = self.loop.create_task(coro)
self._tasks.add(task)
task.add_done_callback(self._tasks.discard)
return task
async def __recycler(self) -> NoReturn:
while True:
instance = await self._recycle_bin.get()
try:
await self._destroy_instance(instance)
except Exception:
log.exception("Error when recycle instance %r", instance)
finally:
self._recycle_bin.task_done()
@abstractmethod
async def _create_instance(self) -> T:
pass
@abstractmethod
async def _destroy_instance(self, instance: Any) -> None:
pass
# noinspection PyMethodMayBeStatic,PyUnusedLocal
@abstractmethod
async def _check_instance(self, instance: Any) -> bool:
return True
def __len__(self) -> int:
return self._len
def __recycle_instance(self, instance: Any) -> None:
self._len -= 1
self._semaphore.release()
if instance in self._recycle_times:
self._recycle_times.pop(instance)
if instance in self._used:
self._used.discard(instance)
self._recycle_bin.put_nowait(instance)
async def __create_new_instance(self) -> None:
await self._semaphore.acquire()
instance: Any = await self._create_instance()
self._len += 1
if self._recycle:
deadline = self._recycle * (1 + random())
self._recycle_times[instance] += deadline
await self._instances.put(instance)
async def __acquire(self) -> T:
if not self._semaphore.locked():
await self.__create_new_instance()
instance = await self._instances.get()
try:
result = await self._check_instance(instance)
except Exception:
log.exception("Check instance %r failed", instance)
self.__recycle_instance(instance)
else:
if not result:
self.__recycle_instance(instance)
return await self.__acquire()
self._used.add(instance)
return instance
async def __release(self, instance: Any) -> None:
self._used.discard(instance)
if self._recycle and self._recycle_times[instance] < self.loop.time():
self.__recycle_instance(instance)
return
self._instances.put_nowait(instance)
[docs] def acquire(self) -> AsyncContextManager[T]:
return ContextManager(self.__acquire, self.__release)
[docs] async def close(self, timeout: Optional[Number] = None) -> None:
instances = list(self._used)
self._used.clear()
while self._instances.qsize():
try:
instances.append(self._instances.get_nowait())
except asyncio.QueueEmpty:
break
async def log_exception(coro: Awaitable[Any]) -> None:
try:
await coro
except Exception:
log.exception("Exception when task execution")
await asyncio.wait_for(
asyncio.gather(
*[
self.__create_task(
log_exception(self._destroy_instance(instance)),
)
for instance in instances
],
return_exceptions=True,
),
timeout=timeout,
)
await cancel_tasks(self._tasks)