Skip to content

Commit

Permalink
update llamacloud index with image nodes (run-llama#15996)
Browse files Browse the repository at this point in the history
  • Loading branch information
jerryjliu authored and raspawar committed Oct 7, 2024
1 parent ce6b5a0 commit 545fa9d
Show file tree
Hide file tree
Showing 4 changed files with 145 additions and 9 deletions.
11 changes: 11 additions & 0 deletions llama-index-core/llama_index/core/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -526,6 +526,17 @@ def resolve_image(self) -> ImageType:
else:
raise ValueError("No image found in node.")

@property
def hash(self) -> str:
"""Get hash of node."""
# doc identity depends on if image, image_path, or image_url is set
image_str = self.image or "None"
image_path_str = self.image_path or "None"
image_url_str = self.image_url or "None"
image_text = self.text or "None"
doc_identity = f"{image_str}-{image_path_str}-{image_url_str}-{image_text}"
return str(sha256(doc_identity.encode("utf-8", "surrogatepass")).hexdigest())


class IndexNode(TextNode):
"""Node with reference to any object.
Expand Down
14 changes: 13 additions & 1 deletion llama-index-core/tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import pytest
from llama_index.core.schema import NodeWithScore, TextNode
from llama_index.core.schema import NodeWithScore, TextNode, ImageNode


@pytest.fixture()
Expand Down Expand Up @@ -48,3 +48,15 @@ def test_text_node_hash() -> None:
assert node2.hash == node.hash
node3 = TextNode(text="new", metadata={"foo": "baz"})
assert node3.hash != node.hash


def test_image_node_hash() -> None:
"""Test hash for ImageNode."""
node = ImageNode(image="base64", image_path="path")
node2 = ImageNode(image="base64", image_path="path2")
assert node.hash != node2.hash

# id's don't count as part of the hash
node3 = ImageNode(image_url="base64", id_="id")
node4 = ImageNode(image_url="base64", id_="id2")
assert node3.hash == node4.hash
Original file line number Diff line number Diff line change
@@ -1,13 +1,61 @@
from typing import Any, List, Optional

from llama_cloud import TextNodeWithScore
from llama_cloud import TextNodeWithScore, PageScreenshotNodeWithScore
from llama_cloud.resources.pipelines.client import OMIT, PipelineType
from llama_cloud.client import LlamaCloud, AsyncLlamaCloud
from llama_cloud.core import remove_none_from_dict
from llama_cloud.core.api_error import ApiError

from llama_index.core.base.base_retriever import BaseRetriever
from llama_index.core.constants import DEFAULT_PROJECT_NAME
from llama_index.core.ingestion.api_utils import get_aclient, get_client
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode
from llama_index.core.schema import NodeWithScore, QueryBundle, TextNode, ImageNode
from llama_index.core.vector_stores.types import MetadataFilters
import asyncio
import urllib.parse
import base64


def _get_page_screenshot(
client: LlamaCloud, file_id: str, page_index: int, project_id: str
) -> str:
"""Get the page screenshot."""
# TODO: this currently uses requests, should be replaced with the client
_response = client._client_wrapper.httpx_client.request(
"GET",
urllib.parse.urljoin(
f"{client._client_wrapper.get_base_url()}/",
f"api/v1/files/{file_id}/page_screenshots/{page_index}",
),
params=remove_none_from_dict({"project_id": project_id}),
headers=client._client_wrapper.get_headers(),
timeout=60,
)
if 200 <= _response.status_code < 300:
return _response.content
else:
raise ApiError(status_code=_response.status_code, body=_response.text)


async def _aget_page_screenshot(
client: AsyncLlamaCloud, file_id: str, page_index: int, project_id: str
) -> str:
"""Get the page screenshot."""
# TODO: this currently uses requests, should be replaced with the client
_response = await client._client_wrapper.httpx_client.request(
"GET",
urllib.parse.urljoin(
f"{client._client_wrapper.get_base_url()}/",
f"api/v1/files/{file_id}/page_screenshots/{page_index}",
),
params=remove_none_from_dict({"project_id": project_id}),
headers=client._client_wrapper.get_headers(),
timeout=60,
)
if 200 <= _response.status_code < 300:
return _response.content
else:
raise ApiError(status_code=_response.status_code, body=_response.text)


class LlamaCloudRetriever(BaseRetriever):
Expand All @@ -28,6 +76,7 @@ def __init__(
timeout: int = 60,
retrieval_mode: Optional[str] = None,
files_top_k: Optional[int] = None,
retrieve_image_nodes: Optional[bool] = None,
**kwargs: Any,
) -> None:
"""Initialize the Platform Retriever."""
Expand Down Expand Up @@ -58,6 +107,9 @@ def __init__(
self._filters = filters if filters is not None else OMIT
self._retrieval_mode = retrieval_mode if retrieval_mode is not None else OMIT
self._files_top_k = files_top_k if files_top_k is not None else OMIT
self._retrieve_image_nodes = (
retrieve_image_nodes if retrieve_image_nodes is not None else OMIT
)

super().__init__(
callback_manager=kwargs.get("callback_manager", None),
Expand All @@ -74,6 +126,62 @@ def _result_nodes_to_node_with_score(

return nodes

def _image_nodes_to_node_with_score(
self, raw_image_nodes: List[PageScreenshotNodeWithScore]
) -> List[NodeWithScore]:
image_nodes = []
if self._retrieve_image_nodes:
for raw_image_node in raw_image_nodes:
# TODO: this is a hack to use requests, should be replaced with the client
image_bytes = _get_page_screenshot(
client=self._client,
file_id=raw_image_node.node.file_id,
page_index=raw_image_node.node.page_index,
project_id=self.project_id,
)
# Convert image bytes to base64 encoded string
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
image_node_with_score = NodeWithScore(
node=ImageNode(image=image_base64), score=raw_image_node.score
)
image_nodes.append(image_node_with_score)
else:
if len(image_nodes) > 0:
raise ValueError(
"Image nodes were retrieved but `retrieve_image_nodes` was set to False."
)
return image_nodes

async def _aimage_nodes_to_node_with_score(
self, raw_image_nodes: List[PageScreenshotNodeWithScore]
) -> List[NodeWithScore]:
image_nodes = []
if self._retrieve_image_nodes:
tasks = [
_aget_page_screenshot(
client=self._aclient,
file_id=raw_image_node.node.file_id,
page_index=raw_image_node.node.page_index,
project_id=self.project_id,
)
for raw_image_node in raw_image_nodes
]

image_bytes_list = await asyncio.gather(*tasks)
for image_bytes, raw_image_node in zip(image_bytes_list, raw_image_nodes):
# Convert image bytes to base64 encoded string
image_base64 = base64.b64encode(image_bytes).decode("utf-8")
image_node_with_score = NodeWithScore(
node=ImageNode(image=image_base64), score=raw_image_node.score
)
image_nodes.append(image_node_with_score)
else:
if len(image_nodes) > 0:
raise ValueError(
"Image nodes were retrieved but `retrieve_image_nodes` was set to False."
)
return image_nodes

def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve from the platform."""
pipelines = self._client.pipelines.search_pipelines(
Expand Down Expand Up @@ -109,11 +217,13 @@ def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
search_filters=self._filters,
files_top_k=self._files_top_k,
retrieval_mode=self._retrieval_mode,
retrieve_image_nodes=self._retrieve_image_nodes,
)

result_nodes = results.retrieval_nodes
result_nodes = self._result_nodes_to_node_with_score(results.retrieval_nodes)
result_nodes.extend(self._image_nodes_to_node_with_score(results.image_nodes))

return self._result_nodes_to_node_with_score(result_nodes)
return result_nodes

async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Asynchronously retrieve from the platform."""
Expand Down Expand Up @@ -150,8 +260,11 @@ async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
search_filters=self._filters,
files_top_k=self._files_top_k,
retrieval_mode=self._retrieval_mode,
retrieve_image_nodes=self._retrieve_image_nodes,
)

result_nodes = results.retrieval_nodes

return self._result_nodes_to_node_with_score(result_nodes)
result_nodes = self._result_nodes_to_node_with_score(results.retrieval_nodes)
result_nodes.extend(
await self._aimage_nodes_to_node_with_score(results.image_nodes)
)
return result_nodes
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-indices-managed-llama-cloud"
readme = "README.md"
version = "0.3.0"
version = "0.3.1"

[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
Expand Down

0 comments on commit 545fa9d

Please sign in to comment.