Skip to content

Commit

Permalink
fix(ai-plugins): add multi model support to ai-* plugin (#10782)
Browse files Browse the repository at this point in the history
  • Loading branch information
oowl authored Nov 22, 2024
1 parent c4824b8 commit f62a799
Show file tree
Hide file tree
Showing 9 changed files with 72 additions and 13 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: |
"**ai-prompt-guard**: Fixed an issue where the *ai-prompt-guard* plugin could fail when handling requests with multiple models."
type: bugfix
scope: Plugin
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: |
"**ai-semantic-cache**: Fixed an issue where the plugin failed when handling requests with multiple models."
type: bugfix
scope: Plugin
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
message: |
"**ai-semantic-prompt-guard**: Fixed an issue where requests with multiple models caused failures."
type: bugfix
scope: Plugin
10 changes: 7 additions & 3 deletions kong/plugins/ai-prompt-guard/filters/guard-prompt.lua
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
local buffer = require("string.buffer")
local ngx_re_find = ngx.re.find
local ai_plugin_ctx = require("kong.llm.plugin.ctx")
local cjson = require("cjson")

local _M = {
NAME = "guard-prompt",
Expand Down Expand Up @@ -58,11 +59,14 @@ local execute do
return nil, bad_format_error
end
if v.role == "user" or conf.match_all_roles then
if type(v.content) ~= "string" then
if type(v.content) == "string" then
buf:put(v.content)
elseif type(v.content) == "table" then
local content = cjson.encode(v.content)
buf:put(content)
else
return nil, bad_format_error
end
buf:put(v.content)

if just_pick_latest then
break
end
Expand Down
29 changes: 25 additions & 4 deletions kong/plugins/ai-semantic-cache/filters/search-cache.lua
Original file line number Diff line number Diff line change
Expand Up @@ -81,18 +81,39 @@ local function format_chat(messages, countback, discard_system, discard_assistan
end

local buf = buffer.new()
local content

for i = #messages, #messages - countback + 1, -1 do
local message = messages[i]
if message then
if message.role == "system" and not discard_system then
buf:putf("%s: %s\n\n", message.role, message.content)
if type(message.content) == "table" then
content = cjson.encode(message.content)
else
content = message.content
end
buf:putf("%s: %s\n\n", message.role, content)
elseif message.role == "assistant" and not discard_assistant then
buf:putf("%s: %s\n\n", message.role, message.content)
if type(message.content) == "table" then
content = cjson.encode(message.content)
else
content = message.content
end
buf:putf("%s: %s\n\n", message.role, content)
elseif message.role == "user" then
buf:putf("%s\n\n", message.content)
if type(message.content) == "table" then
content = cjson.encode(message.content)
else
content = message.content
end
buf:putf("%s: %s\n\n", message.role, content)
elseif message.role == "tool" and not discard_tool then
buf:putf("%s: %s\n\n", message.role, message.content)
if type(message.content) == "table" then
content = cjson.encode(message.content)
else
content = message.content
end
buf:putf("%s: %s\n\n", message.role, content)
end
end
end
Expand Down
10 changes: 7 additions & 3 deletions kong/plugins/ai-semantic-prompt-guard/filters/guard-prompt.lua
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ local buffer = require("string.buffer")
local guard = require("kong.plugins.ai-semantic-prompt-guard.guard")
local vectordb = require("kong.llm.vectordb")
local ai_plugin_ctx = require("kong.llm.plugin.ctx")

local cjson = require("cjson")

local _M = {
NAME = "semantic-guard-prompt",
Expand Down Expand Up @@ -54,10 +54,14 @@ local execute do
return nil, bad_format_error
end
if v.role == "user" or conf.rules.match_all_roles then
if type(v.content) ~= "string" then
if type(v.content) == "string" then
buf:put(v.content)
elseif type(v.content) == "table" then
local content = cjson.encode(v.content)
buf:put(content)
else
return nil, bad_format_error
end
buf:put(v.content)

if not conf.rules.match_all_conversation_history then
break
Expand Down
6 changes: 3 additions & 3 deletions spec-ee/03-plugins/43-ai-semantic-cache/01-unit_spec.lua
Original file line number Diff line number Diff line change
Expand Up @@ -78,15 +78,15 @@ describe(PLUGIN_NAME .. ": (unit)", function()

-- test truncate to two messages
output = search_cache._format_chat(samples["llm/v1/chat"]["valid"]["messages"], 2, false, false)
assert.same(output, 'What is 2π?\n\nassistant: Pi (π) is a mathematical constant that represents the ratio of a circle\'s circumference to its diameter. This ratio is constant for all circles and is approximately equal to 3.14159.\n\n')
assert.same(output, 'user: What is 2π?\n\nassistant: Pi (π) is a mathematical constant that represents the ratio of a circle\'s circumference to its diameter. This ratio is constant for all circles and is approximately equal to 3.14159.\n\n')

-- test discard system messages
output = search_cache._format_chat(samples["llm/v1/chat"]["valid"]["messages"], 20, true, false)
assert.same(output, 'What is 2π?\n\nassistant: Pi (π) is a mathematical constant that represents the ratio of a circle\'s circumference to its diameter. This ratio is constant for all circles and is approximately equal to 3.14159.\n\nWhat is Pi?\n\n')
assert.same(output, 'user: What is 2π?\n\nassistant: Pi (π) is a mathematical constant that represents the ratio of a circle\'s circumference to its diameter. This ratio is constant for all circles and is approximately equal to 3.14159.\n\nuser: What is Pi?\n\n')

-- test discard assistant messages
output = search_cache._format_chat(samples["llm/v1/chat"]["valid"]["messages"], 20, false, true)
assert.same(output, 'What is 2π?\n\nWhat is Pi?\n\nsystem: You are a mathematician.\n\n')
assert.same(output, 'user: What is 2π?\n\nuser: What is Pi?\n\nsystem: You are a mathematician.\n\n')
end)

end)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ local TEST_SCANARIOS = {
{ id = "97a884ab-5b8f-442a-8011-89dce47a68b6", desc = "good caching", vector_config = "good", embeddings_config = "good", embeddings_response = "good", chat_request = "good", stop_on_failure = true, message_countback = 10, expect = 200 },
{ id = "97a884ab-5b8f-442a-8011-89dce47a68b8", desc = "good caching", vector_config = "good", embeddings_config = "good", embeddings_response = "good", chat_request = "good", stop_on_failure = true, message_countback = 10, expect = 200, model = "gpt-4-turbo" },
{ id = "97a884ab-5b8f-442a-8011-89dce47a68b3", desc = "good caching with ignore tool",vector_config = "good", embeddings_config = "good", embeddings_response = "good", chat_request = "good-with-tool", stop_on_failure = true, message_countback = 10, expect = 200, model = "gpt-4-turbo" },
{ id = "97a884ab-5b8f-442a-8011-89dce47a68c3", desc = "good caching with multi model",vector_config = "good", embeddings_config = "good", embeddings_response = "good", chat_request = "good-with-mutli-model", stop_on_failure = true, message_countback = 10, expect = 200, model = "gpt-4-turbo" },
{ id = "97a884ab-5b8f-442a-8011-89dce47a68b1", desc = "good caching", vector_config = "good", embeddings_config = "good", embeddings_response = "good", chat_request = "good", stop_on_failure = true, message_countback = 10, expect = 200, enable_buffer_proxy = true },
{ id = "97a884ab-5b8f-442a-8011-816356521752", desc = "good caching with ai proxy", vector_config = "good", embeddings_config = "good", embeddings_response = "good", chat_request = "good", stop_on_failure = true, message_countback = 10, expect = 200, with_ai_proxy = true },
{ id = "4819bbfb-7669-4d7d-a7b8-1c60dc71d2a8", desc = "stream request rest response", vector_config = "good", embeddings_config = "good", embeddings_response = "good", chat_request = "good", stop_on_failure = true, message_countback = 10, expect = 200, stream_request = true },
Expand Down
17 changes: 17 additions & 0 deletions spec-ee/fixtures/ai-proxy/chat/request/good-with-mutli-model.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
{
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "What's in this image?"},
{
"type": "image_url",
"image_url": {
"url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg"
}
}
]
}
]
}

0 comments on commit f62a799

Please sign in to comment.