From 2710433c746acdefba1d3604486ab3766b2f42c0 Mon Sep 17 00:00:00 2001
From: Diego Colombo <dicolomb@microsoft.com>
Date: Wed, 3 Apr 2024 05:54:06 +0100
Subject: [PATCH] FLAREInstructQueryEngine : delegating retriever api if the
 query engine supports it (#12503)

* delegating retriever api if the query engine supports it

* add comments

* address lint errors
---
 .../core/query_engine/flare/base.py           | 24 +++++++++++++++++--
 1 file changed, 22 insertions(+), 2 deletions(-)

diff --git a/llama-index-core/llama_index/core/query_engine/flare/base.py b/llama-index-core/llama_index/core/query_engine/flare/base.py
index f7515fe558..e7860e1072 100644
--- a/llama-index-core/llama_index/core/query_engine/flare/base.py
+++ b/llama-index-core/llama_index/core/query_engine/flare/base.py
@@ -4,7 +4,7 @@ Active Retrieval Augmented Generation.
 
 """
 
-from typing import Any, Dict, Optional
+from typing import Any, Dict, Optional, List
 
 from llama_index.core.base.base_query_engine import BaseQueryEngine
 from llama_index.core.base.response.schema import RESPONSE_TYPE, Response
@@ -20,7 +20,9 @@ from llama_index.core.query_engine.flare.output_parser import (
     IsDoneOutputParser,
     QueryTaskOutputParser,
 )
-from llama_index.core.schema import QueryBundle
+
+from llama_index.core.query_engine.retriever_query_engine import RetrieverQueryEngine
+from llama_index.core.schema import QueryBundle, NodeWithScore
 from llama_index.core.service_context import ServiceContext
 from llama_index.core.settings import Settings, llm_from_settings_or_context
 from llama_index.core.utils import print_text
@@ -258,3 +260,21 @@ class FLAREInstructQueryEngine(BaseQueryEngine):
 
     async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE:
         return self._query(query_bundle)
+
+    def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
+        # if the query engine is a retriever, then use the retrieve method
+        if isinstance(self._query_engine, RetrieverQueryEngine):
+            return self._query_engine.retrieve(query_bundle)
+        else:
+            raise NotImplementedError(
+                "This query engine does not support retrieve, use query directly"
+            )
+
+    async def aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
+        # if the query engine is a retriever, then use the retrieve method
+        if isinstance(self._query_engine, RetrieverQueryEngine):
+            return await self._query_engine.aretrieve(query_bundle)
+        else:
+            raise NotImplementedError(
+                "This query engine does not support retrieve, use query directly"
+            )
-- 
GitLab