From debaaa91119509347104973be477a1dc00f29b44 Mon Sep 17 00:00:00 2001
From: Haotian Zhang <socool.king@gmail.com>
Date: Mon, 12 Feb 2024 15:47:15 -0500
Subject: [PATCH] Dedup logic for recursive retriever nodes (#10597)

---
 .../core/retrievers/recursive_retriever.py     | 18 ++++++++++++++++++
 1 file changed, 18 insertions(+)

diff --git a/llama-index-core/llama_index/core/retrievers/recursive_retriever.py b/llama-index-core/llama_index/core/retrievers/recursive_retriever.py
index 55603df8d5..2210bbb142 100644
--- a/llama-index-core/llama_index/core/retrievers/recursive_retriever.py
+++ b/llama-index-core/llama_index/core/retrievers/recursive_retriever.py
@@ -64,6 +64,21 @@ class RecursiveRetriever(BaseRetriever):
         self._query_response_tmpl = query_response_tmpl or DEFAULT_QUERY_RESPONSE_TMPL
         super().__init__(callback_manager, verbose=verbose)
 
+    def _deduplicate_nodes(
+        self, nodes_with_score: List[NodeWithScore]
+    ) -> List[NodeWithScore]:
+        """Deduplicate nodes according to node id.
+        Keep the node with the highest score/first returned.
+        """
+        node_ids = set()
+        deduplicate_nodes = []
+        for node_with_score in nodes_with_score:
+            node = node_with_score.node
+            if node.id_ not in node_ids:
+                node_ids.add(node.id_)
+                deduplicate_nodes.append(node_with_score)
+        return deduplicate_nodes
+
     def _query_retrieved_nodes(
         self, query_bundle: QueryBundle, nodes_with_score: List[NodeWithScore]
     ) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
@@ -116,6 +131,9 @@ class RecursiveRetriever(BaseRetriever):
             nodes_to_add.extend(cur_retrieved_nodes)
             additional_nodes.extend(cur_additional_nodes)
 
+        # dedup nodes in case some nodes could be retrieved from multiple sources
+        nodes_to_add = self._deduplicate_nodes(nodes_to_add)
+        additional_nodes = self._deduplicate_nodes(additional_nodes)
         return nodes_to_add, additional_nodes
 
     def _get_object(self, query_id: str) -> RQN_TYPE:
-- 
GitLab