From 4fff8e7afd8b37bd6f99318d85800e5baba49044 Mon Sep 17 00:00:00 2001 From: Grant Doctor <gdoctor@users.noreply.github.com> Date: Fri, 29 Mar 2024 10:01:12 -0700 Subject: [PATCH] QueryFusionRetriever sends incorrect QueryBundle type to it's retrievers (#12387) --- .../core/retrievers/fusion_retriever.py | 27 +++++++++---------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/llama-index-core/llama_index/core/retrievers/fusion_retriever.py b/llama-index-core/llama_index/core/retrievers/fusion_retriever.py index 45fa1bae89..ae84a98e87 100644 --- a/llama-index-core/llama_index/core/retrievers/fusion_retriever.py +++ b/llama-index-core/llama_index/core/retrievers/fusion_retriever.py @@ -80,7 +80,7 @@ class QueryFusionRetriever(BaseRetriever): PromptTemplate, prompts["query_gen_prompt"] ).template - def _get_queries(self, original_query: str) -> List[str]: + def _get_queries(self, original_query: str) -> List[QueryBundle]: prompt_str = self.query_gen_prompt.format( num_queries=self.num_queries - 1, query=original_query, @@ -92,8 +92,9 @@ class QueryFusionRetriever(BaseRetriever): if self._verbose: queries_str = "\n".join(queries) print(f"Generated queries:\n{queries_str}") + # The LLM often returns more queries than we asked for, so trim the list. - return response.text.split("\n")[: self.num_queries] + return [QueryBundle(q) for q in queries[: self.num_queries - 1]] def _reciprocal_rerank_fusion( self, results: Dict[Tuple[str, int], List[NodeWithScore]] @@ -206,13 +207,13 @@ class QueryFusionRetriever(BaseRetriever): return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True) def _run_nested_async_queries( - self, queries: List[str] + self, queries: List[QueryBundle] ) -> Dict[Tuple[str, int], List[NodeWithScore]]: tasks, task_queries = [], [] for query in queries: for i, retriever in enumerate(self._retrievers): tasks.append(retriever.aretrieve(query)) - task_queries.append((query, i)) + task_queries.append((query.query_str, i)) task_results = run_async_tasks(tasks) @@ -223,13 +224,13 @@ class QueryFusionRetriever(BaseRetriever): return results async def _run_async_queries( - self, queries: List[str] + self, queries: List[QueryBundle] ) -> Dict[Tuple[str, int], List[NodeWithScore]]: tasks, task_queries = [], [] for query in queries: for i, retriever in enumerate(self._retrievers): tasks.append(retriever.aretrieve(query)) - task_queries.append((query, i)) + task_queries.append((query.query_str, i)) task_results = await asyncio.gather(*tasks) @@ -240,20 +241,19 @@ class QueryFusionRetriever(BaseRetriever): return results def _run_sync_queries( - self, queries: List[str] + self, queries: List[QueryBundle] ) -> Dict[Tuple[str, int], List[NodeWithScore]]: results = {} for query in queries: for i, retriever in enumerate(self._retrievers): - results[(query, i)] = retriever.retrieve(query) + results[(query.query_str, i)] = retriever.retrieve(query) return results def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + queries: List[QueryBundle] = [query_bundle] if self.num_queries > 1: - queries = self._get_queries(query_bundle.query_str) - else: - queries = [query_bundle.query_str] + queries.extend(self._get_queries(query_bundle.query_str)) if self.use_async: results = self._run_nested_async_queries(queries) @@ -274,10 +274,9 @@ class QueryFusionRetriever(BaseRetriever): raise ValueError(f"Invalid fusion mode: {self.mode}") async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + queries: List[QueryBundle] = [query_bundle] if self.num_queries > 1: - queries = self._get_queries(query_bundle.query_str) - else: - queries = [query_bundle.query_str] + queries.extend(self._get_queries(query_bundle.query_str)) results = await self._run_async_queries(queries) -- GitLab