diff --git a/llama-index-core/llama_index/core/agent/react/base.py b/llama-index-core/llama_index/core/agent/react/base.py index 7bf445e2f59244585e60b70b4655ff881416fb3b..e9ac6219ec0f06d8c29e7f9e14dd454b54b795e9 100644 --- a/llama-index-core/llama_index/core/agent/react/base.py +++ b/llama-index-core/llama_index/core/agent/react/base.py @@ -14,6 +14,7 @@ from typing import ( Optional, Sequence, Type, + Callable, ) from llama_index.core.agent.react.formatter import ReActChatFormatter @@ -29,7 +30,7 @@ from llama_index.core.memory.chat_memory_buffer import ChatMemoryBuffer from llama_index.core.memory.types import BaseMemory from llama_index.core.objects.base import ObjectRetriever from llama_index.core.settings import Settings -from llama_index.core.tools import BaseTool +from llama_index.core.tools import BaseTool, ToolOutput from llama_index.core.prompts.mixin import PromptMixinType @@ -57,6 +58,9 @@ class ReActAgent(AgentRunner): verbose: bool = False, tool_retriever: Optional[ObjectRetriever[BaseTool]] = None, context: Optional[str] = None, + handle_reasoning_failure_fn: Optional[ + Callable[[CallbackManager, Exception], ToolOutput] + ] = None, ) -> None: """Init params.""" callback_manager = callback_manager or llm.callback_manager @@ -74,6 +78,7 @@ class ReActAgent(AgentRunner): output_parser=output_parser, callback_manager=callback_manager, verbose=verbose, + handle_reasoning_failure_fn=handle_reasoning_failure_fn, ) super().__init__( step_engine, @@ -97,15 +102,26 @@ class ReActAgent(AgentRunner): callback_manager: Optional[CallbackManager] = None, verbose: bool = False, context: Optional[str] = None, + handle_reasoning_failure_fn: Optional[ + Callable[[CallbackManager, Exception], ToolOutput] + ] = None, **kwargs: Any, ) -> "ReActAgent": - """Convenience constructor method from set of of BaseTools (Optional). + """Convenience constructor method from set of BaseTools (Optional). NOTE: kwargs should have been exhausted by this point. In other words the various upstream components such as BaseSynthesizer (response synthesizer) or BaseRetriever should have picked up off their respective kwargs in their constructions. + If `handle_reasoning_failure_fn` is provided, when LLM fails to follow the response templates specified in + the System Prompt, this function will be called. This function should provide to the Agent, so that the Agent + can have a second chance to fix its mistakes. + To handle the exception yourself, you can provide a function that raises the `Exception`. + + Note: If you modified any response template in the System Prompt, you should override the method + `_extract_reasoning_step` in `ReActAgentWorker`. + Returns: ReActAgent """ @@ -126,6 +142,7 @@ class ReActAgent(AgentRunner): callback_manager=callback_manager, verbose=verbose, context=context, + handle_reasoning_failure_fn=handle_reasoning_failure_fn, ) def _get_prompt_modules(self) -> PromptMixinType: diff --git a/llama-index-core/llama_index/core/agent/react/step.py b/llama-index-core/llama_index/core/agent/react/step.py index d15fb4195c35a92733d0334b7b195eb4ec325bef..8ba9027d9be478fbb4c5baf798d5b69aad3061b8 100644 --- a/llama-index-core/llama_index/core/agent/react/step.py +++ b/llama-index-core/llama_index/core/agent/react/step.py @@ -15,6 +15,7 @@ from typing import ( Sequence, Tuple, cast, + Callable, ) from llama_index.core.agent.react.formatter import ReActChatFormatter @@ -74,6 +75,31 @@ def add_user_step_to_reasoning( print(f"Added user message to memory: {step.input}") +def tell_llm_about_failure_in_extract_reasoning_step( + callback_manager: CallbackManager, _: ValueError +) -> ToolOutput: + """ + If the developer has instructed to tell the Agent a complaint about its non-cooperation, + we will emit a Tool Output that we prepared (at initialization time) to the LLM, so that + the LLM can be more cooperative in its next generation. + """ + message = "Error: Could not parse output. Please follow the thought-action-input format. Try again." + dummy_tool_output = ToolOutput( + content=message, + tool_name="unknown", + raw_input={}, + raw_output=message, + ) + with callback_manager.event( + CBEventType.FUNCTION_CALL, + payload={ + EventPayload.FUNCTION_CALL: "unknown", + }, + ) as event: + event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(dummy_tool_output)}) + return dummy_tool_output + + class ReActAgentWorker(BaseAgentWorker): """OpenAI Agent worker.""" @@ -87,6 +113,9 @@ class ReActAgentWorker(BaseAgentWorker): callback_manager: Optional[CallbackManager] = None, verbose: bool = False, tool_retriever: Optional[ObjectRetriever[BaseTool]] = None, + handle_reasoning_failure_fn: Optional[ + Callable[[CallbackManager, Exception], ToolOutput] + ] = None, ) -> None: self._llm = llm self.callback_manager = callback_manager or llm.callback_manager @@ -94,6 +123,10 @@ class ReActAgentWorker(BaseAgentWorker): self._react_chat_formatter = react_chat_formatter or ReActChatFormatter() self._output_parser = output_parser or ReActOutputParser() self._verbose = verbose + self._handle_reasoning_failure_fn = ( + handle_reasoning_failure_fn + or tell_llm_about_failure_in_extract_reasoning_step + ) if len(tools) > 0 and tool_retriever is not None: raise ValueError("Cannot specify both tools and tool_retriever") @@ -116,9 +149,12 @@ class ReActAgentWorker(BaseAgentWorker): output_parser: Optional[ReActOutputParser] = None, callback_manager: Optional[CallbackManager] = None, verbose: bool = False, + handle_reasoning_failure_fn: Optional[ + Callable[[CallbackManager, Exception], ToolOutput] + ] = None, **kwargs: Any, ) -> "ReActAgentWorker": - """Convenience constructor method from set of of BaseTools (Optional). + """Convenience constructor method from set of BaseTools (Optional). NOTE: kwargs should have been exhausted by this point. In other words the various upstream components such as BaseSynthesizer (response synthesizer) @@ -126,7 +162,7 @@ class ReActAgentWorker(BaseAgentWorker): constructions. Returns: - ReActAgent + ReActAgentWorker """ llm = llm or Settings.llm if callback_manager is not None: @@ -140,6 +176,7 @@ class ReActAgentWorker(BaseAgentWorker): output_parser=output_parser, callback_manager=callback_manager, verbose=verbose, + handle_reasoning_failure_fn=handle_reasoning_failure_fn, ) def _get_prompts(self) -> PromptDictType: @@ -223,36 +260,42 @@ class ReActAgentWorker(BaseAgentWorker): tools_dict: Dict[str, AsyncBaseTool] = { tool.metadata.get_name(): tool for tool in tools } - _, current_reasoning, is_done = self._extract_reasoning_step( - output, is_streaming - ) - - if is_done: - return current_reasoning, True - - # call tool with input - reasoning_step = cast(ActionReasoningStep, current_reasoning[-1]) - if reasoning_step.action in tools_dict: - tool = tools_dict[reasoning_step.action] - with self.callback_manager.event( - CBEventType.FUNCTION_CALL, - payload={ - EventPayload.FUNCTION_CALL: reasoning_step.action_input, - EventPayload.TOOL: tool.metadata, - }, - ) as event: - try: - tool_output = tool.call(**reasoning_step.action_input) - except Exception as e: - tool_output = ToolOutput( - content=f"Error: {e!s}", - tool_name=tool.metadata.name, - raw_input={"kwargs": reasoning_step.action_input}, - raw_output=e, - ) - event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)}) + try: + _, current_reasoning, is_done = self._extract_reasoning_step( + output, is_streaming + ) + except ValueError as exp: + current_reasoning = [] + tool_output = self._handle_reasoning_failure_fn(self.callback_manager, exp) else: - tool_output = self._handle_nonexistent_tool_name(reasoning_step) + if is_done: + return current_reasoning, True + + # call tool with input + reasoning_step = cast(ActionReasoningStep, current_reasoning[-1]) + if reasoning_step.action in tools_dict: + tool = tools_dict[reasoning_step.action] + with self.callback_manager.event( + CBEventType.FUNCTION_CALL, + payload={ + EventPayload.FUNCTION_CALL: reasoning_step.action_input, + EventPayload.TOOL: tool.metadata, + }, + ) as event: + try: + tool_output = tool.call(**reasoning_step.action_input) + except Exception as e: + tool_output = ToolOutput( + content=f"Error: {e!s}", + tool_name=tool.metadata.name, + raw_input={"kwargs": reasoning_step.action_input}, + raw_output=e, + ) + event.on_end( + payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)} + ) + else: + tool_output = self._handle_nonexistent_tool_name(reasoning_step) task.extra_state["sources"].append(tool_output) @@ -270,36 +313,42 @@ class ReActAgentWorker(BaseAgentWorker): is_streaming: bool = False, ) -> Tuple[List[BaseReasoningStep], bool]: tools_dict = {tool.metadata.name: tool for tool in tools} - _, current_reasoning, is_done = self._extract_reasoning_step( - output, is_streaming - ) - - if is_done: - return current_reasoning, True - - # call tool with input - reasoning_step = cast(ActionReasoningStep, current_reasoning[-1]) - if reasoning_step.action in tools_dict: - tool = tools_dict[reasoning_step.action] - with self.callback_manager.event( - CBEventType.FUNCTION_CALL, - payload={ - EventPayload.FUNCTION_CALL: reasoning_step.action_input, - EventPayload.TOOL: tool.metadata, - }, - ) as event: - try: - tool_output = await tool.acall(**reasoning_step.action_input) - except Exception as e: - tool_output = ToolOutput( - content=f"Error: {e!s}", - tool_name=tool.metadata.name, - raw_input={"kwargs": reasoning_step.action_input}, - raw_output=e, - ) - event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)}) + try: + _, current_reasoning, is_done = self._extract_reasoning_step( + output, is_streaming + ) + except ValueError as exp: + current_reasoning = [] + tool_output = self._handle_reasoning_failure_fn(self.callback_manager, exp) else: - tool_output = self._handle_nonexistent_tool_name(reasoning_step) + if is_done: + return current_reasoning, True + + # call tool with input + reasoning_step = cast(ActionReasoningStep, current_reasoning[-1]) + if reasoning_step.action in tools_dict: + tool = tools_dict[reasoning_step.action] + with self.callback_manager.event( + CBEventType.FUNCTION_CALL, + payload={ + EventPayload.FUNCTION_CALL: reasoning_step.action_input, + EventPayload.TOOL: tool.metadata, + }, + ) as event: + try: + tool_output = await tool.acall(**reasoning_step.action_input) + except Exception as e: + tool_output = ToolOutput( + content=f"Error: {e!s}", + tool_name=tool.metadata.name, + raw_input={"kwargs": reasoning_step.action_input}, + raw_output=e, + ) + event.on_end( + payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)} + ) + else: + tool_output = self._handle_nonexistent_tool_name(reasoning_step) task.extra_state["sources"].append(tool_output) diff --git a/llama-index-core/tests/agent/react/test_react_agent.py b/llama-index-core/tests/agent/react/test_react_agent.py index f26a3ddf876a9b30efd125b79a43aeabf5ad1837..f853847eba71d258047730b7e4a7407cf73493c8 100644 --- a/llama-index-core/tests/agent/react/test_react_agent.py +++ b/llama-index-core/tests/agent/react/test_react_agent.py @@ -305,6 +305,26 @@ def _get_observations(task: Task) -> List[str]: return [s.observation for s in obs_steps] +def test_complaint_when_no_reasoning_step(): + runner = ReActAgent.from_tools( + tools=[], + llm=MockLLM(), + ) + task = runner.create_task("lorem") + chat_response = ChatResponse( + message=ChatMessage( + content="Thought: ipsum\nAction: dolor", role=MessageRole.ASSISTANT + ) + ) + current_reasoning, is_done = runner.agent_worker._process_actions( + task, tools=[], output=chat_response + ) + assert ( + current_reasoning[0].get_content() + == "Observation: Error: Could not parse output. Please follow the thought-action-input format. Try again." + ) + + def test_add_step( add_tool: FunctionTool, ) -> None: