Skip to content

Commit

Permalink
[PR #8089/dc38630b backport][3.9] 💅 Propagate error causes via asynci…
Browse files Browse the repository at this point in the history
…o protocols

This is supposed to unify setting exceptions on the future objects,
allowing to also attach their causes whenever available. It'll make
possible for the end-users to see more detailed tracebacks.

It's also supposed to help with tracking down what's happening
with #4581.

PR #8089

Co-Authored-By: J. Nick Koston <[email protected]>
Co-Authored-By: Sam Bull <[email protected]>
(cherry picked from commit dc38630)
  • Loading branch information
webknjaz committed Feb 16, 2024
1 parent e45da11 commit 7af2680
Show file tree
Hide file tree
Showing 15 changed files with 178 additions and 66 deletions.
3 changes: 3 additions & 0 deletions CHANGES/8089.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The asynchronous internals now set the underlying causes
when assigning exceptions to the future objects
-- by :user:`webknjaz`.
12 changes: 7 additions & 5 deletions aiohttp/_http_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiD
from yarl import URL as _URL

from aiohttp import hdrs
from aiohttp.helpers import DEBUG
from aiohttp.helpers import DEBUG, set_exception

from .http_exceptions import (
BadHttpMessage,
Expand Down Expand Up @@ -763,11 +763,13 @@ cdef int cb_on_body(cparser.llhttp_t* parser,
cdef bytes body = at[:length]
try:
pyparser._payload.feed_data(body, length)
except BaseException as exc:
except BaseException as underlying_exc:
reraised_exc = underlying_exc
if pyparser._payload_exception is not None:
pyparser._payload.set_exception(pyparser._payload_exception(str(exc)))
else:
pyparser._payload.set_exception(exc)
reraised_exc = pyparser._payload_exception(str(underlying_exc))

set_exception(pyparser._payload, reraised_exc, underlying_exc)

pyparser._payload_error = 1
return -1
else:
Expand Down
7 changes: 6 additions & 1 deletion aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import Optional, cast

from .helpers import set_exception
from .tcp_helpers import tcp_nodelay


Expand Down Expand Up @@ -76,7 +77,11 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
set_exception(
waiter,
ConnectionError("Connection lost"),
exc,
)

async def _drain_helper(self) -> None:
if not self.connected:
Expand Down
67 changes: 50 additions & 17 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,20 @@

from .base_protocol import BaseProtocol
from .client_exceptions import (
ClientConnectionError,
ClientOSError,
ClientPayloadError,
ServerDisconnectedError,
ServerTimeoutError,
)
from .helpers import BaseTimerContext, status_code_must_be_empty_body
from .helpers import (
_EXC_SENTINEL,
BaseTimerContext,
set_exception,
status_code_must_be_empty_body,
)
from .http import HttpResponseParser, RawResponseMessage
from .http_exceptions import HttpProcessingError
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader


Expand Down Expand Up @@ -73,36 +80,58 @@ def is_connected(self) -> bool:
def connection_lost(self, exc: Optional[BaseException]) -> None:
self._drop_timeout()

original_connection_error = exc
reraised_exc = original_connection_error

connection_closed_cleanly = original_connection_error is None

if self._payload_parser is not None:
with suppress(Exception):
with suppress(Exception): # FIXME: log this somehow?
self._payload_parser.feed_eof()

uncompleted = None
if self._parser is not None:
try:
uncompleted = self._parser.feed_eof()
except Exception as e:
except Exception as underlying_exc:
if self._payload is not None:
exc = ClientPayloadError("Response payload is not completed")
exc.__cause__ = e
self._payload.set_exception(exc)
client_payload_exc_msg = (
f"Response payload is not completed: {underlying_exc !r}"
)
if not connection_closed_cleanly:
client_payload_exc_msg = (
f"{client_payload_exc_msg !s}. "
f"{original_connection_error !r}"
)
set_exception(
self._payload,
ClientPayloadError(client_payload_exc_msg),
underlying_exc,
)

if not self.is_eof():
if isinstance(exc, OSError):
exc = ClientOSError(*exc.args)
if exc is None:
exc = ServerDisconnectedError(uncompleted)
if isinstance(original_connection_error, OSError):
reraised_exc = ClientOSError(*original_connection_error.args)
if connection_closed_cleanly:
reraised_exc = ServerDisconnectedError(uncompleted)
# assigns self._should_close to True as side effect,
# we do it anyway below
self.set_exception(exc)
underlying_non_eof_exc = (
_EXC_SENTINEL
if connection_closed_cleanly
else original_connection_error
)
assert underlying_non_eof_exc is not None
assert reraised_exc is not None
self.set_exception(reraised_exc, underlying_non_eof_exc)

self._should_close = True
self._parser = None
self._payload = None
self._payload_parser = None
self._reading_paused = False

super().connection_lost(exc)
super().connection_lost(reraised_exc)

def eof_received(self) -> None:
# should call parser.feed_eof() most likely
Expand All @@ -116,10 +145,14 @@ def resume_reading(self) -> None:
super().resume_reading()
self._reschedule_timeout()

def set_exception(self, exc: BaseException) -> None:
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
self._should_close = True
self._drop_timeout()
super().set_exception(exc)
super().set_exception(exc, exc_cause)

def set_parser(self, parser: Any, payload: Any) -> None:
# TODO: actual types are:
Expand Down Expand Up @@ -196,7 +229,7 @@ def _on_read_timeout(self) -> None:
exc = ServerTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
if self._payload is not None:
self._payload.set_exception(exc)
set_exception(self._payload, exc)

def data_received(self, data: bytes) -> None:
self._reschedule_timeout()
Expand All @@ -222,14 +255,14 @@ def data_received(self, data: bytes) -> None:
# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as exc:
except BaseException as underlying_exc:
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
# should_close is True after the call
self.set_exception(exc)
self.set_exception(HttpProcessingError(), underlying_exc)
return

self._upgraded = upgraded
Expand Down
34 changes: 22 additions & 12 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
netrc_from_env,
noop,
reify,
set_exception,
set_result,
)
from .http import (
Expand Down Expand Up @@ -630,20 +631,29 @@ async def write_bytes(

for chunk in self.body:
await writer.write(chunk) # type: ignore[arg-type]
except OSError as exc:
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
protocol.set_exception(exc)
else:
new_exc = ClientOSError(
exc.errno, "Can not write request body for %s" % self.url
except OSError as underlying_exc:
reraised_exc = underlying_exc

exc_is_not_timeout = underlying_exc.errno is not None or not isinstance(
underlying_exc, asyncio.TimeoutError
)
if exc_is_not_timeout:
reraised_exc = ClientOSError(
underlying_exc.errno,
f"Can not write request body for {self.url !s}",
)
new_exc.__context__ = exc
new_exc.__cause__ = exc
protocol.set_exception(new_exc)

set_exception(protocol, reraised_exc, underlying_exc)
except asyncio.CancelledError:
await writer.write_eof()
except Exception as exc:
protocol.set_exception(exc)
except Exception as underlying_exc:
set_exception(
protocol,
ClientConnectionError(
f"Failed to send bytes into the underlying connection {conn !s}",
),
underlying_exc,
)
else:
await writer.write_eof()
protocol.start_timeout()
Expand Down Expand Up @@ -1086,7 +1096,7 @@ def _cleanup_writer(self) -> None:
def _notify_content(self) -> None:
content = self.content
if content and content.exception() is None:
content.set_exception(ClientConnectionError("Connection closed"))
set_exception(content, ClientConnectionError("Connection closed"))
self._released = True

async def wait_for_close(self) -> None:
Expand Down
36 changes: 33 additions & 3 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,9 +810,39 @@ def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
fut.set_result(result)


def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None:
if not fut.done():
fut.set_exception(exc)
_EXC_SENTINEL = BaseException()


class ErrorableProtocol(Protocol):
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = ...,
) -> None:
... # pragma: no cover


def set_exception(
fut: "asyncio.Future[_T] | ErrorableProtocol",
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
"""Set future exception.
If the future is marked as complete, this function is a no-op.
:param exc_cause: An exception that is a direct cause of ``exc``.
Only set if provided.
"""
if asyncio.isfuture(fut) and fut.done():
return

exc_is_sentinel = exc_cause is _EXC_SENTINEL
exc_causes_itself = exc is exc_cause
if not exc_is_sentinel and not exc_causes_itself:
exc.__cause__ = exc_cause

fut.set_exception(exc)


@functools.total_ordering
Expand Down
27 changes: 18 additions & 9 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
from .base_protocol import BaseProtocol
from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor
from .helpers import (
_EXC_SENTINEL,
DEBUG,
NO_EXTENSIONS,
BaseTimerContext,
method_must_be_empty_body,
set_exception,
status_code_must_be_empty_body,
)
from .http_exceptions import (
Expand Down Expand Up @@ -446,13 +448,16 @@ def get_content_length() -> Optional[int]:
assert self._payload_parser is not None
try:
eof, data = self._payload_parser.feed_data(data[start_pos:], SEP)
except BaseException as exc:
except BaseException as underlying_exc:
reraised_exc = underlying_exc
if self.payload_exception is not None:
self._payload_parser.payload.set_exception(
self.payload_exception(str(exc))
)
else:
self._payload_parser.payload.set_exception(exc)
reraised_exc = self.payload_exception(str(underlying_exc))

set_exception(
self._payload_parser.payload,
reraised_exc,
underlying_exc,
)

eof = True
data = b""
Expand Down Expand Up @@ -834,7 +839,7 @@ def feed_data(
exc = TransferEncodingError(
chunk[:pos].decode("ascii", "surrogateescape")
)
self.payload.set_exception(exc)
set_exception(self.payload, exc)
raise exc
size = int(bytes(size_b), 16)

Expand Down Expand Up @@ -939,8 +944,12 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None:
else:
self.decompressor = ZLibDecompressor(encoding=encoding)

def set_exception(self, exc: BaseException) -> None:
self.out.set_exception(exc)
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
set_exception(self.out, exc, exc_cause)

def feed_data(self, chunk: bytes, size: int) -> None:
if not size:
Expand Down
4 changes: 2 additions & 2 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .base_protocol import BaseProtocol
from .compression_utils import ZLibCompressor, ZLibDecompressor
from .helpers import NO_EXTENSIONS
from .helpers import NO_EXTENSIONS, set_exception
from .streams import DataQueue

__all__ = (
Expand Down Expand Up @@ -314,7 +314,7 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
return self._feed_data(data)
except Exception as exc:
self._exc = exc
self.queue.set_exception(exc)
set_exception(self.queue, exc)
return True, b""

def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
Expand Down
Loading

0 comments on commit 7af2680

Please sign in to comment.