From 7e03100af26ffc7764d8029f6e65199c0ce47a0c Mon Sep 17 00:00:00 2001
From: Paulus Schoutsen <balloob@gmail.com>
Date: Wed, 27 Nov 2024 00:51:21 -0500
Subject: [PATCH] Allow an LLM to see script response values (#131683)

---
 homeassistant/helpers/llm.py | 24 +++++---------
 tests/helpers/test_llm.py    | 64 +++++++++++++++++++++++++++---------
 2 files changed, 58 insertions(+), 30 deletions(-)

diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py
index d322810b0ef..49ae1455006 100644
--- a/homeassistant/helpers/llm.py
+++ b/homeassistant/helpers/llm.py
@@ -22,15 +22,13 @@ from homeassistant.components.conversation import (
 from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER
 from homeassistant.components.homeassistant import async_should_expose
 from homeassistant.components.intent import async_device_supports_timers
-from homeassistant.components.script import ATTR_VARIABLES, DOMAIN as SCRIPT_DOMAIN
+from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN
 from homeassistant.components.weather import INTENT_GET_WEATHER
 from homeassistant.const import (
     ATTR_DOMAIN,
-    ATTR_ENTITY_ID,
     ATTR_SERVICE,
     EVENT_HOMEASSISTANT_CLOSE,
     EVENT_SERVICE_REMOVED,
-    SERVICE_TURN_ON,
 )
 from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id
 from homeassistant.exceptions import HomeAssistantError
@@ -416,9 +414,7 @@ class AssistAPI(API):
                 ):
                     continue
 
-                script_tool = ScriptTool(self.hass, state.entity_id)
-                if script_tool.parameters.schema:
-                    tools.append(script_tool)
+                tools.append(ScriptTool(self.hass, state.entity_id))
 
         return tools
 
@@ -702,10 +698,9 @@ class ScriptTool(Tool):
         script_entity_id: str,
     ) -> None:
         """Init the class."""
-        self.name = split_entity_id(script_entity_id)[1]
+        self._object_id = self.name = split_entity_id(script_entity_id)[1]
         if self.name[0].isdigit():
             self.name = "_" + self.name
-        self._entity_id = script_entity_id
 
         self.description, self.parameters = _get_cached_script_parameters(
             hass, script_entity_id
@@ -745,14 +740,13 @@ class ScriptTool(Tool):
                     floor = list(intent.find_floors(floor, floor_reg))[0].floor_id
                     tool_input.tool_args[field] = floor
 
-        await hass.services.async_call(
+        result = await hass.services.async_call(
             SCRIPT_DOMAIN,
-            SERVICE_TURN_ON,
-            {
-                ATTR_ENTITY_ID: self._entity_id,
-                ATTR_VARIABLES: tool_input.tool_args,
-            },
+            self._object_id,
+            tool_input.tool_args,
             context=llm_context.context,
+            blocking=True,
+            return_response=True,
         )
 
-        return {"success": True}
+        return {"success": True, "result": result}
diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py
index 7174d77886a..4b2fc9e5fc1 100644
--- a/tests/helpers/test_llm.py
+++ b/tests/helpers/test_llm.py
@@ -656,7 +656,10 @@ async def test_script_tool(
             "script": {
                 "test_script": {
                     "description": "This is a test script",
-                    "sequence": [],
+                    "sequence": [
+                        {"variables": {"result": {"drinks": 2}}},
+                        {"stop": True, "response_variable": "result"},
+                    ],
                     "fields": {
                         "beer": {"description": "Number of beers", "required": True},
                         "wine": {"selector": {"number": {"min": 0, "max": 3}}},
@@ -692,7 +695,7 @@ async def test_script_tool(
     api = await llm.async_get_api(hass, "assist", llm_context)
 
     tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
-    assert len(tools) == 1
+    assert len(tools) == 2
 
     tool = tools[0]
     assert tool.name == "test_script"
@@ -719,6 +722,7 @@ async def test_script_tool(
         "script_with_no_fields": ("This is another test script", vol.Schema({})),
     }
 
+    # Test script with response
     tool_input = llm.ToolInput(
         tool_name="test_script",
         tool_args={
@@ -731,26 +735,56 @@ async def test_script_tool(
         },
     )
 
-    with patch("homeassistant.core.ServiceRegistry.async_call") as mock_service_call:
+    with patch(
+        "homeassistant.core.ServiceRegistry.async_call",
+        side_effect=hass.services.async_call,
+    ) as mock_service_call:
         response = await api.async_call_tool(tool_input)
 
     mock_service_call.assert_awaited_once_with(
         "script",
-        "turn_on",
+        "test_script",
         {
-            "entity_id": "script.test_script",
-            "variables": {
-                "beer": "3",
-                "wine": 0,
-                "where": area.id,
-                "area_list": [area.id],
-                "floor": floor.floor_id,
-                "floor_list": [floor.floor_id],
-            },
+            "beer": "3",
+            "wine": 0,
+            "where": area.id,
+            "area_list": [area.id],
+            "floor": floor.floor_id,
+            "floor_list": [floor.floor_id],
         },
         context=context,
+        blocking=True,
+        return_response=True,
+    )
+    assert response == {
+        "success": True,
+        "result": {"drinks": 2},
+    }
+
+    # Test script with no response
+    tool_input = llm.ToolInput(
+        tool_name="script_with_no_fields",
+        tool_args={},
+    )
+
+    with patch(
+        "homeassistant.core.ServiceRegistry.async_call",
+        side_effect=hass.services.async_call,
+    ) as mock_service_call:
+        response = await api.async_call_tool(tool_input)
+
+    mock_service_call.assert_awaited_once_with(
+        "script",
+        "script_with_no_fields",
+        {},
+        context=context,
+        blocking=True,
+        return_response=True,
     )
-    assert response == {"success": True}
+    assert response == {
+        "success": True,
+        "result": {},
+    }
 
     # Test reload script with new parameters
     config = {
@@ -782,7 +816,7 @@ async def test_script_tool(
     api = await llm.async_get_api(hass, "assist", llm_context)
 
     tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)]
-    assert len(tools) == 1
+    assert len(tools) == 2
 
     tool = tools[0]
     assert tool.name == "test_script"
-- 
GitLab