From b386e24f5de32fdfc1ca2f77ff57fcf586c4362a Mon Sep 17 00:00:00 2001 From: Pranav Puranik <54378813+PranavPuranik@users.noreply.github.com> Date: Thu, 1 Aug 2024 14:15:28 -0500 Subject: [PATCH] Sets up metadata db for every llm class (#1401) --- embedchain/embedchain/app.py | 10 +--------- embedchain/embedchain/llm/base.py | 8 ++++++++ embedchain/poetry.lock | 8 ++++---- embedchain/tests/conftest.py | 2 +- embedchain/tests/llm/conftest.py | 8 ++++++++ 5 files changed, 22 insertions(+), 14 deletions(-) create mode 100644 embedchain/tests/llm/conftest.py diff --git a/embedchain/embedchain/app.py b/embedchain/embedchain/app.py index 610f9aad1a..ede9ec75e5 100644 --- a/embedchain/embedchain/app.py +++ b/embedchain/embedchain/app.py @@ -20,7 +20,7 @@ ) from embedchain.client import Client from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config -from embedchain.core.db.database import get_session, init_db, setup_engine +from embedchain.core.db.database import get_session from embedchain.core.db.models import DataSource from embedchain.embedchain import EmbedChain from embedchain.embedder.base import BaseEmbedder @@ -89,10 +89,6 @@ def __init__( if name and config: raise Exception("Cannot provide both name and config. Please provide only one of them.") - # Initialize the metadata db for the app - setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI")) - init_db() - self.auto_deploy = auto_deploy # Store the dict config as an attribute to be able to send it self.config_data = config_data if (config_data and validate_config(config_data)) else None @@ -389,10 +385,6 @@ def from_config( vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {})) if llm_config_data: - # Initialize the metadata db for the app here since llmfactory needs it for initialization of - # the llm memory - setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI")) - init_db() llm_provider = llm_config_data.get("provider", "openai") llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {})) else: diff --git a/embedchain/embedchain/llm/base.py b/embedchain/embedchain/llm/base.py index 533ad20c5f..bcbed4c8fc 100644 --- a/embedchain/embedchain/llm/base.py +++ b/embedchain/embedchain/llm/base.py @@ -1,9 +1,11 @@ import logging +import os from collections.abc import Generator from typing import Any, Optional from langchain.schema import BaseMessage as LCBaseMessage +from embedchain.constants import SQLITE_PATH from embedchain.config import BaseLlmConfig from embedchain.config.llm.base import ( DEFAULT_PROMPT, @@ -11,6 +13,7 @@ DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE, DOCS_SITE_PROMPT_TEMPLATE, ) +from embedchain.core.db.database import init_db, setup_engine from embedchain.helpers.json_serializable import JSONSerializable from embedchain.memory.base import ChatHistory from embedchain.memory.message import ChatMessage @@ -30,6 +33,11 @@ def __init__(self, config: Optional[BaseLlmConfig] = None): else: self.config = config + # Initialize the metadata db for the app here since llmfactory needs it for initialization of + # the llm memory + setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI", f"sqlite:///{SQLITE_PATH}")) + init_db() + self.memory = ChatHistory() self.is_docs_site_instance = False self.history: Any = None diff --git a/embedchain/poetry.lock b/embedchain/poetry.lock index f8d99d69e9..893867bc66 100644 --- a/embedchain/poetry.lock +++ b/embedchain/poetry.lock @@ -2391,18 +2391,18 @@ tenacity = ">=8.1.0,<8.4.0 || >8.4.0,<9.0.0" [[package]] name = "langchain-aws" -version = "0.1.10" +version = "0.1.13" description = "An integration package connecting AWS and LangChain" optional = true python-versions = "<4.0,>=3.8.1" files = [ - {file = "langchain_aws-0.1.10-py3-none-any.whl", hash = "sha256:2cba72efaa9f0dc406d8e06a1fbaa3762678d489cbc5147cf64a7012189c161c"}, - {file = "langchain_aws-0.1.10.tar.gz", hash = "sha256:7f01dacbf8345a28192cec4ef31018cc33a91de0b82122f913eec09a76d64fd5"}, + {file = "langchain_aws-0.1.13-py3-none-any.whl", hash = "sha256:c4db60c8a83b8ff3e66170e0bd646739176fcd1a20a9d0a10828a1e21339af1d"}, + {file = "langchain_aws-0.1.13.tar.gz", hash = "sha256:fda790732a72de4ccec3760dba24db5f9fa5cb8724dfd9676a7d5cf87a9f1a98"}, ] [package.dependencies] boto3 = ">=1.34.131,<1.35.0" -langchain-core = ">=0.2.6,<0.3" +langchain-core = ">=0.2.17,<0.3" numpy = ">=1,<2" [[package]] diff --git a/embedchain/tests/conftest.py b/embedchain/tests/conftest.py index 1675d329e3..0b5807a0d9 100644 --- a/embedchain/tests/conftest.py +++ b/embedchain/tests/conftest.py @@ -32,4 +32,4 @@ def clean_db(): def disable_telemetry(): os.environ["EC_TELEMETRY"] = "false" yield - del os.environ["EC_TELEMETRY"] + del os.environ["EC_TELEMETRY"] \ No newline at end of file diff --git a/embedchain/tests/llm/conftest.py b/embedchain/tests/llm/conftest.py new file mode 100644 index 0000000000..edfafb95e3 --- /dev/null +++ b/embedchain/tests/llm/conftest.py @@ -0,0 +1,8 @@ + +from unittest import mock +import pytest + +@pytest.fixture(autouse=True) +def mock_alembic_command_upgrade(): + with mock.patch("alembic.command.upgrade"): + yield