From 3546490384634be72d1a9293dde3ab50931da2d3 Mon Sep 17 00:00:00 2001
From: Logan <logan.markewich@live.com>
Date: Fri, 16 Feb 2024 17:16:54 -0600
Subject: [PATCH] wip improved object retrieval (#10513)

---
 llama-index-core/llama_index/core/schema.py   | 44 ++++++++++++++++++-
 .../llama_index/core/vector_stores/utils.py   |  6 +--
 .../llms/llama-index-llms-litellm/BUILD       |  4 ++
 .../llama-index-llms-litellm/pyproject.toml   |  2 +-
 .../llms/llama-index-llms-litellm/tests/BUILD |  4 +-
 .../llama_index/legacy/core/base_retriever.py |  3 +-
 .../llama_index/legacy/indices/base.py        |  6 ++-
 7 files changed, 61 insertions(+), 8 deletions(-)

diff --git a/llama-index-core/llama_index/core/schema.py b/llama-index-core/llama_index/core/schema.py
index 9eac9bef74..4ecda5b947 100644
--- a/llama-index-core/llama_index/core/schema.py
+++ b/llama-index-core/llama_index/core/schema.py
@@ -1,4 +1,5 @@
 """Base schema for data structures."""
+
 import json
 import textwrap
 import uuid
@@ -499,7 +500,26 @@ class IndexNode(TextNode):
     """
 
     index_id: str
-    obj: Any = Field(exclude=True)
+    obj: Any = None
+
+    def dict(self, **kwargs: Any) -> Dict[str, Any]:
+        from llama_index.core.storage.docstore.utils import doc_to_json
+
+        data = super().dict(**kwargs)
+
+        try:
+            if self.obj is None:
+                data["obj"] = None
+            elif isinstance(self.obj, BaseNode):
+                data["obj"] = doc_to_json(self.obj)
+            elif isinstance(self.obj, BaseModel):
+                data["obj"] = self.obj.dict()
+            else:
+                data["obj"] = json.dumps(self.obj)
+        except Exception:
+            raise ValueError("IndexNode obj is not serializable: " + str(self.obj))
+
+        return data
 
     @classmethod
     def from_text_node(
@@ -514,6 +534,28 @@ class IndexNode(TextNode):
             index_id=index_id,
         )
 
+    # TODO: return type here not supported by current mypy version
+    @classmethod
+    def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self:  # type: ignore
+        output = super().from_dict(data, **kwargs)
+
+        obj = data.get("obj", None)
+        parsed_obj = None
+        if isinstance(obj, str):
+            parsed_obj = TextNode(text=obj)
+        elif isinstance(obj, dict):
+            from llama_index.core.storage.docstore.utils import json_to_doc
+
+            # check if its a node, else assume stringable
+            try:
+                parsed_obj = json_to_doc(obj)
+            except Exception:
+                parsed_obj = TextNode(text=str(obj))
+
+        output.obj = parsed_obj
+
+        return output
+
     @classmethod
     def get_type(cls) -> str:
         return ObjectType.INDEX
diff --git a/llama-index-core/llama_index/core/vector_stores/utils.py b/llama-index-core/llama_index/core/vector_stores/utils.py
index 13f04531c4..66c42570be 100644
--- a/llama-index-core/llama_index/core/vector_stores/utils.py
+++ b/llama-index-core/llama_index/core/vector_stores/utils.py
@@ -71,11 +71,11 @@ def metadata_dict_to_node(metadata: dict, text: Optional[str] = None) -> BaseNod
 
     node: BaseNode
     if node_type == IndexNode.class_name():
-        node = IndexNode.parse_raw(node_json)
+        node = IndexNode.from_json(node_json)
     elif node_type == ImageNode.class_name():
-        node = ImageNode.parse_raw(node_json)
+        node = ImageNode.from_json(node_json)
     else:
-        node = TextNode.parse_raw(node_json)
+        node = TextNode.from_json(node_json)
 
     if text is not None:
         node.set_content(text)
diff --git a/llama-index-integrations/llms/llama-index-llms-litellm/BUILD b/llama-index-integrations/llms/llama-index-llms-litellm/BUILD
index 0896ca890d..a8f4940ed6 100644
--- a/llama-index-integrations/llms/llama-index-llms-litellm/BUILD
+++ b/llama-index-integrations/llms/llama-index-llms-litellm/BUILD
@@ -1,3 +1,7 @@
 poetry_requirements(
     name="poetry",
 )
+
+python_sources(
+    interpreter_constraints=["==3.9.*", "==3.10.*"],
+)
diff --git a/llama-index-integrations/llms/llama-index-llms-litellm/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-litellm/pyproject.toml
index 8165327c55..fec49c9e66 100644
--- a/llama-index-integrations/llms/llama-index-llms-litellm/pyproject.toml
+++ b/llama-index-integrations/llms/llama-index-llms-litellm/pyproject.toml
@@ -27,7 +27,7 @@ readme = "README.md"
 version = "0.1.1"
 
 [tool.poetry.dependencies]
-python = ">=3.8.1,<3.12"
+python = ">=3.9,<3.12"
 llama-index-core = "^0.10.1"
 litellm = "^1.18.13"
 
diff --git a/llama-index-integrations/llms/llama-index-llms-litellm/tests/BUILD b/llama-index-integrations/llms/llama-index-llms-litellm/tests/BUILD
index dabf212d7e..5cd7615688 100644
--- a/llama-index-integrations/llms/llama-index-llms-litellm/tests/BUILD
+++ b/llama-index-integrations/llms/llama-index-llms-litellm/tests/BUILD
@@ -1 +1,3 @@
-python_tests()
+python_tests(
+  interpreter_constraints=["==3.9.*", "==3.10.*"],
+)
diff --git a/llama-index-legacy/llama_index/legacy/core/base_retriever.py b/llama-index-legacy/llama_index/legacy/core/base_retriever.py
index 2b69aab840..9cdfd8b6ec 100644
--- a/llama-index-legacy/llama_index/legacy/core/base_retriever.py
+++ b/llama-index-legacy/llama_index/legacy/core/base_retriever.py
@@ -77,6 +77,7 @@ class BaseRetriever(ChainableMixin, PromptMixin):
                 f"Retrieving from object {obj.__class__.__name__} with query {query_bundle.query_str}\n",
                 color="llama_pink",
             )
+
         if isinstance(obj, NodeWithScore):
             return [obj]
         elif isinstance(obj, BaseNode):
@@ -149,7 +150,7 @@ class BaseRetriever(ChainableMixin, PromptMixin):
             node = n.node
             score = n.score or 1.0
             if isinstance(node, IndexNode):
-                obj = self.object_map.get(node.index_id, None)
+                obj = node.obj or self.object_map.get(node.index_id, None)
                 if obj is not None:
                     if self._verbose:
                         print_text(
diff --git a/llama-index-legacy/llama_index/legacy/indices/base.py b/llama-index-legacy/llama_index/legacy/indices/base.py
index 3482ec35d9..416b5f1881 100644
--- a/llama-index-legacy/llama_index/legacy/indices/base.py
+++ b/llama-index-legacy/llama_index/legacy/indices/base.py
@@ -67,7 +67,11 @@ class BaseIndex(Generic[IS], ABC):
         self._graph_store = self._storage_context.graph_store
 
         objects = objects or []
-        self._object_map = {obj.index_id: obj.obj for obj in objects}
+        self._object_map = {}
+        for obj in objects:
+            self._object_map[obj.index_id] = obj.obj
+            obj.obj = None  # clear the object avoid serialization issues
+
         with self._service_context.callback_manager.as_trace("index_construction"):
             if index_struct is None:
                 nodes = nodes or []
-- 
GitLab