Skip to content

Commit

Permalink
Fix mongodb hybrid search, also pass hybrid_top_k in vector retriever (
Browse files Browse the repository at this point in the history
  • Loading branch information
logan-markewich authored and raspawar committed Oct 7, 2024
1 parent 85dc133 commit 6c0bc92
Show file tree
Hide file tree
Showing 3 changed files with 13 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@ def __init__(
node_ids: Optional[List[str]] = None,
doc_ids: Optional[List[str]] = None,
sparse_top_k: Optional[int] = None,
hybrid_top_k: Optional[int] = None,
callback_manager: Optional[CallbackManager] = None,
object_map: Optional[dict] = None,
embed_model: Optional[BaseEmbedding] = None,
Expand All @@ -67,6 +68,7 @@ def __init__(
self._doc_ids = doc_ids
self._filters = filters
self._sparse_top_k = sparse_top_k
self._hybrid_top_k = hybrid_top_k
self._kwargs: Dict[str, Any] = kwargs.get("vector_store_kwargs", {})

callback_manager = callback_manager or CallbackManager()
Expand Down Expand Up @@ -126,6 +128,7 @@ def _build_vector_store_query(
alpha=self._alpha,
filters=self._filters,
sparse_top_k=self._sparse_top_k,
hybrid_top_k=self._hybrid_top_k,
)

def _build_node_list_from_query_result(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -229,6 +229,10 @@ def client(self) -> Any:
return self._mongodb_client

def _query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
hybrid_top_k = query.hybrid_top_k or query.similarity_top_k
sparse_top_k = query.sparse_top_k or query.similarity_top_k
dense_top_k = query.similarity_top_k

if query.mode == VectorStoreQueryMode.DEFAULT:
if not query.query_embedding:
raise ValueError("query_embedding in VectorStoreQueryMode.DEFAULT")
Expand All @@ -240,7 +244,7 @@ def _query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
query_vector=query.query_embedding,
search_field=self._embedding_key,
index_name=self._vector_index_name,
limit=query.similarity_top_k,
limit=dense_top_k,
filter=filter,
oversampling_factor=self._oversampling_factor,
),
Expand All @@ -259,15 +263,11 @@ def _query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
index_name=self._fulltext_index_name,
operator="text",
filter=filter,
limit=query.similarity_top_k,
limit=sparse_top_k,
)
pipeline.append({"$set": {"score": {"$meta": "searchScore"}}})

elif query.mode == VectorStoreQueryMode.HYBRID:
if query.hybrid_top_k is None:
raise ValueError(
f"hybrid_top_k not set. You must use this, not similarity_top_k in hybrid mode."
)
# Combines Vector and Full-Text searches with Reciprocal Rank Fusion weighting
logger.debug(f"Running {query.mode} mode query pipeline")
scores_fields = ["vector_score", "fulltext_score"]
Expand All @@ -280,7 +280,7 @@ def _query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
query_vector=query.query_embedding,
search_field=self._embedding_key,
index_name=self._vector_index_name,
limit=query.hybrid_top_k,
limit=dense_top_k,
filter=filter,
oversampling_factor=self._oversampling_factor,
)
Expand All @@ -296,7 +296,7 @@ def _query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
index_name=self._fulltext_index_name,
operator="text",
filter=filter,
limit=query.hybrid_top_k,
limit=sparse_top_k,
)
text_pipeline.extend(reciprocal_rank_stage("fulltext_score"))
combine_pipelines(pipeline, text_pipeline, self._collection.name)
Expand All @@ -306,7 +306,7 @@ def _query(self, query: VectorStoreQuery) -> VectorStoreQueryResult:
query.alpha or 0.5
) # If no alpha is given, equal weighting is applied
pipeline += final_hybrid_stage(
scores_fields=scores_fields, limit=query.hybrid_top_k, alpha=alpha
scores_fields=scores_fields, limit=hybrid_top_k, alpha=alpha
)

# Remove embeddings unless requested.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-vector-stores-mongodb"
readme = "README.md"
version = "0.2.1"
version = "0.3.0"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down

0 comments on commit 6c0bc92

Please sign in to comment.