Source code for aiomisc.iterator_wrapper

import asyncio
import inspect
from abc import abstractmethod
from collections import deque
from concurrent.futures import Executor
from queue import Empty as QueueEmpty
from queue import Queue
from time import time
from types import TracebackType
from typing import (
    Any, AsyncIterator, Awaitable, Callable, Deque, Generator, NoReturn,
    Optional, Type, TypeVar, Union,
)
from weakref import finalize

from aiomisc.compat import EventLoopMixin
from aiomisc.counters import Statistic


T = TypeVar("T")
R = TypeVar("R")

GenType = Generator[T, R, None]
FuncType = Callable[[], GenType]


[docs]class ChannelClosed(RuntimeError): pass
[docs]class QueueWrapperBase:
[docs] @abstractmethod def put(self, item: Any) -> None: raise NotImplementedError
[docs] def get(self) -> Any: raise NotImplementedError
[docs]class DequeWrapper(QueueWrapperBase): __slots__ = "queue", def __init__(self) -> None: self.queue: Deque[Any] = deque()
[docs] def get(self) -> Any: if not self.queue: raise QueueEmpty return self.queue.popleft()
[docs] def put(self, item: Any) -> None: return self.queue.append(item)
[docs]class QueueWrapper(QueueWrapperBase): __slots__ = "queue", def __init__(self, max_size: int) -> None: self.queue: Queue = Queue(maxsize=max_size)
[docs] def put(self, item: Any) -> None: return self.queue.put(item)
[docs] def get(self) -> Any: return self.queue.get_nowait()
[docs]def make_queue(max_size: int = 0) -> QueueWrapperBase: if max_size > 0: return QueueWrapper(max_size) return DequeWrapper()
[docs]class FromThreadChannel: SLEEP_LOW_THRESHOLD = 0.0001 SLEEP_DIFFERENCE_DIVIDER = 10 __slots__ = ("queue", "__closed", "__last_received_item") def __init__(self, maxsize: int = 0): self.queue: QueueWrapperBase = make_queue(max_size=maxsize) self.__closed = False self.__last_received_item: float = time()
[docs] def close(self) -> None: self.__closed = True
@property def is_closed(self) -> bool: return self.__closed def __enter__(self) -> "FromThreadChannel": return self def __exit__( self, exc_type: Type[Exception], exc_val: Exception, exc_tb: TracebackType, ) -> None: self.close()
[docs] def put(self, item: Any) -> None: if self.is_closed: raise ChannelClosed self.queue.put(item) self.__last_received_item = time()
def _compute_sleep_time(self) -> Union[float, int]: if self.__last_received_item < 0: return 0 delta = time() - self.__last_received_item if delta > 1: return 1 sleep_time = delta / self.SLEEP_DIFFERENCE_DIVIDER if sleep_time < self.SLEEP_LOW_THRESHOLD: return 0 return sleep_time def __await__(self) -> Any: while True: try: res = self.queue.get() return res except QueueEmpty: if self.is_closed: raise ChannelClosed sleep_time = self._compute_sleep_time() yield from asyncio.sleep(sleep_time).__await__()
[docs] async def get(self) -> Any: return await self
[docs]class IteratorWrapperStatistic(Statistic): started: int queue_size: int queue_length: int yielded: int enqueued: int
[docs]class IteratorWrapper(AsyncIterator, EventLoopMixin): __slots__ = ( "__channel", "__close_event", "__gen_func", "__gen_task", "_statistic", "executor", ) + EventLoopMixin.__slots__ def __init__( self, gen_func: FuncType, loop: Optional[asyncio.AbstractEventLoop] = None, max_size: int = 0, executor: Optional[Executor] = None, statistic_name: Optional[str] = None, ): self._loop = loop self.executor = executor self.__close_event = asyncio.Event() self.__channel: FromThreadChannel = FromThreadChannel(maxsize=max_size) self.__gen_task: Optional[asyncio.Future] = None self.__gen_func: Callable = gen_func self._statistic = IteratorWrapperStatistic(statistic_name) self._statistic.queue_size = max_size @property def closed(self) -> bool: return self.__channel.is_closed @staticmethod def __throw(_: Any) -> NoReturn: raise def _in_thread(self) -> None: self._statistic.started += 1 with self.__channel: try: gen = iter(self.__gen_func()) throw = self.__throw if inspect.isgenerator(gen): throw = gen.throw # type: ignore while not self.closed: item = next(gen) try: self.__channel.put((item, False)) except Exception as e: throw(e) self.__channel.close() break finally: del item self._statistic.enqueued += 1 except StopIteration: return except Exception as e: if self.closed: return self.__channel.put((e, True)) finally: self._statistic.started -= 1 self.loop.call_soon_threadsafe(self.__close_event.set)
[docs] def close(self) -> Awaitable[None]: self.__channel.close() return asyncio.ensure_future(self.wait_closed())
[docs] async def wait_closed(self) -> None: await self.__close_event.wait() if self.__gen_task: await asyncio.gather(self.__gen_task, return_exceptions=True)
def _run(self) -> Any: return self.loop.run_in_executor( self.executor, self._in_thread, ) def __aiter__(self) -> AsyncIterator[Any]: if not self.loop.is_running(): raise RuntimeError("Event loop is not running") if self.__gen_task is None: gen_task = self._run() if gen_task is None: raise RuntimeError("Iterator task was not created") self.__gen_task = gen_task return IteratorProxy(self, self.close) async def __anext__(self) -> Awaitable[T]: try: item, is_exc = await self.__channel.get() except ChannelClosed: await self.wait_closed() raise StopAsyncIteration if is_exc: await self.close() raise item from item self._statistic.yielded += 1 return item async def __aenter__(self) -> "IteratorWrapper": return self async def __aexit__( self, exc_type: Any, exc_val: Any, exc_tb: Any, ) -> None: if self.closed: return await self.close()
[docs]class IteratorProxy(AsyncIterator): def __init__( self, iterator: AsyncIterator, finalizer: Callable[[], Any], ): self.__iterator = iterator finalize(self, finalizer) def __anext__(self) -> Awaitable[Any]: return self.__iterator.__anext__()