Source code for memcachio.pool

from __future__ import annotations

import asyncio
import dataclasses
import enum
import logging
import time
import weakref
from abc import ABC, abstractmethod
from collections import defaultdict
from collections.abc import Callable, Iterable, Sequence
from contextlib import closing, suppress
from typing import Any, Literal, TypeVar, cast

from .commands import AWSAutoDiscoveryConfig, Command
from .compat import Unpack, asyncio_timeout
from .connection import (
    BaseConnection,
    ConnectionParams,
    TCPConnection,
    UnixSocketConnection,
)
from .defaults import (
    BLOCKING_TIMEOUT,
    CONNECT_TIMEOUT,
    IDLE_CONNECTION_TIMEOUT,
    MAX_CONNECTIONS,
    MAXIMUM_ERROR_COUNT_FOR_ENDPOINT_REMOVAL,
    MAXIMUM_RECOVERY_ATTEMPTS,
    MIN_CONNECTIONS,
    MONITOR_UNHEALTHY_ENDPOINTS,
    READ_TIMEOUT,
    REMOVE_UNHEALTHY_ENDPOINTS,
    RETRY_BACKOFF_POLICY,
)
from .errors import ConnectionNotAvailable, MemcachioConnectionError
from .routing import KeyRouter
from .types import (
    AWSAutoDiscoveryEndpoint,
    MemcachedEndpoint,
    SingleMemcachedInstanceEndpoint,
    UnixSocketEndpoint,
    normalize_endpoint,
    normalize_single_server_endpoint,
)

R = TypeVar("R")

logger = logging.getLogger(__name__)


