diff --git a/docs/examples/agent/custom_agent.ipynb b/docs/examples/agent/custom_agent.ipynb index bceaf60b9cc24cb0988952e0e3b123de838762b5..14ee8ab7798b27c7b132d5c5375d150a358575ed 100644 --- a/docs/examples/agent/custom_agent.ipynb +++ b/docs/examples/agent/custom_agent.ipynb @@ -64,7 +64,7 @@ "outputs": [], "source": [ "from llama_index.agent import CustomSimpleAgentWorker, Task, AgentChatResponse\n", - "from typing import Dict, Any, List, Tuple\n", + "from typing import Dict, Any, List, Tuple, Optional\n", "from llama_index.tools import BaseTool, QueryEngineTool\n", "from llama_index.program import LLMTextCompletionProgram\n", "from llama_index.output_parsers import PydanticOutputParser\n", @@ -185,7 +185,7 @@ " return {\"count\": 0, \"current_reasoning\": []}\n", "\n", " def _run_step(\n", - " self, state: Dict[str, Any], task: Task\n", + " self, state: Dict[str, Any], task: Task, input: Optional[str] = None\n", " ) -> Tuple[AgentChatResponse, bool]:\n", " \"\"\"Run step.\n", "\n", @@ -193,11 +193,17 @@ " Tuple of (agent_response, is_done)\n", "\n", " \"\"\"\n", - " if \"new_input\" not in state:\n", + " if input is not None:\n", + " # if input is specified, override input\n", + " new_input = input\n", + " elif \"new_input\" not in state:\n", " new_input = task.input\n", " else:\n", " new_input = state[\"new_input\"]\n", "\n", + " if self.verbose:\n", + " print(f\"> Current Input: {new_input}\")\n", + "\n", " # first run router query engine\n", " response = self._router_query_engine.query(new_input)\n", "\n", @@ -441,7 +447,9 @@ "id": "c0d6798c-a33c-455a-b4e4-fb0bb4208435", "metadata": {}, "source": [ - "## Try Out Some Queries" + "## Try Out Some Queries\n", + "\n", + "Let's run some e2e queries through `agent.chat`." ] }, { @@ -454,18 +462,14 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[1;3;38;5;200mSelecting query engine 0: The choice is about translating a natural language query into a SQL query over a table containing city_stats, which likely includes information about the country of each city..\n", - "\u001b[0m> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)), and foreign keys: .\n", - "> Predicted SQL query: SELECT city_name, country FROM city_stats\n", - "> Question: Which countries are each city from?\n", + "> Current Input: Which countries are each city from?\n", + "\u001b[1;3;38;5;200mSelecting query engine 0: The first choice is the most relevant because it mentions a table containing city_stats, which likely includes information about the country of each city..\n", + "\u001b[0m> Question: Which countries are each city from?\n", "> Response: The city of Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.\n", - "> Response eval: {'has_error': True, 'new_question': 'Which country is each of the following cities from: Toronto, Tokyo, Berlin?', 'explanation': 'The original question was too vague as it did not specify which cities the question was referring to. The new question provides specific cities for which the country of origin is being asked.'}\n", + "> Response eval: {'has_error': True, 'new_question': 'Which country is each of the following cities from: Toronto, Tokyo, Berlin?', 'explanation': 'The original question was too vague and did not specify which cities the user was interested in. The new question provides specific cities for the system to provide information on.'}\n", + "> Current Input: Which country is each of the following cities from: Toronto, Tokyo, Berlin?\n", "\u001b[1;3;38;5;200mSelecting query engine 0: This choice is relevant because it mentions a table containing city_stats, which likely includes information about the country of each city..\n", - "\u001b[0m> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)), and foreign keys: .\n", - "> Predicted SQL query: SELECT city_name, country\n", - "FROM city_stats\n", - "WHERE city_name IN ('Toronto', 'Tokyo', 'Berlin')\n", - "> Question: Which country is each of the following cities from: Toronto, Tokyo, Berlin?\n", + "\u001b[0m> Question: Which country is each of the following cities from: Toronto, Tokyo, Berlin?\n", "> Response: Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.\n", "> Response eval: {'has_error': False, 'new_question': '', 'explanation': ''}\n", "Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.\n" @@ -487,37 +491,27 @@ "name": "stdout", "output_type": "stream", "text": [ - "\u001b[1;3;38;5;200mSelecting query engine 0: The question is asking about the top modes of transportation for the city with the highest population. Choice (1) is the most relevant because it mentions a table containing city_stats, which likely includes information about the population of each city..\n", - "\u001b[0m> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)), and foreign keys: .\n", - "> Predicted SQL query: SELECT city_name, population, mode_of_transportation\n", - "FROM city_stats\n", - "WHERE population = (SELECT MAX(population) FROM city_stats)\n", - "ORDER BY mode_of_transportation ASC\n", - "LIMIT 5;\n", - "> Question: What are the top modes of transporation fo the city with the higehest population?\n", - "> Response: I'm sorry, but there was an error in retrieving the information. Please try again later.\n", - "> Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation for the city with the highest population?', 'explanation': 'The original question had spelling errors which might have caused the system to not understand the question correctly. The corrected question should now be clear and understandable for the system.'}\n", - "\u001b[1;3;38;5;200mSelecting query engine 0: The first choice is the most relevant because it mentions translating a natural language query into a SQL query over a table containing city_stats, which likely includes information about the population of each city..\n", - "\u001b[0m> Table desc str: Table 'city_stats' has columns: city_name (VARCHAR(16)), population (INTEGER), country (VARCHAR(16)), and foreign keys: .\n", - "> Predicted SQL query: SELECT city_name, population, country\n", - "FROM city_stats\n", - "WHERE population = (SELECT MAX(population) FROM city_stats)\n", - "> Question: What are the top modes of transportation for the city with the highest population?\n", - "> Response: The city with the highest population is Tokyo, Japan with a population of 13,960,000.\n", - "> Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation for Tokyo, Japan?', 'explanation': 'The assistant failed to answer the original question correctly. The response was about the city with the highest population, but it did not mention anything about the top modes of transportation in that city. The new question directly asks about the top modes of transportation in Tokyo, Japan, which is the city with the highest population.'}\n", - "\u001b[1;3;38;5;200mSelecting query engine 3: The question specifically asks about Tokyo, and choice (4) is about answering semantic questions about Tokyo..\n", - "\u001b[0m> Question: What are the top modes of transportation for Tokyo, Japan?\n", - "> Response: The top modes of transportation for Tokyo, Japan are trains and subways, which are considered clean and efficient. Tokyo has an extensive network of electric train lines and over 900 train stations. Buses, monorails, and trams also play a secondary role in the public transportation system. Additionally, expressways connect Tokyo to other points in the Greater Tokyo Area and beyond. Taxis and long-distance ferries are also available for transportation within the city and to the surrounding islands.\n", - "> Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation for Tokyo, Japan?', 'explanation': 'The original question was not answered correctly because the assistant did not provide information on the top modes of transportation for the city with the highest population. The new question directly asks for the top modes of transportation for Tokyo, Japan, which is the city with the highest population.'}\n", - "\u001b[1;3;38;5;200mSelecting query engine 3: Tokyo is mentioned in choice 4.\n", - "\u001b[0m> Question: What are the top modes of transportation for Tokyo, Japan?\n", - "> Response: The top modes of transportation for Tokyo, Japan are trains and subways, which are considered clean and efficient. Tokyo has an extensive network of electric train lines and over 900 train stations. Buses, monorails, and trams also play a secondary role in public transportation within the city. Additionally, Tokyo has two major airports, Narita International Airport and Haneda Airport, which offer domestic and international flights. Expressways and taxis are also available for transportation within the city.\n", - "> Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation for Tokyo, Japan?', 'explanation': 'The response is erroneous because it does not answer the question asked. The question asks for the top modes of transportation in the city with the highest population, but the response only provides the population of the city. The new question directly asks for the top modes of transportation in Tokyo, Japan, which is the city with the highest population.'}\n", + "> Current Input: What are the top modes of transporation fo the city with the higehest population?\n", + "\u001b[1;3;38;5;200mSelecting query engine 0: The question is asking about the top modes of transportation for the city with the highest population, which requires translating a natural language query into a SQL query over a table containing city statistics..\n", + "\u001b[0m> Question: What are the top modes of transporation fo the city with the higehest population?\n", + "> Response: I'm sorry, but there seems to be an error in the SQL query. Please check the syntax and try again.\n", + "> Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation for the city with the highest population?', 'explanation': 'The original question had spelling errors which might have caused confusion. The corrected question now clearly asks for the top modes of transportation in the city with the highest population.'}\n", + "> Current Input: What are the top modes of transportation for the city with the highest population?\n", + "\u001b[1;3;38;5;200mSelecting query engine 0: This choice is relevant because it mentions translating a natural language query into a SQL query over a table containing city statistics, which could include information about the population of cities..\n", + "\u001b[0m> Question: What are the top modes of transportation for the city with the highest population?\n", + "> Response: The city with the highest population is Tokyo, Japan.\n", + "> Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation in Tokyo, Japan?', 'explanation': 'The assistant did not answer the original question correctly. The question asked for the top modes of transportation in the city with the highest population, but the assistant only provided the city with the highest population. The new question directly asks for the top modes of transportation in Tokyo, Japan, which is the city with the highest population.'}\n", + "> Current Input: What are the top modes of transportation in Tokyo, Japan?\n", + "\u001b[1;3;38;5;200mSelecting query engine 3: Tokyo is mentioned in choice 4, which is specifically about answering semantic questions about Tokyo..\n", + "\u001b[0m> Question: What are the top modes of transportation in Tokyo, Japan?\n", + "> Response: The top modes of transportation in Tokyo, Japan are trains and subways, which are considered clean and efficient. There are also buses, monorails, and trams that play a secondary role in the transportation system. Additionally, there are expressways that connect Tokyo to other points in the Greater Tokyo Area and beyond. Taxis and long-distance ferries are also available for transportation within the city and to other islands.\n", + "> Response eval: {'has_error': True, 'new_question': 'What are the top modes of transportation in Tokyo, Japan?', 'explanation': 'The original question was not answered correctly. The assistant provided information about the city with the highest population, but did not answer the question about the top modes of transportation in that city. The new question directly asks about the modes of transportation in Tokyo, Japan, which is the city with the highest population.'}\n", + "> Current Input: What are the top modes of transportation in Tokyo, Japan?\n", "\u001b[1;3;38;5;200mSelecting query engine 3: The question specifically asks about Tokyo, and choice 4 is about answering semantic questions about Tokyo..\n", - "\u001b[0m> Question: What are the top modes of transportation for Tokyo, Japan?\n", - "> Response: The top modes of transportation for Tokyo, Japan are trains and subways, which are considered clean and efficient. Tokyo has an extensive network of electric train lines and over 900 train stations. Buses, monorails, and trams also play a secondary role in public transportation within the city. Additionally, Tokyo has two major airports, Narita International Airport and Haneda Airport, which offer domestic and international flights. Expressways and taxis are also available for transportation within the city.\n", + "\u001b[0m> Question: What are the top modes of transportation in Tokyo, Japan?\n", + "> Response: The top modes of transportation in Tokyo, Japan are trains and subways. Tokyo has an extensive network of clean and efficient trains and subways operated by various operators. There are up to 62 electric train lines and more than 900 train stations in Tokyo. Buses, monorails, and trams also play a secondary role in public transportation within the city. Additionally, Tokyo has expressways, taxis, and long-distance ferries as other means of transportation.\n", "> Response eval: {'has_error': False, 'new_question': '', 'explanation': ''}\n", - "The top modes of transportation for Tokyo, Japan are trains and subways, which are considered clean and efficient. Tokyo has an extensive network of electric train lines and over 900 train stations. Buses, monorails, and trams also play a secondary role in public transportation within the city. Additionally, Tokyo has two major airports, Narita International Airport and Haneda Airport, which offer domestic and international flights. Expressways and taxis are also available for transportation within the city.\n" + "The top modes of transportation in Tokyo, Japan are trains and subways. Tokyo has an extensive network of clean and efficient trains and subways operated by various operators. There are up to 62 electric train lines and more than 900 train stations in Tokyo. Buses, monorails, and trams also play a secondary role in public transportation within the city. Additionally, Tokyo has expressways, taxis, and long-distance ferries as other means of transportation.\n" ] } ], @@ -531,46 +525,146 @@ { "cell_type": "code", "execution_count": null, - "id": "6b98fa9c-123d-4347-a696-81148d48bc4c", + "id": "926b79ba-1868-46ca-bb7e-2bd3c907773c", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "The top modes of transportation for Tokyo, Japan are trains and subways, which are considered clean and efficient. Tokyo has an extensive network of electric train lines and over 900 train stations. Buses, monorails, and trams also play a secondary role in public transportation within the city. Additionally, Tokyo has two major airports, Narita International Airport and Haneda Airport, which offer domestic and international flights. Expressways and taxis are also available for transportation within the city.\n" + "> Current Input: What are the sports teams of each city in Asia?\n", + "\u001b[1;3;38;5;200mSelecting query engine 3: The question is asking about sports teams in Asia, and Tokyo is a city in Asia..\n", + "\u001b[0m> Question: What are the sports teams of each city in Asia?\n", + "> Response: I'm sorry, but the context information does not provide a comprehensive list of sports teams in each city in Asia. It only mentions some of the sports teams in Tokyo, Japan. To answer your question, I would need more specific information or a different source that provides a broader overview of sports teams in cities across Asia.\n", + "> Response eval: {'has_error': True, 'new_question': 'What are some popular sports teams in Tokyo, Japan?', 'explanation': 'The original question is too broad and the system does not have enough information to provide a comprehensive list of sports teams in each city in Asia. The new question is more specific and focuses on sports teams in Tokyo, Japan, which the system has information on.'}\n", + "> Current Input: What are some popular sports teams in Tokyo, Japan?\n", + "\u001b[1;3;38;5;200mSelecting query engine 3: Tokyo is mentioned in choice 4, which is about answering semantic questions about Tokyo..\n", + "\u001b[0m> Question: What are some popular sports teams in Tokyo, Japan?\n", + "> Response: Some popular sports teams in Tokyo, Japan include the Yomiuri Giants and Tokyo Yakult Swallows in baseball, F.C. Tokyo and Tokyo Verdy 1969 in soccer, and the Hitachi SunRockers, Toyota Alvark Tokyo, and Tokyo Excellence in basketball. Tokyo is also known for its sumo wrestling tournaments held at the RyÅgoku Kokugikan sumo arena.\n", + "> Response eval: {'has_error': False, 'new_question': '', 'explanation': ''}\n", + "Some popular sports teams in Tokyo, Japan include the Yomiuri Giants and Tokyo Yakult Swallows in baseball, F.C. Tokyo and Tokyo Verdy 1969 in soccer, and the Hitachi SunRockers, Toyota Alvark Tokyo, and Tokyo Excellence in basketball. Tokyo is also known for its sumo wrestling tournaments held at the RyÅgoku Kokugikan sumo arena.\n" ] } ], "source": [ + "response = agent.chat(\"What are the sports teams of each city in Asia?\")\n", "print(str(response))" ] }, + { + "cell_type": "markdown", + "id": "a89391a4-8836-432a-a42f-9c831710f2e4", + "metadata": {}, + "source": [ + "## Step-wise Queries\n", + "\n", + "We can also try some step-wise queries. This allows us to inject user feedback in the middle of a task execution to guide responses towards the correct state faster." + ] + }, { "cell_type": "code", "execution_count": null, - "id": "926b79ba-1868-46ca-bb7e-2bd3c907773c", + "id": "04a4b1a2-5c7f-494f-bb35-a472dde47eab", + "metadata": {}, + "outputs": [], + "source": [ + "agent_worker = RetryAgentWorker.from_tools(\n", + " query_engine_tools,\n", + " llm=llm,\n", + " verbose=True,\n", + " callback_manager=callback_manager,\n", + ")\n", + "agent = AgentRunner(agent_worker, callback_manager=callback_manager)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e9e51266-9b29-48e7-b40f-cc45e7f8dd30", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "\u001b[1;3;38;5;200mSelecting query engine 3: The question is asking about sports teams in Asia, and Tokyo is located in Asia..\n", - "\u001b[0m> Question: What are the sports teams of each city in Asia?\n", - "> Response: I'm sorry, but the context information does not provide a comprehensive list of sports teams in each city in Asia. It only mentions some sports teams in Tokyo, Japan. To get a complete list of sports teams in each city in Asia, you would need to consult a reliable source or conduct further research.\n", - "> Response eval: {'has_error': True, 'new_question': 'What are some popular sports teams in Tokyo, Japan?', 'explanation': 'The original question is too broad and requires extensive data that the system may not possess. The new question is more specific and focuses on a single city, making it more likely to receive a correct and comprehensive answer.'}\n", - "\u001b[1;3;38;5;200mSelecting query engine 3: The question specifically asks about Tokyo, and choice 4 is about answering semantic questions about Tokyo..\n", - "\u001b[0m> Question: What are some popular sports teams in Tokyo, Japan?\n", - "> Response: Some popular sports teams in Tokyo, Japan include the Yomiuri Giants and Tokyo Yakult Swallows in baseball, F.C. Tokyo and Tokyo Verdy 1969 in soccer, and Hitachi SunRockers, Toyota Alvark Tokyo, and Tokyo Excellence in basketball. Tokyo is also known for its sumo wrestling tournaments held at the RyÅgoku Kokugikan sumo arena.\n", - "> Response eval: {'has_error': False, 'new_question': '', 'explanation': ''}\n", - "Some popular sports teams in Tokyo, Japan include the Yomiuri Giants and Tokyo Yakult Swallows in baseball, F.C. Tokyo and Tokyo Verdy 1969 in soccer, and Hitachi SunRockers, Toyota Alvark Tokyo, and Tokyo Excellence in basketball. Tokyo is also known for its sumo wrestling tournaments held at the RyÅgoku Kokugikan sumo arena.\n" + "> Current Input: Which countries are each city from?\n", + "\u001b[1;3;38;5;200mSelecting query engine 0: The first choice is the most relevant because it mentions a table containing city_stats, which likely includes information about the country of each city..\n", + "\u001b[0m> Question: Which countries are each city from?\n", + "> Response: The city of Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.\n", + "> Response eval: {'has_error': True, 'new_question': 'Can you tell me the country of origin for each of these cities: Toronto, Tokyo, Berlin?', 'explanation': 'The original question was too vague and did not specify which cities the user was interested in. The new question provides specific cities for the system to provide information on.'}\n" + ] + }, + { + "data": { + "text/plain": [ + "AgentChatResponse(response='The city of Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.', sources=[], source_nodes=[])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "task = agent.create_task(\"Which countries are each city from?\")\n", + "step_output = agent.run_step(task.task_id)\n", + "step_output.output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3fe5c76f-b183-43c8-b965-5ce576a70be3", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "> Current Input: Which country is each of the following cities from: Toronto, Tokyo, Berlin?\n", + "\u001b[1;3;38;5;200mSelecting query engine 0: This choice is relevant because it mentions a table containing city_stats, which likely includes information about the country of each city..\n", + "\u001b[0m> Question: Which country is each of the following cities from: Toronto, Tokyo, Berlin?\n", + "> Response: Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.\n", + "> Response eval: {'has_error': False, 'new_question': '', 'explanation': ''}\n" ] + }, + { + "data": { + "text/plain": [ + "AgentChatResponse(response='Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.', sources=[], source_nodes=[])" + ] + }, + "execution_count": null, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "response = agent.chat(\"What are the sports teams of each city in Asia?\")\n", - "print(str(response))" + "step_output = agent.run_step(\n", + " task.task_id,\n", + " input=\"Which country is each of the following cities from: Toronto, Tokyo, Berlin?\",\n", + ")\n", + "step_output.output" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3bb30f0a-5947-456b-ab05-5db6d76bca7a", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Toronto is from Canada, Tokyo is from Japan, and Berlin is from Germany.\n" + ] + } + ], + "source": [ + "if step_output.is_last:\n", + " response = agent.finalize_response(task.task_id)\n", + " print(str(response))" ] } ], diff --git a/docs/examples/agent/openai_agent.ipynb b/docs/examples/agent/openai_agent.ipynb index 0ad817ff33e3225af2299b70a87fbd73b2e7ee20..e5f8891d9964076032562fe6c174a64c09112d6e 100644 --- a/docs/examples/agent/openai_agent.ipynb +++ b/docs/examples/agent/openai_agent.ipynb @@ -190,17 +190,18 @@ " return ai_message.content\n", "\n", " def _call_function(self, tool_call: dict) -> ChatMessage:\n", - " id_ = tool_call[\"id\"]\n", - " function_call = tool_call[\"function\"]\n", - " tool = self._tools[function_call[\"name\"]]\n", - " output = tool(**json.loads(function_call[\"arguments\"]))\n", + " id_ = tool_call.id\n", + " function_call = tool_call.function\n", + " tool = self._tools[function_call.name]\n", + " output = tool(**json.loads(function_call.arguments))\n", + " print(f\"> Calling tool: {function_call.name}\")\n", " return ChatMessage(\n", - " name=function_call[\"name\"],\n", + " name=function_call.name,\n", " content=str(output),\n", " role=\"tool\",\n", " additional_kwargs={\n", " \"tool_call_id\": id_,\n", - " \"name\": function_call[\"name\"],\n", + " \"name\": function_call.name,\n", " },\n", " )" ] diff --git a/llama_index/agent/custom/simple.py b/llama_index/agent/custom/simple.py index 53b44290103b875d61dbb7ff45649beeea73619b..e735ac66b97613e80fca38ed64f3fdb85d3f02c4 100644 --- a/llama_index/agent/custom/simple.py +++ b/llama_index/agent/custom/simple.py @@ -180,9 +180,7 @@ class CustomSimpleAgentWorker(BaseModel, BaseAgentWorker): @abstractmethod def _run_step( - self, - state: Dict[str, Any], - task: Task, + self, state: Dict[str, Any], task: Task, input: Optional[str] = None ) -> Tuple[AgentChatResponse, bool]: """Run step. @@ -192,9 +190,7 @@ class CustomSimpleAgentWorker(BaseModel, BaseAgentWorker): """ async def _arun_step( - self, - state: Dict[str, Any], - task: Task, + self, state: Dict[str, Any], task: Task, input: Optional[str] = None ) -> Tuple[AgentChatResponse, bool]: """Run step (async). @@ -211,7 +207,9 @@ class CustomSimpleAgentWorker(BaseModel, BaseAgentWorker): @trace_method("run_step") def run_step(self, step: TaskStep, task: Task, **kwargs: Any) -> TaskStepOutput: """Run step.""" - agent_response, is_done = self._run_step(step.step_state, task) + agent_response, is_done = self._run_step( + step.step_state, task, input=step.input + ) response = self._get_task_step_response(agent_response, step, is_done) # sync step state with task state task.extra_state.update(step.step_state) @@ -222,7 +220,9 @@ class CustomSimpleAgentWorker(BaseModel, BaseAgentWorker): self, step: TaskStep, task: Task, **kwargs: Any ) -> TaskStepOutput: """Run step (async).""" - agent_response, is_done = await self._arun_step(step.step_state, task) + agent_response, is_done = await self._arun_step( + step.step_state, task, input=step.input + ) response = self._get_task_step_response(agent_response, step, is_done) task.extra_state.update(step.step_state) return response diff --git a/llama_index/agent/runner/base.py b/llama_index/agent/runner/base.py index d9019ab857338bf4427a9ad5820522da6a69995e..eaf81e39f12b7563503e49ac8256a5565d0bcbae 100644 --- a/llama_index/agent/runner/base.py +++ b/llama_index/agent/runner/base.py @@ -1,4 +1,3 @@ -import uuid from abc import abstractmethod from collections import deque from typing import Any, Deque, Dict, List, Optional, Union, cast @@ -140,13 +139,7 @@ def validate_step_from_args( raise ValueError(f"step must be TaskStep: {step}") return step else: - return ( - None - if input is None - else TaskStep( - task_id=task_id, step_id=str(uuid.uuid4()), input=input, **kwargs - ) - ) + return None class TaskState(BaseModel): @@ -316,6 +309,7 @@ class AgentRunner(BaseAgentRunner): self, task_id: str, step: Optional[TaskStep] = None, + input: Optional[str] = None, mode: ChatResponseMode = ChatResponseMode.WAIT, **kwargs: Any, ) -> TaskStepOutput: @@ -323,6 +317,8 @@ class AgentRunner(BaseAgentRunner): task = self.state.get_task(task_id) step_queue = self.state.get_step_queue(task_id) step = step or step_queue.popleft() + if input is not None: + step.input = input # TODO: figure out if you can dynamically swap in different step executors # not clear when you would do that by theoretically possible @@ -347,6 +343,7 @@ class AgentRunner(BaseAgentRunner): self, task_id: str, step: Optional[TaskStep] = None, + input: Optional[str] = None, mode: ChatResponseMode = ChatResponseMode.WAIT, **kwargs: Any, ) -> TaskStepOutput: @@ -354,6 +351,8 @@ class AgentRunner(BaseAgentRunner): task = self.state.get_task(task_id) step_queue = self.state.get_step_queue(task_id) step = step or step_queue.popleft() + if input is not None: + step.input = input # TODO: figure out if you can dynamically swap in different step executors # not clear when you would do that by theoretically possible @@ -382,7 +381,9 @@ class AgentRunner(BaseAgentRunner): ) -> TaskStepOutput: """Run step.""" step = validate_step_from_args(task_id, input, step, **kwargs) - return self._run_step(task_id, step, mode=ChatResponseMode.WAIT, **kwargs) + return self._run_step( + task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs + ) async def arun_step( self, @@ -394,7 +395,7 @@ class AgentRunner(BaseAgentRunner): """Run step (async).""" step = validate_step_from_args(task_id, input, step, **kwargs) return await self._arun_step( - task_id, step, mode=ChatResponseMode.WAIT, **kwargs + task_id, step, input=input, mode=ChatResponseMode.WAIT, **kwargs ) def stream_step( @@ -406,7 +407,9 @@ class AgentRunner(BaseAgentRunner): ) -> TaskStepOutput: """Run step (stream).""" step = validate_step_from_args(task_id, input, step, **kwargs) - return self._run_step(task_id, step, mode=ChatResponseMode.STREAM, **kwargs) + return self._run_step( + task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs + ) async def astream_step( self, @@ -418,7 +421,7 @@ class AgentRunner(BaseAgentRunner): """Run step (async stream).""" step = validate_step_from_args(task_id, input, step, **kwargs) return await self._arun_step( - task_id, step, mode=ChatResponseMode.STREAM, **kwargs + task_id, step, input=input, mode=ChatResponseMode.STREAM, **kwargs ) def finalize_response( diff --git a/tests/agent/openai/test_openai_agent.py b/tests/agent/openai/test_openai_agent.py index 7f12b1e083c12a40b1de8aee501ef5760c9f8972..8cbffb783f34b45f37872ff73364a4ea78314345 100644 --- a/tests/agent/openai/test_openai_agent.py +++ b/tests/agent/openai/test_openai_agent.py @@ -266,10 +266,14 @@ def test_add_step( tools=[add_tool], llm=llm, ) + ## NOTE: can only take a single step before finishing, + # since mocked chat output does not call any tools task = agent.create_task("What is 1 + 1?") - # first step step_output = agent.run_step(task.task_id) + assert str(step_output) == "\n\nThis is a test!" + # add human input (not used but should be in memory) + task = agent.create_task("What is 1 + 1?") step_output = agent.run_step(task.task_id, input="tmp") chat_history: List[ChatMessage] = task.extra_state["new_memory"].get_all() assert "tmp" in [m.content for m in chat_history] @@ -307,6 +311,7 @@ async def test_async_add_step( mock_instance.chat.completions.create.return_value = mock_achat_completion() step_output = await agent.arun_step(task.task_id) # add human input (not used but should be in memory) + task = agent.create_task("What is 1 + 1?") mock_instance.chat.completions.create.return_value = mock_achat_completion() step_output = await agent.arun_step(task.task_id, input="tmp") chat_history: List[ChatMessage] = task.extra_state["new_memory"].get_all() @@ -322,6 +327,7 @@ async def test_async_add_step( mock_instance.chat.completions.create.side_effect = mock_achat_stream step_output = await agent.astream_step(task.task_id) # add human input (not used but should be in memory) + task = agent.create_task("What is 1 + 1?") mock_instance.chat.completions.create.side_effect = mock_achat_stream step_output = await agent.astream_step(task.task_id, input="tmp") chat_history = task.extra_state["new_memory"].get_all() diff --git a/tests/agent/react/test_react_agent.py b/tests/agent/react/test_react_agent.py index af45634993967227393c8f97d2ef6161ab67d776..e09e345c1639ce21d50f2fea37da50b34c2e8af2 100644 --- a/tests/agent/react/test_react_agent.py +++ b/tests/agent/react/test_react_agent.py @@ -257,19 +257,36 @@ async def test_astream_chat_basic( ] -def _get_agent(tools: List[BaseTool]) -> ReActAgent: - mock_llm = MockChatLLM( - responses=[ - ChatMessage( - content=MOCK_ACTION_RESPONSE, - role=MessageRole.ASSISTANT, - ), - ChatMessage( - content=MOCK_FINAL_RESPONSE, - role=MessageRole.ASSISTANT, - ), - ] - ) +def _get_agent( + tools: List[BaseTool], + streaming: bool = False, +) -> ReActAgent: + if streaming: + mock_llm = MockStreamChatLLM( + responses=[ + ChatMessage( + content=MOCK_ACTION_RESPONSE, + role=MessageRole.ASSISTANT, + ), + ChatMessage( + content=MOCK_STREAM_FINAL_RESPONSE, + role=MessageRole.ASSISTANT, + ), + ] + ) + else: + mock_llm = MockChatLLM( + responses=[ + ChatMessage( + content=MOCK_ACTION_RESPONSE, + role=MessageRole.ASSISTANT, + ), + ChatMessage( + content=MOCK_FINAL_RESPONSE, + role=MessageRole.ASSISTANT, + ), + ] + ) return ReActAgent.from_tools( tools=tools, llm=mock_llm, @@ -299,7 +316,7 @@ def test_add_step( assert "tmp" in observations # stream_step - agent = _get_agent([add_tool]) + agent = _get_agent([add_tool], streaming=True) task = agent.create_task("What is 1 + 1?") # first step step_output = agent.stream_step(task.task_id) @@ -324,7 +341,7 @@ async def test_async_add_step( assert "tmp" in observations # async stream step - agent = _get_agent([add_tool]) + agent = _get_agent([add_tool], streaming=True) task = agent.create_task("What is 1 + 1?") # first step step_output = await agent.astream_step(task.task_id)