Skip to content

Commit

Permalink
Fix serde issue for huggingface inference API embedding (run-llama#16053
Browse files Browse the repository at this point in the history
)

* wip

* wip
  • Loading branch information
Disiok authored and raspawar committed Oct 7, 2024
1 parent 4349d8c commit 2c21689
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,15 +61,15 @@ class HuggingFaceInferenceAPIEmbedding(BaseEmbedding): # type: ignore[misc]
" Defaults to None, meaning it will loop until the server is available."
),
)
headers: Dict[str, str] = Field(
headers: Optional[Dict[str, str]] = Field(
default=None,
description=(
"Additional headers to send to the server. By default only the"
" authorization and user-agent headers are sent. Values in this dictionary"
" will override the default values."
),
)
cookies: Dict[str, str] = Field(
cookies: Optional[Dict[str, str]] = Field(
default=None, description="Additional cookies to send to the server."
)
task: Optional[str] = Field(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-embeddings-huggingface-api"
readme = "README.md"
version = "0.2.0"
version = "0.2.1"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,10 @@ def test_serialization(
assert serialized["model_name"] == STUB_MODEL_NAME
# Check Hugging Face Inference API Embeddings derived class specifics
assert serialized["pooling"] == Pooling.CLS

def test_serde(
self, hf_inference_api_embedding: HuggingFaceInferenceAPIEmbedding
) -> None:
serialized = hf_inference_api_embedding.model_dump()
deserialized = HuggingFaceInferenceAPIEmbedding.model_validate(serialized)
assert deserialized.headers == hf_inference_api_embedding.headers

0 comments on commit 2c21689

Please sign in to comment.