Skip to content
Snippets Groups Projects
Unverified Commit 0bcc8b50 authored by Marcus Schiesser's avatar Marcus Schiesser Committed by GitHub
Browse files

Refactor: add AgentRunner.from_llm method (#10452)

refactor: add factory method to create agent by llm
parent 595c567b
No related branches found
No related tags found
No related merge requests found
......@@ -34,6 +34,14 @@ Example usage:
agent.chat("What is 2123 * 215123")
```
To automatically pick the best agent depending on the LLM, you can use the `from_llm` method to generate an agent.
```python
from llama_index.agent import AgentRunner
agent = AgentRunner.from_llm([multiply_tool], llm=llm, verbose=True)
```
## Defining Tools
### Query Engine Tools
......
......@@ -27,6 +27,7 @@ from llama_index.llms.base import ChatMessage
from llama_index.llms.llm import LLM
from llama_index.memory import BaseMemory, ChatMemoryBuffer
from llama_index.memory.types import BaseMemory
from llama_index.tools.types import BaseTool
class BaseAgentRunner(BaseAgent):
......@@ -220,6 +221,32 @@ class AgentRunner(BaseAgentRunner):
self.delete_task_on_finish = delete_task_on_finish
self.default_tool_choice = default_tool_choice
@staticmethod
def from_llm(
tools: Optional[List[BaseTool]] = None,
llm: Optional[LLM] = None,
**kwargs: Any,
) -> "AgentRunner":
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_utils import is_function_calling_model
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
from llama_index.agent import OpenAIAgent
return OpenAIAgent.from_tools(
tools=tools,
llm=llm,
**kwargs,
)
else:
from llama_index.agent import ReActAgent
return ReActAgent.from_tools(
tools=tools,
llm=llm,
**kwargs,
)
@property
def chat_history(self) -> List[ChatMessage]:
return self.memory.get_all()
......
......@@ -8,8 +8,6 @@ from llama_index.core.base_query_engine import BaseQueryEngine
from llama_index.core.base_retriever import BaseRetriever
from llama_index.data_structs.data_structs import IndexStruct
from llama_index.ingestion import run_transformations
from llama_index.llms.openai import OpenAI
from llama_index.llms.openai_utils import is_function_calling_model
from llama_index.schema import BaseNode, Document, IndexNode
from llama_index.service_context import ServiceContext
from llama_index.storage.docstore.types import BaseDocumentStore, RefDocInfo
......@@ -364,15 +362,20 @@ class BaseIndex(Generic[IS], ABC):
kwargs["service_context"] = self._service_context
# resolve chat mode
if chat_mode == ChatMode.BEST:
if chat_mode in [ChatMode.REACT, ChatMode.OPENAI, ChatMode.BEST]:
# use an agent with query engine tool in these chat modes
# NOTE: lazy import
from llama_index.agent import AgentRunner
from llama_index.tools.query_engine import QueryEngineTool
# get LLM
service_context = cast(ServiceContext, kwargs["service_context"])
llm = service_context.llm
if isinstance(llm, OpenAI) and is_function_calling_model(llm.model):
chat_mode = ChatMode.OPENAI
else:
chat_mode = ChatMode.REACT
# convert query engine to tool
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
return AgentRunner.from_llm(tools=[query_engine_tool], llm=llm, **kwargs)
if chat_mode == ChatMode.CONDENSE_QUESTION:
# NOTE: lazy import
......@@ -398,32 +401,6 @@ class BaseIndex(Generic[IS], ABC):
**kwargs,
)
elif chat_mode in [ChatMode.REACT, ChatMode.OPENAI]:
# NOTE: lazy import
from llama_index.agent import OpenAIAgent, ReActAgent
from llama_index.tools.query_engine import QueryEngineTool
# convert query engine to tool
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
# get LLM
service_context = cast(ServiceContext, kwargs.pop("service_context"))
llm = service_context.llm
if chat_mode == ChatMode.REACT:
return ReActAgent.from_tools(
tools=[query_engine_tool],
llm=llm,
**kwargs,
)
elif chat_mode == ChatMode.OPENAI:
return OpenAIAgent.from_tools(
tools=[query_engine_tool],
llm=llm,
**kwargs,
)
else:
raise ValueError(f"Unknown chat mode: {chat_mode}")
elif chat_mode == ChatMode.SIMPLE:
from llama_index.chat_engine import SimpleChatEngine
......
......@@ -189,3 +189,16 @@ def test_dag_agent() -> None:
assert step_outputs[0].is_last is True
assert step_outputs[1].is_last is True
assert len(agent_runner.state.task_dict[task.task_id].completed_steps) == 3
def test_agent_from_llm() -> None:
from llama_index.agent import OpenAIAgent, ReActAgent
from llama_index.llms.mock import MockLLM
from llama_index.llms.openai import OpenAI
llm = OpenAI()
agent_runner = AgentRunner.from_llm(llm=llm)
assert isinstance(agent_runner, OpenAIAgent)
llm = MockLLM()
agent_runner = AgentRunner.from_llm(llm=llm)
assert isinstance(agent_runner, ReActAgent)
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment