From 6372cc64ac09f217817911c1f114623a16047df5 Mon Sep 17 00:00:00 2001
From: Ming <tslmy@users.noreply.github.com>
Date: Mon, 29 Jan 2024 18:47:59 -0800
Subject: [PATCH] Improve `agent.react.output_parser`'s ability in parsing
 JSONs from Action Inputs (#10323)

---
 llama_index/agent/react/output_parser.py      | 36 +++++++++++--------
 poetry.lock                                   | 13 ++++++-
 pyproject.toml                                |  1 +
 tests/agent/react/test_react_output_parser.py | 26 ++++++++++++++
 4 files changed, 61 insertions(+), 15 deletions(-)

diff --git a/llama_index/agent/react/output_parser.py b/llama_index/agent/react/output_parser.py
index 7db41634c6..eb66e02e38 100644
--- a/llama_index/agent/react/output_parser.py
+++ b/llama_index/agent/react/output_parser.py
@@ -1,7 +1,6 @@
 """ReAct output parser."""
 
 
-import json
 import re
 from typing import Tuple
 
@@ -16,7 +15,7 @@ from llama_index.types import BaseOutputParser
 
 def extract_tool_use(input_text: str) -> Tuple[str, str, str]:
     pattern = (
-        r"\s*Thought: (.*?)\nAction: ([a-zA-Z0-9_]+).*?\nAction Input: .*?(\{.*?\})"
+        r"\s*Thought: (.*?)\nAction: ([a-zA-Z0-9_]+).*?\nAction Input: .*?(\{.*\})"
     )
 
     match = re.search(pattern, input_text, re.DOTALL)
@@ -50,6 +49,26 @@ def extract_final_response(input_text: str) -> Tuple[str, str]:
     return thought, answer
 
 
+def parse_action_reasoning_step(output: str) -> ActionReasoningStep:
+    """
+    Parse an action reasoning step from the LLM output.
+    """
+    # Weaker LLMs may generate ReActAgent steps whose Action Input are horrible JSON strings.
+    # `dirtyjson` is more lenient than `json` in parsing JSON strings.
+    import dirtyjson as json
+
+    thought, action, action_input = extract_tool_use(output)
+    json_str = extract_json_str(action_input)
+    # First we try json, if this fails we use ast
+    try:
+        action_input_dict = json.loads(json_str)
+    except Exception:
+        action_input_dict = action_input_parser(json_str)
+    return ActionReasoningStep(
+        thought=thought, action=action, action_input=action_input_dict
+    )
+
+
 class ReActOutputParser(BaseOutputParser):
     """ReAct Output parser."""
 
@@ -85,18 +104,7 @@ class ReActOutputParser(BaseOutputParser):
             )
 
         if "Action:" in output:
-            thought, action, action_input = extract_tool_use(output)
-            json_str = extract_json_str(action_input)
-
-            # First we try json, if this fails we use ast
-            try:
-                action_input_dict = json.loads(json_str)
-            except json.JSONDecodeError:
-                action_input_dict = action_input_parser(json_str)
-
-            return ActionReasoningStep(
-                thought=thought, action=action, action_input=action_input_dict
-            )
+            return parse_action_reasoning_step(output)
 
         raise ValueError(f"Could not parse output: {output}")
 
diff --git a/poetry.lock b/poetry.lock
index f8385d6077..d906ea9fec 100644
--- a/poetry.lock
+++ b/poetry.lock
@@ -1177,6 +1177,17 @@ files = [
 [package.extras]
 graph = ["objgraph (>=1.7.2)"]
 
+[[package]]
+name = "dirtyjson"
+version = "1.0.8"
+description = "JSON decoder for Python that can extract data from the muck"
+optional = false
+python-versions = "*"
+files = [
+    {file = "dirtyjson-1.0.8-py3-none-any.whl", hash = "sha256:125e27248435a58acace26d5c2c4c11a1c0de0a9c5124c5a94ba78e517d74f53"},
+    {file = "dirtyjson-1.0.8.tar.gz", hash = "sha256:90ca4a18f3ff30ce849d100dcf4a003953c79d3a2348ef056f1d9c22231a25fd"},
+]
+
 [[package]]
 name = "diskcache"
 version = "5.6.3"
@@ -7812,4 +7823,4 @@ query-tools = ["guidance", "jsonpath-ng", "lm-format-enforcer", "rank-bm25", "sc
 [metadata]
 lock-version = "2.0"
 python-versions = ">=3.8.1,<4.0"
-content-hash = "1ccf014ec22186ffcbec1325d4876089a318f0bec886beadd05b39a145e6a86a"
+content-hash = "6f230242ed4fe799ae5b98f5b64e5f388ad54c304198c6de55bc056be873397c"
diff --git a/pyproject.toml b/pyproject.toml
index 1d30bc327f..7ce3394a07 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -78,6 +78,7 @@ spacy = {optional = true, version = "^3.7.1"}
 aiohttp = "^3.8.6"
 networkx = ">=3.0"
 psycopg2-binary = {optional = true, version = "^2.9.9"}
+dirtyjson = "^1.0.8"
 
 [tool.poetry.extras]
 gradientai = [
diff --git a/tests/agent/react/test_react_output_parser.py b/tests/agent/react/test_react_output_parser.py
index 8d95218ea3..c9534048ae 100644
--- a/tests/agent/react/test_react_output_parser.py
+++ b/tests/agent/react/test_react_output_parser.py
@@ -1,9 +1,23 @@
 from llama_index.agent.react.output_parser import (
     extract_final_response,
     extract_tool_use,
+    parse_action_reasoning_step,
 )
 
 
+def test_parse_action_reasoning_step() -> None:
+    mock_input_text = """\
+Thought: Gotta use a tool.
+Action: tool
+Action Input: {'pages': ['coffee'] /* comment */, 'load_kwargs': {}, 'query_str': ''}, along those lines.
+"""
+    assert parse_action_reasoning_step(mock_input_text).action_input == {
+        "pages": ["coffee"],
+        "load_kwargs": {},
+        "query_str": "",
+    }
+
+
 def test_extract_tool_use() -> None:
     mock_input_text = """\
 Thought: I need to use a tool to help me answer the question.
@@ -16,6 +30,18 @@ Action Input: {"a": 1, "b": 1}
     assert action_input == '{"a": 1, "b": 1}'
 
 
+def test_extract_tool_use_with_nested_dicts() -> None:
+    mock_input_text = """\
+Thought: Gotta use a tool.
+Action: tool
+Action Input: {"a": 1, "b": {}}
+"""
+    thought, action, action_input = extract_tool_use(mock_input_text)
+    assert thought == "Gotta use a tool."
+    assert action == "tool"
+    assert action_input == '{"a": 1, "b": {}}'
+
+
 def test_extract_tool_use_() -> None:
     mock_input_text = """\
 Thought: I need to use a tool to help me answer the question.
-- 
GitLab