From 0e39c36990c156a82702c2039e7d6cf433b36892 Mon Sep 17 00:00:00 2001
From: Kirushikesh DB <49152921+Kirushikesh@users.noreply.github.com>
Date: Fri, 1 Mar 2024 22:44:31 +0530
Subject: [PATCH] Updated the simple fusion to handle duplicate nodes (#11542)

---
 .../llama_index/core/retrievers/fusion_retriever.py         | 6 +++++-
 .../llama_index/legacy/retrievers/fusion_retriever.py       | 6 +++++-
 2 files changed, 10 insertions(+), 2 deletions(-)

diff --git a/llama-index-core/llama_index/core/retrievers/fusion_retriever.py b/llama-index-core/llama_index/core/retrievers/fusion_retriever.py
index d5de8373b1..d1425fcc9a 100644
--- a/llama-index-core/llama_index/core/retrievers/fusion_retriever.py
+++ b/llama-index-core/llama_index/core/retrievers/fusion_retriever.py
@@ -130,7 +130,11 @@ class QueryFusionRetriever(BaseRetriever):
         for nodes_with_scores in results.values():
             for node_with_score in nodes_with_scores:
                 text = node_with_score.node.get_content()
-                all_nodes[text] = node_with_score
+                if text in all_nodes:
+                    max_score = max(node_with_score.score, all_nodes[text].score)
+                    all_nodes[text].score = max_score
+                else:
+                    all_nodes[text] = node_with_score
 
         return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)
 
diff --git a/llama-index-legacy/llama_index/legacy/retrievers/fusion_retriever.py b/llama-index-legacy/llama_index/legacy/retrievers/fusion_retriever.py
index 8b9f285930..caa900bacf 100644
--- a/llama-index-legacy/llama_index/legacy/retrievers/fusion_retriever.py
+++ b/llama-index-legacy/llama_index/legacy/retrievers/fusion_retriever.py
@@ -127,7 +127,11 @@ class QueryFusionRetriever(BaseRetriever):
         for nodes_with_scores in results.values():
             for node_with_score in nodes_with_scores:
                 text = node_with_score.node.get_content()
-                all_nodes[text] = node_with_score
+                if text in all_nodes:
+                    score = max(node_with_score.score, all_nodes[text].score)
+                    all_nodes[text].score = score
+                else:
+                    all_nodes[text] = node_with_score
 
         return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)
 
-- 
GitLab