diff --git a/docs/docs/examples/agent/return_direct_agent.ipynb b/docs/docs/examples/agent/return_direct_agent.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..f5ef5be32d9e08e7708570ef46ff7320183000a2 --- /dev/null +++ b/docs/docs/examples/agent/return_direct_agent.ipynb @@ -0,0 +1,361 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Controlling Agent Reasoning Loop with Return Direct Tools\n", + "\n", + "All tools have an option for `return_direct` -- if this is set to `True`, and the associated tool is called (without any other tools being called), the agent resoning loop is ended and the tool output is returned directly.\n", + "\n", + "This can be useful for speeding up resonse times when you know the tool output is good enough, to avoid the agent re-writing the response, and for ending the resoning loop." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "This notebook walks through a notebook where an agent needs to gather information from a user in order to make a resturant booking." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pip install llama-index-core llama-index-llms-anthropic" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"ANTHROPIC_API_KEY\"] = \"sk-ant-...\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Tools setup" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Optional\n", + "\n", + "from llama_index.core.tools import FunctionTool\n", + "from llama_index.core.bridge.pydantic import BaseModel\n", + "\n", + "# we will store booking under random IDs\n", + "bookings = {}\n", + "\n", + "\n", + "# we will represent and track the state of a booking as a Pydantic model\n", + "class Booking(BaseModel):\n", + " name: Optional[str] = None\n", + " email: Optional[str] = None\n", + " phone: Optional[str] = None\n", + " date: Optional[str] = None\n", + " time: Optional[str] = None\n", + "\n", + "\n", + "def get_booking_state(user_id: str) -> str:\n", + " \"\"\"Get the current state of a booking for a given booking ID.\"\"\"\n", + " try:\n", + " return str(bookings[user_id].dict())\n", + " except:\n", + " return f\"Booking ID {user_id} not found\"\n", + "\n", + "\n", + "def update_booking(user_id: str, property: str, value: str) -> str:\n", + " \"\"\"Update a property of a booking for a given booking ID. Only enter details that are explicitly provided.\"\"\"\n", + " booking = bookings[user_id]\n", + " setattr(booking, property, value)\n", + " return f\"Booking ID {user_id} updated with {property} = {value}\"\n", + "\n", + "\n", + "def create_booking(user_id: str) -> str:\n", + " \"\"\"Create a new booking and return the booking ID.\"\"\"\n", + " bookings[user_id] = Booking()\n", + " return \"Booking created, but not yet confirmed. Please provide your name, email, phone, date, and time.\"\n", + "\n", + "\n", + "def confirm_booking(user_id: str) -> str:\n", + " \"\"\"Confirm a booking for a given booking ID.\"\"\"\n", + " booking = bookings[user_id]\n", + "\n", + " if booking.name is None:\n", + " raise ValueError(\"Please provide your name.\")\n", + "\n", + " if booking.email is None:\n", + " raise ValueError(\"Please provide your email.\")\n", + "\n", + " if booking.phone is None:\n", + " raise ValueError(\"Please provide your phone number.\")\n", + "\n", + " if booking.date is None:\n", + " raise ValueError(\"Please provide the date of your booking.\")\n", + "\n", + " if booking.time is None:\n", + " raise ValueError(\"Please provide the time of your booking.\")\n", + "\n", + " return f\"Booking ID {user_id} confirmed!\"\n", + "\n", + "\n", + "# create tools for each function\n", + "get_booking_state_tool = FunctionTool.from_defaults(fn=get_booking_state)\n", + "update_booking_tool = FunctionTool.from_defaults(fn=update_booking)\n", + "create_booking_tool = FunctionTool.from_defaults(\n", + " fn=create_booking, return_direct=True\n", + ")\n", + "confirm_booking_tool = FunctionTool.from_defaults(\n", + " fn=confirm_booking, return_direct=True\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## A user has walked in! Lets help them make a booking" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index.llms.anthropic import Anthropic\n", + "from llama_index.core.llms import ChatMessage\n", + "from llama_index.core.agent import FunctionCallingAgentWorker, AgentRunner\n", + "\n", + "llm = Anthropic(model=\"claude-3-sonnet-20240229\", temperature=0.1)\n", + "\n", + "user = \"user123\"\n", + "prefix_messages = [\n", + " ChatMessage(\n", + " role=\"system\",\n", + " content=(\n", + " f\"You are now connected to the booking system and helping {user} with making a booking. \"\n", + " \"Only enter details that the user has explicitly provided. \"\n", + " \"Do not make up any details.\"\n", + " ),\n", + " )\n", + "]\n", + "\n", + "worker = FunctionCallingAgentWorker(\n", + " tools=[\n", + " get_booking_state_tool,\n", + " update_booking_tool,\n", + " create_booking_tool,\n", + " confirm_booking_tool,\n", + " ],\n", + " llm=llm,\n", + " prefix_messages=prefix_messages,\n", + " max_function_calls=10,\n", + " allow_parallel_tool_calls=False,\n", + " verbose=True,\n", + ")\n", + "\n", + "agent = AgentRunner(worker)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Added user message to memory: Hello! I would like to make a booking, around 5pm?\n", + "=== LLM Response ===\n", + "Okay, let's create a new booking for you. To do that, I'll invoke the create_booking tool with your user ID:\n", + "=== Calling Function ===\n", + "Calling function: create_booking with args: {\"user_id\": \"user123\"}\n", + "=== Function Output ===\n", + "Booking created, but not yet confirmed. Please provide your name, email, phone, date, and time.\n" + ] + } + ], + "source": [ + "response = agent.chat(\"Hello! I would like to make a booking, around 5pm?\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Booking created, but not yet confirmed. Please provide your name, email, phone, date, and time.\n" + ] + } + ], + "source": [ + "print(str(response))" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Perfect, we can see the function output was retruned directly, with no modification or final LLM call!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Added user message to memory: Sure! My name is Logan, and my email is test@gmail.com\n", + "=== LLM Response ===\n", + "Got it, thanks for providing your name and email. Let me update the booking with those details:\n", + "=== Calling Function ===\n", + "Calling function: update_booking with args: {\"user_id\": \"user123\", \"property\": \"name\", \"value\": \"Logan\"}\n", + "=== Function Output ===\n", + "Booking ID user123 updated with name = Logan\n", + "=== Calling Function ===\n", + "Calling function: update_booking with args: {\"user_id\": \"user123\", \"property\": \"email\", \"value\": \"test@gmail.com\"}\n", + "=== Function Output ===\n", + "Booking ID user123 updated with email = test@gmail.com\n", + "=== LLM Response ===\n", + "I still need your phone number, date, and preferred time for the booking. Please provide those details.\n" + ] + } + ], + "source": [ + "response = agent.chat(\"Sure! My name is Logan, and my email is test@gmail.com\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "assistant: I still need your phone number, date, and preferred time for the booking. Please provide those details.\n" + ] + } + ], + "source": [ + "print(str(response))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Added user message to memory: Right! My phone number is 1234567890, the date of the booking is April 5, at 5pm.\n", + "=== LLM Response ===\n", + "Great, thank you for providing the remaining details. Let me update the booking:\n", + "=== Calling Function ===\n", + "Calling function: update_booking with args: {\"user_id\": \"user123\", \"property\": \"phone\", \"value\": \"1234567890\"}\n", + "=== Function Output ===\n", + "Booking ID user123 updated with phone = 1234567890\n", + "=== Calling Function ===\n", + "Calling function: update_booking with args: {\"user_id\": \"user123\", \"property\": \"date\", \"value\": \"2023-04-05\"}\n", + "=== Function Output ===\n", + "Booking ID user123 updated with date = 2023-04-05\n", + "=== Calling Function ===\n", + "Calling function: update_booking with args: {\"user_id\": \"user123\", \"property\": \"time\", \"value\": \"17:00\"}\n", + "=== Function Output ===\n", + "Booking ID user123 updated with time = 17:00\n", + "=== LLM Response ===\n", + "Your booking for April 5th at 5pm is now created with all the provided details. To confirm it, I'll invoke:\n", + "=== Calling Function ===\n", + "Calling function: confirm_booking with args: {\"user_id\": \"user123\"}\n", + "=== Function Output ===\n", + "Booking ID user123 confirmed!\n" + ] + } + ], + "source": [ + "response = agent.chat(\n", + " \"Right! My phone number is 1234567890, the date of the booking is April 5, at 5pm.\"\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Booking ID user123 confirmed!\n" + ] + } + ], + "source": [ + "print(str(response))" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "name='Logan' email='test@gmail.com' phone='1234567890' date='2023-04-05' time='17:00'\n" + ] + } + ], + "source": [ + "print(bookings[\"user123\"])" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/docs/module_guides/deploying/agents/usage_pattern.md b/docs/docs/module_guides/deploying/agents/usage_pattern.md index d75209479ed0bd75ee8e626f43f9d712d38e0d45..e8ed4c4b1a594660d71a995437817e72221c6c3c 100644 --- a/docs/docs/module_guides/deploying/agents/usage_pattern.md +++ b/docs/docs/module_guides/deploying/agents/usage_pattern.md @@ -64,6 +64,7 @@ query_engine_tools = [ description="Provides information about Lyft financials for year 2021. " "Use a detailed plain text question as input to the tool.", ), + return_direct=False, ), QueryEngineTool( query_engine=uber_engine, @@ -72,6 +73,7 @@ query_engine_tools = [ description="Provides information about Uber financials for year 2021. " "Use a detailed plain text question as input to the tool.", ), + return_direct=False, ), ] @@ -106,6 +108,31 @@ query_engine_tools = [ outer_agent = ReActAgent.from_tools(query_engine_tools, llm=llm, verbose=True) ``` +### Return Direct + +You'll notice the option `return_direct` in the tool class constructor. If this is set to `True`, the response from the query engine is returned directly, without being interpreted and rewritten by the agent. This can be helpful for decreasing runtime, or designing/specifying tools that will end the agent reasoning loop. + +For example, say you specify a tool: + +```python +tool = QueryEngineTool.from_defaults( + query_engine, + name="<name>", + description="<description>", + return_direct=True, +) + +agent = OpenAIAgent.from_tools([tool]) + +response = agent.chat("<question that invokes tool>") +``` + +In the above example, the query engine tool would be invoked, and the response from that tool would be directly returned as the response, and the execution loop would end. + +If `return_direct=False` was used, then the agent would rewrite the response using the context of the chat history, or even make another tool call. + +We have also provided an [example notebook](../../../examples/agent/return_direct_agent.ipynb) of using `return_direct`. + ## Lower-Level API The OpenAIAgent and ReActAgent are simple wrappers on top of an `AgentRunner` interacting with an `AgentWorker`. diff --git a/docs/mkdocs.yml b/docs/mkdocs.yml index 1e45a1b6b14a7400b975229ca3c02459637f48d4..770470736adcb0c06bfee60cee1086d20b899301 100644 --- a/docs/mkdocs.yml +++ b/docs/mkdocs.yml @@ -79,6 +79,7 @@ nav: - ./examples/agent/agent_runner/agent_around_query_pipeline_with_HyDE_for_PDFs.ipynb - ./examples/agent/mistral_agent.ipynb - ./examples/agent/openai_agent_tool_call_parser.ipynb + - ./examples/agent/return_direct_agent.ipynb - ./examples/agent/anthropic_agent.ipynb - Callbacks: - ./examples/callbacks/HoneyHiveLlamaIndexTracer.ipynb diff --git a/llama-index-core/llama_index/core/agent/function_calling/step.py b/llama-index-core/llama_index/core/agent/function_calling/step.py index f1122da240bd3e5b0e9871eee054a83698fd2a96..5e7c16805a7fd034d93ea6b9a1282fd5c78bf60b 100644 --- a/llama-index-core/llama_index/core/agent/function_calling/step.py +++ b/llama-index-core/llama_index/core/agent/function_calling/step.py @@ -170,14 +170,14 @@ class FunctionCallingAgentWorker(BaseAgentWorker): memory: BaseMemory, sources: List[ToolOutput], verbose: bool = False, - ) -> None: + ) -> bool: + tool = get_function_by_name(tools, tool_call.tool_name) + with self.callback_manager.event( CBEventType.FUNCTION_CALL, payload={ EventPayload.FUNCTION_CALL: json.dumps(tool_call.tool_kwargs), - EventPayload.TOOL: get_function_by_name( - tools, tool_call.tool_name - ).metadata, + EventPayload.TOOL: tool.metadata, }, ) as event: tool_output = call_tool_with_selection(tool_call, tools, verbose=verbose) @@ -194,6 +194,8 @@ class FunctionCallingAgentWorker(BaseAgentWorker): sources.append(tool_output) memory.put(function_message) + return tool.metadata.return_direct + async def _acall_function( self, tools: List[BaseTool], @@ -201,14 +203,14 @@ class FunctionCallingAgentWorker(BaseAgentWorker): memory: BaseMemory, sources: List[ToolOutput], verbose: bool = False, - ) -> None: + ) -> bool: + tool = get_function_by_name(tools, tool_call.tool_name) + with self.callback_manager.event( CBEventType.FUNCTION_CALL, payload={ EventPayload.FUNCTION_CALL: json.dumps(tool_call.tool_kwargs), - EventPayload.TOOL: get_function_by_name( - tools, tool_call.tool_name - ).metadata, + EventPayload.TOOL: tool.metadata, }, ) as event: tool_output = await acall_tool_with_selection( @@ -227,6 +229,8 @@ class FunctionCallingAgentWorker(BaseAgentWorker): sources.append(tool_output) memory.put(function_message) + return tool.metadata.return_direct + @trace_method("run_step") def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput: """Run step.""" @@ -248,6 +252,11 @@ class FunctionCallingAgentWorker(BaseAgentWorker): tool_calls = self._llm.get_tool_calls_from_response( response, error_on_no_tool_call=False ) + + if self._verbose and response.message.content: + print("=== LLM Response ===") + print(str(response.message.content)) + if not self.allow_parallel_tool_calls and len(tool_calls) > 1: raise ValueError( "Parallel tool calls not supported for synchronous function calling agent" @@ -264,24 +273,37 @@ class FunctionCallingAgentWorker(BaseAgentWorker): new_steps = [] else: is_done = False - for tool_call in tool_calls: + for i, tool_call in enumerate(tool_calls): # TODO: maybe execute this with multi-threading - self._call_function( + return_direct = self._call_function( tools, tool_call, task.extra_state["new_memory"], task.extra_state["sources"], verbose=self._verbose, ) + task.extra_state["n_function_calls"] += 1 + + # check if any of the tools return directly -- only works if there is one tool call + if i == 0 and return_direct: + is_done = True + response = task.extra_state["sources"][-1].content + break + # put tool output in sources and memory - new_steps = [ - step.get_next_step( - step_id=str(uuid.uuid4()), - # NOTE: input is unused - input=None, - ) - ] + new_steps = ( + [ + step.get_next_step( + step_id=str(uuid.uuid4()), + # NOTE: input is unused + input=None, + ) + ] + if not is_done + else [] + ) + agent_response = AgentChatResponse( response=str(response), sources=task.extra_state["sources"] ) @@ -316,6 +338,11 @@ class FunctionCallingAgentWorker(BaseAgentWorker): tool_calls = self._llm.get_tool_calls_from_response( response, error_on_no_tool_call=False ) + + if self._verbose and response.message.content: + print("=== LLM Response ===") + print(str(response.message.content)) + if not self.allow_parallel_tool_calls and len(tool_calls) > 1: raise ValueError( "Parallel tool calls not supported for synchronous function calling agent" @@ -342,16 +369,27 @@ class FunctionCallingAgentWorker(BaseAgentWorker): ) for tool_call in tool_calls ] - await asyncio.gather(*tasks) + return_directs = await asyncio.gather(*tasks) + + # check if any of the tools return directly -- only works if there is one tool call + if len(return_directs) == 1 and return_directs[0]: + is_done = True + response = task.extra_state["sources"][-1].content + task.extra_state["n_function_calls"] += len(tool_calls) # put tool output in sources and memory - new_steps = [ - step.get_next_step( - step_id=str(uuid.uuid4()), - # NOTE: input is unused - input=None, - ) - ] + new_steps = ( + [ + step.get_next_step( + step_id=str(uuid.uuid4()), + # NOTE: input is unused + input=None, + ) + ] + if not is_done + else [] + ) + agent_response = AgentChatResponse( response=str(response), sources=task.extra_state["sources"] ) diff --git a/llama-index-core/llama_index/core/agent/react/step.py b/llama-index-core/llama_index/core/agent/react/step.py index 8ba9027d9be478fbb4c5baf798d5b69aad3061b8..5062f88418aa0df681b5cd8eab3e7f6de289d397 100644 --- a/llama-index-core/llama_index/core/agent/react/step.py +++ b/llama-index-core/llama_index/core/agent/react/step.py @@ -260,6 +260,8 @@ class ReActAgentWorker(BaseAgentWorker): tools_dict: Dict[str, AsyncBaseTool] = { tool.metadata.get_name(): tool for tool in tools } + tool = None + try: _, current_reasoning, is_done = self._extract_reasoning_step( output, is_streaming @@ -290,6 +292,7 @@ class ReActAgentWorker(BaseAgentWorker): tool_name=tool.metadata.name, raw_input={"kwargs": reasoning_step.action_input}, raw_output=e, + is_error=True, ) event.on_end( payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)} @@ -299,11 +302,19 @@ class ReActAgentWorker(BaseAgentWorker): task.extra_state["sources"].append(tool_output) - observation_step = ObservationReasoningStep(observation=str(tool_output)) + observation_step = ObservationReasoningStep( + observation=str(tool_output), + return_direct=tool.metadata.return_direct and not tool_output.is_error + if tool + else False, + ) current_reasoning.append(observation_step) if self._verbose: print_text(f"{observation_step.get_content()}\n", color="blue") - return current_reasoning, False + return ( + current_reasoning, + tool.metadata.return_direct and not tool_output.is_error if tool else False, + ) async def _aprocess_actions( self, @@ -313,6 +324,8 @@ class ReActAgentWorker(BaseAgentWorker): is_streaming: bool = False, ) -> Tuple[List[BaseReasoningStep], bool]: tools_dict = {tool.metadata.name: tool for tool in tools} + tool = None + try: _, current_reasoning, is_done = self._extract_reasoning_step( output, is_streaming @@ -343,6 +356,7 @@ class ReActAgentWorker(BaseAgentWorker): tool_name=tool.metadata.name, raw_input={"kwargs": reasoning_step.action_input}, raw_output=e, + is_error=True, ) event.on_end( payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)} @@ -352,11 +366,19 @@ class ReActAgentWorker(BaseAgentWorker): task.extra_state["sources"].append(tool_output) - observation_step = ObservationReasoningStep(observation=str(tool_output)) + observation_step = ObservationReasoningStep( + observation=str(tool_output), + return_direct=tool.metadata.return_direct and not tool_output.is_error + if tool + else False, + ) current_reasoning.append(observation_step) if self._verbose: print_text(f"{observation_step.get_content()}\n", color="blue") - return current_reasoning, False + return ( + current_reasoning, + tool.metadata.return_direct and not tool_output.is_error if tool else False, + ) def _handle_nonexistent_tool_name(self, reasoning_step): # We still emit a `tool_output` object to the task, so that the LLM can know @@ -374,6 +396,7 @@ class ReActAgentWorker(BaseAgentWorker): tool_name=reasoning_step.action, raw_input={"kwargs": reasoning_step.action_input}, raw_output=content, + is_error=True, ) event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)}) return tool_output @@ -392,6 +415,11 @@ class ReActAgentWorker(BaseAgentWorker): if isinstance(current_reasoning[-1], ResponseReasoningStep): response_step = cast(ResponseReasoningStep, current_reasoning[-1]) response_str = response_step.response + elif ( + isinstance(current_reasoning[-1], ObservationReasoningStep) + and current_reasoning[-1].return_direct + ): + response_str = current_reasoning[-1].observation else: response_str = current_reasoning[-1].get_content() @@ -501,7 +529,6 @@ class ReActAgentWorker(BaseAgentWorker): ) # TODO: see if we want to do step-based inputs tools = self.get_tools(task.input) - input_chat = self._react_chat_formatter.format( tools, chat_history=task.memory.get() + task.extra_state["new_memory"].get_all(), @@ -600,7 +627,7 @@ class ReActAgentWorker(BaseAgentWorker): if not is_done: # given react prompt outputs, call tools or return response - reasoning_steps, _ = self._process_actions( + reasoning_steps, is_done = self._process_actions( task, tools=tools, output=full_response, is_streaming=True ) task.extra_state["current_reasoning"].extend(reasoning_steps) @@ -608,6 +635,13 @@ class ReActAgentWorker(BaseAgentWorker): agent_response: AGENT_CHAT_RESPONSE_TYPE = self._get_response( task.extra_state["current_reasoning"], task.extra_state["sources"] ) + if is_done: + agent_response.is_dummy_stream = True + task.extra_state["new_memory"].put( + ChatMessage( + content=agent_response.response, role=MessageRole.ASSISTANT + ) + ) else: # Get the response in a separate thread so we can yield the response response_stream = self._add_back_chunk_to_stream( @@ -664,7 +698,7 @@ class ReActAgentWorker(BaseAgentWorker): if not is_done: # given react prompt outputs, call tools or return response - reasoning_steps, _ = self._process_actions( + reasoning_steps, is_done = self._process_actions( task, tools=tools, output=full_response, is_streaming=True ) task.extra_state["current_reasoning"].extend(reasoning_steps) @@ -672,6 +706,14 @@ class ReActAgentWorker(BaseAgentWorker): agent_response: AGENT_CHAT_RESPONSE_TYPE = self._get_response( task.extra_state["current_reasoning"], task.extra_state["sources"] ) + + if is_done: + agent_response.is_dummy_stream = True + task.extra_state["new_memory"].put( + ChatMessage( + content=agent_response.response, role=MessageRole.ASSISTANT + ) + ) else: # Get the response in a separate thread so we can yield the response response_stream = self._async_add_back_chunk_to_stream( diff --git a/llama-index-core/llama_index/core/agent/react/types.py b/llama-index-core/llama_index/core/agent/react/types.py index 51d85bd26a2557f9be3a3bc30937d7708e6758bd..4ee1d8b00967d621cdf352a4a7d7237a06bccec7 100644 --- a/llama-index-core/llama_index/core/agent/react/types.py +++ b/llama-index-core/llama_index/core/agent/react/types.py @@ -43,6 +43,7 @@ class ObservationReasoningStep(BaseReasoningStep): """Observation reasoning step.""" observation: str + return_direct: bool = False def get_content(self) -> str: """Get content.""" @@ -51,7 +52,7 @@ class ObservationReasoningStep(BaseReasoningStep): @property def is_done(self) -> bool: """Is the reasoning step the last one.""" - return False + return self.return_direct class ResponseReasoningStep(BaseReasoningStep): diff --git a/llama-index-core/llama_index/core/agent/react_multimodal/step.py b/llama-index-core/llama_index/core/agent/react_multimodal/step.py index c09cc78e78def71db5313b8cdd03e816dabdc31c..20fc093d6d745a552aa5a7d8194cb6fe7dc9a200 100644 --- a/llama-index-core/llama_index/core/agent/react_multimodal/step.py +++ b/llama-index-core/llama_index/core/agent/react_multimodal/step.py @@ -294,11 +294,13 @@ class MultimodalReActAgentWorker(BaseAgentWorker): task.extra_state["sources"].append(tool_output) - observation_step = ObservationReasoningStep(observation=str(tool_output)) + observation_step = ObservationReasoningStep( + observation=str(tool_output), return_direct=tool.metadata.return_direct + ) current_reasoning.append(observation_step) if self._verbose: print_text(f"{observation_step.get_content()}\n", color="blue") - return current_reasoning, False + return current_reasoning, tool.metadata.return_direct async def _aprocess_actions( self, @@ -330,11 +332,13 @@ class MultimodalReActAgentWorker(BaseAgentWorker): task.extra_state["sources"].append(tool_output) - observation_step = ObservationReasoningStep(observation=str(tool_output)) + observation_step = ObservationReasoningStep( + observation=str(tool_output), return_direct=tool.metadata.return_direct + ) current_reasoning.append(observation_step) if self._verbose: print_text(f"{observation_step.get_content()}\n", color="blue") - return current_reasoning, False + return current_reasoning, tool.metadata.return_direct def _get_response( self, @@ -350,6 +354,11 @@ class MultimodalReActAgentWorker(BaseAgentWorker): if isinstance(current_reasoning[-1], ResponseReasoningStep): response_step = cast(ResponseReasoningStep, current_reasoning[-1]) response_str = response_step.response + elif ( + isinstance(current_reasoning[-1], ObservationReasoningStep) + and current_reasoning[-1].return_direct + ): + response_str = current_reasoning[-1].observation else: response_str = current_reasoning[-1].get_content() diff --git a/llama-index-core/llama_index/core/agent/runner/base.py b/llama-index-core/llama_index/core/agent/runner/base.py index 2f17a796d63b8799c17270577770465fe7e69333..81a21b144ea4748d6087a6f7ec3c55f41b9170fe 100644 --- a/llama-index-core/llama_index/core/agent/runner/base.py +++ b/llama-index-core/llama_index/core/agent/runner/base.py @@ -667,7 +667,10 @@ class AgentRunner(BaseAgentRunner): chat_response = self._chat( message, chat_history, tool_choice, mode=ChatResponseMode.STREAM ) - assert isinstance(chat_response, StreamingAgentChatResponse) + assert isinstance(chat_response, StreamingAgentChatResponse) or ( + isinstance(chat_response, AgentChatResponse) + and chat_response.is_dummy_stream + ) e.on_end(payload={EventPayload.RESPONSE: chat_response}) return chat_response @@ -689,7 +692,10 @@ class AgentRunner(BaseAgentRunner): chat_response = await self._achat( message, chat_history, tool_choice, mode=ChatResponseMode.STREAM ) - assert isinstance(chat_response, StreamingAgentChatResponse) + assert isinstance(chat_response, StreamingAgentChatResponse) or ( + isinstance(chat_response, AgentChatResponse) + and chat_response.is_dummy_stream + ) e.on_end(payload={EventPayload.RESPONSE: chat_response}) return chat_response diff --git a/llama-index-core/llama_index/core/chat_engine/types.py b/llama-index-core/llama_index/core/chat_engine/types.py index 72753cb92e29646f1f8f7c8e582a4cdb8e2e94a8..f84a156627db9e4583b21a617cf21714d4098083 100644 --- a/llama-index-core/llama_index/core/chat_engine/types.py +++ b/llama-index-core/llama_index/core/chat_engine/types.py @@ -50,6 +50,7 @@ class AgentChatResponse: response: str = "" sources: List[ToolOutput] = field(default_factory=list) source_nodes: List[NodeWithScore] = field(default_factory=list) + is_dummy_stream: bool = False def __post_init__(self) -> None: if self.sources and not self.source_nodes: @@ -60,6 +61,31 @@ class AgentChatResponse: def __str__(self) -> str: return self.response + @property + def response_gen(self) -> Generator[str, None, None]: + """Used for fake streaming, i.e. with tool outputs.""" + if not self.is_dummy_stream: + raise ValueError( + "response_gen is only available for streaming responses. " + "Set is_dummy_stream=True if you still want a generator." + ) + + for token in self.response.split(" "): + yield token + " " + time.sleep(0.1) + + async def async_response_gen(self) -> AsyncGenerator[str, None]: + """Used for fake streaming, i.e. with tool outputs.""" + if not self.is_dummy_stream: + raise ValueError( + "response_gen is only available for streaming responses. " + "Set is_dummy_stream=True if you still want a generator." + ) + + for token in self.response.split(" "): + yield token + " " + await asyncio.sleep(0.1) + @dataclass class StreamingAgentChatResponse: diff --git a/llama-index-core/llama_index/core/tools/calling.py b/llama-index-core/llama_index/core/tools/calling.py index 76e9b89ced0780aab04ab3ff09a0749997d66da1..c442289177a70ab708ac6acdd1920c909775afc5 100644 --- a/llama-index-core/llama_index/core/tools/calling.py +++ b/llama-index-core/llama_index/core/tools/calling.py @@ -24,6 +24,7 @@ def call_tool(tool: BaseTool, arguments: dict) -> ToolOutput: tool_name=tool.metadata.name, raw_input=arguments, raw_output=str(e), + is_error=True, ) @@ -45,6 +46,7 @@ async def acall_tool(tool: BaseTool, arguments: dict) -> ToolOutput: tool_name=tool.metadata.name, raw_input=arguments, raw_output=str(e), + is_error=True, ) @@ -62,7 +64,13 @@ def call_tool_with_selection( print("=== Calling Function ===") print(f"Calling function: {name} with args: {arguments_str}") tool = tools_by_name[name] - return call_tool(tool, tool_call.tool_kwargs) + output = call_tool(tool, tool_call.tool_kwargs) + + if verbose: + print("=== Function Output ===") + print(output.content) + + return output async def acall_tool_with_selection( @@ -79,4 +87,10 @@ async def acall_tool_with_selection( print("=== Calling Function ===") print(f"Calling function: {name} with args: {arguments_str}") tool = tools_by_name[name] - return await acall_tool(tool, tool_call.tool_kwargs) + output = await acall_tool(tool, tool_call.tool_kwargs) + + if verbose: + print("=== Function Output ===") + print(output.content) + + return output diff --git a/llama-index-core/llama_index/core/tools/function_tool.py b/llama-index-core/llama_index/core/tools/function_tool.py index 736e71da13e93c20bed2d4bf8c588d601f1119bf..3f2800d585f3d239056f16037c6e075478068252 100644 --- a/llama-index-core/llama_index/core/tools/function_tool.py +++ b/llama-index-core/llama_index/core/tools/function_tool.py @@ -47,6 +47,7 @@ class FunctionTool(AsyncBaseTool): fn: Callable[..., Any], name: Optional[str] = None, description: Optional[str] = None, + return_direct: bool = False, fn_schema: Optional[Type[BaseModel]] = None, async_fn: Optional[AsyncCallable] = None, tool_metadata: Optional[ToolMetadata] = None, @@ -60,7 +61,10 @@ class FunctionTool(AsyncBaseTool): f"{name}", fn, additional_fields=None ) tool_metadata = ToolMetadata( - name=name, description=description, fn_schema=fn_schema + name=name, + description=description, + fn_schema=fn_schema, + return_direct=return_direct, ) return cls(fn=fn, metadata=tool_metadata, async_fn=async_fn) diff --git a/llama-index-core/llama_index/core/tools/ondemand_loader_tool.py b/llama-index-core/llama_index/core/tools/ondemand_loader_tool.py index 0a9ddf375527b5aa44ed1c16bb46059a6ad38156..8ded88e9365a85a86c6cc469e940d48e0ae017a6 100644 --- a/llama-index-core/llama_index/core/tools/ondemand_loader_tool.py +++ b/llama-index-core/llama_index/core/tools/ondemand_loader_tool.py @@ -93,6 +93,7 @@ class OnDemandLoaderTool(AsyncBaseTool): query_str_kwargs_key: str = "query_str", name: Optional[str] = None, description: Optional[str] = None, + return_direct: bool = False, fn_schema: Optional[Type[BaseModel]] = None, ) -> "OnDemandLoaderTool": """From defaults.""" @@ -106,7 +107,12 @@ class OnDemandLoaderTool(AsyncBaseTool): fn_schema = create_schema_from_function( name or "LoadData", tool._fn, [(query_str_kwargs_key, str, None)] ) - metadata = ToolMetadata(name=name, description=description, fn_schema=fn_schema) + metadata = ToolMetadata( + name=name, + description=description, + fn_schema=fn_schema, + return_direct=return_direct, + ) return cls( loader=tool._fn, index_cls=index_cls, diff --git a/llama-index-core/llama_index/core/tools/query_engine.py b/llama-index-core/llama_index/core/tools/query_engine.py index f5712b407106b31370e7e383db954fa5e7dcf3e2..13b69335985297c19f9b9c3126e50e5ea1ef6d2b 100644 --- a/llama-index-core/llama_index/core/tools/query_engine.py +++ b/llama-index-core/llama_index/core/tools/query_engine.py @@ -40,12 +40,15 @@ class QueryEngineTool(AsyncBaseTool): query_engine: BaseQueryEngine, name: Optional[str] = None, description: Optional[str] = None, + return_direct: bool = False, resolve_input_errors: bool = True, ) -> "QueryEngineTool": name = name or DEFAULT_NAME description = description or DEFAULT_DESCRIPTION - metadata = ToolMetadata(name=name, description=description) + metadata = ToolMetadata( + name=name, description=description, return_direct=return_direct + ) return cls( query_engine=query_engine, metadata=metadata, diff --git a/llama-index-core/llama_index/core/tools/types.py b/llama-index-core/llama_index/core/tools/types.py index bec2e752a2157d4e472644b91f476cfd852dacc1..66207ae44f3fd648216c76a7d95c9dc3572ec50e 100644 --- a/llama-index-core/llama_index/core/tools/types.py +++ b/llama-index-core/llama_index/core/tools/types.py @@ -20,6 +20,7 @@ class ToolMetadata: description: str name: Optional[str] = None fn_schema: Optional[Type[BaseModel]] = DefaultToolFnSchema + return_direct: bool = False def get_parameters_dict(self) -> dict: if self.fn_schema is None: @@ -86,6 +87,7 @@ class ToolOutput(BaseModel): tool_name: str raw_input: Dict[str, Any] raw_output: Any + is_error: bool = False def __str__(self) -> str: """String.""" diff --git a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py index e84c5c52462a0628664e66ad72e8c3b673a8c711..f705d13789881dd361517b1ba90b9860d277b71d 100644 --- a/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py +++ b/llama-index-integrations/agent/llama-index-agent-openai/llama_index/agent/openai/step.py @@ -162,6 +162,7 @@ def call_function( tool_name=name, raw_input={"args": arguments_str}, raw_output=error_message, + is_error=True, ), ) @@ -228,6 +229,7 @@ async def acall_function( tool_name=name, raw_input={"args": arguments_str}, raw_output=error_message, + is_error=True, ), ) @@ -445,20 +447,20 @@ class OpenAIAgentWorker(BaseAgentWorker): tool_call: OpenAIToolCall, memory: BaseMemory, sources: List[ToolOutput], - ) -> None: + ) -> bool: function_call = tool_call.function # validations to get passed mypy assert function_call is not None assert function_call.name is not None assert function_call.arguments is not None + tool = get_function_by_name(tools, function_call.name) + with self.callback_manager.event( CBEventType.FUNCTION_CALL, payload={ EventPayload.FUNCTION_CALL: function_call.arguments, - EventPayload.TOOL: get_function_by_name( - tools, function_call.name - ).metadata, + EventPayload.TOOL: tool.metadata, }, ) as event: function_message, tool_output = call_function( @@ -471,26 +473,28 @@ class OpenAIAgentWorker(BaseAgentWorker): sources.append(tool_output) memory.put(function_message) + return tool.metadata.return_direct and not tool_output.is_error + async def _acall_function( self, tools: List[BaseTool], tool_call: OpenAIToolCall, memory: BaseMemory, sources: List[ToolOutput], - ) -> None: + ) -> bool: function_call = tool_call.function # validations to get passed mypy assert function_call is not None assert function_call.name is not None assert function_call.arguments is not None + tool = get_function_by_name(tools, function_call.name) + with self.callback_manager.event( CBEventType.FUNCTION_CALL, payload={ EventPayload.FUNCTION_CALL: function_call.arguments, - EventPayload.TOOL: get_function_by_name( - tools, function_call.name - ).metadata, + EventPayload.TOOL: tool.metadata, }, ) as event: function_message, tool_output = await acall_function( @@ -503,6 +507,8 @@ class OpenAIAgentWorker(BaseAgentWorker): sources.append(tool_output) memory.put(function_message) + return tool.metadata.return_direct and not tool_output.is_error + def initialize_step(self, task: Task, **kwargs: Any) -> TaskStep: """Initialize step from task.""" sources: List[ToolOutput] = [] @@ -574,7 +580,7 @@ class OpenAIAgentWorker(BaseAgentWorker): if tool_call.type != "function": raise ValueError("Invalid tool type. Unsupported by OpenAI") # TODO: maybe execute this with multi-threading - self._call_function( + return_direct = self._call_function( tools, tool_call, task.extra_state["new_memory"], @@ -585,13 +591,32 @@ class OpenAIAgentWorker(BaseAgentWorker): if tool_choice not in ("auto", "none"): tool_choice = "auto" task.extra_state["n_function_calls"] += 1 - new_steps = [ - step.get_next_step( - step_id=str(uuid.uuid4()), - # NOTE: input is unused - input=None, - ) - ] + + if return_direct and len(latest_tool_calls) == 1: + is_done = True + response_str = task.extra_state["sources"][-1].content + chat_response = ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, content=response_str + ) + ) + agent_chat_response = self._process_message(task, chat_response) + agent_chat_response.is_dummy_stream = ( + mode == ChatResponseMode.STREAM + ) + break + + new_steps = ( + [ + step.get_next_step( + step_id=str(uuid.uuid4()), + # NOTE: input is unused + input=None, + ) + ] + if not is_done + else [] + ) # attach next step to task @@ -641,7 +666,7 @@ class OpenAIAgentWorker(BaseAgentWorker): if tool_call.type != "function": raise ValueError("Invalid tool type. Unsupported by OpenAI") # TODO: maybe execute this with multi-threading - await self._acall_function( + return_direct = await self._acall_function( tools, tool_call, task.extra_state["new_memory"], @@ -653,6 +678,20 @@ class OpenAIAgentWorker(BaseAgentWorker): tool_choice = "auto" task.extra_state["n_function_calls"] += 1 + if return_direct and len(latest_tool_calls) == 1: + is_done = True + response_str = task.extra_state["sources"][-1].content + chat_response = ChatResponse( + message=ChatMessage( + role=MessageRole.ASSISTANT, content=response_str + ) + ) + agent_chat_response = self._process_message(task, chat_response) + agent_chat_response.is_dummy_stream = ( + mode == ChatResponseMode.STREAM + ) + break + # generate next step, append to task queue new_steps = ( [