Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

remove extraneous default params for nims that expect conservative pa… #749

Merged
merged 11 commits into from
Jun 25, 2024
1 change: 1 addition & 0 deletions garak/generators/nim.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class NVOpenAIChat(OpenAICompatible):
"top_p": 0.7,
"top_k": 0, # top_k is hard set to zero as of 24.04.30
"uri": "https://integrate.api.nvidia.com/v1/",
"suppressed_params": {"n", "frequency_penalty", "presence_penalty"},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have a ref for these we can throw into a comment so we can keep it up to date?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This was learned through trial by 422 / 400 error at the moment

}
active = True
supports_multiple_generations = False
Expand Down
49 changes: 21 additions & 28 deletions garak/generators/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,9 @@ class OpenAICompatible(Generator):
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
"stop": ["#", ";"],
"suppressed_params": set(),
}

temperature = 0.7
top_p = 1.0
frequency_penalty = 0.0
presence_penalty = 0.0
stop = ["#", ";"]

# avoid attempt to pickle the client attribute
def __getstate__(self) -> object:
self._clear_client()
Expand Down Expand Up @@ -162,6 +157,20 @@ def _call_model(
if self.client is None:
# reload client once when consuming the generator
self._load_client()

create_args = {
"model": self.name,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"n": generations_this_call,
"top_p": self.top_p,
"frequency_penalty": self.frequency_penalty,
"presence_penalty": self.presence_penalty,
"stop": self.stop,
}

create_args = {k: v for k, v in create_args.items() if v is not None and k not in self.suppressed_params}

if self.generator == self.client.completions:
if not isinstance(prompt, str):
msg = (
Expand All @@ -172,17 +181,9 @@ def _call_model(
print(msg)
return list()

response = self.generator.create(
model=self.name,
prompt=prompt,
temperature=self.temperature,
max_tokens=self.max_tokens,
n=generations_this_call,
top_p=self.top_p,
frequency_penalty=self.frequency_penalty,
presence_penalty=self.presence_penalty,
stop=self.stop,
)
create_args["prompt"] = prompt

response = self.generator.create(**create_args)
return [c.text for c in response.choices]

elif self.generator == self.client.chat.completions:
Expand All @@ -199,17 +200,9 @@ def _call_model(
print(msg)
return list()
try:
response = self.generator.create(
model=self.name,
messages=messages,
temperature=self.temperature,
top_p=self.top_p,
n=generations_this_call,
stop=self.stop,
max_tokens=self.max_tokens,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
)
create_args["messages"] = messages
response = self.generator.create(**create_args)

return [c.message.content for c in response.choices]
except openai.BadRequestError:
msg = "Bad request: " + str(repr(prompt))
Expand Down
18 changes: 18 additions & 0 deletions tests/generators/test_nim.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,3 +52,21 @@ def test_nim_parallel_attempts():
def test_nim_hf_detector():
garak.cli.main("-m nim -p lmrc.Bullying -g 1 -n google/gemma-2b".split())
assert True


@pytest.mark.skipif(
os.getenv(NVOpenAIChat.ENV_VAR, None) is None,
reason=f"NIM API key is not set in {NVOpenAIChat.ENV_VAR}",
)
def test_nim_conservative_api(): # extraneous params can throw 422
g = NVOpenAIChat(name="nvidia/nemotron-4-340b-instruct")
result = g._call_model("this is a test", generations_this_call=1)
assert isinstance(result, list), "NIM _call_model should return a list"
assert len(result) == 1, "NIM _call_model result list should have one item"
assert isinstance(result[0], str), "NIM _call_model should return a list"
result = g.generate("this is a test", generations_this_call=1)
assert isinstance(result, list), "NIM generate() should return a list"
assert (
len(result) == 1
), "NIM generate() result list should have one item when generations_this_call=1"
assert isinstance(result[0], str), "NIM generate() should return a list"
Loading