From d8bcba9ef0a31383537c87d47ea8d58d12b2e18f Mon Sep 17 00:00:00 2001
From: Andrew Sayre <6730289+andrewsayre@users.noreply.github.com>
Date: Tue, 11 Mar 2025 13:00:43 -0500
Subject: [PATCH] Enable HEOS automatic failover (#140394)

Failover
---
 homeassistant/components/heos/coordinator.py | 18 +++++++++++++++---
 tests/components/heos/__init__.py            |  4 ++++
 tests/components/heos/test_init.py           | 19 +++++++++++++++++++
 3 files changed, 38 insertions(+), 3 deletions(-)

diff --git a/homeassistant/components/heos/coordinator.py b/homeassistant/components/heos/coordinator.py
index 93fe069d9be..0333c60ec21 100644
--- a/homeassistant/components/heos/coordinator.py
+++ b/homeassistant/components/heos/coordinator.py
@@ -43,7 +43,6 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
 
     def __init__(self, hass: HomeAssistant, config_entry: HeosConfigEntry) -> None:
         """Set up the coordinator and set in config_entry."""
-        self.host: str = config_entry.data[CONF_HOST]
         credentials: Credentials | None = None
         if config_entry.options:
             credentials = Credentials(
@@ -53,9 +52,10 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
         # media position update upon start of playback or when media changes
         self.heos = Heos(
             HeosOptions(
-                self.host,
+                config_entry.data[CONF_HOST],
                 all_progress_events=False,
                 auto_reconnect=True,
+                auto_failover=True,
                 credentials=credentials,
             )
         )
@@ -66,6 +66,11 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
         self._inputs: Sequence[MediaItem] = []
         super().__init__(hass, _LOGGER, config_entry=config_entry, name=DOMAIN)
 
+    @property
+    def host(self) -> str:
+        """Get the host address of the device."""
+        return self.heos.current_host
+
     @property
     def inputs(self) -> Sequence[MediaItem]:
         """Get input sources across all devices."""
@@ -159,8 +164,15 @@ class HeosCoordinator(DataUpdateCoordinator[None]):
 
     async def _async_on_reconnected(self) -> None:
         """Handle when reconnected so resources are updated and entities marked available."""
+        assert self.config_entry is not None
+        if self.host != self.config_entry.data[CONF_HOST]:
+            self.hass.config_entries.async_update_entry(
+                self.config_entry, data={CONF_HOST: self.host}
+            )
+            _LOGGER.warning("Successfully failed over to HEOS host %s", self.host)
+        else:
+            _LOGGER.warning("Successfully reconnected to HEOS host %s", self.host)
         await self._async_update_sources()
-        _LOGGER.warning("Successfully reconnected to HEOS host %s", self.host)
         self.async_update_listeners()
 
     async def _async_on_controller_event(
diff --git a/tests/components/heos/__init__.py b/tests/components/heos/__init__.py
index 016cc7b3580..862b1e5ffab 100644
--- a/tests/components/heos/__init__.py
+++ b/tests/components/heos/__init__.py
@@ -64,3 +64,7 @@ class MockHeos(Heos):
     def mock_set_connection_state(self, connection_state: ConnectionState) -> None:
         """Set the connection state on the mock instance."""
         self._connection._state = connection_state
+
+    def mock_set_current_host(self, host: str) -> None:
+        """Set the current host on the mock instance."""
+        self._connection._host = host
diff --git a/tests/components/heos/test_init.py b/tests/components/heos/test_init.py
index b155abaf0e9..7bc232ad5a6 100644
--- a/tests/components/heos/test_init.py
+++ b/tests/components/heos/test_init.py
@@ -297,6 +297,25 @@ async def test_reconnected_new_entities_created(
     assert entity_registry.async_get_entity_id(MEDIA_PLAYER_DOMAIN, DOMAIN, "3")
 
 
+async def test_reconnected_failover_updates_host(
+    hass: HomeAssistant, config_entry: MockConfigEntry, controller: MockHeos
+) -> None:
+    """Test the config entry host is updated after failover."""
+    config_entry.add_to_hass(hass)
+    assert await hass.config_entries.async_setup(config_entry.entry_id)
+    assert config_entry.data[CONF_HOST] == "127.0.0.1"
+
+    # Simulate reconnection
+    controller.mock_set_current_host("127.0.0.2")
+    await controller.dispatcher.wait_send(
+        SignalType.HEOS_EVENT, SignalHeosEvent.CONNECTED
+    )
+    await hass.async_block_till_done()
+
+    # Assert config entry host updated
+    assert config_entry.data[CONF_HOST] == "127.0.0.2"
+
+
 async def test_players_changed_new_entities_created(
     hass: HomeAssistant,
     entity_registry: er.EntityRegistry,
-- 
GitLab