Skip to content
Snippets Groups Projects
Unverified Commit debaaa91 authored by Haotian Zhang's avatar Haotian Zhang Committed by GitHub
Browse files

Dedup logic for recursive retriever nodes (#10597)

parent d2faa5d7
No related branches found
No related tags found
No related merge requests found
......@@ -64,6 +64,21 @@ class RecursiveRetriever(BaseRetriever):
self._query_response_tmpl = query_response_tmpl or DEFAULT_QUERY_RESPONSE_TMPL
super().__init__(callback_manager, verbose=verbose)
def _deduplicate_nodes(
self, nodes_with_score: List[NodeWithScore]
) -> List[NodeWithScore]:
"""Deduplicate nodes according to node id.
Keep the node with the highest score/first returned.
"""
node_ids = set()
deduplicate_nodes = []
for node_with_score in nodes_with_score:
node = node_with_score.node
if node.id_ not in node_ids:
node_ids.add(node.id_)
deduplicate_nodes.append(node_with_score)
return deduplicate_nodes
def _query_retrieved_nodes(
self, query_bundle: QueryBundle, nodes_with_score: List[NodeWithScore]
) -> Tuple[List[NodeWithScore], List[NodeWithScore]]:
......@@ -116,6 +131,9 @@ class RecursiveRetriever(BaseRetriever):
nodes_to_add.extend(cur_retrieved_nodes)
additional_nodes.extend(cur_additional_nodes)
# dedup nodes in case some nodes could be retrieved from multiple sources
nodes_to_add = self._deduplicate_nodes(nodes_to_add)
additional_nodes = self._deduplicate_nodes(additional_nodes)
return nodes_to_add, additional_nodes
def _get_object(self, query_id: str) -> RQN_TYPE:
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment