diff --git a/aidial_interceptors_sdk/utils/_exceptions.py b/aidial_interceptors_sdk/utils/_exceptions.py index 350c457..9fad761 100644 --- a/aidial_interceptors_sdk/utils/_exceptions.py +++ b/aidial_interceptors_sdk/utils/_exceptions.py @@ -3,68 +3,39 @@ from typing import Dict from aidial_sdk.exceptions import HTTPException as DialException -from fastapi import HTTPException as FastAPIException -from fastapi.responses import JSONResponse as FastAPIResponse from openai import APIConnectionError, APIError, APIStatusError, APITimeoutError -from typing_extensions import override _log = logging.getLogger(__name__) -# TODO: support headers in DIAL SDK exception -class DialExceptionWithHeaders(DialException): - headers: Dict[str, str] | None = None - - def __init__(self, *, headers: Dict[str, str] | None = None, **kwargs): - super().__init__(**kwargs) - self.headers = headers - - @classmethod - def create( - cls, - status_code: int, - content: dict | str, - headers: Dict[str, str] | None = None, +def _parse_dial_exception( + status_code: int, + content: dict | str, + headers: Dict[str, str] | None = None, +): + if ( + isinstance(content, dict) + and (error := content.get("error")) + and isinstance(error, dict) ): - if ( - isinstance(content, dict) - and (error := content.get("error")) - and isinstance(error, dict) - ): - message = error.get("message") or "Unknown error" - code = error.get("code") - type = error.get("type") - param = error.get("param") - display_message = error.get("display_message") - else: - message = content - code = type = param = display_message = None - - return cls( - status_code=status_code, - message=message, - type=type, - param=param, - code=code, - display_message=display_message, - headers=headers, - ) + message = error.get("message") or "Unknown error" + code = error.get("code") + type = error.get("type") + param = error.get("param") + display_message = error.get("display_message") + else: + message = str(content) + code = type = param = display_message = None - @override - def to_fastapi_response(self) -> FastAPIResponse: - return FastAPIResponse( - status_code=self.status_code, - content=self.json_error(), - headers=self.headers, - ) - - @override - def to_fastapi_exception(self) -> FastAPIException: - return FastAPIException( - status_code=self.status_code, - detail=self.json_error(), - headers=self.headers, - ) + return DialException( + status_code=status_code, + message=message, + type=type, + param=param, + code=code, + display_message=display_message, + headers=headers, + ) def to_dial_exception(exc: Exception) -> DialException: @@ -95,7 +66,7 @@ def to_dial_exception(exc: Exception) -> DialException: except Exception: content = r.text - return DialExceptionWithHeaders.create( + return _parse_dial_exception( status_code=r.status_code, headers=plain_headers, content=content, @@ -110,7 +81,7 @@ def to_dial_exception(exc: Exception) -> DialException: except Exception: pass - return DialExceptionWithHeaders.create( + return _parse_dial_exception( status_code=status_code, headers={}, content={"error": exc.body or {}}, diff --git a/poetry.lock b/poetry.lock index cca6bb3..9777427 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2,41 +2,45 @@ [[package]] name = "aidial-sdk" -version = "0.15.0" +version = "0.16.0rc" description = "Framework to create applications and model adapters for AI DIAL" optional = false -python-versions = "<4.0,>=3.8.1" -files = [ - {file = "aidial_sdk-0.15.0-py3-none-any.whl", hash = "sha256:7b9b3e5ec9688be2919dcd7dd0312aac807dc7917393ee5f846332713ad2e26a"}, - {file = "aidial_sdk-0.15.0.tar.gz", hash = "sha256:6b47bb36e8c795300e0d4b61308c6a2f86b59abb97905390a02789b343460720"}, -] +python-versions = ">=3.8.1,<4.0" +files = [] +develop = false [package.dependencies] -aiohttp = ">=3.8.3,<4.0.0" +aiohttp = "^3.8.3" fastapi = ">=0.51,<1.0" httpx = ">=0.25.0,<1.0" -opentelemetry-api = {version = "1.20.0", optional = true, markers = "extra == \"telemetry\""} -opentelemetry-distro = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} -opentelemetry-exporter-otlp-proto-grpc = {version = "1.20.0", optional = true, markers = "extra == \"telemetry\""} -opentelemetry-exporter-prometheus = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} -opentelemetry-instrumentation = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} -opentelemetry-instrumentation-aiohttp-client = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} -opentelemetry-instrumentation-fastapi = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} -opentelemetry-instrumentation-httpx = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} -opentelemetry-instrumentation-logging = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} -opentelemetry-instrumentation-requests = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} -opentelemetry-instrumentation-system-metrics = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} -opentelemetry-instrumentation-urllib = {version = "0.41b0", optional = true, markers = "extra == \"telemetry\""} -opentelemetry-sdk = {version = "1.20.0", optional = true, markers = "extra == \"telemetry\""} -prometheus-client = {version = "0.17.1", optional = true, markers = "extra == \"telemetry\""} +opentelemetry-api = {version = "1.20.0", optional = true} +opentelemetry-distro = {version = "0.41b0", optional = true} +opentelemetry-exporter-otlp-proto-grpc = {version = "1.20.0", optional = true} +opentelemetry-exporter-prometheus = {version = "0.41b0", optional = true} +opentelemetry-instrumentation = {version = "0.41b0", optional = true} +opentelemetry-instrumentation-aiohttp-client = {version = "0.41b0", optional = true} +opentelemetry-instrumentation-fastapi = {version = "0.41b0", optional = true} +opentelemetry-instrumentation-httpx = {version = "0.41b0", optional = true} +opentelemetry-instrumentation-logging = {version = "0.41b0", optional = true} +opentelemetry-instrumentation-requests = {version = "0.41b0", optional = true} +opentelemetry-instrumentation-system-metrics = {version = "0.41b0", optional = true} +opentelemetry-instrumentation-urllib = {version = "0.41b0", optional = true} +opentelemetry-sdk = {version = "1.20.0", optional = true} +prometheus-client = {version = "0.17.1", optional = true} pydantic = ">=1.10,<3" -requests = ">=2.19,<3.0" +requests = "^2.19" uvicorn = ">=0.19,<1.0" -wrapt = ">=1.14,<2.0" +wrapt = "^1.14" [package.extras] telemetry = ["opentelemetry-api (==1.20.0)", "opentelemetry-distro (==0.41b0)", "opentelemetry-exporter-otlp-proto-grpc (==1.20.0)", "opentelemetry-exporter-prometheus (==0.41b0)", "opentelemetry-instrumentation (==0.41b0)", "opentelemetry-instrumentation-aiohttp-client (==0.41b0)", "opentelemetry-instrumentation-fastapi (==0.41b0)", "opentelemetry-instrumentation-httpx (==0.41b0)", "opentelemetry-instrumentation-logging (==0.41b0)", "opentelemetry-instrumentation-requests (==0.41b0)", "opentelemetry-instrumentation-system-metrics (==0.41b0)", "opentelemetry-instrumentation-urllib (==0.41b0)", "opentelemetry-sdk (==1.20.0)", "prometheus-client (==0.17.1)"] +[package.source] +type = "git" +url = "https://github.com/epam/ai-dial-sdk.git" +reference = "feat/support-headers-in-dial-exception" +resolved_reference = "9fdead4e938e3f81906059155cc41894ac9ff80c" + [[package]] name = "aiohappyeyeballs" version = "2.4.0" @@ -3030,4 +3034,4 @@ examples = ["aiostream", "en_core_web_sm", "numpy", "pillow", "spacy"] [metadata] lock-version = "2.0" python-versions = ">=3.11,<4.0" -content-hash = "479014a1cbff10400410cba9f2e0b408d3bf00d5d2e7670223e510f6ac696743" +content-hash = "01035d031a12fca7a68d1b9a4ae2e396e50991aec0f90ff395a7d7fea15bd241" diff --git a/pyproject.toml b/pyproject.toml index f1bc232..683188e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -25,7 +25,9 @@ python = ">=3.11,<4.0" fastapi = ">=0.51,<1.0" httpx = ">=0.25.0,<1.0" openai = ">=1.32.0,<2.0" -aidial-sdk = { version = "^0.15.0", extras = ["telemetry"] } +# FIXME: revert to a release version +# aidial-sdk = { version = "^0.15.0", extras = ["telemetry"] } +aidial-sdk = { git = "https://github.com/epam/ai-dial-sdk.git", branch = "feat/support-headers-in-dial-exception", extras = ["telemetry"] } # Extras for examples aiostream = { version = "^0.6.2", optional = true } diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 38402c3..f8b2965 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -6,13 +6,13 @@ from aidial_interceptors_sdk.chat_completion.base import ( ChatCompletionNoOpInterceptor, ) -from aidial_interceptors_sdk.utils._exceptions import DialExceptionWithHeaders +from aidial_interceptors_sdk.utils._exceptions import _parse_dial_exception from tests.utils.applications import create_broken_application from tests.utils.chunks import create_chunk_checker, create_sse_stream_checker from tests.utils.dial_app import create_httpx_client from tests.utils.json import has_type, match_objects, memorize -to_many_requests_error = DialExceptionWithHeaders.create( +to_many_requests_error = _parse_dial_exception( status_code=http.HTTPStatus.TOO_MANY_REQUESTS, content={"error": {"message": "Too many requests"}}, headers={"retry-after": "42"}, @@ -52,11 +52,10 @@ async def test_interceptor_errors(stream: bool, repeats: int): actual_headers = { k.decode(): v.decode() for k, v in response.headers.raw } - # FIXME: Retry-After should actually be propagated - # See https://github.com/epam/ai-dial-sdk/blob/45681f3763679e115d95bc5ce32cf382e0083420/aidial_sdk/_errors.py#L21C1-L26C6 assert match_objects( actual_headers, { + "retry-after": "42", "content-length": has_type(str), "content-type": "application/json", },