diff --git a/llama_index/retrievers/fusion_retriever.py b/llama_index/retrievers/fusion_retriever.py index 99275fe8786a09f216356bd1300a8d7beb66fd82..a61f92734a4dc6e6eb17abe6223d201da7458f67 100644 --- a/llama_index/retrievers/fusion_retriever.py +++ b/llama_index/retrievers/fusion_retriever.py @@ -1,3 +1,4 @@ +import asyncio from enum import Enum from typing import Dict, List, Optional, Tuple @@ -125,6 +126,22 @@ class QueryFusionRetriever(BaseRetriever): return results + async def _run_async_queries_async( + self, queries: List[str] + ) -> Dict[Tuple[str, int], List[NodeWithScore]]: + tasks = [] + for query in queries: + for i, retriever in enumerate(self._retrievers): + tasks.append(retriever.aretrieve(query)) + + task_results = await asyncio.gather(*tasks) + + results = {} + for i, (query, query_result) in enumerate(zip(queries, task_results)): + results[(query, i)] = query_result + + return results + def _run_sync_queries( self, queries: List[str] ) -> Dict[Tuple[str, int], List[NodeWithScore]]: @@ -152,3 +169,18 @@ class QueryFusionRetriever(BaseRetriever): return self._simple_fusion(results)[: self.similarity_top_k] else: raise ValueError(f"Invalid fusion mode: {self.mode}") + + async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + if self.num_queries > 1: + queries = self._get_queries(query_bundle.query_str) + else: + queries = [query_bundle.query_str] + + results = await self._run_async_queries_async(queries) + + if self.mode == FUSION_MODES.RECIPROCAL_RANK: + return self._reciprocal_rerank_fusion(results)[: self.similarity_top_k] + elif self.mode == FUSION_MODES.SIMPLE: + return self._simple_fusion(results)[: self.similarity_top_k] + else: + raise ValueError(f"Invalid fusion mode: {self.mode}")