Skip to content

Commit

Permalink
OpenAI upgrade (#477)
Browse files Browse the repository at this point in the history
* Update gitignore. Upgrade openai to version 1.12.0 and refactor OpenAIGenerator to handle new calling conventions. Add test_openai.py.

* black, update pyproject.toml

---------

Co-authored-by: Leon Derczynski <[email protected]>
  • Loading branch information
erickgalinkin and leondz authored Feb 15, 2024
1 parent 286a7b1 commit df0c8c4
Show file tree
Hide file tree
Showing 5 changed files with 91 additions and 24 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,4 @@ garak.log
hitlog.*.jsonl
.vscode
runs/
logs/
66 changes: 44 additions & 22 deletions garak/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@

import os
import re
from typing import List
import logging
from typing import List, Union

import openai
import backoff
Expand All @@ -39,10 +40,7 @@
"gpt-4",
"gpt-4-32k",
"gpt-3.5-turbo",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo-16k",
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-16k-0613",
"gpt-3.5-turbo-0301",
)

Expand All @@ -65,24 +63,26 @@ def __init__(self, name, generations=10):

super().__init__(name, generations=generations)

openai.api_key = os.getenv("OPENAI_API_KEY", default=None)
if openai.api_key is None:
api_key = os.getenv("OPENAI_API_KEY", default=None)
if api_key is None:
raise ValueError(
'Put the OpenAI API key in the OPENAI_API_KEY environment variable (this was empty)\n \
e.g.: export OPENAI_API_KEY="sk-123XXXXXXXXXXXX"'
)

self.client = openai.OpenAI(api_key=api_key)

if self.name in completion_models:
self.generator = openai.Completion
self.generator = self.client.completions
elif self.name in chat_models:
self.generator = openai.ChatCompletion
self.generator = self.client.chat.completions
elif "-".join(self.name.split("-")[:-1]) in chat_models and re.match(
r"^.+-[01][0-9][0-3][0-9]$", self.name
): # handle model names -MMDDish suffix
self.generator = openai.ChatCompletion
self.generator = self.client.completions

elif self.name == "":
openai_model_list = sorted([m["id"] for m in openai.Model().list()["data"]])
openai_model_list = sorted([m.id for m in self.client.models.list().data])
raise ValueError(
"Model name is required for OpenAI, use --model_name\n"
+ " API returns following available models: ▶️ "
Expand All @@ -95,19 +95,29 @@ def __init__(self, name, generations=10):
f"No OpenAI API defined for '{self.name}' in generators/openai.py - please add one!"
)

# noinspection PyArgumentList
@backoff.on_exception(
backoff.fibo,
(
openai.error.RateLimitError,
openai.error.ServiceUnavailableError,
openai.error.APIError,
openai.error.Timeout,
openai.error.APIConnectionError,
openai.RateLimitError,
openai.InternalServerError,
openai.APIError,
openai.APITimeoutError,
openai.APIConnectionError,
),
max_value=70,
)
def _call_model(self, prompt: str) -> List[str]:
if self.generator == openai.Completion:
def _call_model(self, prompt: Union[str, list[dict]]) -> List[str]:
if self.generator == self.client.completions:
if not isinstance(prompt, str):
msg = (
f"Expected a string for OpenAI completions model {self.name}, but got {type(prompt)}. "
f"Returning nothing!"
)
logging.error(msg)
print(msg)
return list()

response = self.generator.create(
model=self.name,
prompt=prompt,
Expand All @@ -119,12 +129,24 @@ def _call_model(self, prompt: str) -> List[str]:
presence_penalty=self.presence_penalty,
stop=self.stop,
)
return [c["text"] for c in response["choices"]]

elif self.generator == openai.ChatCompletion:
return [c.text for c in response.choices]

elif self.generator == self.client.chat.completions:
if isinstance(prompt, str):
messages = [{"role": "user", "content": prompt}]
elif isinstance(prompt, list):
messages = prompt
else:
msg = (
f"Expected a list of dicts for OpenAI Chat model {self.name}, but got {type(prompt)} instead. "
f"Returning nothing!"
)
logging.error(msg)
print(msg)
return list()
response = self.generator.create(
model=self.name,
messages=[{"role": "user", "content": prompt}],
messages=messages,
temperature=self.temperature,
top_p=self.top_p,
n=self.generations,
Expand All @@ -133,7 +155,7 @@ def _call_model(self, prompt: str) -> List[str]:
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
)
return [c["message"]["content"] for c in response["choices"]]
return [c.message.content for c in response.choices]

else:
raise ValueError(
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ dependencies = [
"colorama>=0.4.3",
"tqdm>=4.64.0",
"cohere>=4.5.1",
"openai>=0.27.7,<1.0.0",
"openai==1.12.0",
"replicate>=0.8.3",
"pytest>=8.0",
"google-api-python-client>=2.0",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ datasets>=2
colorama>=0.4.3
tqdm>=4.64.0
cohere>=4.5.1
openai>=0.27.7,<1.0.0
openai==1.12.0
replicate>=0.8.3
pytest>=8.0
google-api-python-client>=2.0
Expand Down
44 changes: 44 additions & 0 deletions tests/generators/test_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
from garak.generators.openai import OpenAIGenerator

DEFAULT_GENERATIONS_QTY = 10


def test_completion():
generator = OpenAIGenerator(name="gpt-3.5-turbo-instruct")
assert generator.name == "gpt-3.5-turbo-instruct"
assert generator.generations == DEFAULT_GENERATIONS_QTY
assert isinstance(generator.max_tokens, int)
generator.max_tokens = 99
assert generator.max_tokens == 99
generator.temperature = 0.5
assert generator.temperature == 0.5
output = generator.generate("How could I possibly ")
assert len(output) == DEFAULT_GENERATIONS_QTY
for item in output:
assert isinstance(item, str)
print("test passed!")


def test_chat():
generator = OpenAIGenerator(name="gpt-3.5-turbo")
assert generator.name == "gpt-3.5-turbo"
assert generator.generations == DEFAULT_GENERATIONS_QTY
assert isinstance(generator.max_tokens, int)
generator.max_tokens = 99
assert generator.max_tokens == 99
generator.temperature = 0.5
assert generator.temperature == 0.5
output = generator.generate("Hello OpenAI!")
assert len(output) == DEFAULT_GENERATIONS_QTY
for item in output:
assert isinstance(item, str)
messages = [
{"role": "user", "content": "Hello OpenAI!"},
{"role": "assistant", "content": "Hello! How can I help you today?"},
{"role": "user", "content": "How do I write a sonnet?"},
]
output = generator.generate(messages)
assert len(output) == DEFAULT_GENERATIONS_QTY
for item in output:
assert isinstance(item, str)
print("test passed!")

0 comments on commit df0c8c4

Please sign in to comment.