From dde989685c38a52c4bf220c3cc9c728c3e761c24 Mon Sep 17 00:00:00 2001
From: Michael Hansen <mike@rhasspy.org>
Date: Mon, 16 Sep 2024 21:34:07 -0500
Subject: [PATCH] Add Assist satellite configuration (#126063)

* Basic implementation

* Add websocket commands

* Clean up

* Add callback to other signatures

* Remove unused constant

* Re-add callback

* Add callback to test
---
 .../components/assist_satellite/__init__.py   |   9 +-
 .../components/assist_satellite/entity.py     |  40 +++++++
 .../assist_satellite/websocket_api.py         |  84 +++++++++++++
 .../components/esphome/assist_satellite.py    |  15 ++-
 .../components/voip/assist_satellite.py       |  14 +++
 tests/components/assist_satellite/conftest.py |  29 ++++-
 .../assist_satellite/test_websocket_api.py    | 112 ++++++++++++++++++
 7 files changed, 300 insertions(+), 3 deletions(-)

diff --git a/homeassistant/components/assist_satellite/__init__.py b/homeassistant/components/assist_satellite/__init__.py
index 3d6e04bcc75..2d4459ffd8c 100644
--- a/homeassistant/components/assist_satellite/__init__.py
+++ b/homeassistant/components/assist_satellite/__init__.py
@@ -11,15 +11,22 @@ from homeassistant.helpers.entity_component import EntityComponent
 from homeassistant.helpers.typing import ConfigType
 
 from .const import DOMAIN, AssistSatelliteEntityFeature
-from .entity import AssistSatelliteEntity, AssistSatelliteEntityDescription
+from .entity import (
+    AssistSatelliteConfiguration,
+    AssistSatelliteEntity,
+    AssistSatelliteEntityDescription,
+    AssistSatelliteWakeWord,
+)
 from .errors import SatelliteBusyError
 from .websocket_api import async_register_websocket_api
 
 __all__ = [
     "DOMAIN",
     "AssistSatelliteEntity",
+    "AssistSatelliteConfiguration",
     "AssistSatelliteEntityDescription",
     "AssistSatelliteEntityFeature",
+    "AssistSatelliteWakeWord",
     "SatelliteBusyError",
 ]
 
diff --git a/homeassistant/components/assist_satellite/entity.py b/homeassistant/components/assist_satellite/entity.py
index c00cb26cb63..079d3ae2948 100644
--- a/homeassistant/components/assist_satellite/entity.py
+++ b/homeassistant/components/assist_satellite/entity.py
@@ -4,6 +4,7 @@ from abc import abstractmethod
 import asyncio
 from collections.abc import AsyncIterable
 import contextlib
+from dataclasses import dataclass
 from enum import StrEnum
 import logging
 import time
@@ -57,6 +58,34 @@ class AssistSatelliteEntityDescription(EntityDescription, frozen_or_thawed=True)
     """A class that describes Assist satellite entities."""
 
 
+@dataclass(frozen=True)
+class AssistSatelliteWakeWord:
+    """Available wake word model."""
+
+    id: str
+    """Unique id for wake word model."""
+
+    wake_word: str
+    """Wake word phrase."""
+
+    trained_languages: list[str]
+    """List of languages that the wake word was trained on."""
+
+
+@dataclass
+class AssistSatelliteConfiguration:
+    """Satellite configuration."""
+
+    available_wake_words: list[AssistSatelliteWakeWord]
+    """List of available available wake word models."""
+
+    active_wake_words: list[str]
+    """List of active wake word ids."""
+
+    max_active_wake_words: int
+    """Maximum number of simultaneous wake words allowed (0 for no limit)."""
+
+
 class AssistSatelliteEntity(entity.Entity):
     """Entity encapsulating the state and functionality of an Assist satellite."""
 
@@ -98,6 +127,17 @@ class AssistSatelliteEntity(entity.Entity):
         """Options passed for text-to-speech."""
         return self._attr_tts_options
 
+    @callback
+    @abstractmethod
+    def async_get_configuration(self) -> AssistSatelliteConfiguration:
+        """Get the current satellite configuration."""
+
+    @abstractmethod
+    async def async_set_configuration(
+        self, config: AssistSatelliteConfiguration
+    ) -> None:
+        """Set the current satellite configuration."""
+
     async def async_intercept_wake_word(self) -> str | None:
         """Intercept the next wake word from the satellite.
 
diff --git a/homeassistant/components/assist_satellite/websocket_api.py b/homeassistant/components/assist_satellite/websocket_api.py
index 8de10c8a9de..0d7a434dba5 100644
--- a/homeassistant/components/assist_satellite/websocket_api.py
+++ b/homeassistant/components/assist_satellite/websocket_api.py
@@ -1,5 +1,6 @@
 """Assist satellite Websocket API."""
 
+from dataclasses import asdict, replace
 from typing import Any
 
 import voluptuous as vol
@@ -18,6 +19,8 @@ from .entity import AssistSatelliteEntity
 def async_register_websocket_api(hass: HomeAssistant) -> None:
     """Register the websocket API."""
     websocket_api.async_register_command(hass, websocket_intercept_wake_word)
+    websocket_api.async_register_command(hass, websocket_get_configuration)
+    websocket_api.async_register_command(hass, websocket_set_wake_words)
 
 
 @callback
@@ -59,3 +62,84 @@ async def websocket_intercept_wake_word(
     task = hass.async_create_task(intercept_wake_word(), "intercept_wake_word")
     connection.subscriptions[msg["id"]] = task.cancel
     connection.send_message(websocket_api.result_message(msg["id"]))
+
+
+@callback
+@websocket_api.websocket_command(
+    {
+        vol.Required("type"): "assist_satellite/get_configuration",
+        vol.Required("entity_id"): cv.entity_domain(DOMAIN),
+    }
+)
+def websocket_get_configuration(
+    hass: HomeAssistant,
+    connection: websocket_api.connection.ActiveConnection,
+    msg: dict[str, Any],
+) -> None:
+    """Get the current satellite configuration."""
+    component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
+    satellite = component.get_entity(msg["entity_id"])
+    if satellite is None:
+        connection.send_error(
+            msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
+        )
+        return
+
+    config_dict = asdict(satellite.async_get_configuration())
+    config_dict["pipeline_entity_id"] = satellite.pipeline_entity_id
+    config_dict["vad_entity_id"] = satellite.vad_sensitivity_entity_id
+
+    connection.send_result(msg["id"], config_dict)
+
+
+@callback
+@websocket_api.websocket_command(
+    {
+        vol.Required("type"): "assist_satellite/set_wake_words",
+        vol.Required("entity_id"): cv.entity_domain(DOMAIN),
+        vol.Required("wake_word_ids"): [str],
+    }
+)
+@websocket_api.require_admin
+@websocket_api.async_response
+async def websocket_set_wake_words(
+    hass: HomeAssistant,
+    connection: websocket_api.connection.ActiveConnection,
+    msg: dict[str, Any],
+) -> None:
+    """Set the active wake words for the satellite."""
+    component: EntityComponent[AssistSatelliteEntity] = hass.data[DOMAIN]
+    satellite = component.get_entity(msg["entity_id"])
+    if satellite is None:
+        connection.send_error(
+            msg["id"], websocket_api.ERR_NOT_FOUND, "Entity not found"
+        )
+        return
+
+    config = satellite.async_get_configuration()
+
+    # Don't set too many active wake words
+    actual_ids = msg["wake_word_ids"]
+    if len(actual_ids) > config.max_active_wake_words:
+        connection.send_error(
+            msg["id"],
+            websocket_api.ERR_NOT_SUPPORTED,
+            f"Maximum number of active wake words is {config.max_active_wake_words}",
+        )
+        return
+
+    # Verify all ids are available
+    available_ids = {ww.id for ww in config.available_wake_words}
+    for ww_id in actual_ids:
+        if ww_id not in available_ids:
+            connection.send_error(
+                msg["id"],
+                websocket_api.ERR_NOT_SUPPORTED,
+                f"Wake word id is not supported: {ww_id}",
+            )
+            return
+
+    await satellite.async_set_configuration(
+        replace(config, active_wake_words=actual_ids)
+    )
+    connection.send_result(msg["id"])
diff --git a/homeassistant/components/esphome/assist_satellite.py b/homeassistant/components/esphome/assist_satellite.py
index 7ce46fab64b..3c66c82a734 100644
--- a/homeassistant/components/esphome/assist_satellite.py
+++ b/homeassistant/components/esphome/assist_satellite.py
@@ -36,7 +36,7 @@ from homeassistant.components.intent import (
 from homeassistant.components.media_player import async_process_play_media_url
 from homeassistant.config_entries import ConfigEntry
 from homeassistant.const import EntityCategory, Platform
-from homeassistant.core import HomeAssistant
+from homeassistant.core import HomeAssistant, callback
 from homeassistant.helpers import entity_registry as er
 from homeassistant.helpers.entity_platform import AddEntitiesCallback
 
@@ -150,6 +150,19 @@ class EsphomeAssistSatellite(
             f"{self.entry_data.device_info.mac_address}-vad_sensitivity",
         )
 
+    @callback
+    def async_get_configuration(
+        self,
+    ) -> assist_satellite.AssistSatelliteConfiguration:
+        """Get the current satellite configuration."""
+        raise NotImplementedError
+
+    async def async_set_configuration(
+        self, config: assist_satellite.AssistSatelliteConfiguration
+    ) -> None:
+        """Set the current satellite configuration."""
+        raise NotImplementedError
+
     async def async_added_to_hass(self) -> None:
         """Run when entity about to be added to hass."""
         await super().async_added_to_hass()
diff --git a/homeassistant/components/voip/assist_satellite.py b/homeassistant/components/voip/assist_satellite.py
index f75f65a08ea..2f37a8a63e1 100644
--- a/homeassistant/components/voip/assist_satellite.py
+++ b/homeassistant/components/voip/assist_satellite.py
@@ -20,6 +20,7 @@ from homeassistant.components.assist_pipeline import (
     PipelineNotFound,
 )
 from homeassistant.components.assist_satellite import (
+    AssistSatelliteConfiguration,
     AssistSatelliteEntity,
     AssistSatelliteEntityDescription,
 )
@@ -141,6 +142,19 @@ class VoipAssistSatellite(VoIPEntity, AssistSatelliteEntity, RtpDatagramProtocol
         assert self.voip_device.protocol == self
         self.voip_device.protocol = None
 
+    @callback
+    def async_get_configuration(
+        self,
+    ) -> AssistSatelliteConfiguration:
+        """Get the current satellite configuration."""
+        raise NotImplementedError
+
+    async def async_set_configuration(
+        self, config: AssistSatelliteConfiguration
+    ) -> None:
+        """Set the current satellite configuration."""
+        raise NotImplementedError
+
     # -------------------------------------------------------------------------
     # VoIP
     # -------------------------------------------------------------------------
diff --git a/tests/components/assist_satellite/conftest.py b/tests/components/assist_satellite/conftest.py
index a14e9e9452b..3a374b312cc 100644
--- a/tests/components/assist_satellite/conftest.py
+++ b/tests/components/assist_satellite/conftest.py
@@ -8,11 +8,13 @@ import pytest
 from homeassistant.components.assist_pipeline import PipelineEvent
 from homeassistant.components.assist_satellite import (
     DOMAIN as AS_DOMAIN,
+    AssistSatelliteConfiguration,
     AssistSatelliteEntity,
     AssistSatelliteEntityFeature,
+    AssistSatelliteWakeWord,
 )
 from homeassistant.config_entries import ConfigEntry, ConfigFlow
-from homeassistant.core import HomeAssistant
+from homeassistant.core import HomeAssistant, callback
 from homeassistant.setup import async_setup_component
 
 from tests.common import (
@@ -42,6 +44,20 @@ class MockAssistSatellite(AssistSatelliteEntity):
         """Initialize the mock entity."""
         self.events = []
         self.announcements = []
+        self.config = AssistSatelliteConfiguration(
+            available_wake_words=[
+                AssistSatelliteWakeWord(
+                    id="1234", wake_word="okay nabu", trained_languages=["en"]
+                ),
+                AssistSatelliteWakeWord(
+                    id="5678",
+                    wake_word="hey jarvis",
+                    trained_languages=["en"],
+                ),
+            ],
+            active_wake_words=["1234"],
+            max_active_wake_words=1,
+        )
 
     def on_pipeline_event(self, event: PipelineEvent) -> None:
         """Handle pipeline events."""
@@ -51,6 +67,17 @@ class MockAssistSatellite(AssistSatelliteEntity):
         """Announce media on a device."""
         self.announcements.append((message, media_id))
 
+    @callback
+    def async_get_configuration(self) -> AssistSatelliteConfiguration:
+        """Get the current satellite configuration."""
+        return self.config
+
+    async def async_set_configuration(
+        self, config: AssistSatelliteConfiguration
+    ) -> None:
+        """Set the current satellite configuration."""
+        self.config = config
+
 
 @pytest.fixture
 def entity() -> MockAssistSatellite:
diff --git a/tests/components/assist_satellite/test_websocket_api.py b/tests/components/assist_satellite/test_websocket_api.py
index 7895ea2555a..709005e38cf 100644
--- a/tests/components/assist_satellite/test_websocket_api.py
+++ b/tests/components/assist_satellite/test_websocket_api.py
@@ -273,3 +273,115 @@ async def test_intercept_wake_word_unsubscribe(
 
         # Wake word should not be intercepted
         mock_pipeline_from_audio_stream.assert_called_once()
+
+
+async def test_get_configuration(
+    hass: HomeAssistant,
+    init_components: ConfigEntry,
+    entity: MockAssistSatellite,
+    hass_ws_client: WebSocketGenerator,
+) -> None:
+    """Test getting satellite configuration."""
+    ws_client = await hass_ws_client(hass)
+
+    with (
+        patch.object(entity, "_attr_pipeline_entity_id", "select.test_pipeline"),
+        patch.object(entity, "_attr_vad_sensitivity_entity_id", "select.test_vad"),
+    ):
+        await ws_client.send_json_auto_id(
+            {
+                "type": "assist_satellite/get_configuration",
+                "entity_id": ENTITY_ID,
+            }
+        )
+        msg = await ws_client.receive_json()
+        assert msg["success"]
+        assert msg["result"] == {
+            "active_wake_words": ["1234"],
+            "available_wake_words": [
+                {"id": "1234", "trained_languages": ["en"], "wake_word": "okay nabu"},
+                {"id": "5678", "trained_languages": ["en"], "wake_word": "hey jarvis"},
+            ],
+            "max_active_wake_words": 1,
+            "pipeline_entity_id": "select.test_pipeline",
+            "vad_entity_id": "select.test_vad",
+        }
+
+
+async def test_set_wake_words(
+    hass: HomeAssistant,
+    init_components: ConfigEntry,
+    entity: MockAssistSatellite,
+    hass_ws_client: WebSocketGenerator,
+) -> None:
+    """Test setting active wake words."""
+    ws_client = await hass_ws_client(hass)
+
+    await ws_client.send_json_auto_id(
+        {
+            "type": "assist_satellite/set_wake_words",
+            "entity_id": ENTITY_ID,
+            "wake_word_ids": ["5678"],
+        }
+    )
+    msg = await ws_client.receive_json()
+    assert msg["success"]
+
+    # Verify change
+    await ws_client.send_json_auto_id(
+        {
+            "type": "assist_satellite/get_configuration",
+            "entity_id": ENTITY_ID,
+        }
+    )
+    msg = await ws_client.receive_json()
+    assert msg["success"]
+    assert msg["result"].get("active_wake_words") == ["5678"]
+
+
+async def test_set_wake_words_exceed_maximum(
+    hass: HomeAssistant,
+    init_components: ConfigEntry,
+    entity: MockAssistSatellite,
+    hass_ws_client: WebSocketGenerator,
+) -> None:
+    """Test setting too many active wake words."""
+    ws_client = await hass_ws_client(hass)
+
+    await ws_client.send_json_auto_id(
+        {
+            "type": "assist_satellite/set_wake_words",
+            "entity_id": ENTITY_ID,
+            "wake_word_ids": ["1234", "5678"],  # max of 1
+        }
+    )
+    msg = await ws_client.receive_json()
+    assert not msg["success"]
+    assert msg["error"] == {
+        "code": "not_supported",
+        "message": "Maximum number of active wake words is 1",
+    }
+
+
+async def test_set_wake_words_bad_id(
+    hass: HomeAssistant,
+    init_components: ConfigEntry,
+    entity: MockAssistSatellite,
+    hass_ws_client: WebSocketGenerator,
+) -> None:
+    """Test setting active wake words with a bad id."""
+    ws_client = await hass_ws_client(hass)
+
+    await ws_client.send_json_auto_id(
+        {
+            "type": "assist_satellite/set_wake_words",
+            "entity_id": ENTITY_ID,
+            "wake_word_ids": ["abcd"],  # not an available id
+        }
+    )
+    msg = await ws_client.receive_json()
+    assert not msg["success"]
+    assert msg["error"] == {
+        "code": "not_supported",
+        "message": "Wake word id is not supported: abcd",
+    }
-- 
GitLab