[docs] class EndpointStatus(enum.IntEnum): """ Enumeration of endpoint statuses. Used by :meth:`~ClusterPool.update_endpoint_status` """ #: Mark the endpoint as up and usable UP = enum.auto() #: Mark the endpoint as down and not in use DOWN = enum.auto()
[docs] @dataclasses.dataclass class EndpointHealthcheckConfig: #: Whether to remove unhealthy endpoints on connection errors remove_unhealthy_endpoints: bool = REMOVE_UNHEALTHY_ENDPOINTS #: Maximum numbers of errors to tolerate before marking an endpoint #: as unhealthy maximum_error_count_for_removal: int = MAXIMUM_ERROR_COUNT_FOR_ENDPOINT_REMOVAL #: Whether to monitor unhealthy endpoints after they have been #: removed and attempt to restore them if they recover monitor_unhealthy_endpoints: bool = MONITOR_UNHEALTHY_ENDPOINTS #: Maximum attempts to make to recover unhealthy endpoints maximum_recovery_attempts: int = MAXIMUM_RECOVERY_ATTEMPTS #: Retry backoff policy retry_backoff_policy: Literal["linear", "exponential"] = RETRY_BACKOFF_POLICY
[docs] @dataclasses.dataclass class PoolMetrics: """ Tracks metrics for a connection pool. """ #: Timestamp when the pool was initialized. created_at: float | None = None #: Total number of successfully processed requests. requests_processed: int = 0 #: Total number of requests that failed. requests_failed: int = 0 #: Timestamp when the last connection was established last_connection_created: float = 0.0 #: Timestamp when the last request completed processing. last_request_processed: float = 0.0 #: Average time taken to process requests. average_response_time: float = 0.0 #: Number of requests currently pending. requests_pending: int = 0 #: Current connections current_connections: int = 0 #: Maximum connections created maximum_connections: int = 0 #: Total number of connection errors connection_errors: int = 0 #: Current endpoint status status: EndpointStatus | None = None #: Number of times the pool was marked down down_count: int = 0 #: Number of times the node was switched from down to up recovery_count: int = 0 def on_connection_created(self, connection: BaseConnection) -> None: self.last_connection_created = time.time() self.current_connections += 1 self.maximum_connections = max(self.maximum_connections, self.current_connections) if not self.status: self.status = EndpointStatus.UP def on_connection_error(self, connection: BaseConnection | None, exc: Exception) -> None: self.connection_errors += 1 if not self.status: self.status = EndpointStatus.UP def on_connection_terminated(self, connection: BaseConnection) -> None: self.current_connections -= 1 def on_command_dispatched(self, command: Command[Any]) -> None: self.requests_pending += 1 pass def on_status_update(self, status: EndpointStatus) -> None: if self.status != status: match status: case EndpointStatus.DOWN: self.down_count += 1 case EndpointStatus.UP: self.recovery_count += 1 self.status = status def on_command_completed(self, command: Command[Any]) -> None: self.requests_pending -= 1 self.last_request_processed = time.time() if not command.noreply: if command.response.done() and not command.response.cancelled(): if not command.response.exception(): self.average_response_time = ( (self.requests_processed * self.average_response_time) + command.response_time ) / (self.requests_processed + 1) self.requests_processed += 1 else: self.requests_failed += 1 elif command.request_sent.done(): self.requests_processed += 1 @classmethod def merge(cls, metrics: Iterable[PoolMetrics]) -> PoolMetrics: if not metrics: return PoolMetrics() return PoolMetrics( created_at=min(metrics, key=lambda m: m.created_at or 0).created_at, requests_processed=sum([m.requests_processed for m in metrics]), requests_failed=sum([m.requests_failed for m in metrics]), requests_pending=sum([m.requests_pending for m in metrics]), last_connection_created=max( metrics, key=lambda m: m.last_connection_created ).last_connection_created, last_request_processed=max( metrics, key=lambda m: m.last_request_processed ).last_request_processed, average_response_time=sum( [m.requests_processed * m.average_response_time for m in metrics] ) / (sum([m.requests_processed for m in metrics]) or 1), current_connections=sum([m.current_connections for m in metrics]), maximum_connections=sum([m.maximum_connections for m in metrics]), connection_errors=sum([m.connection_errors for m in metrics]), recovery_count=sum([m.recovery_count for m in metrics]), down_count=sum([m.down_count for m in metrics]), status=EndpointStatus.UP if all([m.status == EndpointStatus.UP for m in metrics]) else EndpointStatus.DOWN, )
[docs] class Pool(ABC): """ The abstract base class for a connection pool used by :class:`~memcachio.Client` """ def __init__( self, endpoint: MemcachedEndpoint, min_connections: int = MIN_CONNECTIONS, max_connections: int = MAX_CONNECTIONS, blocking_timeout: float = BLOCKING_TIMEOUT, idle_connection_timeout: float = IDLE_CONNECTION_TIMEOUT, **connection_args: Unpack[ConnectionParams], ): """ :param endpoint: The memcached server address(es) :param min_connections: The minimum number of connections to keep in the pool. :param max_connections: The maximum number of simultaneous connections to memcached. :param blocking_timeout: The timeout (in seconds) to wait for a connection to become available. :param idle_connection_timeout: The maximum time to allow a connection to remain idle in the pool before being disconnected :param connection_args: Arguments to pass to the constructor of :class:`~memcachio.BaseConnection`. refer to :class:`~memcachio.connection.ConnectionParams` """ self.endpoint = normalize_endpoint(endpoint) self._max_connections = max_connections self._min_connections = min_connections self._blocking_timeout = blocking_timeout self._idle_connection_timeout = idle_connection_timeout self._connection_parameters: ConnectionParams = connection_args @property @abstractmethod def metrics(self) -> PoolMetrics: """ Pool health metrics """ ...
[docs] @abstractmethod async def initialize(self) -> None: """ Initialize the connection pool. The method can throw a connection error if the target server(s) can't be connected to. """ ...
[docs] @abstractmethod async def execute_command(self, command: Command[R]) -> None: """ Dispatches a command to memcached. To receive the response the future pointed to by ``command.response`` should be awaited as it will be updated when the response (or exception) is available on the transport. """ ...
[docs] @abstractmethod def close(self) -> None: """ Closes the connection pool and disconnects all active connections """ ...
def __del__(self) -> None: with suppress(RuntimeError): self.close()
[docs] class SingleServerPool(Pool): """ Connection pool to manage connections to a single memcached server. """ def __init__( self, endpoint: SingleMemcachedInstanceEndpoint, min_connections: int = MIN_CONNECTIONS, max_connections: int = MAX_CONNECTIONS, blocking_timeout: float = BLOCKING_TIMEOUT, idle_connection_timeout: float = IDLE_CONNECTION_TIMEOUT, **connection_args: Unpack[ConnectionParams], ) -> None: super().__init__( endpoint, min_connections=min_connections, max_connections=max_connections, blocking_timeout=blocking_timeout, idle_connection_timeout=idle_connection_timeout, **connection_args, ) self.__server_endpoint = endpoint self.__connections: asyncio.Queue[BaseConnection | None] = asyncio.LifoQueue( self._max_connections ) self.__pool_lock: asyncio.Lock = asyncio.Lock() self._connection_class: type[TCPConnection | UnixSocketConnection] self.__initialized = False self._connection_parameters.setdefault("on_connect_callbacks", []).append( self.__on_connection_created ) self._connection_parameters.setdefault("on_disconnect_callbacks", []).append( self.__on_connection_disconnected ) self._active_connections: weakref.WeakSet[BaseConnection] = weakref.WeakSet() self.__metrics = PoolMetrics() while True: try: self.__connections.put_nowait(None) except asyncio.QueueFull: break @property def metrics(self) -> PoolMetrics: return self.__metrics
[docs] async def initialize(self) -> None: if self.__initialized: return async with self.__pool_lock: if self.__initialized: return if not self.metrics.created_at: self.metrics.created_at = time.time() for _ in range(self._min_connections): connection = self.__connections.get_nowait() try: if not connection: self.__connections.put_nowait(await self.__create_connection()) except ConnectionError as e: self.metrics.on_connection_error(connection, e) self.__connections.put_nowait(None) raise self.__initialized = True
[docs] async def execute_command(self, command: Command[R]) -> None: connection, release = None, None try: connection, release = await self.__get_connection_from_pool() await connection.connect() connection.create_request(command) self.metrics.on_command_dispatched(command) await command.request_sent (command.request_sent if command.noreply else command.response).add_done_callback( lambda _: self.metrics.on_command_completed(command) ) if release: if command.noreply: self.__connections.put_nowait(connection) else: command.response.add_done_callback( lambda _: self.__connections.put_nowait(connection) ) except MemcachioConnectionError as e: self.metrics.on_connection_error(connection, e) if release: self.__connections.put_nowait(None) raise
[docs] def close(self) -> None: while True: try: if connection := self.__connections.get_nowait(): connection.close() except asyncio.QueueEmpty: break while True: try: self.__connections.put_nowait(None) except asyncio.QueueFull: break self.__initialized = False
async def __get_connection_from_pool(self) -> tuple[BaseConnection, bool]: await self.initialize() released = False try: async with asyncio_timeout(self._blocking_timeout): connection = await self.__connections.get() try: if connection and connection.reusable: self.__connections.put_nowait(connection) released = True else: if not connection: connection = await self.__create_connection() except ConnectionError as e: self.metrics.on_connection_error(connection, e) self.__connections.put_nowait(None) raise return connection, not released except asyncio.TimeoutError: raise ConnectionNotAvailable(self.__server_endpoint, self._blocking_timeout) async def __create_connection(self) -> BaseConnection: connection: BaseConnection if isinstance(self.__server_endpoint, UnixSocketEndpoint): connection = UnixSocketConnection(self.__server_endpoint, **self._connection_parameters) else: connection = TCPConnection(self.__server_endpoint, **self._connection_parameters) if not connection.connected: await connection.connect() return connection def __check_connection_idle(self, connection: BaseConnection) -> None: if ( time.time() - connection.metrics.last_read > self._idle_connection_timeout and connection.metrics.requests_pending == 0 and len(self._active_connections) > self._min_connections ): connection.close() self._active_connections.discard(connection) elif connection.connected: asyncio.get_running_loop().call_later( self._idle_connection_timeout, self.__check_connection_idle, connection ) def __on_connection_created(self, connection: BaseConnection) -> None: self._active_connections.add(connection) self.metrics.on_connection_created(connection) if self._idle_connection_timeout: asyncio.get_running_loop().call_later( self._idle_connection_timeout, self.__check_connection_idle, connection ) def __on_connection_disconnected(self, connection: BaseConnection) -> None: self.metrics.on_connection_terminated(connection) self._active_connections.discard(connection)
[docs] class ClusterPool(Pool): """ Connection pool to manage connections to multiple memcached servers. For multi-key commands, rendezvous hashing is used to distribute the command to the appropriate endpoints. """ def __init__( self, endpoint: Sequence[SingleMemcachedInstanceEndpoint] | AWSAutoDiscoveryEndpoint, min_connections: int = MIN_CONNECTIONS, max_connections: int = MAX_CONNECTIONS, blocking_timeout: float = BLOCKING_TIMEOUT, idle_connection_timeout: float = IDLE_CONNECTION_TIMEOUT, hashing_function: Callable[[str], int] | None = None, endpoint_healthcheck_config: EndpointHealthcheckConfig | None = None, **connection_args: Unpack[ConnectionParams], ) -> None: """ :param endpoint: The memcached server address(es) :param min_connections: The minimum number of connections per endpoint to keep in the pool. :param max_connections: The maximum number of simultaneous connections per memcached endpoint. :param blocking_timeout: The timeout (in seconds) to wait for a connection to become available. :param idle_connection_timeout: The maximum time to allow a connection to remain idle in the pool before being disconnected :param hashing_function: A function to use for routing keys to endpoints for multi-key commands. If none is provided the default :func:`hashlib.md5` implementation from the standard library is used. :param endpoint_healthcheck_config: Configuration to control whether endpoints are automatically removed/recovered based on health checks. :param connection_args: Arguments to pass to the constructor of :class:`~memcachio.BaseConnection`. refer to :class:`~memcachio.connection.ConnectionParams` """ self._cluster_pools: dict[SingleMemcachedInstanceEndpoint, SingleServerPool] = {} self.__pool_lock = asyncio.Lock() self.__initialized = False super().__init__( endpoint, min_connections=min_connections, max_connections=max_connections, blocking_timeout=blocking_timeout, idle_connection_timeout=idle_connection_timeout, **connection_args, ) self.__autodiscovery_current_version: int = 0 self.__autodiscovery_endpoint: AWSAutoDiscoveryEndpoint | None = None self.__all_endpoints: set[SingleMemcachedInstanceEndpoint] = set() if isinstance(self.endpoint, AWSAutoDiscoveryEndpoint): self.__autodiscovery_endpoint = self.endpoint else: self.__all_endpoints = { normalize_single_server_endpoint(endpoint) for endpoint in cast(Iterable[SingleMemcachedInstanceEndpoint], self.endpoint) } self.__unhealthy_endpoints: set[SingleMemcachedInstanceEndpoint] = set() self._router = KeyRouter(self.__all_endpoints, hasher=hashing_function) self.__healthcheck_tasks: dict[SingleMemcachedInstanceEndpoint, asyncio.Task[None]] = {} self.__endpoint_healthcheck_config: EndpointHealthcheckConfig = ( endpoint_healthcheck_config or EndpointHealthcheckConfig() ) self.__autodiscovery_task: asyncio.Task[None] | None = None @property def metrics(self) -> PoolMetrics: """ Aggregate metrics obtained from the sub-pools for each endpoint that this cluster pool is configured against. """ return PoolMetrics.merge( [ self._cluster_pools[endpoint].metrics for endpoint in self.__all_endpoints if endpoint in self._cluster_pools ] ) @property def endpoints(self) -> set[SingleMemcachedInstanceEndpoint]: return self.__all_endpoints - self.__unhealthy_endpoints async def __autodiscovery_query(self) -> None: if self.__autodiscovery_endpoint: with closing( TCPConnection( (self.__autodiscovery_endpoint.host, self.__autodiscovery_endpoint.port), connect_timeout=self._connection_parameters.get( "connect_timeout", CONNECT_TIMEOUT ), read_timeout=self._connection_parameters.get("read_timeout", READ_TIMEOUT), ) ) as connection: await connection.connect() command = AWSAutoDiscoveryConfig() connection.create_request(command) autodiscovery_version, endpoints = await command.response if autodiscovery_version > self.__autodiscovery_current_version: new_endpoints = endpoints - self.__all_endpoints discarded_endpoints = self.__all_endpoints - endpoints for endpoint in new_endpoints: self.add_endpoint(endpoint) for endpoint in discarded_endpoints: self.remove_endpoint(endpoint) self.__autodiscovery_current_version = autodiscovery_version async def __refresh_autodiscovered_endpoints(self) -> None: if not self.__autodiscovery_endpoint: return while True: try: await self.__autodiscovery_query() await asyncio.sleep(self.__autodiscovery_endpoint.refresh_interval) except asyncio.CancelledError: break
[docs] async def initialize(self) -> None: if self.__initialized: return async with self.__pool_lock: if self.__initialized: return if self.__autodiscovery_endpoint: await self.__autodiscovery_query() self.__autodiscovery_task = asyncio.create_task( self.__refresh_autodiscovered_endpoints() ) else: for endpoint in self.endpoints: self.add_endpoint(endpoint) await asyncio.gather( *[self._cluster_pools[endpoint].initialize() for endpoint in self.endpoints] ) self.__initialized = True
[docs] async def execute_command(self, command: Command[R]) -> None: """ Dispatches a command to the appropriate memcached endpoint(s). To receive the response the future pointed to by ``command.response`` should be awaited as it will be updated when the response(s) (or exception) are available on the transport and merged (if it is a command that spans multiple endpoints). """ try: await self.initialize() if command.keys and len(command.keys) == 1: await self._cluster_pools[self._router.get_node(command.keys[0])].execute_command( command ) else: mapping = defaultdict(list) if command.keys: for key in command.keys: mapping[self._router.get_node(key)].append(key) endpoint_commands = { endpoint: command.clone(keys) for endpoint, keys in mapping.items() } else: endpoint_commands = { endpoint: command.clone(command.keys) for endpoint in self.endpoints } await asyncio.gather( *[ self._cluster_pools[endpoint].execute_command(endpoint_command) for endpoint, endpoint_command in endpoint_commands.items() ] ) if not command.noreply: command.response.set_result( command.merge( await asyncio.gather( *[command.response for command in endpoint_commands.values()] ) ) ) except MemcachioConnectionError as e: if self.__endpoint_healthcheck_config.remove_unhealthy_endpoints: if ( not (current_task := self.__healthcheck_tasks.get(e.endpoint, None)) or current_task.done() ): self.__healthcheck_tasks[e.endpoint] = asyncio.create_task( self.__check_endpoint_health(e.endpoint) ) raise
async def __check_endpoint_health(self, endpoint: SingleMemcachedInstanceEndpoint) -> None: attempt = 0 while True: try: try: if pool := self._cluster_pools.get(endpoint, None): if ( pool.metrics.connection_errors < self.__endpoint_healthcheck_config.maximum_error_count_for_removal ): return await pool.initialize() if pool.metrics.current_connections > 0: if self.__endpoint_healthcheck_config.monitor_unhealthy_endpoints: logger.info( f"Memcached server at {endpoint} has recovered after {2**attempt} seconds" ) self.update_endpoint_status(endpoint, EndpointStatus.UP) break else: pool.close() except MemcachioConnectionError: self.update_endpoint_status(endpoint, EndpointStatus.DOWN) if ( not self.__endpoint_healthcheck_config.monitor_unhealthy_endpoints or attempt == self.__endpoint_healthcheck_config.maximum_recovery_attempts ): logger.error(f"Memcached server at {endpoint} unreachable and marked down") break except Exception: logger.exception("Unknown error while checking endpoint health") break if ( endpoint in self.__unhealthy_endpoints and self.__endpoint_healthcheck_config.monitor_unhealthy_endpoints and attempt < self.__endpoint_healthcheck_config.maximum_recovery_attempts ): match self.__endpoint_healthcheck_config.retry_backoff_policy: case "linear": delay = attempt case "exponential": delay = 2**attempt logger.debug( f"Memcached server at {endpoint} still down, attempting recovery attempt in {delay} seconds" ) attempt += 1 await asyncio.sleep(delay) except asyncio.CancelledError: break
[docs] def close(self) -> None: for pool in self._cluster_pools.values(): pool.close() for task in self.__healthcheck_tasks.values(): task.cancel() self.__healthcheck_tasks.clear() self.__unhealthy_endpoints.clear() self.__initialized = False
[docs] def add_endpoint(self, endpoint: SingleMemcachedInstanceEndpoint) -> None: """ Add a new endpoint to this pool """ normalized_endpoint = normalize_single_server_endpoint(endpoint) self.__all_endpoints.add(normalized_endpoint) self._router.add_node(normalized_endpoint) if normalized_endpoint not in self._cluster_pools: self._cluster_pools[normalized_endpoint] = SingleServerPool( normalized_endpoint, min_connections=self._min_connections, max_connections=self._max_connections, blocking_timeout=self._blocking_timeout, idle_connection_timeout=self._idle_connection_timeout, **self._connection_parameters, )
[docs] def remove_endpoint(self, endpoint: SingleMemcachedInstanceEndpoint) -> None: """ Remove an endpoint from this pool. This will immediately also close all connections to that endpoint. """ normalized_endpoint = normalize_single_server_endpoint(endpoint) self.__all_endpoints.discard(normalized_endpoint) self._router.remove_node(normalized_endpoint) if pool := self._cluster_pools.pop(normalized_endpoint, None): pool.close()
[docs] def update_endpoint_status( self, endpoint: SingleMemcachedInstanceEndpoint, status: EndpointStatus ) -> None: """ Change the status of an endpoint in this pool. Marking an endpoint as :enum:`EndpointStatus.DOWN` will immediately stop routing requests to it, while marking it as :enum:`EndpointStatus.UP` will immediately start routing requests to it. """ normalized_endpoint = normalize_single_server_endpoint(endpoint) match status: case EndpointStatus.UP: self.__unhealthy_endpoints.discard(normalized_endpoint) self._router.add_node(normalized_endpoint) case EndpointStatus.DOWN: self.__unhealthy_endpoints.add(normalized_endpoint) self._router.remove_node(normalized_endpoint) self._cluster_pools[normalized_endpoint].metrics.on_status_update(status)