-
Notifications
You must be signed in to change notification settings - Fork 309
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
lazi-fy huggingface, langchain serve, litellm loading #717
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
garak/generators/huggingface.py
Outdated
|
||
from transformers import LlavaNextProcessor, LlavaNextForConditionalGeneration | ||
|
||
PIL = importlib.import_module("PIL") | ||
self.Image = PIL.Image | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am a little concerned about importing libs and storing them as attributes in the generator. This will introduce similar issue with pickle
during multiprocessing that having the client in OpenAIGenerator
produced.
These might be better served as methods called to load when not already set combined with a __getstate__()
implementation similar to https://github.com/leondz/garak/blob/3833eefbb29ec10f22298963f668f2b0f99c4526/garak/generators/openai.py#L111-L114
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am doing some testing of this pattern in other huggingface changes and while I still like the idea of protecting class for safe pickle support there might be more work needed in general around how we share generator instances when allowing mulitprocessing
. As is, shifting heavy lift objects to be created in each new process could be very expensive in terms of resources.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I can get this, it makes sense.
On the other hand - if one is doing multiprocessing
with local models, the consumption of gigabytes/tens of gigabytes of GPU memory per instance seems to shrink the relative difficulties of storing libraries in the generator.
I am inclined to adopt a safer pattern here, but to not support proactively the general case of parallelisation with locally-run model, instead implementing a generators.base.Generator
attribute specifying whether a generator is parallelisation-compatible and setting it False
for 🤗 classes such as Model
and Pipeline
. Running parallel local models seems like an edge case - the parallelisation is intended for stuff that's lighter.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
implementing a generators.base.Generator attribute specifying whether a generator is parallelisation-compatible and setting it False for 🤗 classes such as Model and Pipeline
This is the approach I am targeting first.
I am also considering looking for methods that can defer multiprocessing approaches for locally executing generators to something specifically provided by the generator, for instance hugginface
provides the Accelerate library to allow the generator to execute interference using GPUs efficiently and might be reasonable to expose a _call_model_with_dataset()
or something of the sort that can recieve a set of attempts and have the generator figure out how to parallelize them.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oh, that could be really nice. I guess maybe worth doing after a pattern emerges for decoupling the attempt queue from running generation?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@jmartin-tech: heavy externals in these modules now loaded/unloaded using _load_client
/ _clear_client
pattern
self.processor = self.LlavaNextProcessor.from_pretrained(self.name) | ||
self.model = self.LlavaNextForConditionalGeneration.from_pretrained( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This all the way down to _load_client()
needs to move into _load_client()
as the processor
and model
will not transfer well in a pickle
def _clear_client(self): | ||
self.Image = None | ||
self.LlavaNextProcessor = None | ||
self.LlavaNextForConditionalGeneration = None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Processor and Model should not be included in pickle.
self.LlavaNextForConditionalGeneration = None | |
self.LlavaNextForConditionalGeneration = None | |
self.processor = None | |
self.model = None |
self.LlavaNextProcessor = transformers.LlavaNextProcessor | ||
self.LlavaNextForConditionalGeneration = ( | ||
transformers.LlavaNextForConditionalGeneration | ||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Processor and Model need to be part of this method.
) | |
) | |
self.processor = self.LlavaNextProcessor.from_pretrained(self.name) | |
self.model = self.LlavaNextForConditionalGeneration.from_pretrained( | |
self.name, | |
torch_dtype=self.torch_dtype, | |
low_cpu_mem_usage=self.low_cpu_mem_usage, | |
) | |
if torch.cuda.is_available(): | |
self.model.to(self.device_map) | |
else: | |
raise RuntimeError( | |
"CUDA is not supported on this device. Please make sure CUDA is installed and configured properly." | |
) |
@@ -142,6 +155,8 @@ def __init__(self, name: str = "", generations: int = 10, config_root=_config): | |||
" or in the configuration file" | |||
) | |||
|
|||
self._load_client() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Provider check is still needed and the key extraction should have been moved into a custom _validate_env_var()
implementation. This looks like something I missed in #711 that incorrectly enforces an API key as required for all provider
values.
self._load_client() | |
self._load_client() | |
def _validate_env_var(self): | |
if self.provider is None: | |
raise ValueError( | |
"litellm generator needs to have a provider value configured - see docs" | |
) | |
if self.provider == "openai": | |
return super()._validate_env_var() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
So, uh, while dealing with this slow module, I started getting a failed test, but noticed that I had an openai env var key set, which meant the test actually ran. Have you ever seen the litellm
tests pass? Looking at the tests we have, and the basic code examples on their website (see eg the invocations on https://docs.litellm.ai/docs/), the provider
check seems to block intended functionality.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I have never had them run and pass as they both required keys, I had noted in the original PR that it seemed like a config file would be required to instantiate the class. Although there was a comment that said it was not required, the original embedded config parsing did require a provider. If provider
was not found in _config.plugins.generators["litellm.LiteLLMGenerator"]
it would raise a ValueError
.
I intend to validate function as part of the testing here by setting up a local instance, however there is another issue with this class as the torch_dtype
value cannot be accepted as a string. I have fixes for this in progress in the refactor branch I am working one. Short term I was intending to manually patch the torch_dtype
default value to allow testing of this change in isolation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After some testing, I have validated that LiteLLM
as implemented does required a provider
from the config file. This can be enforced using something like the suggestion updated in the top of this thread and would supply a method to suppress the api key requirement when supplying something other than openai
as the provider
.
A future PR can also expand the testing to provide a mock config that will would allow for mocking a response from the generator similar to the mock openai responses recently incorporated.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think the "as implemented" is something to be flexible on (see example in #755 ). This would remove the enforcement requirement. Unfortunately I don't have a good pattern for validating the input. I'm OK with relaxing the provider constraint and letting litellm
do its own validation.
Yeah, that's my read too. I'm not sure the provider requirement aligns with
documented litellm use or the garak generator test as written.
…On Thu, Jun 13, 2024, 21:03 Jeffrey Martin ***@***.***> wrote:
***@***.**** commented on this pull request.
------------------------------
In garak/generators/litellm.py
<#717 (comment)>:
> @@ -142,6 +155,8 @@ def __init__(self, name: str = "", generations: int = 10, config_root=_config):
" or in the configuration file"
)
+ self._load_client()
I have never had them run and pass as they both required keys, I had noted
in the original PR that it seemed like a config file would be required to
instantiate the class. Although there was a comment that said it was not
required, the original embedded config parsing
<https://github.com/leondz/garak/pull/572/files#diff-ff886897d971b8e468ed44457e63890380c3c95e187ee701c2c34e61699bcf44R106-R134>
did *require* a provider. If provider was not found in
_config.plugins.generators["litellm.LiteLLMGenerator"] it would raise a
ValueError.
I intend to validate function as part of the testing here by setting up a
local instance, however there is another issue with this class as the
torch_dtype value cannot be accepted as a string. I have fixes for this
in progress in the refactor branch I am working one. Short term I was
intending to manually patch the torch_dtype default value to allow
testing of this change in isolation.
—
Reply to this email directly, view it on GitHub
<#717 (comment)>, or
unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAA5YTVJIDP3EHKQAS7HYUDZHHUJDAVCNFSM6AAAAABIWXJZQ2VHI2DSMVQWIX3LMV43YUDVNRWFEZLROVSXG5CSMV3GSZLXHMZDCMJWGY4TQNRYGQ>
.
You are receiving this because you authored the thread.Message ID:
***@***.***>
|
may be resolved by #768 in which case will close |
no need to review/merge until #711 lands