From 2710433c746acdefba1d3604486ab3766b2f42c0 Mon Sep 17 00:00:00 2001 From: Diego Colombo <dicolomb@microsoft.com> Date: Wed, 3 Apr 2024 05:54:06 +0100 Subject: [PATCH] FLAREInstructQueryEngine : delegating retriever api if the query engine supports it (#12503) * delegating retriever api if the query engine supports it * add comments * address lint errors --- .../core/query_engine/flare/base.py | 24 +++++++++++++++++-- 1 file changed, 22 insertions(+), 2 deletions(-) diff --git a/llama-index-core/llama_index/core/query_engine/flare/base.py b/llama-index-core/llama_index/core/query_engine/flare/base.py index f7515fe558..e7860e1072 100644 --- a/llama-index-core/llama_index/core/query_engine/flare/base.py +++ b/llama-index-core/llama_index/core/query_engine/flare/base.py @@ -4,7 +4,7 @@ Active Retrieval Augmented Generation. """ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List from llama_index.core.base.base_query_engine import BaseQueryEngine from llama_index.core.base.response.schema import RESPONSE_TYPE, Response @@ -20,7 +20,9 @@ from llama_index.core.query_engine.flare.output_parser import ( IsDoneOutputParser, QueryTaskOutputParser, ) -from llama_index.core.schema import QueryBundle + +from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine +from llama_index.core.schema import QueryBundle, NodeWithScore from llama_index.core.service_context import ServiceContext from llama_index.core.settings import Settings, llm_from_settings_or_context from llama_index.core.utils import print_text @@ -258,3 +260,21 @@ class FLAREInstructQueryEngine(BaseQueryEngine): async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: return self._query(query_bundle) + + def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + # if the query engine is a retriever, then use the retrieve method + if isinstance(self._query_engine, RetrieverQueryEngine): + return self._query_engine.retrieve(query_bundle) + else: + raise NotImplementedError( + "This query engine does not support retrieve, use query directly" + ) + + async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + # if the query engine is a retriever, then use the retrieve method + if isinstance(self._query_engine, RetrieverQueryEngine): + return await self._query_engine.aretrieve(query_bundle) + else: + raise NotImplementedError( + "This query engine does not support retrieve, use query directly" + ) -- GitLab