From 107b37e878cf4ebb798b2fec0ad08439d0d717da Mon Sep 17 00:00:00 2001
From: Ming <tslmy@users.noreply.github.com>
Date: Sun, 31 Mar 2024 17:47:30 -0700
Subject: [PATCH] [ReAct][Robustness++] When LLM failed to follow the response
 template, tell it so (#12300)

---
 .../llama_index/core/agent/react/base.py      |  21 ++-
 .../llama_index/core/agent/react/step.py      | 169 +++++++++++-------
 .../tests/agent/react/test_react_agent.py     |  20 +++
 3 files changed, 148 insertions(+), 62 deletions(-)

diff --git a/llama-index-core/llama_index/core/agent/react/base.py b/llama-index-core/llama_index/core/agent/react/base.py
index 7bf445e2f5..e9ac6219ec 100644
--- a/llama-index-core/llama_index/core/agent/react/base.py
+++ b/llama-index-core/llama_index/core/agent/react/base.py
@@ -14,6 +14,7 @@ from typing import (
     Optional,
     Sequence,
     Type,
+    Callable,
 )
 
 from llama_index.core.agent.react.formatter import ReActChatFormatter
@@ -29,7 +30,7 @@ from llama_index.core.memory.chat_memory_buffer import ChatMemoryBuffer
 from llama_index.core.memory.types import BaseMemory
 from llama_index.core.objects.base import ObjectRetriever
 from llama_index.core.settings import Settings
-from llama_index.core.tools import BaseTool
+from llama_index.core.tools import BaseTool, ToolOutput
 from llama_index.core.prompts.mixin import PromptMixinType
 
 
@@ -57,6 +58,9 @@ class ReActAgent(AgentRunner):
         verbose: bool = False,
         tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
         context: Optional[str] = None,
+        handle_reasoning_failure_fn: Optional[
+            Callable[[CallbackManager, Exception], ToolOutput]
+        ] = None,
     ) -> None:
         """Init params."""
         callback_manager = callback_manager or llm.callback_manager
@@ -74,6 +78,7 @@ class ReActAgent(AgentRunner):
             output_parser=output_parser,
             callback_manager=callback_manager,
             verbose=verbose,
+            handle_reasoning_failure_fn=handle_reasoning_failure_fn,
         )
         super().__init__(
             step_engine,
@@ -97,15 +102,26 @@ class ReActAgent(AgentRunner):
         callback_manager: Optional[CallbackManager] = None,
         verbose: bool = False,
         context: Optional[str] = None,
+        handle_reasoning_failure_fn: Optional[
+            Callable[[CallbackManager, Exception], ToolOutput]
+        ] = None,
         **kwargs: Any,
     ) -> "ReActAgent":
-        """Convenience constructor method from set of of BaseTools (Optional).
+        """Convenience constructor method from set of BaseTools (Optional).
 
         NOTE: kwargs should have been exhausted by this point. In other words
         the various upstream components such as BaseSynthesizer (response synthesizer)
         or BaseRetriever should have picked up off their respective kwargs in their
         constructions.
 
