diff --git a/llama-index-core/llama_index/core/query_engine/flare/base.py b/llama-index-core/llama_index/core/query_engine/flare/base.py index f7515fe558f2198c4638f87ba75f114883151b59..e7860e10722d0d7b9ec1a5551d19f594e9398287 100644 --- a/llama-index-core/llama_index/core/query_engine/flare/base.py +++ b/llama-index-core/llama_index/core/query_engine/flare/base.py @@ -4,7 +4,7 @@ Active Retrieval Augmented Generation. """ -from typing import Any, Dict, Optional +from typing import Any, Dict, Optional, List from llama_index.core.base.base_query_engine import BaseQueryEngine from llama_index.core.base.response.schema import RESPONSE_TYPE, Response @@ -20,7 +20,9 @@ from llama_index.core.query_engine.flare.output_parser import ( IsDoneOutputParser, QueryTaskOutputParser, ) -from llama_index.core.schema import QueryBundle + +from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine +from llama_index.core.schema import QueryBundle, NodeWithScore from llama_index.core.service_context import ServiceContext from llama_index.core.settings import Settings, llm_from_settings_or_context from llama_index.core.utils import print_text @@ -258,3 +260,21 @@ class FLAREInstructQueryEngine(BaseQueryEngine): async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: return self._query(query_bundle) + + def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + # if the query engine is a retriever, then use the retrieve method + if isinstance(self._query_engine, RetrieverQueryEngine): + return self._query_engine.retrieve(query_bundle) + else: + raise NotImplementedError( + "This query engine does not support retrieve, use query directly" + ) + + async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + # if the query engine is a retriever, then use the retrieve method + if isinstance(self._query_engine, RetrieverQueryEngine): + return await self._query_engine.aretrieve(query_bundle) + else: + raise NotImplementedError( + "This query engine does not support retrieve, use query directly" + )