diff --git a/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py b/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py index 9eb68ae496b549..3a429245494945 100644 --- a/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py +++ b/llama-index-core/llama_index/core/indices/vector_store/retrievers/retriever.py @@ -48,6 +48,7 @@ def __init__( node_ids: Optional[List[str]] = None, doc_ids: Optional[List[str]] = None, sparse_top_k: Optional[int] = None, + hybrid_top_k: Optional[int] = None, callback_manager: Optional[CallbackManager] = None, object_map: Optional[dict] = None, embed_model: Optional[BaseEmbedding] = None, @@ -67,6 +68,7 @@ def __init__( self._doc_ids = doc_ids self._filters = filters self._sparse_top_k = sparse_top_k + self._hybrid_top_k = hybrid_top_k self._kwargs: Dict[str, Any] = kwargs.get("vector_store_kwargs", {}) callback_manager = callback_manager or CallbackManager() @@ -126,6 +128,7 @@ def _build_vector_store_query( alpha=self._alpha, filters=self._filters, sparse_top_k=self._sparse_top_k, + hybrid_top_k=self._hybrid_top_k, ) def _build_node_list_from_query_result( diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/llama_index/vector_stores/mongodb/base.py b/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/llama_index/vector_stores/mongodb/base.py index 3cfe7606d1ed60..99ed76f2d95433 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/llama_index/vector_stores/mongodb/base.py +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/llama_index/vector_stores/mongodb/base.py @@ -229,6 +229,10 @@ def client(self) -> Any: return self._mongodb_client def _query(self, query: VectorStoreQuery) -> VectorStoreQueryResult: + hybrid_top_k = query.hybrid_top_k or query.similarity_top_k + sparse_top_k = query.sparse_top_k or query.similarity_top_k + dense_top_k = query.similarity_top_k + if query.mode == VectorStoreQueryMode.DEFAULT: if not query.query_embedding: raise ValueError("query_embedding in VectorStoreQueryMode.DEFAULT") @@ -240,7 +244,7 @@ def _query(self, query: VectorStoreQuery) -> VectorStoreQueryResult: query_vector=query.query_embedding, search_field=self._embedding_key, index_name=self._vector_index_name, - limit=query.similarity_top_k, + limit=dense_top_k, filter=filter, oversampling_factor=self._oversampling_factor, ), @@ -259,15 +263,11 @@ def _query(self, query: VectorStoreQuery) -> VectorStoreQueryResult: index_name=self._fulltext_index_name, operator="text", filter=filter, - limit=query.similarity_top_k, + limit=sparse_top_k, ) pipeline.append({"$set": {"score": {"$meta": "searchScore"}}}) elif query.mode == VectorStoreQueryMode.HYBRID: - if query.hybrid_top_k is None: - raise ValueError( - f"hybrid_top_k not set. You must use this, not similarity_top_k in hybrid mode." - ) # Combines Vector and Full-Text searches with Reciprocal Rank Fusion weighting logger.debug(f"Running {query.mode} mode query pipeline") scores_fields = ["vector_score", "fulltext_score"] @@ -280,7 +280,7 @@ def _query(self, query: VectorStoreQuery) -> VectorStoreQueryResult: query_vector=query.query_embedding, search_field=self._embedding_key, index_name=self._vector_index_name, - limit=query.hybrid_top_k, + limit=dense_top_k, filter=filter, oversampling_factor=self._oversampling_factor, ) @@ -296,7 +296,7 @@ def _query(self, query: VectorStoreQuery) -> VectorStoreQueryResult: index_name=self._fulltext_index_name, operator="text", filter=filter, - limit=query.hybrid_top_k, + limit=sparse_top_k, ) text_pipeline.extend(reciprocal_rank_stage("fulltext_score")) combine_pipelines(pipeline, text_pipeline, self._collection.name) @@ -306,7 +306,7 @@ def _query(self, query: VectorStoreQuery) -> VectorStoreQueryResult: query.alpha or 0.5 ) # If no alpha is given, equal weighting is applied pipeline += final_hybrid_stage( - scores_fields=scores_fields, limit=query.hybrid_top_k, alpha=alpha + scores_fields=scores_fields, limit=hybrid_top_k, alpha=alpha ) # Remove embeddings unless requested. diff --git a/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/pyproject.toml b/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/pyproject.toml index 817f918190493c..1fd3e6c1c454b7 100644 --- a/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/pyproject.toml +++ b/llama-index-integrations/vector_stores/llama-index-vector-stores-mongodb/pyproject.toml @@ -29,7 +29,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-vector-stores-mongodb" readme = "README.md" -version = "0.2.1" +version = "0.3.0" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"