+        If `handle_reasoning_failure_fn` is provided, when LLM fails to follow the response templates specified in
+        the System Prompt, this function will be called. This function should provide to the Agent, so that the Agent
+        can have a second chance to fix its mistakes.
+        To handle the exception yourself, you can provide a function that raises the `Exception`.
+
+        Note: If you modified any response template in the System Prompt, you should override the method
+        `_extract_reasoning_step` in `ReActAgentWorker`.
+
         Returns:
             ReActAgent
         """
@@ -126,6 +142,7 @@ class ReActAgent(AgentRunner):
             callback_manager=callback_manager,
             verbose=verbose,
             context=context,
+            handle_reasoning_failure_fn=handle_reasoning_failure_fn,
         )
 
     def _get_prompt_modules(self) -> PromptMixinType:
diff --git a/llama-index-core/llama_index/core/agent/react/step.py b/llama-index-core/llama_index/core/agent/react/step.py
index d15fb4195c..8ba9027d9b 100644
--- a/llama-index-core/llama_index/core/agent/react/step.py
+++ b/llama-index-core/llama_index/core/agent/react/step.py
@@ -15,6 +15,7 @@ from typing import (
     Sequence,
     Tuple,
     cast,
+    Callable,
 )
 
 from llama_index.core.agent.react.formatter import ReActChatFormatter
@@ -74,6 +75,31 @@ def add_user_step_to_reasoning(
             print(f"Added user message to memory: {step.input}")
 
 
+def tell_llm_about_failure_in_extract_reasoning_step(
+    callback_manager: CallbackManager, _: ValueError
+) -> ToolOutput:
+    """
+    If the developer has instructed to tell the Agent a complaint about its non-cooperation,
+    we will emit a Tool Output that we prepared (at initialization time) to the LLM, so that
+    the LLM can be more cooperative in its next generation.
+    """
+    message = "Error: Could not parse output. Please follow the thought-action-input format. Try again."
+    dummy_tool_output = ToolOutput(
+        content=message,
+        tool_name="unknown",
+        raw_input={},
+        raw_output=message,
+    )
+    with callback_manager.event(
+        CBEventType.FUNCTION_CALL,
+        payload={
+            EventPayload.FUNCTION_CALL: "unknown",
+        },
+    ) as event:
+        event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(dummy_tool_output)})
+    return dummy_tool_output
+
+
 class ReActAgentWorker(BaseAgentWorker):
     """OpenAI Agent worker."""
 
@@ -87,6 +113,9 @@ class ReActAgentWorker(BaseAgentWorker):
         callback_manager: Optional[CallbackManager] = None,
         verbose: bool = False,
         tool_retriever: Optional[ObjectRetriever[BaseTool]] = None,
+        handle_reasoning_failure_fn: Optional[
+            Callable[[CallbackManager, Exception], ToolOutput]
+        ] = None,
     ) -> None:
         self._llm = llm
         self.callback_manager = callback_manager or llm.callback_manager
@@ -94,6 +123,10 @@ class ReActAgentWorker(BaseAgentWorker):
         self._react_chat_formatter = react_chat_formatter or ReActChatFormatter()
         self._output_parser = output_parser or ReActOutputParser()
         self._verbose = verbose
+        self._handle_reasoning_failure_fn = (
+            handle_reasoning_failure_fn
+            or tell_llm_about_failure_in_extract_reasoning_step
+        )
 
         if len(tools) > 0 and tool_retriever is not None:
             raise ValueError("Cannot specify both tools and tool_retriever")
@@ -116,9 +149,12 @@ class ReActAgentWorker(BaseAgentWorker):
         output_parser: Optional[ReActOutputParser] = None,
         callback_manager: Optional[CallbackManager] = None,
         verbose: bool = False,
+        handle_reasoning_failure_fn: Optional[
+            Callable[[CallbackManager, Exception], ToolOutput]
+        ] = None,
         **kwargs: Any,
     ) -> "ReActAgentWorker":
-        """Convenience constructor method from set of of BaseTools (Optional).
+        """Convenience constructor method from set of BaseTools (Optional).
 
         NOTE: kwargs should have been exhausted by this point. In other words
         the various upstream components such as BaseSynthesizer (response synthesizer)
@@ -126,7 +162,7 @@ class ReActAgentWorker(BaseAgentWorker):
         constructions.
 
         Returns:
-            ReActAgent
+            ReActAgentWorker
         """
         llm = llm or Settings.llm
         if callback_manager is not None:
@@ -140,6 +176,7 @@ class ReActAgentWorker(BaseAgentWorker):
             output_parser=output_parser,
             callback_manager=callback_manager,
             verbose=verbose,
+            handle_reasoning_failure_fn=handle_reasoning_failure_fn,
         )
 
     def _get_prompts(self) -> PromptDictType:
@@ -223,36 +260,42 @@ class ReActAgentWorker(BaseAgentWorker):
         tools_dict: Dict[str, AsyncBaseTool] = {
             tool.metadata.get_name(): tool for tool in tools
         }
-        _, current_reasoning, is_done = self._extract_reasoning_step(
-            output, is_streaming
-        )
-
-        if is_done:
-            return current_reasoning, True
-
-        # call tool with input
-        reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
-        if reasoning_step.action in tools_dict:
-            tool = tools_dict[reasoning_step.action]
-            with self.callback_manager.event(
-                CBEventType.FUNCTION_CALL,
-                payload={
-                    EventPayload.FUNCTION_CALL: reasoning_step.action_input,
-                    EventPayload.TOOL: tool.metadata,
-                },
-            ) as event:
-                try:
-                    tool_output = tool.call(**reasoning_step.action_input)
-                except Exception as e:
-                    tool_output = ToolOutput(
-                        content=f"Error: {e!s}",
-                        tool_name=tool.metadata.name,
-                        raw_input={"kwargs": reasoning_step.action_input},
-                        raw_output=e,
-                    )
-                event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
+        try:
+            _, current_reasoning, is_done = self._extract_reasoning_step(
+                output, is_streaming
+            )
+        except ValueError as exp:
+            current_reasoning = []
+            tool_output = self._handle_reasoning_failure_fn(self.callback_manager, exp)
         else:
