diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-api/llama_index/embeddings/huggingface_api/base.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-api/llama_index/embeddings/huggingface_api/base.py index 012abd0a74f9f..8aa503ccdd008 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-api/llama_index/embeddings/huggingface_api/base.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-api/llama_index/embeddings/huggingface_api/base.py @@ -61,7 +61,7 @@ class HuggingFaceInferenceAPIEmbedding(BaseEmbedding): # type: ignore[misc] " Defaults to None, meaning it will loop until the server is available." ), ) - headers: Dict[str, str] = Field( + headers: Optional[Dict[str, str]] = Field( default=None, description=( "Additional headers to send to the server. By default only the" @@ -69,7 +69,7 @@ class HuggingFaceInferenceAPIEmbedding(BaseEmbedding): # type: ignore[misc] " will override the default values." ), ) - cookies: Dict[str, str] = Field( + cookies: Optional[Dict[str, str]] = Field( default=None, description="Additional cookies to send to the server." ) task: Optional[str] = Field( diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-api/pyproject.toml b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-api/pyproject.toml index 175c1c95aa042..e2b0457a453af 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-api/pyproject.toml +++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-api/pyproject.toml @@ -27,7 +27,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-embeddings-huggingface-api" readme = "README.md" -version = "0.2.0" +version = "0.2.1" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" diff --git a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-api/tests/test_hf_inference.py b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-api/tests/test_hf_inference.py index a78c2389bb4d4..ea66835ecf4b3 100644 --- a/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-api/tests/test_hf_inference.py +++ b/llama-index-integrations/embeddings/llama-index-embeddings-huggingface-api/tests/test_hf_inference.py @@ -106,3 +106,10 @@ def test_serialization( assert serialized["model_name"] == STUB_MODEL_NAME # Check Hugging Face Inference API Embeddings derived class specifics assert serialized["pooling"] == Pooling.CLS + + def test_serde( + self, hf_inference_api_embedding: HuggingFaceInferenceAPIEmbedding + ) -> None: + serialized = hf_inference_api_embedding.model_dump() + deserialized = HuggingFaceInferenceAPIEmbedding.model_validate(serialized) + assert deserialized.headers == hf_inference_api_embedding.headers