Skip to content

Commit

Permalink
[PR #8089/dc38630b backport][3.10] 💅 Propagate error causes via async…
Browse files Browse the repository at this point in the history
…io protocols (#8161)

**This is a backport of PR #8089 as merged into master
(dc38630).**

This is supposed to unify setting exceptions on the future objects,
allowing to also attach their causes whenever available. It'll make
possible for the end-users to see more detailed tracebacks.

It's also supposed to help with tracking down what's happening with
#4581.

PR #8089

Co-Authored-By: J. Nick Koston <[email protected]>
Co-Authored-By: Sam Bull <[email protected]>
(cherry picked from commit dc38630)
  • Loading branch information
webknjaz authored Feb 16, 2024
1 parent 6cb21d1 commit d4322e7
Show file tree
Hide file tree
Showing 15 changed files with 177 additions and 66 deletions.
3 changes: 3 additions & 0 deletions CHANGES/8089.bugfix.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
The asynchronous internals now set the underlying causes
when assigning exceptions to the future objects
-- by :user:`webknjaz`.
12 changes: 7 additions & 5 deletions aiohttp/_http_parser.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ from multidict import CIMultiDict as _CIMultiDict, CIMultiDictProxy as _CIMultiD
from yarl import URL as _URL

from aiohttp import hdrs
from aiohttp.helpers import DEBUG
from aiohttp.helpers import DEBUG, set_exception

from .http_exceptions import (
BadHttpMessage,
Expand Down Expand Up @@ -763,11 +763,13 @@ cdef int cb_on_body(cparser.llhttp_t* parser,
cdef bytes body = at[:length]
try:
pyparser._payload.feed_data(body, length)
except BaseException as exc:
except BaseException as underlying_exc:
reraised_exc = underlying_exc
if pyparser._payload_exception is not None:
pyparser._payload.set_exception(pyparser._payload_exception(str(exc)))
else:
pyparser._payload.set_exception(exc)
reraised_exc = pyparser._payload_exception(str(underlying_exc))

set_exception(pyparser._payload, reraised_exc, underlying_exc)

pyparser._payload_error = 1
return -1
else:
Expand Down
7 changes: 6 additions & 1 deletion aiohttp/base_protocol.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import asyncio
from typing import Optional, cast

from .helpers import set_exception
from .tcp_helpers import tcp_nodelay


Expand Down Expand Up @@ -76,7 +77,11 @@ def connection_lost(self, exc: Optional[BaseException]) -> None:
if exc is None:
waiter.set_result(None)
else:
waiter.set_exception(exc)
set_exception(
waiter,
ConnectionError("Connection lost"),
exc,
)

async def _drain_helper(self) -> None:
if not self.connected:
Expand Down
66 changes: 49 additions & 17 deletions aiohttp/client_proto.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,14 @@
ServerDisconnectedError,
SocketTimeoutError,
)
from .helpers import BaseTimerContext, status_code_must_be_empty_body
from .helpers import (
_EXC_SENTINEL,
BaseTimerContext,
set_exception,
status_code_must_be_empty_body,
)
from .http import HttpResponseParser, RawResponseMessage
from .http_exceptions import HttpProcessingError
from .streams import EMPTY_PAYLOAD, DataQueue, StreamReader


Expand Down Expand Up @@ -73,36 +79,58 @@ def is_connected(self) -> bool:
def connection_lost(self, exc: Optional[BaseException]) -> None:
self._drop_timeout()

original_connection_error = exc
reraised_exc = original_connection_error

connection_closed_cleanly = original_connection_error is None

if self._payload_parser is not None:
with suppress(Exception):
with suppress(Exception): # FIXME: log this somehow?
self._payload_parser.feed_eof()

uncompleted = None
if self._parser is not None:
try:
uncompleted = self._parser.feed_eof()
except Exception as e:
except Exception as underlying_exc:
if self._payload is not None:
exc = ClientPayloadError("Response payload is not completed")
exc.__cause__ = e
self._payload.set_exception(exc)
client_payload_exc_msg = (
f"Response payload is not completed: {underlying_exc !r}"
)
if not connection_closed_cleanly:
client_payload_exc_msg = (
f"{client_payload_exc_msg !s}. "
f"{original_connection_error !r}"
)
set_exception(
self._payload,
ClientPayloadError(client_payload_exc_msg),
underlying_exc,
)

if not self.is_eof():
if isinstance(exc, OSError):
exc = ClientOSError(*exc.args)
if exc is None:
exc = ServerDisconnectedError(uncompleted)
if isinstance(original_connection_error, OSError):
reraised_exc = ClientOSError(*original_connection_error.args)
if connection_closed_cleanly:
reraised_exc = ServerDisconnectedError(uncompleted)
# assigns self._should_close to True as side effect,
# we do it anyway below
self.set_exception(exc)
underlying_non_eof_exc = (
_EXC_SENTINEL
if connection_closed_cleanly
else original_connection_error
)
assert underlying_non_eof_exc is not None
assert reraised_exc is not None
self.set_exception(reraised_exc, underlying_non_eof_exc)

self._should_close = True
self._parser = None
self._payload = None
self._payload_parser = None
self._reading_paused = False

super().connection_lost(exc)
super().connection_lost(reraised_exc)

def eof_received(self) -> None:
# should call parser.feed_eof() most likely
Expand All @@ -116,10 +144,14 @@ def resume_reading(self) -> None:
super().resume_reading()
self._reschedule_timeout()

def set_exception(self, exc: BaseException) -> None:
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
self._should_close = True
self._drop_timeout()
super().set_exception(exc)
super().set_exception(exc, exc_cause)

def set_parser(self, parser: Any, payload: Any) -> None:
# TODO: actual types are:
Expand Down Expand Up @@ -196,7 +228,7 @@ def _on_read_timeout(self) -> None:
exc = SocketTimeoutError("Timeout on reading data from socket")
self.set_exception(exc)
if self._payload is not None:
self._payload.set_exception(exc)
set_exception(self._payload, exc)

def data_received(self, data: bytes) -> None:
self._reschedule_timeout()
Expand All @@ -222,14 +254,14 @@ def data_received(self, data: bytes) -> None:
# parse http messages
try:
messages, upgraded, tail = self._parser.feed_data(data)
except BaseException as exc:
except BaseException as underlying_exc:
if self.transport is not None:
# connection.release() could be called BEFORE
# data_received(), the transport is already
# closed in this case
self.transport.close()
# should_close is True after the call
self.set_exception(exc)
self.set_exception(HttpProcessingError(), underlying_exc)
return

self._upgraded = upgraded
Expand Down
34 changes: 22 additions & 12 deletions aiohttp/client_reqrep.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@
netrc_from_env,
noop,
reify,
set_exception,
set_result,
)
from .http import (
Expand Down Expand Up @@ -630,20 +631,29 @@ async def write_bytes(

for chunk in self.body:
await writer.write(chunk) # type: ignore[arg-type]
except OSError as exc:
if exc.errno is None and isinstance(exc, asyncio.TimeoutError):
protocol.set_exception(exc)
else:
new_exc = ClientOSError(
exc.errno, "Can not write request body for %s" % self.url
except OSError as underlying_exc:
reraised_exc = underlying_exc

exc_is_not_timeout = underlying_exc.errno is not None or not isinstance(
underlying_exc, asyncio.TimeoutError
)
if exc_is_not_timeout:
reraised_exc = ClientOSError(
underlying_exc.errno,
f"Can not write request body for {self.url !s}",
)
new_exc.__context__ = exc
new_exc.__cause__ = exc
protocol.set_exception(new_exc)

set_exception(protocol, reraised_exc, underlying_exc)
except asyncio.CancelledError:
await writer.write_eof()
except Exception as exc:
protocol.set_exception(exc)
except Exception as underlying_exc:
set_exception(
protocol,
ClientConnectionError(
f"Failed to send bytes into the underlying connection {conn !s}",
),
underlying_exc,
)
else:
await writer.write_eof()
protocol.start_timeout()
Expand Down Expand Up @@ -1086,7 +1096,7 @@ def _cleanup_writer(self) -> None:
def _notify_content(self) -> None:
content = self.content
if content and content.exception() is None:
content.set_exception(ClientConnectionError("Connection closed"))
set_exception(content, ClientConnectionError("Connection closed"))
self._released = True

async def wait_for_close(self) -> None:
Expand Down
36 changes: 33 additions & 3 deletions aiohttp/helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -810,9 +810,39 @@ def set_result(fut: "asyncio.Future[_T]", result: _T) -> None:
fut.set_result(result)


def set_exception(fut: "asyncio.Future[_T]", exc: BaseException) -> None:
if not fut.done():
fut.set_exception(exc)
_EXC_SENTINEL = BaseException()


class ErrorableProtocol(Protocol):
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = ...,
) -> None:
... # pragma: no cover


def set_exception(
fut: "asyncio.Future[_T] | ErrorableProtocol",
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
"""Set future exception.
If the future is marked as complete, this function is a no-op.
:param exc_cause: An exception that is a direct cause of ``exc``.
Only set if provided.
"""
if asyncio.isfuture(fut) and fut.done():
return

exc_is_sentinel = exc_cause is _EXC_SENTINEL
exc_causes_itself = exc is exc_cause
if not exc_is_sentinel and not exc_causes_itself:
exc.__cause__ = exc_cause

fut.set_exception(exc)


@functools.total_ordering
Expand Down
27 changes: 18 additions & 9 deletions aiohttp/http_parser.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,12 @@
from .base_protocol import BaseProtocol
from .compression_utils import HAS_BROTLI, BrotliDecompressor, ZLibDecompressor
from .helpers import (
_EXC_SENTINEL,
DEBUG,
NO_EXTENSIONS,
BaseTimerContext,
method_must_be_empty_body,
set_exception,
status_code_must_be_empty_body,
)
from .http_exceptions import (
Expand Down Expand Up @@ -446,13 +448,16 @@ def get_content_length() -> Optional[int]:
assert self._payload_parser is not None
try:
eof, data = self._payload_parser.feed_data(data[start_pos:], SEP)
except BaseException as exc:
except BaseException as underlying_exc:
reraised_exc = underlying_exc
if self.payload_exception is not None:
self._payload_parser.payload.set_exception(
self.payload_exception(str(exc))
)
else:
self._payload_parser.payload.set_exception(exc)
reraised_exc = self.payload_exception(str(underlying_exc))

set_exception(
self._payload_parser.payload,
reraised_exc,
underlying_exc,
)

eof = True
data = b""
Expand Down Expand Up @@ -834,7 +839,7 @@ def feed_data(
exc = TransferEncodingError(
chunk[:pos].decode("ascii", "surrogateescape")
)
self.payload.set_exception(exc)
set_exception(self.payload, exc)
raise exc
size = int(bytes(size_b), 16)

Expand Down Expand Up @@ -939,8 +944,12 @@ def __init__(self, out: StreamReader, encoding: Optional[str]) -> None:
else:
self.decompressor = ZLibDecompressor(encoding=encoding)

def set_exception(self, exc: BaseException) -> None:
self.out.set_exception(exc)
def set_exception(
self,
exc: BaseException,
exc_cause: BaseException = _EXC_SENTINEL,
) -> None:
set_exception(self.out, exc, exc_cause)

def feed_data(self, chunk: bytes, size: int) -> None:
if not size:
Expand Down
4 changes: 2 additions & 2 deletions aiohttp/http_websocket.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@

from .base_protocol import BaseProtocol
from .compression_utils import ZLibCompressor, ZLibDecompressor
from .helpers import NO_EXTENSIONS
from .helpers import NO_EXTENSIONS, set_exception
from .streams import DataQueue

__all__ = (
Expand Down Expand Up @@ -314,7 +314,7 @@ def feed_data(self, data: bytes) -> Tuple[bool, bytes]:
return self._feed_data(data)
except Exception as exc:
self._exc = exc
self.queue.set_exception(exc)
set_exception(self.queue, exc)
return True, b""

def _feed_data(self, data: bytes) -> Tuple[bool, bytes]:
Expand Down
Loading

0 comments on commit d4322e7

Please sign in to comment.