-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lazi-fy huggingface, langchain serve, litellm loading #717
Changes from all commits
d7eaf3b
a0845c7
46b5363
2d6385f
b2cf3bd
1903802
ba28a06
184cca5
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||||||
---|---|---|---|---|---|---|---|---|---|---|
|
@@ -14,6 +14,7 @@ | |||||||||
https://huggingface.co/docs/api-inference/quicktour | ||||||||||
""" | ||||||||||
|
||||||||||
import importlib | ||||||||||
import logging | ||||||||||
import re | ||||||||||
import os | ||||||||||
|
@@ -22,8 +23,6 @@ | |||||||||
|
||||||||||
import backoff | ||||||||||
import torch | ||||||||||
from PIL import Image | ||||||||||
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | ||||||||||
|
||||||||||
from garak import _config | ||||||||||
from garak.exception import ModelNameMissingError | ||||||||||
|
@@ -579,14 +578,40 @@ class LLaVA(Generator): | |||||||||
"llava-hf/llava-v1.6-mistral-7b-hf", | ||||||||||
] | ||||||||||
|
||||||||||
# avoid attempt to pickle the client attribute | ||||||||||
def __getstate__(self) -> object: | ||||||||||
self._clear_client() | ||||||||||
return dict(self.__dict__) | ||||||||||
|
||||||||||
# restore the client attribute | ||||||||||
def __setstate__(self, d) -> object: | ||||||||||
self.__dict__.update(d) | ||||||||||
self._load_client() | ||||||||||
|
||||||||||
def _load_client(self): | ||||||||||
PIL = importlib.import_module("PIL") | ||||||||||
self.Image = PIL.Image | ||||||||||
|
||||||||||
transformers = importlib.import_module("transformers") | ||||||||||
self.LlavaNextProcessor = transformers.LlavaNextProcessor | ||||||||||
self.LlavaNextForConditionalGeneration = ( | ||||||||||
transformers.LlavaNextForConditionalGeneration | ||||||||||
) | ||||||||||
|
||||||||||
def _clear_client(self): | ||||||||||
self.Image = None | ||||||||||
self.LlavaNextProcessor = None | ||||||||||
self.LlavaNextForConditionalGeneration = None | ||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Processor and Model should not be included in pickle.
Suggested change
|
||||||||||
|
||||||||||
def __init__(self, name="", generations=10, config_root=_config): | ||||||||||
super().__init__(name, generations=generations, config_root=config_root) | ||||||||||
if self.name not in self.supported_models: | ||||||||||
raise ModelNameMissingError( | ||||||||||
f"Invalid modal name {self.name}, current support: {self.supported_models}." | ||||||||||
) | ||||||||||
self.processor = LlavaNextProcessor.from_pretrained(self.name) | ||||||||||
self.model = LlavaNextForConditionalGeneration.from_pretrained( | ||||||||||
|
||||||||||
self.processor = self.LlavaNextProcessor.from_pretrained(self.name) | ||||||||||
self.model = self.LlavaNextForConditionalGeneration.from_pretrained( | ||||||||||
Comment on lines
+613
to
+614
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This all the way down to |
||||||||||
self.name, | ||||||||||
torch_dtype=self.torch_dtype, | ||||||||||
low_cpu_mem_usage=self.low_cpu_mem_usage, | ||||||||||
|
@@ -597,15 +622,16 @@ def __init__(self, name="", generations=10, config_root=_config): | |||||||||
raise RuntimeError( | ||||||||||
"CUDA is not supported on this device. Please make sure CUDA is installed and configured properly." | ||||||||||
) | ||||||||||
self._load_client() | ||||||||||
|
||||||||||
def generate( | ||||||||||
self, prompt: str, generations_this_call: int = 1 | ||||||||||
) -> List[Union[str, None]]: | ||||||||||
text_prompt = prompt["text"] | ||||||||||
try: | ||||||||||
image_prompt = Image.open(prompt["image"]) | ||||||||||
except FileNotFoundError: | ||||||||||
raise FileNotFoundError(f"Cannot open image {prompt['image']}.") | ||||||||||
image_prompt = self.Image.open(prompt["image"]) | ||||||||||
except FileNotFoundError as exc: | ||||||||||
raise FileNotFoundError(f"Cannot open image {prompt['image']}.") from exc | ||||||||||
except Exception as e: | ||||||||||
raise Exception(e) | ||||||||||
|
||||||||||
|
Original file line number | Diff line number | Diff line change | ||||||||||||||||||||
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
@@ -31,23 +31,19 @@ | |||||||||||||||||||||
``` | ||||||||||||||||||||||
""" | ||||||||||||||||||||||
|
||||||||||||||||||||||
import importlib | ||||||||||||||||||||||
import logging | ||||||||||||||||||||||
|
||||||||||||||||||||||
from os import getenv | ||||||||||||||||||||||
from typing import List, Union | ||||||||||||||||||||||
|
||||||||||||||||||||||
import backoff | ||||||||||||||||||||||
|
||||||||||||||||||||||
import litellm | ||||||||||||||||||||||
|
||||||||||||||||||||||
from garak import _config | ||||||||||||||||||||||
from garak.exception import APIKeyMissingError | ||||||||||||||||||||||
from garak.generators.base import Generator | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Fix issue with Ollama which does not support `presence_penalty` | ||||||||||||||||||||||
litellm.drop_params = True | ||||||||||||||||||||||
# Suppress log messages from LiteLLM | ||||||||||||||||||||||
litellm.verbose_logger.disabled = True | ||||||||||||||||||||||
# litellm.set_verbose = True | ||||||||||||||||||||||
|
||||||||||||||||||||||
# Based on the param support matrix below: | ||||||||||||||||||||||
# https://docs.litellm.ai/docs/completion/input | ||||||||||||||||||||||
|
@@ -109,6 +105,26 @@ class LiteLLMGenerator(Generator): | |||||||||||||||||||||
"stop", | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# avoid attempt to pickle the client attribute | ||||||||||||||||||||||
def __getstate__(self) -> object: | ||||||||||||||||||||||
self._clear_client() | ||||||||||||||||||||||
return dict(self.__dict__) | ||||||||||||||||||||||
|
||||||||||||||||||||||
# restore the client attribute | ||||||||||||||||||||||
def __setstate__(self, d) -> object: | ||||||||||||||||||||||
self.__dict__.update(d) | ||||||||||||||||||||||
self._load_client() | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _load_client(self): | ||||||||||||||||||||||
self.litellm = importlib.import_module("litellm") | ||||||||||||||||||||||
# Fix issue with Ollama which does not support `presence_penalty` | ||||||||||||||||||||||
self.litellm.drop_params = True | ||||||||||||||||||||||
# Suppress log messages from LiteLLM | ||||||||||||||||||||||
self.litellm.verbose_logger.disabled = True | ||||||||||||||||||||||
|
||||||||||||||||||||||
def _clear_client(self): | ||||||||||||||||||||||
self.litellm = None | ||||||||||||||||||||||
|
||||||||||||||||||||||
def __init__(self, name: str = "", generations: int = 10, config_root=_config): | ||||||||||||||||||||||
self.name = name | ||||||||||||||||||||||
self.api_base = None | ||||||||||||||||||||||
|
@@ -127,13 +143,10 @@ def __init__(self, name: str = "", generations: int = 10, config_root=_config): | |||||||||||||||||||||
self.name, generations=self.generations, config_root=config_root | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
if self.provider is None: | ||||||||||||||||||||||
raise ValueError( | ||||||||||||||||||||||
"litellm generator needs to have a provider value configured - see docs" | ||||||||||||||||||||||
) | ||||||||||||||||||||||
elif ( | ||||||||||||||||||||||
if ( | ||||||||||||||||||||||
self.api_key is None | ||||||||||||||||||||||
): # TODO: special case where api_key is not always required | ||||||||||||||||||||||
# TODO: add other providers | ||||||||||||||||||||||
if self.provider == "openai": | ||||||||||||||||||||||
self.api_key = getenv(self.key_env_var, None) | ||||||||||||||||||||||
if self.api_key is None: | ||||||||||||||||||||||
|
@@ -142,6 +155,8 @@ def __init__(self, name: str = "", generations: int = 10, config_root=_config): | |||||||||||||||||||||
" or in the configuration file" | ||||||||||||||||||||||
) | ||||||||||||||||||||||
|
||||||||||||||||||||||
self._load_client() | ||||||||||||||||||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Provider check is still needed and the key extraction should have been moved into a custom
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So, uh, while dealing with this slow module, I started getting a failed test, but noticed that I had an openai env var key set, which meant the test actually ran. Have you ever seen the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I have never had them run and pass as they both required keys, I had noted in the original PR that it seemed like a config file would be required to instantiate the class. Although there was a comment that said it was not required, the original embedded config parsing did require a provider. If I intend to validate function as part of the testing here by setting up a local instance, however there is another issue with this class as the There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. After some testing, I have validated that A future PR can also expand the testing to provide a mock config that will would allow for mocking a response from the generator similar to the mock openai responses recently incorporated. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think the "as implemented" is something to be flexible on (see example in #755 ). This would remove the enforcement requirement. Unfortunately I don't have a good pattern for validating the input. I'm OK with relaxing the provider constraint and letting |
||||||||||||||||||||||
|
||||||||||||||||||||||
@backoff.on_exception(backoff.fibo, Exception, max_value=70) | ||||||||||||||||||||||
def _call_model( | ||||||||||||||||||||||
self, prompt: str, generations_this_call: int = 1 | ||||||||||||||||||||||
|
@@ -159,7 +174,7 @@ def _call_model( | |||||||||||||||||||||
print(msg) | ||||||||||||||||||||||
return [] | ||||||||||||||||||||||
|
||||||||||||||||||||||
response = litellm.completion( | ||||||||||||||||||||||
response = self.litellm.completion( | ||||||||||||||||||||||
model=self.name, | ||||||||||||||||||||||
messages=prompt, | ||||||||||||||||||||||
temperature=self.temperature, | ||||||||||||||||||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Processor and Model need to be part of this method.