From 0bcc8b509a698a455f9b0868c18d57613f692b21 Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Wed, 7 Feb 2024 10:34:20 +0700
Subject: [PATCH] Refactor: add AgentRunner.from_llm method (#10452)

refactor: add factory method to create agent by llm
---
 .../deploying/agents/usage_pattern.md         |  8 ++++
 llama_index/agent/runner/base.py              | 27 ++++++++++++
 llama_index/indices/base.py                   | 43 +++++--------------
 tests/agent/runner/test_base.py               | 13 ++++++
 4 files changed, 58 insertions(+), 33 deletions(-)

diff --git a/docs/module_guides/deploying/agents/usage_pattern.md b/docs/module_guides/deploying/agents/usage_pattern.md
index 9283c9769..9f06c0bff 100644
--- a/docs/module_guides/deploying/agents/usage_pattern.md
+++ b/docs/module_guides/deploying/agents/usage_pattern.md
@@ -34,6 +34,14 @@ Example usage:
 agent.chat("What is 2123 * 215123")
 ```
 
+To automatically pick the best agent depending on the LLM, you can use the `from_llm` method to generate an agent.
+
+```python
+from llama_index.agent import AgentRunner
+
+agent = AgentRunner.from_llm([multiply_tool], llm=llm, verbose=True)
+```
+
 ## Defining Tools
 
 ### Query Engine Tools
diff --git a/llama_index/agent/runner/base.py b/llama_index/agent/runner/base.py
index 0e33dc15e..d9019ab85 100644
--- a/llama_index/agent/runner/base.py
+++ b/llama_index/agent/runner/base.py
@@ -27,6 +27,7 @@ from llama_index.llms.base import ChatMessage
 from llama_index.llms.llm import LLM
 from llama_index.memory import BaseMemory, ChatMemoryBuffer
 from llama_index.memory.types import BaseMemory
+from llama_index.tools.types import BaseTool
 
 
 class BaseAgentRunner(BaseAgent):
@@ -220,6 +221,32 @@ class AgentRunner(BaseAgentRunner):
         self.delete_task_on_finish = delete_task_on_finish
         self.default_tool_choice = default_tool_choice
 
+    @staticmethod
+    def from_llm(
+        tools: Optional[List[BaseTool]] = None,
+        llm: Optional[LLM] = None,
+        **kwargs: Any,
+    ) -> "AgentRunner":
+        from llama_index.llms.openai import OpenAI
+        from llama_index.llms.openai_utils import is_function_calling_model
+
+        if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
+            from llama_index.agent import OpenAIAgent
+
+            return OpenAIAgent.from_tools(
+                tools=tools,
+                llm=llm,
+                **kwargs,
+            )
+        else:
+            from llama_index.agent import ReActAgent
+
+            return ReActAgent.from_tools(
+                tools=tools,
+                llm=llm,
+                **kwargs,
+            )
+
     @property
     def chat_history(self) -> List[ChatMessage]:
         return self.memory.get_all()
diff --git a/llama_index/indices/base.py b/llama_index/indices/base.py
index 7028d9d5e..8cfd2e9f6 100644
--- a/llama_index/indices/base.py
+++ b/llama_index/indices/base.py
@@ -8,8 +8,6 @@ from llama_index.core.base_query_engine import BaseQueryEngine
 from llama_index.core.base_retriever import BaseRetriever
 from llama_index.data_structs.data_structs import IndexStruct
 from llama_index.ingestion import run_transformations
-from llama_index.llms.openai import OpenAI
-from llama_index.llms.openai_utils import is_function_calling_model
 from llama_index.schema import BaseNode, Document, IndexNode
 from llama_index.service_context import ServiceContext
 from llama_index.storage.docstore.types import BaseDocumentStore, RefDocInfo
@@ -364,15 +362,20 @@ class BaseIndex(Generic[IS], ABC):
             kwargs["service_context"] = self._service_context
 
         # resolve chat mode
-        if chat_mode == ChatMode.BEST:
+        if chat_mode in [ChatMode.REACT, ChatMode.OPENAI, ChatMode.BEST]:
+            # use an agent with query engine tool in these chat modes
+            # NOTE: lazy import
+            from llama_index.agent import AgentRunner
+            from llama_index.tools.query_engine import QueryEngineTool
+
             # get LLM
             service_context = cast(ServiceContext, kwargs["service_context"])
             llm = service_context.llm
 
-            if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
-                chat_mode = ChatMode.OPENAI
-            else:
-                chat_mode = ChatMode.REACT
+            # convert query engine to tool
+            query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
+
+            return AgentRunner.from_llm(tools=[query_engine_tool], llm=llm, **kwargs)
 
         if chat_mode == ChatMode.CONDENSE_QUESTION:
             # NOTE: lazy import
@@ -398,32 +401,6 @@ class BaseIndex(Generic[IS], ABC):
                 **kwargs,
             )
 
-        elif chat_mode in [ChatMode.REACT, ChatMode.OPENAI]:
-            # NOTE: lazy import
-            from llama_index.agent import OpenAIAgent, ReActAgent
-            from llama_index.tools.query_engine import QueryEngineTool
-
-            # convert query engine to tool
-            query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
-
-            # get LLM
-            service_context = cast(ServiceContext, kwargs.pop("service_context"))
-            llm = service_context.llm
-
-            if chat_mode == ChatMode.REACT:
-                return ReActAgent.from_tools(
-                    tools=[query_engine_tool],
-                    llm=llm,
-                    **kwargs,
-                )
-            elif chat_mode == ChatMode.OPENAI:
-                return OpenAIAgent.from_tools(
-                    tools=[query_engine_tool],
-                    llm=llm,
-                    **kwargs,
-                )
-            else:
-                raise ValueError(f"Unknown chat mode: {chat_mode}")
         elif chat_mode == ChatMode.SIMPLE:
             from llama_index.chat_engine import SimpleChatEngine
 
diff --git a/tests/agent/runner/test_base.py b/tests/agent/runner/test_base.py
index 1c402c9c8..374889392 100644
--- a/tests/agent/runner/test_base.py
+++ b/tests/agent/runner/test_base.py
@@ -189,3 +189,16 @@ def test_dag_agent() -> None:
     assert step_outputs[0].is_last is True
     assert step_outputs[1].is_last is True
     assert len(agent_runner.state.task_dict[task.task_id].completed_steps) == 3
+
+
+def test_agent_from_llm() -> None:
+    from llama_index.agent import OpenAIAgent, ReActAgent
+    from llama_index.llms.mock import MockLLM
+    from llama_index.llms.openai import OpenAI
+
+    llm = OpenAI()
+    agent_runner = AgentRunner.from_llm(llm=llm)
+    assert isinstance(agent_runner, OpenAIAgent)
+    llm = MockLLM()
+    agent_runner = AgentRunner.from_llm(llm=llm)
+    assert isinstance(agent_runner, ReActAgent)
-- 
GitLab