From 7e03100af26ffc7764d8029f6e65199c0ce47a0c Mon Sep 17 00:00:00 2001 From: Paulus Schoutsen <balloob@gmail.com> Date: Wed, 27 Nov 2024 00:51:21 -0500 Subject: [PATCH] Allow an LLM to see script response values (#131683) --- homeassistant/helpers/llm.py | 24 +++++--------- tests/helpers/test_llm.py | 64 +++++++++++++++++++++++++++--------- 2 files changed, 58 insertions(+), 30 deletions(-) diff --git a/homeassistant/helpers/llm.py b/homeassistant/helpers/llm.py index d322810b0ef..49ae1455006 100644 --- a/homeassistant/helpers/llm.py +++ b/homeassistant/helpers/llm.py @@ -22,15 +22,13 @@ from homeassistant.components.conversation import ( from homeassistant.components.cover import INTENT_CLOSE_COVER, INTENT_OPEN_COVER from homeassistant.components.homeassistant import async_should_expose from homeassistant.components.intent import async_device_supports_timers -from homeassistant.components.script import ATTR_VARIABLES, DOMAIN as SCRIPT_DOMAIN +from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN from homeassistant.components.weather import INTENT_GET_WEATHER from homeassistant.const import ( ATTR_DOMAIN, - ATTR_ENTITY_ID, ATTR_SERVICE, EVENT_HOMEASSISTANT_CLOSE, EVENT_SERVICE_REMOVED, - SERVICE_TURN_ON, ) from homeassistant.core import Context, Event, HomeAssistant, callback, split_entity_id from homeassistant.exceptions import HomeAssistantError @@ -416,9 +414,7 @@ class AssistAPI(API): ): continue - script_tool = ScriptTool(self.hass, state.entity_id) - if script_tool.parameters.schema: - tools.append(script_tool) + tools.append(ScriptTool(self.hass, state.entity_id)) return tools @@ -702,10 +698,9 @@ class ScriptTool(Tool): script_entity_id: str, ) -> None: """Init the class.""" - self.name = split_entity_id(script_entity_id)[1] + self._object_id = self.name = split_entity_id(script_entity_id)[1] if self.name[0].isdigit(): self.name = "_" + self.name - self._entity_id = script_entity_id self.description, self.parameters = _get_cached_script_parameters( hass, script_entity_id @@ -745,14 +740,13 @@ class ScriptTool(Tool): floor = list(intent.find_floors(floor, floor_reg))[0].floor_id tool_input.tool_args[field] = floor - await hass.services.async_call( + result = await hass.services.async_call( SCRIPT_DOMAIN, - SERVICE_TURN_ON, - { - ATTR_ENTITY_ID: self._entity_id, - ATTR_VARIABLES: tool_input.tool_args, - }, + self._object_id, + tool_input.tool_args, context=llm_context.context, + blocking=True, + return_response=True, ) - return {"success": True} + return {"success": True, "result": result} diff --git a/tests/helpers/test_llm.py b/tests/helpers/test_llm.py index 7174d77886a..4b2fc9e5fc1 100644 --- a/tests/helpers/test_llm.py +++ b/tests/helpers/test_llm.py @@ -656,7 +656,10 @@ async def test_script_tool( "script": { "test_script": { "description": "This is a test script", - "sequence": [], + "sequence": [ + {"variables": {"result": {"drinks": 2}}}, + {"stop": True, "response_variable": "result"}, + ], "fields": { "beer": {"description": "Number of beers", "required": True}, "wine": {"selector": {"number": {"min": 0, "max": 3}}}, @@ -692,7 +695,7 @@ async def test_script_tool( api = await llm.async_get_api(hass, "assist", llm_context) tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)] - assert len(tools) == 1 + assert len(tools) == 2 tool = tools[0] assert tool.name == "test_script" @@ -719,6 +722,7 @@ async def test_script_tool( "script_with_no_fields": ("This is another test script", vol.Schema({})), } + # Test script with response tool_input = llm.ToolInput( tool_name="test_script", tool_args={ @@ -731,26 +735,56 @@ async def test_script_tool( }, ) - with patch("homeassistant.core.ServiceRegistry.async_call") as mock_service_call: + with patch( + "homeassistant.core.ServiceRegistry.async_call", + side_effect=hass.services.async_call, + ) as mock_service_call: response = await api.async_call_tool(tool_input) mock_service_call.assert_awaited_once_with( "script", - "turn_on", + "test_script", { - "entity_id": "script.test_script", - "variables": { - "beer": "3", - "wine": 0, - "where": area.id, - "area_list": [area.id], - "floor": floor.floor_id, - "floor_list": [floor.floor_id], - }, + "beer": "3", + "wine": 0, + "where": area.id, + "area_list": [area.id], + "floor": floor.floor_id, + "floor_list": [floor.floor_id], }, context=context, + blocking=True, + return_response=True, + ) + assert response == { + "success": True, + "result": {"drinks": 2}, + } + + # Test script with no response + tool_input = llm.ToolInput( + tool_name="script_with_no_fields", + tool_args={}, + ) + + with patch( + "homeassistant.core.ServiceRegistry.async_call", + side_effect=hass.services.async_call, + ) as mock_service_call: + response = await api.async_call_tool(tool_input) + + mock_service_call.assert_awaited_once_with( + "script", + "script_with_no_fields", + {}, + context=context, + blocking=True, + return_response=True, ) - assert response == {"success": True} + assert response == { + "success": True, + "result": {}, + } # Test reload script with new parameters config = { @@ -782,7 +816,7 @@ async def test_script_tool( api = await llm.async_get_api(hass, "assist", llm_context) tools = [tool for tool in api.tools if isinstance(tool, llm.ScriptTool)] - assert len(tools) == 1 + assert len(tools) == 2 tool = tools[0] assert tool.name == "test_script" -- GitLab