Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

lazi-fy huggingface, langchain serve, litellm loading #717

Closed
wants to merge 8 commits into from
40 changes: 33 additions & 7 deletions garak/generators/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
https://huggingface.co/docs/api-inference/quicktour
"""

import importlib
import logging
import re
import os
Expand All @@ -22,8 +23,6 @@

import backoff
import torch
from PIL import Image
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration

from garak import _config
from garak.exception import ModelNameMissingError
Expand Down Expand Up @@ -579,14 +578,40 @@ class LLaVA(Generator):
"llava-hf/llava-v1.6-mistral-7b-hf",
]

# avoid attempt to pickle the client attribute
def __getstate__(self) -> object:
self._clear_client()
return dict(self.__dict__)

# restore the client attribute
def __setstate__(self, d) -> object:
self.__dict__.update(d)
self._load_client()

def _load_client(self):
PIL = importlib.import_module("PIL")
self.Image = PIL.Image

transformers = importlib.import_module("transformers")
self.LlavaNextProcessor = transformers.LlavaNextProcessor
self.LlavaNextForConditionalGeneration = (
transformers.LlavaNextForConditionalGeneration
)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Processor and Model need to be part of this method.

Suggested change
)
)
self.processor = self.LlavaNextProcessor.from_pretrained(self.name)
self.model = self.LlavaNextForConditionalGeneration.from_pretrained(
self.name,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=self.low_cpu_mem_usage,
)
if torch.cuda.is_available():
self.model.to(self.device_map)
else:
raise RuntimeError(
"CUDA is not supported on this device. Please make sure CUDA is installed and configured properly."
)


def _clear_client(self):
self.Image = None
self.LlavaNextProcessor = None
self.LlavaNextForConditionalGeneration = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Processor and Model should not be included in pickle.

Suggested change
self.LlavaNextForConditionalGeneration = None
self.LlavaNextForConditionalGeneration = None
self.processor = None
self.model = None


def __init__(self, name="", generations=10, config_root=_config):
super().__init__(name, generations=generations, config_root=config_root)
if self.name not in self.supported_models:
raise ModelNameMissingError(
f"Invalid modal name {self.name}, current support: {self.supported_models}."
)
self.processor = LlavaNextProcessor.from_pretrained(self.name)
self.model = LlavaNextForConditionalGeneration.from_pretrained(

self.processor = self.LlavaNextProcessor.from_pretrained(self.name)
self.model = self.LlavaNextForConditionalGeneration.from_pretrained(
Comment on lines +613 to +614
Copy link
Collaborator

@jmartin-tech jmartin-tech Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This all the way down to _load_client() needs to move into _load_client() as the processor and model will not transfer well in a pickle

self.name,
torch_dtype=self.torch_dtype,
low_cpu_mem_usage=self.low_cpu_mem_usage,
Expand All @@ -597,15 +622,16 @@ def __init__(self, name="", generations=10, config_root=_config):
raise RuntimeError(
"CUDA is not supported on this device. Please make sure CUDA is installed and configured properly."
)
self._load_client()

def generate(
self, prompt: str, generations_this_call: int = 1
) -> List[Union[str, None]]:
text_prompt = prompt["text"]
try:
image_prompt = Image.open(prompt["image"])
except FileNotFoundError:
raise FileNotFoundError(f"Cannot open image {prompt['image']}.")
image_prompt = self.Image.open(prompt["image"])
except FileNotFoundError as exc:
raise FileNotFoundError(f"Cannot open image {prompt['image']}.") from exc
except Exception as e:
raise Exception(e)

Expand Down
39 changes: 27 additions & 12 deletions garak/generators/litellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,23 +31,19 @@
```
"""

import importlib
import logging

from os import getenv
from typing import List, Union

import backoff

import litellm

from garak import _config
from garak.exception import APIKeyMissingError
from garak.generators.base import Generator

