Skip to content
Snippets Groups Projects
Unverified Commit 4fff8e7a authored by Grant Doctor's avatar Grant Doctor Committed by GitHub
Browse files

QueryFusionRetriever sends incorrect QueryBundle type to it's retrievers (#12387)

parent e8eae515
Branches
Tags
No related merge requests found
......@@ -80,7 +80,7 @@ class QueryFusionRetriever(BaseRetriever):
PromptTemplate, prompts["query_gen_prompt"]
).template
def _get_queries(self, original_query: str) -> List[str]:
def _get_queries(self, original_query: str) -> List[QueryBundle]:
prompt_str = self.query_gen_prompt.format(
num_queries=self.num_queries - 1,
query=original_query,
......@@ -92,8 +92,9 @@ class QueryFusionRetriever(BaseRetriever):
if self._verbose:
queries_str = "\n".join(queries)
print(f"Generated queries:\n{queries_str}")
# The LLM often returns more queries than we asked for, so trim the list.
return response.text.split("\n")[: self.num_queries]
return [QueryBundle(q) for q in queries[: self.num_queries - 1]]
def _reciprocal_rerank_fusion(
self, results: Dict[Tuple[str, int], List[NodeWithScore]]
......@@ -206,13 +207,13 @@ class QueryFusionRetriever(BaseRetriever):
return sorted(all_nodes.values(), key=lambda x: x.score or 0.0, reverse=True)
def _run_nested_async_queries(
self, queries: List[str]
self, queries: List[QueryBundle]
) -> Dict[Tuple[str, int], List[NodeWithScore]]:
tasks, task_queries = [], []
for query in queries:
for i, retriever in enumerate(self._retrievers):
tasks.append(retriever.aretrieve(query))
task_queries.append((query, i))
task_queries.append((query.query_str, i))
task_results = run_async_tasks(tasks)
......@@ -223,13 +224,13 @@ class QueryFusionRetriever(BaseRetriever):
return results
async def _run_async_queries(
self, queries: List[str]
self, queries: List[QueryBundle]
) -> Dict[Tuple[str, int], List[NodeWithScore]]:
tasks, task_queries = [], []
for query in queries:
for i, retriever in enumerate(self._retrievers):
tasks.append(retriever.aretrieve(query))
task_queries.append((query, i))
task_queries.append((query.query_str, i))
task_results = await asyncio.gather(*tasks)
......@@ -240,20 +241,19 @@ class QueryFusionRetriever(BaseRetriever):
return results
def _run_sync_queries(
self, queries: List[str]
self, queries: List[QueryBundle]
) -> Dict[Tuple[str, int], List[NodeWithScore]]:
results = {}
for query in queries:
for i, retriever in enumerate(self._retrievers):
results[(query, i)] = retriever.retrieve(query)
results[(query.query_str, i)] = retriever.retrieve(query)
return results
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
queries: List[QueryBundle] = [query_bundle]
if self.num_queries > 1:
queries = self._get_queries(query_bundle.query_str)
else:
queries = [query_bundle.query_str]
queries.extend(self._get_queries(query_bundle.query_str))
if self.use_async:
results = self._run_nested_async_queries(queries)
......@@ -274,10 +274,9 @@ class QueryFusionRetriever(BaseRetriever):
raise ValueError(f"Invalid fusion mode: {self.mode}")
async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
queries: List[QueryBundle] = [query_bundle]
if self.num_queries > 1:
queries = self._get_queries(query_bundle.query_str)
else:
queries = [query_bundle.query_str]
queries.extend(self._get_queries(query_bundle.query_str))
results = await self._run_async_queries(queries)
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment