Skip to content

Commit

Permalink
add xai grok llm provider (#75)
Browse files Browse the repository at this point in the history
  • Loading branch information
djcopley authored Nov 29, 2024
1 parent cf620d5 commit a24b10d
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/shelloracle/providers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,9 @@ def _providers() -> dict[str, type[Provider]]:
from shelloracle.providers.localai import LocalAI
from shelloracle.providers.ollama import Ollama
from shelloracle.providers.openai import OpenAI
from shelloracle.providers.xai import XAI

return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI}
return {Ollama.name: Ollama, OpenAI.name: OpenAI, LocalAI.name: LocalAI, XAI.name: XAI}


def get_provider(name: str) -> type[Provider]:
Expand Down
38 changes: 38 additions & 0 deletions src/shelloracle/providers/xai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
from collections.abc import AsyncIterator

from openai import APIError, AsyncOpenAI

from shelloracle.providers import Provider, ProviderError, Setting, system_prompt


class XAI(Provider):
name = "XAI"

api_key = Setting(default="")
model = Setting(default="grok-beta")

def __init__(self):
if not self.api_key:
msg = "No API key provided"
raise ProviderError(msg)
self.client = AsyncOpenAI(
api_key=self.api_key,
base_url="https://api.x.ai/v1",
)

async def generate(self, prompt: str) -> AsyncIterator[str]:
try:
stream = await self.client.chat.completions.create(
model=self.model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": prompt},
],
stream=True,
)
async for chunk in stream:
if chunk.choices[0].delta.content is not None:
yield chunk.choices[0].delta.content
except APIError as e:
msg = f"Something went wrong while querying XAI: {e}"
raise ProviderError(msg) from e
41 changes: 41 additions & 0 deletions tests/providers/test_xai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import pytest

from shelloracle.providers.xai import XAI


class TestOpenAI:
@pytest.fixture
def xai_config(self, set_config):
config = {
"shelloracle": {"provider": "XAI"},
"provider": {
"XAI": {
"api_key": "xai-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx",
"model": "grok-beta",
}
},
}
set_config(config)

@pytest.fixture
def xai_instance(self, xai_config):
return XAI()

def test_name(self):
assert XAI.name == "XAI"

def test_api_key(self, xai_instance):
assert (
xai_instance.api_key
== "xai-xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
)

def test_model(self, xai_instance):
assert xai_instance.model == "grok-beta"

@pytest.mark.asyncio
async def test_generate(self, mock_asyncopenai, xai_instance):
result = ""
async for response in xai_instance.generate(""):
result += response
assert result == "head -c 100 /dev/urandom | hexdump -C"

0 comments on commit a24b10d

Please sign in to comment.