Skip to content

Commit

Permalink
feat(postgres): add support for engine parameters (#15951)
Browse files Browse the repository at this point in the history
* feat(postgres): add support for engine parameters
- Introduced engine_params to support passing parameters to create_engine.
- Updated create_engine and create_async_engine calls to include engine_params.
- Initialized engine_params in the constructor.

* style(lint): reformat for readability

* refactor(postgres): rename engine_params to create_engine_kwargs

* refactor(postgres): rename engine_params to create_engine_kwargs

* chore: bump version to 0.2.3

* fix(postgres): rename engine_params to create_engine_kwargs
  • Loading branch information
armoucar-neon authored Sep 11, 2024
1 parent d15a2ce commit 8f5c44a
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ def __init__(
debug: bool = False,
use_jsonb: bool = False,
hnsw_kwargs: Optional[Dict[str, Any]] = None,
create_engine_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
"""Constructor.
Expand All @@ -191,6 +192,7 @@ def __init__(
hnsw_kwargs (Optional[Dict[str, Any]], optional): HNSW kwargs, a dict that
contains "hnsw_ef_construction", "hnsw_ef_search", "hnsw_m", and optionally "hnsw_dist_method". Defaults to None,
which turns off HNSW search.
create_engine_kwargs (Optional[Dict[str, Any]], optional): Engine parameters to pass to create_engine. Defaults to None.
"""
table_name = table_name.lower()
schema_name = schema_name.lower()
Expand Down Expand Up @@ -231,6 +233,8 @@ def __init__(
use_jsonb=use_jsonb,
)

self.create_engine_kwargs = create_engine_kwargs or {}

async def close(self) -> None:
if not self._is_initialized:
return
Expand Down Expand Up @@ -264,6 +268,7 @@ def from_params(
debug: bool = False,
use_jsonb: bool = False,
hnsw_kwargs: Optional[Dict[str, Any]] = None,
create_engine_kwargs: Optional[Dict[str, Any]] = None,
) -> "PGVectorStore":
"""Construct from params.
Expand All @@ -287,6 +292,7 @@ def from_params(
hnsw_kwargs (Optional[Dict[str, Any]], optional): HNSW kwargs, a dict that
contains "hnsw_ef_construction", "hnsw_ef_search", "hnsw_m", and optionally "hnsw_dist_method". Defaults to None,
which turns off HNSW search.
create_engine_kwargs (Optional[Dict[str, Any]], optional): Engine parameters to pass to create_engine. Defaults to None.
Returns:
PGVectorStore: Instance of PGVectorStore constructed from params.
Expand All @@ -311,6 +317,7 @@ def from_params(
debug=debug,
use_jsonb=use_jsonb,
hnsw_kwargs=hnsw_kwargs,
create_engine_kwargs=create_engine_kwargs,
)

@property
Expand All @@ -324,10 +331,14 @@ def _connect(self) -> Any:
from sqlalchemy.ext.asyncio import AsyncSession, create_async_engine
from sqlalchemy.orm import sessionmaker

self._engine = create_engine(self.connection_string, echo=self.debug)
self._engine = create_engine(
self.connection_string, echo=self.debug, **self.create_engine_kwargs
)
self._session = sessionmaker(self._engine)

self._async_engine = create_async_engine(self.async_connection_string)
self._async_engine = create_async_engine(
self.async_connection_string, **self.create_engine_kwargs
)
self._async_session = sessionmaker(self._async_engine, class_=AsyncSession) # type: ignore

def _create_schema_if_not_exists(self) -> None:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-vector-stores-postgres"
readme = "README.md"
version = "0.2.2"
version = "0.2.3"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down

0 comments on commit 8f5c44a

Please sign in to comment.