From dde989685c38a52c4bf220c3cc9c728c3e761c24 Mon Sep 17 00:00:00 2001 From: Michael Hansen <mike@rhasspy.org> Date: Mon, 16 Sep 2024 21:34:07 -0500 Subject: [PATCH] Add Assist satellite configuration (#126063) * Basic implementation * Add websocket commands * Clean up * Add callback to other signatures * Remove unused constant * Re-add callback * Add callback to test --- .../components/assist_satellite/__init__.py | 9 +- .../components/assist_satellite/entity.py | 40 +++++++ .../assist_satellite/websocket_api.py | 84 +++++++++++++ .../components/esphome/assist_satellite.py | 15 ++- .../components/voip/assist_satellite.py | 14 +++ tests/components/assist_satellite/conftest.py | 29 ++++- .../assist_satellite/test_websocket_api.py | 112 ++++++++++++++++++ 7 files changed, 300 insertions(+), 3 deletions(-) diff --git a/homeassistant/components/assist_satellite/__init__.py b/homeassistant/components/assist_satellite/__init__.py index 3d6e04bcc75..2d4459ffd8c 100644 --- a/homeassistant/components/assist_satellite/__init__.py +++ b/homeassistant/components/assist_satellite/__init__.py @@ -11,15 +11,22 @@ from homeassistant.helpers.entity_component import EntityComponent from homeassistant.helpers.typing import ConfigType from .const import DOMAIN, AssistSatelliteEntityFeature -from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription +from .entity import ( + AssistSatelliteConfiguration, + AssistSatelliteEntity, + AssistSatelliteEntityDescription, + AssistSatelliteWakeWord, +) from .errors import SatelliteBusyError from .websocket_api import async_register_websocket_api __all__ = [ "DOMAIN", "AssistSatelliteEntity", + "AssistSatelliteConfiguration", "AssistSatelliteEntityDescription", "AssistSatelliteEntityFeature", + "AssistSatelliteWakeWord", "SatelliteBusyError", ] diff --git a/homeassistant/components/assist_satellite/entity.py b/homeassistant/components/assist_satellite/entity.py index c00cb26cb63..079d3ae2948 100644 --- a/homeassistant/components/assist_satellite/entity.py +++ b/homeassistant/components/assist_satellite/entity.py @@ -4,6 +4,7 @@ from abc import abstractmethod import asyncio from collections.abc import AsyncIterable import contextlib +from dataclasses import dataclass from enum import StrEnum import logging import time @@ -57,6 +58,34 @@ class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True) """A class that describes Assist satellite entities.""" +@dataclass(frozen=True) +class AssistSatelliteWakeWord: + """Available wake word model.""" + + id: str + """Unique id for wake word model.""" + + wake_word: str + """Wake word phrase.""" + + trained_languages: list[str] + """List of languages that the wake word was trained on.""" + + +@dataclass +class AssistSatelliteConfiguration: + """Satellite configuration.""" + + available_wake_words: list[AssistSatelliteWakeWord] + """List of available available wake word models.""" + + active_wake_words: list[str] + """List of active wake word ids.""" + + max_active_wake_words: int + """Maximum number of simultaneous wake words allowed (0 for no limit).""" + + class AssistSatelliteEntity(entity.Entity): """Entity encapsulating the state and functionality of an Assist satellite.""" @@ -98,6 +127,17 @@ class AssistSatelliteEntity(entity.Entity): """Options passed for text-to-speech.""" return self._attr_tts_options + @callback + @abstractmethod + def async_get_configuration(self) -> AssistSatelliteConfiguration: + """Get the current satellite configuration.""" + + @abstractmethod + async def async_set_configuration( + self, config: AssistSatelliteConfiguration + ) -> None: + """Set the current satellite configuration.""" + async def async_intercept_wake_word(self) -> str | None: """Intercept the next wake word from the satellite. diff --git a/homeassistant/components/assist_satellite/websocket_api.py b/homeassistant/components/assist_satellite/websocket_api.py index 8de10c8a9de..0d7a434dba5 100644 --- a/homeassistant/components/assist_satellite/websocket_api.py +++ b/homeassistant/components/assist_satellite/websocket_api.py @@ -1,5 +1,6 @@ """Assist satellite Websocket API.""" +from dataclasses import asdict, replace from typing import Any import voluptuous as vol @@ -18,6 +19,8 @@ from .entity import AssistSatelliteEntity def async_register_websocket_api(hass: HomeAssistant) -> None: """Register the websocket API.""" websocket_api.async_register_command(hass, websocket_intercept_wake_word) + websocket_api.async_register_command(hass, websocket_get_configuration) + websocket_api.async_register_command(hass, websocket_set_wake_words) @callback @@ -59,3 +62,84 @@ async def websocket_intercept_wake_word( task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word") connection.subscriptions[msg["id"]] = task.cancel connection.send_message(websocket_api.result_message(msg["id"])) + + +@callback +@websocket_api.websocket_command( + { + vol.Required("type"): "assist_satellite/get_configuration", + vol.Required("entity_id"): cv.entity_domain(DOMAIN), + } +) +def websocket_get_configuration( + hass: HomeAssistant, + connection: websocket_api.connection.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Get the current satellite configuration.""" + component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN] + satellite = component.get_entity(msg["entity_id"]) + if satellite is None: + connection.send_error( + msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found" + ) + return + + config_dict = asdict(satellite.async_get_configuration()) + config_dict["pipeline_entity_id"] = satellite.pipeline_entity_id + config_dict["vad_entity_id"] = satellite.vad_sensitivity_entity_id + + connection.send_result(msg["id"], config_dict) + + +@callback +@websocket_api.websocket_command( + { + vol.Required("type"): "assist_satellite/set_wake_words", + vol.Required("entity_id"): cv.entity_domain(DOMAIN), + vol.Required("wake_word_ids"): [str], + } +) +@websocket_api.require_admin +@websocket_api.async_response +async def websocket_set_wake_words( + hass: HomeAssistant, + connection: websocket_api.connection.ActiveConnection, + msg: dict[str, Any], +) -> None: + """Set the active wake words for the satellite.""" + component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN] + satellite = component.get_entity(msg["entity_id"]) + if satellite is None: + connection.send_error( + msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found" + ) + return + + config = satellite.async_get_configuration() + + # Don't set too many active wake words + actual_ids = msg["wake_word_ids"] + if len(actual_ids) > config.max_active_wake_words: + connection.send_error( + msg["id"], + websocket_api.ERR_NOT_SUPPORTED, + f"Maximum number of active wake words is {config.max_active_wake_words}", + ) + return + + # Verify all ids are available + available_ids = {ww.id for ww in config.available_wake_words} + for ww_id in actual_ids: + if ww_id not in available_ids: + connection.send_error( + msg["id"], + websocket_api.ERR_NOT_SUPPORTED, + f"Wake word id is not supported: {ww_id}", + ) + return + + await satellite.async_set_configuration( + replace(config, active_wake_words=actual_ids) + ) + connection.send_result(msg["id"]) diff --git a/homeassistant/components/esphome/assist_satellite.py b/homeassistant/components/esphome/assist_satellite.py index 7ce46fab64b..3c66c82a734 100644 --- a/homeassistant/components/esphome/assist_satellite.py +++ b/homeassistant/components/esphome/assist_satellite.py @@ -36,7 +36,7 @@ from homeassistant.components.intent import ( from homeassistant.components.media_player import async_process_play_media_url from homeassistant.config_entries import ConfigEntry from homeassistant.const import EntityCategory, Platform -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.helpers import entity_registry as er from homeassistant.helpers.entity_platform import AddEntitiesCallback @@ -150,6 +150,19 @@ class EsphomeAssistSatellite( f"{self.entry_data.device_info.mac_address}-vad_sensitivity", ) + @callback + def async_get_configuration( + self, + ) -> assist_satellite.AssistSatelliteConfiguration: + """Get the current satellite configuration.""" + raise NotImplementedError + + async def async_set_configuration( + self, config: assist_satellite.AssistSatelliteConfiguration + ) -> None: + """Set the current satellite configuration.""" + raise NotImplementedError + async def async_added_to_hass(self) -> None: """Run when entity about to be added to hass.""" await super().async_added_to_hass() diff --git a/homeassistant/components/voip/assist_satellite.py b/homeassistant/components/voip/assist_satellite.py index f75f65a08ea..2f37a8a63e1 100644 --- a/homeassistant/components/voip/assist_satellite.py +++ b/homeassistant/components/voip/assist_satellite.py @@ -20,6 +20,7 @@ from homeassistant.components.assist_pipeline import ( PipelineNotFound, ) from homeassistant.components.assist_satellite import ( + AssistSatelliteConfiguration, AssistSatelliteEntity, AssistSatelliteEntityDescription, ) @@ -141,6 +142,19 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol assert self.voip_device.protocol == self self.voip_device.protocol = None + @callback + def async_get_configuration( + self, + ) -> AssistSatelliteConfiguration: + """Get the current satellite configuration.""" + raise NotImplementedError + + async def async_set_configuration( + self, config: AssistSatelliteConfiguration + ) -> None: + """Set the current satellite configuration.""" + raise NotImplementedError + # ------------------------------------------------------------------------- # VoIP # ------------------------------------------------------------------------- diff --git a/tests/components/assist_satellite/conftest.py b/tests/components/assist_satellite/conftest.py index a14e9e9452b..3a374b312cc 100644 --- a/tests/components/assist_satellite/conftest.py +++ b/tests/components/assist_satellite/conftest.py @@ -8,11 +8,13 @@ import pytest from homeassistant.components.assist_pipeline import PipelineEvent from homeassistant.components.assist_satellite import ( DOMAIN as AS_DOMAIN, + AssistSatelliteConfiguration, AssistSatelliteEntity, AssistSatelliteEntityFeature, + AssistSatelliteWakeWord, ) from homeassistant.config_entries import ConfigEntry, ConfigFlow -from homeassistant.core import HomeAssistant +from homeassistant.core import HomeAssistant, callback from homeassistant.setup import async_setup_component from tests.common import ( @@ -42,6 +44,20 @@ class MockAssistSatellite(AssistSatelliteEntity): """Initialize the mock entity.""" self.events = [] self.announcements = [] + self.config = AssistSatelliteConfiguration( + available_wake_words=[ + AssistSatelliteWakeWord( + id="1234", wake_word="okay nabu", trained_languages=["en"] + ), + AssistSatelliteWakeWord( + id="5678", + wake_word="hey jarvis", + trained_languages=["en"], + ), + ], + active_wake_words=["1234"], + max_active_wake_words=1, + ) def on_pipeline_event(self, event: PipelineEvent) -> None: """Handle pipeline events.""" @@ -51,6 +67,17 @@ class MockAssistSatellite(AssistSatelliteEntity): """Announce media on a device.""" self.announcements.append((message, media_id)) + @callback + def async_get_configuration(self) -> AssistSatelliteConfiguration: + """Get the current satellite configuration.""" + return self.config + + async def async_set_configuration( + self, config: AssistSatelliteConfiguration + ) -> None: + """Set the current satellite configuration.""" + self.config = config + @pytest.fixture def entity() -> MockAssistSatellite: diff --git a/tests/components/assist_satellite/test_websocket_api.py b/tests/components/assist_satellite/test_websocket_api.py index 7895ea2555a..709005e38cf 100644 --- a/tests/components/assist_satellite/test_websocket_api.py +++ b/tests/components/assist_satellite/test_websocket_api.py @@ -273,3 +273,115 @@ async def test_intercept_wake_word_unsubscribe( # Wake word should not be intercepted mock_pipeline_from_audio_stream.assert_called_once() + + +async def test_get_configuration( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test getting satellite configuration.""" + ws_client = await hass_ws_client(hass) + + with ( + patch.object(entity, "_attr_pipeline_entity_id", "select.test_pipeline"), + patch.object(entity, "_attr_vad_sensitivity_entity_id", "select.test_vad"), + ): + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/get_configuration", + "entity_id": ENTITY_ID, + } + ) + msg = await ws_client.receive_json() + assert msg["success"] + assert msg["result"] == { + "active_wake_words": ["1234"], + "available_wake_words": [ + {"id": "1234", "trained_languages": ["en"], "wake_word": "okay nabu"}, + {"id": "5678", "trained_languages": ["en"], "wake_word": "hey jarvis"}, + ], + "max_active_wake_words": 1, + "pipeline_entity_id": "select.test_pipeline", + "vad_entity_id": "select.test_vad", + } + + +async def test_set_wake_words( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test setting active wake words.""" + ws_client = await hass_ws_client(hass) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/set_wake_words", + "entity_id": ENTITY_ID, + "wake_word_ids": ["5678"], + } + ) + msg = await ws_client.receive_json() + assert msg["success"] + + # Verify change + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/get_configuration", + "entity_id": ENTITY_ID, + } + ) + msg = await ws_client.receive_json() + assert msg["success"] + assert msg["result"].get("active_wake_words") == ["5678"] + + +async def test_set_wake_words_exceed_maximum( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test setting too many active wake words.""" + ws_client = await hass_ws_client(hass) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/set_wake_words", + "entity_id": ENTITY_ID, + "wake_word_ids": ["1234", "5678"], # max of 1 + } + ) + msg = await ws_client.receive_json() + assert not msg["success"] + assert msg["error"] == { + "code": "not_supported", + "message": "Maximum number of active wake words is 1", + } + + +async def test_set_wake_words_bad_id( + hass: HomeAssistant, + init_components: ConfigEntry, + entity: MockAssistSatellite, + hass_ws_client: WebSocketGenerator, +) -> None: + """Test setting active wake words with a bad id.""" + ws_client = await hass_ws_client(hass) + + await ws_client.send_json_auto_id( + { + "type": "assist_satellite/set_wake_words", + "entity_id": ENTITY_ID, + "wake_word_ids": ["abcd"], # not an available id + } + ) + msg = await ws_client.receive_json() + assert not msg["success"] + assert msg["error"] == { + "code": "not_supported", + "message": "Wake word id is not supported: abcd", + } -- GitLab