Skip to content

Commit

Permalink
fix: fixed propagation of exceptions from upstream endpoints (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
adubovik authored Dec 2, 2024
1 parent e45e651 commit 48ad4d9
Show file tree
Hide file tree
Showing 34 changed files with 1,243 additions and 1,052 deletions.
3 changes: 2 additions & 1 deletion .env.example
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
LOG_LEVEL=DEBUG
LOG_LEVEL=INFO
DIAL_SDK_LOG=WARNING
DIAL_URL=DIAL_URL
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,7 @@ The command will start the server on `http://localhost:5000` exposing endpoints
First clone the repository:

```sh
git clone https://github.com/epam/ai-dial-interceptors-sdk.git .
git clone https://github.com/epam/ai-dial-interceptors-sdk.git
cd ai-dial-interceptors-sdk
echo "DIAL_URL=URL" > .env
```
Expand Down
108 changes: 68 additions & 40 deletions aidial_interceptors_sdk/chat_completion/adapter.py
Original file line number Diff line number Diff line change
@@ -1,45 +1,57 @@
import json
import logging
from typing import Any, AsyncIterator, Type, cast
from typing import AsyncIterator, Type, cast

from aidial_sdk.chat_completion import ChatCompletion as DialChatCompletion
from aidial_sdk.chat_completion import Request as DialRequest
from aidial_sdk.chat_completion import Response as DialResponse
from aidial_sdk.chat_completion.chunks import DefaultChunk
from aidial_sdk.exceptions import HTTPException as DialException
from openai import AsyncStream
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk

from aidial_interceptors_sdk.chat_completion.annotated_value import (
AnnotatedException,
Annotation,
)
from aidial_interceptors_sdk.chat_completion.base import (
ChatCompletionInterceptor,
RequestDict,
)
from aidial_interceptors_sdk.dial_client import DialClient
from aidial_interceptors_sdk.error import EarlyStreamExit
from aidial_interceptors_sdk.utils._debug import debug_logging
from aidial_interceptors_sdk.utils._exceptions import dial_exception_decorator
from aidial_interceptors_sdk.utils._http_client import HTTPClientFactory
from aidial_interceptors_sdk.utils._reflection import call_with_extra_body
from aidial_interceptors_sdk.utils.streaming import (
block_response_to_streaming_chunk,
handle_streaming_errors,
map_stream,
materialize_streaming_errors,
singleton_stream,
)

_log = logging.getLogger(__name__)
_debug = _log.isEnabledFor(logging.DEBUG)


def interceptor_to_chat_completion(
cls: Type[ChatCompletionInterceptor],
dial_url: str,
client_factory: HTTPClientFactory,
) -> DialChatCompletion:
class Impl(DialChatCompletion):
@dial_exception_decorator
async def chat_completion(
self, request: DialRequest, response: DialResponse
) -> None:
dial_client = await DialClient.create(
dial_url=dial_url,
api_key=request.api_key,
api_version=request.api_version,
authorization=request.jwt,
headers=request.headers,
client_factory=client_factory,
)

interceptor = cls(
Expand All @@ -53,52 +65,68 @@ async def chat_completion(
interceptor.traverse_request
)(request_body)

async def call_upstream(
request: dict, call_context: Any | None
) -> AsyncIterator[dict]:
upstream_response = cast(
AsyncStream[ChatCompletionChunk] | ChatCompletion,
await call_with_extra_body(
dial_client.client.chat.completions.create, request
),
)

if isinstance(upstream_response, ChatCompletion):
resp = upstream_response.to_dict()
if _debug:
_log.debug(
f"upstream response[{call_context}]: {json.dumps(resp)}"
)

chunk = block_response_to_streaming_chunk(resp)
stream = singleton_stream(chunk)
else:

def on_upstream_chunk(chunk: ChatCompletionChunk) -> dict:
d = chunk.to_dict()
if _debug:
_log.debug(
f"upstream chunk[{call_context}]: {json.dumps(d)}"
)
return d

stream = map_stream(on_upstream_chunk, upstream_response)

return handle_streaming_errors(stream)

try:
await interceptor.on_stream_start()

async for chunk in await interceptor.call_upstreams(
def call_upstream(context: Annotation, request: dict):
return call_single_upstream(dial_client, context, request)

async for value in await interceptor.call_upstreams(
request_body, call_upstream
):
if "error" in chunk.chunk:
await interceptor.on_stream_error(chunk)
if isinstance(value, AnnotatedException):
await interceptor.on_stream_error(value)
else:
await interceptor.traverse_response_chunk(chunk)
await interceptor.traverse_response_chunk(value)

await interceptor.on_stream_end()
except EarlyStreamExit:
pass

return Impl()


async def call_single_upstream(
dial_client: DialClient, context: Annotation, request: RequestDict
) -> AsyncIterator[dict | DialException]:
response = cast(
AsyncStream[ChatCompletionChunk] | ChatCompletion,
await call_with_extra_body(
dial_client.client.chat.completions.create, request
),
)

if isinstance(response, ChatCompletion):
resp = response.to_dict()
if _log.isEnabledFor(logging.DEBUG):
_log.debug(f"upstream response[{context}]: {json.dumps(resp)}")

# Non-streaming mode:
# Removing the default fields which are generated by
# DIAL SDK automatically.
# It also means that these fields aren't proxied from the upstream.
# They are recreated on each interceptor call.
# If the fields aren't removed, then they will be merged
# recursively with the one generated by SDK and we will end up with
# "object": "chat.completionchat.completionchat.completion"
for key in DefaultChunk.__annotations__.keys():
resp.pop(key, None)

chunk = block_response_to_streaming_chunk(resp)
stream = singleton_stream(chunk)
else:
# Streaming mode:
# No need to remove default fields, because
# they will be automatically overridden by the default fields
# generated by DIAL SDK, when each chunk is merged naively with
# a default chunk.

def on_upstream_chunk(chunk: ChatCompletionChunk) -> dict:
d = chunk.to_dict()
if _log.isEnabledFor(logging.DEBUG):
_log.debug(f"upstream chunk[{context}]: {json.dumps(d)}")
return d

stream = map_stream(on_upstream_chunk, response)

return materialize_streaming_errors(stream)
8 changes: 0 additions & 8 deletions aidial_interceptors_sdk/chat_completion/annotated_chunk.py

This file was deleted.

25 changes: 25 additions & 0 deletions aidial_interceptors_sdk/chat_completion/annotated_value.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from abc import ABC
from typing import Any

from aidial_sdk.exceptions import HTTPException as DialException
from aidial_sdk.pydantic_v1 import BaseModel

Annotation = Any | None


class AnnotatedValueBase(BaseModel, ABC):
class Config:
arbitrary_types_allowed = True

annotation: Annotation = None


class AnnotatedChunk(AnnotatedValueBase):
chunk: dict


class AnnotatedException(AnnotatedValueBase):
error: DialException


AnnotatedValue = AnnotatedChunk | AnnotatedException
34 changes: 20 additions & 14 deletions aidial_interceptors_sdk/chat_completion/base.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
from typing import Any, AsyncIterator, Callable, Coroutine
from typing import AsyncIterator, Awaitable, Callable

from aidial_interceptors_sdk.chat_completion.annotated_chunk import (
AnnotatedChunk,
from aidial_sdk.exceptions import HTTPException as DialException

from aidial_interceptors_sdk.chat_completion.annotated_value import (
AnnotatedException,
AnnotatedValue,
Annotation,
)
from aidial_interceptors_sdk.chat_completion.request_handler import (
RequestHandler,
Expand All @@ -10,24 +14,26 @@
ResponseHandler,
)
from aidial_interceptors_sdk.dial_client import DialClient
from aidial_interceptors_sdk.utils.streaming import annotate_stream

RequestDict = dict


class ChatCompletionInterceptor(RequestHandler, ResponseHandler):
dial_client: DialClient

async def call_upstreams(
self,
request: dict,
request: RequestDict,
call_upstream: Callable[
[dict, Any | None], Coroutine[Any, Any, AsyncIterator[dict]]
[Annotation, RequestDict],
Awaitable[AsyncIterator[dict | DialException]],
],
) -> AsyncIterator[AnnotatedChunk]:
async def iterator():
call_context = None
async for chunk in await call_upstream(request, call_context):
yield AnnotatedChunk(chunk=chunk, annotation=call_context)

return iterator()
) -> AsyncIterator[AnnotatedValue]:
annotation = None
return annotate_stream(
annotation, await call_upstream(annotation, request)
)

async def on_stream_start(self) -> None:
# TODO: it's probably worth to put all the chunks
Expand All @@ -37,8 +43,8 @@ async def on_stream_start(self) -> None:
# its "assistant" role is reported.
pass

async def on_stream_error(self, error: AnnotatedChunk) -> None:
self.send_chunk(error.chunk)
async def on_stream_error(self, error: AnnotatedException) -> None:
raise error.error

async def on_stream_end(self) -> None:
# TODO: it's probably worth to withhold the last chunk generated by
Expand Down
26 changes: 13 additions & 13 deletions aidial_interceptors_sdk/chat_completion/helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, Callable, Coroutine, List, TypeVar, overload
from typing import Awaitable, Callable, List, TypeVar, overload

from aidial_interceptors_sdk.utils.not_given import NOT_GIVEN, NotGiven

Expand All @@ -13,7 +13,7 @@ async def traverse_dict_value(
key: str,
on_value: Callable[
[P, T | NotGiven | None],
Coroutine[Any, Any, T | NotGiven | None],
Awaitable[T | NotGiven | None],
],
) -> dict: ...

Expand All @@ -25,7 +25,7 @@ async def traverse_dict_value(
key: str,
on_value: Callable[
[P, T | NotGiven | None],
Coroutine[Any, Any, T | NotGiven | None],
Awaitable[T | NotGiven | None],
],
) -> NotGiven: ...

Expand All @@ -37,7 +37,7 @@ async def traverse_dict_value(
key: str,
on_value: Callable[
[P, T | NotGiven | None],
Coroutine[Any, Any, T | NotGiven | None],
Awaitable[T | NotGiven | None],
],
) -> None: ...

Expand All @@ -48,7 +48,7 @@ async def traverse_dict_value(
key: str,
on_value: Callable[
[P, T | NotGiven | None],
Coroutine[Any, Any, T | NotGiven | None],
Awaitable[T | NotGiven | None],
],
) -> dict | NotGiven | None:
if d is None or isinstance(d, NotGiven):
Expand All @@ -71,7 +71,7 @@ async def traverse_required_dict_value(
path: P,
d: None,
key: str,
on_value: Callable[[P, T], Coroutine[Any, Any, T]],
on_value: Callable[[P, T], Awaitable[T]],
) -> None: ...


Expand All @@ -80,7 +80,7 @@ async def traverse_required_dict_value(
path: P,
d: NotGiven,
key: str,
on_value: Callable[[P, T], Coroutine[Any, Any, T]],
on_value: Callable[[P, T], Awaitable[T]],
) -> NotGiven: ...


Expand All @@ -89,15 +89,15 @@ async def traverse_required_dict_value(
path: P,
d: dict,
key: str,
on_value: Callable[[P, T], Coroutine[Any, Any, T]],
on_value: Callable[[P, T], Awaitable[T]],
) -> dict: ...


async def traverse_required_dict_value(
path: P,
d: dict | NotGiven | None,
key: str,
on_value: Callable[[P, T], Coroutine[Any, Any, T]],
on_value: Callable[[P, T], Awaitable[T]],
) -> dict | NotGiven | None:
if d is None or isinstance(d, NotGiven):
return d
Expand All @@ -115,30 +115,30 @@ async def traverse_required_dict_value(
async def traverse_list(
create_elem_path: Callable[[int], P],
lst: NotGiven,
on_elem: Callable[[P, T], Coroutine[Any, Any, List[T] | T]],
on_elem: Callable[[P, T], Awaitable[List[T] | T]],
) -> NotGiven: ...


@overload
async def traverse_list(
create_elem_path: Callable[[int], P],
lst: None,
on_elem: Callable[[P, T], Coroutine[Any, Any, List[T] | T]],
on_elem: Callable[[P, T], Awaitable[List[T] | T]],
) -> None: ...


@overload
async def traverse_list(
create_elem_path: Callable[[int], P],
lst: List[T],
on_elem: Callable[[P, T], Coroutine[Any, Any, List[T] | T]],
on_elem: Callable[[P, T], Awaitable[List[T] | T]],
) -> List[T]: ...


async def traverse_list(
create_elem_path: Callable[[int], P],
lst: List[T] | NotGiven | None,
on_elem: Callable[[P, T], Coroutine[Any, Any, List[T] | T]],
on_elem: Callable[[P, T], Awaitable[List[T] | T]],
) -> List[T] | NotGiven | None:
if lst is None or isinstance(lst, NotGiven):
return lst
Expand Down
Loading

0 comments on commit 48ad4d9

Please sign in to comment.