From 4fff8e7afd8b37bd6f99318d85800e5baba49044 Mon Sep 17 00:00:00 2001
From: Grant Doctor <gdoctor@users.noreply.github.com>
Date: Fri, 29 Mar 2024 10:01:12 -0700
Subject: [PATCH] QueryFusionRetriever sends incorrect QueryBundle type to it's
 retrievers (#12387)

---
 .../core/retrievers/fusion_retriever.py       | 27 +++++++++----------
 1 file changed, 13 insertions(+), 14 deletions(-)

diff --git a/llama-index-core/llama_index/core/retrievers/fusion_retriever.py b/llama-index-core/llama_index/core/retrievers/fusion_retriever.py
index 45fa1bae89..ae84a98e87 100644
--- a/llama-index-core/llama_index/core/retrievers/fusion_retriever.py
+++ b/llama-index-core/llama_index/core/retrievers/fusion_retriever.py
@@ -80,7 +80,7 @@ class QueryFusionRetriever(BaseRetriever):
                 PromptTemplate, prompts["query_gen_prompt"]
             ).template
 
-    def _get_queries(self, original_query: str) -> List[str]:
+    def _get_queries(self, original_query: str) -> List[QueryBundle]:
         prompt_str = self.query_gen_prompt.format(
             num_queries=self.num_queries - 1,
             query=original_query,
@@ -92,8 +92,9 @@ class QueryFusionRetriever(BaseRetriever):
         if self._verbose:
             queries_str = "\n".join(queries)
             print(f"Generated queries:\n{queries_str}")
+
         # The LLM often returns more queries than we asked for, so trim the list.
-        return response.text.split("\n")[: self.num_queries]
+        return [QueryBundle(q) for q in queries[: self.num_queries - 1]]
 
     def _reciprocal_rerank_fusion(
         self, results: Dict[Tuple[str, int], List[NodeWithScore]]
@@ -206,13 +207,13 @@ class QueryFusionRetriever(BaseRetriever):
         return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)
 
     def _run_nested_async_queries(
-        self, queries: List[str]
+        self, queries: List[QueryBundle]
     ) -> Dict[Tuple[str, int], List[NodeWithScore]]:
         tasks, task_queries = [], []
         for query in queries:
             for i, retriever in enumerate(self._retrievers):
                 tasks.append(retriever.aretrieve(query))
-                task_queries.append((query, i))
+                task_queries.append((query.query_str, i))
 
         task_results = run_async_tasks(tasks)
 
@@ -223,13 +224,13 @@ class QueryFusionRetriever(BaseRetriever):
         return results
 
     async def _run_async_queries(
-        self, queries: List[str]
+        self, queries: List[QueryBundle]
     ) -> Dict[Tuple[str, int], List[NodeWithScore]]:
         tasks, task_queries = [], []
         for query in queries:
             for i, retriever in enumerate(self._retrievers):
                 tasks.append(retriever.aretrieve(query))
-                task_queries.append((query, i))
+                task_queries.append((query.query_str, i))
 
         task_results = await asyncio.gather(*tasks)
 
@@ -240,20 +241,19 @@ class QueryFusionRetriever(BaseRetriever):
         return results
 
     def _run_sync_queries(
-        self, queries: List[str]
+        self, queries: List[QueryBundle]
     ) -> Dict[Tuple[str, int], List[NodeWithScore]]:
         results = {}
         for query in queries:
             for i, retriever in enumerate(self._retrievers):
-                results[(query, i)] = retriever.retrieve(query)
+                results[(query.query_str, i)] = retriever.retrieve(query)
 
         return results
 
     def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
+        queries: List[QueryBundle] = [query_bundle]
         if self.num_queries > 1:
-            queries = self._get_queries(query_bundle.query_str)
-        else:
-            queries = [query_bundle.query_str]
+            queries.extend(self._get_queries(query_bundle.query_str))
 
         if self.use_async:
             results = self._run_nested_async_queries(queries)
@@ -274,10 +274,9 @@ class QueryFusionRetriever(BaseRetriever):
             raise ValueError(f"Invalid fusion mode: {self.mode}")
 
     async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
+        queries: List[QueryBundle] = [query_bundle]
         if self.num_queries > 1:
-            queries = self._get_queries(query_bundle.query_str)
-        else:
-            queries = [query_bundle.query_str]
+            queries.extend(self._get_queries(query_bundle.query_str))
 
         results = await self._run_async_queries(queries)
 
-- 
GitLab