diff --git a/docs/module_guides/deploying/agents/usage_pattern.md b/docs/module_guides/deploying/agents/usage_pattern.md index 9283c97695ead6f091867075b46d95a7d63a4c79..9f06c0bff9bae9931f74bdc75202249eb905cc32 100644 --- a/docs/module_guides/deploying/agents/usage_pattern.md +++ b/docs/module_guides/deploying/agents/usage_pattern.md @@ -34,6 +34,14 @@ Example usage: agent.chat("What is 2123 * 215123") ``` +To automatically pick the best agent depending on the LLM, you can use the `from_llm` method to generate an agent. + +```python +from llama_index.agent import AgentRunner + +agent = AgentRunner.from_llm([multiply_tool], llm=llm, verbose=True) +``` + ## Defining Tools ### Query Engine Tools diff --git a/llama_index/agent/runner/base.py b/llama_index/agent/runner/base.py index 0e33dc15ed5f1ec200bc4b185ac98e7b7d49d047..d9019ab857338bf4427a9ad5820522da6a69995e 100644 --- a/llama_index/agent/runner/base.py +++ b/llama_index/agent/runner/base.py @@ -27,6 +27,7 @@ from llama_index.llms.base import ChatMessage from llama_index.llms.llm import LLM from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.memory.types import BaseMemory +from llama_index.tools.types import BaseTool class BaseAgentRunner(BaseAgent): @@ -220,6 +221,32 @@ class AgentRunner(BaseAgentRunner): self.delete_task_on_finish = delete_task_on_finish self.default_tool_choice = default_tool_choice + @staticmethod + def from_llm( + tools: Optional[List[BaseTool]] = None, + llm: Optional[LLM] = None, + **kwargs: Any, + ) -> "AgentRunner": + from llama_index.llms.openai import OpenAI + from llama_index.llms.openai_utils import is_function_calling_model + + if isinstance(llm, OpenAI) and is_function_calling_model(llm.model): + from llama_index.agent import OpenAIAgent + + return OpenAIAgent.from_tools( + tools=tools, + llm=llm, + **kwargs, + ) + else: + from llama_index.agent import ReActAgent + + return ReActAgent.from_tools( + tools=tools, + llm=llm, + **kwargs, + ) + @property def chat_history(self) -> List[ChatMessage]: return self.memory.get_all() diff --git a/llama_index/indices/base.py b/llama_index/indices/base.py index 7028d9d5e403d066a1fb24742d77ae9658d1cde5..8cfd2e9f6566812dd028e812c1ceb3329795630d 100644 --- a/llama_index/indices/base.py +++ b/llama_index/indices/base.py @@ -8,8 +8,6 @@ from llama_index.core.base_query_engine import BaseQueryEngine from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import IndexStruct from llama_index.ingestion import run_transformations -from llama_index.llms.openai import OpenAI -from llama_index.llms.openai_utils import is_function_calling_model from llama_index.schema import BaseNode, Document, IndexNode from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import BaseDocumentStore, RefDocInfo @@ -364,15 +362,20 @@ class BaseIndex(Generic[IS], ABC): kwargs["service_context"] = self._service_context # resolve chat mode - if chat_mode == ChatMode.BEST: + if chat_mode in [ChatMode.REACT, ChatMode.OPENAI, ChatMode.BEST]: + # use an agent with query engine tool in these chat modes + # NOTE: lazy import + from llama_index.agent import AgentRunner + from llama_index.tools.query_engine import QueryEngineTool + # get LLM service_context = cast(ServiceContext, kwargs["service_context"]) llm = service_context.llm - if isinstance(llm, OpenAI) and is_function_calling_model(llm.model): - chat_mode = ChatMode.OPENAI - else: - chat_mode = ChatMode.REACT + # convert query engine to tool + query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) + + return AgentRunner.from_llm(tools=[query_engine_tool], llm=llm, **kwargs) if chat_mode == ChatMode.CONDENSE_QUESTION: # NOTE: lazy import @@ -398,32 +401,6 @@ class BaseIndex(Generic[IS], ABC): **kwargs, ) - elif chat_mode in [ChatMode.REACT, ChatMode.OPENAI]: - # NOTE: lazy import - from llama_index.agent import OpenAIAgent, ReActAgent - from llama_index.tools.query_engine import QueryEngineTool - - # convert query engine to tool - query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) - - # get LLM - service_context = cast(ServiceContext, kwargs.pop("service_context")) - llm = service_context.llm - - if chat_mode == ChatMode.REACT: - return ReActAgent.from_tools( - tools=[query_engine_tool], - llm=llm, - **kwargs, - ) - elif chat_mode == ChatMode.OPENAI: - return OpenAIAgent.from_tools( - tools=[query_engine_tool], - llm=llm, - **kwargs, - ) - else: - raise ValueError(f"Unknown chat mode: {chat_mode}") elif chat_mode == ChatMode.SIMPLE: from llama_index.chat_engine import SimpleChatEngine diff --git a/tests/agent/runner/test_base.py b/tests/agent/runner/test_base.py index 1c402c9c8c3738a4eb6abf24abf05bf8fee9b323..374889392e26c83dad7e772e018bdfb5a833027a 100644 --- a/tests/agent/runner/test_base.py +++ b/tests/agent/runner/test_base.py @@ -189,3 +189,16 @@ def test_dag_agent() -> None: assert step_outputs[0].is_last is True assert step_outputs[1].is_last is True assert len(agent_runner.state.task_dict[task.task_id].completed_steps) == 3 + + +def test_agent_from_llm() -> None: + from llama_index.agent import OpenAIAgent, ReActAgent + from llama_index.llms.mock import MockLLM + from llama_index.llms.openai import OpenAI + + llm = OpenAI() + agent_runner = AgentRunner.from_llm(llm=llm) + assert isinstance(agent_runner, OpenAIAgent) + llm = MockLLM() + agent_runner = AgentRunner.from_llm(llm=llm) + assert isinstance(agent_runner, ReActAgent)