From debaaa91119509347104973be477a1dc00f29b44 Mon Sep 17 00:00:00 2001 From: Haotian Zhang <socool.king@gmail.com> Date: Mon, 12 Feb 2024 15:47:15 -0500 Subject: [PATCH] Dedup logic for recursive retriever nodes (#10597) --- .../core/retrievers/recursive_retriever.py | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/llama-index-core/llama_index/core/retrievers/recursive_retriever.py b/llama-index-core/llama_index/core/retrievers/recursive_retriever.py index 55603df8d5..2210bbb142 100644 --- a/llama-index-core/llama_index/core/retrievers/recursive_retriever.py +++ b/llama-index-core/llama_index/core/retrievers/recursive_retriever.py @@ -64,6 +64,21 @@ class RecursiveRetriever(BaseRetriever): self._query_response_tmpl = query_response_tmpl or DEFAULT_QUERY_RESPONSE_TMPL super().__init__(callback_manager, verbose=verbose) + def _deduplicate_nodes( + self, nodes_with_score: List[NodeWithScore] + ) -> List[NodeWithScore]: + """Deduplicate nodes according to node id. + Keep the node with the highest score/first returned. + """ + node_ids = set() + deduplicate_nodes = [] + for node_with_score in nodes_with_score: + node = node_with_score.node + if node.id_ not in node_ids: + node_ids.add(node.id_) + deduplicate_nodes.append(node_with_score) + return deduplicate_nodes + def _query_retrieved_nodes( self, query_bundle: QueryBundle, nodes_with_score: List[NodeWithScore] ) -> Tuple[List[NodeWithScore], List[NodeWithScore]]: @@ -116,6 +131,9 @@ class RecursiveRetriever(BaseRetriever): nodes_to_add.extend(cur_retrieved_nodes) additional_nodes.extend(cur_additional_nodes) + # dedup nodes in case some nodes could be retrieved from multiple sources + nodes_to_add = self._deduplicate_nodes(nodes_to_add) + additional_nodes = self._deduplicate_nodes(additional_nodes) return nodes_to_add, additional_nodes def _get_object(self, query_id: str) -> RQN_TYPE: -- GitLab