# Fix issue with Ollama which does not support `presence_penalty`
litellm.drop_params = True
# Suppress log messages from LiteLLM
litellm.verbose_logger.disabled = True
# litellm.set_verbose = True

# Based on the param support matrix below:
# https://docs.litellm.ai/docs/completion/input
Expand Down Expand Up @@ -109,6 +105,26 @@ class LiteLLMGenerator(Generator):
"stop",
)

# avoid attempt to pickle the client attribute
def __getstate__(self) -> object:
self._clear_client()
return dict(self.__dict__)

# restore the client attribute
def __setstate__(self, d) -> object:
self.__dict__.update(d)
self._load_client()

def _load_client(self):
self.litellm = importlib.import_module("litellm")
# Fix issue with Ollama which does not support `presence_penalty`
self.litellm.drop_params = True
# Suppress log messages from LiteLLM
self.litellm.verbose_logger.disabled = True

def _clear_client(self):
self.litellm = None

def __init__(self, name: str = "", generations: int = 10, config_root=_config):
self.name = name
self.api_base = None
Expand All @@ -127,13 +143,10 @@ def __init__(self, name: str = "", generations: int = 10, config_root=_config):
self.name, generations=self.generations, config_root=config_root
)

if self.provider is None:
raise ValueError(
"litellm generator needs to have a provider value configured - see docs"
)
elif (
if (
self.api_key is None
): # TODO: special case where api_key is not always required
# TODO: add other providers
if self.provider == "openai":
self.api_key = getenv(self.key_env_var, None)
if self.api_key is None:
Expand All @@ -142,6 +155,8 @@ def __init__(self, name: str = "", generations: int = 10, config_root=_config):
" or in the configuration file"
)

self._load_client()
Copy link
Collaborator

@jmartin-tech jmartin-tech Jun 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Provider check is still needed and the key extraction should have been moved into a custom _validate_env_var() implementation. This looks like something I missed in #711 that incorrectly enforces an API key as required for all provider values.

Suggested change
self._load_client()
self._load_client()
def _validate_env_var(self):
if self.provider is None:
raise ValueError(
"litellm generator needs to have a provider value configured - see docs"
)
if self.provider == "openai":
return super()._validate_env_var()

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So, uh, while dealing with this slow module, I started getting a failed test, but noticed that I had an openai env var key set, which meant the test actually ran. Have you ever seen the litellm tests pass? Looking at the tests we have, and the basic code examples on their website (see eg the invocations on https://docs.litellm.ai/docs/), the provider check seems to block intended functionality.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I have never had them run and pass as they both required keys, I had noted in the original PR that it seemed like a config file would be required to instantiate the class. Although there was a comment that said it was not required, the original embedded config parsing did require a provider. If provider was not found in _config.plugins.generators["litellm.LiteLLMGenerator"] it would raise a ValueError.

I intend to validate function as part of the testing here by setting up a local instance, however there is another issue with this class as the torch_dtype value cannot be accepted as a string. I have fixes for this in progress in the refactor branch I am working one. Short term I was intending to manually patch the torch_dtype default value to allow testing of this change in isolation.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After some testing, I have validated that LiteLLM as implemented does required a provider from the config file. This can be enforced using something like the suggestion updated in the top of this thread and would supply a method to suppress the api key requirement when supplying something other than openai as the provider.

A future PR can also expand the testing to provide a mock config that will would allow for mocking a response from the generator similar to the mock openai responses recently incorporated.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think the "as implemented" is something to be flexible on (see example in #755 ). This would remove the enforcement requirement. Unfortunately I don't have a good pattern for validating the input. I'm OK with relaxing the provider constraint and letting litellm do its own validation.


@backoff.on_exception(backoff.fibo, Exception, max_value=70)
def _call_model(
self, prompt: str, generations_this_call: int = 1
Expand All @@ -159,7 +174,7 @@ def _call_model(
print(msg)
return []

response = litellm.completion(
response = self.litellm.completion(
model=self.name,
messages=prompt,
temperature=self.temperature,
Expand Down
Loading