diff --git a/llama_index/constants.py b/llama_index/constants.py index 09861f4ca1a60452ff21ceb196a7318fce4295e8..cb3f8c1ca580b7ac826d301f1b7a40e727265e64 100644 --- a/llama_index/constants.py +++ b/llama_index/constants.py @@ -19,6 +19,7 @@ AI21_J2_CONTEXT_WINDOW = 8192 TYPE_KEY = "__type__" DATA_KEY = "__data__" VECTOR_STORE_KEY = "vector_store" +IMAGE_STORE_KEY = "image_store" GRAPH_STORE_KEY = "graph_store" INDEX_STORE_KEY = "index_store" DOC_STORE_KEY = "doc_store" diff --git a/llama_index/data_structs/data_structs.py b/llama_index/data_structs/data_structs.py index 355bc601d593e80188c17d17e7a162283fca336d..6d339b201c1e911507783ddd92dcd80769e1fe31 100644 --- a/llama_index/data_structs/data_structs.py +++ b/llama_index/data_structs/data_structs.py @@ -202,6 +202,16 @@ class IndexDict(IndexStruct): return IndexStructType.VECTOR_STORE +@dataclass +class MultiModelIndexDict(IndexDict): + """A simple dictionary of documents, but loads a MultiModelVectorStore.""" + + @classmethod + def get_type(cls) -> IndexStructType: + """Get type.""" + return IndexStructType.MULTIMODAL_VECTOR_STORE + + @dataclass class KG(IndexStruct): """A table of keywords mapping keywords to text chunks.""" diff --git a/llama_index/data_structs/registry.py b/llama_index/data_structs/registry.py index c85f05f874faa22b21b901dff92bfbc1962bda30..e4c834c61954a05f870c269e87c6a4ff9a530582 100644 --- a/llama_index/data_structs/registry.py +++ b/llama_index/data_structs/registry.py @@ -10,6 +10,7 @@ from llama_index.data_structs.data_structs import ( IndexList, IndexStruct, KeywordTable, + MultiModelIndexDict, ) from llama_index.data_structs.document_summary import IndexDocumentSummary from llama_index.data_structs.struct_type import IndexStructType @@ -25,4 +26,5 @@ INDEX_STRUCT_TYPE_TO_INDEX_STRUCT_CLASS: Dict[IndexStructType, Type[IndexStruct] IndexStructType.KG: KG, IndexStructType.EMPTY: EmptyIndexStruct, IndexStructType.DOCUMENT_SUMMARY: IndexDocumentSummary, + IndexStructType.MULTIMODAL_VECTOR_STORE: MultiModelIndexDict, } diff --git a/llama_index/data_structs/struct_type.py b/llama_index/data_structs/struct_type.py index 342bd4cf86d706d0991fbac2c47d7f2065f4ebdf..c3f2cbb919247aad5ea84d8e4615bcf81384045f 100644 --- a/llama_index/data_structs/struct_type.py +++ b/llama_index/data_structs/struct_type.py @@ -87,6 +87,8 @@ class IndexStructType(str, Enum): CHATGPT_RETRIEVAL_PLUGIN = "chatgpt_retrieval_plugin" DEEPLAKE = "deeplake" EPSILLA = "epsilla" + # multimodal + MULTIMODAL_VECTOR_STORE = "multimodal" # for SQL index SQL = "sql" # for KG index diff --git a/llama_index/embeddings/clip.py b/llama_index/embeddings/clip.py index ad1f5a0f8407ea53544f6a36da940e6f981bc4ab..bd94920f2127e6a42034e0547696d90f19e4be4c 100644 --- a/llama_index/embeddings/clip.py +++ b/llama_index/embeddings/clip.py @@ -7,6 +7,7 @@ from llama_index.embeddings.base import ( Embedding, ) from llama_index.embeddings.mutli_modal_base import MultiModalEmbedding +from llama_index.schema import ImageType logger = logging.getLogger(__name__) @@ -111,10 +112,10 @@ class ClipEmbedding(MultiModalEmbedding): # IMAGE EMBEDDINGS - async def _aget_image_embedding(self, img_file_path: str) -> Embedding: + async def _aget_image_embedding(self, img_file_path: ImageType) -> Embedding: return self._get_image_embedding(img_file_path) - def _get_image_embedding(self, img_file_path: str) -> Embedding: + def _get_image_embedding(self, img_file_path: ImageType) -> Embedding: try: import torch from PIL import Image diff --git a/llama_index/embeddings/mutli_modal_base.py b/llama_index/embeddings/mutli_modal_base.py index ccc5f71a6ec6dc5896130872affa5ed0866740dc..276063ca07a294f02012a29680b7a66b33de198e 100644 --- a/llama_index/embeddings/mutli_modal_base.py +++ b/llama_index/embeddings/mutli_modal_base.py @@ -9,6 +9,7 @@ from llama_index.embeddings.base import ( BaseEmbedding, Embedding, ) +from llama_index.schema import ImageType from llama_index.utils import get_tqdm_iterable @@ -16,7 +17,7 @@ class MultiModalEmbedding(BaseEmbedding): """Base class for Multi Modal embeddings.""" @abstractmethod - def _get_image_embedding(self, img_file_path: str) -> Embedding: + def _get_image_embedding(self, img_file_path: ImageType) -> Embedding: """ Embed the input image synchronously. @@ -25,7 +26,7 @@ class MultiModalEmbedding(BaseEmbedding): """ @abstractmethod - async def _aget_image_embedding(self, img_file_path: str) -> Embedding: + async def _aget_image_embedding(self, img_file_path: ImageType) -> Embedding: """ Embed the input image asynchronously. @@ -33,7 +34,7 @@ class MultiModalEmbedding(BaseEmbedding): docstring for more information. """ - def get_image_embedding(self, img_file_path: str) -> Embedding: + def get_image_embedding(self, img_file_path: ImageType) -> Embedding: """ Embed the input image. """ @@ -50,7 +51,7 @@ class MultiModalEmbedding(BaseEmbedding): ) return image_embedding - async def aget_image_embedding(self, img_file_path: str) -> Embedding: + async def aget_image_embedding(self, img_file_path: ImageType) -> Embedding: """Get image embedding.""" with self.callback_manager.event( CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()} @@ -65,7 +66,7 @@ class MultiModalEmbedding(BaseEmbedding): ) return image_embedding - def _get_image_embeddings(self, img_file_paths: List[str]) -> List[Embedding]: + def _get_image_embeddings(self, img_file_paths: List[ImageType]) -> List[Embedding]: """ Embed the input sequence of image synchronously. @@ -77,7 +78,7 @@ class MultiModalEmbedding(BaseEmbedding): ] async def _aget_image_embeddings( - self, img_file_paths: List[str] + self, img_file_paths: List[ImageType] ) -> List[Embedding]: """ Embed the input sequence of image asynchronously. @@ -92,10 +93,10 @@ class MultiModalEmbedding(BaseEmbedding): ) def get_image_embedding_batch( - self, img_file_paths: List[str], show_progress: bool = False + self, img_file_paths: List[ImageType], show_progress: bool = False ) -> List[Embedding]: """Get a list of image embeddings, with batching.""" - cur_batch: List[str] = [] + cur_batch: List[ImageType] = [] result_embeddings: List[Embedding] = [] queue_with_progress = enumerate( @@ -128,11 +129,11 @@ class MultiModalEmbedding(BaseEmbedding): return result_embeddings async def aget_image_embedding_batch( - self, img_file_paths: List[str], show_progress: bool = False + self, img_file_paths: List[ImageType], show_progress: bool = False ) -> List[Embedding]: """Asynchronously get a list of image embeddings, with batching.""" - cur_batch: List[str] = [] - callback_payloads: List[Tuple[str, List[str]]] = [] + cur_batch: List[ImageType] = [] + callback_payloads: List[Tuple[str, List[ImageType]]] = [] result_embeddings: List[Embedding] = [] embeddings_coroutines: List[Coroutine] = [] for idx, img_file_path in enumerate(img_file_paths): diff --git a/llama_index/embeddings/utils.py b/llama_index/embeddings/utils.py index cfb7dae69e2d0e2c6646d6d5f73b565eca504006..f5350afa7eee1c642ecd93cbdb523240caed0e5d 100644 --- a/llama_index/embeddings/utils.py +++ b/llama_index/embeddings/utils.py @@ -4,6 +4,7 @@ from typing import List, Optional, Union from llama_index.bridge.langchain import Embeddings as LCEmbeddings from llama_index.embeddings.base import BaseEmbedding +from llama_index.embeddings.clip import ClipEmbedding from llama_index.embeddings.huggingface import HuggingFaceEmbedding from llama_index.embeddings.huggingface_utils import ( DEFAULT_HUGGINGFACE_EMBEDDING_MODEL, @@ -50,6 +51,10 @@ def resolve_embed_model(embed_model: Optional[EmbedType] = None) -> BaseEmbeddin "\n******" ) + # for image embeddings + if embed_model == "clip": + embed_model = ClipEmbedding() + if isinstance(embed_model, str): splits = embed_model.split(":", 1) is_local = splits[0] diff --git a/llama_index/indices/multi_modal/__init__.py b/llama_index/indices/multi_modal/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..4843edbf7d009af16951b56e1044b43388aee1bd --- /dev/null +++ b/llama_index/indices/multi_modal/__init__.py @@ -0,0 +1,7 @@ +"""Vector-store based data structures.""" + +from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex + +__all__ = [ + "MultiModalVectorStoreIndex", +] diff --git a/llama_index/indices/multi_modal/base.py b/llama_index/indices/multi_modal/base.py new file mode 100644 index 0000000000000000000000000000000000000000..3e4d0dc098aff6cd96af6ff8a7031952ff8e457b --- /dev/null +++ b/llama_index/indices/multi_modal/base.py @@ -0,0 +1,331 @@ +"""Multi Modal Vector Store Index. + +An index that that is built on top of multiple vector stores for different modalities. + +""" +import logging +from typing import Any, List, Optional, Sequence + +from llama_index.data_structs.data_structs import IndexDict, MultiModelIndexDict +from llama_index.embeddings.mutli_modal_base import MultiModalEmbedding +from llama_index.embeddings.utils import EmbedType, resolve_embed_model +from llama_index.indices.base_retriever import BaseRetriever +from llama_index.indices.service_context import ServiceContext +from llama_index.indices.utils import ( + async_embed_image_nodes, + async_embed_nodes, + embed_image_nodes, + embed_nodes, +) +from llama_index.indices.vector_store.base import VectorStoreIndex +from llama_index.schema import BaseNode, ImageNode, IndexNode +from llama_index.storage.storage_context import StorageContext +from llama_index.vector_stores.simple import DEFAULT_VECTOR_STORE, SimpleVectorStore +from llama_index.vector_stores.types import VectorStore + +logger = logging.getLogger(__name__) + + +class MultiModalVectorStoreIndex(VectorStoreIndex): + """Multi-Modal Vector Store Index. + + Args: + use_async (bool): Whether to use asynchronous calls. Defaults to False. + show_progress (bool): Whether to show tqdm progress bars. Defaults to False. + store_nodes_override (bool): set to True to always store Node objects in index + store and document store even if vector store keeps text. Defaults to False + """ + + image_namespace = "image" + index_struct_cls = MultiModelIndexDict + + def __init__( + self, + nodes: Optional[Sequence[BaseNode]] = None, + index_struct: Optional[MultiModelIndexDict] = None, + service_context: Optional[ServiceContext] = None, + storage_context: Optional[StorageContext] = None, + use_async: bool = False, + store_nodes_override: bool = False, + show_progress: bool = False, + # Image-related kwargs + image_vector_store: Optional[VectorStore] = None, + image_embed_model: EmbedType = "clip", + **kwargs: Any, + ) -> None: + """Initialize params.""" + image_embed_model = resolve_embed_model(image_embed_model) + assert isinstance(image_embed_model, MultiModalEmbedding) + self._image_embed_model = image_embed_model + + storage_context = storage_context or StorageContext.from_defaults() + + if image_vector_store is not None: + storage_context.add_vector_store(image_vector_store, self.image_namespace) + + if self.image_namespace not in storage_context.vector_stores: + storage_context.add_vector_store(SimpleVectorStore(), self.image_namespace) + + super().__init__( + nodes=nodes, + index_struct=index_struct, + service_context=service_context, + storage_context=storage_context, + show_progress=show_progress, + use_async=use_async, + # force to true, since vector dbs don't store images + store_nodes_override=True, + **kwargs, + ) + + @classmethod + def from_vector_store( + cls, + vector_store: VectorStore, + service_context: Optional[ServiceContext] = None, + # Image-related kwargs + image_vector_store: Optional[VectorStore] = None, + image_embed_model: EmbedType = "clip", + **kwargs: Any, + ) -> "VectorStoreIndex": + if not vector_store.stores_text: + raise ValueError( + "Cannot initialize from a vector store that does not store text." + ) + + storage_context = StorageContext.from_defaults(vector_store=vector_store) + return cls( + nodes=[], + service_context=service_context, + storage_context=storage_context, + image_vector_store=image_vector_store, + image_embed_model=image_embed_model, + **kwargs, + ) + + def as_retriever(self, **kwargs: Any) -> BaseRetriever: + raise NotImplementedError("Retriever not yet implemented for MultiModalIndex.") + + def _get_node_with_embedding( + self, + nodes: Sequence[BaseNode], + show_progress: bool = False, + is_image: bool = False, + ) -> List[BaseNode]: + """Get tuples of id, node, and embedding. + + Allows us to store these nodes in a vector store. + Embeddings are called in batches. + + """ + if is_image: + id_to_embed_map = embed_image_nodes( + nodes, + embed_model=self._image_embed_model, + show_progress=show_progress, + ) + else: + id_to_embed_map = embed_nodes( + nodes, + embed_model=self._service_context.embed_model, + show_progress=show_progress, + ) + + results = [] + for node in nodes: + embedding = id_to_embed_map[node.node_id] + result = node.copy() + result.embedding = embedding + results.append(result) + return results + + async def _aget_node_with_embedding( + self, + nodes: Sequence[BaseNode], + show_progress: bool = False, + is_image: bool = False, + ) -> List[BaseNode]: + """Asynchronously get tuples of id, node, and embedding. + + Allows us to store these nodes in a vector store. + Embeddings are called in batches. + + """ + if is_image: + id_to_embed_map = await async_embed_image_nodes( + nodes, + embed_model=self._image_embed_model, + show_progress=show_progress, + ) + else: + id_to_embed_map = await async_embed_nodes( + nodes, + embed_model=self._service_context.embed_model, + show_progress=show_progress, + ) + + results = [] + for node in nodes: + embedding = id_to_embed_map[node.node_id] + result = node.copy() + result.embedding = embedding + results.append(result) + return results + + async def _async_add_nodes_to_index( + self, + index_struct: IndexDict, + nodes: Sequence[BaseNode], + show_progress: bool = False, + **insert_kwargs: Any, + ) -> None: + """Asynchronously add nodes to index.""" + if not nodes: + return + + image_nodes: List[ImageNode] = [] + text_nodes: List[BaseNode] = [] + + for node in nodes: + if isinstance(node, ImageNode): + image_nodes.append(node) + if node.text: + text_nodes.append(node) + + # embed all nodes as text - incclude image nodes that have text attached + text_nodes = await self._aget_node_with_embedding( + text_nodes, show_progress, is_image=False + ) + new_text_ids = await self.storage_context.vector_stores[ + DEFAULT_VECTOR_STORE + ].async_add(text_nodes, **insert_kwargs) + + # embed image nodes as images directly + image_nodes = await self._aget_node_with_embedding( + image_nodes, show_progress, is_image=True + ) + new_img_ids = await self.storage_context.vector_stores[ + self.image_namespace + ].async_add(image_nodes, **insert_kwargs) + + # TODO: can vector stores just store images directly? Then no need for docstore + # Maybe a fix for later + + # if the vector store doesn't store text, we need to add the nodes to the + # index struct and document store + all_nodes = text_nodes + image_nodes + all_new_ids = new_text_ids + new_img_ids + if not self._vector_store.stores_text or self._store_nodes_override: + for node, new_id in zip(all_nodes, all_new_ids): + # NOTE: remove embedding from node to avoid duplication + node_without_embedding = node.copy() + node_without_embedding.embedding = None + + index_struct.add_node(node_without_embedding, text_id=new_id) + self._docstore.add_documents( + [node_without_embedding], allow_update=True + ) + else: + # NOTE: if the vector store keeps text, + # we only need to add image and index nodes + for node, new_id in zip(all_nodes, all_new_ids): + if isinstance(node, (ImageNode, IndexNode)): + # NOTE: remove embedding from node to avoid duplication + node_without_embedding = node.copy() + node_without_embedding.embedding = None + + index_struct.add_node(node_without_embedding, text_id=new_id) + self._docstore.add_documents( + [node_without_embedding], allow_update=True + ) + + def _add_nodes_to_index( + self, + index_struct: IndexDict, + nodes: Sequence[BaseNode], + show_progress: bool = False, + **insert_kwargs: Any, + ) -> None: + """Add document to index.""" + if not nodes: + return + + image_nodes: List[ImageNode] = [] + text_nodes: List[BaseNode] = [] + + for node in nodes: + if isinstance(node, ImageNode): + image_nodes.append(node) + if node.text: + text_nodes.append(node) + + # embed all nodes as text - incclude image nodes that have text attached + text_nodes = self._get_node_with_embedding( + text_nodes, show_progress, is_image=False + ) + new_text_ids = self.storage_context.vector_stores[DEFAULT_VECTOR_STORE].add( + text_nodes, **insert_kwargs + ) + + # embed image nodes as images directly + image_nodes = self._get_node_with_embedding( + image_nodes, show_progress, is_image=True + ) + new_img_ids = self.storage_context.vector_stores[self.image_namespace].add( + image_nodes, **insert_kwargs + ) + + # TODO: can vector stores just store images directly? Then no need for docstore + # Maybe a fix for later + + # if the vector store doesn't store text, we need to add the nodes to the + # index struct and document store + all_nodes = text_nodes + image_nodes + all_new_ids = new_text_ids + new_img_ids + if not self._vector_store.stores_text or self._store_nodes_override: + for node, new_id in zip(all_nodes, all_new_ids): + # NOTE: remove embedding from node to avoid duplication + node_without_embedding = node.copy() + node_without_embedding.embedding = None + + index_struct.add_node(node_without_embedding, text_id=new_id) + self._docstore.add_documents( + [node_without_embedding], allow_update=True + ) + else: + # NOTE: if the vector store keeps text, + # we only need to add image and index nodes + for node, new_id in zip(all_nodes, all_new_ids): + if isinstance(node, (ImageNode, IndexNode)): + # NOTE: remove embedding from node to avoid duplication + node_without_embedding = node.copy() + node_without_embedding.embedding = None + + index_struct.add_node(node_without_embedding, text_id=new_id) + self._docstore.add_documents( + [node_without_embedding], allow_update=True + ) + + def delete_ref_doc( + self, ref_doc_id: str, delete_from_docstore: bool = False, **delete_kwargs: Any + ) -> None: + """Delete a document and it's nodes by using ref_doc_id.""" + # delete from all vector stores + for vector_store in self._storage_context.vector_stores.values(): + vector_store.delete(ref_doc_id) + + # delete from index_struct only if needed + if not self._vector_store.stores_text or self._store_nodes_override: + ref_doc_info = self._docstore.get_ref_doc_info(ref_doc_id) + if ref_doc_info is not None: + for node_id in ref_doc_info.node_ids: + self._index_struct.delete(node_id) + self._vector_store.delete(node_id) + + # delete from docstore only if needed + if ( + not self._vector_store.stores_text or self._store_nodes_override + ) and delete_from_docstore: + self._docstore.delete_ref_doc(ref_doc_id, raise_error=False) + + self._storage_context.index_store.add_index_struct(self._index_struct) diff --git a/llama_index/indices/registry.py b/llama_index/indices/registry.py index 924ad13527c88c681772b304b0c4756e3ddaea0d..6078b22b2211f8173f9b78c0c349b7f3e53f8418 100644 --- a/llama_index/indices/registry.py +++ b/llama_index/indices/registry.py @@ -9,6 +9,7 @@ from llama_index.indices.empty.base import EmptyIndex from llama_index.indices.keyword_table.base import KeywordTableIndex from llama_index.indices.knowledge_graph.base import KnowledgeGraphIndex from llama_index.indices.list.base import SummaryIndex +from llama_index.indices.multi_modal import MultiModalVectorStoreIndex from llama_index.indices.struct_store.pandas import PandasIndex from llama_index.indices.struct_store.sql import SQLStructStoreIndex from llama_index.indices.tree.base import TreeIndex @@ -24,4 +25,5 @@ INDEX_STRUCT_TYPE_TO_INDEX_CLASS: Dict[IndexStructType, Type[BaseIndex]] = { IndexStructType.KG: KnowledgeGraphIndex, IndexStructType.EMPTY: EmptyIndex, IndexStructType.DOCUMENT_SUMMARY: DocumentSummaryIndex, + IndexStructType.MULTIMODAL_VECTOR_STORE: MultiModalVectorStoreIndex, } diff --git a/llama_index/indices/utils.py b/llama_index/indices/utils.py index 80ef6a8fc54978b9d502d953b168962f8577b506..2ba3f2ae12153ba25f516a09bf47bc4e3fb1300d 100644 --- a/llama_index/indices/utils.py +++ b/llama_index/indices/utils.py @@ -4,7 +4,8 @@ import re from typing import Dict, List, Optional, Sequence, Set, Tuple from llama_index.embeddings.base import BaseEmbedding -from llama_index.schema import BaseNode, MetadataMode +from llama_index.embeddings.mutli_modal_base import MultiModalEmbedding +from llama_index.schema import BaseNode, ImageNode, MetadataMode from llama_index.utils import globals_helper, truncate_text from llama_index.vector_stores.types import VectorStoreQueryResult @@ -143,6 +144,42 @@ def embed_nodes( return id_to_embed_map +def embed_image_nodes( + nodes: Sequence[ImageNode], + embed_model: MultiModalEmbedding, + show_progress: bool = False, +) -> Dict[str, List[float]]: + """Get image embeddings of the given nodes, run image embedding model if necessary. + + Args: + nodes (Sequence[ImageNode]): The nodes to embed. + embed_model (MultiModalEmbedding): The embedding model to use. + show_progress (bool): Whether to show progress bar. + + Returns: + Dict[str, List[float]]: A map from node id to embedding. + """ + id_to_embed_map: Dict[str, List[float]] = {} + + images_to_embed = [] + ids_to_embed = [] + for node in nodes: + if node.embedding is None: + ids_to_embed.append(node.node_id) + images_to_embed.append(node.resolve_image()) + else: + id_to_embed_map[node.node_id] = node.embedding + + new_embeddings = embed_model.get_image_embedding_batch( + images_to_embed, show_progress=show_progress + ) + + for new_id, img_embedding in zip(ids_to_embed, new_embeddings): + id_to_embed_map[new_id] = img_embedding + + return id_to_embed_map + + async def async_embed_nodes( nodes: Sequence[BaseNode], embed_model: BaseEmbedding, show_progress: bool = False ) -> Dict[str, List[float]]: @@ -175,3 +212,39 @@ async def async_embed_nodes( id_to_embed_map[new_id] = text_embedding return id_to_embed_map + + +async def async_embed_image_nodes( + nodes: Sequence[ImageNode], + embed_model: MultiModalEmbedding, + show_progress: bool = False, +) -> Dict[str, List[float]]: + """Get image embeddings of the given nodes, run image embedding model if necessary. + + Args: + nodes (Sequence[ImageNode]): The nodes to embed. + embed_model (MultiModalEmbedding): The embedding model to use. + show_progress (bool): Whether to show progress bar. + + Returns: + Dict[str, List[float]]: A map from node id to embedding. + """ + id_to_embed_map: Dict[str, List[float]] = {} + + images_to_embed = [] + ids_to_embed = [] + for node in nodes: + if node.embedding is None: + ids_to_embed.append(node.node_id) + images_to_embed.append(node.resolve_image()) + else: + id_to_embed_map[node.node_id] = node.embedding + + new_embeddings = await embed_model.aget_image_embedding_batch( + images_to_embed, show_progress=show_progress + ) + + for new_id, img_embedding in zip(ids_to_embed, new_embeddings): + id_to_embed_map[new_id] = img_embedding + + return id_to_embed_map diff --git a/llama_index/multi_modal_llms/generic_utils.py b/llama_index/multi_modal_llms/generic_utils.py index dfdca8b31981b691f9726c4e3c1bc5f76a3214e8..0b60677ad05fa7aa5a213a7b51ef68da7488e6b6 100644 --- a/llama_index/multi_modal_llms/generic_utils.py +++ b/llama_index/multi_modal_llms/generic_utils.py @@ -7,8 +7,7 @@ def load_image_urls(image_urls: list[str]) -> list[ImageDocument]: # load remote image urls into image documents image_documents = [] for i in range(len(image_urls)): - new_image_document = ImageDocument() - new_image_document.metadata["image_url"] = image_urls[i] + new_image_document = ImageDocument(image_url=image_urls[i]) image_documents.append(new_image_document) return image_documents diff --git a/llama_index/node_parser/node_utils.py b/llama_index/node_parser/node_utils.py index 87acd0a2b0d0f8b1a4a1f0e2bd0caaf1f23dbebe..e8bd2308c3eb5e76f53f3c7642ca91e4d67f6264 100644 --- a/llama_index/node_parser/node_utils.py +++ b/llama_index/node_parser/node_utils.py @@ -43,6 +43,13 @@ def build_nodes_from_splits( embedding=document.embedding, metadata=node_metadata, image=document.image, + image_path=document.image_path, + image_url=document.image_url, + excluded_embed_metadata_keys=document.excluded_embed_metadata_keys, + excluded_llm_metadata_keys=document.excluded_llm_metadata_keys, + metadata_seperator=document.metadata_seperator, + metadata_template=document.metadata_template, + text_template=document.text_template, relationships={NodeRelationship.SOURCE: ref_doc.as_related_node_info()}, ) nodes.append(image_node) # type: ignore @@ -75,6 +82,28 @@ def build_nodes_from_splits( else: raise ValueError(f"Unknown document type: {type(document)}") + # account for pure image documents + if len(text_splits) == 0 and isinstance(document, ImageDocument): + node_metadata = {} + if include_metadata: + node_metadata = document.metadata + + image_node = ImageNode( + text="", + embedding=document.embedding, + metadata=node_metadata, + image=document.image, + image_path=document.image_path, + image_url=document.image_url, + excluded_embed_metadata_keys=document.excluded_embed_metadata_keys, + excluded_llm_metadata_keys=document.excluded_llm_metadata_keys, + metadata_seperator=document.metadata_seperator, + metadata_template=document.metadata_template, + text_template=document.text_template, + relationships={NodeRelationship.SOURCE: ref_doc.as_related_node_info()}, + ) + nodes.append(image_node) # type: ignore + # if include_prev_next_rel, then add prev/next relationships if include_prev_next_rel: for i, node in enumerate(nodes): diff --git a/llama_index/readers/file/image_caption_reader.py b/llama_index/readers/file/image_caption_reader.py index ce2f4c814499a30ede65fb8af86ba4900b0230e3..39b415494e6565fd18d447970819ac7b285082a3 100644 --- a/llama_index/readers/file/image_caption_reader.py +++ b/llama_index/readers/file/image_caption_reader.py @@ -92,6 +92,7 @@ class ImageCaptionReader(BaseReader): ImageDocument( text=text_str, image=image_str, + image_path=str(file), metadata=extra_info or {}, ) ] diff --git a/llama_index/readers/file/image_reader.py b/llama_index/readers/file/image_reader.py index 6137471217be04d601357061c8e9085b4a7f1a98..9997edd5e0d49453b425a10e55e4f24150bbbba8 100644 --- a/llama_index/readers/file/image_reader.py +++ b/llama_index/readers/file/image_reader.py @@ -109,5 +109,10 @@ class ImageReader(BaseReader): text_str = re.sub(r"<.*?>", "", sequence, count=1).strip() return [ - ImageDocument(text=text_str, image=image_str, metadata=extra_info or {}) + ImageDocument( + text=text_str, + image=image_str, + image_path=str(file), + metadata=extra_info or {}, + ) ] diff --git a/llama_index/readers/file/image_vision_llm_reader.py b/llama_index/readers/file/image_vision_llm_reader.py index bb772c5cc99b25d021fc54629cc5338907814306..02991e3f94d90532f0986ad903abe9373f48c097 100644 --- a/llama_index/readers/file/image_vision_llm_reader.py +++ b/llama_index/readers/file/image_vision_llm_reader.py @@ -87,6 +87,7 @@ class ImageVisionLLMReader(BaseReader): ImageDocument( text=text_str, image=image_str, + image_path=str(file), metadata=extra_info or {}, ) ] diff --git a/llama_index/schema.py b/llama_index/schema.py index a8142a7096e65d99eab453e3e366d65a92fad197..7df70594354b455b6c304bd16bd06f7a1d91aea5 100644 --- a/llama_index/schema.py +++ b/llama_index/schema.py @@ -5,6 +5,7 @@ import uuid from abc import abstractmethod from enum import Enum, auto from hashlib import sha256 +from io import BytesIO from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union from typing_extensions import Self @@ -25,6 +26,8 @@ DEFAULT_METADATA_TMPL = "{key}: {value}" TRUNCATE_LENGTH = 350 WRAP_WIDTH = 70 +ImageType = Union[str, BytesIO] + class BaseComponent(BaseModel): """Base component object to capture class names.""" @@ -387,6 +390,8 @@ class ImageNode(TextNode): # TODO: store reference instead of actual image # base64 encoded image str image: Optional[str] = None + image_path: Optional[str] = None + image_url: Optional[str] = None @classmethod def get_type(cls) -> str: @@ -396,6 +401,21 @@ class ImageNode(TextNode): def class_name(cls) -> str: return "ImageNode" + def resolve_image(self) -> ImageType: + """Resolve an image such that PIL can read it.""" + if self.image is not None: + return self.image + elif self.image_path is not None: + return self.image_path + elif self.image_url is not None: + # load image from URL + import requests + + response = requests.get(self.image_url) + return BytesIO(response.content) + else: + raise ValueError("No image found in node.") + class IndexNode(TextNode): """Node with reference to any object. @@ -613,12 +633,9 @@ class Document(TextNode): return "Document" -class ImageDocument(Document): +class ImageDocument(Document, ImageNode): """Data document containing an image.""" - # base64 encoded image str - image: Optional[str] = None - @classmethod def class_name(cls) -> str: return "ImageDocument" diff --git a/llama_index/storage/storage_context.py b/llama_index/storage/storage_context.py index a7cfe4ef7e88d183169e1be01bd37cf81022adb9..13236f752764afbbfa25cbd2f0d19abc95f80af5 100644 --- a/llama_index/storage/storage_context.py +++ b/llama_index/storage/storage_context.py @@ -32,6 +32,7 @@ from llama_index.vector_stores.simple import ( from llama_index.vector_stores.types import VectorStore DEFAULT_PERSIST_DIR = "./storage" +IMAGE_STORE_FNAME = "image_store.json" @dataclass @@ -116,6 +117,7 @@ class StorageContext: docstore_fname: str = DOCSTORE_FNAME, index_store_fname: str = INDEX_STORE_FNAME, vector_store_fname: str = VECTOR_STORE_FNAME, + image_store_fname: str = IMAGE_STORE_FNAME, graph_store_fname: str = GRAPH_STORE_FNAME, fs: Optional[fsspec.AbstractFileSystem] = None, ) -> None: diff --git a/poetry.lock b/poetry.lock index 704b1d3677fdd63cfcdef28682c07e88ec3882a4..79297309da2c97ad6d366a627ff83eb9e50eff39 100644 --- a/poetry.lock +++ b/poetry.lock @@ -2879,6 +2879,16 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, + {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -4728,6 +4738,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -4735,8 +4746,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -4753,6 +4771,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -4760,6 +4779,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"},