Source code for memcachio.commands

from __future__ import annotations

import abc
import asyncio
import copy
import dataclasses
import time
import weakref
from asyncio import Future
from collections import ChainMap
from collections.abc import Sequence
from io import BytesIO
from typing import AnyStr, ClassVar, Generic, TypeVar, cast

from .compat import Self
from .constants import LINE_END, Commands, Responses
from .errors import (
    AutoDiscoveryError,
    ClientError,
    MemcachedError,
    NotEnoughData,
    ServerError,
)
from .types import KeyT, MemcachedItem, SingleMemcachedInstanceEndpoint, TCPEndpoint, ValueT
from .utils import bytestr, decodedstr

R = TypeVar("R")


[docs] @dataclasses.dataclass class Request(Generic[R]): command: weakref.ProxyType[Command[R]] header: bytes body: list[bytes] = dataclasses.field(default_factory=lambda: []) def __bytes__(self) -> bytes: request_body: bytes = self.command.name if self.header: request_body += b" %b" % self.header if self.command.noreply: request_body += b" noreply" request_body += LINE_END request_body += LINE_END.join(self.body) return request_body
[docs] class Command(abc.ABC, Generic[R]): """ The abstract generic command class used throughout the client -> connection pool -> connection lifecycle to manage dispatching a request, parsing the response (if not a ``noreply`` request) and resolving the future to receive the eventual parsed response. This class is not meant for direct use, however it's structure might be interesting if implementing a custom connection pool or using the connection classes directly. As an example the following snippet demonstrates creating a ``set`` command that takes no flags or expiry and is dispatched directly to a connection without using :class:`~memcachio.Client` or a :class:`~memcachio.pool.Pool`:: import asyncio import weakref from io import BytesIO import memcachio from memcachio.commands import Command, Request from memcachio.constants import Commands class MyCustomSetCommand(Command[bool]): name = Commands.SET def __init__(self, key: str, value: bytes, noreply: bool = False): self.value = value super().__init__(key, noreply=noreply) def build_request(self) -> memcachio.commands.Request[bool]: return Request( weakref.proxy(self), b"%b 0 0 %d" % (self.keys[0].encode(), len(self.value)), [self.value + b"\\r\\n"], ) def parse(self, data: BytesIO, endpoint: SingleMemcachedInstanceEndpoint) -> bool: response = data.readline() return response.rstrip() == b"STORED" async def example(): command = MyCustomSetCommand("fubar", b"1234") connection = memcachio.TCPConnection(("localhost", 11211)) await connection.connect() connection.create_request(command) assert await command.response asyncio.run(example()) """ __slots__ = ( "__weakref__", "_keys", "noreply", "request_sent", "response", "response_time", "created_at", ) #: The name of the command name: ClassVar[Commands] #: A future that should be set when the request has been written #: to the socket request_sent: Future[bool] #: A future that should be set when the request has received a response #: and parsed response: Future[R] #: response_time: float def __init__(self, *keys: KeyT, noreply: bool = False): """ :param keys: The keys this command operates on :param noreply: Whether the server should send a response to the command. """ self.noreply = noreply self.created_at = time.time() self.response_time = 0 self._keys: list[str] = [decodedstr(key) for key in keys or []] self.request_sent = asyncio.get_running_loop().create_future() self.response = asyncio.get_running_loop().create_future() (self.request_sent if self.noreply else self.response).add_done_callback( lambda _: self._update_response_time() ) def merge(self, responses: list[R]) -> R: return responses[0] @property def keys(self) -> list[str]: return self._keys def clone(self, keys: Sequence[KeyT]) -> Self: subset = copy.copy(self) subset._keys = list(decodedstr(key) for key in keys) subset.request_sent = asyncio.get_running_loop().create_future() subset.response = asyncio.get_running_loop().create_future() (subset.request_sent if subset.noreply else subset.response).add_done_callback( lambda _: subset._update_response_time() ) return subset
[docs] @abc.abstractmethod def build_request(self) -> Request[R]: """ Build the header and (optional) payload for the request to send to the memcached server """ ...
[docs] @abc.abstractmethod def parse(self, data: BytesIO, endpoint: SingleMemcachedInstanceEndpoint) -> R: """ Parse the response from the buffer assuming the current position of the buffer contains the response for this command. """ ...
def _update_response_time(self) -> None: self.response_time = time.time() - self.created_at def _check_header(self, header: bytes) -> None: if not header.endswith(LINE_END): raise NotEnoughData(len(header)) response = header.rstrip() if response.startswith(Responses.CLIENT_ERROR): raise ClientError(decodedstr(response.split(Responses.CLIENT_ERROR)[1]).strip()) elif response.startswith(Responses.SERVER_ERROR): raise ServerError(decodedstr(response.split(Responses.SERVER_ERROR)[1]).strip()) elif response.startswith(Responses.ERROR): raise MemcachedError(decodedstr(response).strip()) return None
class BasicResponseCommand(Command[bool]): success: ClassVar[Responses] def parse(self, data: BytesIO, endpoint: SingleMemcachedInstanceEndpoint) -> bool: response = data.readline() self._check_header(response) if not response.rstrip() == self.success.value: return False return True class GetCommand(Command[dict[AnyStr, MemcachedItem[AnyStr]]]): __slots__ = ("decode_responses", "encoding", "items") name = Commands.GET def __init__(self, *keys: KeyT, decode: bool = False, encoding: str = "utf-8") -> None: self.items: list[MemcachedItem[AnyStr]] = [] self.decode_responses = decode self.encoding = encoding super().__init__(*keys, noreply=False) def build_request(self) -> Request[R]: return Request(weakref.proxy(self), f"{' '.join(self.keys)}".encode()) def parse( self, data: BytesIO, endpoint: SingleMemcachedInstanceEndpoint ) -> dict[AnyStr, MemcachedItem[AnyStr]]: while True: header = data.readline() self._check_header(header) if header.lstrip() == (Responses.END + LINE_END): break parts = header.split() if len(parts) < 4 or parts[0] != Responses.VALUE: msg = f"Unexpected response header: {decodedstr(header)}" raise ValueError(msg) key = parts[1] flags = int(parts[2]) size = int(parts[3]) cas = int(parts[4]) if len(parts) > 4 else None value = data.read(size) if len(value) != size: raise NotEnoughData(len(value) + len(header)) item = MemcachedItem[AnyStr]( cast(AnyStr, decodedstr(key, self.encoding) if self.decode_responses else key), flags, size, cas, cast(AnyStr, decodedstr(value, self.encoding) if self.decode_responses else value), ) data.read(2) self.items.append(item) return {i.key: i for i in self.items} def merge( self, results: list[dict[AnyStr, MemcachedItem[AnyStr]]] ) -> dict[AnyStr, MemcachedItem[AnyStr]]: merged = {} for res in results: for key, item in res.items(): merged[key] = item return merged class GetsCommand(GetCommand[AnyStr]): name = Commands.GETS class GatCommand(GetCommand[AnyStr]): __slots__ = ("expiry",) name = Commands.GAT def __init__( self, *keys: KeyT, expiry: int = 0, decode: bool = False, encoding: str = "utf-8", ) -> None: self.expiry = expiry super().__init__(*keys, decode=decode, encoding=encoding) def build_request(self) -> Request[R]: return Request(weakref.proxy(self), f"{self.expiry} {' '.join(self.keys)}".encode()) class GatsCommand(GatCommand[AnyStr]): name = Commands.GATS class GenericStoreCommand(BasicResponseCommand): __slots__ = ("cas", "encoding", "expiry", "flags", "value") def __init__( self, key: KeyT, value: ValueT, *, flags: int | None = None, expiry: int = 0, noreply: bool = False, cas: int | None = None, encoding: str = "utf-8", ) -> None: self.encoding = encoding self.flags = flags self.expiry = expiry self.value = bytestr(value, self.encoding) self.cas = cas super().__init__(key, noreply=noreply) def build_request(self) -> Request[R]: header = f"{decodedstr(self.keys[0])} {self.flags or 0} {self.expiry}" header += f" {len(self.value)}" if self.cas is not None: header += f" {self.cas}" return Request(weakref.proxy(self), header.encode(), [self.value + LINE_END]) class SetCommand(GenericStoreCommand): name = Commands.SET success = Responses.STORED class CheckAndSetCommand(GenericStoreCommand): name = Commands.CAS success = Responses.STORED class AddCommand(GenericStoreCommand): name = Commands.ADD success = Responses.STORED class AppendCommand(GenericStoreCommand): name = Commands.APPEND success = Responses.STORED class PrependCommand(GenericStoreCommand): name = Commands.PREPEND success = Responses.STORED class ReplaceCommand(GenericStoreCommand): name = Commands.REPLACE success = Responses.STORED class ArithmenticCommand(Command[int | None]): __slots__ = ("amount",) def __init__(self, key: KeyT, amount: int, noreply: bool) -> None: self.amount = amount super().__init__(key, noreply=noreply) def build_request(self) -> Request[R]: request = f"{decodedstr(self.keys[0])} {self.amount}" return Request(weakref.proxy(self), request.encode()) def parse(self, data: BytesIO, endpoint: SingleMemcachedInstanceEndpoint) -> int | None: response = data.readline() self._check_header(response) response = response.rstrip() if response == Responses.NOT_FOUND: return None return int(response) class IncrCommand(ArithmenticCommand): name = Commands.INCR class DecrCommand(ArithmenticCommand): name = Commands.DECR class DeleteCommand(BasicResponseCommand): name = Commands.DELETE success = Responses.DELETED def build_request(self) -> Request[R]: return Request(weakref.proxy(self), bytestr(self.keys[0])) class TouchCommand(BasicResponseCommand): __slots__ = ("expiry",) name = Commands.TOUCH success = Responses.TOUCHED def __init__(self, key: KeyT, *, expiry: int, noreply: bool = False) -> None: self.expiry = expiry super().__init__(key, noreply=noreply) def build_request(self) -> Request[R]: return Request(weakref.proxy(self), f"{self.keys[0]} {self.expiry}".encode()) class FlushAllCommand(BasicResponseCommand): __slots__ = ("expiry",) name = Commands.FLUSH_ALL success = Responses.OK def __init__(self, expiry: int) -> None: self.expiry = expiry super().__init__() def build_request(self) -> Request[R]: return Request(weakref.proxy(self), bytestr(self.expiry)) def merge(self, results: list[bool]) -> bool: return all(results) class VersionCommand(Command[dict[SingleMemcachedInstanceEndpoint, str]]): name = Commands.VERSION def build_request(self) -> Request[R]: return Request(weakref.proxy(self), b"") def parse( self, data: BytesIO, endpoint: SingleMemcachedInstanceEndpoint ) -> dict[SingleMemcachedInstanceEndpoint, str]: response = data.readline() self._check_header(response) return {endpoint: response.partition(Responses.VERSION)[-1].strip().decode("utf-8")} def merge( self, responses: list[dict[SingleMemcachedInstanceEndpoint, str]] ) -> dict[SingleMemcachedInstanceEndpoint, str]: return dict(ChainMap(*responses)) class StatsCommand(Command[dict[SingleMemcachedInstanceEndpoint, dict[AnyStr, AnyStr]]]): name = Commands.STATS def __init__( self, arg: str | None = None, *, decode_responses: bool = False, encoding: str = "utf-8" ): self.arg = arg self.decode_responses = decode_responses self.encoding = encoding super().__init__(noreply=False) def build_request(self) -> Request[R]: return Request(weakref.proxy(self), b"" if not self.arg else bytestr(self.arg)) def parse( self, data: BytesIO, endpoint: SingleMemcachedInstanceEndpoint ) -> dict[SingleMemcachedInstanceEndpoint, dict[AnyStr, AnyStr]]: stats = {} while True: section = data.readline() self._check_header(section) if section.startswith(Responses.END): break elif section.startswith(Responses.STAT): part = section.lstrip(Responses.STAT).strip() item, value = (decodedstr(part) if self.decode_responses else part).split() stats[cast(AnyStr, item)] = cast(AnyStr, value) return {endpoint: stats} def merge( self, responses: list[dict[SingleMemcachedInstanceEndpoint, dict[AnyStr, AnyStr]]] ) -> dict[SingleMemcachedInstanceEndpoint, dict[AnyStr, AnyStr]]: return dict(ChainMap(*responses)) class AWSAutoDiscoveryConfig(Command[tuple[int, set[SingleMemcachedInstanceEndpoint]]]): name = Commands.CONFIG def build_request(self) -> Request[R]: return Request(weakref.proxy(self), b"get cluster") def parse( self, data: BytesIO, endpoint: SingleMemcachedInstanceEndpoint ) -> tuple[int, set[SingleMemcachedInstanceEndpoint]]: header = data.readline() self._check_header(header) version = int(data.readline()) parsed_endpoints = [host.split(b"|") for host in data.readline().split(b" ")] if (data.readline().strip()) != Responses.END: raise AutoDiscoveryError("Malformed response") return version, {TCPEndpoint(host[1].decode(), int(host[2])) for host in parsed_endpoints}