From 545fa9dad56415863fe2378d0191c554cec01fd1 Mon Sep 17 00:00:00 2001 From: Jerry Liu Date: Sat, 14 Sep 2024 10:09:31 -0700 Subject: [PATCH] update llamacloud index with image nodes (#15996) --- llama-index-core/llama_index/core/schema.py | 11 ++ llama-index-core/tests/test_schema.py | 14 +- .../indices/managed/llama_cloud/retriever.py | 127 +++++++++++++++++- .../pyproject.toml | 2 +- 4 files changed, 145 insertions(+), 9 deletions(-) diff --git a/llama-index-core/llama_index/core/schema.py b/llama-index-core/llama_index/core/schema.py index c7a3caa19306e1..9b5d2b7e87052e 100644 --- a/llama-index-core/llama_index/core/schema.py +++ b/llama-index-core/llama_index/core/schema.py @@ -526,6 +526,17 @@ def resolve_image(self) -> ImageType: else: raise ValueError("No image found in node.") + @property + def hash(self) -> str: + """Get hash of node.""" + # doc identity depends on if image, image_path, or image_url is set + image_str = self.image or "None" + image_path_str = self.image_path or "None" + image_url_str = self.image_url or "None" + image_text = self.text or "None" + doc_identity = f"{image_str}-{image_path_str}-{image_url_str}-{image_text}" + return str(sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest()) + class IndexNode(TextNode): """Node with reference to any object. diff --git a/llama-index-core/tests/test_schema.py b/llama-index-core/tests/test_schema.py index c5dc2f7892f334..a32c3a2e425611 100644 --- a/llama-index-core/tests/test_schema.py +++ b/llama-index-core/tests/test_schema.py @@ -1,5 +1,5 @@ import pytest -from llama_index.core.schema import NodeWithScore, TextNode +from llama_index.core.schema import NodeWithScore, TextNode, ImageNode @pytest.fixture() @@ -48,3 +48,15 @@ def test_text_node_hash() -> None: assert node2.hash == node.hash node3 = TextNode(text="new", metadata={"foo": "baz"}) assert node3.hash != node.hash + + +def test_image_node_hash() -> None: + """Test hash for ImageNode.""" + node = ImageNode(image="base64", image_path="path") + node2 = ImageNode(image="base64", image_path="path2") + assert node.hash != node2.hash + + # id's don't count as part of the hash + node3 = ImageNode(image_url="base64", id_="id") + node4 = ImageNode(image_url="base64", id_="id2") + assert node3.hash == node4.hash diff --git a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py index 55060ae248ac91..b2efc73ded101b 100644 --- a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py +++ b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/llama_index/indices/managed/llama_cloud/retriever.py @@ -1,13 +1,61 @@ from typing import Any, List, Optional -from llama_cloud import TextNodeWithScore +from llama_cloud import TextNodeWithScore, PageScreenshotNodeWithScore from llama_cloud.resources.pipelines.client import OMIT, PipelineType +from llama_cloud.client import LlamaCloud, AsyncLlamaCloud +from llama_cloud.core import remove_none_from_dict +from llama_cloud.core.api_error import ApiError from llama_index.core.base.base_retriever import BaseRetriever from llama_index.core.constants import DEFAULT_PROJECT_NAME from llama_index.core.ingestion.api_utils import get_aclient, get_client -from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode +from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode, ImageNode from llama_index.core.vector_stores.types import MetadataFilters +import asyncio +import urllib.parse +import base64 + + +def _get_page_screenshot( + client: LlamaCloud, file_id: str, page_index: int, project_id: str +) -> str: + """Get the page screenshot.""" + # TODO: this currently uses requests, should be replaced with the client + _response = client._client_wrapper.httpx_client.request( + "GET", + urllib.parse.urljoin( + f"{client._client_wrapper.get_base_url()}/", + f"api/v1/files/{file_id}/page_screenshots/{page_index}", + ), + params=remove_none_from_dict({"project_id": project_id}), + headers=client._client_wrapper.get_headers(), + timeout=60, + ) + if 200 <= _response.status_code < 300: + return _response.content + else: + raise ApiError(status_code=_response.status_code, body=_response.text) + + +async def _aget_page_screenshot( + client: AsyncLlamaCloud, file_id: str, page_index: int, project_id: str +) -> str: + """Get the page screenshot.""" + # TODO: this currently uses requests, should be replaced with the client + _response = await client._client_wrapper.httpx_client.request( + "GET", + urllib.parse.urljoin( + f"{client._client_wrapper.get_base_url()}/", + f"api/v1/files/{file_id}/page_screenshots/{page_index}", + ), + params=remove_none_from_dict({"project_id": project_id}), + headers=client._client_wrapper.get_headers(), + timeout=60, + ) + if 200 <= _response.status_code < 300: + return _response.content + else: + raise ApiError(status_code=_response.status_code, body=_response.text) class LlamaCloudRetriever(BaseRetriever): @@ -28,6 +76,7 @@ def __init__( timeout: int = 60, retrieval_mode: Optional[str] = None, files_top_k: Optional[int] = None, + retrieve_image_nodes: Optional[bool] = None, **kwargs: Any, ) -> None: """Initialize the Platform Retriever.""" @@ -58,6 +107,9 @@ def __init__( self._filters = filters if filters is not None else OMIT self._retrieval_mode = retrieval_mode if retrieval_mode is not None else OMIT self._files_top_k = files_top_k if files_top_k is not None else OMIT + self._retrieve_image_nodes = ( + retrieve_image_nodes if retrieve_image_nodes is not None else OMIT + ) super().__init__( callback_manager=kwargs.get("callback_manager", None), @@ -74,6 +126,62 @@ def _result_nodes_to_node_with_score( return nodes + def _image_nodes_to_node_with_score( + self, raw_image_nodes: List[PageScreenshotNodeWithScore] + ) -> List[NodeWithScore]: + image_nodes = [] + if self._retrieve_image_nodes: + for raw_image_node in raw_image_nodes: + # TODO: this is a hack to use requests, should be replaced with the client + image_bytes = _get_page_screenshot( + client=self._client, + file_id=raw_image_node.node.file_id, + page_index=raw_image_node.node.page_index, + project_id=self.project_id, + ) + # Convert image bytes to base64 encoded string + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + image_node_with_score = NodeWithScore( + node=ImageNode(image=image_base64), score=raw_image_node.score + ) + image_nodes.append(image_node_with_score) + else: + if len(image_nodes) > 0: + raise ValueError( + "Image nodes were retrieved but `retrieve_image_nodes` was set to False." + ) + return image_nodes + + async def _aimage_nodes_to_node_with_score( + self, raw_image_nodes: List[PageScreenshotNodeWithScore] + ) -> List[NodeWithScore]: + image_nodes = [] + if self._retrieve_image_nodes: + tasks = [ + _aget_page_screenshot( + client=self._aclient, + file_id=raw_image_node.node.file_id, + page_index=raw_image_node.node.page_index, + project_id=self.project_id, + ) + for raw_image_node in raw_image_nodes + ] + + image_bytes_list = await asyncio.gather(*tasks) + for image_bytes, raw_image_node in zip(image_bytes_list, raw_image_nodes): + # Convert image bytes to base64 encoded string + image_base64 = base64.b64encode(image_bytes).decode("utf-8") + image_node_with_score = NodeWithScore( + node=ImageNode(image=image_base64), score=raw_image_node.score + ) + image_nodes.append(image_node_with_score) + else: + if len(image_nodes) > 0: + raise ValueError( + "Image nodes were retrieved but `retrieve_image_nodes` was set to False." + ) + return image_nodes + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Retrieve from the platform.""" pipelines = self._client.pipelines.search_pipelines( @@ -109,11 +217,13 @@ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: search_filters=self._filters, files_top_k=self._files_top_k, retrieval_mode=self._retrieval_mode, + retrieve_image_nodes=self._retrieve_image_nodes, ) - result_nodes = results.retrieval_nodes + result_nodes = self._result_nodes_to_node_with_score(results.retrieval_nodes) + result_nodes.extend(self._image_nodes_to_node_with_score(results.image_nodes)) - return self._result_nodes_to_node_with_score(result_nodes) + return result_nodes async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: """Asynchronously retrieve from the platform.""" @@ -150,8 +260,11 @@ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: search_filters=self._filters, files_top_k=self._files_top_k, retrieval_mode=self._retrieval_mode, + retrieve_image_nodes=self._retrieve_image_nodes, ) - result_nodes = results.retrieval_nodes - - return self._result_nodes_to_node_with_score(result_nodes) + result_nodes = self._result_nodes_to_node_with_score(results.retrieval_nodes) + result_nodes.extend( + await self._aimage_nodes_to_node_with_score(results.image_nodes) + ) + return result_nodes diff --git a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/pyproject.toml b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/pyproject.toml index fc39fefd639bb3..c68dbeb2ead238 100644 --- a/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/pyproject.toml +++ b/llama-index-integrations/indices/llama-index-indices-managed-llama-cloud/pyproject.toml @@ -30,7 +30,7 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-indices-managed-llama-cloud" readme = "README.md" -version = "0.3.0" +version = "0.3.1" [tool.poetry.dependencies] python = ">=3.8.1,<4.0"