Skip to content

Commit

Permalink
fix: langfuse talk to model (#3535)
Browse files Browse the repository at this point in the history
# Description

Please include a summary of the changes and the related issue. Please
also include relevant motivation and context.

## Checklist before requesting a review

Please delete options that are not relevant.

- [ ] My code follows the style guidelines of this project
- [ ] I have performed a self-review of my code
- [ ] I have commented hard-to-understand areas
- [ ] I have ideally added tests that prove my fix is effective or that
my feature works
- [ ] New and existing unit tests pass locally with my changes
- [ ] Any dependent changes have been merged

## Screenshots (if appropriate):

---------

Co-authored-by: Stan Girard <[email protected]>
  • Loading branch information
chloedia and StanGirard authored Jan 6, 2025
1 parent d835fc6 commit 9681a9e
Show file tree
Hide file tree
Showing 3 changed files with 23 additions and 12 deletions.
9 changes: 6 additions & 3 deletions core/quivr_core/rag/quivr_rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@
from langchain_core.runnables import RunnableLambda, RunnablePassthrough
from langchain_core.vectorstores import VectorStore

from quivr_core.llm import LLMEndpoint
from quivr_core.rag.entities.chat import ChatHistory
from quivr_core.rag.entities.config import RetrievalConfig
from quivr_core.llm import LLMEndpoint
from quivr_core.rag.entities.models import (
ParsedRAGChunkResponse,
ParsedRAGResponse,
Expand All @@ -24,6 +24,7 @@
)
from quivr_core.rag.prompts import custom_prompts
from quivr_core.rag.utils import (
LangfuseService,
combine_documents,
format_file_list,
get_chunk_metadata,
Expand All @@ -32,6 +33,8 @@
)

logger = logging.getLogger("quivr_core")
langfuse_service = LangfuseService()
langfuse_handler = langfuse_service.get_handler()


class IdempotentCompressor(BaseDocumentCompressor):
Expand Down Expand Up @@ -173,7 +176,7 @@ def answer(
"chat_history": history,
"custom_instructions": (self.retrieval_config.prompt),
},
config={"metadata": metadata},
config={"metadata": metadata, "callbacks": [langfuse_handler]},
)
response = parse_response(
raw_llm_response, self.retrieval_config.llm_config.model
Expand Down Expand Up @@ -206,7 +209,7 @@ async def answer_astream(
"chat_history": history,
"custom_personality": (self.retrieval_config.prompt),
},
config={"metadata": metadata},
config={"metadata": metadata, "callbacks": [langfuse_handler]},
):
# Could receive this anywhere so we need to save it for the last chunk
if "docs" in chunk:
Expand Down
17 changes: 8 additions & 9 deletions core/quivr_core/rag/quivr_rag_langgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,6 @@
from langgraph.types import Send
from pydantic import BaseModel, Field

from langfuse.callback import CallbackHandler

from quivr_core.llm import LLMEndpoint
from quivr_core.llm_tools.llm_tools import LLMToolFactory
from quivr_core.rag.entities.chat import ChatHistory
Expand All @@ -41,6 +39,7 @@
)
from quivr_core.rag.prompts import custom_prompts
from quivr_core.rag.utils import (
LangfuseService,
collect_tools,
combine_documents,
format_file_list,
Expand All @@ -50,8 +49,8 @@

logger = logging.getLogger("quivr_core")

# Initialize Langfuse CallbackHandler for Langchain (tracing)
langfuse_handler = CallbackHandler()
langfuse_service = LangfuseService()
langfuse_handler = langfuse_service.get_handler()


class SplittedInput(BaseModel):
Expand Down Expand Up @@ -502,7 +501,7 @@ async def rewrite(self, state: AgentState) -> AgentState:
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

# Replace each question with its condensed version
for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
tasks.set_definition(task_id, response.content)

return {**state, "tasks": tasks}
Expand Down Expand Up @@ -558,7 +557,7 @@ async def tool_routing(self, state: AgentState):
)
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
tasks.set_completion(task_id, response.is_task_completable)
if not response.is_task_completable and response.tool:
tasks.set_tool(task_id, response.tool)
Expand Down Expand Up @@ -599,7 +598,7 @@ async def run_tool(self, state: AgentState) -> AgentState:
)
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
_docs = tool_wrapper.format_output(response)
_docs = self.filter_chunks_by_relevance(_docs)
tasks.set_docs(task_id, _docs)
Expand Down Expand Up @@ -652,7 +651,7 @@ async def retrieve(self, state: AgentState) -> AgentState:
task_ids = [task[1] for task in async_jobs] if async_jobs else []

# Process responses and associate docs with tasks
for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
_docs = self.filter_chunks_by_relevance(response)
tasks.set_docs(task_id, _docs) # Associate docs with the specific task

Expand Down Expand Up @@ -715,7 +714,7 @@ async def dynamic_retrieve(self, state: AgentState) -> AgentState:
task_ids = [jobs[1] for jobs in async_jobs] if async_jobs else []

_n = []
for response, task_id in zip(responses, task_ids):
for response, task_id in zip(responses, task_ids, strict=False):
_docs = self.filter_chunks_by_relevance(response)
_n.append(len(_docs))
tasks.set_docs(task_id, _docs)
Expand Down
9 changes: 9 additions & 0 deletions core/quivr_core/rag/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from langchain_core.messages import AIMessage, BaseMessage, HumanMessage, SystemMessage
from langchain_core.messages.ai import AIMessageChunk
from langchain_core.prompts import format_document
from langfuse.callback import CallbackHandler

from quivr_core.rag.entities.config import WorkflowConfig
from quivr_core.rag.entities.models import (
Expand Down Expand Up @@ -195,3 +196,11 @@ def collect_tools(workflow_config: WorkflowConfig):
activated_tools += f"Tool {i+1} description: {tool.description}\n\n"

return validated_tools, activated_tools


class LangfuseService:
def __init__(self):
self.langfuse_handler = CallbackHandler()

def get_handler(self):
return self.langfuse_handler

0 comments on commit 9681a9e

Please sign in to comment.