-            tool_output = self._handle_nonexistent_tool_name(reasoning_step)
+            if is_done:
+                return current_reasoning, True
+
+            # call tool with input
+            reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
+            if reasoning_step.action in tools_dict:
+                tool = tools_dict[reasoning_step.action]
+                with self.callback_manager.event(
+                    CBEventType.FUNCTION_CALL,
+                    payload={
+                        EventPayload.FUNCTION_CALL: reasoning_step.action_input,
+                        EventPayload.TOOL: tool.metadata,
+                    },
+                ) as event:
+                    try:
+                        tool_output = tool.call(**reasoning_step.action_input)
+                    except Exception as e:
+                        tool_output = ToolOutput(
+                            content=f"Error: {e!s}",
+                            tool_name=tool.metadata.name,
+                            raw_input={"kwargs": reasoning_step.action_input},
+                            raw_output=e,
+                        )
+                    event.on_end(
+                        payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)}
+                    )
+            else:
+                tool_output = self._handle_nonexistent_tool_name(reasoning_step)
 
         task.extra_state["sources"].append(tool_output)
 
@@ -270,36 +313,42 @@ class ReActAgentWorker(BaseAgentWorker):
         is_streaming: bool = False,
     ) -> Tuple[List[BaseReasoningStep], bool]:
         tools_dict = {tool.metadata.name: tool for tool in tools}
-        _, current_reasoning, is_done = self._extract_reasoning_step(
-            output, is_streaming
-        )
-
-        if is_done:
-            return current_reasoning, True
-
-        # call tool with input
-        reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
-        if reasoning_step.action in tools_dict:
-            tool = tools_dict[reasoning_step.action]
-            with self.callback_manager.event(
-                CBEventType.FUNCTION_CALL,
-                payload={
-                    EventPayload.FUNCTION_CALL: reasoning_step.action_input,
-                    EventPayload.TOOL: tool.metadata,
-                },
-            ) as event:
-                try:
-                    tool_output = await tool.acall(**reasoning_step.action_input)
-                except Exception as e:
-                    tool_output = ToolOutput(
-                        content=f"Error: {e!s}",
-                        tool_name=tool.metadata.name,
-                        raw_input={"kwargs": reasoning_step.action_input},
-                        raw_output=e,
-                    )
-                event.on_end(payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)})
+        try:
+            _, current_reasoning, is_done = self._extract_reasoning_step(
+                output, is_streaming
+            )
+        except ValueError as exp:
+            current_reasoning = []
+            tool_output = self._handle_reasoning_failure_fn(self.callback_manager, exp)
         else:
-            tool_output = self._handle_nonexistent_tool_name(reasoning_step)
+            if is_done:
+                return current_reasoning, True
+
+            # call tool with input
+            reasoning_step = cast(ActionReasoningStep, current_reasoning[-1])
+            if reasoning_step.action in tools_dict:
+                tool = tools_dict[reasoning_step.action]
+                with self.callback_manager.event(
+                    CBEventType.FUNCTION_CALL,
+                    payload={
+                        EventPayload.FUNCTION_CALL: reasoning_step.action_input,
+                        EventPayload.TOOL: tool.metadata,
+                    },
+                ) as event:
+                    try:
+                        tool_output = await tool.acall(**reasoning_step.action_input)
+                    except Exception as e:
+                        tool_output = ToolOutput(
+                            content=f"Error: {e!s}",
+                            tool_name=tool.metadata.name,
+                            raw_input={"kwargs": reasoning_step.action_input},
+                            raw_output=e,
+                        )
+                    event.on_end(
+                        payload={EventPayload.FUNCTION_OUTPUT: str(tool_output)}
+                    )
+            else:
+                tool_output = self._handle_nonexistent_tool_name(reasoning_step)
 
         task.extra_state["sources"].append(tool_output)
 
diff --git a/llama-index-core/tests/agent/react/test_react_agent.py b/llama-index-core/tests/agent/react/test_react_agent.py
index f26a3ddf87..f853847eba 100644
--- a/llama-index-core/tests/agent/react/test_react_agent.py
+++ b/llama-index-core/tests/agent/react/test_react_agent.py
@@ -305,6 +305,26 @@ def _get_observations(task: Task) -> List[str]:
     return [s.observation for s in obs_steps]
 
 
+def test_complaint_when_no_reasoning_step():
+    runner = ReActAgent.from_tools(
+        tools=[],
+        llm=MockLLM(),
+    )
+    task = runner.create_task("lorem")
+    chat_response = ChatResponse(
+        message=ChatMessage(
+            content="Thought: ipsum\nAction: dolor", role=MessageRole.ASSISTANT
+        )
+    )
+    current_reasoning, is_done = runner.agent_worker._process_actions(
+        task, tools=[], output=chat_response
+    )
+    assert (
+        current_reasoning[0].get_content()
+        == "Observation: Error: Could not parse output. Please follow the thought-action-input format. Try again."
+    )
+
+
 def test_add_step(
     add_tool: FunctionTool,
 ) -> None:
-- 
GitLab