import asyncio
import inspect
import threading as _threading
from abc import abstractmethod
from contextlib import suppress
from collections import deque
from concurrent.futures import Executor
from queue import Empty as QueueEmpty
from queue import Full as QueueFull
from queue import Queue
from types import TracebackType
from typing import (
Any,
AsyncIterator,
Awaitable,
Callable,
Deque,
Generator,
Generic,
NoReturn,
Optional,
Type,
)
from weakref import finalize
from .types import P, T
class _ImmediateAwaitable:
"""Awaitable that resolves immediately without being a coroutine.
Unlike ``async def noop(): pass``, discarding this object without
awaiting it does **not** trigger 'coroutine was never awaited'.
Used by :meth:`IteratorWrapper.close` when called from a GC
finalizer thread where the return value is always discarded.
"""
__slots__ = ()
def __await__(self) -> Generator[None, None, None]:
return iter(()) # type: ignore[return-value]
[docs]
class ChannelClosed(RuntimeError):
pass
class ChannelTimeout(asyncio.TimeoutError):
"""Raised when a channel operation times out."""
pass
class QueueWrapperBase:
@abstractmethod
def put(
self, item: Any, *, block: bool = True, timeout: Optional[float] = None
) -> None:
raise NotImplementedError
def get(self) -> Any:
raise NotImplementedError
class DequeWrapper(QueueWrapperBase):
__slots__ = ("_lock", "queue")
def __init__(self) -> None:
self._lock = _threading.Lock()
self.queue: Deque[Any] = deque()
def get(self) -> Any:
with self._lock:
if not self.queue:
raise QueueEmpty
return self.queue.popleft()
def put(
self, item: Any, *, block: bool = True, timeout: Optional[float] = None
) -> None: # noqa: ARG002
with self._lock:
self.queue.append(item)
class QueueWrapper(QueueWrapperBase):
__slots__ = ("queue",)
def __init__(self, max_size: int) -> None:
self.queue: Queue = Queue(maxsize=max_size)
def put(
self, item: Any, *, block: bool = True, timeout: Optional[float] = None
) -> None:
return self.queue.put(item, block=block, timeout=timeout)
def get(self) -> Any:
return self.queue.get_nowait()
def make_queue(max_size: int = 0) -> QueueWrapperBase:
if max_size > 0:
return QueueWrapper(max_size)
return DequeWrapper()
[docs]
class FromThreadChannel:
"""
A thread-safe channel for passing data from threads to async code.
Uses asyncio.Event for efficient waiting instead of polling, which:
- Eliminates CPU-wasting busy loops
- Provides immediate wake-up when data is available
- Supports optional timeout for bounded waiting
Args:
maxsize: Maximum queue size (0 = unbounded)
timeout: Default timeout in seconds for get operations (0 = no timeout)
"""
__slots__ = ("_lock", "queue", "_closed", "_data_event", "_loop", "_timeout")
def __init__(self, maxsize: int = 0, timeout: float = 0):
self._lock = _threading.Lock()
self.queue: QueueWrapperBase = make_queue(max_size=maxsize)
self._closed = False
self._data_event: Optional[asyncio.Event] = None
self._loop: Optional[asyncio.AbstractEventLoop] = None
self._timeout = timeout
def _get_event(self) -> asyncio.Event:
"""Get or create the asyncio.Event, ensuring it's bound to the right loop."""
if self._data_event is None:
self._loop = asyncio.get_running_loop()
self._data_event = asyncio.Event()
return self._data_event
def _signal_data_available(self) -> None:
"""Signal that data is available. Thread-safe."""
loop = self._loop
if loop is not None and self._data_event is not None:
try:
loop.call_soon_threadsafe(self._data_event.set)
except RuntimeError:
pass # loop is closed, no waiters to wake
[docs]
def close(self) -> None:
with self._lock:
self._closed = True
# Wake up any waiters so they can see the channel is closed
self._signal_data_available()
@property
def is_closed(self) -> bool:
with self._lock:
return self._closed
def __enter__(self) -> "FromThreadChannel":
return self
def __exit__(
self,
exc_type: Type[Exception],
exc_val: Exception,
exc_tb: TracebackType,
) -> None:
self.close()
[docs]
def put(self, item: Any) -> None:
"""Put an item into the channel. Thread-safe.
For bounded queues, ensure we periodically re-check whether the
channel was closed while waiting for space so producer threads don't
block forever when consumers are cancelled.
"""
if isinstance(self.queue, QueueWrapper):
while True:
with self._lock:
if self._closed:
raise ChannelClosed
try:
self.queue.put(item, timeout=0.1)
break
except QueueFull:
if self.is_closed:
raise ChannelClosed
continue
else:
with self._lock:
if self._closed:
raise ChannelClosed
self.queue.put(item)
self._signal_data_available()
[docs]
async def get(self, timeout: Optional[float] = None) -> Any:
"""
Get an item from the channel.
Args:
timeout: Timeout in seconds. None uses default, 0 disables timeout.
Raises:
ChannelClosed: If the channel is closed and empty.
ChannelTimeout: If timeout is exceeded.
"""
effective_timeout = timeout if timeout is not None else self._timeout
event = self._get_event()
while True:
try:
result = self.queue.get()
return result
except QueueEmpty:
if self.is_closed:
raise ChannelClosed
# Clear event before waiting (in case it was set)
event.clear()
# Double-check queue after clearing event to avoid race condition
try:
result = self.queue.get()
return result
except QueueEmpty:
pass
# Wait for data with optional timeout
if effective_timeout > 0:
try:
await asyncio.wait_for(event.wait(), timeout=effective_timeout)
except asyncio.TimeoutError:
raise ChannelTimeout(
f"Channel get timed out after {effective_timeout}s"
)
else:
await event.wait()
def __await__(self) -> Any:
return self.get().__await__()
[docs]
class IteratorWrapper(Generic[P, T], AsyncIterator):
__slots__ = (
"__channel",
"__close_event",
"__gen_func",
"__gen_task",
"_loop",
"executor",
)
def __init__(
self,
gen_func: Callable[P, Generator[T, None, None]],
loop: Optional[asyncio.AbstractEventLoop] = None,
max_size: int = 0,
executor: Optional[Executor] = None,
):
self._loop = loop
self.executor = executor
self.__close_event = _threading.Event()
self.__channel: FromThreadChannel = FromThreadChannel(maxsize=max_size)
self.__gen_task: Optional[asyncio.Future] = None
self.__gen_func: Callable = gen_func
@property
def loop(self) -> asyncio.AbstractEventLoop:
if self._loop is None:
self._loop = asyncio.get_running_loop()
return self._loop
@property
def closed(self) -> bool:
return self.__channel.is_closed
@staticmethod
def __throw(e: BaseException) -> NoReturn:
raise e
def _in_thread(self) -> None:
gen: Optional[Generator[Any, None, None]] = None
with self.__channel:
try:
gen = iter(self.__gen_func())
throw = self.__throw
if inspect.isgenerator(gen):
throw = gen.throw # type: ignore
while not self.closed:
item = next(gen)
try:
self.__channel.put((item, False))
except Exception as e:
throw(e)
self.__channel.close()
break
finally:
del item
except StopIteration:
return
except Exception as e:
if self.closed:
return
self.__channel.put((e, True))
finally:
if gen is not None and inspect.isgenerator(gen):
with suppress(Exception):
gen.close()
self.__close_event.set()
[docs]
def close(self) -> Awaitable[None]:
self.__channel.close()
coro = self.wait_closed()
try:
return asyncio.ensure_future(coro)
except RuntimeError:
# GC finalizer may call close() from a non-event-loop thread
# where ensure_future fails. Close the abandoned coroutine and
# fall back to scheduling cleanup via call_soon_threadsafe.
coro.close()
try:
self.loop.call_soon_threadsafe(
lambda: asyncio.ensure_future(self.wait_closed()),
)
except (RuntimeError, AttributeError):
pass
return _ImmediateAwaitable()
[docs]
async def wait_closed(self) -> None:
if self.__gen_task is None:
return
loop = asyncio.get_running_loop()
await loop.run_in_executor(None, self.__close_event.wait)
await asyncio.gather(self.__gen_task, return_exceptions=True)
def _run(self) -> Any:
return self.loop.run_in_executor(self.executor, self._in_thread)
def __aiter__(self) -> AsyncIterator[T]:
if not self.loop.is_running():
raise RuntimeError("Event loop is not running")
if self.__gen_task is None:
gen_task = self._run()
if gen_task is None:
raise RuntimeError("Iterator task was not created")
self.__gen_task = gen_task
return IteratorProxy(self, self.close)
async def __anext__(self) -> T:
try:
item, is_exc = await self.__channel.get()
except ChannelClosed:
await self.wait_closed()
raise StopAsyncIteration
if is_exc:
await self.close()
raise item from item
return item
async def __aenter__(self) -> "IteratorWrapper":
return self
async def __aexit__(
self,
exc_type: Any,
exc_val: Any,
exc_tb: Any,
) -> None:
if self.closed:
return
await self.close()
class IteratorProxy(Generic[T], AsyncIterator):
def __init__(
self,
iterator: AsyncIterator[T],
finalizer: Callable[[], Any],
):
self.__iterator = iterator
finalize(self, finalizer)
def __anext__(self) -> Awaitable[T]:
return self.__iterator.__anext__()