diff --git a/api/core/model_runtime/model_providers/localai/localai.yaml b/api/core/model_runtime/model_providers/localai/localai.yaml index 151f02ee6f1e09..864dd7a30c3a6f 100644 --- a/api/core/model_runtime/model_providers/localai/localai.yaml +++ b/api/core/model_runtime/model_providers/localai/localai.yaml @@ -15,6 +15,7 @@ help: supported_model_types: - llm - text-embedding + - rerank - speech2text configurate_methods: - customizable-model diff --git a/api/core/model_runtime/model_providers/localai/rerank/__init__.py b/api/core/model_runtime/model_providers/localai/rerank/__init__.py new file mode 100644 index 00000000000000..e69de29bb2d1d6 diff --git a/api/core/model_runtime/model_providers/localai/rerank/rerank.py b/api/core/model_runtime/model_providers/localai/rerank/rerank.py new file mode 100644 index 00000000000000..96087d06dc2774 --- /dev/null +++ b/api/core/model_runtime/model_providers/localai/rerank/rerank.py @@ -0,0 +1,120 @@ +from json import dumps +from typing import Optional + +import httpx +from requests import post +from yarl import URL + +from core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult +from core.model_runtime.errors.invoke import ( + InvokeAuthorizationError, + InvokeBadRequestError, + InvokeConnectionError, + InvokeError, + InvokeRateLimitError, + InvokeServerUnavailableError, +) +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.__base.rerank_model import RerankModel + + +class LocalaiRerankModel(RerankModel): + """ + LocalAI rerank model API is compatible with Jina rerank model API. So just copy the JinaRerankModel class code here. + """ + + def _invoke(self, model: str, credentials: dict, + query: str, docs: list[str], score_threshold: Optional[float] = None, top_n: Optional[int] = None, + user: Optional[str] = None) -> RerankResult: + """ + Invoke rerank model + + :param model: model name + :param credentials: model credentials + :param query: search query + :param docs: docs for reranking + :param score_threshold: score threshold + :param top_n: top n documents to return + :param user: unique user id + :return: rerank result + """ + if len(docs) == 0: + return RerankResult(model=model, docs=[]) + + server_url = credentials['server_url'] + model_name = model + + if not server_url: + raise CredentialsValidateFailedError('server_url is required') + if not model_name: + raise CredentialsValidateFailedError('model_name is required') + + url = server_url + headers = { + 'Authorization': f"Bearer {credentials.get('api_key')}", + 'Content-Type': 'application/json' + } + + data = { + "model": model_name, + "query": query, + "documents": docs, + "top_n": top_n + } + + try: + response = post(str(URL(url) / 'rerank'), headers=headers, data=dumps(data), timeout=10) + response.raise_for_status() + results = response.json() + + rerank_documents = [] + for result in results['results']: + rerank_document = RerankDocument( + index=result['index'], + text=result['document']['text'], + score=result['relevance_score'], + ) + if score_threshold is None or result['relevance_score'] >= score_threshold: + rerank_documents.append(rerank_document) + + return RerankResult(model=model, docs=rerank_documents) + except httpx.HTTPStatusError as e: + raise InvokeServerUnavailableError(str(e)) + + def validate_credentials(self, model: str, credentials: dict) -> None: + """ + Validate model credentials + + :param model: model name + :param credentials: model credentials + :return: + """ + try: + + self._invoke( + model=model, + credentials=credentials, + query="What is the capital of the United States?", + docs=[ + "Carson City is the capital city of the American state of Nevada. At the 2010 United States " + "Census, Carson City had a population of 55,274.", + "The Commonwealth of the Northern Mariana Islands is a group of islands in the Pacific Ocean that " + "are a political division controlled by the United States. Its capital is Saipan.", + ], + score_threshold=0.8 + ) + except Exception as ex: + raise CredentialsValidateFailedError(str(ex)) + + @property + def _invoke_error_mapping(self) -> dict[type[InvokeError], list[type[Exception]]]: + """ + Map model invoke error to unified error + """ + return { + InvokeConnectionError: [httpx.ConnectError], + InvokeServerUnavailableError: [httpx.RemoteProtocolError], + InvokeRateLimitError: [], + InvokeAuthorizationError: [httpx.HTTPStatusError], + InvokeBadRequestError: [httpx.RequestError] + } diff --git a/api/tests/integration_tests/model_runtime/localai/test_rerank.py b/api/tests/integration_tests/model_runtime/localai/test_rerank.py new file mode 100644 index 00000000000000..a75439337eb5c0 --- /dev/null +++ b/api/tests/integration_tests/model_runtime/localai/test_rerank.py @@ -0,0 +1,158 @@ +import os + +import pytest +from api.core.model_runtime.entities.rerank_entities import RerankResult + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel + + +def test_validate_credentials_for_chat_model(): + model = LocalaiRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='bge-reranker-v2-m3', + credentials={ + 'server_url': 'hahahaha', + 'completion_type': 'completion', + } + ) + + model.validate_credentials( + model='bge-reranker-base', + credentials={ + 'server_url': os.environ.get('LOCALAI_SERVER_URL'), + 'completion_type': 'completion', + } + ) + +def test_invoke_rerank_model(): + model = LocalaiRerankModel() + + response = model.invoke( + model='bge-reranker-base', + credentials={ + 'server_url': os.environ.get('LOCALAI_SERVER_URL') + }, + query='Organic skincare products for sensitive skin', + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials" + ], + top_n=3, + score_threshold=0.75, + user="abc-123" + ) + + assert isinstance(response, RerankResult) + assert len(response.docs) == 3 +import os + +import pytest +from api.core.model_runtime.entities.rerank_entities import RerankDocument, RerankResult + +from core.model_runtime.errors.validate import CredentialsValidateFailedError +from core.model_runtime.model_providers.localai.rerank.rerank import LocalaiRerankModel + + +def test_validate_credentials_for_chat_model(): + model = LocalaiRerankModel() + + with pytest.raises(CredentialsValidateFailedError): + model.validate_credentials( + model='bge-reranker-v2-m3', + credentials={ + 'server_url': 'hahahaha', + 'completion_type': 'completion', + } + ) + + model.validate_credentials( + model='bge-reranker-base', + credentials={ + 'server_url': os.environ.get('LOCALAI_SERVER_URL'), + 'completion_type': 'completion', + } + ) + +def test_invoke_rerank_model(): + model = LocalaiRerankModel() + + response = model.invoke( + model='bge-reranker-base', + credentials={ + 'server_url': os.environ.get('LOCALAI_SERVER_URL') + }, + query='Organic skincare products for sensitive skin', + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials" + ], + top_n=3, + score_threshold=0.75, + user="abc-123" + ) + + assert isinstance(response, RerankResult) + assert len(response.docs) == 3 + +def test__invoke(): + model = LocalaiRerankModel() + + # Test case 1: Empty docs + result = model._invoke( + model='bge-reranker-base', + credentials={ + 'server_url': 'https://example.com', + 'api_key': '1234567890' + }, + query='Organic skincare products for sensitive skin', + docs=[], + top_n=3, + score_threshold=0.75, + user="abc-123" + ) + assert isinstance(result, RerankResult) + assert len(result.docs) == 0 + + # Test case 2: Valid invocation + result = model._invoke( + model='bge-reranker-base', + credentials={ + 'server_url': 'https://example.com', + 'api_key': '1234567890' + }, + query='Organic skincare products for sensitive skin', + docs=[ + "Eco-friendly kitchenware for modern homes", + "Biodegradable cleaning supplies for eco-conscious consumers", + "Organic cotton baby clothes for sensitive skin", + "Natural organic skincare range for sensitive skin", + "Tech gadgets for smart homes: 2024 edition", + "Sustainable gardening tools and compost solutions", + "Sensitive skin-friendly facial cleansers and toners", + "Organic food wraps and storage solutions", + "Yoga mats made from recycled materials" + ], + top_n=3, + score_threshold=0.75, + user="abc-123" + ) + assert isinstance(result, RerankResult) + assert len(result.docs) == 3 + assert all(isinstance(doc, RerankDocument) for doc in result.docs) \ No newline at end of file