diff --git a/llama-index-core/llama_index/core/retrievers/fusion_retriever.py b/llama-index-core/llama_index/core/retrievers/fusion_retriever.py index 45fa1bae89ee75f399211e51c1a05923fc070148..ae84a98e875c173eb6a6c26308ed1b7f01bb5bbd 100644 --- a/llama-index-core/llama_index/core/retrievers/fusion_retriever.py +++ b/llama-index-core/llama_index/core/retrievers/fusion_retriever.py @@ -80,7 +80,7 @@ class QueryFusionRetriever(BaseRetriever): PromptTemplate, prompts["query_gen_prompt"] ).template - def _get_queries(self, original_query: str) -> List[str]: + def _get_queries(self, original_query: str) -> List[QueryBundle]: prompt_str = self.query_gen_prompt.format( num_queries=self.num_queries - 1, query=original_query, @@ -92,8 +92,9 @@ class QueryFusionRetriever(BaseRetriever): if self._verbose: queries_str = "\n".join(queries) print(f"Generated queries:\n{queries_str}") + # The LLM often returns more queries than we asked for, so trim the list. - return response.text.split("\n")[: self.num_queries] + return [QueryBundle(q) for q in queries[: self.num_queries - 1]] def _reciprocal_rerank_fusion( self, results: Dict[Tuple[str, int], List[NodeWithScore]] @@ -206,13 +207,13 @@ class QueryFusionRetriever(BaseRetriever): return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True) def _run_nested_async_queries( - self, queries: List[str] + self, queries: List[QueryBundle] ) -> Dict[Tuple[str, int], List[NodeWithScore]]: tasks, task_queries = [], [] for query in queries: for i, retriever in enumerate(self._retrievers): tasks.append(retriever.aretrieve(query)) - task_queries.append((query, i)) + task_queries.append((query.query_str, i)) task_results = run_async_tasks(tasks) @@ -223,13 +224,13 @@ class QueryFusionRetriever(BaseRetriever): return results async def _run_async_queries( - self, queries: List[str] + self, queries: List[QueryBundle] ) -> Dict[Tuple[str, int], List[NodeWithScore]]: tasks, task_queries = [], [] for query in queries: for i, retriever in enumerate(self._retrievers): tasks.append(retriever.aretrieve(query)) - task_queries.append((query, i)) + task_queries.append((query.query_str, i)) task_results = await asyncio.gather(*tasks) @@ -240,20 +241,19 @@ class QueryFusionRetriever(BaseRetriever): return results def _run_sync_queries( - self, queries: List[str] + self, queries: List[QueryBundle] ) -> Dict[Tuple[str, int], List[NodeWithScore]]: results = {} for query in queries: for i, retriever in enumerate(self._retrievers): - results[(query, i)] = retriever.retrieve(query) + results[(query.query_str, i)] = retriever.retrieve(query) return results def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + queries: List[QueryBundle] = [query_bundle] if self.num_queries > 1: - queries = self._get_queries(query_bundle.query_str) - else: - queries = [query_bundle.query_str] + queries.extend(self._get_queries(query_bundle.query_str)) if self.use_async: results = self._run_nested_async_queries(queries) @@ -274,10 +274,9 @@ class QueryFusionRetriever(BaseRetriever): raise ValueError(f"Invalid fusion mode: {self.mode}") async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + queries: List[QueryBundle] = [query_bundle] if self.num_queries > 1: - queries = self._get_queries(query_bundle.query_str) - else: - queries = [query_bundle.query_str] + queries.extend(self._get_queries(query_bundle.query_str)) results = await self._run_async_queries(queries)