From 0bcc8b509a698a455f9b0868c18d57613f692b21 Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Wed, 7 Feb 2024 10:34:20 +0700 Subject: [PATCH] Refactor: add AgentRunner.from_llm method (#10452) refactor: add factory method to create agent by llm --- .../deploying/agents/usage_pattern.md | 8 ++++ llama_index/agent/runner/base.py | 27 ++++++++++++ llama_index/indices/base.py | 43 +++++-------------- tests/agent/runner/test_base.py | 13 ++++++ 4 files changed, 58 insertions(+), 33 deletions(-) diff --git a/docs/module_guides/deploying/agents/usage_pattern.md b/docs/module_guides/deploying/agents/usage_pattern.md index 9283c9769..9f06c0bff 100644 --- a/docs/module_guides/deploying/agents/usage_pattern.md +++ b/docs/module_guides/deploying/agents/usage_pattern.md @@ -34,6 +34,14 @@ Example usage: agent.chat("What is 2123 * 215123") ``` +To automatically pick the best agent depending on the LLM, you can use the `from_llm` method to generate an agent. + +```python +from llama_index.agent import AgentRunner + +agent = AgentRunner.from_llm([multiply_tool], llm=llm, verbose=True) +``` + ## Defining Tools ### Query Engine Tools diff --git a/llama_index/agent/runner/base.py b/llama_index/agent/runner/base.py index 0e33dc15e..d9019ab85 100644 --- a/llama_index/agent/runner/base.py +++ b/llama_index/agent/runner/base.py @@ -27,6 +27,7 @@ from llama_index.llms.base import ChatMessage from llama_index.llms.llm import LLM from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.memory.types import BaseMemory +from llama_index.tools.types import BaseTool class BaseAgentRunner(BaseAgent): @@ -220,6 +221,32 @@ class AgentRunner(BaseAgentRunner): self.delete_task_on_finish = delete_task_on_finish self.default_tool_choice = default_tool_choice + @staticmethod + def from_llm( + tools: Optional[List[BaseTool]] = None, + llm: Optional[LLM] = None, + **kwargs: Any, + ) -> "AgentRunner": + from llama_index.llms.openai import OpenAI + from llama_index.llms.openai_utils import is_function_calling_model + + if isinstance(llm, OpenAI) and is_function_calling_model(llm.model): + from llama_index.agent import OpenAIAgent + + return OpenAIAgent.from_tools( + tools=tools, + llm=llm, + **kwargs, + ) + else: + from llama_index.agent import ReActAgent + + return ReActAgent.from_tools( + tools=tools, + llm=llm, + **kwargs, + ) + @property def chat_history(self) -> List[ChatMessage]: return self.memory.get_all() diff --git a/llama_index/indices/base.py b/llama_index/indices/base.py index 7028d9d5e..8cfd2e9f6 100644 --- a/llama_index/indices/base.py +++ b/llama_index/indices/base.py @@ -8,8 +8,6 @@ from llama_index.core.base_query_engine import BaseQueryEngine from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import IndexStruct from llama_index.ingestion import run_transformations -from llama_index.llms.openai import OpenAI -from llama_index.llms.openai_utils import is_function_calling_model from llama_index.schema import BaseNode, Document, IndexNode from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import BaseDocumentStore, RefDocInfo @@ -364,15 +362,20 @@ class BaseIndex(Generic[IS], ABC): kwargs["service_context"] = self._service_context # resolve chat mode - if chat_mode == ChatMode.BEST: + if chat_mode in [ChatMode.REACT, ChatMode.OPENAI, ChatMode.BEST]: + # use an agent with query engine tool in these chat modes + # NOTE: lazy import + from llama_index.agent import AgentRunner + from llama_index.tools.query_engine import QueryEngineTool + # get LLM service_context = cast(ServiceContext, kwargs["service_context"]) llm = service_context.llm - if isinstance(llm, OpenAI) and is_function_calling_model(llm.model): - chat_mode = ChatMode.OPENAI - else: - chat_mode = ChatMode.REACT + # convert query engine to tool + query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) + + return AgentRunner.from_llm(tools=[query_engine_tool], llm=llm, **kwargs) if chat_mode == ChatMode.CONDENSE_QUESTION: # NOTE: lazy import @@ -398,32 +401,6 @@ class BaseIndex(Generic[IS], ABC): **kwargs, ) - elif chat_mode in [ChatMode.REACT, ChatMode.OPENAI]: - # NOTE: lazy import - from llama_index.agent import OpenAIAgent, ReActAgent - from llama_index.tools.query_engine import QueryEngineTool - - # convert query engine to tool - query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) - - # get LLM - service_context = cast(ServiceContext, kwargs.pop("service_context")) - llm = service_context.llm - - if chat_mode == ChatMode.REACT: - return ReActAgent.from_tools( - tools=[query_engine_tool], - llm=llm, - **kwargs, - ) - elif chat_mode == ChatMode.OPENAI: - return OpenAIAgent.from_tools( - tools=[query_engine_tool], - llm=llm, - **kwargs, - ) - else: - raise ValueError(f"Unknown chat mode: {chat_mode}") elif chat_mode == ChatMode.SIMPLE: from llama_index.chat_engine import SimpleChatEngine diff --git a/tests/agent/runner/test_base.py b/tests/agent/runner/test_base.py index 1c402c9c8..374889392 100644 --- a/tests/agent/runner/test_base.py +++ b/tests/agent/runner/test_base.py @@ -189,3 +189,16 @@ def test_dag_agent() -> None: assert step_outputs[0].is_last is True assert step_outputs[1].is_last is True assert len(agent_runner.state.task_dict[task.task_id].completed_steps) == 3 + + +def test_agent_from_llm() -> None: + from llama_index.agent import OpenAIAgent, ReActAgent + from llama_index.llms.mock import MockLLM + from llama_index.llms.openai import OpenAI + + llm = OpenAI() + agent_runner = AgentRunner.from_llm(llm=llm) + assert isinstance(agent_runner, OpenAIAgent) + llm = MockLLM() + agent_runner = AgentRunner.from_llm(llm=llm) + assert isinstance(agent_runner, ReActAgent) -- GitLab