From 40a769ec377c0698c94acdedff21d969142499ae Mon Sep 17 00:00:00 2001 From: woutdenolf Date: Fri, 23 Jun 2023 15:52:52 +0200 Subject: [PATCH] Extract abstract async connection class (#2734) * make 'socket_timeout' and 'socket_connect_timeout' equivalent for TCP and UDS connections * abstract asynio connection in analogy with the synchronous connection --------- Co-authored-by: dvora-h <67596500+dvora-h@users.noreply.github.com> --- redis/asyncio/connection.py | 300 +++++++++++++++++------------------- redis/connection.py | 16 +- 2 files changed, 148 insertions(+), 168 deletions(-) diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index efe3a3e1b0..1bc3aa38a6 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -8,6 +8,7 @@ import sys import threading import weakref +from abc import abstractmethod from itertools import chain from types import MappingProxyType from typing import ( @@ -198,7 +199,7 @@ def parse_error(cls, response: str) -> ResponseError: def on_disconnect(self): raise NotImplementedError() - def on_connect(self, connection: "Connection"): + def on_connect(self, connection: "AbstractConnection"): raise NotImplementedError() async def can_read_destructive(self) -> bool: @@ -226,7 +227,7 @@ def _clear(self): self._buffer = b"" self._chunks.clear() - def on_connect(self, connection: "Connection"): + def on_connect(self, connection: "AbstractConnection"): """Called when the stream connects""" self._stream = connection._reader if self._stream is None: @@ -360,7 +361,7 @@ def __init__(self, socket_read_size: int): super().__init__(socket_read_size=socket_read_size) self._reader: Optional[hiredis.Reader] = None - def on_connect(self, connection: "Connection"): + def on_connect(self, connection: "AbstractConnection"): self._stream = connection._reader kwargs: _HiredisReaderArgs = { "protocolError": InvalidResponse, @@ -432,25 +433,23 @@ async def read_response( class ConnectCallbackProtocol(Protocol): - def __call__(self, connection: "Connection"): + def __call__(self, connection: "AbstractConnection"): ... class AsyncConnectCallbackProtocol(Protocol): - async def __call__(self, connection: "Connection"): + async def __call__(self, connection: "AbstractConnection"): ... ConnectCallbackT = Union[ConnectCallbackProtocol, AsyncConnectCallbackProtocol] -class Connection: - """Manages TCP communication to and from a Redis server""" +class AbstractConnection: + """Manages communication to and from a Redis server""" __slots__ = ( "pid", - "host", - "port", "db", "username", "client_name", @@ -458,9 +457,6 @@ class Connection: "password", "socket_timeout", "socket_connect_timeout", - "socket_keepalive", - "socket_keepalive_options", - "socket_type", "redis_connect_func", "retry_on_timeout", "retry_on_error", @@ -482,15 +478,10 @@ class Connection: def __init__( self, *, - host: str = "localhost", - port: Union[str, int] = 6379, db: Union[str, int] = 0, password: Optional[str] = None, socket_timeout: Optional[float] = None, socket_connect_timeout: Optional[float] = None, - socket_keepalive: bool = False, - socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, - socket_type: int = 0, retry_on_timeout: bool = False, retry_on_error: Union[list, _Sentinel] = SENTINEL, encoding: str = "utf-8", @@ -514,18 +505,15 @@ def __init__( "2. 'credential_provider'" ) self.pid = os.getpid() - self.host = host - self.port = int(port) self.db = db self.client_name = client_name self.credential_provider = credential_provider self.password = password self.username = username self.socket_timeout = socket_timeout - self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None - self.socket_keepalive = socket_keepalive - self.socket_keepalive_options = socket_keepalive_options or {} - self.socket_type = socket_type + if socket_connect_timeout is None: + socket_connect_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: retry_on_error = [] @@ -546,7 +534,6 @@ def __init__( self.retry = Retry(NoBackoff(), 0) self.health_check_interval = health_check_interval self.next_health_check: float = -1 - self.ssl_context: Optional[RedisSSLContext] = None self.encoder = encoder_class(encoding, encoding_errors, decode_responses) self.redis_connect_func = redis_connect_func self._reader: Optional[asyncio.StreamReader] = None @@ -560,11 +547,9 @@ def __repr__(self): repr_args = ",".join((f"{k}={v}" for k, v in self.repr_pieces())) return f"{self.__class__.__name__}<{repr_args}>" + @abstractmethod def repr_pieces(self): - pieces = [("host", self.host), ("port", self.port), ("db", self.db)] - if self.client_name: - pieces.append(("client_name", self.client_name)) - return pieces + pass @property def is_connected(self): @@ -623,51 +608,17 @@ async def connect(self): if task and inspect.isawaitable(task): await task + @abstractmethod async def _connect(self): - """Create a TCP socket connection""" - async with async_timeout(self.socket_connect_timeout): - reader, writer = await asyncio.open_connection( - host=self.host, - port=self.port, - ssl=self.ssl_context.get() if self.ssl_context else None, - ) - self._reader = reader - self._writer = writer - sock = writer.transport.get_extra_info("socket") - if sock: - sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - try: - # TCP_KEEPALIVE - if self.socket_keepalive: - sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) - for k, v in self.socket_keepalive_options.items(): - sock.setsockopt(socket.SOL_TCP, k, v) + pass - except (OSError, TypeError): - # `socket_keepalive_options` might contain invalid options - # causing an error. Do not leave the connection open. - writer.close() - raise + @abstractmethod + def _host_error(self) -> str: + pass - def _error_message(self, exception): - # args for socket.error can either be (errno, "message") - # or just "message" - if not exception.args: - # asyncio has a bug where on Connection reset by peer, the - # exception is not instanciated, so args is empty. This is the - # workaround. - # See: https://github.com/redis/redis-py/issues/2237 - # See: https://github.com/python/cpython/issues/94061 - return ( - f"Error connecting to {self.host}:{self.port}. Connection reset by peer" - ) - elif len(exception.args) == 1: - return f"Error connecting to {self.host}:{self.port}. {exception.args[0]}." - else: - return ( - f"Error {exception.args[0]} connecting to {self.host}:{self.port}. " - f"{exception.args[0]}." - ) + @abstractmethod + def _error_message(self, exception: BaseException) -> str: + pass async def on_connect(self) -> None: """Initialize the connection, authenticate and select a database""" @@ -808,9 +759,8 @@ async def can_read_destructive(self): return await self._parser.can_read_destructive() except OSError as e: await self.disconnect(nowait=True) - raise ConnectionError( - f"Error while reading from {self.host}:{self.port}: {e.args}" - ) + host_error = self._host_error() + raise ConnectionError(f"Error while reading from {host_error}: {e.args}") async def read_response( self, @@ -821,6 +771,7 @@ async def read_response( ): """Read the response from a previously sent command""" read_timeout = timeout if timeout is not None else self.socket_timeout + host_error = self._host_error() try: if read_timeout is not None: async with async_timeout(read_timeout): @@ -838,13 +789,11 @@ async def read_response( # it was a self.socket_timeout error. if disconnect_on_error: await self.disconnect(nowait=True) - raise TimeoutError(f"Timeout reading from {self.host}:{self.port}") + raise TimeoutError(f"Timeout reading from {host_error}") except OSError as e: if disconnect_on_error: await self.disconnect(nowait=True) - raise ConnectionError( - f"Error while reading from {self.host}:{self.port} : {e.args}" - ) + raise ConnectionError(f"Error while reading from {host_error} : {e.args}") except BaseException: # Also by default close in case of BaseException. A lot of code # relies on this behaviour when doing Command/Response pairs. @@ -938,7 +887,90 @@ def pack_commands(self, commands: Iterable[Iterable[EncodableT]]) -> List[bytes] return output +class Connection(AbstractConnection): + "Manages TCP communication to and from a Redis server" + + def __init__( + self, + *, + host: str = "localhost", + port: Union[str, int] = 6379, + socket_keepalive: bool = False, + socket_keepalive_options: Optional[Mapping[int, Union[int, bytes]]] = None, + socket_type: int = 0, + **kwargs, + ): + self.host = host + self.port = int(port) + self.socket_keepalive = socket_keepalive + self.socket_keepalive_options = socket_keepalive_options or {} + self.socket_type = socket_type + super().__init__(**kwargs) + + def repr_pieces(self): + pieces = [("host", self.host), ("port", self.port), ("db", self.db)] + if self.client_name: + pieces.append(("client_name", self.client_name)) + return pieces + + def _connection_arguments(self) -> Mapping: + return {"host": self.host, "port": self.port} + + async def _connect(self): + """Create a TCP socket connection""" + async with async_timeout(self.socket_connect_timeout): + reader, writer = await asyncio.open_connection( + **self._connection_arguments() + ) + self._reader = reader + self._writer = writer + sock = writer.transport.get_extra_info("socket") + if sock: + sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) + try: + # TCP_KEEPALIVE + if self.socket_keepalive: + sock.setsockopt(socket.SOL_SOCKET, socket.SO_KEEPALIVE, 1) + for k, v in self.socket_keepalive_options.items(): + sock.setsockopt(socket.SOL_TCP, k, v) + + except (OSError, TypeError): + # `socket_keepalive_options` might contain invalid options + # causing an error. Do not leave the connection open. + writer.close() + raise + + def _host_error(self) -> str: + return f"{self.host}:{self.port}" + + def _error_message(self, exception: BaseException) -> str: + # args for socket.error can either be (errno, "message") + # or just "message" + + host_error = self._host_error() + + if not exception.args: + # asyncio has a bug where on Connection reset by peer, the + # exception is not instanciated, so args is empty. This is the + # workaround. + # See: https://github.com/redis/redis-py/issues/2237 + # See: https://github.com/python/cpython/issues/94061 + return f"Error connecting to {host_error}. Connection reset by peer" + elif len(exception.args) == 1: + return f"Error connecting to {host_error}. {exception.args[0]}." + else: + return ( + f"Error {exception.args[0]} connecting to {host_error}. " + f"{exception.args[0]}." + ) + + class SSLConnection(Connection): + """Manages SSL connections to and from the Redis server(s). + This class extends the Connection class, adding SSL functionality, and making + use of ssl.SSLContext (https://docs.python.org/3/library/ssl.html#ssl.SSLContext) + """ + def __init__( self, ssl_keyfile: Optional[str] = None, @@ -949,7 +981,6 @@ def __init__( ssl_check_hostname: bool = False, **kwargs, ): - super().__init__(**kwargs) self.ssl_context: RedisSSLContext = RedisSSLContext( keyfile=ssl_keyfile, certfile=ssl_certfile, @@ -958,6 +989,12 @@ def __init__( ca_data=ssl_ca_data, check_hostname=ssl_check_hostname, ) + super().__init__(**kwargs) + + def _connection_arguments(self) -> Mapping: + kwargs = super()._connection_arguments() + kwargs["ssl"] = self.ssl_context.get() + return kwargs @property def keyfile(self): @@ -1037,77 +1074,12 @@ def get(self) -> ssl.SSLContext: return self.context -class UnixDomainSocketConnection(Connection): # lgtm [py/missing-call-to-init] - def __init__( - self, - *, - path: str = "", - db: Union[str, int] = 0, - username: Optional[str] = None, - password: Optional[str] = None, - socket_timeout: Optional[float] = None, - socket_connect_timeout: Optional[float] = None, - encoding: str = "utf-8", - encoding_errors: str = "strict", - decode_responses: bool = False, - retry_on_timeout: bool = False, - retry_on_error: Union[list, _Sentinel] = SENTINEL, - parser_class: Type[BaseParser] = DefaultParser, - socket_read_size: int = 65536, - health_check_interval: float = 0.0, - client_name: str = None, - retry: Optional[Retry] = None, - redis_connect_func=None, - credential_provider: Optional[CredentialProvider] = None, - ): - """ - Initialize a new UnixDomainSocketConnection. - To specify a retry policy, first set `retry_on_timeout` to `True` - then set `retry` to a valid `Retry` object - """ - if (username or password) and credential_provider is not None: - raise DataError( - "'username' and 'password' cannot be passed along with 'credential_" - "provider'. Please provide only one of the following arguments: \n" - "1. 'password' and (optional) 'username'\n" - "2. 'credential_provider'" - ) - self.pid = os.getpid() +class UnixDomainSocketConnection(AbstractConnection): + "Manages UDS communication to and from a Redis server" + + def __init__(self, *, path: str = "", **kwargs): self.path = path - self.db = db - self.client_name = client_name - self.credential_provider = credential_provider - self.password = password - self.username = username - self.socket_timeout = socket_timeout - self.socket_connect_timeout = socket_connect_timeout or socket_timeout or None - self.retry_on_timeout = retry_on_timeout - if retry_on_error is SENTINEL: - retry_on_error = [] - if retry_on_timeout: - retry_on_error.append(TimeoutError) - self.retry_on_error = retry_on_error - if retry_on_error: - if retry is None: - self.retry = Retry(NoBackoff(), 1) - else: - # deep-copy the Retry object as it is mutable - self.retry = copy.deepcopy(retry) - # Update the retry's supported errors with the specified errors - self.retry.update_supported_errors(retry_on_error) - else: - self.retry = Retry(NoBackoff(), 0) - self.health_check_interval = health_check_interval - self.next_health_check = -1 - self.redis_connect_func = redis_connect_func - self.encoder = Encoder(encoding, encoding_errors, decode_responses) - self._sock = None - self._reader = None - self._writer = None - self._socket_read_size = socket_read_size - self.set_parser(parser_class) - self._connect_callbacks = [] - self._buffer_cutoff = 6000 + super().__init__(**kwargs) def repr_pieces(self) -> Iterable[Tuple[str, Union[str, int]]]: pieces = [("path", self.path), ("db", self.db)] @@ -1122,15 +1094,21 @@ async def _connect(self): self._writer = writer await self.on_connect() - def _error_message(self, exception): + def _host_error(self) -> str: + return self.host + + def _error_message(self, exception: BaseException) -> str: # args for socket.error can either be (errno, "message") # or just "message" + host_error = self._host_error() if len(exception.args) == 1: - return f"Error connecting to unix socket: {self.path}. {exception.args[0]}." + return ( + f"Error connecting to unix socket: {host_error}. {exception.args[0]}." + ) else: return ( f"Error {exception.args[0]} connecting to unix socket: " - f"{self.path}. {exception.args[1]}." + f"{host_error}. {exception.args[1]}." ) @@ -1162,7 +1140,7 @@ def to_bool(value) -> Optional[bool]: class ConnectKwargs(TypedDict, total=False): username: str password: str - connection_class: Type[Connection] + connection_class: Type[AbstractConnection] host: str port: int db: int @@ -1284,7 +1262,7 @@ class initializer. In the case of conflicting arguments, querystring def __init__( self, - connection_class: Type[Connection] = Connection, + connection_class: Type[AbstractConnection] = Connection, max_connections: Optional[int] = None, **connection_kwargs, ): @@ -1307,8 +1285,8 @@ def __init__( self._fork_lock = threading.Lock() self._lock = asyncio.Lock() self._created_connections: int - self._available_connections: List[Connection] - self._in_use_connections: Set[Connection] + self._available_connections: List[AbstractConnection] + self._in_use_connections: Set[AbstractConnection] self.reset() # lgtm [py/init-calls-subclass] self.encoder_class = self.connection_kwargs.get("encoder_class", Encoder) @@ -1431,7 +1409,7 @@ def make_connection(self): self._created_connections += 1 return self.connection_class(**self.connection_kwargs) - async def release(self, connection: Connection): + async def release(self, connection: AbstractConnection): """Releases the connection back to the pool""" self._checkpid() async with self._lock: @@ -1452,7 +1430,7 @@ async def release(self, connection: Connection): await connection.disconnect() return - def owns_connection(self, connection: Connection): + def owns_connection(self, connection: AbstractConnection): return connection.pid == self.pid async def disconnect(self, inuse_connections: bool = True): @@ -1466,7 +1444,7 @@ async def disconnect(self, inuse_connections: bool = True): self._checkpid() async with self._lock: if inuse_connections: - connections: Iterable[Connection] = chain( + connections: Iterable[AbstractConnection] = chain( self._available_connections, self._in_use_connections ) else: @@ -1524,14 +1502,14 @@ def __init__( self, max_connections: int = 50, timeout: Optional[int] = 20, - connection_class: Type[Connection] = Connection, + connection_class: Type[AbstractConnection] = Connection, queue_class: Type[asyncio.Queue] = asyncio.LifoQueue, **connection_kwargs, ): self.queue_class = queue_class self.timeout = timeout - self._connections: List[Connection] + self._connections: List[AbstractConnection] super().__init__( connection_class=connection_class, max_connections=max_connections, @@ -1621,7 +1599,7 @@ async def get_connection(self, command_name, *keys, **options): return connection - async def release(self, connection: Connection): + async def release(self, connection: AbstractConnection): """Releases the connection back to the pool.""" # Make sure we haven't changed process. self._checkpid() diff --git a/redis/connection.py b/redis/connection.py index bf0d6dea80..bec456c9ce 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -594,6 +594,8 @@ def __init__( self, db=0, password=None, + socket_timeout=None, + socket_connect_timeout=None, retry_on_timeout=False, retry_on_error=SENTINEL, encoding="utf-8", @@ -629,6 +631,10 @@ def __init__( self.credential_provider = credential_provider self.password = password self.username = username + self.socket_timeout = socket_timeout + if socket_connect_timeout is None: + socket_connect_timeout = socket_timeout + self.socket_connect_timeout = socket_connect_timeout self.retry_on_timeout = retry_on_timeout if retry_on_error is SENTINEL: retry_on_error = [] @@ -941,8 +947,6 @@ def __init__( self, host="localhost", port=6379, - socket_timeout=None, - socket_connect_timeout=None, socket_keepalive=False, socket_keepalive_options=None, socket_type=0, @@ -950,8 +954,6 @@ def __init__( ): self.host = host self.port = int(port) - self.socket_timeout = socket_timeout - self.socket_connect_timeout = socket_connect_timeout or socket_timeout self.socket_keepalive = socket_keepalive self.socket_keepalive_options = socket_keepalive_options or {} self.socket_type = socket_type @@ -1172,9 +1174,8 @@ def _connect(self): class UnixDomainSocketConnection(AbstractConnection): "Manages UDS communication to and from a Redis server" - def __init__(self, path="", socket_timeout=None, **kwargs): + def __init__(self, path="", **kwargs): self.path = path - self.socket_timeout = socket_timeout super().__init__(**kwargs) def repr_pieces(self): @@ -1186,8 +1187,9 @@ def repr_pieces(self): def _connect(self): "Create a Unix domain socket connection" sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM) - sock.settimeout(self.socket_timeout) + sock.settimeout(self.socket_connect_timeout) sock.connect(self.path) + sock.settimeout(self.socket_timeout) return sock def _host_error(self):