From 6eab5e3e14cb29c78f42e4d015b5f6be13b109d0 Mon Sep 17 00:00:00 2001
From: Michael Hansen <mike@rhasspy.org>
Date: Mon, 16 Sep 2024 22:08:39 -0500
Subject: [PATCH] Add ESPHome Assist satellite configuration (#126085)

* Basic implementation

* Add websocket commands

* Clean up

* Add callback to other signatures

* Remove unused constant

* Re-add callback

* Add callback to test

* Implement get/set configuration

* Add tests

* Re-add constant

* Bump aioesphomeapi

---------

Co-authored-by: Paulus Schoutsen <balloob@gmail.com>
---
 .../components/esphome/assist_satellite.py    | 35 ++++++++++++-
 .../components/esphome/manifest.json          |  2 +-
 requirements_all.txt                          |  2 +-
 requirements_test_all.txt                     |  2 +-
 .../esphome/test_assist_satellite.py          | 49 +++++++++++++++++++
 5 files changed, 85 insertions(+), 5 deletions(-)

diff --git a/homeassistant/components/esphome/assist_satellite.py b/homeassistant/components/esphome/assist_satellite.py
index 3c66c82a734..f8ed4c48651 100644
--- a/homeassistant/components/esphome/assist_satellite.py
+++ b/homeassistant/components/esphome/assist_satellite.py
@@ -79,6 +79,7 @@ _TIMER_EVENT_TYPES: EsphomeEnumMapper[VoiceAssistantTimerEventType, TimerEventTy
 )
 
 _ANNOUNCEMENT_TIMEOUT_SEC = 5 * 60  # 5 minutes
+_CONFIG_TIMEOUT_SEC = 5
 
 
 async def async_setup_entry(
@@ -128,6 +129,11 @@ class EsphomeAssistSatellite(
         self._tts_streaming_task: asyncio.Task | None = None
         self._udp_server: VoiceAssistantUDPServer | None = None
 
+        # Empty config. Updated when added to HA.
+        self._satellite_config = assist_satellite.AssistSatelliteConfiguration(
+            available_wake_words=[], active_wake_words=[], max_active_wake_words=0
+        )
+
     @property
     def pipeline_entity_id(self) -> str | None:
         """Return the entity ID of the pipeline to use for the next conversation."""
@@ -155,13 +161,33 @@ class EsphomeAssistSatellite(
         self,
     ) -> assist_satellite.AssistSatelliteConfiguration:
         """Get the current satellite configuration."""
-        raise NotImplementedError
+        return self._satellite_config
 
     async def async_set_configuration(
         self, config: assist_satellite.AssistSatelliteConfiguration
     ) -> None:
         """Set the current satellite configuration."""
-        raise NotImplementedError
+        await self.cli.set_voice_assistant_configuration(
+            active_wake_words=config.active_wake_words
+        )
+        _LOGGER.debug("Set active wake words: %s", config.active_wake_words)
+
+    async def _update_satellite_config(self) -> None:
+        """Get the latest satellite configuration from the device."""
+        config = await self.cli.get_voice_assistant_configuration(_CONFIG_TIMEOUT_SEC)
+
+        # Update available/active wake words
+        self._satellite_config.available_wake_words = [
+            assist_satellite.AssistSatelliteWakeWord(
+                id=model.id,
+                wake_word=model.wake_word,
+                trained_languages=list(model.trained_languages),
+            )
+            for model in config.available_wake_words
+        ]
+        self._satellite_config.active_wake_words = list(config.active_wake_words)
+        self._satellite_config.max_active_wake_words = config.max_active_wake_words
+        _LOGGER.debug("Received satellite configuration: %s", self._satellite_config)
 
     async def async_added_to_hass(self) -> None:
         """Run when entity about to be added to hass."""
@@ -214,6 +240,11 @@ class EsphomeAssistSatellite(
             # Will use media player for TTS/announcements
             self._update_tts_format()
 
+        # Fetch latest config in the background
+        self.config_entry.async_create_background_task(
+            self.hass, self._update_satellite_config(), "esphome_voice_assistant_config"
+        )
+
     async def async_will_remove_from_hass(self) -> None:
         """Run when entity will be removed from hass."""
         await super().async_will_remove_from_hass()
diff --git a/homeassistant/components/esphome/manifest.json b/homeassistant/components/esphome/manifest.json
index dbf51aafae4..aca92f976cc 100644
--- a/homeassistant/components/esphome/manifest.json
+++ b/homeassistant/components/esphome/manifest.json
@@ -17,7 +17,7 @@
   "mqtt": ["esphome/discover/#"],
   "quality_scale": "platinum",
   "requirements": [
-    "aioesphomeapi==26.0.0",
+    "aioesphomeapi==27.0.0",
     "esphome-dashboard-api==1.2.3",
     "bleak-esphome==1.0.0"
   ],
diff --git a/requirements_all.txt b/requirements_all.txt
index a314b6c51cb..a40b660b548 100644
--- a/requirements_all.txt
+++ b/requirements_all.txt
@@ -240,7 +240,7 @@ aioelectricitymaps==0.4.0
 aioemonitor==1.0.5
 
 # homeassistant.components.esphome
-aioesphomeapi==26.0.0
+aioesphomeapi==27.0.0
 
 # homeassistant.components.flo
 aioflo==2021.11.0
diff --git a/requirements_test_all.txt b/requirements_test_all.txt
index d0341c2502b..3fc8d2bd20e 100644
--- a/requirements_test_all.txt
+++ b/requirements_test_all.txt
@@ -228,7 +228,7 @@ aioelectricitymaps==0.4.0
 aioemonitor==1.0.5
 
 # homeassistant.components.esphome
-aioesphomeapi==26.0.0
+aioesphomeapi==27.0.0
 
 # homeassistant.components.flo
 aioflo==2021.11.0
diff --git a/tests/components/esphome/test_assist_satellite.py b/tests/components/esphome/test_assist_satellite.py
index 5136e160e89..03111c0d8d8 100644
--- a/tests/components/esphome/test_assist_satellite.py
+++ b/tests/components/esphome/test_assist_satellite.py
@@ -27,8 +27,10 @@ import pytest
 from homeassistant.components import assist_satellite, tts
 from homeassistant.components.assist_pipeline import PipelineEvent, PipelineEventType
 from homeassistant.components.assist_satellite import (
+    AssistSatelliteConfiguration,
     AssistSatelliteEntity,
     AssistSatelliteEntityFeature,
+    AssistSatelliteWakeWord,
 )
 
 # pylint: disable-next=hass-component-root-import
@@ -1380,3 +1382,50 @@ async def test_pipeline_abort(
 
             # Only first chunk
             assert chunks == [b"before-abort"]
+
+
+async def test_get_set_configuration(
+    hass: HomeAssistant,
+    mock_client: APIClient,
+    mock_esphome_device: Callable[
+        [APIClient, list[EntityInfo], list[UserService], list[EntityState]],
+        Awaitable[MockESPHomeDevice],
+    ],
+) -> None:
+    """Test getting and setting the satellite configuration."""
+    expected_config = AssistSatelliteConfiguration(
+        available_wake_words=[
+            AssistSatelliteWakeWord("1234", "okay nabu", ["en"]),
+            AssistSatelliteWakeWord("5678", "hey jarvis", ["en"]),
+        ],
+        active_wake_words=["1234"],
+        max_active_wake_words=1,
+    )
+    mock_client.get_voice_assistant_configuration.return_value = expected_config
+
+    mock_device: MockESPHomeDevice = await mock_esphome_device(
+        mock_client=mock_client,
+        entity_info=[],
+        user_service=[],
+        states=[],
+        device_info={
+            "voice_assistant_feature_flags": VoiceAssistantFeature.VOICE_ASSISTANT
+        },
+    )
+    await hass.async_block_till_done()
+
+    satellite = get_satellite_entity(hass, mock_device.device_info.mac_address)
+    assert satellite is not None
+
+    # HA should have been updated
+    actual_config = satellite.async_get_configuration()
+    assert actual_config == expected_config
+
+    # Change active wake words
+    actual_config.active_wake_words = ["5678"]
+    await satellite.async_set_configuration(actual_config)
+
+    # Device should have been updated
+    mock_client.set_voice_assistant_configuration.assert_called_once_with(
+        active_wake_words=["5678"]
+    )
-- 
GitLab