Source code for aiomisc.circuit_breaker

import asyncio
import logging
import threading
from collections import Counter, deque
from contextlib import contextmanager
from enum import IntEnum, unique
from functools import wraps
from random import random
from typing import Any, Awaitable, Callable
from typing import Counter as CounterType
from typing import (
    Deque, Generator, Iterable, Optional, Tuple, Type, TypeVar, Union,
)

from aiomisc.compat import EventLoopMixin
from aiomisc.counters import Statistic
from aiomisc.utils import awaitable


log = logging.getLogger(__name__)
Number = Union[int, float]
StatisticType = Deque[Tuple[int, Counter]]
T = TypeVar("T")
ExceptionInspectorType = Optional[Callable[[Exception], bool]]


[docs]@unique class CounterKey(IntEnum): FAIL = 0 OK = 1 TOTAL = 2
[docs]@unique class CircuitBreakerStates(IntEnum): PASSING = 0 BROKEN = 1 RECOVERING = 2
[docs]class CircuitBreakerStatistic(Statistic): call_count: int error_ratio: float error_ratio_threshold: float call_passing: int call_broken: int call_recovering: int call_recovering_ok: int call_recovering_failed: int
[docs]class CircuitBroken(Exception): __slots__ = ("last_exception",) def __init__(self, last_exception: Optional[Exception]): self.last_exception = last_exception def __repr__(self) -> str: return f"<CircuitBroken: {self.last_exception!r}>"
[docs]class CircuitBreaker(EventLoopMixin): __slots__ = ( "_broken_time", "_counters", "_error_ratio", "_exceptions", "_exception_inspector", "_last_exception", "_lock", "_passing_time", "_recovery_at", "_recovery_ratio", "_recovery_time", "_response_time", "_state", "_statistic", "_stuck_until", ) + EventLoopMixin.__slots__ BUCKET_COUNT = 10 # Thresholds when state will be changed # * RECOVER state will be changed to BROKEN # when error ratio will be greater or equal then # RECOVER_BROKEN_THRESHOLD RECOVER_BROKEN_THRESHOLD = 0.5 # * PASSING state will be changed to BROKEN # when error ratio will be greater or equal then # PASSING_BROKEN_THRESHOLD PASSING_BROKEN_THRESHOLD = 1 _stuck_until: Number _recovery_at: Number _last_exception: Optional[Exception] def __init__( self, error_ratio: float, response_time: Number, exceptions: Iterable[Type[Exception]] = (Exception,), recovery_time: Optional[Number] = None, broken_time: Optional[Number] = None, passing_time: Optional[Number] = None, exception_inspector: Optional[ExceptionInspectorType] = None, statistic_name: Optional[str] = None, ): """ Circuit Breaker pattern implementation. The class instance collects call statistics through the ``call`` or ``call async`` methods. The state machine has three states: * ``CircuitBreakerStates.PASSING`` * ``CircuitBreakerStates.BROKEN`` * ``CircuitBreakerStates.RECOVERING`` In passing mode all results or exceptions will be returned as is. Statistic collects for each call. In broken mode returns exception ``CircuitBroken`` for each call. Statistic doesn't collecting. In recovering mode the part of calls is real function calls and remainings raises ``CircuitBroken``. The count of real calls grows exponentially in this case but when 20% (by default) will be failed the state returns to broken state. :param error_ratio: Failed to success calls ratio. The state might be changed if ratio will reach given value within ``response time`` (in seconds). Value between 0.0 and 1.0. :param response_time: Time window to collect statistics (seconds) :param exceptions: Only this exceptions will affect ratio. Base class ``Exception`` used by default. :param recovery_time: minimal time in recovery state (seconds) :param broken_time: minimal time in broken state (seconds) :param passing_time: minimum time in passing state (seconds) """ if response_time <= 0: raise ValueError("Response time must be greater then zero") if 0. > error_ratio >= 1.: raise ValueError( "Error ratio must be between 0 and 1 not %r" % error_ratio, ) self._statistic: StatisticType = deque( maxlen=self.BUCKET_COUNT, ) self._lock = threading.RLock() self._error_ratio = error_ratio self._state = CircuitBreakerStates.PASSING self._response_time = response_time self._stuck_until = 0 self._recovery_at = 0 self._exceptions = tuple(frozenset(exceptions)) self._exception_inspector = exception_inspector self._passing_time = passing_time or self._response_time self._broken_time = broken_time or self._response_time self._recovery_time = recovery_time or self._response_time self._last_exception = None self._counters = CircuitBreakerStatistic(statistic_name) self._counters.error_ratio_threshold = error_ratio @property def response_time(self) -> Number: return self._response_time @property def state(self) -> CircuitBreakerStates: return self._state def _get_time(self) -> float: return self.loop.time()
[docs] def bucket(self) -> int: ts = self._get_time() * self.BUCKET_COUNT return int(ts - (ts % self._response_time))
[docs] def counter(self) -> Counter: with self._lock: current = self.bucket() if not self._statistic: # Empty statistic just return a new counter counter: CounterType[int] = Counter() self._statistic.append((current, counter)) return counter bucket, counter = self._statistic[-1] if current != bucket: # Append Counter to statistic or shift when maxsize reached counter = Counter() self._statistic.append((current, counter)) return counter
def __gen_statistic(self) -> Generator[Counter, None, None]: """ Generator which returns only buckets Counters not before current_time """ not_before = self.bucket() - (self._response_time * self.BUCKET_COUNT) for idx in range(len(self._statistic) - 1, -1, -1): bucket, counter = self._statistic[idx] if bucket < not_before: break yield counter
[docs] def get_state_delay(self) -> Number: delay = self._stuck_until - self._get_time() return max(delay, 0)
def _inspect_exception(self, e: Exception) -> int: if not self._exception_inspector: return 1 # noinspection PyBroadException try: return 1 if self._exception_inspector(e) else 0 except Exception: log.exception( "Unhandled exception in %r", self._exception_inspector, ) return 1 def _on_passing( self, counter: CounterType[int], ) -> Generator[Any, Any, Any]: try: yield counter[CounterKey.OK] += 1 self._last_exception = None except self._exceptions as e: self._last_exception = e counter[CounterKey.FAIL] += self._inspect_exception(e) raise finally: counter[CounterKey.TOTAL] += 1 def _on_recover( self, counter: CounterType[int], ) -> Generator[Any, Any, Any]: current_time = self._get_time() condition = (random() + 1) < ( 2 ** ((current_time - self._recovery_at) / self._recovery_time) ) if not condition: self._counters.call_recovering_failed += 1 raise CircuitBroken(self._last_exception) self._counters.call_recovering_ok += 1 yield from self._on_passing(counter) @property def recovery_ratio(self) -> Number: total_count = 0 upper_count = 0 for counter in self.__gen_statistic(): total_count += 1 if not counter[CounterKey.TOTAL]: continue fail_ratio = counter[CounterKey.FAIL] / counter[CounterKey.TOTAL] if fail_ratio >= self._error_ratio: upper_count += 1 if not total_count: return 0 return upper_count / total_count def _compute_state(self) -> None: current_time = self._get_time() if current_time < self._stuck_until: # Skip state changing until return if self._state is CircuitBreakerStates.BROKEN: self._state = CircuitBreakerStates.RECOVERING self._recovery_at = current_time return # Do not compute when not enough statistic if ( self._state is CircuitBreakerStates.PASSING and len(self._statistic) < self.BUCKET_COUNT ): return recovery_ratio = self.recovery_ratio if self._state is CircuitBreakerStates.PASSING: if recovery_ratio >= self.PASSING_BROKEN_THRESHOLD: self._stuck_until = current_time + self._broken_time self._state = CircuitBreakerStates.BROKEN self._statistic.clear() return if self._state is not CircuitBreakerStates.RECOVERING: return if recovery_ratio >= ( self.RECOVER_BROKEN_THRESHOLD * self._error_ratio ): self._stuck_until = current_time + self._broken_time self._state = CircuitBreakerStates.BROKEN self._statistic.clear() return recovery_length = current_time - self._recovery_at if recovery_length >= self._recovery_time: self._stuck_until = current_time + self._passing_time self._state = CircuitBreakerStates.PASSING return
[docs] @contextmanager def context(self) -> Generator[Any, Any, Any]: counter = self.counter() self._compute_state() self._counters.call_count += 1 if self._state is CircuitBreakerStates.PASSING: self._counters.call_passing += 1 yield from self._on_passing(counter) return elif self._state is CircuitBreakerStates.BROKEN: self._counters.call_broken += 1 raise CircuitBroken(self._last_exception) elif self._state is CircuitBreakerStates.RECOVERING: self._counters.call_recovering += 1 yield from self._on_recover(counter) return raise NotImplementedError(self._state)
[docs] def call( self, func: Callable[..., T], *args: Any, **kwargs: Any, ) -> T: with self.context(): return func(*args, **kwargs)
[docs] async def call_async( self, func: Callable[..., Awaitable[T]], *args: Any, **kwargs: Any, ) -> T: with self.context(): return await awaitable(func)(*args, **kwargs) # type: ignore
def __repr__(self) -> str: return "<{}: state={!r} recovery_ratio={!s}>".format( self.__class__.__name__, self._state, self.recovery_ratio, )
CutoutFuncType = Union[Callable[..., T], Callable[..., Awaitable[T]]] CutoutDecoratorReturnType = Callable[..., Union[T, Awaitable[T]]] CutoutReturnType = Callable[[CutoutFuncType], CutoutDecoratorReturnType]
[docs]def cutout( ratio: float, response_time: Union[int, float], *exceptions: Type[Exception], **kwargs: Any, ) -> CutoutReturnType: circuit_breaker = CircuitBreaker( error_ratio=ratio, response_time=response_time, exceptions=exceptions, **kwargs, ) def decorator(func: CutoutFuncType) -> CutoutDecoratorReturnType: @wraps(func) async def async_wrapper( *args: Any, **kw: Any, ) -> T: return await circuit_breaker.call_async(func, *args, **kw) @wraps(func) def wrapper(*args: Any, **kw: Any) -> Any: return circuit_breaker.call(func, *args, **kw) if asyncio.iscoroutinefunction(func): return async_wrapper return wrapper return decorator