Skip to content

Commit

Permalink
Sets up metadata db for every llm class (#1401)
Browse files Browse the repository at this point in the history
  • Loading branch information
PranavPuranik authored Aug 1, 2024
1 parent 58b6887 commit b386e24
Show file tree
Hide file tree
Showing 5 changed files with 22 additions and 14 deletions.
10 changes: 1 addition & 9 deletions embedchain/embedchain/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
)
from embedchain.client import Client
from embedchain.config import AppConfig, CacheConfig, ChunkerConfig, Mem0Config
from embedchain.core.db.database import get_session, init_db, setup_engine
from embedchain.core.db.database import get_session
from embedchain.core.db.models import DataSource
from embedchain.embedchain import EmbedChain
from embedchain.embedder.base import BaseEmbedder
Expand Down Expand Up @@ -89,10 +89,6 @@ def __init__(
if name and config:
raise Exception("Cannot provide both name and config. Please provide only one of them.")

# Initialize the metadata db for the app
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
init_db()

self.auto_deploy = auto_deploy
# Store the dict config as an attribute to be able to send it
self.config_data = config_data if (config_data and validate_config(config_data)) else None
Expand Down Expand Up @@ -389,10 +385,6 @@ def from_config(
vector_db = VectorDBFactory.create(vector_db_provider, vector_db_config_data.get("config", {}))

if llm_config_data:
# Initialize the metadata db for the app here since llmfactory needs it for initialization of
# the llm memory
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI"))
init_db()
llm_provider = llm_config_data.get("provider", "openai")
llm = LlmFactory.create(llm_provider, llm_config_data.get("config", {}))
else:
Expand Down
8 changes: 8 additions & 0 deletions embedchain/embedchain/llm/base.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
import logging
import os
from collections.abc import Generator
from typing import Any, Optional

from langchain.schema import BaseMessage as LCBaseMessage

from embedchain.constants import SQLITE_PATH
from embedchain.config import BaseLlmConfig
from embedchain.config.llm.base import (
DEFAULT_PROMPT,
DEFAULT_PROMPT_WITH_HISTORY_TEMPLATE,
DEFAULT_PROMPT_WITH_MEM0_MEMORY_TEMPLATE,
DOCS_SITE_PROMPT_TEMPLATE,
)
from embedchain.core.db.database import init_db, setup_engine
from embedchain.helpers.json_serializable import JSONSerializable
from embedchain.memory.base import ChatHistory
from embedchain.memory.message import ChatMessage
Expand All @@ -30,6 +33,11 @@ def __init__(self, config: Optional[BaseLlmConfig] = None):
else:
self.config = config

# Initialize the metadata db for the app here since llmfactory needs it for initialization of
# the llm memory
setup_engine(database_uri=os.environ.get("EMBEDCHAIN_DB_URI", f"sqlite:///{SQLITE_PATH}"))
init_db()

self.memory = ChatHistory()
self.is_docs_site_instance = False
self.history: Any = None
Expand Down
8 changes: 4 additions & 4 deletions embedchain/poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion embedchain/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,4 +32,4 @@ def clean_db():
def disable_telemetry():
os.environ["EC_TELEMETRY"] = "false"
yield
del os.environ["EC_TELEMETRY"]
del os.environ["EC_TELEMETRY"]
8 changes: 8 additions & 0 deletions embedchain/tests/llm/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@

from unittest import mock
import pytest

@pytest.fixture(autouse=True)
def mock_alembic_command_upgrade():
with mock.patch("alembic.command.upgrade"):
yield

0 comments on commit b386e24

Please sign in to comment.