From 3553f56f7698ddd65a1e693f4ceb0ba70fd67a25 Mon Sep 17 00:00:00 2001
From: Javier Torres <javierandrestorresreyes@gmail.com>
Date: Wed, 10 Apr 2024 12:16:33 -0500
Subject: [PATCH] Refactor kvdocstore delete methods (#12681)

* refactor docstore delete

* sanity check

* sanity check
---
 .../core/storage/docstore/keyval_docstore.py  | 136 ++++++++++--------
 .../storage/docstore/test_simple_docstore.py  |  93 +++++++++++-
 2 files changed, 166 insertions(+), 63 deletions(-)

diff --git a/llama-index-core/llama_index/core/storage/docstore/keyval_docstore.py b/llama-index-core/llama_index/core/storage/docstore/keyval_docstore.py
index 1972e26c92..f121340eb8 100644
--- a/llama-index-core/llama_index/core/storage/docstore/keyval_docstore.py
+++ b/llama-index-core/llama_index/core/storage/docstore/keyval_docstore.py
@@ -374,90 +374,94 @@ class KVDocumentStore(BaseDocumentStore):
         """Check if document exists."""
         return await self._kvstore.aget(doc_id, self._node_collection) is not None
 
-    def _remove_ref_doc_node(self, doc_id: str) -> None:
-        """Helper function to remove node doc_id from ref_doc_collection."""
+    def _get_ref_doc_id(self, doc_id: str) -> Optional[str]:
+        """Helper function to get ref_doc_info for a given doc_id."""
         metadata = self._kvstore.get(doc_id, collection=self._metadata_collection)
         if metadata is None:
-            return
+            return None
+
+        return metadata.get("ref_doc_id", None)
+
+    async def _aget_ref_doc_id(self, doc_id: str) -> Optional[str]:
+        """Helper function to get ref_doc_info for a given doc_id."""
+        metadata = await self._kvstore.aget(
+            doc_id, collection=self._metadata_collection
+        )
+        if metadata is None:
+            return None
 
-        ref_doc_id = metadata.get("ref_doc_id", None)
+        return metadata.get("ref_doc_id", None)
 
+    def _remove_from_ref_doc_node(self, doc_id: str) -> None:
+        """
+        Helper function to remove node doc_id from ref_doc_collection.
+        If ref_doc has no more doc_ids, delete it from the collection.
+        """
+        ref_doc_id = self._get_ref_doc_id(doc_id)
         if ref_doc_id is None:
             return
-
         ref_doc_info = self._kvstore.get(
             ref_doc_id, collection=self._ref_doc_collection
         )
-
-        if ref_doc_info is not None:
-            ref_doc_obj = RefDocInfo(**ref_doc_info)
-
+        if ref_doc_info is None:
+            return
+        ref_doc_obj = RefDocInfo(**ref_doc_info)
+        if doc_id in ref_doc_obj.node_ids:  # sanity check
             ref_doc_obj.node_ids.remove(doc_id)
-
-            # delete ref_doc from collection if it has no more doc_ids
-            if len(ref_doc_obj.node_ids) > 0:
-                self._kvstore.put(
-                    ref_doc_id,
-                    ref_doc_obj.to_dict(),
-                    collection=self._ref_doc_collection,
-                )
-
+        # delete ref_doc from collection if it has no more doc_ids
+        if len(ref_doc_obj.node_ids) > 0:
+            self._kvstore.put(
+                ref_doc_id,
+                ref_doc_obj.to_dict(),
+                collection=self._ref_doc_collection,
+            )
+        else:
             self._kvstore.delete(ref_doc_id, collection=self._metadata_collection)
+            self._kvstore.delete(ref_doc_id, collection=self._node_collection)
+            self._kvstore.delete(ref_doc_id, collection=self._ref_doc_collection)
 
-    async def _aremove_ref_doc_node(self, doc_id: str) -> None:
-        """Helper function to remove node doc_id from ref_doc_collection."""
-        metadata = await self._kvstore.aget(
-            doc_id, collection=self._metadata_collection
-        )
-        if metadata is None:
-            return
-
-        ref_doc_id = metadata.get("ref_doc_id", None)
-
+    async def _aremove_from_ref_doc_node(self, doc_id: str) -> None:
+        """
+        Helper function to remove node doc_id from ref_doc_collection.
+        If ref_doc has no more doc_ids, delete it from the collection.
+        """
+        ref_doc_id = await self._aget_ref_doc_id(doc_id)
         if ref_doc_id is None:
             return
-
         ref_doc_info = await self._kvstore.aget(
             ref_doc_id, collection=self._ref_doc_collection
         )
-
-        if ref_doc_info is not None:
-            ref_doc_obj = RefDocInfo(**ref_doc_info)
-
+        if ref_doc_info is None:
+            return
+        ref_doc_obj = RefDocInfo(**ref_doc_info)
+        if doc_id in ref_doc_obj.node_ids:  # sanity check
             ref_doc_obj.node_ids.remove(doc_id)
-
-            # delete ref_doc from collection if it has no more doc_ids
-            if len(ref_doc_obj.node_ids) > 0:
-                await self._kvstore.aput(
-                    ref_doc_id,
-                    ref_doc_obj.to_dict(),
-                    collection=self._ref_doc_collection,
-                )
-
+        # delete ref_doc from collection if it has no more doc_ids
+        if len(ref_doc_obj.node_ids) > 0:
+            await self._kvstore.aput(
+                ref_doc_id,
+                ref_doc_obj.to_dict(),
+                collection=self._ref_doc_collection,
+            )
+        else:
             await self._kvstore.adelete(
                 ref_doc_id, collection=self._metadata_collection
             )
+            await self._kvstore.adelete(ref_doc_id, collection=self._node_collection)
+            await self._kvstore.adelete(ref_doc_id, collection=self._ref_doc_collection)
 
-    def delete_document(
-        self, doc_id: str, raise_error: bool = True, remove_ref_doc_node: bool = True
-    ) -> None:
+    def delete_document(self, doc_id: str, raise_error: bool = True) -> None:
         """Delete a document from the store."""
-        if remove_ref_doc_node:
-            self._remove_ref_doc_node(doc_id)
-
+        self._remove_from_ref_doc_node(doc_id)
         delete_success = self._kvstore.delete(doc_id, collection=self._node_collection)
         _ = self._kvstore.delete(doc_id, collection=self._metadata_collection)
 
         if not delete_success and raise_error:
             raise ValueError(f"doc_id {doc_id} not found.")
 
-    async def adelete_document(
-        self, doc_id: str, raise_error: bool = True, remove_ref_doc_node: bool = True
-    ) -> None:
+    async def adelete_document(self, doc_id: str, raise_error: bool = True) -> None:
         """Delete a document from the store."""
-        if remove_ref_doc_node:
-            await self._aremove_ref_doc_node(doc_id)
-
+        await self._aremove_from_ref_doc_node(doc_id)
         delete_success = await self._kvstore.adelete(
             doc_id, collection=self._node_collection
         )
@@ -475,11 +479,16 @@ class KVDocumentStore(BaseDocumentStore):
             else:
                 return
 
-        for doc_id in ref_doc_info.node_ids:
-            self.delete_document(doc_id, raise_error=False, remove_ref_doc_node=False)
-            self._kvstore.delete(doc_id, collection=self._metadata_collection)
+        original_node_ids = (
+            ref_doc_info.node_ids.copy()
+        )  # copy to avoid mutation during iteration
+        for doc_id in original_node_ids:
+            self.delete_document(doc_id, raise_error=False)
 
+        # Deleting all the nodes should already delete the ref_doc, but just to be sure
         self._kvstore.delete(ref_doc_id, collection=self._ref_doc_collection)
+        self._kvstore.delete(ref_doc_id, collection=self._metadata_collection)
+        self._kvstore.delete(ref_doc_id, collection=self._node_collection)
 
     async def adelete_ref_doc(self, ref_doc_id: str, raise_error: bool = True) -> None:
         """Delete a ref_doc and all it's associated nodes."""
@@ -490,13 +499,16 @@ class KVDocumentStore(BaseDocumentStore):
             else:
                 return
 
-        for doc_id in ref_doc_info.node_ids:
-            await self.adelete_document(
-                doc_id, raise_error=False, remove_ref_doc_node=False
-            )
-            await self._kvstore.adelete(doc_id, collection=self._metadata_collection)
+        original_node_ids = (
+            ref_doc_info.node_ids.copy()
+        )  # copy to avoid mutation during iteration
+        for doc_id in original_node_ids:
+            await self.adelete_document(doc_id, raise_error=False)
 
+        # Deleting all the nodes should already delete the ref_doc, but just to be sure
         await self._kvstore.adelete(ref_doc_id, collection=self._ref_doc_collection)
+        await self._kvstore.adelete(ref_doc_id, collection=self._metadata_collection)
+        await self._kvstore.adelete(ref_doc_id, collection=self._node_collection)
 
     def set_document_hash(self, doc_id: str, doc_hash: str) -> None:
         """Set the hash for a given doc_id."""
diff --git a/llama-index-core/tests/storage/docstore/test_simple_docstore.py b/llama-index-core/tests/storage/docstore/test_simple_docstore.py
index 0a1085f36d..4abe9e5801 100644
--- a/llama-index-core/tests/storage/docstore/test_simple_docstore.py
+++ b/llama-index-core/tests/storage/docstore/test_simple_docstore.py
@@ -4,7 +4,7 @@
 from pathlib import Path
 
 import pytest
-from llama_index.core.schema import Document, TextNode
+from llama_index.core.schema import Document, TextNode, NodeRelationship
 from llama_index.core.storage.docstore import SimpleDocumentStore
 from llama_index.core.storage.kvstore.simple_kvstore import SimpleKVStore
 
@@ -62,3 +62,94 @@ def test_docstore_dict() -> None:
     assert gd1 == doc
     gd2 = new_docstore.get_document("d2")
     assert gd2 == node
+
+
+def test_docstore_delete_document() -> None:
+    doc = Document(text="hello world", id_="d1", metadata={"foo": "bar"})
+    node = TextNode(text="my node", id_="d2", metadata={"node": "info"})
+
+    docstore = SimpleDocumentStore()
+    docstore.add_documents([doc, node])
+    docstore.delete_document("d1")
+
+    assert docstore._kvstore.get("d1", docstore._node_collection) is None
+    assert docstore._kvstore.get("d1", docstore._metadata_collection) is None
+    assert docstore._kvstore.get("d1", docstore._ref_doc_collection) is None
+
+    assert docstore._kvstore.get("d2", docstore._node_collection) is not None
+    assert docstore._kvstore.get("d2", docstore._metadata_collection) is not None
+
+
+def test_docstore_delete_ref_doc() -> None:
+    ref_doc = Document(text="hello world", id_="d1", metadata={"foo": "bar"})
+    doc = Document(text="hello world", id_="d2", metadata={"foo": "bar"})
+    doc.relationships[NodeRelationship.SOURCE] = ref_doc.as_related_node_info()
+    node = TextNode(text="my node", id_="d3", metadata={"node": "info"})
+    node.relationships[NodeRelationship.SOURCE] = ref_doc.as_related_node_info()
+
+    docstore = SimpleDocumentStore()
+    docstore.add_documents([ref_doc, doc, node])
+    docstore.delete_ref_doc("d1")
+
+    assert docstore._kvstore.get("d1", docstore._node_collection) is None
+    assert docstore._kvstore.get("d1", docstore._metadata_collection) is None
+    assert docstore._kvstore.get("d1", docstore._ref_doc_collection) is None
+    assert docstore._kvstore.get("d2", docstore._node_collection) is None
+    assert docstore._kvstore.get("d2", docstore._metadata_collection) is None
+    assert docstore._kvstore.get("d2", docstore._ref_doc_collection) is None
+    assert docstore._kvstore.get("d3", docstore._node_collection) is None
+    assert docstore._kvstore.get("d3", docstore._metadata_collection) is None
+    assert docstore._kvstore.get("d3", docstore._ref_doc_collection) is None
+
+
+def test_docstore_delete_ref_doc_not_in_docstore() -> None:
+    ref_doc = Document(text="hello world", id_="d1", metadata={"foo": "bar"})
+    doc = Document(text="hello world", id_="d2", metadata={"foo": "bar"})
+    doc.relationships[NodeRelationship.SOURCE] = ref_doc.as_related_node_info()
+    node = TextNode(text="my node", id_="d3", metadata={"node": "info"})
+    node.relationships[NodeRelationship.SOURCE] = ref_doc.as_related_node_info()
+
+    docstore = SimpleDocumentStore()
+    docstore.add_documents([doc, node])
+    assert docstore._kvstore.get("d1", docstore._ref_doc_collection) is not None
+
+    docstore.delete_ref_doc("d1")
+
+    assert docstore._kvstore.get("d1", docstore._node_collection) is None
+    assert docstore._kvstore.get("d1", docstore._metadata_collection) is None
+    assert docstore._kvstore.get("d1", docstore._ref_doc_collection) is None
+    assert docstore._kvstore.get("d2", docstore._node_collection) is None
+    assert docstore._kvstore.get("d2", docstore._metadata_collection) is None
+    assert docstore._kvstore.get("d2", docstore._ref_doc_collection) is None
+    assert docstore._kvstore.get("d3", docstore._node_collection) is None
+    assert docstore._kvstore.get("d3", docstore._metadata_collection) is None
+    assert docstore._kvstore.get("d3", docstore._ref_doc_collection) is None
+
+
+def test_docstore_delete_all_ref_doc_nodes() -> None:
+    ref_doc = Document(text="hello world", id_="d1", metadata={"foo": "bar"})
+    doc = Document(text="hello world", id_="d2", metadata={"foo": "bar"})
+    doc.relationships[NodeRelationship.SOURCE] = ref_doc.as_related_node_info()
+    node = TextNode(text="my node", id_="d3", metadata={"node": "info"})
+    node.relationships[NodeRelationship.SOURCE] = ref_doc.as_related_node_info()
+
+    docstore = SimpleDocumentStore()
+    docstore.add_documents([ref_doc, doc, node])
+
+    assert docstore._kvstore.get("d1", docstore._ref_doc_collection)["node_ids"] == [
+        "d2",
+        "d3",
+    ]
+
+    docstore.delete_document("d2")
+    assert docstore._kvstore.get("d1", docstore._node_collection) is not None
+    assert docstore._kvstore.get("d1", docstore._metadata_collection) is not None
+    assert docstore._kvstore.get("d1", docstore._ref_doc_collection) is not None
+    assert docstore._kvstore.get("d1", docstore._ref_doc_collection)["node_ids"] == [
+        "d3"
+    ]
+
+    docstore.delete_document("d3")
+    assert docstore._kvstore.get("d1", docstore._node_collection) is None
+    assert docstore._kvstore.get("d1", docstore._metadata_collection) is None
+    assert docstore._kvstore.get("d1", docstore._ref_doc_collection) is None
-- 
GitLab