import asyncio
import hashlib
import hmac
import os
import socket
import sys
from inspect import Traceback
from multiprocessing import ProcessError
from os import chmod, urandom
from subprocess import PIPE, Popen
from tempfile import mktemp
from types import MappingProxyType
from typing import (
Any, Callable, Coroutine, Dict, Mapping, Optional, Set, Tuple, Type,
)
from aiomisc.counters import Statistic
from aiomisc.thread_pool import threaded
from aiomisc.utils import (
bind_socket, cancel_tasks, fast_uuid4, set_exception, shield,
)
from aiomisc_log import LOG_FORMAT, LOG_LEVEL
from aiomisc_worker import (
COOKIE_SIZE, INET_AF, INT_SIGNAL, SIGNAL, AddressType, PacketTypes, T, log,
)
from aiomisc_worker.protocol import AsyncProtocol, FileIOProtocol
[docs]class WorkerPoolStatistic(Statistic):
processes: int
spawning: int
queue_size: int
submitted: int
sum_time: float
done: int
success: int
error: int
bad_auth: int
task_added: int
[docs]class WorkerPool:
tasks: asyncio.Queue
server: asyncio.AbstractServer
address: AddressType
initializer: Optional[Callable[[], Any]]
initializer_args: Tuple[Any, ...]
initializer_kwargs: Mapping[str, Any]
_supervisor: Popen
worker_ids: Tuple[bytes, ...]
pids: Set[int]
SERVER_CLOSE_TIMEOUT = 1
if hasattr(socket, "AF_UNIX"):
def _create_socket(self) -> None:
path = mktemp(suffix=".sock", prefix="worker-")
self.socket = bind_socket(
socket.AF_UNIX,
socket.SOCK_STREAM,
address=path,
)
self.address = path
chmod(path, 0o600)
else:
def _create_socket(self) -> None:
self.socket = bind_socket(
INET_AF,
socket.SOCK_STREAM,
address="localhost",
reuse_addr=False,
reuse_port=False,
)
self.address = self.socket.getsockname()[:2]
@staticmethod
def _kill_process(process: Popen) -> None:
if process.returncode is not None:
return None
log.debug("Terminating worker pool process PID: %s", process.pid)
process.kill()
@threaded
def __create_supervisor(self, *identity: bytes) -> Popen:
if self.__closing:
raise RuntimeError("Pool closed")
env = dict(os.environ)
env["AIOMISC_NO_PLUGINS"] = ""
process = Popen(
[sys.executable, "-m", "aiomisc_worker"], stdin=PIPE, env=env,
)
assert process.stdin
log_level = (
log.getEffectiveLevel()
if LOG_LEVEL is None
else LOG_LEVEL.get()
)
log_format = "color" if LOG_FORMAT is None else LOG_FORMAT.get()
proto_stdin = FileIOProtocol(process.stdin)
proto_stdin.send((log_level, log_format))
proto_stdin.send(self.address)
proto_stdin.send(self.__cookie)
proto_stdin.send(identity)
proto_stdin.send((
self.initializer,
self.initializer_args,
dict(self.initializer_kwargs),
))
process.stdin.close()
return process
def __init__(
self, workers: int, max_overflow: int = 0, *,
initializer: Optional[Callable[..., Any]] = None,
initializer_args: Tuple[Any, ...] = (),
initializer_kwargs: Mapping[str, Any] = MappingProxyType({}),
statistic_name: Optional[str] = None,
):
self._create_socket()
self.__cookie = urandom(COOKIE_SIZE)
self.__loop: Optional[asyncio.AbstractEventLoop] = None
self.__futures: Set[asyncio.Future] = set()
self.__task_store: Set[asyncio.Task] = set()
self.__closing = False
self.__closing_lock = asyncio.Lock()
self._statistic = WorkerPoolStatistic(name=statistic_name)
self.__max_overflow = max_overflow
self.workers = workers
self.pids = set()
self.initializer = initializer
self.initializer_args = initializer_args
self.initializer_kwargs = initializer_kwargs
@property
def loop(self) -> asyncio.AbstractEventLoop:
if self.__loop is None:
self.__loop = asyncio.get_running_loop()
return self.__loop
async def __handle_client(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter,
) -> None:
proto = AsyncProtocol(reader, writer)
packet_type, worker_id, digest, pid = await proto.receive()
async with self.__closing_lock:
if self.__closing:
proto.close()
if packet_type == PacketTypes.BAD_INITIALIZER:
packet_type, exc = await proto.receive()
if packet_type != PacketTypes.EXCEPTION:
await proto.send(PacketTypes.BAD_PACKET)
else:
set_exception(self.__futures, exc)
await self.close()
return
if packet_type != PacketTypes.AUTH:
await proto.send(PacketTypes.BAD_PACKET)
if writer.can_write_eof():
writer.write_eof()
return
if worker_id not in self.worker_ids:
log.error("Unknown worker with id %r", worker_id)
return
expected_digest = hmac.HMAC(
self.__cookie,
worker_id,
digestmod=hashlib.sha256,
).digest()
if expected_digest != digest:
await proto.send(PacketTypes.AUTH_FAIL)
if writer.can_write_eof():
writer.write_eof()
log.debug("Bad digest %r expected %r", digest, expected_digest)
return
await proto.send(PacketTypes.AUTH_OK)
self._statistic.processes += 1
self._statistic.spawning += 1
self.pids.add(pid)
try:
while not reader.at_eof():
func: Callable
args: Tuple[Any, ...]
kwargs: Dict[str, Any]
result_future: asyncio.Future
process_future: asyncio.Future
(
func, args, kwargs, result_future, process_future,
) = await self.tasks.get()
if process_future.done() or result_future.done():
continue
try:
process_future.set_result(pid)
await proto.send((PacketTypes.REQUEST, func, args, kwargs))
packet_type, payload = await proto.receive()
if result_future.done():
log.debug(
"Result future %r already done, skipping",
result_future,
)
continue
if packet_type == PacketTypes.RESULT:
result_future.set_result(payload)
elif packet_type in (
PacketTypes.EXCEPTION, PacketTypes.CANCELLED,
):
result_future.set_exception(payload)
del packet_type, payload
except (asyncio.IncompleteReadError, ConnectionError):
if not result_future.done():
result_future.set_exception(
ProcessError(f"Process {pid!r} unexpected exited"),
)
break
except Exception as e:
if not result_future.done():
result_future.set_exception(e)
if not writer.is_closing():
if writer.can_write_eof():
writer.write_eof()
writer.close()
raise
finally:
self._statistic.processes -= 1
self.pids.discard(pid)
def __start_handler(
self, reader: asyncio.StreamReader, writer: asyncio.StreamWriter,
) -> asyncio.Task:
return self.__task(self.__handle_client(reader, writer))
def __task_add(self, task: asyncio.Task) -> None:
self._statistic.task_added += 1
task.add_done_callback(self.__task_store.discard)
self.__task_store.add(task)
def __task(self, coroutine: Coroutine) -> asyncio.Task:
task = self.loop.create_task(coroutine)
self.__task_add(task)
return task
[docs] async def start(self) -> None:
self.tasks = asyncio.Queue(maxsize=self.__max_overflow)
self.server = await asyncio.start_server(
self.__start_handler,
sock=self.socket,
)
del self.socket
self.worker_ids = tuple(
fast_uuid4().bytes for _ in range(self.workers)
)
self._supervisor = await self.__create_supervisor(*self.worker_ids)
def __create_future(self) -> asyncio.Future:
future = self.loop.create_future()
self.__futures.add(future)
future.add_done_callback(self.__futures.discard)
return future
def __reject_futures(self) -> None:
set_exception(self.__futures, RuntimeError("Pool closed"))
[docs] @shield
async def close(self) -> None:
async with (self.__closing_lock):
if self.__closing:
return
self._kill_supervisor()
self.__closing = True
self.server.close()
await asyncio.gather(
asyncio.wait_for(
self.server.wait_closed(),
timeout=self.SERVER_CLOSE_TIMEOUT,
),
return_exceptions=True,
)
await cancel_tasks(tuple(self.__task_store))
await cancel_tasks(tuple(self.__futures))
def _kill_supervisor(self) -> None:
supervisor: Optional[Popen] = getattr(self, "_supervisor", None)
if supervisor is None or supervisor.poll() is not None:
return
log.debug(
"Sending %r to supervisor process PID: %d. Workers: %r",
INT_SIGNAL, supervisor.pid, self.pids,
)
os.kill(self._supervisor.pid, INT_SIGNAL)
def __del__(self) -> None:
self._kill_supervisor()
[docs] async def create_task(
self, func: Callable[..., T],
*args: Any, **kwargs: Any,
) -> T:
result_future = self.__create_future()
process_future = self.__create_future()
await self.tasks.put((
func, args, kwargs, result_future, process_future,
))
pid: int = await process_future
try:
return await result_future
except asyncio.CancelledError:
log.debug("Sending %r to worker PID: %d", SIGNAL, pid)
os.kill(pid, SIGNAL)
raise
async def __aenter__(self) -> "WorkerPool":
await self.start()
return self
async def __aexit__(
self, exc_type: Type[Exception],
exc_val: Exception, exc_tb: Traceback,
) -> None:
await self.close()